diff --git a/controller/types/type.go b/controller/types/type.go index ae4b381..2c69838 100644 --- a/controller/types/type.go +++ b/controller/types/type.go @@ -12,8 +12,13 @@ type Controller struct { Router fiber.Router } -func NewController(app *fiber.App, group string) *Controller { +func NewController(app *fiber.App, group string, middleware ...interface{}) *Controller { + router := app.Group(group) + for _, m := range middleware { + router.Use(m) + } + return &Controller{ - Router: app.Group(group), + Router: router, } } diff --git a/controller/user/user.go b/controller/user/user.go index 5983540..b8a5a79 100644 --- a/controller/user/user.go +++ b/controller/user/user.go @@ -5,6 +5,7 @@ import ( "gofiber.study.skcks.cn/common/errorx" "gofiber.study.skcks.cn/common/response" "gofiber.study.skcks.cn/controller/types" + "gofiber.study.skcks.cn/middleware" "gofiber.study.skcks.cn/services/user" ) @@ -18,7 +19,7 @@ func (c *Controller) GetRouter() fiber.Router { func NewController(app *fiber.App) *Controller { return &Controller{ - Controller: types.NewController(app, "/user"), + Controller: types.NewController(app, "/user", middleware.CasbinMiddleWare), } } diff --git a/global/jwt.go b/global/jwt.go index 34f1eb8..ed3eaf2 100644 --- a/global/jwt.go +++ b/global/jwt.go @@ -13,8 +13,9 @@ var JwtConfig *config.JwtConfig // @Param id body string true // @Param account body string true type UserClaims struct { - Id string `json:"id"` - Account string `json:"account"` + Id string `json:"id"` + Account string `json:"account"` + Identify string `json:"identify"` jwt.RegisteredClaims `swaggerignore:"true"` } diff --git a/middleware/casbin.go b/middleware/casbin.go index 549bec4..ff7ea95 100644 --- a/middleware/casbin.go +++ b/middleware/casbin.go @@ -5,6 +5,8 @@ import ( "github.com/casbin/casbin/v2/model" "github.com/casbin/xorm-adapter/v2" _ "github.com/go-sql-driver/mysql" + "github.com/gofiber/fiber/v2" + "gofiber.study.skcks.cn/common/errorx" "gofiber.study.skcks.cn/common/logger" "gofiber.study.skcks.cn/common/utils" "gofiber.study.skcks.cn/global" @@ -12,6 +14,11 @@ import ( "xorm.io/xorm" ) +const ( + DefaultSystem = "WEB" + CasBinSeparator = "::" +) + func NewCasbin(engine *xorm.EngineGroup) { casbinModels := make([]models.CasbinModel, 0) err := engine.Find(&casbinModels) @@ -54,3 +61,38 @@ func initRootGroupPermission(e *casbin.Enforcer) { // role::root 组 继承 root 的权限 _, _ = e.AddGroupingPolicy("role::root", "user::root") } + +func CasbinMiddleWare(c *fiber.Ctx) error { + headers := c.GetReqHeaders() + logger.Log.Debugf("headers %v", headers) + token := headers["Token"] + logger.Log.Debugf("token %s", token) + + userClaim, err := global.ParseToken(token) + if err != nil { + logger.Log.Errorf("[CasbinMiddleWare] 认证错误 %s", err) + return c.JSON(errorx.NewErrorWithCode(fiber.StatusUnauthorized, "认证失效, 请重新登录")) + } + + identify := userClaim.Identify + system := headers["System"] + if len(system) == 0 { + system = DefaultSystem + } + + uri := string(c.Request().URI().Path()) + method := c.Method() + + logger.Log.Debugf("identify=>%s, system=>%s, uri=>%s, method=>%s", identify, system, uri, method) + + hasPermission, err := global.Enforcer.Enforce(identify, system, uri, method) + if err != nil { + return c.JSON(errorx.ParseError(err)) + } + + if hasPermission { + return c.Next() + } else { + return c.JSON(errorx.NewErrorWithCode(fiber.StatusForbidden, "无权访问")) + } +} diff --git a/services/auth/auth.go b/services/auth/auth.go index 75a1950..0b5c06e 100644 --- a/services/auth/auth.go +++ b/services/auth/auth.go @@ -8,6 +8,7 @@ import ( "gofiber.study.skcks.cn/common/logger" "gofiber.study.skcks.cn/common/utils" "gofiber.study.skcks.cn/global" + "gofiber.study.skcks.cn/middleware" "gofiber.study.skcks.cn/model/dto" "gofiber.study.skcks.cn/model/generic/models" "gofiber.study.skcks.cn/model/vo" @@ -82,6 +83,14 @@ func (s *Service) fromRefreshTokenGetUserCache(refreshTokenB64 string) (user *mo return } +func (s *Service) getUserClaims(user *models.User) global.UserClaims { + return global.UserClaims{ + Id: user.Id, + Account: user.Account, + Identify: "user" + middleware.CasBinSeparator + user.Account, + } +} + func (s *Service) Login(login *dto.Login) (result *vo.Login, err error) { err = global.ValidateStruct(login) if err != nil { @@ -98,10 +107,7 @@ func (s *Service) Login(login *dto.Login) (result *vo.Login, err error) { return nil, Failed } - token, err := global.GetToken(global.UserClaims{ - Id: user.Id, - Account: user.Account, - }) + token, err := global.GetToken(s.getUserClaims(user)) if err != nil { return @@ -123,10 +129,7 @@ func (s *Service) RefreshToken(refreshTokenB64 string) (result *vo.Login, err er return } - token, err := global.GetToken(global.UserClaims{ - Id: user.Id, - Account: user.Account, - }) + token, err := global.GetToken(s.getUserClaims(user)) if err != nil { return