package waf import ( "context" "github.com/go-redis/redis/v8" "gofiber.study.skcks.cn/common/config" "gofiber.study.skcks.cn/common/logger" "gofiber.study.skcks.cn/global" "gofiber.study.skcks.cn/model/generic/models" "time" "xorm.io/xorm" ) var Service *Waf const ( StoreName = "waf" Separator = ":" ) type Waf struct { db *xorm.EngineGroup store *redis.Client config config.Waf } func InitService() { Service = &Waf{ db: global.DataSources, store: global.Redis, config: global.Config.Waf, } // //ctx := context.Background() //table, err := Service.store.HGetAll(ctx, StoreName).Result() //utils.MainAppExec(func() { // if err != nil { // logger.Log.Errorf("[waf] 初始化出错 %s", err) // } // // logger.Log.Debugf("[waf] \n%v", table) //}) } func (w *Waf) Access(ip string) bool { key := StoreName + Separator + "access" + Separator + ip wafModel := &models.Waf{Ip: ip} ban, err := w.db.Get(wafModel) if err != nil { logger.Log.Errorf("[waf] access 出错 %s", err) return false } if ban { logger.Log.Infof("[waf] 阻止黑名单 ip:%s 访问", ip) return false } ctx := context.Background() num, err := w.store.LLen(ctx, key).Result() if err != nil { logger.Log.Errorf("[waf] access 出错 %s", err) return false } if num < w.config.RateLimit { w.store.LPush(ctx, key, time.Now().Unix()) w.store.Expire(ctx, key, 300*time.Second) return true } else { last, _ := w.store.LIndex(ctx, key, -1).Int64() if time.Now().Unix()-last < 60 { logger.Log.Warnf("[waf] ip:%s 访问频率超过限制 %d", ip, w.config.RateLimit) return false } else { w.store.LPush(ctx, key, time.Now().Unix()) w.store.LTrim(ctx, key, 0, w.config.RateLimit-1) w.store.Expire(ctx, key, 300*time.Second) return true } } }