diff --git a/go.mod b/go.mod index 379257d..1fe5e15 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,9 @@ module github.com/ecodeclub/ginx go 1.21.1 require ( + github.com/ecodeclub/ekit v0.0.8 github.com/gin-gonic/gin v1.9.1 + github.com/golang-jwt/jwt/v5 v5.0.0 github.com/redis/go-redis/v9 v9.1.0 github.com/stretchr/testify v1.8.3 go.uber.org/mock v0.3.0 diff --git a/go.sum b/go.sum index 6376821..1d68aa8 100644 --- a/go.sum +++ b/go.sum @@ -15,6 +15,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/ecodeclub/ekit v0.0.8 h1:861Aot0GvD5ueREEYDVYc1oIhDuFyg6MTxIyiOa4Pvw= +github.com/ecodeclub/ekit v0.0.8/go.mod h1:OqTojKeKFTxeeAAUwNIPKu339SRkX6KAuoK/8A5BCEs= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= @@ -31,6 +33,8 @@ github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= +github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -40,6 +44,8 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= @@ -49,6 +55,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -90,8 +98,9 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/middlewares/jwt/claims.go b/middlewares/jwt/claims.go new file mode 100644 index 0000000..44ae07a --- /dev/null +++ b/middlewares/jwt/claims.go @@ -0,0 +1,69 @@ +package jwt + +import ( + "time" + + "github.com/ecodeclub/ekit/bean/option" + "github.com/golang-jwt/jwt/v5" +) + +type RegisteredClaims[T any] struct { + Data T `json:"data"` + jwt.RegisteredClaims +} + +type Options struct { + Expire time.Duration // 有效期 + EncryptionKey string // 加密密钥 + DecryptKey string // 解密密钥 + Method jwt.SigningMethod // 签名方式 + Issuer string // 签发人 + genIDFn func() string // 生成 JWT ID (jti) 的函数 +} + +// NewOptions 定义一个 JWT 配置. +// DecryptKey: 默认与 EncryptionKey 相同. +// Method: 默认使用 jwt.SigningMethodHS256 签名方式. +func NewOptions(expire time.Duration, encryptionKey string, + opts ...option.Option[Options]) *Options { + dOpts := Options{ + Expire: expire, + EncryptionKey: encryptionKey, + DecryptKey: encryptionKey, + Method: jwt.SigningMethodHS256, + genIDFn: func() string { return "" }, + } + + option.Apply[Options](&dOpts, opts...) + + return &dOpts +} + +// WithDecryptKey 设置解密密钥. +func WithDecryptKey(decryptKey string) option.Option[Options] { + return func(o *Options) { + o.DecryptKey = decryptKey + } +} + +// WithMethod 设置 JWT 的签名方法. +func WithMethod(method jwt.SigningMethod) option.Option[Options] { + return func(o *Options) { + o.Method = method + } +} + +// WithIssuer 设置签发人. +func WithIssuer(issuer string) option.Option[Options] { + return func(o *Options) { + o.Issuer = issuer + } +} + +// WithGenIDFunc 设置生成 JWT ID 的函数. +// 可以设置成 WithGenIDFunc(uuid.NewString). +func WithGenIDFunc(fn func() string) option.Option[Options] { + return func(o *Options) { + o.genIDFn = fn + } +} diff --git a/middlewares/jwt/claims_test.go b/middlewares/jwt/claims_test.go new file mode 100644 index 0000000..7fbccfd --- /dev/null +++ b/middlewares/jwt/claims_test.go @@ -0,0 +1,184 @@ +package jwt + +import ( + "testing" + "time" + + "github.com/ecodeclub/ekit/bean/option" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" +) + +func TestNewOptions(t *testing.T) { + var genIDFn func() string + tests := []struct { + name string + expire time.Duration + encryptionKey string + want *Options + }{ + { + name: "normal", + expire: 10 * time.Minute, + encryptionKey: "sign key", + want: &Options{ + Expire: 10 * time.Minute, + EncryptionKey: "sign key", + DecryptKey: "sign key", + Method: jwt.SigningMethodHS256, + genIDFn: genIDFn, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewOptions(tt.expire, tt.encryptionKey) + got.genIDFn = genIDFn + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithDecryptKey(t *testing.T) { + tests := []struct { + name string + fn func() option.Option[Options] + want string + }{ + { + name: "normal", + fn: func() option.Option[Options] { + return nil + }, + want: encryptionKey, + }, + { + name: "set_another_key", + fn: func() option.Option[Options] { + return WithDecryptKey("another sign key") + }, + want: "another sign key", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got string + if tt.fn() == nil { + got = NewOptions(defaultExpire, encryptionKey). + DecryptKey + } else { + got = NewOptions(defaultExpire, encryptionKey, + tt.fn()).DecryptKey + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithMethod(t *testing.T) { + tests := []struct { + name string + fn func() option.Option[Options] + want jwt.SigningMethod + }{ + { + name: "normal", + fn: func() option.Option[Options] { + return nil + }, + want: jwt.SigningMethodHS256, + }, + { + name: "set_another_method", + fn: func() option.Option[Options] { + return WithMethod(jwt.SigningMethodHS384) + }, + want: jwt.SigningMethodHS384, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got jwt.SigningMethod + if tt.fn() == nil { + got = NewOptions(defaultExpire, encryptionKey). + Method + } else { + got = NewOptions(defaultExpire, encryptionKey, + tt.fn()).Method + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithIssuer(t *testing.T) { + tests := []struct { + name string + fn func() option.Option[Options] + want string + }{ + { + name: "normal", + fn: func() option.Option[Options] { + return nil + }, + }, + { + name: "set_another_issuer", + fn: func() option.Option[Options] { + return WithIssuer("foo") + }, + want: "foo", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got string + if tt.fn() == nil { + got = NewOptions(defaultExpire, encryptionKey). + Issuer + } else { + got = NewOptions(defaultExpire, encryptionKey, + tt.fn()).Issuer + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithGenIDFunc(t *testing.T) { + tests := []struct { + name string + fn func() option.Option[Options] + want string + }{ + { + name: "normal", + fn: func() option.Option[Options] { + return nil + }, + }, + { + name: "set_another_gen_id_func", + fn: func() option.Option[Options] { + return WithGenIDFunc(func() string { + return "unique id" + }) + }, + want: "unique id", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got string + if tt.fn() == nil { + got = NewOptions(defaultExpire, encryptionKey). + genIDFn() + } else { + got = NewOptions(defaultExpire, encryptionKey, + tt.fn()).genIDFn() + } + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/middlewares/jwt/jwt.go b/middlewares/jwt/jwt.go new file mode 100644 index 0000000..9bf6272 --- /dev/null +++ b/middlewares/jwt/jwt.go @@ -0,0 +1,292 @@ +package jwt + +import ( + "errors" + "fmt" + "log/slog" + "net/http" + "strings" + "time" + + "github.com/ecodeclub/ekit/bean/option" + "github.com/ecodeclub/ekit/set" + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" +) + +const bearerPrefix = "Bearer" + +var ( + ErrEmptyRefreshOpts = errors.New("refreshJWTOptions are nil") +) + +type Management[T any] struct { + ignorePath func(path string) bool // Middleware 方法中忽略认证的路径 + allowTokenHeader string // 认证的请求头(存放 token 的请求头 key) + exposeAccessHeader string // 暴露到外部的资源请求头 + exposeRefreshHeader string // 暴露到外部的刷新请求头 + + accessJWTOptions *Options // 资源 token 选项 + refreshJWTOptions *Options // 刷新 token 选项 + rotateRefreshToken bool // 轮换刷新令牌 + + nowFunc func() time.Time // 控制 jwt 的时间 +} + +// NewManagement 定义一个 Management. +// ignorePath: 默认使用 func(path string) bool { return false } 也就是全部不忽略. +// allowTokenHeader: 默认使用 authorization 为认证请求头. +// exposeAccessHeader: 默认使用 x-access-token 为暴露外部的资源请求头. +// exposeRefreshHeader: 默认使用 x-refresh-token 为暴露外部的刷新请求头. +// refreshJWTOptions: 默认使用 nil 为刷新 token 的配置, +// 如要使用 refresh 相关功能则需要使用 WithRefreshJWTOptions 添加相关配置. +// rotateRefreshToken: 默认不轮换刷新令牌. +// 该配置需要设置 refreshJWTOptions 才有效. +func NewManagement[T any](accessJWTOptions *Options, + opts ...option.Option[Management[T]]) *Management[T] { + + if accessJWTOptions == nil { + panic("accessJWTOptions 不允许为 nil") + } + dOpts := defaultManagementOptions[T]() + dOpts.accessJWTOptions = accessJWTOptions + option.Apply[Management[T]](&dOpts, opts...) + + return &dOpts +} + +func defaultManagementOptions[T any]() Management[T] { + return Management[T]{ + ignorePath: func(path string) bool { return false }, + allowTokenHeader: "authorization", + exposeAccessHeader: "x-access-token", + exposeRefreshHeader: "x-refresh-token", + rotateRefreshToken: false, + nowFunc: time.Now, + } +} + +// WithIgnorePath 设置忽略资源令牌认证的路径. +func WithIgnorePath[T any](fn func(path string) bool) option.Option[Management[T]] { + return func(m *Management[T]) { + m.ignorePath = fn + } +} + +// StaticIgnorePaths 设置静态忽略的路径. +func StaticIgnorePaths(paths ...string) func(path string) bool { + s := set.NewMapSet[string](len(paths)) + for _, path := range paths { + s.Add(path) + } + return func(path string) bool { + if s.Exist(path) { + return true + } + return false + } +} + +// WithAllowTokenHeader 设置允许 token 的请求头. +func WithAllowTokenHeader[T any](header string) option.Option[Management[T]] { + return func(m *Management[T]) { + m.allowTokenHeader = header + } +} + +// WithExposeAccessHeader 设置公开资源令牌的请求头. +func WithExposeAccessHeader[T any](header string) option.Option[Management[T]] { + return func(m *Management[T]) { + m.exposeAccessHeader = header + } +} + +// WithExposeRefreshHeader 设置公开刷新令牌的请求头. +func WithExposeRefreshHeader[T any](header string) option.Option[Management[T]] { + return func(m *Management[T]) { + m.exposeRefreshHeader = header + } +} + +// WithRefreshJWTOptions 设置刷新令牌相关的配置. +func WithRefreshJWTOptions[T any](refreshOpts *Options) option.Option[Management[T]] { + return func(m *Management[T]) { + m.refreshJWTOptions = refreshOpts + } +} + +// WithRotateRefreshToken 设置轮换刷新令牌. +func WithRotateRefreshToken[T any](isRotate bool) option.Option[Management[T]] { + return func(m *Management[T]) { + m.rotateRefreshToken = isRotate + } +} + +// WithNowFunc 设置当前时间. +// 一般用于测试固定 jwt. +func WithNowFunc[T any](nowFunc func() time.Time) option.Option[Management[T]] { + return func(m *Management[T]) { + m.nowFunc = nowFunc + } +} + +// Refresh 刷新 token 的 gin.HandlerFunc. +func (m *Management[T]) Refresh(ctx *gin.Context) { + if m.refreshJWTOptions == nil { + slog.Error("refreshJWTOptions 为 nil, 请使用 WithRefreshJWTOptions 设置 refresh 相关的配置") + ctx.Status(http.StatusInternalServerError) + return + } + + tokenStr := m.extractTokenString(ctx) + clm, err := m.VerifyRefreshToken(tokenStr, + jwt.WithTimeFunc(m.nowFunc)) + if err != nil { + slog.Debug("refresh token verification failed") + ctx.Status(http.StatusUnauthorized) + return + } + accessToken, err := m.GenerateAccessToken(clm.Data) + if err != nil { + slog.Error("failed to generate access token") + ctx.Status(http.StatusInternalServerError) + return + } + ctx.Header(m.exposeAccessHeader, accessToken) + + // 轮换刷新令牌 + if m.rotateRefreshToken { + refreshToken, err := m.GenerateRefreshToken(clm.Data) + if err != nil { + slog.Error("failed to generate refresh token") + ctx.Status(http.StatusInternalServerError) + return + } + ctx.Header(m.exposeRefreshHeader, refreshToken) + } + ctx.Status(http.StatusNoContent) +} + +// Middleware 登录认证的中间件. +func (m *Management[T]) Middleware() gin.HandlerFunc { + return func(ctx *gin.Context) { + // 不需要校验 + if m.ignorePath(ctx.Request.URL.Path) { + return + } + + // 提取 token + tokenStr := m.extractTokenString(ctx) + if tokenStr == "" { + slog.Debug("failed to extract token") + ctx.AbortWithStatus(http.StatusUnauthorized) + return + } + + // 校验 token + clm, err := m.VerifyAccessToken(tokenStr, + jwt.WithTimeFunc(m.nowFunc)) + if err != nil { + slog.Debug("access token verification failed") + ctx.AbortWithStatus(http.StatusUnauthorized) + return + } + + // 设置 claims + m.SetClaims(ctx, clm) + } +} + +// extractTokenString 提取 token 字符串. +func (m *Management[T]) extractTokenString(ctx *gin.Context) string { + authCode := ctx.GetHeader(m.allowTokenHeader) + if authCode == "" { + return "" + } + var b strings.Builder + b.WriteString(bearerPrefix) + b.WriteString(" ") + prefix := b.String() + if strings.HasPrefix(authCode, prefix) { + return authCode[len(prefix):] + } + return "" +} + +// GenerateAccessToken 生成资源 token. +func (m *Management[T]) GenerateAccessToken(data T) (string, error) { + nowTime := m.nowFunc() + claims := RegisteredClaims[T]{ + Data: data, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: m.accessJWTOptions.Issuer, + ExpiresAt: jwt.NewNumericDate(nowTime.Add(m.accessJWTOptions.Expire)), + IssuedAt: jwt.NewNumericDate(nowTime), + ID: m.accessJWTOptions.genIDFn(), + }, + } + + token := jwt.NewWithClaims(m.accessJWTOptions.Method, claims) + return token.SignedString([]byte(m.accessJWTOptions.EncryptionKey)) +} + +// VerifyAccessToken 校验资源 token. +func (m *Management[T]) VerifyAccessToken(token string, opts ...jwt.ParserOption) (RegisteredClaims[T], error) { + t, err := jwt.ParseWithClaims(token, &RegisteredClaims[T]{}, + func(*jwt.Token) (interface{}, error) { + return []byte(m.accessJWTOptions.DecryptKey), nil + }, + opts..., + ) + if err != nil || !t.Valid { + return RegisteredClaims[T]{}, fmt.Errorf("验证失败: %v", err) + } + clm, _ := t.Claims.(*RegisteredClaims[T]) + return *clm, nil +} + +// GenerateRefreshToken 生成刷新 token. +// 需要设置 refreshJWTOptions 否则返回 ErrEmptyRefreshOpts 错误. +func (m *Management[T]) GenerateRefreshToken(data T) (string, error) { + if m.refreshJWTOptions == nil { + return "", ErrEmptyRefreshOpts + } + + nowTime := m.nowFunc() + claims := RegisteredClaims[T]{ + Data: data, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: m.refreshJWTOptions.Issuer, + ExpiresAt: jwt.NewNumericDate(nowTime.Add(m.refreshJWTOptions.Expire)), + IssuedAt: jwt.NewNumericDate(nowTime), + ID: m.refreshJWTOptions.genIDFn(), + }, + } + + token := jwt.NewWithClaims(m.refreshJWTOptions.Method, claims) + return token.SignedString([]byte(m.refreshJWTOptions.EncryptionKey)) +} + +// VerifyRefreshToken 校验刷新 token. +// 需要设置 refreshJWTOptions 否则返回 ErrEmptyRefreshOpts 错误. +func (m *Management[T]) VerifyRefreshToken(token string, opts ...jwt.ParserOption) (RegisteredClaims[T], error) { + if m.refreshJWTOptions == nil { + return RegisteredClaims[T]{}, ErrEmptyRefreshOpts + } + t, err := jwt.ParseWithClaims(token, &RegisteredClaims[T]{}, + func(*jwt.Token) (interface{}, error) { + return []byte(m.refreshJWTOptions.DecryptKey), nil + }, + opts..., + ) + if err != nil || !t.Valid { + return RegisteredClaims[T]{}, fmt.Errorf("验证失败: %v", err) + } + clm, _ := t.Claims.(*RegisteredClaims[T]) + return *clm, nil +} + +// SetClaims 设置 claims 到 key=`claims` 的 gin.Context 中. +func (m *Management[T]) SetClaims(ctx *gin.Context, claims RegisteredClaims[T]) { + ctx.Set("claims", claims) +} diff --git a/middlewares/jwt/jwt_test.go b/middlewares/jwt/jwt_test.go new file mode 100644 index 0000000..9f01891 --- /dev/null +++ b/middlewares/jwt/jwt_test.go @@ -0,0 +1,975 @@ +package jwt + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/ecodeclub/ekit/bean/option" + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" +) + +type data struct { + Foo string `json:"foo"` +} + +var ( + defaultExpire = 10 * time.Minute + defaultClaims = RegisteredClaims[data]{ + Data: data{Foo: "1"}, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(nowTime.Add(defaultExpire)), + IssuedAt: jwt.NewNumericDate(nowTime), + }, + } + encryptionKey = "sign key" + nowTime = time.UnixMilli(1695571200000) + defaultOption = NewOptions(defaultExpire, encryptionKey) + defaultIgnorePaths = func(path string) bool { + ignorePaths := []string{"/login", "/signup"} + for _, ignorePath := range ignorePaths { + if path == ignorePath { + return true + } + } + return false + } + defaultManagement = NewManagement[data](defaultOption, + WithNowFunc[data](func() time.Time { + return nowTime + }), + ) +) + +func TestManagement_Middleware(t *testing.T) { + type testCase[T any] struct { + name string + m *Management[T] + reqBuilder func(t *testing.T) *http.Request + wantCode int + } + tests := []testCase[data]{ + { + // 验证失败 + name: "verify_failed", + m: NewManagement[data](defaultOption, + WithIgnorePath[data](StaticIgnorePaths("/login"))), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.RMpM5YNgxl9OtCy4lt_JRxv6k8s6plCkthnAV-vbXEQ") + return req + }, + wantCode: http.StatusUnauthorized, + }, + { + // 提取 token 失败 + name: "extract_token_failed", + m: NewManagement[data](defaultOption, + WithIgnorePath[data](StaticIgnorePaths("/login"))), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer ") + return req + }, + wantCode: http.StatusUnauthorized, + }, + { + // 无需认证直接通过 + name: "pass_without_authentication", + m: NewManagement[data](defaultOption, + WithIgnorePath[data](StaticIgnorePaths("/login"))), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/login", nil) + if err != nil { + t.Fatal(err) + } + return req + }, + wantCode: http.StatusOK, + }, + { + // 验证通过 + name: "pass_the_verification", + m: NewManagement[data](defaultOption, + WithIgnorePath[data](StaticIgnorePaths("/login")), + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695571500000) + }), + ), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.RMpM5YNgxl9OtCy4lt_JRxv6k8s6plCkthnAV-vbXEQ") + return req + }, + wantCode: http.StatusOK, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := gin.Default() + server.Use(tt.m.Middleware()) + tt.m.registerRoutes(server) + + req := tt.reqBuilder(t) + recorder := httptest.NewRecorder() + + server.ServeHTTP(recorder, req) + assert.Equal(t, tt.wantCode, recorder.Code) + }) + } +} + +func TestManagement_Refresh(t *testing.T) { + type testCase[T any] struct { + name string + m *Management[T] + reqBuilder func(t *testing.T) *http.Request + wantCode int + wantAccessToken string + wantRefreshToken string + } + tests := []testCase[data]{ + { + // 更新资源令牌并轮换刷新令牌 + name: "refresh_access_token_and_rotate_refresh_token", + m: NewManagement[data](defaultOption, + WithRefreshJWTOptions[data]( + NewOptions(24*60*time.Minute, + "refresh sign key", + )), + WithRotateRefreshToken[data](true), + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695623000000) + }), + ), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/refresh", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE") + return req + }, + wantCode: http.StatusNoContent, + wantAccessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjIzNjAwLCJpYXQiOjE2OTU2MjMwMDB9.i4kCx4-s5EM0a8w2o0usSfkMTLmzUSuEe-inlzg6ru0", + wantRefreshToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NzA5NDAwLCJpYXQiOjE2OTU2MjMwMDB9.IzPgEwXgoAwaFK-eby4uMl0GYBQwdfZYRi2Bhk3iE_8", + }, + { + // 更新资源令牌但轮换刷新令牌生成失败 + name: "refresh_access_token_but_gen_rotate_refresh_token_failed", + m: NewManagement[data](defaultOption, + WithRefreshJWTOptions[data]( + NewOptions(24*60*time.Minute, + "refresh sign key", + WithMethod(jwt.SigningMethodRS256), + )), + WithRotateRefreshToken[data](true), + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695623000000) + }), + ), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/refresh", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE") + return req + }, + wantCode: http.StatusInternalServerError, + }, + { + // 更新资源令牌 + name: "refresh_access_token", + m: NewManagement[data](defaultOption, + WithRefreshJWTOptions[data]( + NewOptions(24*60*time.Minute, + "refresh sign key", + )), + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695623000000) + }), + ), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/refresh", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE") + return req + }, + wantCode: http.StatusNoContent, + wantAccessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjIzNjAwLCJpYXQiOjE2OTU2MjMwMDB9.i4kCx4-s5EM0a8w2o0usSfkMTLmzUSuEe-inlzg6ru0", + }, + { + // 生成资源令牌失败 + name: "gen_access_token_failed", + m: NewManagement[data]( + &Options{ + Expire: 10 * time.Minute, + EncryptionKey: encryptionKey, + DecryptKey: encryptionKey, + Method: jwt.SigningMethodRS256, + genIDFn: func() string { return "" }, + }, + WithRefreshJWTOptions[data]( + NewOptions(24*60*time.Minute, + "refresh sign key", + )), + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695623000000) + }), + ), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/refresh", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE") + return req + }, + wantCode: http.StatusInternalServerError, + }, + { + // 刷新令牌认证失败 + name: "refresh_token_verify_failed", + m: NewManagement[data]( + defaultOption, + WithRefreshJWTOptions[data]( + NewOptions(24*60*time.Minute, + "refresh sign key", + )), + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695723000000) + }), + ), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/refresh", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE") + return req + }, + wantCode: http.StatusUnauthorized, + }, + { + // 没有设置刷新令牌选项 + name: "not_set_refreshJWTOptions", + m: NewManagement[data]( + defaultOption, + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695723000000) + }), + ), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/refresh", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE") + return req + }, + wantCode: http.StatusInternalServerError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := gin.Default() + tt.m.registerRoutes(server) + + req := tt.reqBuilder(t) + recorder := httptest.NewRecorder() + + server.ServeHTTP(recorder, req) + assert.Equal(t, tt.wantCode, recorder.Code) + if tt.wantCode != http.StatusOK { + return + } + assert.Equal(t, tt.wantAccessToken, + recorder.Header().Get("x-access-token")) + assert.Equal(t, tt.wantRefreshToken, + recorder.Header().Get("x-refresh-token")) + }) + } +} + +func TestManagement_GenerateAccessToken(t *testing.T) { + m := defaultManagement + type testCase[T any] struct { + name string + data T + want string + wantErr error + } + tests := []testCase[data]{ + { + name: "normal", + data: data{Foo: "1"}, + want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.RMpM5YNgxl9OtCy4lt_JRxv6k8s6plCkthnAV-vbXEQ", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := m.GenerateAccessToken(tt.data) + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestManagement_VerifyAccessToken(t *testing.T) { + type testCase[T any] struct { + name string + m *Management[T] + token string + want RegisteredClaims[T] + wantErr error + } + tests := []testCase[data]{ + { + name: "normal", + m: defaultManagement, + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.RMpM5YNgxl9OtCy4lt_JRxv6k8s6plCkthnAV-vbXEQ", + want: defaultClaims, + }, + { + // token 过期了 + name: "token_expired", + m: NewManagement[data](defaultOption, + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695671200000) + }), + ), + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.RMpM5YNgxl9OtCy4lt_JRxv6k8s6plCkthnAV-vbXEQ", + wantErr: fmt.Errorf("验证失败: %v", + fmt.Errorf("%v: %v", jwt.ErrTokenInvalidClaims, jwt.ErrTokenExpired)), + }, + { + // token 签名错误 + name: "bad_sign_key", + m: NewManagement[data]( + defaultOption, + WithNowFunc[data](func() time.Time { + return nowTime + }), + ), + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.pnP991l48s_j4fkiZnmh48gjgDGult9Or_wLChHvYp0", + wantErr: fmt.Errorf("验证失败: %v", + fmt.Errorf("%v: %v", jwt.ErrTokenSignatureInvalid, jwt.ErrSignatureInvalid)), + }, + { + // 错误的 token + name: "bad_token", + m: defaultManagement, + token: "bad_token", + wantErr: fmt.Errorf("验证失败: %v: token contains an invalid number of segments", + jwt.ErrTokenMalformed), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.m.VerifyAccessToken(tt.token, + jwt.WithTimeFunc(tt.m.nowFunc)) + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestManagement_GenerateRefreshToken(t *testing.T) { + m := defaultManagement + type testCase[T any] struct { + name string + refreshJWTOptions *Options + data T + want string + wantErr error + } + tests := []testCase[data]{ + { + name: "normal", + refreshJWTOptions: NewOptions(24*60*time.Minute, "refresh sign key"), + data: data{Foo: "1"}, + want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE", + }, + { + name: "mistake", + data: data{Foo: "1"}, + want: "", + wantErr: ErrEmptyRefreshOpts, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m.refreshJWTOptions = tt.refreshJWTOptions + got, err := m.GenerateRefreshToken(tt.data) + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestManagement_VerifyRefreshToken(t *testing.T) { + defaultRefOpts := &Options{ + Expire: 24 * 60 * time.Minute, + EncryptionKey: "refresh sign key", + DecryptKey: "refresh sign key", + Method: jwt.SigningMethodHS256, + } + type testCase[T any] struct { + name string + m *Management[T] + token string + want RegisteredClaims[T] + wantErr error + } + tests := []testCase[data]{ + { + name: "normal", + m: NewManagement[data](defaultOption, + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695601200000) + }), + WithRefreshJWTOptions[data](defaultRefOpts), + ), + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE", + want: RegisteredClaims[data]{ + Data: data{Foo: "1"}, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(nowTime.Add(24 * 60 * time.Minute)), + IssuedAt: jwt.NewNumericDate(nowTime), + }, + }, + }, + { + // token 过期了 + name: "token_expired", + m: NewManagement[data](defaultOption, + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695701200000) + }), + WithRefreshJWTOptions[data](defaultRefOpts), + ), + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE", + wantErr: fmt.Errorf("验证失败: %v", + fmt.Errorf("%v: %v", jwt.ErrTokenInvalidClaims, jwt.ErrTokenExpired)), + }, + { + // token 签名错误 + name: "bad_sign_key", + m: NewManagement[data](defaultOption, + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695601200000) + }), + WithRefreshJWTOptions[data](defaultRefOpts), + ), + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.yZ_ZlD1jE-0b3qd0bicTDLSdwGsenv6tRmOEqMCM2uw", + wantErr: fmt.Errorf("验证失败: %v", + fmt.Errorf("%v: %v", jwt.ErrTokenSignatureInvalid, jwt.ErrSignatureInvalid)), + }, + { + // 错误的 token + name: "bad_token", + m: NewManagement[data](defaultOption, + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695601200000) + }), + WithRefreshJWTOptions[data](defaultRefOpts), + ), + token: "bad_token", + wantErr: fmt.Errorf("验证失败: %v: token contains an invalid number of segments", + jwt.ErrTokenMalformed), + }, + { + name: "no_refresh_options", + m: NewManagement[data](defaultOption, + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695601200000) + }), + ), + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE", + wantErr: ErrEmptyRefreshOpts, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.m.VerifyRefreshToken(tt.token, + jwt.WithTimeFunc(tt.m.nowFunc)) + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestManagement_SetClaims(t *testing.T) { + m := defaultManagement + type testCase[T any] struct { + name string + claims RegisteredClaims[T] + want RegisteredClaims[T] + wantErr error + } + tests := []testCase[data]{ + { + name: "normal", + claims: defaultClaims, + want: defaultClaims, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, _ := gin.CreateTestContext(httptest.NewRecorder()) + m.SetClaims(ctx, tt.claims) + v, ok := ctx.Get("claims") + if !ok { + t.Errorf("claims not found") + } + clm, ok := v.(RegisteredClaims[data]) + if !ok { + t.Errorf("claims type error") + } + assert.Equal(t, tt.want, clm) + }) + } +} + +func TestManagement_extractTokenString(t *testing.T) { + m := defaultManagement + type header struct { + key string + value string + } + type testCase[T any] struct { + name string + header header + want string + } + tests := []testCase[data]{ + { + name: "normal", + header: header{ + key: "authorization", + value: "Bearer token", + }, + want: "token", + }, + { + name: "mistake_prefix", + header: header{ + key: "authorization", + value: "bearer token", + }, + }, + { + name: "no_allow_token_header", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req, err := http.NewRequest(http.MethodGet, "", nil) + req.Header.Add(tt.header.key, tt.header.value) + if err != nil { + t.Fatal(err) + } + ctx.Request = req + + got := m.extractTokenString(ctx) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestNewManagement(t *testing.T) { + type testCase[T any] struct { + name string + accessJWTOptions *Options + wantPanic bool + } + tests := []testCase[data]{ + { + name: "normal", + accessJWTOptions: defaultOption, + wantPanic: false, + }, + { + name: "accessJWTOptions_are_nil", + accessJWTOptions: nil, + wantPanic: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if err := recover(); err != nil { + if !tt.wantPanic { + t.Errorf("期望出现 painc ,但没有") + } + } + }() + NewManagement[data](tt.accessJWTOptions) + }) + } +} + +func TestWithIgnorePath(t *testing.T) { + type testCase[T any] struct { + name string + fn func() option.Option[Management[T]] + paths []string + want []bool + } + tests := []testCase[data]{ + { + name: "default", + fn: func() option.Option[Management[data]] { + return nil + }, + paths: []string{"profile", "abc"}, + want: []bool{false, false}, + }, + { + name: "all_exists_paths", + fn: func() option.Option[Management[data]] { + return WithIgnorePath[data](defaultIgnorePaths) + }, + paths: []string{"/login", "/signup"}, + want: []bool{true, true}, + }, + { + name: "one_path_does_not_exist", + fn: func() option.Option[Management[data]] { + return WithIgnorePath[data](defaultIgnorePaths) + }, + paths: []string{"/login", "/profile", "/signup"}, + want: []bool{true, false, true}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var ignoreFn func(path string) bool + if tt.fn() == nil { + ignoreFn = NewManagement[data]( + defaultOption, + ).ignorePath + } else { + ignoreFn = NewManagement[data]( + defaultOption, + tt.fn(), + ).ignorePath + } + exists := make([]bool, 0, len(tt.paths)) + for _, path := range tt.paths { + exists = append(exists, ignoreFn(path)) + } + assert.Equal(t, tt.want, exists) + }) + } +} + +func TestStaticIgnorePaths(t *testing.T) { + tests := []struct { + name string + paths []string + requestPaths []string + want []bool + }{ + { + name: "normal", + paths: []string{"login", "signup"}, + requestPaths: []string{"profile", "login", "info", "signup"}, + want: []bool{false, true, false, true}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotBool := make([]bool, 0, len(tt.want)) + fn := StaticIgnorePaths(tt.paths...) + for _, path := range tt.requestPaths { + gotBool = append(gotBool, fn(path)) + } + assert.Equal(t, tt.want, gotBool) + }) + } +} + +func TestWithAllowTokenHeader(t *testing.T) { + type testCase[T any] struct { + name string + fn func() option.Option[Management[T]] + want string + } + tests := []testCase[data]{ + { + name: "default", + fn: func() option.Option[Management[data]] { + return nil + }, + want: "authorization", + }, + { + name: "set_another_header", + fn: func() option.Option[Management[data]] { + return WithAllowTokenHeader[data]("jwt") + }, + want: "jwt", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got string + if tt.fn() == nil { + got = NewManagement[data]( + defaultOption, + ).allowTokenHeader + } else { + got = NewManagement[data]( + defaultOption, + tt.fn(), + ).allowTokenHeader + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithExposeAccessHeader(t *testing.T) { + type testCase[T any] struct { + name string + fn func() option.Option[Management[T]] + want string + } + tests := []testCase[data]{ + { + name: "default", + fn: func() option.Option[Management[data]] { + return nil + }, + want: "x-access-token", + }, + { + name: "set_another_header", + fn: func() option.Option[Management[data]] { + return WithExposeAccessHeader[data]("token") + }, + want: "token", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got string + if tt.fn() == nil { + got = NewManagement[data]( + defaultOption, + ).exposeAccessHeader + } else { + got = NewManagement[data]( + defaultOption, + tt.fn(), + ).exposeAccessHeader + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithExposeRefreshHeader(t *testing.T) { + type testCase[T any] struct { + name string + fn func() option.Option[Management[T]] + want string + } + tests := []testCase[data]{ + { + name: "default", + fn: func() option.Option[Management[data]] { + return nil + }, + want: "x-refresh-token", + }, + { + name: "set_another_header", + fn: func() option.Option[Management[data]] { + return WithExposeRefreshHeader[data]("refresh-token") + }, + want: "refresh-token", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got string + if tt.fn() == nil { + got = NewManagement[data]( + defaultOption, + ).exposeRefreshHeader + } else { + got = NewManagement[data]( + defaultOption, + tt.fn(), + ).exposeRefreshHeader + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithRotateRefreshToken(t *testing.T) { + type testCase[T any] struct { + name string + fn func() option.Option[Management[T]] + want bool + } + tests := []testCase[data]{ + { + name: "default", + fn: func() option.Option[Management[data]] { + return nil + }, + want: false, + }, + { + name: "set_another_header", + fn: func() option.Option[Management[data]] { + return WithRotateRefreshToken[data](true) + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got bool + if tt.fn() == nil { + got = NewManagement[data]( + defaultOption, + ).rotateRefreshToken + } else { + got = NewManagement[data]( + defaultOption, + tt.fn(), + ).rotateRefreshToken + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithNowFunc(t *testing.T) { + type testCase[T any] struct { + name string + fn func() option.Option[Management[T]] + want time.Time + } + tests := []testCase[data]{ + { + name: "default", + fn: func() option.Option[Management[data]] { + return nil + }, + want: time.Now(), + }, + { + name: "set_another_now_func", + fn: func() option.Option[Management[data]] { + return WithNowFunc[data](func() time.Time { + return nowTime + }) + }, + want: nowTime, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got time.Time + if tt.fn() == nil { + got = NewManagement[data]( + defaultOption, + ).nowFunc() + } else { + got = NewManagement[data]( + defaultOption, + tt.fn(), + ).nowFunc() + } + assert.Equal(t, tt.want.Unix(), got.Unix()) + }) + } +} + +func TestWithRefreshJWTOptions(t *testing.T) { + var genIDFn func() string + type testCase[T any] struct { + name string + fn func() option.Option[Management[T]] + want *Options + } + tests := []testCase[data]{ + { + name: "default", + fn: func() option.Option[Management[data]] { + return nil + }, + want: nil, + }, + { + name: "set_refresh_jwt_options", + fn: func() option.Option[Management[data]] { + return WithRefreshJWTOptions[data]( + NewOptions( + 24*60*time.Minute, + "refresh sign key", + WithGenIDFunc(genIDFn), + ), + ) + }, + want: &Options{ + Expire: 24 * 60 * time.Minute, + EncryptionKey: "refresh sign key", + DecryptKey: "refresh sign key", + Method: jwt.SigningMethodHS256, + genIDFn: genIDFn, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got *Options + if tt.fn() == nil { + got = NewManagement[data]( + defaultOption, + ).refreshJWTOptions + } else { + got = NewManagement[data]( + defaultOption, + tt.fn(), + ).refreshJWTOptions + } + assert.Equal(t, tt.want, got) + }) + } +} + +func (m *Management[T]) registerRoutes(server *gin.Engine) { + server.GET("/", func(ctx *gin.Context) { + ctx.Status(http.StatusOK) + }) + server.GET("/login", func(ctx *gin.Context) { + ctx.Status(http.StatusOK) + }) + server.GET("refresh", m.Refresh) +} diff --git a/middlewares/jwt/types.go b/middlewares/jwt/types.go new file mode 100644 index 0000000..6662d4f --- /dev/null +++ b/middlewares/jwt/types.go @@ -0,0 +1,33 @@ +package jwt + +import ( + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" +) + +// Manager jwt 管理器. +type Manager[T any] interface { + // Middleware 登录认证的中间件. + Middleware() gin.HandlerFunc + + // Refresh 刷新 token 的 gin.HandlerFunc. + // 需要设置 refreshJWTOptions 否则会出现 500 的 http 状态码. + Refresh(ctx *gin.Context) + + // GenerateAccessToken 生成资源 token. + GenerateAccessToken(data T) (string, error) + + // VerifyAccessToken 校验资源 token. + VerifyAccessToken(token string, opts ...jwt.ParserOption) (RegisteredClaims[T], error) + + // GenerateRefreshToken 生成刷新 token. + // 需要设置 refreshJWTOptions 否则返回 ErrEmptyRefreshOpts 错误. + GenerateRefreshToken(data T) (string, error) + + // VerifyRefreshToken 校验刷新 token. + // 需要设置 refreshJWTOptions 否则返回 ErrEmptyRefreshOpts 错误. + VerifyRefreshToken(token string, opts ...jwt.ParserOption) (RegisteredClaims[T], error) + + // SetClaims 设置 claims 到 key=`claims` 的 gin.Context 中. + SetClaims(ctx *gin.Context, claims RegisteredClaims[T]) +}