diff --git a/middlewares/accesslog/builder.go b/middlewares/accesslog/builder.go new file mode 100644 index 0000000..c3ab0ac --- /dev/null +++ b/middlewares/accesslog/builder.go @@ -0,0 +1,135 @@ +package accesslog + +import ( + "bytes" + "context" + "io" + "time" + + "github.com/gin-gonic/gin" + "go.uber.org/atomic" +) + +type AccessLog struct { + //http 请求类型 + Method string + //url 整个请求的url + Url string + //请求体 + ReqBody string + //响应体 + RespBody string + //处理时间 + Duration string + //状态码 + Status int +} + +type Builder struct { + allowReqBody *atomic.Bool + allowRespBody *atomic.Bool + //logger logger.LoggerV1 //这里要自己确认用什么日志级别 + // + loggerFunc func(ctx context.Context, al *AccessLog) + maxLength *atomic.Int64 +} + +func NewBuilder(fn func(ctx context.Context, al *AccessLog)) *Builder { + return &Builder{ + allowReqBody: atomic.NewBool(false), + allowRespBody: atomic.NewBool(false), + loggerFunc: fn, + maxLength: atomic.NewInt64(1024), + } +} + +// AllowReqBody 是否打印请求体 +func (b *Builder) AllowReqBody() *Builder { + b.allowReqBody.Store(true) + return b +} + +// AllowRespBody 是否打印响应体 +func (b *Builder) AllowRespBody() *Builder { + b.allowRespBody.Store(true) + return b +} + +// MaxLength 打印的最大长度 +func (b *Builder) MaxLength(maxLength int64) *Builder { + b.maxLength.Store(maxLength) + return b +} + +func (b *Builder) Builder() gin.HandlerFunc { + return func(ctx *gin.Context) { + var ( + //请求处理开始时间 + start = time.Now() + //url + url = ctx.Request.URL.String() + //url 长度 + curLen = int64(len(url)) + //运行打印的最大长度 + maxLength = b.maxLength.Load() + //是否打印请求体 + allowReqBody = b.allowReqBody.Load() + //是否打印响应体 + allowRespBody = b.allowRespBody.Load() + ) + + if curLen >= maxLength { + url = url[:maxLength] + } + + accessLog := &AccessLog{ + Method: ctx.Request.Method, + Url: url, + } + if ctx.Request.Body != nil && allowReqBody { + body, _ := ctx.GetRawData() + ctx.Request.Body = io.NopCloser(bytes.NewReader(body)) + if int64(len(body)) >= maxLength { + body = body[:maxLength] + } + //注意资源的消耗 + accessLog.ReqBody = string(body) + } + + if allowRespBody { + ctx.Writer = responseWriter{ + ResponseWriter: ctx.Writer, + al: accessLog, + maxLength: maxLength, + } + } + + defer func() { + accessLog.Duration = time.Now().Sub(start).String() + //日志打印 + b.loggerFunc(ctx, accessLog) + }() + ctx.Next() + } +} + +type responseWriter struct { + gin.ResponseWriter + al *AccessLog + maxLength int64 +} + +func (r responseWriter) WriteHeader(statusCode int) { + + r.al.Status = statusCode + r.ResponseWriter.WriteHeader(statusCode) +} + +func (r responseWriter) Write(data []byte) (int, error) { + curLen := int64(len(data)) + if curLen >= r.maxLength { + data = data[:r.maxLength] + } + r.al.RespBody = string(data) + return r.ResponseWriter.Write(data) +} diff --git a/middlewares/accesslog/builder_test.go b/middlewares/accesslog/builder_test.go new file mode 100644 index 0000000..30ecd4f --- /dev/null +++ b/middlewares/accesslog/builder_test.go @@ -0,0 +1,249 @@ +package accesslog + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuilder_Builder(t *testing.T) { + testCases := []struct { + name string + getReq func() *http.Request + accesslog *AccessLog + logfunc func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) + middleWarebuilder func(func(ctx context.Context, al *AccessLog)) gin.HandlerFunc + setStatus int + setRsp string + resultAccessLog *AccessLog + }{ + { + name: "不打印请求体,响应体", + getReq: func() *http.Request { + req, err := http.NewRequest(http.MethodGet, "/accesslog", nil) + require.NoError(t, err) + return req + }, + accesslog: &AccessLog{}, + logfunc: func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) { + + return func(ctx context.Context, al *AccessLog) { + // + //accesslog.Status = al.Status + //accesslog.Method = al.Method + ////url 整个请求的u + //accesslog.Url = al.Url + ////请求体 + //accesslog.ReqBody = al.ReqBody + ////响应体 + //accesslog.RespBody = al.RespBody + ////处理时间 + //accesslog.Duration = al.Duration + ////状态码 + //accesslog.Status = al.Status + copy(accesslog, al) + fmt.Printf("请求类型: %s \n请求url:%s \n请求体:%s \n响应体:%s \n状态码:%d \n消耗时间:%s \n", al.Method, al.Url, al.ReqBody, al.RespBody, al.Status, al.Duration) + } + }, + middleWarebuilder: func(f func(ctx context.Context, al *AccessLog)) gin.HandlerFunc { + return NewBuilder(f).Builder() + }, + resultAccessLog: &AccessLog{ + Method: "GET", + Url: "/accesslog", + }, + }, + { + name: "不打印请求体,打印响应体", + getReq: func() *http.Request { + req, err := http.NewRequest(http.MethodGet, "/accesslog", nil) + require.NoError(t, err) + return req + }, + accesslog: &AccessLog{}, + logfunc: func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) { + + return func(ctx context.Context, al *AccessLog) { + + copy(accesslog, al) + + fmt.Printf("请求类型: %s \n请求url:%s \n请求体:%s \n响应体:%s \n状态码:%d \n消耗时间:%s \n", al.Method, al.Url, al.ReqBody, al.RespBody, al.Status, al.Duration) + } + }, + middleWarebuilder: func(f func(ctx context.Context, al *AccessLog)) gin.HandlerFunc { + return NewBuilder(f).AllowRespBody().Builder() + }, + resultAccessLog: &AccessLog{ + Method: "GET", + Url: "/accesslog", + RespBody: `{"msg":"aa22"}`, + Status: http.StatusOK, + }, + }, + { + name: "打印请求体,不打印响应体", + getReq: func() *http.Request { + read := strings.NewReader(`{"msg":"aa11"}`) + + req, err := http.NewRequest(http.MethodGet, "/accesslog", read) + require.NoError(t, err) + return req + }, + accesslog: &AccessLog{}, + logfunc: func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) { + + return func(ctx context.Context, al *AccessLog) { + + copy(accesslog, al) + + fmt.Printf("请求类型: %s \n请求url:%s \n请求体:%s \n响应体:%s \n状态码:%d \n消耗时间:%s \n", al.Method, al.Url, al.ReqBody, al.RespBody, al.Status, al.Duration) + } + }, + middleWarebuilder: func(f func(ctx context.Context, al *AccessLog)) gin.HandlerFunc { + return NewBuilder(f).AllowReqBody().Builder() + }, + resultAccessLog: &AccessLog{ + Method: "GET", + Url: "/accesslog", + ReqBody: `{"msg":"aa11"}`, + }, + }, + { + name: "打印请求体,打印响应体", + getReq: func() *http.Request { + read := strings.NewReader(`{"msg":"aa11"}`) + + req, err := http.NewRequest(http.MethodGet, "/accesslog", read) + require.NoError(t, err) + return req + }, + accesslog: &AccessLog{}, + logfunc: func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) { + + return func(ctx context.Context, al *AccessLog) { + + copy(accesslog, al) + + fmt.Printf("请求类型: %s \n请求url:%s \n请求体:%s \n响应体:%s \n状态码:%d \n消耗时间:%s \n", al.Method, al.Url, al.ReqBody, al.RespBody, al.Status, al.Duration) + } + }, + middleWarebuilder: func(f func(ctx context.Context, al *AccessLog)) gin.HandlerFunc { + return NewBuilder(f).AllowReqBody().AllowRespBody().Builder() + }, + resultAccessLog: &AccessLog{ + Method: "GET", + Url: "/accesslog", + ReqBody: `{"msg":"aa11"}`, + RespBody: `{"msg":"aa22"}`, + Status: http.StatusOK, + }, + }, + { + name: "打印请求体超标,不打印响应体,限制长度为10", + getReq: func() *http.Request { + read := strings.NewReader(`{"msg":"aa11"}`) + + req, err := http.NewRequest(http.MethodGet, "/accesslog", read) + require.NoError(t, err) + return req + }, + accesslog: &AccessLog{}, + logfunc: func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) { + + return func(ctx context.Context, al *AccessLog) { + + copy(accesslog, al) + + fmt.Printf("请求类型: %s \n请求url:%s \n请求体:%s \n响应体:%s \n状态码:%d \n消耗时间:%s \n", al.Method, al.Url, al.ReqBody, al.RespBody, al.Status, al.Duration) + } + }, + middleWarebuilder: func(f func(ctx context.Context, al *AccessLog)) gin.HandlerFunc { + return NewBuilder(f).AllowReqBody().MaxLength(10).Builder() + }, + resultAccessLog: &AccessLog{ + Method: "GET", + Url: "/accesslog", + ReqBody: `{"msg":"aa`, + }, + }, + { + name: "不打印请求体,打印响应体超标,限制长度为10", + getReq: func() *http.Request { + read := strings.NewReader(`{"msg":"aa11"}`) + + req, err := http.NewRequest(http.MethodGet, "/accesslog", read) + require.NoError(t, err) + return req + }, + accesslog: &AccessLog{}, + logfunc: func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) { + + return func(ctx context.Context, al *AccessLog) { + + copy(accesslog, al) + + fmt.Printf("请求类型: %s \n请求url:%s \n请求体:%s \n响应体:%s \n状态码:%d \n消耗时间:%s \n", al.Method, al.Url, al.ReqBody, al.RespBody, al.Status, al.Duration) + } + }, + middleWarebuilder: func(f func(ctx context.Context, al *AccessLog)) gin.HandlerFunc { + return NewBuilder(f).AllowRespBody().MaxLength(10).Builder() + }, + resultAccessLog: &AccessLog{ + Method: "GET", + Url: "/accesslog", + RespBody: `{"msg":"aa`, + Status: http.StatusOK, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + server := gin.Default() + server.Use(tc.middleWarebuilder(tc.logfunc(tc.accesslog))) + server.GET("/accesslog", func(ctx *gin.Context) { + ctx.JSON(http.StatusOK, map[string]any{ + "msg": "aa22", + }) + }) + resp := httptest.NewRecorder() + + server.ServeHTTP(resp, tc.getReq()) + //中间件使用的defer 所有这里要给点时间 + time.Sleep(time.Millisecond * 100) + assert.Equal(t, tc.accesslog.Method, tc.resultAccessLog.Method) + assert.Equal(t, tc.accesslog.Url, tc.resultAccessLog.Url) + assert.Equal(t, tc.accesslog.ReqBody, tc.resultAccessLog.ReqBody) + assert.Equal(t, tc.accesslog.RespBody, tc.resultAccessLog.RespBody) + //时间不好判断 + //assert.Equal(t, tc.accesslog.Duration, tc.resultAccessLog.Duration) + + assert.Equal(t, tc.accesslog.Status, tc.resultAccessLog.Status) + + }) + + } +} + +func copy(source, target *AccessLog) { + source.Status = target.Status + source.Method = target.Method + //url 整个请求的u + source.Url = target.Url + //请求体 + source.ReqBody = target.ReqBody + //响应体 + source.RespBody = target.RespBody + //处理时间 + source.Duration = target.Duration + //状态码 + source.Status = target.Status +}