From 52af35d8a980ca363c38ddae196d481fd3f45a76 Mon Sep 17 00:00:00 2001 From: joil Date: Thu, 28 Sep 2023 18:28:58 +0800 Subject: [PATCH 1/4] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0JWT=E4=B8=AD?= =?UTF-8?q?=E9=97=B4=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 2 + go.sum | 11 +- middlewares/auth/auth.go | 133 ++++++++++ middlewares/auth/auth_test.go | 473 ++++++++++++++++++++++++++++++++++ middlewares/auth/jwt.go | 76 ++++++ middlewares/auth/jwt_test.go | 169 ++++++++++++ middlewares/auth/types.go | 13 + middlewares/token/jwt.go | 97 +++++++ middlewares/token/jwt_test.go | 170 ++++++++++++ middlewares/token/types.go | 10 + 10 files changed, 1153 insertions(+), 1 deletion(-) create mode 100644 middlewares/auth/auth.go create mode 100644 middlewares/auth/auth_test.go create mode 100644 middlewares/auth/jwt.go create mode 100644 middlewares/auth/jwt_test.go create mode 100644 middlewares/auth/types.go create mode 100644 middlewares/token/jwt.go create mode 100644 middlewares/token/jwt_test.go create mode 100644 middlewares/token/types.go diff --git a/go.mod b/go.mod index 379257d..1fe5e15 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,9 @@ module github.com/ecodeclub/ginx go 1.21.1 require ( + github.com/ecodeclub/ekit v0.0.8 github.com/gin-gonic/gin v1.9.1 + github.com/golang-jwt/jwt/v5 v5.0.0 github.com/redis/go-redis/v9 v9.1.0 github.com/stretchr/testify v1.8.3 go.uber.org/mock v0.3.0 diff --git a/go.sum b/go.sum index 6376821..1d68aa8 100644 --- a/go.sum +++ b/go.sum @@ -15,6 +15,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/ecodeclub/ekit v0.0.8 h1:861Aot0GvD5ueREEYDVYc1oIhDuFyg6MTxIyiOa4Pvw= +github.com/ecodeclub/ekit v0.0.8/go.mod h1:OqTojKeKFTxeeAAUwNIPKu339SRkX6KAuoK/8A5BCEs= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= @@ -31,6 +33,8 @@ github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= +github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -40,6 +44,8 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= @@ -49,6 +55,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -90,8 +98,9 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/middlewares/auth/auth.go b/middlewares/auth/auth.go new file mode 100644 index 0000000..d52cb90 --- /dev/null +++ b/middlewares/auth/auth.go @@ -0,0 +1,133 @@ +package auth + +import ( + "strings" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + + "github.com/ecodeclub/ginx/middlewares/token" +) + +type authHandler[T jwt.Claims] struct { + allowTokenHeader string + bearerPrefix string + claimsCTXKey string + exposeAccessHeader string + exposeRefreshHeader string + token token.Token[T] +} + +func NewAuthHandler[T jwt.Claims](token token.Token[T], + opts ...authHdlOption[T]) Handler[T] { + dOpts := defaultAuthHdlOption[T]() + dOpts.token = token + + for _, opt := range opts { + opt.apply(&dOpts) + } + + return &dOpts +} + +type authHdlOption[T jwt.Claims] interface { + apply(*authHandler[T]) +} + +type funcAuthHdlOption[T jwt.Claims] struct { + f func(handler *authHandler[T]) +} + +func (fdo *funcAuthHdlOption[T]) apply(do *authHandler[T]) { + fdo.f(do) +} + +func newFuncAuthHdlOption[T jwt.Claims]( + f func(handler *authHandler[T])) *funcAuthHdlOption[T] { + return &funcAuthHdlOption[T]{ + f: f, + } +} + +func defaultAuthHdlOption[T jwt.Claims]() authHandler[T] { + return authHandler[T]{ + allowTokenHeader: "authorization", + bearerPrefix: "Bearer", + claimsCTXKey: "claims", + exposeAccessHeader: "x-access-token", + exposeRefreshHeader: "x-refresh-token", + } +} + +func WithAllowTokenHeader[T jwt.Claims](header string) authHdlOption[T] { + return newFuncAuthHdlOption[T](func(h *authHandler[T]) { + h.allowTokenHeader = header + }) +} + +func WithBearerPrefix[T jwt.Claims](prefix string) authHdlOption[T] { + return newFuncAuthHdlOption[T](func(h *authHandler[T]) { + h.bearerPrefix = prefix + }) +} + +func WithClaimsCTXKey[T jwt.Claims](key string) authHdlOption[T] { + return newFuncAuthHdlOption[T](func(h *authHandler[T]) { + h.claimsCTXKey = key + }) +} + +func WithExposeAccessHeader[T jwt.Claims](header string) authHdlOption[T] { + return newFuncAuthHdlOption[T](func(h *authHandler[T]) { + h.exposeAccessHeader = header + }) +} + +func WithExposeRefreshHeader[T jwt.Claims](header string) authHdlOption[T] { + return newFuncAuthHdlOption[T](func(h *authHandler[T]) { + h.exposeRefreshHeader = header + }) +} + +// ExtractTokenString 提取 token +func (a *authHandler[T]) ExtractTokenString(ctx *gin.Context) string { + authCode := ctx.GetHeader(a.allowTokenHeader) + if authCode == "" { + return "" + } + var b strings.Builder + b.WriteString(a.bearerPrefix) + b.WriteString(" ") + prefix := b.String() + if strings.HasPrefix(authCode, prefix) { + return authCode[len(prefix):] + } + return "" +} + +func (a *authHandler[T]) VerifyToken(ctx *gin.Context, token string) error { + claims, err := a.token.Verify(token) + if err != nil { + return err + } + ctx.Set(a.claimsCTXKey, claims) + return nil +} + +func (a *authHandler[T]) SetAccessToken(ctx *gin.Context, claims T) error { + tokenStr, err := a.token.Generate(claims) + if err != nil { + return err + } + ctx.Header(a.exposeAccessHeader, tokenStr) + return nil +} + +func (a *authHandler[T]) SetRefreshToken(ctx *gin.Context, claims T) error { + tokenStr, err := a.token.Generate(claims) + if err != nil { + return err + } + ctx.Header(a.exposeRefreshHeader, tokenStr) + return nil +} diff --git a/middlewares/auth/auth_test.go b/middlewares/auth/auth_test.go new file mode 100644 index 0000000..ac695e6 --- /dev/null +++ b/middlewares/auth/auth_test.go @@ -0,0 +1,473 @@ +package auth + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + + "github.com/ecodeclub/ginx/middlewares/token" +) + +type myClaims struct { + Foo string `json:"foo"` + jwt.RegisteredClaims +} + +var jwtToken = token.NewJWTToken[myClaims]("foo") + +func TestNewAuthHandler(t *testing.T) { + type testCase[T jwt.Claims] struct { + name string + token token.Token[T] + opts []authHdlOption[T] + want Handler[T] + } + tests := []testCase[myClaims]{ + { + name: "normal_default_creates", + token: jwtToken, + opts: []authHdlOption[myClaims]{}, + want: &authHandler[myClaims]{ + allowTokenHeader: "authorization", + bearerPrefix: "Bearer", + claimsCTXKey: "claims", + exposeAccessHeader: "x-access-token", + exposeRefreshHeader: "x-refresh-token", + token: jwtToken, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewAuthHandler(tt.token, tt.opts...) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithAllowTokenHeader(t *testing.T) { + type testCase[T jwt.Claims] struct { + name string + header string + want Handler[T] + } + tests := []testCase[myClaims]{ + { + name: "normal_set_allow_token_handler", + header: "auth", + want: &authHandler[myClaims]{ + allowTokenHeader: "auth", + bearerPrefix: "Bearer", + claimsCTXKey: "claims", + exposeAccessHeader: "x-access-token", + exposeRefreshHeader: "x-refresh-token", + token: jwtToken, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewAuthHandler[myClaims](jwtToken, + WithAllowTokenHeader[myClaims](tt.header)) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithBearerPrefix(t *testing.T) { + type testCase[T jwt.Claims] struct { + name string + prefix string + want Handler[T] + } + tests := []testCase[myClaims]{ + { + name: "normal_set_bearer_prefix", + prefix: "jwt", + want: &authHandler[myClaims]{ + allowTokenHeader: "authorization", + bearerPrefix: "jwt", + claimsCTXKey: "claims", + exposeAccessHeader: "x-access-token", + exposeRefreshHeader: "x-refresh-token", + token: jwtToken, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewAuthHandler[myClaims](jwtToken, + WithBearerPrefix[myClaims](tt.prefix)) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithClaimsCTXKey(t *testing.T) { + type testCase[T jwt.Claims] struct { + name string + claimsCTXKey string + want Handler[T] + } + tests := []testCase[myClaims]{ + { + name: "normal_set_claims_ctx_key", + claimsCTXKey: "clm", + want: &authHandler[myClaims]{ + allowTokenHeader: "authorization", + bearerPrefix: "Bearer", + claimsCTXKey: "clm", + exposeAccessHeader: "x-access-token", + exposeRefreshHeader: "x-refresh-token", + token: jwtToken, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewAuthHandler[myClaims](jwtToken, + WithClaimsCTXKey[myClaims](tt.claimsCTXKey)) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithExposeAccessHeader(t *testing.T) { + type testCase[T jwt.Claims] struct { + name string + exposeAccessHeader string + want Handler[T] + } + tests := []testCase[myClaims]{ + { + name: "normal_set_expose_access_header", + exposeAccessHeader: "access", + want: &authHandler[myClaims]{ + allowTokenHeader: "authorization", + bearerPrefix: "Bearer", + claimsCTXKey: "claims", + exposeAccessHeader: "access", + exposeRefreshHeader: "x-refresh-token", + token: jwtToken, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewAuthHandler[myClaims](jwtToken, + WithExposeAccessHeader[myClaims](tt.exposeAccessHeader)) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithExposeRefreshHeader(t *testing.T) { + type testCase[T jwt.Claims] struct { + name string + exposeRefreshHeader string + want Handler[T] + } + tests := []testCase[myClaims]{ + { + name: "normal_set_expose_refresh_Header", + exposeRefreshHeader: "refresh", + want: &authHandler[myClaims]{ + allowTokenHeader: "authorization", + bearerPrefix: "Bearer", + claimsCTXKey: "claims", + exposeAccessHeader: "x-access-token", + exposeRefreshHeader: "refresh", + token: jwtToken, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewAuthHandler[myClaims](jwtToken, + WithExposeRefreshHeader[myClaims](tt.exposeRefreshHeader)) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_authHandler_ExtractTokenString(t *testing.T) { + a := NewAuthHandler[myClaims](jwtToken) + type testCase[T jwt.Claims] struct { + name string + reqBuilder func(t *testing.T) *http.Request + want string + } + tests := []testCase[myClaims]{ + { + name: "normal_extract_token", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.a1q3jHKedQGbA-Zrn6S21QUpI2ZNYNHoeG5LkxAXRJQ") + return req + }, + want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.a1q3jHKedQGbA-Zrn6S21QUpI2ZNYNHoeG5LkxAXRJQ", + }, + { + name: "bad_extract_token", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer_eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.a1q3jHKedQGbA-Zrn6S21QUpI2ZNYNHoeG5LkxAXRJQ") + return req + }, + want: "", + }, + { + name: "header_value_not_found", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "", nil) + if err != nil { + t.Fatal(err) + } + return req + }, + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, _ := gin.CreateTestContext(httptest.NewRecorder()) + ctx.Request = tt.reqBuilder(t) + + got := a.ExtractTokenString(ctx) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_authHandler_SetAccessToken(t *testing.T) { + type testCase[T jwt.Claims] struct { + name string + jwtToken token.Token[T] + claims T + want string + wantErr error + } + tests := []testCase[myClaims]{ + { + name: "normal_set_access_token", + jwtToken: token.NewJWTToken[myClaims]("foo", + token.WithNowFunc[myClaims](func() time.Time { + return time.UnixMilli(1695571200000) + })), + claims: myClaims{ + Foo: "bar", + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "bar", + Subject: "1", + IssuedAt: jwt.NewNumericDate( + time.UnixMilli(1695571200000)), + ExpiresAt: jwt.NewNumericDate( + time.UnixMilli(1695571200000). + Add(10 * time.Minute)), + }, + }, + want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.DkipgOka6QyyyvhW3IKLTnnWDQVuTBeGO5vb3Poj7ZY", + }, + { + name: "bad_claims", + jwtToken: token.NewJWTToken[myClaims]("foo", + token.WithSigningMethod[myClaims](jwt.SigningMethodRS512)), + claims: myClaims{ + Foo: "bar", + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "bar", + Subject: "1", + IssuedAt: jwt.NewNumericDate( + time.UnixMilli(1695571200000)), + ExpiresAt: jwt.NewNumericDate( + time.UnixMilli(1695571000000)), + }, + }, + wantErr: errors.New("key is invalid"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := NewAuthHandler[myClaims](tt.jwtToken) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req, err := http.NewRequest(http.MethodGet, "", nil) + if err != nil { + t.Fatal(err) + } + ctx.Request = req + + err = a.SetAccessToken(ctx, tt.claims) + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.want, + recorder.Header().Get("x-access-token")) + }) + } +} + +func Test_authHandler_SetRefreshToken(t *testing.T) { + type testCase[T jwt.Claims] struct { + name string + jwtToken token.Token[T] + claims T + want string + wantErr error + } + tests := []testCase[myClaims]{ + { + name: "normal_set_refresh_token", + jwtToken: token.NewJWTToken[myClaims]("foo", + token.WithNowFunc[myClaims](func() time.Time { + return time.UnixMilli(1695571200000) + })), + claims: myClaims{ + Foo: "bar", + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "bar", + Subject: "2", + IssuedAt: jwt.NewNumericDate( + time.UnixMilli(1695571200000)), + ExpiresAt: jwt.NewNumericDate( + time.UnixMilli(1695571200000). + Add(10 * time.Minute)), + }, + }, + want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJiYXIiLCJzdWIiOiIyIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.gc4-zm430YUBtQIJi07uAxMiMCG1tclhOODNM20fZlM", + }, + { + name: "bad_claims", + jwtToken: token.NewJWTToken[myClaims]("foo", + token.WithSigningMethod[myClaims](jwt.SigningMethodRS512)), + claims: myClaims{ + Foo: "bar", + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "bar", + Subject: "2", + IssuedAt: jwt.NewNumericDate( + time.UnixMilli(1695571200000)), + ExpiresAt: jwt.NewNumericDate( + time.UnixMilli(1695571000000)), + }, + }, + wantErr: errors.New("key is invalid"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := NewAuthHandler[myClaims](tt.jwtToken) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req, err := http.NewRequest(http.MethodGet, "", nil) + if err != nil { + t.Fatal(err) + } + ctx.Request = req + + err = a.SetRefreshToken(ctx, tt.claims) + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.want, + recorder.Header().Get("x-refresh-token")) + }) + } +} + +func Test_authHandler_VerifyToken(t *testing.T) { + type testCase[T jwt.Claims] struct { + name string + jwtToken token.Token[T] + token string + want T + wantErr error + } + tests := []testCase[myClaims]{ + { + name: "normal_set_claims", + jwtToken: token.NewJWTToken[myClaims]("foo", + token.WithNowFunc[myClaims](func() time.Time { + return time.UnixMilli(1695571500000) + }), + ), + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.DkipgOka6QyyyvhW3IKLTnnWDQVuTBeGO5vb3Poj7ZY", + want: myClaims{ + Foo: "bar", + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "bar", + Subject: "1", + IssuedAt: jwt.NewNumericDate( + time.UnixMilli(1695571200000)), + ExpiresAt: jwt.NewNumericDate( + time.UnixMilli(1695571200000). + Add(10 * time.Minute)), + }, + }, + }, + { + name: "token_expired", + jwtToken: token.NewJWTToken[myClaims]("foo", + token.WithNowFunc[myClaims](func() time.Time { + return time.UnixMilli(1695572500000) + }), + ), + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.DkipgOka6QyyyvhW3IKLTnnWDQVuTBeGO5vb3Poj7ZY", + wantErr: fmt.Errorf("验证失败: %v", + fmt.Errorf("%v: %v", jwt.ErrTokenInvalidClaims, jwt.ErrTokenExpired)), + }, + { + name: "wrong_signature", + jwtToken: token.NewJWTToken[myClaims]("foo", + token.WithNowFunc[myClaims](func() time.Time { + return time.UnixMilli(1695571500000) + }), + ), + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.5AgQBNdf08M3vUIi_N2fVQrlNdrIbMvRw-8smkXATWc", + wantErr: fmt.Errorf("验证失败: %v", + fmt.Errorf("%v: %v", jwt.ErrTokenSignatureInvalid, jwt.ErrSignatureInvalid)), + }, + { + name: "bad_token", + jwtToken: token.NewJWTToken[myClaims]("foo", + token.WithNowFunc[myClaims](func() time.Time { + return time.UnixMilli(1695571500000) + }), + ), + token: "bad_token", + wantErr: fmt.Errorf("验证失败: %v: token contains an invalid number of segments", + jwt.ErrTokenMalformed), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := NewAuthHandler[myClaims](tt.jwtToken) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + + err := a.VerifyToken(ctx, tt.token) + assert.Equal(t, tt.wantErr, err) + if err != nil { + return + } + claims, ok := ctx.Get("claims") + if !ok { + t.Errorf("claims 设置失败") + } + clm, ok := claims.(myClaims) + if !ok { + t.Errorf("claims 类型错误") + } + assert.Equal(t, tt.want, clm) + }) + } +} diff --git a/middlewares/auth/jwt.go b/middlewares/auth/jwt.go new file mode 100644 index 0000000..896bd91 --- /dev/null +++ b/middlewares/auth/jwt.go @@ -0,0 +1,76 @@ +package auth + +import ( + "net/http" + + "github.com/ecodeclub/ekit/set" + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" +) + +type JWTBuilder[T jwt.Claims] struct { + publicPaths set.Set[string] + Handler[T] +} + +func NewJWTBuilder[T jwt.Claims](handler Handler[T], opts ...BuilderOption[T]) *JWTBuilder[T] { + dOpts := JWTBuilder[T]{ + publicPaths: set.NewMapSet[string](0), + Handler: handler, + } + + for _, opt := range opts { + opt.apply(&dOpts) + } + + return &dOpts +} + +type BuilderOption[T jwt.Claims] interface { + apply(*JWTBuilder[T]) +} + +type funcBuilderOption[T jwt.Claims] struct { + f func(*JWTBuilder[T]) +} + +func (fdo *funcBuilderOption[T]) apply(do *JWTBuilder[T]) { + fdo.f(do) +} + +func newFuncBuilderOption[T jwt.Claims](f func(*JWTBuilder[T])) *funcBuilderOption[T] { + return &funcBuilderOption[T]{ + f: f, + } +} + +func WithIgnorePaths[T jwt.Claims](paths ...string) BuilderOption[T] { + s := set.NewMapSet[string](len(paths)) + for _, path := range paths { + s.Add(path) + } + return newFuncBuilderOption[T](func(b *JWTBuilder[T]) { + b.publicPaths = s + }) +} + +func (b *JWTBuilder[T]) Build() gin.HandlerFunc { + return func(ctx *gin.Context) { + // 不需要校验 + if b.publicPaths.Exist(ctx.Request.URL.Path) { + return + } + + tokenStr := b.ExtractTokenString(ctx) + if tokenStr == "" { + ctx.AbortWithStatus(http.StatusUnauthorized) + return + } + + err := b.VerifyToken(ctx, tokenStr) + if err != nil { + ctx.AbortWithStatus(http.StatusUnauthorized) + return + } + } +} diff --git a/middlewares/auth/jwt_test.go b/middlewares/auth/jwt_test.go new file mode 100644 index 0000000..44714f0 --- /dev/null +++ b/middlewares/auth/jwt_test.go @@ -0,0 +1,169 @@ +package auth + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/ecodeclub/ekit/set" + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + + "github.com/ecodeclub/ginx/middlewares/token" +) + +var authHdl = NewAuthHandler[myClaims](jwtToken) + +func TestNewJWTBuilder(t *testing.T) { + type testCase[T jwt.Claims] struct { + name string + handler Handler[T] + want *JWTBuilder[T] + } + tests := []testCase[myClaims]{ + { + name: "normal", + handler: authHdl, + want: &JWTBuilder[myClaims]{ + Handler: authHdl, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewJWTBuilder[myClaims](tt.handler) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithIgnorePaths(t *testing.T) { + type testCase[T jwt.Claims] struct { + name string + paths []string + want func() *JWTBuilder[T] + } + tests := []testCase[myClaims]{ + { + name: "normal", + paths: []string{ + "/login", + "/signup", + }, + want: func() *JWTBuilder[myClaims] { + pathSet := set.NewMapSet[string](2) + pathSet.Add("/login") + pathSet.Add("/signup") + + return &JWTBuilder[myClaims]{ + publicPaths: pathSet, + Handler: authHdl, + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewJWTBuilder(authHdl, + WithIgnorePaths[myClaims](tt.paths...)) + assert.Equal(t, tt.want(), got) + }) + } +} + +func TestJWTBuilder_Build(t *testing.T) { + type testCase[T jwt.Claims] struct { + name string + b *JWTBuilder[T] + reqBuilder func(t *testing.T) *http.Request + wantCode int + } + tests := []testCase[myClaims]{ + { + name: "normal", + b: NewJWTBuilder[myClaims]( + NewAuthHandler[myClaims]( + token.NewJWTToken[myClaims]("foo", + token.WithNowFunc[myClaims](func() time.Time { + return time.UnixMilli(1695571500000) + })))), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.DkipgOka6QyyyvhW3IKLTnnWDQVuTBeGO5vb3Poj7ZY") + return req + }, + wantCode: http.StatusOK, + }, + { + name: "verification_failed", + b: NewJWTBuilder[myClaims](authHdl), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.DkipgOka6QyyyvhW3IKLTnnWDQVuTBeGO5vb3Poj7ZY") + return req + }, + wantCode: http.StatusUnauthorized, + }, + { + name: "extract_token_failed", + b: NewJWTBuilder[myClaims]( + NewAuthHandler[myClaims]( + token.NewJWTToken[myClaims]("foo", + token.WithNowFunc[myClaims](func() time.Time { + return time.UnixMilli(1695571500000) + })))), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer_eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.DkipgOka6QyyyvhW3IKLTnnWDQVuTBeGO5vb3Poj7ZY") + return req + }, + wantCode: http.StatusUnauthorized, + }, + { + name: "verification_failed", + b: NewJWTBuilder[myClaims](authHdl, + WithIgnorePaths[myClaims]("/login")), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/login", nil) + if err != nil { + t.Fatal(err) + } + return req + }, + wantCode: http.StatusOK, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := gin.Default() + server.Use(tt.b.Build()) + tt.b.registerRoutes(server) + + req := tt.reqBuilder(t) + recorder := httptest.NewRecorder() + + server.ServeHTTP(recorder, req) + assert.Equal(t, tt.wantCode, recorder.Code) + }) + } +} + +func (b *JWTBuilder[T]) registerRoutes(server *gin.Engine) { + server.GET("/", func(ctx *gin.Context) { + ctx.Status(http.StatusOK) + }) + server.GET("/login", func(ctx *gin.Context) { + ctx.Status(http.StatusOK) + }) +} diff --git a/middlewares/auth/types.go b/middlewares/auth/types.go new file mode 100644 index 0000000..6efb6ef --- /dev/null +++ b/middlewares/auth/types.go @@ -0,0 +1,13 @@ +package auth + +import ( + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" +) + +type Handler[T jwt.Claims] interface { + ExtractTokenString(ctx *gin.Context) string + VerifyToken(ctx *gin.Context, token string) error + SetAccessToken(ctx *gin.Context, claims T) error + SetRefreshToken(ctx *gin.Context, claims T) error +} diff --git a/middlewares/token/jwt.go b/middlewares/token/jwt.go new file mode 100644 index 0000000..661a481 --- /dev/null +++ b/middlewares/token/jwt.go @@ -0,0 +1,97 @@ +package token + +import ( + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +type JWTToken[T jwt.Claims] struct { + encryptionKey string // 加密密钥 + decryptKey string // 解密密钥 + nowFunc func() time.Time + method jwt.SigningMethod +} + +// NewJWTToken +// method: 默认签名加密方式使用 SH256 +// decryptKey: 因默认使用对称加密所以与 encryptionKey 相同 +func NewJWTToken[T jwt.Claims](encryptionKey string, opts ...Option[T]) *JWTToken[T] { + dOpts := defaultOption[T]() + dOpts.encryptionKey = encryptionKey + dOpts.decryptKey = encryptionKey + + for _, opt := range opts { + opt.apply(&dOpts) + } + + return &dOpts +} + +type Option[T jwt.Claims] interface { + apply(*JWTToken[T]) +} + +type funcOption[T jwt.Claims] struct { + f func(*JWTToken[T]) +} + +func (fdo *funcOption[T]) apply(do *JWTToken[T]) { + fdo.f(do) +} + +func newFuncOption[T jwt.Claims](f func(*JWTToken[T])) *funcOption[T] { + return &funcOption[T]{ + f: f, + } +} + +func defaultOption[T jwt.Claims]() JWTToken[T] { + return JWTToken[T]{ + nowFunc: time.Now, + method: jwt.SigningMethodHS256, + } +} + +func WithDecryptKey[T jwt.Claims](decryptKey string) Option[T] { + return newFuncOption(func(j *JWTToken[T]) { + j.decryptKey = decryptKey + }) +} + +func WithNowFunc[T jwt.Claims](nowFunc func() time.Time) Option[T] { + return newFuncOption(func(j *JWTToken[T]) { + j.nowFunc = nowFunc + }) +} + +func WithSigningMethod[T jwt.Claims](method jwt.SigningMethod) Option[T] { + return newFuncOption(func(j *JWTToken[T]) { + j.method = method + }) +} + +// Generate 生成 jwt token. +func (j *JWTToken[T]) Generate(claims T) (string, error) { + token := jwt.NewWithClaims(j.method, claims) + return token.SignedString([]byte(j.encryptionKey)) +} + +// Verify 验证token.验证不通过则返回 error. +func (j *JWTToken[T]) Verify(token string) (T, error) { + var claimsZero T + claims := claimsZero + var claimsPtr any = &claims + t, err := jwt.ParseWithClaims(token, claimsPtr.(jwt.Claims), + func(*jwt.Token) (interface{}, error) { + return []byte(j.decryptKey), nil + }, + jwt.WithTimeFunc(j.nowFunc), + ) + if err != nil || !t.Valid { + return claimsZero, fmt.Errorf("验证失败: %v", err) + } + + return claims, nil +} diff --git a/middlewares/token/jwt_test.go b/middlewares/token/jwt_test.go new file mode 100644 index 0000000..4eed689 --- /dev/null +++ b/middlewares/token/jwt_test.go @@ -0,0 +1,170 @@ +package token + +import ( + "fmt" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" +) + +func TestJWTToken_Generate(t *testing.T) { + j := NewJWTToken[jwt.RegisteredClaims]("foo") + nowTime := time.UnixMilli(1695571200000) + type testCase[T jwt.Claims] struct { + name string + claims jwt.RegisteredClaims + want string + wantErr error + } + tests := []testCase[jwt.RegisteredClaims]{ + { + name: "生成token", + claims: jwt.RegisteredClaims{ + Issuer: "bar", + Subject: "1", + IssuedAt: jwt.NewNumericDate(nowTime), + ExpiresAt: jwt.NewNumericDate(nowTime.Add(10 * time.Minute)), + }, + want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.a1q3jHKedQGbA-Zrn6S21QUpI2ZNYNHoeG5LkxAXRJQ", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := j.Generate(tt.claims) + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestJWTToken_Verify(t *testing.T) { + j := NewJWTToken[jwt.RegisteredClaims]("foo") + type testCase[T jwt.Claims] struct { + name string + nowFunc func() time.Time + token string + want jwt.RegisteredClaims + wantErr error + } + tests := []testCase[jwt.RegisteredClaims]{ + { + name: "验证通过", + nowFunc: func() time.Time { + return time.UnixMilli(1695571500000) + }, + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.a1q3jHKedQGbA-Zrn6S21QUpI2ZNYNHoeG5LkxAXRJQ", + want: jwt.RegisteredClaims{ + Issuer: "bar", + Subject: "1", + IssuedAt: jwt.NewNumericDate(time.UnixMilli(1695571200000)), + ExpiresAt: jwt.NewNumericDate(time.UnixMilli(1695571200000).Add(10 * time.Minute)), + }, + }, + { + name: "token过期了", + nowFunc: func() time.Time { + return time.UnixMilli(1695572500000) + }, + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.a1q3jHKedQGbA-Zrn6S21QUpI2ZNYNHoeG5LkxAXRJQ", + wantErr: fmt.Errorf("验证失败: %v", + fmt.Errorf("%v: %v", jwt.ErrTokenInvalidClaims, jwt.ErrTokenExpired)), + }, + { + name: "token签名错误", + nowFunc: func() time.Time { + return time.UnixMilli(1695571500000) + }, + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.5OeEzR5tNTGmXwvloac2wYdZvlt8U5UmFdsnpBJ_zb4", + wantErr: fmt.Errorf("验证失败: %v", + fmt.Errorf("%v: %v", jwt.ErrTokenSignatureInvalid, jwt.ErrSignatureInvalid)), + }, + { + name: "错误的token", + nowFunc: func() time.Time { + return time.UnixMilli(1695571500000) + }, + token: "bad_token", + wantErr: fmt.Errorf("验证失败: %v: token contains an invalid number of segments", + jwt.ErrTokenMalformed), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + j.nowFunc = tt.nowFunc + got, err := j.Verify(tt.token) + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithSigningMethod(t *testing.T) { + type jwtT = jwt.RegisteredClaims + type testCase[T jwt.Claims] struct { + name string + method jwt.SigningMethod + want jwt.SigningMethod + } + tests := []testCase[jwtT]{ + { + name: "设置成功", + method: jwt.SigningMethodHS512, + want: jwt.SigningMethodHS512, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewJWTToken[jwtT]("foo", WithSigningMethod[jwtT](tt.method)).method + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithDecryptKey(t *testing.T) { + type jwtT = jwt.RegisteredClaims + type testCase[T jwt.Claims] struct { + name string + decryptKey string + want string + } + tests := []testCase[jwtT]{ + { + name: "设置解密密钥成功", + decryptKey: "decryptKey", + want: "decryptKey", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewJWTToken[jwtT]("foo", WithDecryptKey[jwtT](tt.decryptKey)).decryptKey + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithNowFunc(t *testing.T) { + type jwtT = jwt.RegisteredClaims + type testCase[T jwt.Claims] struct { + name string + nowFunc func() time.Time + want time.Time + } + tests := []testCase[jwtT]{ + { + name: "设置解密密钥成功", + nowFunc: func() time.Time { + return time.UnixMilli(1695571200000) + }, + want: time.UnixMilli(1695571200000), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewJWTToken[jwtT]("foo", + WithNowFunc[jwtT](tt.nowFunc)).nowFunc() + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/middlewares/token/types.go b/middlewares/token/types.go new file mode 100644 index 0000000..bd10856 --- /dev/null +++ b/middlewares/token/types.go @@ -0,0 +1,10 @@ +package token + +import ( + "github.com/golang-jwt/jwt/v5" +) + +type Token[T jwt.Claims] interface { + Generate(claims T) (string, error) + Verify(token string) (T, error) +} From 264cdf95d050820c0002a7fe4992f2452eb56ad7 Mon Sep 17 00:00:00 2001 From: joil Date: Fri, 29 Sep 2023 13:14:56 +0800 Subject: [PATCH 2/4] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=20JWT=20?= =?UTF-8?q?=E4=B8=80=E9=94=AE=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- handler/refresh_token.go | 42 ++++++++++ handler/refresh_token_test.go | 151 ++++++++++++++++++++++++++++++++++ middlewares/auth/auth.go | 20 +++-- middlewares/auth/auth_test.go | 4 +- 4 files changed, 208 insertions(+), 9 deletions(-) create mode 100644 handler/refresh_token.go create mode 100644 handler/refresh_token_test.go diff --git a/handler/refresh_token.go b/handler/refresh_token.go new file mode 100644 index 0000000..bcfb3e8 --- /dev/null +++ b/handler/refresh_token.go @@ -0,0 +1,42 @@ +package handler + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + + "github.com/ecodeclub/ginx/middlewares/auth" +) + +type TokenHandler[T jwt.Claims] interface { + Refresh(ctx *gin.Context) +} + +type tokenHandler[T jwt.Claims] struct { + accessClaims T + auth.Handler[T] +} + +func NewTokenHandler[T jwt.Claims]( + accessClaims T, handler auth.Handler[T]) TokenHandler[T] { + return &tokenHandler[T]{ + accessClaims: accessClaims, + Handler: handler, + } +} + +func (t *tokenHandler[T]) Refresh(ctx *gin.Context) { + tokenStr := t.ExtractTokenString(ctx) + err := t.VerifyToken(ctx, tokenStr) + if err != nil { + ctx.AbortWithStatus(http.StatusUnauthorized) + return + } + err = t.SetAccessToken(ctx, t.accessClaims) + if err != nil { + ctx.AbortWithStatus(http.StatusInternalServerError) + return + } + ctx.Status(http.StatusOK) +} diff --git a/handler/refresh_token_test.go b/handler/refresh_token_test.go new file mode 100644 index 0000000..184b1cf --- /dev/null +++ b/handler/refresh_token_test.go @@ -0,0 +1,151 @@ +package handler + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + + "github.com/ecodeclub/ginx/middlewares/auth" + "github.com/ecodeclub/ginx/middlewares/token" +) + +type myClaims struct { + Foo string `json:"foo"` + jwt.RegisteredClaims +} + +func Test_token_Refresh(t *testing.T) { + nowTime := time.UnixMilli(1695571500000) + type testCase[T jwt.Claims] struct { + name string + hdl auth.Handler[T] + reqBuilder func(t *testing.T) *http.Request + accessClaims T + wantCode int + wantToken string + } + tests := []testCase[myClaims]{ + { + name: "normal", + hdl: auth.NewAuthHandler[myClaims]( + token.NewJWTToken[myClaims]("access-token-key", + token.WithNowFunc[myClaims]( + func() time.Time { + return nowTime + }, + ), + token.WithDecryptKey[myClaims]("refresh-token-key"), + ), + ), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/refresh", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJyZWZyZXNoIiwic3ViIjoiMSIsImV4cCI6MTY5NTU3MTgwMCwiaWF0IjoxNjk1NTcxMjAwfQ.8_LyHqansmkqcXJ1INVJDPI2XUAzd12keCrSltqnCJQ") + return req + }, + accessClaims: myClaims{ + Foo: "bar", + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "access", + Subject: "1", + IssuedAt: jwt.NewNumericDate(nowTime), + ExpiresAt: jwt.NewNumericDate( + nowTime.Add(10 * time.Minute)), + }, + }, + wantCode: http.StatusOK, + wantToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJhY2Nlc3MiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcyMTAwLCJpYXQiOjE2OTU1NzE1MDB9.rE74rZg00AtSwvFpVMMYQggfPpgsrK6oiil3PjKKpcA", + }, + { + name: "set_access_token_failed", + hdl: auth.NewAuthHandler[myClaims]( + token.NewJWTToken[myClaims]("access-token-key", + token.WithNowFunc[myClaims]( + func() time.Time { + return nowTime + }, + ), + token.WithSigningMethod[myClaims](jwt.SigningMethodRS256), + token.WithDecryptKey[myClaims]("refresh-token-key"), + ), + ), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/refresh", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJyZWZyZXNoIiwic3ViIjoiMSIsImV4cCI6MTY5NTU3MTgwMCwiaWF0IjoxNjk1NTcxMjAwfQ.8_LyHqansmkqcXJ1INVJDPI2XUAzd12keCrSltqnCJQ") + return req + }, + accessClaims: myClaims{ + Foo: "bar", + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "access", + Subject: "1", + IssuedAt: jwt.NewNumericDate(nowTime), + ExpiresAt: jwt.NewNumericDate( + nowTime.Add(10 * time.Minute)), + }, + }, + wantCode: http.StatusInternalServerError, + }, + { + name: "verify_failed", + hdl: auth.NewAuthHandler[myClaims]( + token.NewJWTToken[myClaims]("access-token-key", + token.WithNowFunc[myClaims]( + func() time.Time { + return nowTime + }, + ), + token.WithDecryptKey[myClaims]("mistake-refresh-token-key"), + ), + ), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/refresh", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJyZWZyZXNoIiwic3ViIjoiMSIsImV4cCI6MTY5NTU3MTgwMCwiaWF0IjoxNjk1NTcxMjAwfQ.8_LyHqansmkqcXJ1INVJDPI2XUAzd12keCrSltqnCJQ") + return req + }, + accessClaims: myClaims{ + Foo: "bar", + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "access", + Subject: "1", + IssuedAt: jwt.NewNumericDate(nowTime), + ExpiresAt: jwt.NewNumericDate( + nowTime.Add(10 * time.Minute)), + }, + }, + wantCode: http.StatusUnauthorized, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc := NewTokenHandler[myClaims](tt.accessClaims, tt.hdl) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = tt.reqBuilder(t) + svc.Refresh(ctx) + assert.Equal(t, tt.wantCode, recorder.Code) + if recorder.Code != http.StatusOK { + return + } + assert.Equal(t, tt.wantToken, + recorder.Header().Get("x-access-token")) + }) + } +} + +func (t *tokenHandler[T]) registerRoutes(server *gin.Engine) { + server.GET("/refresh", t.Refresh) +} diff --git a/middlewares/auth/auth.go b/middlewares/auth/auth.go index d52cb90..8fbe1ed 100644 --- a/middlewares/auth/auth.go +++ b/middlewares/auth/auth.go @@ -18,8 +18,14 @@ type authHandler[T jwt.Claims] struct { token token.Token[T] } +// NewAuthHandler +// allowTokenHeader: 默认使用 authorization 为认证请求头. +// bearerPrefix: 默认使用 Bearer 拼接 token. +// claimsCTXKey: 默认使用 claims 为设置到 gin.Context 的key +// exposeAccessHeader: 默认使用 x-access-token 为暴露外部的资源请求头. +// exposeRefreshHeader: 默认使用 x-refresh-token 为暴露外部的刷新请求头. func NewAuthHandler[T jwt.Claims](token token.Token[T], - opts ...authHdlOption[T]) Handler[T] { + opts ...AuthHdlOption[T]) Handler[T] { dOpts := defaultAuthHdlOption[T]() dOpts.token = token @@ -30,7 +36,7 @@ func NewAuthHandler[T jwt.Claims](token token.Token[T], return &dOpts } -type authHdlOption[T jwt.Claims] interface { +type AuthHdlOption[T jwt.Claims] interface { apply(*authHandler[T]) } @@ -59,31 +65,31 @@ func defaultAuthHdlOption[T jwt.Claims]() authHandler[T] { } } -func WithAllowTokenHeader[T jwt.Claims](header string) authHdlOption[T] { +func WithAllowTokenHeader[T jwt.Claims](header string) AuthHdlOption[T] { return newFuncAuthHdlOption[T](func(h *authHandler[T]) { h.allowTokenHeader = header }) } -func WithBearerPrefix[T jwt.Claims](prefix string) authHdlOption[T] { +func WithBearerPrefix[T jwt.Claims](prefix string) AuthHdlOption[T] { return newFuncAuthHdlOption[T](func(h *authHandler[T]) { h.bearerPrefix = prefix }) } -func WithClaimsCTXKey[T jwt.Claims](key string) authHdlOption[T] { +func WithClaimsCTXKey[T jwt.Claims](key string) AuthHdlOption[T] { return newFuncAuthHdlOption[T](func(h *authHandler[T]) { h.claimsCTXKey = key }) } -func WithExposeAccessHeader[T jwt.Claims](header string) authHdlOption[T] { +func WithExposeAccessHeader[T jwt.Claims](header string) AuthHdlOption[T] { return newFuncAuthHdlOption[T](func(h *authHandler[T]) { h.exposeAccessHeader = header }) } -func WithExposeRefreshHeader[T jwt.Claims](header string) authHdlOption[T] { +func WithExposeRefreshHeader[T jwt.Claims](header string) AuthHdlOption[T] { return newFuncAuthHdlOption[T](func(h *authHandler[T]) { h.exposeRefreshHeader = header }) diff --git a/middlewares/auth/auth_test.go b/middlewares/auth/auth_test.go index ac695e6..120d095 100644 --- a/middlewares/auth/auth_test.go +++ b/middlewares/auth/auth_test.go @@ -26,14 +26,14 @@ func TestNewAuthHandler(t *testing.T) { type testCase[T jwt.Claims] struct { name string token token.Token[T] - opts []authHdlOption[T] + opts []AuthHdlOption[T] want Handler[T] } tests := []testCase[myClaims]{ { name: "normal_default_creates", token: jwtToken, - opts: []authHdlOption[myClaims]{}, + opts: []AuthHdlOption[myClaims]{}, want: &authHandler[myClaims]{ allowTokenHeader: "authorization", bearerPrefix: "Bearer", From 77f433a344f05ef68483c3fa2e8793bccdc8ec93 Mon Sep 17 00:00:00 2001 From: joil Date: Mon, 9 Oct 2023 15:36:27 +0800 Subject: [PATCH 3/4] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=20JWT=20?= =?UTF-8?q?=E4=B8=80=E9=94=AE=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- handler/refresh_token.go | 42 -- handler/refresh_token_test.go | 151 ----- middlewares/auth/auth.go | 139 ----- middlewares/auth/auth_test.go | 473 --------------- middlewares/auth/jwt.go | 76 --- middlewares/auth/jwt_test.go | 169 ------ middlewares/auth/types.go | 13 - middlewares/jwt/claims.go | 79 +++ middlewares/jwt/claims_test.go | 115 ++++ middlewares/jwt/jwt.go | 309 ++++++++++ middlewares/jwt/jwt_test.go | 1001 ++++++++++++++++++++++++++++++++ middlewares/jwt/types.go | 19 + middlewares/token/jwt.go | 97 ---- middlewares/token/jwt_test.go | 170 ------ middlewares/token/types.go | 10 - 15 files changed, 1523 insertions(+), 1340 deletions(-) delete mode 100644 handler/refresh_token.go delete mode 100644 handler/refresh_token_test.go delete mode 100644 middlewares/auth/auth.go delete mode 100644 middlewares/auth/auth_test.go delete mode 100644 middlewares/auth/jwt.go delete mode 100644 middlewares/auth/jwt_test.go delete mode 100644 middlewares/auth/types.go create mode 100644 middlewares/jwt/claims.go create mode 100644 middlewares/jwt/claims_test.go create mode 100644 middlewares/jwt/jwt.go create mode 100644 middlewares/jwt/jwt_test.go create mode 100644 middlewares/jwt/types.go delete mode 100644 middlewares/token/jwt.go delete mode 100644 middlewares/token/jwt_test.go delete mode 100644 middlewares/token/types.go diff --git a/handler/refresh_token.go b/handler/refresh_token.go deleted file mode 100644 index bcfb3e8..0000000 --- a/handler/refresh_token.go +++ /dev/null @@ -1,42 +0,0 @@ -package handler - -import ( - "net/http" - - "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt/v5" - - "github.com/ecodeclub/ginx/middlewares/auth" -) - -type TokenHandler[T jwt.Claims] interface { - Refresh(ctx *gin.Context) -} - -type tokenHandler[T jwt.Claims] struct { - accessClaims T - auth.Handler[T] -} - -func NewTokenHandler[T jwt.Claims]( - accessClaims T, handler auth.Handler[T]) TokenHandler[T] { - return &tokenHandler[T]{ - accessClaims: accessClaims, - Handler: handler, - } -} - -func (t *tokenHandler[T]) Refresh(ctx *gin.Context) { - tokenStr := t.ExtractTokenString(ctx) - err := t.VerifyToken(ctx, tokenStr) - if err != nil { - ctx.AbortWithStatus(http.StatusUnauthorized) - return - } - err = t.SetAccessToken(ctx, t.accessClaims) - if err != nil { - ctx.AbortWithStatus(http.StatusInternalServerError) - return - } - ctx.Status(http.StatusOK) -} diff --git a/handler/refresh_token_test.go b/handler/refresh_token_test.go deleted file mode 100644 index 184b1cf..0000000 --- a/handler/refresh_token_test.go +++ /dev/null @@ -1,151 +0,0 @@ -package handler - -import ( - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt/v5" - "github.com/stretchr/testify/assert" - - "github.com/ecodeclub/ginx/middlewares/auth" - "github.com/ecodeclub/ginx/middlewares/token" -) - -type myClaims struct { - Foo string `json:"foo"` - jwt.RegisteredClaims -} - -func Test_token_Refresh(t *testing.T) { - nowTime := time.UnixMilli(1695571500000) - type testCase[T jwt.Claims] struct { - name string - hdl auth.Handler[T] - reqBuilder func(t *testing.T) *http.Request - accessClaims T - wantCode int - wantToken string - } - tests := []testCase[myClaims]{ - { - name: "normal", - hdl: auth.NewAuthHandler[myClaims]( - token.NewJWTToken[myClaims]("access-token-key", - token.WithNowFunc[myClaims]( - func() time.Time { - return nowTime - }, - ), - token.WithDecryptKey[myClaims]("refresh-token-key"), - ), - ), - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "/refresh", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJyZWZyZXNoIiwic3ViIjoiMSIsImV4cCI6MTY5NTU3MTgwMCwiaWF0IjoxNjk1NTcxMjAwfQ.8_LyHqansmkqcXJ1INVJDPI2XUAzd12keCrSltqnCJQ") - return req - }, - accessClaims: myClaims{ - Foo: "bar", - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "access", - Subject: "1", - IssuedAt: jwt.NewNumericDate(nowTime), - ExpiresAt: jwt.NewNumericDate( - nowTime.Add(10 * time.Minute)), - }, - }, - wantCode: http.StatusOK, - wantToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJhY2Nlc3MiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcyMTAwLCJpYXQiOjE2OTU1NzE1MDB9.rE74rZg00AtSwvFpVMMYQggfPpgsrK6oiil3PjKKpcA", - }, - { - name: "set_access_token_failed", - hdl: auth.NewAuthHandler[myClaims]( - token.NewJWTToken[myClaims]("access-token-key", - token.WithNowFunc[myClaims]( - func() time.Time { - return nowTime - }, - ), - token.WithSigningMethod[myClaims](jwt.SigningMethodRS256), - token.WithDecryptKey[myClaims]("refresh-token-key"), - ), - ), - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "/refresh", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJyZWZyZXNoIiwic3ViIjoiMSIsImV4cCI6MTY5NTU3MTgwMCwiaWF0IjoxNjk1NTcxMjAwfQ.8_LyHqansmkqcXJ1INVJDPI2XUAzd12keCrSltqnCJQ") - return req - }, - accessClaims: myClaims{ - Foo: "bar", - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "access", - Subject: "1", - IssuedAt: jwt.NewNumericDate(nowTime), - ExpiresAt: jwt.NewNumericDate( - nowTime.Add(10 * time.Minute)), - }, - }, - wantCode: http.StatusInternalServerError, - }, - { - name: "verify_failed", - hdl: auth.NewAuthHandler[myClaims]( - token.NewJWTToken[myClaims]("access-token-key", - token.WithNowFunc[myClaims]( - func() time.Time { - return nowTime - }, - ), - token.WithDecryptKey[myClaims]("mistake-refresh-token-key"), - ), - ), - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "/refresh", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJyZWZyZXNoIiwic3ViIjoiMSIsImV4cCI6MTY5NTU3MTgwMCwiaWF0IjoxNjk1NTcxMjAwfQ.8_LyHqansmkqcXJ1INVJDPI2XUAzd12keCrSltqnCJQ") - return req - }, - accessClaims: myClaims{ - Foo: "bar", - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "access", - Subject: "1", - IssuedAt: jwt.NewNumericDate(nowTime), - ExpiresAt: jwt.NewNumericDate( - nowTime.Add(10 * time.Minute)), - }, - }, - wantCode: http.StatusUnauthorized, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - svc := NewTokenHandler[myClaims](tt.accessClaims, tt.hdl) - recorder := httptest.NewRecorder() - ctx, _ := gin.CreateTestContext(recorder) - ctx.Request = tt.reqBuilder(t) - svc.Refresh(ctx) - assert.Equal(t, tt.wantCode, recorder.Code) - if recorder.Code != http.StatusOK { - return - } - assert.Equal(t, tt.wantToken, - recorder.Header().Get("x-access-token")) - }) - } -} - -func (t *tokenHandler[T]) registerRoutes(server *gin.Engine) { - server.GET("/refresh", t.Refresh) -} diff --git a/middlewares/auth/auth.go b/middlewares/auth/auth.go deleted file mode 100644 index 8fbe1ed..0000000 --- a/middlewares/auth/auth.go +++ /dev/null @@ -1,139 +0,0 @@ -package auth - -import ( - "strings" - - "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt/v5" - - "github.com/ecodeclub/ginx/middlewares/token" -) - -type authHandler[T jwt.Claims] struct { - allowTokenHeader string - bearerPrefix string - claimsCTXKey string - exposeAccessHeader string - exposeRefreshHeader string - token token.Token[T] -} - -// NewAuthHandler -// allowTokenHeader: 默认使用 authorization 为认证请求头. -// bearerPrefix: 默认使用 Bearer 拼接 token. -// claimsCTXKey: 默认使用 claims 为设置到 gin.Context 的key -// exposeAccessHeader: 默认使用 x-access-token 为暴露外部的资源请求头. -// exposeRefreshHeader: 默认使用 x-refresh-token 为暴露外部的刷新请求头. -func NewAuthHandler[T jwt.Claims](token token.Token[T], - opts ...AuthHdlOption[T]) Handler[T] { - dOpts := defaultAuthHdlOption[T]() - dOpts.token = token - - for _, opt := range opts { - opt.apply(&dOpts) - } - - return &dOpts -} - -type AuthHdlOption[T jwt.Claims] interface { - apply(*authHandler[T]) -} - -type funcAuthHdlOption[T jwt.Claims] struct { - f func(handler *authHandler[T]) -} - -func (fdo *funcAuthHdlOption[T]) apply(do *authHandler[T]) { - fdo.f(do) -} - -func newFuncAuthHdlOption[T jwt.Claims]( - f func(handler *authHandler[T])) *funcAuthHdlOption[T] { - return &funcAuthHdlOption[T]{ - f: f, - } -} - -func defaultAuthHdlOption[T jwt.Claims]() authHandler[T] { - return authHandler[T]{ - allowTokenHeader: "authorization", - bearerPrefix: "Bearer", - claimsCTXKey: "claims", - exposeAccessHeader: "x-access-token", - exposeRefreshHeader: "x-refresh-token", - } -} - -func WithAllowTokenHeader[T jwt.Claims](header string) AuthHdlOption[T] { - return newFuncAuthHdlOption[T](func(h *authHandler[T]) { - h.allowTokenHeader = header - }) -} - -func WithBearerPrefix[T jwt.Claims](prefix string) AuthHdlOption[T] { - return newFuncAuthHdlOption[T](func(h *authHandler[T]) { - h.bearerPrefix = prefix - }) -} - -func WithClaimsCTXKey[T jwt.Claims](key string) AuthHdlOption[T] { - return newFuncAuthHdlOption[T](func(h *authHandler[T]) { - h.claimsCTXKey = key - }) -} - -func WithExposeAccessHeader[T jwt.Claims](header string) AuthHdlOption[T] { - return newFuncAuthHdlOption[T](func(h *authHandler[T]) { - h.exposeAccessHeader = header - }) -} - -func WithExposeRefreshHeader[T jwt.Claims](header string) AuthHdlOption[T] { - return newFuncAuthHdlOption[T](func(h *authHandler[T]) { - h.exposeRefreshHeader = header - }) -} - -// ExtractTokenString 提取 token -func (a *authHandler[T]) ExtractTokenString(ctx *gin.Context) string { - authCode := ctx.GetHeader(a.allowTokenHeader) - if authCode == "" { - return "" - } - var b strings.Builder - b.WriteString(a.bearerPrefix) - b.WriteString(" ") - prefix := b.String() - if strings.HasPrefix(authCode, prefix) { - return authCode[len(prefix):] - } - return "" -} - -func (a *authHandler[T]) VerifyToken(ctx *gin.Context, token string) error { - claims, err := a.token.Verify(token) - if err != nil { - return err - } - ctx.Set(a.claimsCTXKey, claims) - return nil -} - -func (a *authHandler[T]) SetAccessToken(ctx *gin.Context, claims T) error { - tokenStr, err := a.token.Generate(claims) - if err != nil { - return err - } - ctx.Header(a.exposeAccessHeader, tokenStr) - return nil -} - -func (a *authHandler[T]) SetRefreshToken(ctx *gin.Context, claims T) error { - tokenStr, err := a.token.Generate(claims) - if err != nil { - return err - } - ctx.Header(a.exposeRefreshHeader, tokenStr) - return nil -} diff --git a/middlewares/auth/auth_test.go b/middlewares/auth/auth_test.go deleted file mode 100644 index 120d095..0000000 --- a/middlewares/auth/auth_test.go +++ /dev/null @@ -1,473 +0,0 @@ -package auth - -import ( - "errors" - "fmt" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt/v5" - "github.com/stretchr/testify/assert" - - "github.com/ecodeclub/ginx/middlewares/token" -) - -type myClaims struct { - Foo string `json:"foo"` - jwt.RegisteredClaims -} - -var jwtToken = token.NewJWTToken[myClaims]("foo") - -func TestNewAuthHandler(t *testing.T) { - type testCase[T jwt.Claims] struct { - name string - token token.Token[T] - opts []AuthHdlOption[T] - want Handler[T] - } - tests := []testCase[myClaims]{ - { - name: "normal_default_creates", - token: jwtToken, - opts: []AuthHdlOption[myClaims]{}, - want: &authHandler[myClaims]{ - allowTokenHeader: "authorization", - bearerPrefix: "Bearer", - claimsCTXKey: "claims", - exposeAccessHeader: "x-access-token", - exposeRefreshHeader: "x-refresh-token", - token: jwtToken, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NewAuthHandler(tt.token, tt.opts...) - assert.Equal(t, tt.want, got) - }) - } -} - -func TestWithAllowTokenHeader(t *testing.T) { - type testCase[T jwt.Claims] struct { - name string - header string - want Handler[T] - } - tests := []testCase[myClaims]{ - { - name: "normal_set_allow_token_handler", - header: "auth", - want: &authHandler[myClaims]{ - allowTokenHeader: "auth", - bearerPrefix: "Bearer", - claimsCTXKey: "claims", - exposeAccessHeader: "x-access-token", - exposeRefreshHeader: "x-refresh-token", - token: jwtToken, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NewAuthHandler[myClaims](jwtToken, - WithAllowTokenHeader[myClaims](tt.header)) - assert.Equal(t, tt.want, got) - }) - } -} - -func TestWithBearerPrefix(t *testing.T) { - type testCase[T jwt.Claims] struct { - name string - prefix string - want Handler[T] - } - tests := []testCase[myClaims]{ - { - name: "normal_set_bearer_prefix", - prefix: "jwt", - want: &authHandler[myClaims]{ - allowTokenHeader: "authorization", - bearerPrefix: "jwt", - claimsCTXKey: "claims", - exposeAccessHeader: "x-access-token", - exposeRefreshHeader: "x-refresh-token", - token: jwtToken, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NewAuthHandler[myClaims](jwtToken, - WithBearerPrefix[myClaims](tt.prefix)) - assert.Equal(t, tt.want, got) - }) - } -} - -func TestWithClaimsCTXKey(t *testing.T) { - type testCase[T jwt.Claims] struct { - name string - claimsCTXKey string - want Handler[T] - } - tests := []testCase[myClaims]{ - { - name: "normal_set_claims_ctx_key", - claimsCTXKey: "clm", - want: &authHandler[myClaims]{ - allowTokenHeader: "authorization", - bearerPrefix: "Bearer", - claimsCTXKey: "clm", - exposeAccessHeader: "x-access-token", - exposeRefreshHeader: "x-refresh-token", - token: jwtToken, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NewAuthHandler[myClaims](jwtToken, - WithClaimsCTXKey[myClaims](tt.claimsCTXKey)) - assert.Equal(t, tt.want, got) - }) - } -} - -func TestWithExposeAccessHeader(t *testing.T) { - type testCase[T jwt.Claims] struct { - name string - exposeAccessHeader string - want Handler[T] - } - tests := []testCase[myClaims]{ - { - name: "normal_set_expose_access_header", - exposeAccessHeader: "access", - want: &authHandler[myClaims]{ - allowTokenHeader: "authorization", - bearerPrefix: "Bearer", - claimsCTXKey: "claims", - exposeAccessHeader: "access", - exposeRefreshHeader: "x-refresh-token", - token: jwtToken, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NewAuthHandler[myClaims](jwtToken, - WithExposeAccessHeader[myClaims](tt.exposeAccessHeader)) - assert.Equal(t, tt.want, got) - }) - } -} - -func TestWithExposeRefreshHeader(t *testing.T) { - type testCase[T jwt.Claims] struct { - name string - exposeRefreshHeader string - want Handler[T] - } - tests := []testCase[myClaims]{ - { - name: "normal_set_expose_refresh_Header", - exposeRefreshHeader: "refresh", - want: &authHandler[myClaims]{ - allowTokenHeader: "authorization", - bearerPrefix: "Bearer", - claimsCTXKey: "claims", - exposeAccessHeader: "x-access-token", - exposeRefreshHeader: "refresh", - token: jwtToken, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NewAuthHandler[myClaims](jwtToken, - WithExposeRefreshHeader[myClaims](tt.exposeRefreshHeader)) - assert.Equal(t, tt.want, got) - }) - } -} - -func Test_authHandler_ExtractTokenString(t *testing.T) { - a := NewAuthHandler[myClaims](jwtToken) - type testCase[T jwt.Claims] struct { - name string - reqBuilder func(t *testing.T) *http.Request - want string - } - tests := []testCase[myClaims]{ - { - name: "normal_extract_token", - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.a1q3jHKedQGbA-Zrn6S21QUpI2ZNYNHoeG5LkxAXRJQ") - return req - }, - want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.a1q3jHKedQGbA-Zrn6S21QUpI2ZNYNHoeG5LkxAXRJQ", - }, - { - name: "bad_extract_token", - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("authorization", "Bearer_eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.a1q3jHKedQGbA-Zrn6S21QUpI2ZNYNHoeG5LkxAXRJQ") - return req - }, - want: "", - }, - { - name: "header_value_not_found", - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "", nil) - if err != nil { - t.Fatal(err) - } - return req - }, - want: "", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx, _ := gin.CreateTestContext(httptest.NewRecorder()) - ctx.Request = tt.reqBuilder(t) - - got := a.ExtractTokenString(ctx) - assert.Equal(t, tt.want, got) - }) - } -} - -func Test_authHandler_SetAccessToken(t *testing.T) { - type testCase[T jwt.Claims] struct { - name string - jwtToken token.Token[T] - claims T - want string - wantErr error - } - tests := []testCase[myClaims]{ - { - name: "normal_set_access_token", - jwtToken: token.NewJWTToken[myClaims]("foo", - token.WithNowFunc[myClaims](func() time.Time { - return time.UnixMilli(1695571200000) - })), - claims: myClaims{ - Foo: "bar", - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "bar", - Subject: "1", - IssuedAt: jwt.NewNumericDate( - time.UnixMilli(1695571200000)), - ExpiresAt: jwt.NewNumericDate( - time.UnixMilli(1695571200000). - Add(10 * time.Minute)), - }, - }, - want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.DkipgOka6QyyyvhW3IKLTnnWDQVuTBeGO5vb3Poj7ZY", - }, - { - name: "bad_claims", - jwtToken: token.NewJWTToken[myClaims]("foo", - token.WithSigningMethod[myClaims](jwt.SigningMethodRS512)), - claims: myClaims{ - Foo: "bar", - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "bar", - Subject: "1", - IssuedAt: jwt.NewNumericDate( - time.UnixMilli(1695571200000)), - ExpiresAt: jwt.NewNumericDate( - time.UnixMilli(1695571000000)), - }, - }, - wantErr: errors.New("key is invalid"), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - a := NewAuthHandler[myClaims](tt.jwtToken) - recorder := httptest.NewRecorder() - ctx, _ := gin.CreateTestContext(recorder) - req, err := http.NewRequest(http.MethodGet, "", nil) - if err != nil { - t.Fatal(err) - } - ctx.Request = req - - err = a.SetAccessToken(ctx, tt.claims) - assert.Equal(t, tt.wantErr, err) - assert.Equal(t, tt.want, - recorder.Header().Get("x-access-token")) - }) - } -} - -func Test_authHandler_SetRefreshToken(t *testing.T) { - type testCase[T jwt.Claims] struct { - name string - jwtToken token.Token[T] - claims T - want string - wantErr error - } - tests := []testCase[myClaims]{ - { - name: "normal_set_refresh_token", - jwtToken: token.NewJWTToken[myClaims]("foo", - token.WithNowFunc[myClaims](func() time.Time { - return time.UnixMilli(1695571200000) - })), - claims: myClaims{ - Foo: "bar", - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "bar", - Subject: "2", - IssuedAt: jwt.NewNumericDate( - time.UnixMilli(1695571200000)), - ExpiresAt: jwt.NewNumericDate( - time.UnixMilli(1695571200000). - Add(10 * time.Minute)), - }, - }, - want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJiYXIiLCJzdWIiOiIyIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.gc4-zm430YUBtQIJi07uAxMiMCG1tclhOODNM20fZlM", - }, - { - name: "bad_claims", - jwtToken: token.NewJWTToken[myClaims]("foo", - token.WithSigningMethod[myClaims](jwt.SigningMethodRS512)), - claims: myClaims{ - Foo: "bar", - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "bar", - Subject: "2", - IssuedAt: jwt.NewNumericDate( - time.UnixMilli(1695571200000)), - ExpiresAt: jwt.NewNumericDate( - time.UnixMilli(1695571000000)), - }, - }, - wantErr: errors.New("key is invalid"), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - a := NewAuthHandler[myClaims](tt.jwtToken) - recorder := httptest.NewRecorder() - ctx, _ := gin.CreateTestContext(recorder) - req, err := http.NewRequest(http.MethodGet, "", nil) - if err != nil { - t.Fatal(err) - } - ctx.Request = req - - err = a.SetRefreshToken(ctx, tt.claims) - assert.Equal(t, tt.wantErr, err) - assert.Equal(t, tt.want, - recorder.Header().Get("x-refresh-token")) - }) - } -} - -func Test_authHandler_VerifyToken(t *testing.T) { - type testCase[T jwt.Claims] struct { - name string - jwtToken token.Token[T] - token string - want T - wantErr error - } - tests := []testCase[myClaims]{ - { - name: "normal_set_claims", - jwtToken: token.NewJWTToken[myClaims]("foo", - token.WithNowFunc[myClaims](func() time.Time { - return time.UnixMilli(1695571500000) - }), - ), - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.DkipgOka6QyyyvhW3IKLTnnWDQVuTBeGO5vb3Poj7ZY", - want: myClaims{ - Foo: "bar", - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: "bar", - Subject: "1", - IssuedAt: jwt.NewNumericDate( - time.UnixMilli(1695571200000)), - ExpiresAt: jwt.NewNumericDate( - time.UnixMilli(1695571200000). - Add(10 * time.Minute)), - }, - }, - }, - { - name: "token_expired", - jwtToken: token.NewJWTToken[myClaims]("foo", - token.WithNowFunc[myClaims](func() time.Time { - return time.UnixMilli(1695572500000) - }), - ), - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.DkipgOka6QyyyvhW3IKLTnnWDQVuTBeGO5vb3Poj7ZY", - wantErr: fmt.Errorf("验证失败: %v", - fmt.Errorf("%v: %v", jwt.ErrTokenInvalidClaims, jwt.ErrTokenExpired)), - }, - { - name: "wrong_signature", - jwtToken: token.NewJWTToken[myClaims]("foo", - token.WithNowFunc[myClaims](func() time.Time { - return time.UnixMilli(1695571500000) - }), - ), - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.5AgQBNdf08M3vUIi_N2fVQrlNdrIbMvRw-8smkXATWc", - wantErr: fmt.Errorf("验证失败: %v", - fmt.Errorf("%v: %v", jwt.ErrTokenSignatureInvalid, jwt.ErrSignatureInvalid)), - }, - { - name: "bad_token", - jwtToken: token.NewJWTToken[myClaims]("foo", - token.WithNowFunc[myClaims](func() time.Time { - return time.UnixMilli(1695571500000) - }), - ), - token: "bad_token", - wantErr: fmt.Errorf("验证失败: %v: token contains an invalid number of segments", - jwt.ErrTokenMalformed), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - a := NewAuthHandler[myClaims](tt.jwtToken) - recorder := httptest.NewRecorder() - ctx, _ := gin.CreateTestContext(recorder) - - err := a.VerifyToken(ctx, tt.token) - assert.Equal(t, tt.wantErr, err) - if err != nil { - return - } - claims, ok := ctx.Get("claims") - if !ok { - t.Errorf("claims 设置失败") - } - clm, ok := claims.(myClaims) - if !ok { - t.Errorf("claims 类型错误") - } - assert.Equal(t, tt.want, clm) - }) - } -} diff --git a/middlewares/auth/jwt.go b/middlewares/auth/jwt.go deleted file mode 100644 index 896bd91..0000000 --- a/middlewares/auth/jwt.go +++ /dev/null @@ -1,76 +0,0 @@ -package auth - -import ( - "net/http" - - "github.com/ecodeclub/ekit/set" - "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt/v5" -) - -type JWTBuilder[T jwt.Claims] struct { - publicPaths set.Set[string] - Handler[T] -} - -func NewJWTBuilder[T jwt.Claims](handler Handler[T], opts ...BuilderOption[T]) *JWTBuilder[T] { - dOpts := JWTBuilder[T]{ - publicPaths: set.NewMapSet[string](0), - Handler: handler, - } - - for _, opt := range opts { - opt.apply(&dOpts) - } - - return &dOpts -} - -type BuilderOption[T jwt.Claims] interface { - apply(*JWTBuilder[T]) -} - -type funcBuilderOption[T jwt.Claims] struct { - f func(*JWTBuilder[T]) -} - -func (fdo *funcBuilderOption[T]) apply(do *JWTBuilder[T]) { - fdo.f(do) -} - -func newFuncBuilderOption[T jwt.Claims](f func(*JWTBuilder[T])) *funcBuilderOption[T] { - return &funcBuilderOption[T]{ - f: f, - } -} - -func WithIgnorePaths[T jwt.Claims](paths ...string) BuilderOption[T] { - s := set.NewMapSet[string](len(paths)) - for _, path := range paths { - s.Add(path) - } - return newFuncBuilderOption[T](func(b *JWTBuilder[T]) { - b.publicPaths = s - }) -} - -func (b *JWTBuilder[T]) Build() gin.HandlerFunc { - return func(ctx *gin.Context) { - // 不需要校验 - if b.publicPaths.Exist(ctx.Request.URL.Path) { - return - } - - tokenStr := b.ExtractTokenString(ctx) - if tokenStr == "" { - ctx.AbortWithStatus(http.StatusUnauthorized) - return - } - - err := b.VerifyToken(ctx, tokenStr) - if err != nil { - ctx.AbortWithStatus(http.StatusUnauthorized) - return - } - } -} diff --git a/middlewares/auth/jwt_test.go b/middlewares/auth/jwt_test.go deleted file mode 100644 index 44714f0..0000000 --- a/middlewares/auth/jwt_test.go +++ /dev/null @@ -1,169 +0,0 @@ -package auth - -import ( - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/ecodeclub/ekit/set" - "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt/v5" - "github.com/stretchr/testify/assert" - - "github.com/ecodeclub/ginx/middlewares/token" -) - -var authHdl = NewAuthHandler[myClaims](jwtToken) - -func TestNewJWTBuilder(t *testing.T) { - type testCase[T jwt.Claims] struct { - name string - handler Handler[T] - want *JWTBuilder[T] - } - tests := []testCase[myClaims]{ - { - name: "normal", - handler: authHdl, - want: &JWTBuilder[myClaims]{ - Handler: authHdl, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NewJWTBuilder[myClaims](tt.handler) - assert.Equal(t, tt.want, got) - }) - } -} - -func TestWithIgnorePaths(t *testing.T) { - type testCase[T jwt.Claims] struct { - name string - paths []string - want func() *JWTBuilder[T] - } - tests := []testCase[myClaims]{ - { - name: "normal", - paths: []string{ - "/login", - "/signup", - }, - want: func() *JWTBuilder[myClaims] { - pathSet := set.NewMapSet[string](2) - pathSet.Add("/login") - pathSet.Add("/signup") - - return &JWTBuilder[myClaims]{ - publicPaths: pathSet, - Handler: authHdl, - } - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NewJWTBuilder(authHdl, - WithIgnorePaths[myClaims](tt.paths...)) - assert.Equal(t, tt.want(), got) - }) - } -} - -func TestJWTBuilder_Build(t *testing.T) { - type testCase[T jwt.Claims] struct { - name string - b *JWTBuilder[T] - reqBuilder func(t *testing.T) *http.Request - wantCode int - } - tests := []testCase[myClaims]{ - { - name: "normal", - b: NewJWTBuilder[myClaims]( - NewAuthHandler[myClaims]( - token.NewJWTToken[myClaims]("foo", - token.WithNowFunc[myClaims](func() time.Time { - return time.UnixMilli(1695571500000) - })))), - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.DkipgOka6QyyyvhW3IKLTnnWDQVuTBeGO5vb3Poj7ZY") - return req - }, - wantCode: http.StatusOK, - }, - { - name: "verification_failed", - b: NewJWTBuilder[myClaims](authHdl), - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.DkipgOka6QyyyvhW3IKLTnnWDQVuTBeGO5vb3Poj7ZY") - return req - }, - wantCode: http.StatusUnauthorized, - }, - { - name: "extract_token_failed", - b: NewJWTBuilder[myClaims]( - NewAuthHandler[myClaims]( - token.NewJWTToken[myClaims]("foo", - token.WithNowFunc[myClaims](func() time.Time { - return time.UnixMilli(1695571500000) - })))), - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("authorization", "Bearer_eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.DkipgOka6QyyyvhW3IKLTnnWDQVuTBeGO5vb3Poj7ZY") - return req - }, - wantCode: http.StatusUnauthorized, - }, - { - name: "verification_failed", - b: NewJWTBuilder[myClaims](authHdl, - WithIgnorePaths[myClaims]("/login")), - reqBuilder: func(t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, "/login", nil) - if err != nil { - t.Fatal(err) - } - return req - }, - wantCode: http.StatusOK, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - server := gin.Default() - server.Use(tt.b.Build()) - tt.b.registerRoutes(server) - - req := tt.reqBuilder(t) - recorder := httptest.NewRecorder() - - server.ServeHTTP(recorder, req) - assert.Equal(t, tt.wantCode, recorder.Code) - }) - } -} - -func (b *JWTBuilder[T]) registerRoutes(server *gin.Engine) { - server.GET("/", func(ctx *gin.Context) { - ctx.Status(http.StatusOK) - }) - server.GET("/login", func(ctx *gin.Context) { - ctx.Status(http.StatusOK) - }) -} diff --git a/middlewares/auth/types.go b/middlewares/auth/types.go deleted file mode 100644 index 6efb6ef..0000000 --- a/middlewares/auth/types.go +++ /dev/null @@ -1,13 +0,0 @@ -package auth - -import ( - "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt/v5" -) - -type Handler[T jwt.Claims] interface { - ExtractTokenString(ctx *gin.Context) string - VerifyToken(ctx *gin.Context, token string) error - SetAccessToken(ctx *gin.Context, claims T) error - SetRefreshToken(ctx *gin.Context, claims T) error -} diff --git a/middlewares/jwt/claims.go b/middlewares/jwt/claims.go new file mode 100644 index 0000000..d6c5233 --- /dev/null +++ b/middlewares/jwt/claims.go @@ -0,0 +1,79 @@ +package jwt + +import ( + "time" + + "github.com/golang-jwt/jwt/v5" +) + +type RegisteredClaims[T any] struct { + Data T `json:"data"` + jwt.RegisteredClaims +} + +type Options struct { + Issuer string // 签发人 + Expire time.Duration // 有效期 + EncryptionKey string // 加密密钥 + DecryptKey string // 解密密钥 + Method jwt.SigningMethod // 签名方式 +} + +// NewOptions 定义一个JWT Claims配置 +// Issuer: 默认使用 "". +// DecryptKey: 默认与 EncryptionKey 相同. +// Method: 默认使用 jwt.SigningMethodHS256 签名方式. +func NewOptions(expire time.Duration, encryptionKey string, opts ...Option) *Options { + dOpts := Options{ + Expire: expire, + EncryptionKey: encryptionKey, + DecryptKey: encryptionKey, + Method: jwt.SigningMethodHS256, + } + + for _, opt := range opts { + opt.apply(&dOpts) + } + + return &dOpts +} + +type Option interface { + apply(*Options) +} + +type funcOption struct { + f func(handler *Options) +} + +func (fdo *funcOption) apply(do *Options) { + fdo.f(do) +} + +func newFuncOption( + f func(*Options)) *funcOption { + return &funcOption{ + f: f, + } +} + +// WithIssuer 设置签发人. +func WithIssuer(issuer string) Option { + return newFuncOption(func(o *Options) { + o.Issuer = issuer + }) +} + +// WithDecryptKey 设置解密密钥. +func WithDecryptKey(decryptKey string) Option { + return newFuncOption(func(o *Options) { + o.DecryptKey = decryptKey + }) +} + +// WithMethod 设置 JWT 的签名方法. +func WithMethod(method jwt.SigningMethod) Option { + return newFuncOption(func(o *Options) { + o.Method = method + }) +} diff --git a/middlewares/jwt/claims_test.go b/middlewares/jwt/claims_test.go new file mode 100644 index 0000000..5b46aba --- /dev/null +++ b/middlewares/jwt/claims_test.go @@ -0,0 +1,115 @@ +package jwt + +import ( + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" +) + +func TestNewOptions(t *testing.T) { + tests := []struct { + name string + expire time.Duration + encryptionKey string + want *Options + }{ + { + name: "normal", + expire: 10 * time.Minute, + encryptionKey: "sign key", + want: &Options{ + Expire: 10 * time.Minute, + EncryptionKey: "sign key", + DecryptKey: "sign key", + Method: jwt.SigningMethodHS256, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewOptions(tt.expire, tt.encryptionKey) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithDecryptKey(t *testing.T) { + tests := []struct { + name string + decryptKey string + want *Options + }{ + { + name: "set_another_key", + decryptKey: "other sign key", + want: &Options{ + Expire: defaultExpire, + EncryptionKey: encryptionKey, + DecryptKey: "other sign key", + Method: jwt.SigningMethodHS256, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewOptions(defaultExpire, encryptionKey, + WithDecryptKey(tt.decryptKey)) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithIssuer(t *testing.T) { + tests := []struct { + name string + issuer string + want *Options + }{ + { + name: "set_issuer", + issuer: "foo", + want: &Options{ + Issuer: "foo", + Expire: defaultExpire, + EncryptionKey: encryptionKey, + DecryptKey: encryptionKey, + Method: jwt.SigningMethodHS256, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewOptions(defaultExpire, encryptionKey, + WithIssuer(tt.issuer)) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithMethod(t *testing.T) { + tests := []struct { + name string + method jwt.SigningMethod + want *Options + }{ + { + name: "set_another_jwt_signing_method", + method: jwt.SigningMethodHS384, + want: &Options{ + Expire: defaultExpire, + EncryptionKey: encryptionKey, + DecryptKey: encryptionKey, + Method: jwt.SigningMethodHS384, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewOptions(defaultExpire, encryptionKey, + WithMethod(tt.method)) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/middlewares/jwt/jwt.go b/middlewares/jwt/jwt.go new file mode 100644 index 0000000..5559b30 --- /dev/null +++ b/middlewares/jwt/jwt.go @@ -0,0 +1,309 @@ +package jwt + +import ( + "errors" + "fmt" + "net/http" + "strings" + "time" + + "github.com/ecodeclub/ekit/set" + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" +) + +var ( + ErrEmptyRefreshOpts = errors.New("refreshJWTOptions are nil") +) + +type Manager[T any] struct { + publicPaths set.Set[string] // 存放不需要认证的 path + allowTokenHeader string // 认证的请求头(存放 token 的请求头 key) + bearerPrefix string // 拼接 token 的前缀 + claimsCTXKey string // 存放到 gin.Context 的 key + exposeAccessHeader string // 暴露到外部的资源请求头 + exposeRefreshHeader string // 暴露到外部的刷新请求头 + + accessJWTOptions *Options // 资源 token 选项 + refreshJWTOptions *Options // 刷新 token 选项 + rotateRefreshToken bool // 轮换刷新令牌 + + nowFunc func() time.Time // 控制 jwt 的时间 +} + +// NewManager 定义一个 JWTLoginManger +// allowTokenHeader: 默认使用 authorization 为认证请求头. +// bearerPrefix: 默认使用 Bearer 拼接 token. +// claimsCTXKey: 默认使用 claims 为设置到 gin.Context 的key +// exposeAccessHeader: 默认使用 x-access-token 为暴露外部的资源请求头. +// exposeRefreshHeader: 默认使用 x-refresh-token 为暴露外部的刷新请求头. +// refreshJWTOptions: 默认使用 nil 为刷新 token 的配置, +// 如要使用 refresh 功能则需要使用 WithRefreshJWTOptions 添加相关配置. +// rotateRefreshToken: 默认不轮换刷新令牌. +// 该配置需要设置 refreshJWTOptions 才有效. +func NewManager[T any](accessJWTOptions *Options, + opts ...ManagerOption[T]) *Manager[T] { + dOpts := defaultManagerOption[T]() + dOpts.accessJWTOptions = accessJWTOptions + + for _, opt := range opts { + opt.apply(&dOpts) + } + + return &dOpts +} + +type ManagerOption[T any] interface { + apply(*Manager[T]) +} + +type funcManagerOption[T any] struct { + f func(handler *Manager[T]) +} + +func (fdo *funcManagerOption[T]) apply(do *Manager[T]) { + fdo.f(do) +} + +func newFuncManagerOption[T any]( + f func(handler *Manager[T])) *funcManagerOption[T] { + return &funcManagerOption[T]{ + f: f, + } +} + +func defaultManagerOption[T any]() Manager[T] { + return Manager[T]{ + publicPaths: set.NewMapSet[string](0), + allowTokenHeader: "authorization", + bearerPrefix: "Bearer", + claimsCTXKey: "claims", + exposeAccessHeader: "x-access-token", + exposeRefreshHeader: "x-refresh-token", + rotateRefreshToken: false, + nowFunc: time.Now, + } +} + +// WithIgnorePaths 设置忽略资源令牌认证的路径. +// 例如: '/login', '/api/v1/signup'. +func WithIgnorePaths[T any](paths ...string) ManagerOption[T] { + s := set.NewMapSet[string](len(paths)) + for _, path := range paths { + s.Add(path) + } + return newFuncManagerOption[T](func(l *Manager[T]) { + l.publicPaths = s + }) +} + +// WithAllowTokenHeader 设置允许 token 的请求头. +func WithAllowTokenHeader[T any](header string) ManagerOption[T] { + return newFuncManagerOption[T](func(m *Manager[T]) { + m.allowTokenHeader = header + }) +} + +// WithBearerPrefix 设置与 token 拼接的前缀. +// 例如: 'Bearer eyx.eyx.x'中的 'Bearer'. +func WithBearerPrefix[T any](prefix string) ManagerOption[T] { + return newFuncManagerOption[T](func(m *Manager[T]) { + m.bearerPrefix = prefix + }) +} + +// WithClaimsCTXKey 设置放到 gin.Context 中的 key. +func WithClaimsCTXKey[T any](key string) ManagerOption[T] { + return newFuncManagerOption[T](func(m *Manager[T]) { + m.claimsCTXKey = key + }) +} + +// WithExposeAccessHeader 设置公开资源令牌的请求头. +func WithExposeAccessHeader[T any](header string) ManagerOption[T] { + return newFuncManagerOption[T](func(m *Manager[T]) { + m.exposeAccessHeader = header + }) +} + +// WithExposeRefreshHeader 设置公开刷新令牌的请求头. +func WithExposeRefreshHeader[T any](header string) ManagerOption[T] { + return newFuncManagerOption[T](func(m *Manager[T]) { + m.exposeRefreshHeader = header + }) +} + +// WithRefreshJWTOptions 设置刷新令牌相关的配置. +func WithRefreshJWTOptions[T any](refreshOpts *Options) ManagerOption[T] { + return newFuncManagerOption(func(m *Manager[T]) { + m.refreshJWTOptions = refreshOpts + }) +} + +// WithRotateRefreshToken 设置轮换刷新令牌. +func WithRotateRefreshToken[T any](isRotate bool) ManagerOption[T] { + return newFuncManagerOption(func(m *Manager[T]) { + m.rotateRefreshToken = isRotate + }) +} + +// WithNowFunc 设置当前时间. +// 一般用于测试. +func WithNowFunc[T any](nowFunc func() time.Time) ManagerOption[T] { + return newFuncManagerOption(func(m *Manager[T]) { + m.nowFunc = nowFunc + }) +} + +// Refresh 刷新 token 的 gin.HandlerFunc +func (m *Manager[T]) Refresh(ctx *gin.Context) { + if m.refreshJWTOptions == nil { + ctx.Status(http.StatusInternalServerError) + return + } + + tokenStr := m.extractTokenString(ctx) + clm, err := m.VerifyRefreshToken(tokenStr) + if err != nil { + ctx.Status(http.StatusUnauthorized) + return + } + accessToken, err := m.GenerateAccessToken(clm.Data) + if err != nil { + ctx.Status(http.StatusInternalServerError) + return + } + ctx.Header(m.exposeAccessHeader, accessToken) + + // 轮换刷新令牌 + if m.rotateRefreshToken { + refreshToken, err := m.GenerateRefreshToken(clm.Data) + if err != nil { + ctx.Status(http.StatusInternalServerError) + return + } + ctx.Header(m.exposeRefreshHeader, refreshToken) + } + ctx.Status(http.StatusOK) +} + +// MiddlewareBuilder 登录认证的中间件 +func (m *Manager[T]) MiddlewareBuilder() gin.HandlerFunc { + return func(ctx *gin.Context) { + // 不需要校验 + if m.publicPaths.Exist(ctx.Request.URL.Path) { + return + } + + tokenStr := m.extractTokenString(ctx) + if tokenStr == "" { + ctx.AbortWithStatus(http.StatusUnauthorized) + return + } + + err := m.verifyTokenAndSetClm(ctx, tokenStr) + if err != nil { + ctx.AbortWithStatus(http.StatusUnauthorized) + return + } + } +} + +// extractTokenString 提取 token 字符串. +func (m *Manager[T]) extractTokenString(ctx *gin.Context) string { + authCode := ctx.GetHeader(m.allowTokenHeader) + if authCode == "" { + return "" + } + var b strings.Builder + b.WriteString(m.bearerPrefix) + b.WriteString(" ") + prefix := b.String() + if strings.HasPrefix(authCode, prefix) { + return authCode[len(prefix):] + } + return "" +} + +// verifyTokenAndSetClm 校验 access token 并把 claims 设置到 gin.Context 中. +func (m *Manager[T]) verifyTokenAndSetClm(ctx *gin.Context, token string) error { + claims, err := m.VerifyAccessToken(token) + if err != nil { + return err + } + ctx.Set(m.claimsCTXKey, claims) + return nil +} + +// GenerateAccessToken 生成资源 token. +func (m *Manager[T]) GenerateAccessToken(data T) (string, error) { + nowTime := m.nowFunc() + claims := RegisteredClaims[T]{ + Data: data, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: m.accessJWTOptions.Issuer, + ExpiresAt: jwt.NewNumericDate(nowTime.Add(m.accessJWTOptions.Expire)), + NotBefore: jwt.NewNumericDate(nowTime), + IssuedAt: jwt.NewNumericDate(nowTime), + }, + } + + token := jwt.NewWithClaims(m.accessJWTOptions.Method, claims) + return token.SignedString([]byte(m.accessJWTOptions.EncryptionKey)) +} + +// VerifyAccessToken 校验资源 token. +func (m *Manager[T]) VerifyAccessToken(token string) (RegisteredClaims[T], error) { + t, err := jwt.ParseWithClaims(token, &RegisteredClaims[T]{}, + func(*jwt.Token) (interface{}, error) { + return []byte(m.accessJWTOptions.DecryptKey), nil + }, + jwt.WithTimeFunc(m.nowFunc), + ) + if err != nil || !t.Valid { + return RegisteredClaims[T]{}, fmt.Errorf("验证失败: %v", err) + } + clm, _ := t.Claims.(*RegisteredClaims[T]) + return *clm, nil +} + +// GenerateRefreshToken 生成刷新 token. +// 需要设置 refreshJWTOptions 否则返回 ErrEmptyRefreshOpts 错误. +func (m *Manager[T]) GenerateRefreshToken(data T) (string, error) { + if m.refreshJWTOptions == nil { + return "", ErrEmptyRefreshOpts + } + + nowTime := m.nowFunc() + claims := RegisteredClaims[T]{ + Data: data, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: m.refreshJWTOptions.Issuer, + ExpiresAt: jwt.NewNumericDate(nowTime.Add(m.refreshJWTOptions.Expire)), + NotBefore: jwt.NewNumericDate(nowTime), + IssuedAt: jwt.NewNumericDate(nowTime), + }, + } + + token := jwt.NewWithClaims(m.refreshJWTOptions.Method, claims) + return token.SignedString([]byte(m.refreshJWTOptions.EncryptionKey)) +} + +// VerifyRefreshToken 校验刷新 token. +// 需要设置 refreshJWTOptions 否则返回 ErrEmptyRefreshOpts 错误. +func (m *Manager[T]) VerifyRefreshToken(token string) (RegisteredClaims[T], error) { + if m.refreshJWTOptions == nil { + return RegisteredClaims[T]{}, ErrEmptyRefreshOpts + } + t, err := jwt.ParseWithClaims(token, &RegisteredClaims[T]{}, + func(*jwt.Token) (interface{}, error) { + return []byte(m.refreshJWTOptions.DecryptKey), nil + }, + jwt.WithTimeFunc(m.nowFunc), + ) + if err != nil || !t.Valid { + return RegisteredClaims[T]{}, fmt.Errorf("验证失败: %v", err) + } + clm, _ := t.Claims.(*RegisteredClaims[T]) + return *clm, nil +} diff --git a/middlewares/jwt/jwt_test.go b/middlewares/jwt/jwt_test.go new file mode 100644 index 0000000..92c92e4 --- /dev/null +++ b/middlewares/jwt/jwt_test.go @@ -0,0 +1,1001 @@ +package jwt + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/ecodeclub/ekit/set" + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" +) + +type data struct { + Foo string `json:"foo"` +} + +var ( + defaultExpire = 10 * time.Minute + defaultClaims = RegisteredClaims[data]{ + Data: data{Foo: "1"}, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(nowTime.Add(defaultExpire)), + NotBefore: jwt.NewNumericDate(nowTime), + IssuedAt: jwt.NewNumericDate(nowTime), + }, + } + encryptionKey = "sign key" + nowTime = time.UnixMilli(1695571200000) + defaultOption = &Options{ + Expire: defaultExpire, + EncryptionKey: encryptionKey, + DecryptKey: encryptionKey, + Method: jwt.SigningMethodHS256, + } + defaultManager = NewManager[data](defaultOption, + WithNowFunc[data](func() time.Time { + return nowTime + }), + ) +) + +func TestManager_GenerateAccessToken(t *testing.T) { + m := defaultManager + type testCase[T any] struct { + name string + data T + want string + wantErr error + } + tests := []testCase[data]{ + { + name: "normal", + data: data{Foo: "1"}, + want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.UNuVOmAwgR-atNOMVi9JldtT7qGl7LCFuyq4uiYgg_Y", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := m.GenerateAccessToken(tt.data) + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestManager_GenerateRefreshToken(t *testing.T) { + m := defaultManager + type testCase[T any] struct { + name string + refreshJWTOptions *Options + data T + want string + wantErr error + } + tests := []testCase[data]{ + { + name: "normal", + refreshJWTOptions: &Options{ + Expire: 24 * 60 * time.Minute, + EncryptionKey: "refresh sign key", + DecryptKey: "refresh sign key", + Method: jwt.SigningMethodHS256, + }, + data: data{Foo: "1"}, + want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw", + }, + { + name: "mistake", + data: data{Foo: "1"}, + want: "", + wantErr: ErrEmptyRefreshOpts, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m.refreshJWTOptions = tt.refreshJWTOptions + got, err := m.GenerateRefreshToken(tt.data) + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestManager_MiddlewareBuilder(t *testing.T) { + type testCase[T any] struct { + name string + m *Manager[T] + reqBuilder func(t *testing.T) *http.Request + wantCode int + } + tests := []testCase[data]{ + { + // 验证失败 + name: "verify_failed", + m: NewManager[data](defaultOption, + WithIgnorePaths[data]("/login")), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.UNuVOmAwgR-atNOMVi9JldtT7qGl7LCFuyq4uiYgg_Y") + return req + }, + wantCode: http.StatusUnauthorized, + }, + { + // 提取 token 失败 + name: "extract_token_failed", + m: NewManager[data](defaultOption, + WithIgnorePaths[data]("/login")), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer ") + return req + }, + wantCode: http.StatusUnauthorized, + }, + { + // 无需认证直接通过 + name: "pass_without_authentication", + m: NewManager[data](defaultOption, + WithIgnorePaths[data]("/login")), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/login", nil) + if err != nil { + t.Fatal(err) + } + return req + }, + wantCode: http.StatusOK, + }, + { + // 验证通过 + name: "pass_the_verification", + m: NewManager[data](defaultOption, + WithIgnorePaths[data]("/login"), + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695571500000) + }), + ), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.UNuVOmAwgR-atNOMVi9JldtT7qGl7LCFuyq4uiYgg_Y") + return req + }, + wantCode: http.StatusOK, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := gin.Default() + server.Use(tt.m.MiddlewareBuilder()) + tt.m.registerRoutes(server) + + req := tt.reqBuilder(t) + recorder := httptest.NewRecorder() + + server.ServeHTTP(recorder, req) + assert.Equal(t, tt.wantCode, recorder.Code) + }) + } +} + +func TestManager_Refresh(t *testing.T) { + type testCase[T any] struct { + name string + m *Manager[T] + reqBuilder func(t *testing.T) *http.Request + wantCode int + wantAccessToken string + wantRefreshToken string + } + tests := []testCase[data]{ + { + // 更新资源令牌并轮换刷新令牌 + name: "refresh_access_token_and_rotate_refresh_token", + m: NewManager[data](defaultOption, + WithRefreshJWTOptions[data]( + NewOptions(24*60*time.Minute, + "refresh sign key", + )), + WithRotateRefreshToken[data](true), + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695623000000) + }), + ), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/refresh", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw") + return req + }, + wantCode: http.StatusOK, + wantAccessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjIzNjAwLCJuYmYiOjE2OTU2MjMwMDAsImlhdCI6MTY5NTYyMzAwMH0.5Hv-Gq8RW0xAFBh4WhKc0KDLsdgTEv3RUhPceaM4e5M", + wantRefreshToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NzA5NDAwLCJuYmYiOjE2OTU2MjMwMDAsImlhdCI6MTY5NTYyMzAwMH0.4R-JmqcKHtsoFOGFDe5SBA2wNV0F-XvnP2Janp6NfZY", + }, + { + // 更新资源令牌但轮换刷新令牌生成失败 + name: "refresh_access_token_but_gen_rotate_refresh_token_failed", + m: NewManager[data](defaultOption, + WithRefreshJWTOptions[data]( + NewOptions(24*60*time.Minute, + "refresh sign key", + WithMethod(jwt.SigningMethodRS256), + )), + WithRotateRefreshToken[data](true), + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695623000000) + }), + ), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/refresh", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw") + return req + }, + wantCode: http.StatusInternalServerError, + }, + { + // 更新资源令牌 + name: "refresh_access_token", + m: NewManager[data](defaultOption, + WithRefreshJWTOptions[data]( + NewOptions(24*60*time.Minute, + "refresh sign key", + )), + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695623000000) + }), + ), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/refresh", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw") + return req + }, + wantCode: http.StatusOK, + wantAccessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjIzNjAwLCJuYmYiOjE2OTU2MjMwMDAsImlhdCI6MTY5NTYyMzAwMH0.5Hv-Gq8RW0xAFBh4WhKc0KDLsdgTEv3RUhPceaM4e5M", + }, + { + // 生成资源令牌失败 + name: "gen_access_token_failed", + m: NewManager[data]( + &Options{ + Expire: 10 * time.Minute, + EncryptionKey: encryptionKey, + DecryptKey: encryptionKey, + Method: jwt.SigningMethodRS256, + }, + WithRefreshJWTOptions[data]( + NewOptions(24*60*time.Minute, + "refresh sign key", + )), + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695623000000) + }), + ), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/refresh", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw") + return req + }, + wantCode: http.StatusInternalServerError, + }, + { + // 刷新令牌认证失败 + name: "refresh_token_verify_failed", + m: NewManager[data]( + defaultOption, + WithRefreshJWTOptions[data]( + NewOptions(24*60*time.Minute, + "refresh sign key", + )), + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695723000000) + }), + ), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/refresh", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw") + return req + }, + wantCode: http.StatusUnauthorized, + }, + { + // 没有设置刷新令牌选项 + name: "not_set_refreshJWTOptions", + m: NewManager[data]( + defaultOption, + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695723000000) + }), + ), + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/refresh", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw") + return req + }, + wantCode: http.StatusInternalServerError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := gin.Default() + tt.m.registerRoutes(server) + + req := tt.reqBuilder(t) + recorder := httptest.NewRecorder() + + server.ServeHTTP(recorder, req) + assert.Equal(t, tt.wantCode, recorder.Code) + if tt.wantCode != http.StatusOK { + return + } + assert.Equal(t, tt.wantAccessToken, + recorder.Header().Get("x-access-token")) + assert.Equal(t, tt.wantRefreshToken, + recorder.Header().Get("x-refresh-token")) + }) + } +} + +func TestManager_VerifyAccessToken(t *testing.T) { + type testCase[T any] struct { + name string + m *Manager[T] + token string + want RegisteredClaims[T] + wantErr error + } + tests := []testCase[data]{ + { + name: "normal", + m: defaultManager, + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.UNuVOmAwgR-atNOMVi9JldtT7qGl7LCFuyq4uiYgg_Y", + want: defaultClaims, + }, + { + // token 过期了 + name: "token_expired", + m: NewManager[data](defaultOption, + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695671200000) + }), + ), + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.UNuVOmAwgR-atNOMVi9JldtT7qGl7LCFuyq4uiYgg_Y", + wantErr: fmt.Errorf("验证失败: %v", + fmt.Errorf("%v: %v", jwt.ErrTokenInvalidClaims, jwt.ErrTokenExpired)), + }, + { + // token 签名错误 + name: "bad_sign_key", + m: NewManager[data]( + &Options{ + Expire: defaultExpire, + EncryptionKey: encryptionKey, + DecryptKey: "bad sign key", + Method: jwt.SigningMethodHS256, + }, + WithNowFunc[data](func() time.Time { + return nowTime + }), + ), + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.UNuVOmAwgR-atNOMVi9JldtT7qGl7LCFuyq4uiYgg_Y", + wantErr: fmt.Errorf("验证失败: %v", + fmt.Errorf("%v: %v", jwt.ErrTokenSignatureInvalid, jwt.ErrSignatureInvalid)), + }, + { + // 错误的 token + name: "bad_token", + m: defaultManager, + token: "bad_token", + wantErr: fmt.Errorf("验证失败: %v: token contains an invalid number of segments", + jwt.ErrTokenMalformed), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.m.VerifyAccessToken(tt.token) + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestManager_VerifyRefreshToken(t *testing.T) { + type testCase[T any] struct { + name string + m *Manager[T] + token string + want RegisteredClaims[T] + wantErr error + } + tests := []testCase[data]{ + { + name: "normal", + m: NewManager[data](defaultOption, + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695601200000) + }), + WithRefreshJWTOptions[data](&Options{ + Expire: 24 * 60 * time.Minute, + EncryptionKey: "refresh sign key", + DecryptKey: "refresh sign key", + Method: jwt.SigningMethodHS256, + }), + ), + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw", + want: RegisteredClaims[data]{ + Data: data{Foo: "1"}, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(nowTime.Add(24 * 60 * time.Minute)), + NotBefore: jwt.NewNumericDate(nowTime), + IssuedAt: jwt.NewNumericDate(nowTime), + }, + }, + }, + { + // token 过期了 + name: "token_expired", + m: NewManager[data](defaultOption, + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695701200000) + }), + WithRefreshJWTOptions[data](&Options{ + Expire: 24 * 60 * time.Minute, + EncryptionKey: "refresh sign key", + DecryptKey: "refresh sign key", + Method: jwt.SigningMethodHS256, + }), + ), + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw", + wantErr: fmt.Errorf("验证失败: %v", + fmt.Errorf("%v: %v", jwt.ErrTokenInvalidClaims, jwt.ErrTokenExpired)), + }, + { + // token 签名错误 + name: "bad_sign_key", + m: NewManager[data](defaultOption, + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695601200000) + }), + WithRefreshJWTOptions[data](&Options{ + Expire: 24 * 60 * time.Minute, + EncryptionKey: "bad refresh sign key", + DecryptKey: "bad refresh sign key", + Method: jwt.SigningMethodHS256, + }), + ), + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw", + wantErr: fmt.Errorf("验证失败: %v", + fmt.Errorf("%v: %v", jwt.ErrTokenSignatureInvalid, jwt.ErrSignatureInvalid)), + }, + { + // 错误的 token + name: "bad_token", + m: NewManager[data](defaultOption, + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695601200000) + }), + WithRefreshJWTOptions[data](&Options{ + Expire: 24 * 60 * time.Minute, + EncryptionKey: "refresh sign key", + DecryptKey: "refresh sign key", + Method: jwt.SigningMethodHS256, + }), + ), + token: "bad_token", + wantErr: fmt.Errorf("验证失败: %v: token contains an invalid number of segments", + jwt.ErrTokenMalformed), + }, + { + name: "no_refresh_options", + m: NewManager[data](defaultOption, + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695601200000) + }), + ), + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw", + wantErr: ErrEmptyRefreshOpts, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.m.VerifyRefreshToken(tt.token) + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestManager_extractTokenString(t *testing.T) { + m := defaultManager + type header struct { + key string + value string + } + type testCase[T any] struct { + name string + header header + want string + } + tests := []testCase[data]{ + { + name: "normal", + header: header{ + key: "authorization", + value: "Bearer token", + }, + want: "token", + }, + { + name: "mistake_prefix", + header: header{ + key: "authorization", + value: "bearer token", + }, + }, + { + name: "no_allow_token_header", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req, err := http.NewRequest(http.MethodGet, "", nil) + req.Header.Add(tt.header.key, tt.header.value) + if err != nil { + t.Fatal(err) + } + ctx.Request = req + + got := m.extractTokenString(ctx) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestManager_verifyTokenAndSetClm(t *testing.T) { + type testCase[T any] struct { + name string + m *Manager[T] + token string + want RegisteredClaims[T] + wantErr error + } + tests := []testCase[data]{ + { + name: "normal", + m: defaultManager, + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.UNuVOmAwgR-atNOMVi9JldtT7qGl7LCFuyq4uiYgg_Y", + want: defaultClaims, + }, + { + name: "verify_access_token_failed", + m: NewManager[data]( + defaultOption, + WithNowFunc[data](func() time.Time { + return time.UnixMilli(1695671200000) + }), + ), + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.UNuVOmAwgR-atNOMVi9JldtT7qGl7LCFuyq4uiYgg_Y", + want: RegisteredClaims[data]{}, + wantErr: fmt.Errorf("验证失败: %v", + fmt.Errorf("%v: %v", jwt.ErrTokenInvalidClaims, jwt.ErrTokenExpired)), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req, err := http.NewRequest(http.MethodGet, "", nil) + if err != nil { + t.Fatal(err) + } + ctx.Request = req + err = tt.m.verifyTokenAndSetClm(ctx, tt.token) + assert.Equal(t, tt.wantErr, err) + if err != nil { + return + } + v, ok := ctx.Get("claims") + if !ok { + t.Error("claims设置失败") + } + clm, ok := v.(RegisteredClaims[data]) + if !ok { + t.Error("claims不是 RegisteredClaims[T] 类型") + } + assert.Equal(t, tt.want, clm) + }) + } +} + +func TestWithAllowTokenHeader(t *testing.T) { + type testCase[T any] struct { + name string + fn func() ManagerOption[T] + want string + } + tests := []testCase[data]{ + { + name: "default", + fn: func() ManagerOption[data] { + return nil + }, + want: "authorization", + }, + { + name: "set_another_header", + fn: func() ManagerOption[data] { + return WithAllowTokenHeader[data]("jwt") + }, + want: "jwt", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got string + if tt.fn() == nil { + got = NewManager[data]( + defaultOption, + ).allowTokenHeader + } else { + got = NewManager[data]( + defaultOption, + tt.fn(), + ).allowTokenHeader + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithBearerPrefix(t *testing.T) { + type testCase[T any] struct { + name string + fn func() ManagerOption[T] + want string + } + tests := []testCase[data]{ + { + name: "default", + fn: func() ManagerOption[data] { + return nil + }, + want: "Bearer", + }, + { + name: "set_another_prefix", + fn: func() ManagerOption[data] { + return WithBearerPrefix[data]("jwt") + }, + want: "jwt", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got string + if tt.fn() == nil { + got = NewManager[data]( + defaultOption, + ).bearerPrefix + } else { + got = NewManager[data]( + defaultOption, + tt.fn(), + ).bearerPrefix + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithClaimsCTXKey(t *testing.T) { + type testCase[T any] struct { + name string + fn func() ManagerOption[T] + want string + } + tests := []testCase[data]{ + { + name: "default", + fn: func() ManagerOption[data] { + return nil + }, + want: "claims", + }, + { + name: "set_another_ctx_key", + fn: func() ManagerOption[data] { + return WithClaimsCTXKey[data]("clm") + }, + want: "clm", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got string + if tt.fn() == nil { + got = NewManager[data]( + defaultOption, + ).claimsCTXKey + } else { + got = NewManager[data]( + defaultOption, + tt.fn(), + ).claimsCTXKey + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithExposeAccessHeader(t *testing.T) { + type testCase[T any] struct { + name string + fn func() ManagerOption[T] + want string + } + tests := []testCase[data]{ + { + name: "default", + fn: func() ManagerOption[data] { + return nil + }, + want: "x-access-token", + }, + { + name: "set_another_header", + fn: func() ManagerOption[data] { + return WithExposeAccessHeader[data]("token") + }, + want: "token", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got string + if tt.fn() == nil { + got = NewManager[data]( + defaultOption, + ).exposeAccessHeader + } else { + got = NewManager[data]( + defaultOption, + tt.fn(), + ).exposeAccessHeader + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithExposeRefreshHeader(t *testing.T) { + type testCase[T any] struct { + name string + fn func() ManagerOption[T] + want string + } + tests := []testCase[data]{ + { + name: "default", + fn: func() ManagerOption[data] { + return nil + }, + want: "x-refresh-token", + }, + { + name: "set_another_header", + fn: func() ManagerOption[data] { + return WithExposeRefreshHeader[data]("refresh-token") + }, + want: "refresh-token", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got string + if tt.fn() == nil { + got = NewManager[data]( + defaultOption, + ).exposeRefreshHeader + } else { + got = NewManager[data]( + defaultOption, + tt.fn(), + ).exposeRefreshHeader + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithIgnorePaths(t *testing.T) { + type testCase[T any] struct { + name string + fn func() ManagerOption[T] + paths []string + want []bool + } + tests := []testCase[data]{ + { + name: "default", + fn: func() ManagerOption[data] { + return nil + }, + want: []bool{}, + }, + { + name: "all_exists_paths", + fn: func() ManagerOption[data] { + return WithIgnorePaths[data]([]string{ + "/login", + "/signup", + }...) + }, + paths: []string{"/login", "/signup"}, + want: []bool{true, true}, + }, + { + name: "one_path_does_not_exist", + fn: func() ManagerOption[data] { + return WithIgnorePaths[data]([]string{ + "/login", + "/signup", + }...) + }, + paths: []string{"/login", "/profile", "/signup"}, + want: []bool{true, false, true}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var ignorePaths set.Set[string] + if tt.fn() == nil { + ignorePaths = NewManager[data]( + defaultOption, + ).publicPaths + } else { + ignorePaths = NewManager[data]( + defaultOption, + tt.fn(), + ).publicPaths + } + exists := make([]bool, 0, len(tt.paths)) + for _, path := range tt.paths { + exists = append(exists, ignorePaths.Exist(path)) + } + assert.Equal(t, tt.want, exists) + }) + } +} + +func TestWithNowFunc(t *testing.T) { + type testCase[T any] struct { + name string + fn func() ManagerOption[T] + want time.Time + } + tests := []testCase[data]{ + { + name: "default", + fn: func() ManagerOption[data] { + return nil + }, + want: time.Now(), + }, + { + name: "set_another_now_func", + fn: func() ManagerOption[data] { + return WithNowFunc[data](func() time.Time { + return nowTime + }) + }, + want: nowTime, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got time.Time + if tt.fn() == nil { + got = NewManager[data]( + defaultOption, + ).nowFunc() + } else { + got = NewManager[data]( + defaultOption, + tt.fn(), + ).nowFunc() + } + assert.Equal(t, tt.want.Unix(), got.Unix()) + }) + } +} + +func TestWithRefreshJWTOptions(t *testing.T) { + type testCase[T any] struct { + name string + fn func() ManagerOption[T] + want *Options + } + tests := []testCase[data]{ + { + name: "default", + fn: func() ManagerOption[data] { + return nil + }, + want: nil, + }, + { + name: "set_refresh_jwt_options", + fn: func() ManagerOption[data] { + return WithRefreshJWTOptions[data]( + NewOptions( + 24*60*time.Minute, + "refresh sign key", + ), + ) + }, + want: &Options{ + Expire: 24 * 60 * time.Minute, + EncryptionKey: "refresh sign key", + DecryptKey: "refresh sign key", + Method: jwt.SigningMethodHS256, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got *Options + if tt.fn() == nil { + got = NewManager[data]( + defaultOption, + ).refreshJWTOptions + } else { + got = NewManager[data]( + defaultOption, + tt.fn(), + ).refreshJWTOptions + } + assert.Equal(t, tt.want, got) + }) + } +} + +func (m *Manager[T]) registerRoutes(server *gin.Engine) { + server.GET("/", func(ctx *gin.Context) { + ctx.Status(http.StatusOK) + }) + server.GET("/login", func(ctx *gin.Context) { + ctx.Status(http.StatusOK) + }) + server.GET("refresh", m.Refresh) +} diff --git a/middlewares/jwt/types.go b/middlewares/jwt/types.go new file mode 100644 index 0000000..dd99e23 --- /dev/null +++ b/middlewares/jwt/types.go @@ -0,0 +1,19 @@ +package jwt + +import ( + "github.com/gin-gonic/gin" +) + +type LoginManager[T any] interface { + Refresh(ctx *gin.Context) // 刷新 token 的 gin.HandlerFunc + MiddlewareBuilder() gin.HandlerFunc // 登录认证的中间件 + + GenerateAccessToken(data T) (string, error) // 生成资源 token + VerifyAccessToken(token string) (RegisteredClaims[T], error) // 校验资源 token + // GenerateRefreshToken 生成刷新 token. + // 需要设置 refreshJWTOptions 否则返回 ErrEmptyRefreshOpts 错误. + GenerateRefreshToken(data T) (string, error) + // VerifyRefreshToken 校验刷新 token. + // 需要设置 refreshJWTOptions 否则返回 ErrEmptyRefreshOpts 错误. + VerifyRefreshToken(token string) (RegisteredClaims[T], error) // 校验刷新 token +} diff --git a/middlewares/token/jwt.go b/middlewares/token/jwt.go deleted file mode 100644 index 661a481..0000000 --- a/middlewares/token/jwt.go +++ /dev/null @@ -1,97 +0,0 @@ -package token - -import ( - "fmt" - "time" - - "github.com/golang-jwt/jwt/v5" -) - -type JWTToken[T jwt.Claims] struct { - encryptionKey string // 加密密钥 - decryptKey string // 解密密钥 - nowFunc func() time.Time - method jwt.SigningMethod -} - -// NewJWTToken -// method: 默认签名加密方式使用 SH256 -// decryptKey: 因默认使用对称加密所以与 encryptionKey 相同 -func NewJWTToken[T jwt.Claims](encryptionKey string, opts ...Option[T]) *JWTToken[T] { - dOpts := defaultOption[T]() - dOpts.encryptionKey = encryptionKey - dOpts.decryptKey = encryptionKey - - for _, opt := range opts { - opt.apply(&dOpts) - } - - return &dOpts -} - -type Option[T jwt.Claims] interface { - apply(*JWTToken[T]) -} - -type funcOption[T jwt.Claims] struct { - f func(*JWTToken[T]) -} - -func (fdo *funcOption[T]) apply(do *JWTToken[T]) { - fdo.f(do) -} - -func newFuncOption[T jwt.Claims](f func(*JWTToken[T])) *funcOption[T] { - return &funcOption[T]{ - f: f, - } -} - -func defaultOption[T jwt.Claims]() JWTToken[T] { - return JWTToken[T]{ - nowFunc: time.Now, - method: jwt.SigningMethodHS256, - } -} - -func WithDecryptKey[T jwt.Claims](decryptKey string) Option[T] { - return newFuncOption(func(j *JWTToken[T]) { - j.decryptKey = decryptKey - }) -} - -func WithNowFunc[T jwt.Claims](nowFunc func() time.Time) Option[T] { - return newFuncOption(func(j *JWTToken[T]) { - j.nowFunc = nowFunc - }) -} - -func WithSigningMethod[T jwt.Claims](method jwt.SigningMethod) Option[T] { - return newFuncOption(func(j *JWTToken[T]) { - j.method = method - }) -} - -// Generate 生成 jwt token. -func (j *JWTToken[T]) Generate(claims T) (string, error) { - token := jwt.NewWithClaims(j.method, claims) - return token.SignedString([]byte(j.encryptionKey)) -} - -// Verify 验证token.验证不通过则返回 error. -func (j *JWTToken[T]) Verify(token string) (T, error) { - var claimsZero T - claims := claimsZero - var claimsPtr any = &claims - t, err := jwt.ParseWithClaims(token, claimsPtr.(jwt.Claims), - func(*jwt.Token) (interface{}, error) { - return []byte(j.decryptKey), nil - }, - jwt.WithTimeFunc(j.nowFunc), - ) - if err != nil || !t.Valid { - return claimsZero, fmt.Errorf("验证失败: %v", err) - } - - return claims, nil -} diff --git a/middlewares/token/jwt_test.go b/middlewares/token/jwt_test.go deleted file mode 100644 index 4eed689..0000000 --- a/middlewares/token/jwt_test.go +++ /dev/null @@ -1,170 +0,0 @@ -package token - -import ( - "fmt" - "testing" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/stretchr/testify/assert" -) - -func TestJWTToken_Generate(t *testing.T) { - j := NewJWTToken[jwt.RegisteredClaims]("foo") - nowTime := time.UnixMilli(1695571200000) - type testCase[T jwt.Claims] struct { - name string - claims jwt.RegisteredClaims - want string - wantErr error - } - tests := []testCase[jwt.RegisteredClaims]{ - { - name: "生成token", - claims: jwt.RegisteredClaims{ - Issuer: "bar", - Subject: "1", - IssuedAt: jwt.NewNumericDate(nowTime), - ExpiresAt: jwt.NewNumericDate(nowTime.Add(10 * time.Minute)), - }, - want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.a1q3jHKedQGbA-Zrn6S21QUpI2ZNYNHoeG5LkxAXRJQ", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := j.Generate(tt.claims) - assert.Equal(t, tt.wantErr, err) - assert.Equal(t, tt.want, got) - }) - } -} - -func TestJWTToken_Verify(t *testing.T) { - j := NewJWTToken[jwt.RegisteredClaims]("foo") - type testCase[T jwt.Claims] struct { - name string - nowFunc func() time.Time - token string - want jwt.RegisteredClaims - wantErr error - } - tests := []testCase[jwt.RegisteredClaims]{ - { - name: "验证通过", - nowFunc: func() time.Time { - return time.UnixMilli(1695571500000) - }, - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.a1q3jHKedQGbA-Zrn6S21QUpI2ZNYNHoeG5LkxAXRJQ", - want: jwt.RegisteredClaims{ - Issuer: "bar", - Subject: "1", - IssuedAt: jwt.NewNumericDate(time.UnixMilli(1695571200000)), - ExpiresAt: jwt.NewNumericDate(time.UnixMilli(1695571200000).Add(10 * time.Minute)), - }, - }, - { - name: "token过期了", - nowFunc: func() time.Time { - return time.UnixMilli(1695572500000) - }, - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.a1q3jHKedQGbA-Zrn6S21QUpI2ZNYNHoeG5LkxAXRJQ", - wantErr: fmt.Errorf("验证失败: %v", - fmt.Errorf("%v: %v", jwt.ErrTokenInvalidClaims, jwt.ErrTokenExpired)), - }, - { - name: "token签名错误", - nowFunc: func() time.Time { - return time.UnixMilli(1695571500000) - }, - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJiYXIiLCJzdWIiOiIxIiwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.5OeEzR5tNTGmXwvloac2wYdZvlt8U5UmFdsnpBJ_zb4", - wantErr: fmt.Errorf("验证失败: %v", - fmt.Errorf("%v: %v", jwt.ErrTokenSignatureInvalid, jwt.ErrSignatureInvalid)), - }, - { - name: "错误的token", - nowFunc: func() time.Time { - return time.UnixMilli(1695571500000) - }, - token: "bad_token", - wantErr: fmt.Errorf("验证失败: %v: token contains an invalid number of segments", - jwt.ErrTokenMalformed), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - j.nowFunc = tt.nowFunc - got, err := j.Verify(tt.token) - assert.Equal(t, tt.wantErr, err) - assert.Equal(t, tt.want, got) - }) - } -} - -func TestWithSigningMethod(t *testing.T) { - type jwtT = jwt.RegisteredClaims - type testCase[T jwt.Claims] struct { - name string - method jwt.SigningMethod - want jwt.SigningMethod - } - tests := []testCase[jwtT]{ - { - name: "设置成功", - method: jwt.SigningMethodHS512, - want: jwt.SigningMethodHS512, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NewJWTToken[jwtT]("foo", WithSigningMethod[jwtT](tt.method)).method - assert.Equal(t, tt.want, got) - }) - } -} - -func TestWithDecryptKey(t *testing.T) { - type jwtT = jwt.RegisteredClaims - type testCase[T jwt.Claims] struct { - name string - decryptKey string - want string - } - tests := []testCase[jwtT]{ - { - name: "设置解密密钥成功", - decryptKey: "decryptKey", - want: "decryptKey", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NewJWTToken[jwtT]("foo", WithDecryptKey[jwtT](tt.decryptKey)).decryptKey - assert.Equal(t, tt.want, got) - }) - } -} - -func TestWithNowFunc(t *testing.T) { - type jwtT = jwt.RegisteredClaims - type testCase[T jwt.Claims] struct { - name string - nowFunc func() time.Time - want time.Time - } - tests := []testCase[jwtT]{ - { - name: "设置解密密钥成功", - nowFunc: func() time.Time { - return time.UnixMilli(1695571200000) - }, - want: time.UnixMilli(1695571200000), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NewJWTToken[jwtT]("foo", - WithNowFunc[jwtT](tt.nowFunc)).nowFunc() - assert.Equal(t, tt.want, got) - }) - } -} diff --git a/middlewares/token/types.go b/middlewares/token/types.go deleted file mode 100644 index bd10856..0000000 --- a/middlewares/token/types.go +++ /dev/null @@ -1,10 +0,0 @@ -package token - -import ( - "github.com/golang-jwt/jwt/v5" -) - -type Token[T jwt.Claims] interface { - Generate(claims T) (string, error) - Verify(token string) (T, error) -} From 4b25feca61d189d7bfe63161689e70e27b8bdcd7 Mon Sep 17 00:00:00 2001 From: joil Date: Thu, 12 Oct 2023 14:20:10 +0800 Subject: [PATCH 4/4] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=20JWT=20?= =?UTF-8?q?=E4=B8=80=E9=94=AE=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middlewares/jwt/claims.go | 62 ++-- middlewares/jwt/claims_test.go | 145 +++++--- middlewares/jwt/jwt.go | 201 +++++------ middlewares/jwt/jwt_test.go | 606 ++++++++++++++++----------------- middlewares/jwt/types.go | 26 +- 5 files changed, 535 insertions(+), 505 deletions(-) diff --git a/middlewares/jwt/claims.go b/middlewares/jwt/claims.go index d6c5233..44ae07a 100644 --- a/middlewares/jwt/claims.go +++ b/middlewares/jwt/claims.go @@ -3,6 +3,7 @@ package jwt import ( "time" + "github.com/ecodeclub/ekit/bean/option" "github.com/golang-jwt/jwt/v5" ) @@ -12,68 +13,57 @@ type RegisteredClaims[T any] struct { } type Options struct { - Issuer string // 签发人 Expire time.Duration // 有效期 EncryptionKey string // 加密密钥 DecryptKey string // 解密密钥 Method jwt.SigningMethod // 签名方式 + Issuer string // 签发人 + genIDFn func() string // 生成 JWT ID (jti) 的函数 } -// NewOptions 定义一个JWT Claims配置 -// Issuer: 默认使用 "". +// NewOptions 定义一个 JWT 配置. // DecryptKey: 默认与 EncryptionKey 相同. // Method: 默认使用 jwt.SigningMethodHS256 签名方式. -func NewOptions(expire time.Duration, encryptionKey string, opts ...Option) *Options { +func NewOptions(expire time.Duration, encryptionKey string, + opts ...option.Option[Options]) *Options { dOpts := Options{ Expire: expire, EncryptionKey: encryptionKey, DecryptKey: encryptionKey, Method: jwt.SigningMethodHS256, + genIDFn: func() string { return "" }, } - for _, opt := range opts { - opt.apply(&dOpts) - } + option.Apply[Options](&dOpts, opts...) return &dOpts } -type Option interface { - apply(*Options) -} - -type funcOption struct { - f func(handler *Options) -} - -func (fdo *funcOption) apply(do *Options) { - fdo.f(do) +// WithDecryptKey 设置解密密钥. +func WithDecryptKey(decryptKey string) option.Option[Options] { + return func(o *Options) { + o.DecryptKey = decryptKey + } } -func newFuncOption( - f func(*Options)) *funcOption { - return &funcOption{ - f: f, +// WithMethod 设置 JWT 的签名方法. +func WithMethod(method jwt.SigningMethod) option.Option[Options] { + return func(o *Options) { + o.Method = method } } // WithIssuer 设置签发人. -func WithIssuer(issuer string) Option { - return newFuncOption(func(o *Options) { +func WithIssuer(issuer string) option.Option[Options] { + return func(o *Options) { o.Issuer = issuer - }) -} - -// WithDecryptKey 设置解密密钥. -func WithDecryptKey(decryptKey string) Option { - return newFuncOption(func(o *Options) { - o.DecryptKey = decryptKey - }) + } } -// WithMethod 设置 JWT 的签名方法. -func WithMethod(method jwt.SigningMethod) Option { - return newFuncOption(func(o *Options) { - o.Method = method - }) +// WithGenIDFunc 设置生成 JWT ID 的函数. +// 可以设置成 WithGenIDFunc(uuid.NewString). +func WithGenIDFunc(fn func() string) option.Option[Options] { + return func(o *Options) { + o.genIDFn = fn + } } diff --git a/middlewares/jwt/claims_test.go b/middlewares/jwt/claims_test.go index 5b46aba..7fbccfd 100644 --- a/middlewares/jwt/claims_test.go +++ b/middlewares/jwt/claims_test.go @@ -4,11 +4,13 @@ import ( "testing" "time" + "github.com/ecodeclub/ekit/bean/option" "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" ) func TestNewOptions(t *testing.T) { + var genIDFn func() string tests := []struct { name string expire time.Duration @@ -24,12 +26,14 @@ func TestNewOptions(t *testing.T) { EncryptionKey: "sign key", DecryptKey: "sign key", Method: jwt.SigningMethodHS256, + genIDFn: genIDFn, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := NewOptions(tt.expire, tt.encryptionKey) + got.genIDFn = genIDFn assert.Equal(t, tt.want, got) }) } @@ -37,25 +41,71 @@ func TestNewOptions(t *testing.T) { func TestWithDecryptKey(t *testing.T) { tests := []struct { - name string - decryptKey string - want *Options + name string + fn func() option.Option[Options] + want string }{ { - name: "set_another_key", - decryptKey: "other sign key", - want: &Options{ - Expire: defaultExpire, - EncryptionKey: encryptionKey, - DecryptKey: "other sign key", - Method: jwt.SigningMethodHS256, + name: "normal", + fn: func() option.Option[Options] { + return nil + }, + want: encryptionKey, + }, + { + name: "set_another_key", + fn: func() option.Option[Options] { + return WithDecryptKey("another sign key") + }, + want: "another sign key", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got string + if tt.fn() == nil { + got = NewOptions(defaultExpire, encryptionKey). + DecryptKey + } else { + got = NewOptions(defaultExpire, encryptionKey, + tt.fn()).DecryptKey + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWithMethod(t *testing.T) { + tests := []struct { + name string + fn func() option.Option[Options] + want jwt.SigningMethod + }{ + { + name: "normal", + fn: func() option.Option[Options] { + return nil + }, + want: jwt.SigningMethodHS256, + }, + { + name: "set_another_method", + fn: func() option.Option[Options] { + return WithMethod(jwt.SigningMethodHS384) }, + want: jwt.SigningMethodHS384, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := NewOptions(defaultExpire, encryptionKey, - WithDecryptKey(tt.decryptKey)) + var got jwt.SigningMethod + if tt.fn() == nil { + got = NewOptions(defaultExpire, encryptionKey). + Method + } else { + got = NewOptions(defaultExpire, encryptionKey, + tt.fn()).Method + } assert.Equal(t, tt.want, got) }) } @@ -63,52 +113,71 @@ func TestWithDecryptKey(t *testing.T) { func TestWithIssuer(t *testing.T) { tests := []struct { - name string - issuer string - want *Options + name string + fn func() option.Option[Options] + want string }{ { - name: "set_issuer", - issuer: "foo", - want: &Options{ - Issuer: "foo", - Expire: defaultExpire, - EncryptionKey: encryptionKey, - DecryptKey: encryptionKey, - Method: jwt.SigningMethodHS256, + name: "normal", + fn: func() option.Option[Options] { + return nil }, }, + { + name: "set_another_issuer", + fn: func() option.Option[Options] { + return WithIssuer("foo") + }, + want: "foo", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := NewOptions(defaultExpire, encryptionKey, - WithIssuer(tt.issuer)) + var got string + if tt.fn() == nil { + got = NewOptions(defaultExpire, encryptionKey). + Issuer + } else { + got = NewOptions(defaultExpire, encryptionKey, + tt.fn()).Issuer + } assert.Equal(t, tt.want, got) }) } } -func TestWithMethod(t *testing.T) { +func TestWithGenIDFunc(t *testing.T) { tests := []struct { - name string - method jwt.SigningMethod - want *Options + name string + fn func() option.Option[Options] + want string }{ { - name: "set_another_jwt_signing_method", - method: jwt.SigningMethodHS384, - want: &Options{ - Expire: defaultExpire, - EncryptionKey: encryptionKey, - DecryptKey: encryptionKey, - Method: jwt.SigningMethodHS384, + name: "normal", + fn: func() option.Option[Options] { + return nil + }, + }, + { + name: "set_another_gen_id_func", + fn: func() option.Option[Options] { + return WithGenIDFunc(func() string { + return "unique id" + }) }, + want: "unique id", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := NewOptions(defaultExpire, encryptionKey, - WithMethod(tt.method)) + var got string + if tt.fn() == nil { + got = NewOptions(defaultExpire, encryptionKey). + genIDFn() + } else { + got = NewOptions(defaultExpire, encryptionKey, + tt.fn()).genIDFn() + } assert.Equal(t, tt.want, got) }) } diff --git a/middlewares/jwt/jwt.go b/middlewares/jwt/jwt.go index 5559b30..9bf6272 100644 --- a/middlewares/jwt/jwt.go +++ b/middlewares/jwt/jwt.go @@ -3,26 +3,28 @@ package jwt import ( "errors" "fmt" + "log/slog" "net/http" "strings" "time" + "github.com/ecodeclub/ekit/bean/option" "github.com/ecodeclub/ekit/set" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" ) +const bearerPrefix = "Bearer" + var ( ErrEmptyRefreshOpts = errors.New("refreshJWTOptions are nil") ) -type Manager[T any] struct { - publicPaths set.Set[string] // 存放不需要认证的 path - allowTokenHeader string // 认证的请求头(存放 token 的请求头 key) - bearerPrefix string // 拼接 token 的前缀 - claimsCTXKey string // 存放到 gin.Context 的 key - exposeAccessHeader string // 暴露到外部的资源请求头 - exposeRefreshHeader string // 暴露到外部的刷新请求头 +type Management[T any] struct { + ignorePath func(path string) bool // Middleware 方法中忽略认证的路径 + allowTokenHeader string // 认证的请求头(存放 token 的请求头 key) + exposeAccessHeader string // 暴露到外部的资源请求头 + exposeRefreshHeader string // 暴露到外部的刷新请求头 accessJWTOptions *Options // 资源 token 选项 refreshJWTOptions *Options // 刷新 token 选项 @@ -31,53 +33,32 @@ type Manager[T any] struct { nowFunc func() time.Time // 控制 jwt 的时间 } -// NewManager 定义一个 JWTLoginManger +// NewManagement 定义一个 Management. +// ignorePath: 默认使用 func(path string) bool { return false } 也就是全部不忽略. // allowTokenHeader: 默认使用 authorization 为认证请求头. -// bearerPrefix: 默认使用 Bearer 拼接 token. -// claimsCTXKey: 默认使用 claims 为设置到 gin.Context 的key // exposeAccessHeader: 默认使用 x-access-token 为暴露外部的资源请求头. // exposeRefreshHeader: 默认使用 x-refresh-token 为暴露外部的刷新请求头. // refreshJWTOptions: 默认使用 nil 为刷新 token 的配置, -// 如要使用 refresh 功能则需要使用 WithRefreshJWTOptions 添加相关配置. +// 如要使用 refresh 相关功能则需要使用 WithRefreshJWTOptions 添加相关配置. // rotateRefreshToken: 默认不轮换刷新令牌. // 该配置需要设置 refreshJWTOptions 才有效. -func NewManager[T any](accessJWTOptions *Options, - opts ...ManagerOption[T]) *Manager[T] { - dOpts := defaultManagerOption[T]() - dOpts.accessJWTOptions = accessJWTOptions +func NewManagement[T any](accessJWTOptions *Options, + opts ...option.Option[Management[T]]) *Management[T] { - for _, opt := range opts { - opt.apply(&dOpts) + if accessJWTOptions == nil { + panic("accessJWTOptions 不允许为 nil") } + dOpts := defaultManagementOptions[T]() + dOpts.accessJWTOptions = accessJWTOptions + option.Apply[Management[T]](&dOpts, opts...) return &dOpts } -type ManagerOption[T any] interface { - apply(*Manager[T]) -} - -type funcManagerOption[T any] struct { - f func(handler *Manager[T]) -} - -func (fdo *funcManagerOption[T]) apply(do *Manager[T]) { - fdo.f(do) -} - -func newFuncManagerOption[T any]( - f func(handler *Manager[T])) *funcManagerOption[T] { - return &funcManagerOption[T]{ - f: f, - } -} - -func defaultManagerOption[T any]() Manager[T] { - return Manager[T]{ - publicPaths: set.NewMapSet[string](0), +func defaultManagementOptions[T any]() Management[T] { + return Management[T]{ + ignorePath: func(path string) bool { return false }, allowTokenHeader: "authorization", - bearerPrefix: "Bearer", - claimsCTXKey: "claims", exposeAccessHeader: "x-access-token", exposeRefreshHeader: "x-refresh-token", rotateRefreshToken: false, @@ -85,91 +66,89 @@ func defaultManagerOption[T any]() Manager[T] { } } -// WithIgnorePaths 设置忽略资源令牌认证的路径. -// 例如: '/login', '/api/v1/signup'. -func WithIgnorePaths[T any](paths ...string) ManagerOption[T] { +// WithIgnorePath 设置忽略资源令牌认证的路径. +func WithIgnorePath[T any](fn func(path string) bool) option.Option[Management[T]] { + return func(m *Management[T]) { + m.ignorePath = fn + } +} + +// StaticIgnorePaths 设置静态忽略的路径. +func StaticIgnorePaths(paths ...string) func(path string) bool { s := set.NewMapSet[string](len(paths)) for _, path := range paths { s.Add(path) } - return newFuncManagerOption[T](func(l *Manager[T]) { - l.publicPaths = s - }) + return func(path string) bool { + if s.Exist(path) { + return true + } + return false + } } // WithAllowTokenHeader 设置允许 token 的请求头. -func WithAllowTokenHeader[T any](header string) ManagerOption[T] { - return newFuncManagerOption[T](func(m *Manager[T]) { +func WithAllowTokenHeader[T any](header string) option.Option[Management[T]] { + return func(m *Management[T]) { m.allowTokenHeader = header - }) -} - -// WithBearerPrefix 设置与 token 拼接的前缀. -// 例如: 'Bearer eyx.eyx.x'中的 'Bearer'. -func WithBearerPrefix[T any](prefix string) ManagerOption[T] { - return newFuncManagerOption[T](func(m *Manager[T]) { - m.bearerPrefix = prefix - }) -} - -// WithClaimsCTXKey 设置放到 gin.Context 中的 key. -func WithClaimsCTXKey[T any](key string) ManagerOption[T] { - return newFuncManagerOption[T](func(m *Manager[T]) { - m.claimsCTXKey = key - }) + } } // WithExposeAccessHeader 设置公开资源令牌的请求头. -func WithExposeAccessHeader[T any](header string) ManagerOption[T] { - return newFuncManagerOption[T](func(m *Manager[T]) { +func WithExposeAccessHeader[T any](header string) option.Option[Management[T]] { + return func(m *Management[T]) { m.exposeAccessHeader = header - }) + } } // WithExposeRefreshHeader 设置公开刷新令牌的请求头. -func WithExposeRefreshHeader[T any](header string) ManagerOption[T] { - return newFuncManagerOption[T](func(m *Manager[T]) { +func WithExposeRefreshHeader[T any](header string) option.Option[Management[T]] { + return func(m *Management[T]) { m.exposeRefreshHeader = header - }) + } } // WithRefreshJWTOptions 设置刷新令牌相关的配置. -func WithRefreshJWTOptions[T any](refreshOpts *Options) ManagerOption[T] { - return newFuncManagerOption(func(m *Manager[T]) { +func WithRefreshJWTOptions[T any](refreshOpts *Options) option.Option[Management[T]] { + return func(m *Management[T]) { m.refreshJWTOptions = refreshOpts - }) + } } // WithRotateRefreshToken 设置轮换刷新令牌. -func WithRotateRefreshToken[T any](isRotate bool) ManagerOption[T] { - return newFuncManagerOption(func(m *Manager[T]) { +func WithRotateRefreshToken[T any](isRotate bool) option.Option[Management[T]] { + return func(m *Management[T]) { m.rotateRefreshToken = isRotate - }) + } } // WithNowFunc 设置当前时间. -// 一般用于测试. -func WithNowFunc[T any](nowFunc func() time.Time) ManagerOption[T] { - return newFuncManagerOption(func(m *Manager[T]) { +// 一般用于测试固定 jwt. +func WithNowFunc[T any](nowFunc func() time.Time) option.Option[Management[T]] { + return func(m *Management[T]) { m.nowFunc = nowFunc - }) + } } -// Refresh 刷新 token 的 gin.HandlerFunc -func (m *Manager[T]) Refresh(ctx *gin.Context) { +// Refresh 刷新 token 的 gin.HandlerFunc. +func (m *Management[T]) Refresh(ctx *gin.Context) { if m.refreshJWTOptions == nil { + slog.Error("refreshJWTOptions 为 nil, 请使用 WithRefreshJWTOptions 设置 refresh 相关的配置") ctx.Status(http.StatusInternalServerError) return } tokenStr := m.extractTokenString(ctx) - clm, err := m.VerifyRefreshToken(tokenStr) + clm, err := m.VerifyRefreshToken(tokenStr, + jwt.WithTimeFunc(m.nowFunc)) if err != nil { + slog.Debug("refresh token verification failed") ctx.Status(http.StatusUnauthorized) return } accessToken, err := m.GenerateAccessToken(clm.Data) if err != nil { + slog.Error("failed to generate access token") ctx.Status(http.StatusInternalServerError) return } @@ -179,44 +158,53 @@ func (m *Manager[T]) Refresh(ctx *gin.Context) { if m.rotateRefreshToken { refreshToken, err := m.GenerateRefreshToken(clm.Data) if err != nil { + slog.Error("failed to generate refresh token") ctx.Status(http.StatusInternalServerError) return } ctx.Header(m.exposeRefreshHeader, refreshToken) } - ctx.Status(http.StatusOK) + ctx.Status(http.StatusNoContent) } -// MiddlewareBuilder 登录认证的中间件 -func (m *Manager[T]) MiddlewareBuilder() gin.HandlerFunc { +// Middleware 登录认证的中间件. +func (m *Management[T]) Middleware() gin.HandlerFunc { return func(ctx *gin.Context) { // 不需要校验 - if m.publicPaths.Exist(ctx.Request.URL.Path) { + if m.ignorePath(ctx.Request.URL.Path) { return } + // 提取 token tokenStr := m.extractTokenString(ctx) if tokenStr == "" { + slog.Debug("failed to extract token") ctx.AbortWithStatus(http.StatusUnauthorized) return } - err := m.verifyTokenAndSetClm(ctx, tokenStr) + // 校验 token + clm, err := m.VerifyAccessToken(tokenStr, + jwt.WithTimeFunc(m.nowFunc)) if err != nil { + slog.Debug("access token verification failed") ctx.AbortWithStatus(http.StatusUnauthorized) return } + + // 设置 claims + m.SetClaims(ctx, clm) } } // extractTokenString 提取 token 字符串. -func (m *Manager[T]) extractTokenString(ctx *gin.Context) string { +func (m *Management[T]) extractTokenString(ctx *gin.Context) string { authCode := ctx.GetHeader(m.allowTokenHeader) if authCode == "" { return "" } var b strings.Builder - b.WriteString(m.bearerPrefix) + b.WriteString(bearerPrefix) b.WriteString(" ") prefix := b.String() if strings.HasPrefix(authCode, prefix) { @@ -225,26 +213,16 @@ func (m *Manager[T]) extractTokenString(ctx *gin.Context) string { return "" } -// verifyTokenAndSetClm 校验 access token 并把 claims 设置到 gin.Context 中. -func (m *Manager[T]) verifyTokenAndSetClm(ctx *gin.Context, token string) error { - claims, err := m.VerifyAccessToken(token) - if err != nil { - return err - } - ctx.Set(m.claimsCTXKey, claims) - return nil -} - // GenerateAccessToken 生成资源 token. -func (m *Manager[T]) GenerateAccessToken(data T) (string, error) { +func (m *Management[T]) GenerateAccessToken(data T) (string, error) { nowTime := m.nowFunc() claims := RegisteredClaims[T]{ Data: data, RegisteredClaims: jwt.RegisteredClaims{ Issuer: m.accessJWTOptions.Issuer, ExpiresAt: jwt.NewNumericDate(nowTime.Add(m.accessJWTOptions.Expire)), - NotBefore: jwt.NewNumericDate(nowTime), IssuedAt: jwt.NewNumericDate(nowTime), + ID: m.accessJWTOptions.genIDFn(), }, } @@ -253,12 +231,12 @@ func (m *Manager[T]) GenerateAccessToken(data T) (string, error) { } // VerifyAccessToken 校验资源 token. -func (m *Manager[T]) VerifyAccessToken(token string) (RegisteredClaims[T], error) { +func (m *Management[T]) VerifyAccessToken(token string, opts ...jwt.ParserOption) (RegisteredClaims[T], error) { t, err := jwt.ParseWithClaims(token, &RegisteredClaims[T]{}, func(*jwt.Token) (interface{}, error) { return []byte(m.accessJWTOptions.DecryptKey), nil }, - jwt.WithTimeFunc(m.nowFunc), + opts..., ) if err != nil || !t.Valid { return RegisteredClaims[T]{}, fmt.Errorf("验证失败: %v", err) @@ -269,7 +247,7 @@ func (m *Manager[T]) VerifyAccessToken(token string) (RegisteredClaims[T], error // GenerateRefreshToken 生成刷新 token. // 需要设置 refreshJWTOptions 否则返回 ErrEmptyRefreshOpts 错误. -func (m *Manager[T]) GenerateRefreshToken(data T) (string, error) { +func (m *Management[T]) GenerateRefreshToken(data T) (string, error) { if m.refreshJWTOptions == nil { return "", ErrEmptyRefreshOpts } @@ -280,8 +258,8 @@ func (m *Manager[T]) GenerateRefreshToken(data T) (string, error) { RegisteredClaims: jwt.RegisteredClaims{ Issuer: m.refreshJWTOptions.Issuer, ExpiresAt: jwt.NewNumericDate(nowTime.Add(m.refreshJWTOptions.Expire)), - NotBefore: jwt.NewNumericDate(nowTime), IssuedAt: jwt.NewNumericDate(nowTime), + ID: m.refreshJWTOptions.genIDFn(), }, } @@ -291,7 +269,7 @@ func (m *Manager[T]) GenerateRefreshToken(data T) (string, error) { // VerifyRefreshToken 校验刷新 token. // 需要设置 refreshJWTOptions 否则返回 ErrEmptyRefreshOpts 错误. -func (m *Manager[T]) VerifyRefreshToken(token string) (RegisteredClaims[T], error) { +func (m *Management[T]) VerifyRefreshToken(token string, opts ...jwt.ParserOption) (RegisteredClaims[T], error) { if m.refreshJWTOptions == nil { return RegisteredClaims[T]{}, ErrEmptyRefreshOpts } @@ -299,7 +277,7 @@ func (m *Manager[T]) VerifyRefreshToken(token string) (RegisteredClaims[T], erro func(*jwt.Token) (interface{}, error) { return []byte(m.refreshJWTOptions.DecryptKey), nil }, - jwt.WithTimeFunc(m.nowFunc), + opts..., ) if err != nil || !t.Valid { return RegisteredClaims[T]{}, fmt.Errorf("验证失败: %v", err) @@ -307,3 +285,8 @@ func (m *Manager[T]) VerifyRefreshToken(token string) (RegisteredClaims[T], erro clm, _ := t.Claims.(*RegisteredClaims[T]) return *clm, nil } + +// SetClaims 设置 claims 到 key=`claims` 的 gin.Context 中. +func (m *Management[T]) SetClaims(ctx *gin.Context, claims RegisteredClaims[T]) { + ctx.Set("claims", claims) +} diff --git a/middlewares/jwt/jwt_test.go b/middlewares/jwt/jwt_test.go index 92c92e4..9f01891 100644 --- a/middlewares/jwt/jwt_test.go +++ b/middlewares/jwt/jwt_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/ecodeclub/ekit/set" + "github.com/ecodeclub/ekit/bean/option" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" @@ -23,91 +23,32 @@ var ( Data: data{Foo: "1"}, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(nowTime.Add(defaultExpire)), - NotBefore: jwt.NewNumericDate(nowTime), IssuedAt: jwt.NewNumericDate(nowTime), }, } - encryptionKey = "sign key" - nowTime = time.UnixMilli(1695571200000) - defaultOption = &Options{ - Expire: defaultExpire, - EncryptionKey: encryptionKey, - DecryptKey: encryptionKey, - Method: jwt.SigningMethodHS256, + encryptionKey = "sign key" + nowTime = time.UnixMilli(1695571200000) + defaultOption = NewOptions(defaultExpire, encryptionKey) + defaultIgnorePaths = func(path string) bool { + ignorePaths := []string{"/login", "/signup"} + for _, ignorePath := range ignorePaths { + if path == ignorePath { + return true + } + } + return false } - defaultManager = NewManager[data](defaultOption, + defaultManagement = NewManagement[data](defaultOption, WithNowFunc[data](func() time.Time { return nowTime }), ) ) -func TestManager_GenerateAccessToken(t *testing.T) { - m := defaultManager - type testCase[T any] struct { - name string - data T - want string - wantErr error - } - tests := []testCase[data]{ - { - name: "normal", - data: data{Foo: "1"}, - want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.UNuVOmAwgR-atNOMVi9JldtT7qGl7LCFuyq4uiYgg_Y", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := m.GenerateAccessToken(tt.data) - assert.Equal(t, tt.wantErr, err) - assert.Equal(t, tt.want, got) - }) - } -} - -func TestManager_GenerateRefreshToken(t *testing.T) { - m := defaultManager - type testCase[T any] struct { - name string - refreshJWTOptions *Options - data T - want string - wantErr error - } - tests := []testCase[data]{ - { - name: "normal", - refreshJWTOptions: &Options{ - Expire: 24 * 60 * time.Minute, - EncryptionKey: "refresh sign key", - DecryptKey: "refresh sign key", - Method: jwt.SigningMethodHS256, - }, - data: data{Foo: "1"}, - want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw", - }, - { - name: "mistake", - data: data{Foo: "1"}, - want: "", - wantErr: ErrEmptyRefreshOpts, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m.refreshJWTOptions = tt.refreshJWTOptions - got, err := m.GenerateRefreshToken(tt.data) - assert.Equal(t, tt.wantErr, err) - assert.Equal(t, tt.want, got) - }) - } -} - -func TestManager_MiddlewareBuilder(t *testing.T) { +func TestManagement_Middleware(t *testing.T) { type testCase[T any] struct { name string - m *Manager[T] + m *Management[T] reqBuilder func(t *testing.T) *http.Request wantCode int } @@ -115,14 +56,14 @@ func TestManager_MiddlewareBuilder(t *testing.T) { { // 验证失败 name: "verify_failed", - m: NewManager[data](defaultOption, - WithIgnorePaths[data]("/login")), + m: NewManagement[data](defaultOption, + WithIgnorePath[data](StaticIgnorePaths("/login"))), reqBuilder: func(t *testing.T) *http.Request { req, err := http.NewRequest(http.MethodGet, "/", nil) if err != nil { t.Fatal(err) } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.UNuVOmAwgR-atNOMVi9JldtT7qGl7LCFuyq4uiYgg_Y") + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.RMpM5YNgxl9OtCy4lt_JRxv6k8s6plCkthnAV-vbXEQ") return req }, wantCode: http.StatusUnauthorized, @@ -130,8 +71,8 @@ func TestManager_MiddlewareBuilder(t *testing.T) { { // 提取 token 失败 name: "extract_token_failed", - m: NewManager[data](defaultOption, - WithIgnorePaths[data]("/login")), + m: NewManagement[data](defaultOption, + WithIgnorePath[data](StaticIgnorePaths("/login"))), reqBuilder: func(t *testing.T) *http.Request { req, err := http.NewRequest(http.MethodGet, "/", nil) if err != nil { @@ -145,8 +86,8 @@ func TestManager_MiddlewareBuilder(t *testing.T) { { // 无需认证直接通过 name: "pass_without_authentication", - m: NewManager[data](defaultOption, - WithIgnorePaths[data]("/login")), + m: NewManagement[data](defaultOption, + WithIgnorePath[data](StaticIgnorePaths("/login"))), reqBuilder: func(t *testing.T) *http.Request { req, err := http.NewRequest(http.MethodGet, "/login", nil) if err != nil { @@ -159,8 +100,8 @@ func TestManager_MiddlewareBuilder(t *testing.T) { { // 验证通过 name: "pass_the_verification", - m: NewManager[data](defaultOption, - WithIgnorePaths[data]("/login"), + m: NewManagement[data](defaultOption, + WithIgnorePath[data](StaticIgnorePaths("/login")), WithNowFunc[data](func() time.Time { return time.UnixMilli(1695571500000) }), @@ -170,7 +111,7 @@ func TestManager_MiddlewareBuilder(t *testing.T) { if err != nil { t.Fatal(err) } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.UNuVOmAwgR-atNOMVi9JldtT7qGl7LCFuyq4uiYgg_Y") + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.RMpM5YNgxl9OtCy4lt_JRxv6k8s6plCkthnAV-vbXEQ") return req }, wantCode: http.StatusOK, @@ -179,7 +120,7 @@ func TestManager_MiddlewareBuilder(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { server := gin.Default() - server.Use(tt.m.MiddlewareBuilder()) + server.Use(tt.m.Middleware()) tt.m.registerRoutes(server) req := tt.reqBuilder(t) @@ -191,10 +132,10 @@ func TestManager_MiddlewareBuilder(t *testing.T) { } } -func TestManager_Refresh(t *testing.T) { +func TestManagement_Refresh(t *testing.T) { type testCase[T any] struct { name string - m *Manager[T] + m *Management[T] reqBuilder func(t *testing.T) *http.Request wantCode int wantAccessToken string @@ -204,7 +145,7 @@ func TestManager_Refresh(t *testing.T) { { // 更新资源令牌并轮换刷新令牌 name: "refresh_access_token_and_rotate_refresh_token", - m: NewManager[data](defaultOption, + m: NewManagement[data](defaultOption, WithRefreshJWTOptions[data]( NewOptions(24*60*time.Minute, "refresh sign key", @@ -219,17 +160,17 @@ func TestManager_Refresh(t *testing.T) { if err != nil { t.Fatal(err) } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw") + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE") return req }, - wantCode: http.StatusOK, - wantAccessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjIzNjAwLCJuYmYiOjE2OTU2MjMwMDAsImlhdCI6MTY5NTYyMzAwMH0.5Hv-Gq8RW0xAFBh4WhKc0KDLsdgTEv3RUhPceaM4e5M", - wantRefreshToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NzA5NDAwLCJuYmYiOjE2OTU2MjMwMDAsImlhdCI6MTY5NTYyMzAwMH0.4R-JmqcKHtsoFOGFDe5SBA2wNV0F-XvnP2Janp6NfZY", + wantCode: http.StatusNoContent, + wantAccessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjIzNjAwLCJpYXQiOjE2OTU2MjMwMDB9.i4kCx4-s5EM0a8w2o0usSfkMTLmzUSuEe-inlzg6ru0", + wantRefreshToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NzA5NDAwLCJpYXQiOjE2OTU2MjMwMDB9.IzPgEwXgoAwaFK-eby4uMl0GYBQwdfZYRi2Bhk3iE_8", }, { // 更新资源令牌但轮换刷新令牌生成失败 name: "refresh_access_token_but_gen_rotate_refresh_token_failed", - m: NewManager[data](defaultOption, + m: NewManagement[data](defaultOption, WithRefreshJWTOptions[data]( NewOptions(24*60*time.Minute, "refresh sign key", @@ -245,7 +186,7 @@ func TestManager_Refresh(t *testing.T) { if err != nil { t.Fatal(err) } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw") + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE") return req }, wantCode: http.StatusInternalServerError, @@ -253,7 +194,7 @@ func TestManager_Refresh(t *testing.T) { { // 更新资源令牌 name: "refresh_access_token", - m: NewManager[data](defaultOption, + m: NewManagement[data](defaultOption, WithRefreshJWTOptions[data]( NewOptions(24*60*time.Minute, "refresh sign key", @@ -267,21 +208,22 @@ func TestManager_Refresh(t *testing.T) { if err != nil { t.Fatal(err) } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw") + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE") return req }, - wantCode: http.StatusOK, - wantAccessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjIzNjAwLCJuYmYiOjE2OTU2MjMwMDAsImlhdCI6MTY5NTYyMzAwMH0.5Hv-Gq8RW0xAFBh4WhKc0KDLsdgTEv3RUhPceaM4e5M", + wantCode: http.StatusNoContent, + wantAccessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjIzNjAwLCJpYXQiOjE2OTU2MjMwMDB9.i4kCx4-s5EM0a8w2o0usSfkMTLmzUSuEe-inlzg6ru0", }, { // 生成资源令牌失败 name: "gen_access_token_failed", - m: NewManager[data]( + m: NewManagement[data]( &Options{ Expire: 10 * time.Minute, EncryptionKey: encryptionKey, DecryptKey: encryptionKey, Method: jwt.SigningMethodRS256, + genIDFn: func() string { return "" }, }, WithRefreshJWTOptions[data]( NewOptions(24*60*time.Minute, @@ -296,7 +238,7 @@ func TestManager_Refresh(t *testing.T) { if err != nil { t.Fatal(err) } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw") + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE") return req }, wantCode: http.StatusInternalServerError, @@ -304,7 +246,7 @@ func TestManager_Refresh(t *testing.T) { { // 刷新令牌认证失败 name: "refresh_token_verify_failed", - m: NewManager[data]( + m: NewManagement[data]( defaultOption, WithRefreshJWTOptions[data]( NewOptions(24*60*time.Minute, @@ -319,7 +261,7 @@ func TestManager_Refresh(t *testing.T) { if err != nil { t.Fatal(err) } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw") + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE") return req }, wantCode: http.StatusUnauthorized, @@ -327,7 +269,7 @@ func TestManager_Refresh(t *testing.T) { { // 没有设置刷新令牌选项 name: "not_set_refreshJWTOptions", - m: NewManager[data]( + m: NewManagement[data]( defaultOption, WithNowFunc[data](func() time.Time { return time.UnixMilli(1695723000000) @@ -338,7 +280,7 @@ func TestManager_Refresh(t *testing.T) { if err != nil { t.Fatal(err) } - req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw") + req.Header.Add("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE") return req }, wantCode: http.StatusInternalServerError, @@ -365,10 +307,34 @@ func TestManager_Refresh(t *testing.T) { } } -func TestManager_VerifyAccessToken(t *testing.T) { +func TestManagement_GenerateAccessToken(t *testing.T) { + m := defaultManagement type testCase[T any] struct { name string - m *Manager[T] + data T + want string + wantErr error + } + tests := []testCase[data]{ + { + name: "normal", + data: data{Foo: "1"}, + want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.RMpM5YNgxl9OtCy4lt_JRxv6k8s6plCkthnAV-vbXEQ", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := m.GenerateAccessToken(tt.data) + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestManagement_VerifyAccessToken(t *testing.T) { + type testCase[T any] struct { + name string + m *Management[T] token string want RegisteredClaims[T] wantErr error @@ -376,44 +342,39 @@ func TestManager_VerifyAccessToken(t *testing.T) { tests := []testCase[data]{ { name: "normal", - m: defaultManager, - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.UNuVOmAwgR-atNOMVi9JldtT7qGl7LCFuyq4uiYgg_Y", + m: defaultManagement, + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.RMpM5YNgxl9OtCy4lt_JRxv6k8s6plCkthnAV-vbXEQ", want: defaultClaims, }, { // token 过期了 name: "token_expired", - m: NewManager[data](defaultOption, + m: NewManagement[data](defaultOption, WithNowFunc[data](func() time.Time { return time.UnixMilli(1695671200000) }), ), - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.UNuVOmAwgR-atNOMVi9JldtT7qGl7LCFuyq4uiYgg_Y", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.RMpM5YNgxl9OtCy4lt_JRxv6k8s6plCkthnAV-vbXEQ", wantErr: fmt.Errorf("验证失败: %v", fmt.Errorf("%v: %v", jwt.ErrTokenInvalidClaims, jwt.ErrTokenExpired)), }, { // token 签名错误 name: "bad_sign_key", - m: NewManager[data]( - &Options{ - Expire: defaultExpire, - EncryptionKey: encryptionKey, - DecryptKey: "bad sign key", - Method: jwt.SigningMethodHS256, - }, + m: NewManagement[data]( + defaultOption, WithNowFunc[data](func() time.Time { return nowTime }), ), - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.UNuVOmAwgR-atNOMVi9JldtT7qGl7LCFuyq4uiYgg_Y", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.pnP991l48s_j4fkiZnmh48gjgDGult9Or_wLChHvYp0", wantErr: fmt.Errorf("验证失败: %v", fmt.Errorf("%v: %v", jwt.ErrTokenSignatureInvalid, jwt.ErrSignatureInvalid)), }, { // 错误的 token name: "bad_token", - m: defaultManager, + m: defaultManagement, token: "bad_token", wantErr: fmt.Errorf("验证失败: %v: token contains an invalid number of segments", jwt.ErrTokenMalformed), @@ -421,17 +382,57 @@ func TestManager_VerifyAccessToken(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.m.VerifyAccessToken(tt.token) + got, err := tt.m.VerifyAccessToken(tt.token, + jwt.WithTimeFunc(tt.m.nowFunc)) + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestManagement_GenerateRefreshToken(t *testing.T) { + m := defaultManagement + type testCase[T any] struct { + name string + refreshJWTOptions *Options + data T + want string + wantErr error + } + tests := []testCase[data]{ + { + name: "normal", + refreshJWTOptions: NewOptions(24*60*time.Minute, "refresh sign key"), + data: data{Foo: "1"}, + want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE", + }, + { + name: "mistake", + data: data{Foo: "1"}, + want: "", + wantErr: ErrEmptyRefreshOpts, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m.refreshJWTOptions = tt.refreshJWTOptions + got, err := m.GenerateRefreshToken(tt.data) assert.Equal(t, tt.wantErr, err) assert.Equal(t, tt.want, got) }) } } -func TestManager_VerifyRefreshToken(t *testing.T) { +func TestManagement_VerifyRefreshToken(t *testing.T) { + defaultRefOpts := &Options{ + Expire: 24 * 60 * time.Minute, + EncryptionKey: "refresh sign key", + DecryptKey: "refresh sign key", + Method: jwt.SigningMethodHS256, + } type testCase[T any] struct { name string - m *Manager[T] + m *Management[T] token string want RegisteredClaims[T] wantErr error @@ -439,23 +440,17 @@ func TestManager_VerifyRefreshToken(t *testing.T) { tests := []testCase[data]{ { name: "normal", - m: NewManager[data](defaultOption, + m: NewManagement[data](defaultOption, WithNowFunc[data](func() time.Time { return time.UnixMilli(1695601200000) }), - WithRefreshJWTOptions[data](&Options{ - Expire: 24 * 60 * time.Minute, - EncryptionKey: "refresh sign key", - DecryptKey: "refresh sign key", - Method: jwt.SigningMethodHS256, - }), + WithRefreshJWTOptions[data](defaultRefOpts), ), - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE", want: RegisteredClaims[data]{ Data: data{Foo: "1"}, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(nowTime.Add(24 * 60 * time.Minute)), - NotBefore: jwt.NewNumericDate(nowTime), IssuedAt: jwt.NewNumericDate(nowTime), }, }, @@ -463,52 +458,37 @@ func TestManager_VerifyRefreshToken(t *testing.T) { { // token 过期了 name: "token_expired", - m: NewManager[data](defaultOption, + m: NewManagement[data](defaultOption, WithNowFunc[data](func() time.Time { return time.UnixMilli(1695701200000) }), - WithRefreshJWTOptions[data](&Options{ - Expire: 24 * 60 * time.Minute, - EncryptionKey: "refresh sign key", - DecryptKey: "refresh sign key", - Method: jwt.SigningMethodHS256, - }), + WithRefreshJWTOptions[data](defaultRefOpts), ), - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE", wantErr: fmt.Errorf("验证失败: %v", fmt.Errorf("%v: %v", jwt.ErrTokenInvalidClaims, jwt.ErrTokenExpired)), }, { // token 签名错误 name: "bad_sign_key", - m: NewManager[data](defaultOption, + m: NewManagement[data](defaultOption, WithNowFunc[data](func() time.Time { return time.UnixMilli(1695601200000) }), - WithRefreshJWTOptions[data](&Options{ - Expire: 24 * 60 * time.Minute, - EncryptionKey: "bad refresh sign key", - DecryptKey: "bad refresh sign key", - Method: jwt.SigningMethodHS256, - }), + WithRefreshJWTOptions[data](defaultRefOpts), ), - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.yZ_ZlD1jE-0b3qd0bicTDLSdwGsenv6tRmOEqMCM2uw", wantErr: fmt.Errorf("验证失败: %v", fmt.Errorf("%v: %v", jwt.ErrTokenSignatureInvalid, jwt.ErrSignatureInvalid)), }, { // 错误的 token name: "bad_token", - m: NewManager[data](defaultOption, + m: NewManagement[data](defaultOption, WithNowFunc[data](func() time.Time { return time.UnixMilli(1695601200000) }), - WithRefreshJWTOptions[data](&Options{ - Expire: 24 * 60 * time.Minute, - EncryptionKey: "refresh sign key", - DecryptKey: "refresh sign key", - Method: jwt.SigningMethodHS256, - }), + WithRefreshJWTOptions[data](defaultRefOpts), ), token: "bad_token", wantErr: fmt.Errorf("验证失败: %v: token contains an invalid number of segments", @@ -516,26 +496,60 @@ func TestManager_VerifyRefreshToken(t *testing.T) { }, { name: "no_refresh_options", - m: NewManager[data](defaultOption, + m: NewManagement[data](defaultOption, WithNowFunc[data](func() time.Time { return time.UnixMilli(1695601200000) }), ), - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.yb0pocXbtJuZziA6Ugs3wcYOAslrIk1-C_NpKgTrNVw", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NjU3NjAwLCJpYXQiOjE2OTU1NzEyMDB9.y2AQ98i0le5AbmJFgYCAfCVAphd_9NecmHdhtehMSZE", wantErr: ErrEmptyRefreshOpts, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.m.VerifyRefreshToken(tt.token) + got, err := tt.m.VerifyRefreshToken(tt.token, + jwt.WithTimeFunc(tt.m.nowFunc)) assert.Equal(t, tt.wantErr, err) assert.Equal(t, tt.want, got) }) } } -func TestManager_extractTokenString(t *testing.T) { - m := defaultManager +func TestManagement_SetClaims(t *testing.T) { + m := defaultManagement + type testCase[T any] struct { + name string + claims RegisteredClaims[T] + want RegisteredClaims[T] + wantErr error + } + tests := []testCase[data]{ + { + name: "normal", + claims: defaultClaims, + want: defaultClaims, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, _ := gin.CreateTestContext(httptest.NewRecorder()) + m.SetClaims(ctx, tt.claims) + v, ok := ctx.Get("claims") + if !ok { + t.Errorf("claims not found") + } + clm, ok := v.(RegisteredClaims[data]) + if !ok { + t.Errorf("claims type error") + } + assert.Equal(t, tt.want, clm) + }) + } +} + +func TestManagement_extractTokenString(t *testing.T) { + m := defaultManagement type header struct { key string value string @@ -582,176 +596,153 @@ func TestManager_extractTokenString(t *testing.T) { } } -func TestManager_verifyTokenAndSetClm(t *testing.T) { +func TestNewManagement(t *testing.T) { type testCase[T any] struct { - name string - m *Manager[T] - token string - want RegisteredClaims[T] - wantErr error + name string + accessJWTOptions *Options + wantPanic bool } tests := []testCase[data]{ { - name: "normal", - m: defaultManager, - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.UNuVOmAwgR-atNOMVi9JldtT7qGl7LCFuyq4uiYgg_Y", - want: defaultClaims, + name: "normal", + accessJWTOptions: defaultOption, + wantPanic: false, }, { - name: "verify_access_token_failed", - m: NewManager[data]( - defaultOption, - WithNowFunc[data](func() time.Time { - return time.UnixMilli(1695671200000) - }), - ), - token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJuYmYiOjE2OTU1NzEyMDAsImlhdCI6MTY5NTU3MTIwMH0.UNuVOmAwgR-atNOMVi9JldtT7qGl7LCFuyq4uiYgg_Y", - want: RegisteredClaims[data]{}, - wantErr: fmt.Errorf("验证失败: %v", - fmt.Errorf("%v: %v", jwt.ErrTokenInvalidClaims, jwt.ErrTokenExpired)), + name: "accessJWTOptions_are_nil", + accessJWTOptions: nil, + wantPanic: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - recorder := httptest.NewRecorder() - ctx, _ := gin.CreateTestContext(recorder) - req, err := http.NewRequest(http.MethodGet, "", nil) - if err != nil { - t.Fatal(err) - } - ctx.Request = req - err = tt.m.verifyTokenAndSetClm(ctx, tt.token) - assert.Equal(t, tt.wantErr, err) - if err != nil { - return - } - v, ok := ctx.Get("claims") - if !ok { - t.Error("claims设置失败") - } - clm, ok := v.(RegisteredClaims[data]) - if !ok { - t.Error("claims不是 RegisteredClaims[T] 类型") - } - assert.Equal(t, tt.want, clm) + defer func() { + if err := recover(); err != nil { + if !tt.wantPanic { + t.Errorf("期望出现 painc ,但没有") + } + } + }() + NewManagement[data](tt.accessJWTOptions) }) } } -func TestWithAllowTokenHeader(t *testing.T) { +func TestWithIgnorePath(t *testing.T) { type testCase[T any] struct { - name string - fn func() ManagerOption[T] - want string + name string + fn func() option.Option[Management[T]] + paths []string + want []bool } tests := []testCase[data]{ { name: "default", - fn: func() ManagerOption[data] { + fn: func() option.Option[Management[data]] { return nil }, - want: "authorization", + paths: []string{"profile", "abc"}, + want: []bool{false, false}, }, { - name: "set_another_header", - fn: func() ManagerOption[data] { - return WithAllowTokenHeader[data]("jwt") + name: "all_exists_paths", + fn: func() option.Option[Management[data]] { + return WithIgnorePath[data](defaultIgnorePaths) }, - want: "jwt", + paths: []string{"/login", "/signup"}, + want: []bool{true, true}, + }, + { + name: "one_path_does_not_exist", + fn: func() option.Option[Management[data]] { + return WithIgnorePath[data](defaultIgnorePaths) + }, + paths: []string{"/login", "/profile", "/signup"}, + want: []bool{true, false, true}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var got string + var ignoreFn func(path string) bool if tt.fn() == nil { - got = NewManager[data]( + ignoreFn = NewManagement[data]( defaultOption, - ).allowTokenHeader + ).ignorePath } else { - got = NewManager[data]( + ignoreFn = NewManagement[data]( defaultOption, tt.fn(), - ).allowTokenHeader + ).ignorePath } - assert.Equal(t, tt.want, got) + exists := make([]bool, 0, len(tt.paths)) + for _, path := range tt.paths { + exists = append(exists, ignoreFn(path)) + } + assert.Equal(t, tt.want, exists) }) } } -func TestWithBearerPrefix(t *testing.T) { - type testCase[T any] struct { - name string - fn func() ManagerOption[T] - want string - } - tests := []testCase[data]{ +func TestStaticIgnorePaths(t *testing.T) { + tests := []struct { + name string + paths []string + requestPaths []string + want []bool + }{ { - name: "default", - fn: func() ManagerOption[data] { - return nil - }, - want: "Bearer", - }, - { - name: "set_another_prefix", - fn: func() ManagerOption[data] { - return WithBearerPrefix[data]("jwt") - }, - want: "jwt", + name: "normal", + paths: []string{"login", "signup"}, + requestPaths: []string{"profile", "login", "info", "signup"}, + want: []bool{false, true, false, true}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var got string - if tt.fn() == nil { - got = NewManager[data]( - defaultOption, - ).bearerPrefix - } else { - got = NewManager[data]( - defaultOption, - tt.fn(), - ).bearerPrefix + gotBool := make([]bool, 0, len(tt.want)) + fn := StaticIgnorePaths(tt.paths...) + for _, path := range tt.requestPaths { + gotBool = append(gotBool, fn(path)) } - assert.Equal(t, tt.want, got) + assert.Equal(t, tt.want, gotBool) }) } } -func TestWithClaimsCTXKey(t *testing.T) { +func TestWithAllowTokenHeader(t *testing.T) { type testCase[T any] struct { name string - fn func() ManagerOption[T] + fn func() option.Option[Management[T]] want string } tests := []testCase[data]{ { name: "default", - fn: func() ManagerOption[data] { + fn: func() option.Option[Management[data]] { return nil }, - want: "claims", + want: "authorization", }, { - name: "set_another_ctx_key", - fn: func() ManagerOption[data] { - return WithClaimsCTXKey[data]("clm") + name: "set_another_header", + fn: func() option.Option[Management[data]] { + return WithAllowTokenHeader[data]("jwt") }, - want: "clm", + want: "jwt", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got string if tt.fn() == nil { - got = NewManager[data]( + got = NewManagement[data]( defaultOption, - ).claimsCTXKey + ).allowTokenHeader } else { - got = NewManager[data]( + got = NewManagement[data]( defaultOption, tt.fn(), - ).claimsCTXKey + ).allowTokenHeader } assert.Equal(t, tt.want, got) }) @@ -761,20 +752,20 @@ func TestWithClaimsCTXKey(t *testing.T) { func TestWithExposeAccessHeader(t *testing.T) { type testCase[T any] struct { name string - fn func() ManagerOption[T] + fn func() option.Option[Management[T]] want string } tests := []testCase[data]{ { name: "default", - fn: func() ManagerOption[data] { + fn: func() option.Option[Management[data]] { return nil }, want: "x-access-token", }, { name: "set_another_header", - fn: func() ManagerOption[data] { + fn: func() option.Option[Management[data]] { return WithExposeAccessHeader[data]("token") }, want: "token", @@ -784,11 +775,11 @@ func TestWithExposeAccessHeader(t *testing.T) { t.Run(tt.name, func(t *testing.T) { var got string if tt.fn() == nil { - got = NewManager[data]( + got = NewManagement[data]( defaultOption, ).exposeAccessHeader } else { - got = NewManager[data]( + got = NewManagement[data]( defaultOption, tt.fn(), ).exposeAccessHeader @@ -801,20 +792,20 @@ func TestWithExposeAccessHeader(t *testing.T) { func TestWithExposeRefreshHeader(t *testing.T) { type testCase[T any] struct { name string - fn func() ManagerOption[T] + fn func() option.Option[Management[T]] want string } tests := []testCase[data]{ { name: "default", - fn: func() ManagerOption[data] { + fn: func() option.Option[Management[data]] { return nil }, want: "x-refresh-token", }, { name: "set_another_header", - fn: func() ManagerOption[data] { + fn: func() option.Option[Management[data]] { return WithExposeRefreshHeader[data]("refresh-token") }, want: "refresh-token", @@ -824,11 +815,11 @@ func TestWithExposeRefreshHeader(t *testing.T) { t.Run(tt.name, func(t *testing.T) { var got string if tt.fn() == nil { - got = NewManager[data]( + got = NewManagement[data]( defaultOption, ).exposeRefreshHeader } else { - got = NewManager[data]( + got = NewManagement[data]( defaultOption, tt.fn(), ).exposeRefreshHeader @@ -838,62 +829,42 @@ func TestWithExposeRefreshHeader(t *testing.T) { } } -func TestWithIgnorePaths(t *testing.T) { +func TestWithRotateRefreshToken(t *testing.T) { type testCase[T any] struct { - name string - fn func() ManagerOption[T] - paths []string - want []bool + name string + fn func() option.Option[Management[T]] + want bool } tests := []testCase[data]{ { name: "default", - fn: func() ManagerOption[data] { + fn: func() option.Option[Management[data]] { return nil }, - want: []bool{}, + want: false, }, { - name: "all_exists_paths", - fn: func() ManagerOption[data] { - return WithIgnorePaths[data]([]string{ - "/login", - "/signup", - }...) + name: "set_another_header", + fn: func() option.Option[Management[data]] { + return WithRotateRefreshToken[data](true) }, - paths: []string{"/login", "/signup"}, - want: []bool{true, true}, - }, - { - name: "one_path_does_not_exist", - fn: func() ManagerOption[data] { - return WithIgnorePaths[data]([]string{ - "/login", - "/signup", - }...) - }, - paths: []string{"/login", "/profile", "/signup"}, - want: []bool{true, false, true}, + want: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var ignorePaths set.Set[string] + var got bool if tt.fn() == nil { - ignorePaths = NewManager[data]( + got = NewManagement[data]( defaultOption, - ).publicPaths + ).rotateRefreshToken } else { - ignorePaths = NewManager[data]( + got = NewManagement[data]( defaultOption, tt.fn(), - ).publicPaths + ).rotateRefreshToken } - exists := make([]bool, 0, len(tt.paths)) - for _, path := range tt.paths { - exists = append(exists, ignorePaths.Exist(path)) - } - assert.Equal(t, tt.want, exists) + assert.Equal(t, tt.want, got) }) } } @@ -901,20 +872,20 @@ func TestWithIgnorePaths(t *testing.T) { func TestWithNowFunc(t *testing.T) { type testCase[T any] struct { name string - fn func() ManagerOption[T] + fn func() option.Option[Management[T]] want time.Time } tests := []testCase[data]{ { name: "default", - fn: func() ManagerOption[data] { + fn: func() option.Option[Management[data]] { return nil }, want: time.Now(), }, { name: "set_another_now_func", - fn: func() ManagerOption[data] { + fn: func() option.Option[Management[data]] { return WithNowFunc[data](func() time.Time { return nowTime }) @@ -926,11 +897,11 @@ func TestWithNowFunc(t *testing.T) { t.Run(tt.name, func(t *testing.T) { var got time.Time if tt.fn() == nil { - got = NewManager[data]( + got = NewManagement[data]( defaultOption, ).nowFunc() } else { - got = NewManager[data]( + got = NewManagement[data]( defaultOption, tt.fn(), ).nowFunc() @@ -941,26 +912,28 @@ func TestWithNowFunc(t *testing.T) { } func TestWithRefreshJWTOptions(t *testing.T) { + var genIDFn func() string type testCase[T any] struct { name string - fn func() ManagerOption[T] + fn func() option.Option[Management[T]] want *Options } tests := []testCase[data]{ { name: "default", - fn: func() ManagerOption[data] { + fn: func() option.Option[Management[data]] { return nil }, want: nil, }, { name: "set_refresh_jwt_options", - fn: func() ManagerOption[data] { + fn: func() option.Option[Management[data]] { return WithRefreshJWTOptions[data]( NewOptions( 24*60*time.Minute, "refresh sign key", + WithGenIDFunc(genIDFn), ), ) }, @@ -969,6 +942,7 @@ func TestWithRefreshJWTOptions(t *testing.T) { EncryptionKey: "refresh sign key", DecryptKey: "refresh sign key", Method: jwt.SigningMethodHS256, + genIDFn: genIDFn, }, }, } @@ -976,11 +950,11 @@ func TestWithRefreshJWTOptions(t *testing.T) { t.Run(tt.name, func(t *testing.T) { var got *Options if tt.fn() == nil { - got = NewManager[data]( + got = NewManagement[data]( defaultOption, ).refreshJWTOptions } else { - got = NewManager[data]( + got = NewManagement[data]( defaultOption, tt.fn(), ).refreshJWTOptions @@ -990,7 +964,7 @@ func TestWithRefreshJWTOptions(t *testing.T) { } } -func (m *Manager[T]) registerRoutes(server *gin.Engine) { +func (m *Management[T]) registerRoutes(server *gin.Engine) { server.GET("/", func(ctx *gin.Context) { ctx.Status(http.StatusOK) }) diff --git a/middlewares/jwt/types.go b/middlewares/jwt/types.go index dd99e23..6662d4f 100644 --- a/middlewares/jwt/types.go +++ b/middlewares/jwt/types.go @@ -2,18 +2,32 @@ package jwt import ( "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" ) -type LoginManager[T any] interface { - Refresh(ctx *gin.Context) // 刷新 token 的 gin.HandlerFunc - MiddlewareBuilder() gin.HandlerFunc // 登录认证的中间件 +// Manager jwt 管理器. +type Manager[T any] interface { + // Middleware 登录认证的中间件. + Middleware() gin.HandlerFunc + + // Refresh 刷新 token 的 gin.HandlerFunc. + // 需要设置 refreshJWTOptions 否则会出现 500 的 http 状态码. + Refresh(ctx *gin.Context) + + // GenerateAccessToken 生成资源 token. + GenerateAccessToken(data T) (string, error) + + // VerifyAccessToken 校验资源 token. + VerifyAccessToken(token string, opts ...jwt.ParserOption) (RegisteredClaims[T], error) - GenerateAccessToken(data T) (string, error) // 生成资源 token - VerifyAccessToken(token string) (RegisteredClaims[T], error) // 校验资源 token // GenerateRefreshToken 生成刷新 token. // 需要设置 refreshJWTOptions 否则返回 ErrEmptyRefreshOpts 错误. GenerateRefreshToken(data T) (string, error) + // VerifyRefreshToken 校验刷新 token. // 需要设置 refreshJWTOptions 否则返回 ErrEmptyRefreshOpts 错误. - VerifyRefreshToken(token string) (RegisteredClaims[T], error) // 校验刷新 token + VerifyRefreshToken(token string, opts ...jwt.ParserOption) (RegisteredClaims[T], error) + + // SetClaims 设置 claims 到 key=`claims` 的 gin.Context 中. + SetClaims(ctx *gin.Context, claims RegisteredClaims[T]) }