Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions middlewares/accesslog/builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package accesslog

import (
"bytes"
"context"
"io"
"time"

"github.com/gin-gonic/gin"
"go.uber.org/atomic"
)

type AccessLog struct {
//http 请求类型
Method string
//url 整个请求的url
Url string
//请求体
ReqBody string
//响应体
RespBody string
//处理时间
Duration string
//状态码
Status int
}

type Builder struct {
allowReqBody *atomic.Bool
allowRespBody *atomic.Bool
//logger logger.LoggerV1 //这里要自己确认用什么日志级别
//
loggerFunc func(ctx context.Context, al *AccessLog)
maxLength *atomic.Int64
}

func NewBuilder(fn func(ctx context.Context, al *AccessLog)) *Builder {
return &Builder{
allowReqBody: atomic.NewBool(false),
allowRespBody: atomic.NewBool(false),
loggerFunc: fn,
maxLength: atomic.NewInt64(1024),
}
}

// AllowReqBody 是否打印请求体
func (b *Builder) AllowReqBody() *Builder {
b.allowReqBody.Store(true)
return b
}

// AllowRespBody 是否打印响应体
func (b *Builder) AllowRespBody() *Builder {
b.allowRespBody.Store(true)
return b
}

// MaxLength 打印的最大长度
func (b *Builder) MaxLength(maxLength int64) *Builder {
b.maxLength.Store(maxLength)
return b
}

func (b *Builder) Builder() gin.HandlerFunc {
return func(ctx *gin.Context) {
var (
//请求处理开始时间
start = time.Now()
//url
url = ctx.Request.URL.String()
//url 长度
curLen = int64(len(url))
//运行打印的最大长度
maxLength = b.maxLength.Load()
//是否打印请求体
allowReqBody = b.allowReqBody.Load()
//是否打印响应体
allowRespBody = b.allowRespBody.Load()
)

if curLen >= maxLength {
url = url[:maxLength]
}

accessLog := &AccessLog{
Method: ctx.Request.Method,
Url: url,
}
if ctx.Request.Body != nil && allowReqBody {
body, _ := ctx.GetRawData()
ctx.Request.Body = io.NopCloser(bytes.NewReader(body))
if int64(len(body)) >= maxLength {
body = body[:maxLength]
}
//注意资源的消耗
accessLog.ReqBody = string(body)
}

if allowRespBody {
ctx.Writer = responseWriter{
ResponseWriter: ctx.Writer,
al: accessLog,
maxLength: maxLength,
}
}

defer func() {
accessLog.Duration = time.Now().Sub(start).String()
//日志打印
b.loggerFunc(ctx, accessLog)
}()
ctx.Next()
}
}

type responseWriter struct {
gin.ResponseWriter
al *AccessLog
maxLength int64
}

func (r responseWriter) WriteHeader(statusCode int) {

r.al.Status = statusCode
r.ResponseWriter.WriteHeader(statusCode)
}

