casbin认证中间件

This commit is contained in:
Shikong 2022-10-22 18:18:45 +08:00
parent 26baabe76f
commit 26da36367c
5 changed files with 65 additions and 13 deletions

View File

@ -12,8 +12,13 @@ type Controller struct {
Router fiber.Router
}
func NewController(app *fiber.App, group string) *Controller {
func NewController(app *fiber.App, group string, middleware ...interface{}) *Controller {
router := app.Group(group)
for _, m := range middleware {
router.Use(m)
}
return &Controller{
Router: app.Group(group),
Router: router,
}
}

View File

@ -5,6 +5,7 @@ import (
"gofiber.study.skcks.cn/common/errorx"
"gofiber.study.skcks.cn/common/response"
"gofiber.study.skcks.cn/controller/types"
"gofiber.study.skcks.cn/middleware"
"gofiber.study.skcks.cn/services/user"
)
@ -18,7 +19,7 @@ func (c *Controller) GetRouter() fiber.Router {
func NewController(app *fiber.App) *Controller {
return &Controller{
Controller: types.NewController(app, "/user"),
Controller: types.NewController(app, "/user", middleware.CasbinMiddleWare),
}
}

View File

@ -13,8 +13,9 @@ var JwtConfig *config.JwtConfig
// @Param id body string true
// @Param account body string true
type UserClaims struct {
Id string `json:"id"`
Account string `json:"account"`
Id string `json:"id"`
Account string `json:"account"`
Identify string `json:"identify"`
jwt.RegisteredClaims `swaggerignore:"true"`
}

View File

@ -5,6 +5,8 @@ import (
"github.com/casbin/casbin/v2/model"
"github.com/casbin/xorm-adapter/v2"
_ "github.com/go-sql-driver/mysql"
"github.com/gofiber/fiber/v2"
"gofiber.study.skcks.cn/common/errorx"
"gofiber.study.skcks.cn/common/logger"
"gofiber.study.skcks.cn/common/utils"
"gofiber.study.skcks.cn/global"
@ -12,6 +14,11 @@ import (
"xorm.io/xorm"
)
const (
DefaultSystem = "WEB"
CasBinSeparator = "::"
)
func NewCasbin(engine *xorm.EngineGroup) {
casbinModels := make([]models.CasbinModel, 0)
err := engine.Find(&casbinModels)
@ -54,3 +61,38 @@ func initRootGroupPermission(e *casbin.Enforcer) {
// role::root 组 继承 root 的权限
_, _ = e.AddGroupingPolicy("role::root", "user::root")
}
func CasbinMiddleWare(c *fiber.Ctx) error {
headers := c.GetReqHeaders()
logger.Log.Debugf("headers %v", headers)
token := headers["Token"]
logger.Log.Debugf("token %s", token)
userClaim, err := global.ParseToken(token)
if err != nil {
logger.Log.Errorf("[CasbinMiddleWare] 认证错误 %s", err)
return c.JSON(errorx.NewErrorWithCode(fiber.StatusUnauthorized, "认证失效, 请重新登录"))
}
identify := userClaim.Identify
system := headers["System"]
if len(system) == 0 {
system = DefaultSystem
}
uri := string(c.Request().URI().Path())
method := c.Method()
logger.Log.Debugf("identify=>%s, system=>%s, uri=>%s, method=>%s", identify, system, uri, method)
hasPermission, err := global.Enforcer.Enforce(identify, system, uri, method)
if err != nil {
return c.JSON(errorx.ParseError(err))
}
if hasPermission {
return c.Next()
} else {
return c.JSON(errorx.NewErrorWithCode(fiber.StatusForbidden, "无权访问"))
}
}

View File

@ -8,6 +8,7 @@ import (
"gofiber.study.skcks.cn/common/logger"
"gofiber.study.skcks.cn/common/utils"
"gofiber.study.skcks.cn/global"
"gofiber.study.skcks.cn/middleware"
"gofiber.study.skcks.cn/model/dto"
"gofiber.study.skcks.cn/model/generic/models"
"gofiber.study.skcks.cn/model/vo"
@ -82,6 +83,14 @@ func (s *Service) fromRefreshTokenGetUserCache(refreshTokenB64 string) (user *mo
return
}
func (s *Service) getUserClaims(user *models.User) global.UserClaims {
return global.UserClaims{
Id: user.Id,
Account: user.Account,
Identify: "user" + middleware.CasBinSeparator + user.Account,
}
}
func (s *Service) Login(login *dto.Login) (result *vo.Login, err error) {
err = global.ValidateStruct(login)
if err != nil {
@ -98,10 +107,7 @@ func (s *Service) Login(login *dto.Login) (result *vo.Login, err error) {
return nil, Failed
}
token, err := global.GetToken(global.UserClaims{
Id: user.Id,
Account: user.Account,
})
token, err := global.GetToken(s.getUserClaims(user))
if err != nil {
return
@ -123,10 +129,7 @@ func (s *Service) RefreshToken(refreshTokenB64 string) (result *vo.Login, err er
return
}
token, err := global.GetToken(global.UserClaims{
Id: user.Id,
Account: user.Account,
})
token, err := global.GetToken(s.getUserClaims(user))
if err != nil {
return