diff --git a/services/auth/auth.go b/services/auth/auth.go index 93c95c8..75a1950 100644 --- a/services/auth/auth.go +++ b/services/auth/auth.go @@ -2,6 +2,7 @@ package auth import ( "context" + "encoding/base64" "errors" "github.com/goccy/go-json" "gofiber.study.skcks.cn/common/logger" @@ -36,15 +37,48 @@ func InitService() { // generateAndSaveRefreshToken // // 生成并保存 refreshToken 刷新令牌 -func (s *Service) generateAndSaveRefreshToken(user *models.User) (refreshToken string, err error) { - refreshToken, err = global.GetNanoId() +func (s *Service) generateAndSaveRefreshToken(user *models.User) (refreshTokenB64 string, err error) { + refreshTokenB64, err = global.GetNanoId() if err != nil { return } + refreshTokenB64 = user.Id + Separator + refreshTokenB64 + expire := time.Duration(global.Config.Jwt.Expire*2) * time.Second ctx := context.Background() - global.Redis.Set(ctx, RefreshTokenPrefix+refreshToken, utils.Json(user), expire) + global.Redis.Set(ctx, RefreshTokenPrefix+refreshTokenB64, utils.Json(user), expire) + + // base64 加密 + return base64.StdEncoding.EncodeToString([]byte(refreshTokenB64)), err +} + +func (s *Service) fromRefreshTokenGetUserCache(refreshTokenB64 string) (user *models.User, err error) { + refreshTokenBytes, err := base64.StdEncoding.DecodeString(refreshTokenB64) + if err != nil { + return nil, InvalidRefreshToken + } + + refreshToken := RefreshTokenPrefix + string(refreshTokenBytes) + + ctx := context.Background() + data, err := global.Redis.Get(ctx, refreshToken).Result() + if err != nil { + return nil, InvalidRefreshToken + } + global.Redis.Del(ctx, refreshToken) + cache := &models.User{} + err = json.Unmarshal([]byte(data), cache) + if err != nil { + return nil, InvalidRefreshToken + } + + user = &models.User{Id: cache.Id, Account: cache.Account} + exist, err := global.DataSources.Get(user) + if !exist { + logger.Log.Infof("未能从 %s 找到用户信息", refreshToken) + return nil, InvalidRefreshToken + } return } @@ -83,28 +117,10 @@ func (s *Service) Login(login *dto.Login) (result *vo.Login, err error) { }, err } -func (s *Service) RefreshToken(refreshToken string) (result *vo.Login, err error) { - refreshToken = RefreshTokenPrefix + refreshToken - - ctx := context.Background() - data, err := global.Redis.Get(ctx, refreshToken).Result() +func (s *Service) RefreshToken(refreshTokenB64 string) (result *vo.Login, err error) { + user, err := s.fromRefreshTokenGetUserCache(refreshTokenB64) if err != nil { - return nil, InvalidRefreshToken - } - - global.Redis.Del(ctx, refreshToken) - - cache := &models.User{} - err = json.Unmarshal([]byte(data), cache) - if err != nil { - return nil, InvalidRefreshToken - } - - user := &models.User{Id: cache.Id, Account: cache.Account} - exist, err := global.DataSources.Get(user) - if !exist { - logger.Log.Infof("未能从 %s 找到用户信息", refreshToken) - return nil, InvalidRefreshToken + return } token, err := global.GetToken(global.UserClaims{ @@ -116,7 +132,7 @@ func (s *Service) RefreshToken(refreshToken string) (result *vo.Login, err error return } - refreshToken, err = s.generateAndSaveRefreshToken(user) + refreshToken, err := s.generateAndSaveRefreshToken(user) return &vo.Login{ Token: token,