func (r responseWriter) Write(data []byte) (int, error) {
curLen := int64(len(data))
if curLen >= r.maxLength {
data = data[:r.maxLength]
}
r.al.RespBody = string(data)
return r.ResponseWriter.Write(data)
}
249 changes: 249 additions & 0 deletions middlewares/accesslog/builder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
package accesslog

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestBuilder_Builder(t *testing.T) {
testCases := []struct {
name string
getReq func() *http.Request
accesslog *AccessLog
logfunc func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog)
middleWarebuilder func(func(ctx context.Context, al *AccessLog)) gin.HandlerFunc
setStatus int
setRsp string
resultAccessLog *AccessLog
}{
{
name: "不打印请求体,响应体",
getReq: func() *http.Request {
req, err := http.NewRequest(http.MethodGet, "/accesslog", nil)
require.NoError(t, err)
return req
},
accesslog: &AccessLog{},
logfunc: func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) {

return func(ctx context.Context, al *AccessLog) {
//
//accesslog.Status = al.Status
//accesslog.Method = al.Method
////url 整个请求的u
//accesslog.Url = al.Url
////请求体
//accesslog.ReqBody = al.ReqBody
////响应体
//accesslog.RespBody = al.RespBody
////处理时间
//accesslog.Duration = al.Duration
////状态码
//accesslog.Status = al.Status
copy(accesslog, al)
fmt.Printf("请求类型: %s \n请求url:%s \n请求体:%s \n响应体:%s \n状态码:%d \n消耗时间:%s \n", al.Method, al.Url, al.ReqBody, al.RespBody, al.Status, al.Duration)
}
},
middleWarebuilder: func(f func(ctx context.Context, al *AccessLog)) gin.HandlerFunc {
return NewBuilder(f).Builder()
},
resultAccessLog: &AccessLog{
Method: "GET",
Url: "/accesslog",
},
},
{
name: "不打印请求体,打印响应体",
getReq: func() *http.Request {
req, err := http.NewRequest(http.MethodGet, "/accesslog", nil)
require.NoError(t, err)
return req
},
accesslog: &AccessLog{},
logfunc: func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) {

return func(ctx context.Context, al *AccessLog) {

copy(accesslog, al)

fmt.Printf("请求类型: %s \n请求url:%s \n请求体:%s \n响应体:%s \n状态码:%d \n消耗时间:%s \n", al.Method, al.Url, al.ReqBody, al.RespBody, al.Status, al.Duration)
}
},
middleWarebuilder: func(f func(ctx context.Context, al *AccessLog)) gin.HandlerFunc {
return NewBuilder(f).AllowRespBody().Builder()
},
resultAccessLog: &AccessLog{
Method: "GET",
Url: "/accesslog",
RespBody: `{"msg":"aa22"}`,
Status: http.StatusOK,
},
},
{
name: "打印请求体,不打印响应体",
getReq: func() *http.Request {
read := strings.NewReader(`{"msg":"aa11"}`)

req, err := http.NewRequest(http.MethodGet, "/accesslog", read)
require.NoError(t, err)
return req
},
accesslog: &AccessLog{},
logfunc: func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) {

return func(ctx context.Context, al *AccessLog) {

copy(accesslog, al)

fmt.Printf("请求类型: %s \n请求url:%s \n请求体:%s \n响应体:%s \n状态码:%d \n消耗时间:%s \n", al.Method, al.Url, al.ReqBody, al.RespBody, al.Status, al.Duration)
}
},
middleWarebuilder: func(f func(ctx context.Context, al *AccessLog)) gin.HandlerFunc {
return NewBuilder(f).AllowReqBody().Builder()
},
resultAccessLog: &AccessLog{
Method: "GET",
Url: "/accesslog",
ReqBody: `{"msg":"aa11"}`,
},
},
{
name: "打印请求体,打印响应体",
getReq: func() *http.Request {
read := strings.NewReader(`{"msg":"aa11"}`)

req, err := http.NewRequest(http.MethodGet, "/accesslog", read)
require.NoError(t, err)
return req
},
accesslog: &AccessLog{},
logfunc: func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) {

return func(ctx context.Context, al *AccessLog) {

copy(accesslog, al)

fmt.Printf("请求类型: %s \n请求url:%s \n请求体:%s \n响应体:%s \n状态码:%d \n消耗时间:%s \n", al.Method, al.Url, al.ReqBody, al.RespBody, al.Status, al.Duration)
}
},
middleWarebuilder: func(f func(ctx context.Context, al *AccessLog)) gin.HandlerFunc {
return NewBuilder(f).AllowReqBody().AllowRespBody().Builder()
},
resultAccessLog: &AccessLog{
Method: "GET",
Url: "/accesslog",
ReqBody: `{"msg":"aa11"}`,
RespBody: `{"msg":"aa22"}`,
Status: http.StatusOK,
},
},
{
name: "打印请求体超标,不打印响应体,限制长度为10",
getReq: func() *http.Request {
read := strings.NewReader(`{"msg":"aa11"}`)

req, err := http.NewRequest(http.MethodGet, "/accesslog", read)
require.NoError(t, err)
return req
},
accesslog: &AccessLog{},
logfunc: func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) {

return func(ctx context.Context, al *AccessLog) {

copy(accesslog, al)

fmt.Printf("请求类型: %s \n请求url:%s \n请求体:%s \n响应体:%s \n状态码:%d \n消耗时间:%s \n", al.Method, al.Url, al.ReqBody, al.RespBody, al.Status, al.Duration)
}
},
middleWarebuilder: func(f func(ctx context.Context, al *AccessLog)) gin.HandlerFunc {
return NewBuilder(f).AllowReqBody().MaxLength(10).Builder()
},
resultAccessLog: &AccessLog{
Method: "GET",
Url: "/accesslog",
ReqBody: `{"msg":"aa`,
},
},
{
name: "不打印请求体,打印响应体超标,限制长度为10",
getReq: func() *http.Request {
read := strings.NewReader(`{"msg":"aa11"}`)

req, err := http.NewRequest(http.MethodGet, "/accesslog", read)
require.NoError(t, err)
return req
},
accesslog: &AccessLog{},
logfunc: func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) {

return func(ctx context.Context, al *AccessLog) {

copy(accesslog, al)

fmt.Printf("请求类型: %s \n请求url:%s \n请求体:%s \n响应体:%s \n状态码:%d \n消耗时间:%s \n", al.Method, al.Url, al.ReqBody, al.RespBody, al.Status, al.Duration)
}
},
middleWarebuilder: func(f func(ctx context.Context, al *AccessLog)) gin.HandlerFunc {
return NewBuilder(f).AllowRespBody().MaxLength(10).Builder()
},
resultAccessLog: &AccessLog{
Method: "GET",
Url: "/accesslog",
RespBody: `{"msg":"aa`,
Status: http.StatusOK,
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
server := gin.Default()
server.Use(tc.middleWarebuilder(tc.logfunc(tc.accesslog)))
server.GET("/accesslog", func(ctx *gin.Context) {
ctx.JSON(http.StatusOK, map[string]any{
"msg": "aa22",
})
})
resp := httptest.NewRecorder()

server.ServeHTTP(resp, tc.getReq())
//中间件使用的defer 所有这里要给点时间
time.Sleep(time.Millisecond * 100)
assert.Equal(t, tc.accesslog.Method, tc.resultAccessLog.Method)
assert.Equal(t, tc.accesslog.Url, tc.resultAccessLog.Url)
assert.Equal(t, tc.accesslog.ReqBody, tc.resultAccessLog.ReqBody)
assert.Equal(t, tc.accesslog.RespBody, tc.resultAccessLog.RespBody)
//时间不好判断
//assert.Equal(t, tc.accesslog.Duration, tc.resultAccessLog.Duration)

assert.Equal(t, tc.accesslog.Status, tc.resultAccessLog.Status)

})

}
}

func copy(source, target *AccessLog) {
source.Status = target.Status
source.Method = target.Method
//url 整个请求的u
source.Url = target.Url
//请求体
source.ReqBody = target.ReqBody
//响应体
source.RespBody = target.RespBody
//处理时间
source.Duration = target.Duration
//状态码
source.Status = target.Status
}