package middleware import ( "github.com/casbin/casbin/v2" "github.com/casbin/casbin/v2/model" "github.com/casbin/xorm-adapter/v2" _ "github.com/go-sql-driver/mysql" "github.com/gofiber/fiber/v2" "gofiber.study.skcks.cn/common/errorx" "gofiber.study.skcks.cn/common/logger" "gofiber.study.skcks.cn/common/utils" "gofiber.study.skcks.cn/global" "gofiber.study.skcks.cn/model/casbin_model/models" models2 "gofiber.study.skcks.cn/model/generic/models" "xorm.io/xorm" ) const ( DefaultSystem = "WEB" CasBinSeparator = "::" ) func NewCasbin(app *fiber.App, engine *xorm.EngineGroup) { casbinModels := make([]models.CasbinModel, 0) err := engine.Find(&casbinModels) if err != nil { logger.Log.Fatalf("[x] [casbin] model 加载失败: %s", err) } m := model.NewModel() for _, casbinModel := range casbinModels { m.AddDef(casbinModel.Type, casbinModel.Name, casbinModel.Rule) } // CasBin Xorm 适配器 // 若 不指定库名 且 不存在数据库 则会自动创建 casbin 数据库 和 casbin_rule 表 // 若 指定库名 三个参数 为 true 且 必须保证数据库已经存在 如果 不存在 casbin_rule 表 则会自动创建 //a, _ := xormadapter.NewAdapter("mysql", "root:12341234@tcp(10.10.10.100:3306)/") a, _ := xormadapter.NewAdapter("mysql", engine.DataSourceName(), true) e, _ := casbin.NewEnforcer(m, a) // 启用自动保存选项。 e.EnableAutoSave(true) err = e.LoadPolicy() if err != nil { logger.Log.Fatalf("[x] [casbin] policy 加载失败: %s", err) } // 初始化 role::root 组权限 initRootGroupPermission(app, e) global.Enforcer = e utils.MainAppExec(func() { logger.Log.Infof("[√] [casbin] 加载成功") }) } func initRootGroupPermission(app *fiber.App, e *casbin.Enforcer) { // 添加/修改 组策略 // role::root 组 继承 root 的权限 _, _ = e.AddGroupingPolicy("role::root", "user::root") for _, routes := range app.Stack() { for _, route := range routes { _, _ = e.AddPolicy("role::root", DefaultSystem, route.Path, route.Method) } } } func CasbinMiddleWare(c *fiber.Ctx) error { headers := c.GetReqHeaders() logger.Log.Debugf("headers %v", headers) token := headers["Token"] logger.Log.Debugf("token %s", token) userClaim, err := global.ParseToken(token) if err != nil { logger.Log.Errorf("[CasbinMiddleWare] 认证错误 %s", err) return c.JSON(errorx.NewErrorWithCode(fiber.StatusUnauthorized, "认证失效, 请重新登录")) } user := &models2.User{ Id: userClaim.Id, Account: userClaim.Account, } exist, err := global.DataSources.Get(user) if !exist { logger.Log.Errorf("[CasbinMiddleWare] 用户不存在 %s", err) return c.JSON(errorx.NewErrorWithCode(fiber.StatusUnauthorized, "认证失效, 请重新登录")) } if !user.Active { logger.Log.Errorf("[CasbinMiddleWare] 账号未启用 id: %s, account: %s", user.Id, user.Account) return c.JSON(errorx.NewErrorWithCode(fiber.StatusUnauthorized, "该账号暂未启用")) } identify := userClaim.Identify system := headers["System"] if len(system) == 0 { system = DefaultSystem } uri := string(c.Request().URI().Path()) method := c.Method() logger.Log.Debugf("[CasbinMiddleWare] identify=>%s, system=>%s, uri=>%s, method=>%s", identify, system, uri, method) hasPermission, err := global.Enforcer.Enforce(identify, system, uri, method) if err != nil { return c.JSON(errorx.ParseError(err)) } if hasPermission { return c.Next() } else { return c.JSON(errorx.NewErrorWithCode(fiber.StatusForbidden, "无权访问")) } }