From 9c22a9b0963c6ec698a71265bb7051a2017bf191 Mon Sep 17 00:00:00 2001 From: Brandur Date: Thu, 27 Jun 2024 21:09:17 -0700 Subject: [PATCH] Use Go validator framework to succinctly validate requests A general problem with Go is that validating things is a very noisy affair involving long lists of if statements combined with error returns for every field on a struct. A technique that we've been using for a while is to use the Go validator framework [1] that allows fields to be tagged with succinct validation syntax for a variety of different things. e.g. type jobCancelRequest struct { JobIDs []int64String `json:"ids" validate:"required,min=1,max=1000"` } I'm not sure the use of something like this is necessary for a project that's a dependency like core River itself, but but internal use on more of an "application" project like River UI, it might be helpful. We combine the validator with the new API framework from #63 so that incoming request structs are validated automatically for every endpoint, which shaves a lot of lines of otherwise necessary validation code out of individual API endpoint definitions. [1] https://github.com/go-playground/validator?tab=readme-ov-file --- .golangci.yaml | 3 + api_handler.go | 4 +- go.mod | 7 ++ go.sum | 14 +++ internal/apiendpoint/api_endpoint.go | 5 + internal/apiendpoint/api_endpoint_test.go | 51 ++++++--- internal/validate/validate.go | 114 ++++++++++++++++++++ internal/validate/validate_test.go | 125 ++++++++++++++++++++++ 8 files changed, 304 insertions(+), 19 deletions(-) create mode 100644 internal/validate/validate.go create mode 100644 internal/validate/validate_test.go diff --git a/.golangci.yaml b/.golangci.yaml index fa53e4e9..3f37f9d9 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -40,6 +40,9 @@ linters-settings: - desc: "Use `github.com/google/uuid` package for UUIDs instead." pkg: "github.com/xtgo/uuid" + exhaustive: + default-signifies-exhaustive: true + forbidigo: forbid: - msg: "Use `require` variants instead." diff --git a/api_handler.go b/api_handler.go index 291b1ee2..55e9aa80 100644 --- a/api_handler.go +++ b/api_handler.go @@ -108,7 +108,7 @@ func (*jobCancelEndpoint) Meta() *apiendpoint.EndpointMeta { } type jobCancelRequest struct { - JobIDs []int64String `json:"ids"` + JobIDs []int64String `json:"ids" validate:"required,min=1,max=1000"` } func (a *jobCancelEndpoint) Execute(ctx context.Context, req *jobCancelRequest) (*statusResponse, error) { @@ -252,7 +252,7 @@ func (*jobGetEndpoint) Meta() *apiendpoint.EndpointMeta { } type jobGetRequest struct { - JobID int64 `json:"-"` // from ExtractRaw + JobID int64 `json:"-" validate:"required"` // from ExtractRaw } func (req *jobGetRequest) ExtractRaw(r *http.Request) error { diff --git a/go.mod b/go.mod index e5e5b0d1..2180c7cb 100644 --- a/go.mod +++ b/go.mod @@ -16,16 +16,23 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.22.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/leodido/go-urn v1.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/riverqueue/river/riverdriver v0.7.0 // indirect go.opentelemetry.io/otel v1.19.0 // indirect go.opentelemetry.io/otel/trace v1.19.0 // indirect golang.org/x/crypto v0.22.0 // indirect + golang.org/x/net v0.21.0 // indirect golang.org/x/sync v0.7.0 // indirect + golang.org/x/sys v0.19.0 // indirect golang.org/x/text v0.16.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 9f7f7692..7379cb51 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,14 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.22.0 h1:k6HsTZ0sTnROkhS//R0O+55JgM8C4Bx7ia+JlgcnOao= +github.com/go-playground/validator/v10 v10.22.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -21,6 +29,8 @@ github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -56,8 +66,12 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/apiendpoint/api_endpoint.go b/internal/apiendpoint/api_endpoint.go index 5de0914c..1f433bf6 100644 --- a/internal/apiendpoint/api_endpoint.go +++ b/internal/apiendpoint/api_endpoint.go @@ -16,6 +16,7 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/riverqueue/riverui/internal/apierror" + "github.com/riverqueue/riverui/internal/validate" ) // Endpoint is a struct that should be embedded on an API endpoint, and which @@ -123,6 +124,10 @@ func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Requ } } + if err := validate.StructCtx(ctx, &req); err != nil { + return apierror.NewBadRequest(validate.PublicFacingMessage(err)) + } + resp, err := execute(ctx, &req) if err != nil { return err diff --git a/internal/apiendpoint/api_endpoint_test.go b/internal/apiendpoint/api_endpoint_test.go index 24ed15bf..49dfc175 100644 --- a/internal/apiendpoint/api_endpoint_test.go +++ b/internal/apiendpoint/api_endpoint_test.go @@ -91,7 +91,7 @@ func TestMountAndServe(t *testing.T) { requireStatusAndJSONResponse(t, http.StatusCreated, &postResponse{Message: "Hello."}, bundle.recorder) }) - t.Run("APIError", func(t *testing.T) { + t.Run("ValidationError", func(t *testing.T) { t.Parallel() mux, bundle := setup(t) @@ -99,7 +99,19 @@ func TestMountAndServe(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", nil) mux.ServeHTTP(bundle.recorder, req) - requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Missing message value."}, bundle.recorder) + requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Field `message` is required."}, bundle.recorder) + }) + + t.Run("APIError", func(t *testing.T) { + t.Parallel() + + mux, bundle := setup(t) + + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", + bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeAPIError: true, Message: "Hello."}))) + mux.ServeHTTP(bundle.recorder, req) + + requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Bad request."}, bundle.recorder) }) t.Run("InterpretedPostgresError", func(t *testing.T) { @@ -108,7 +120,7 @@ func TestMountAndServe(t *testing.T) { mux, bundle := setup(t) req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakePostgresError: true}))) + bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakePostgresError: true, Message: "Hello."}))) mux.ServeHTTP(bundle.recorder, req) requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Insufficient database privilege to perform this operation."}, bundle.recorder) @@ -123,11 +135,11 @@ func TestMountAndServe(t *testing.T) { t.Cleanup(cancel) req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeInternalError: true}))) + bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."}))) require.NoError(t, err) mux.ServeHTTP(bundle.recorder, req) - requireStatusAndJSONResponse(t, http.StatusInternalServerError, &apierror.APIError{Message: "Internal server error. Check logs for more information."}, bundle.recorder) + requireStatusAndJSONResponse(t, http.StatusServiceUnavailable, &apierror.APIError{Message: "Request timed out. Retrying the request might work."}, bundle.recorder) }) t.Run("InternalServerError", func(t *testing.T) { @@ -136,7 +148,7 @@ func TestMountAndServe(t *testing.T) { mux, bundle := setup(t) req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeInternalError: true}))) + bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeInternalError: true, Message: "Hello."}))) mux.ServeHTTP(bundle.recorder, req) requireStatusAndJSONResponse(t, http.StatusInternalServerError, &apierror.APIError{Message: "Internal server error. Check logs for more information."}, bundle.recorder) @@ -195,8 +207,8 @@ func (*getEndpoint) Meta() *EndpointMeta { } type getRequest struct { - IgnoredJSONMessage string `json:"ignored_json"` - Message string `json:"-"` + IgnoredJSONMessage string `json:"ignored_json" validate:"-"` + Message string `json:"-" validate:"required"` } func (req *getRequest) ExtractRaw(r *http.Request) error { @@ -205,7 +217,7 @@ func (req *getRequest) ExtractRaw(r *http.Request) error { } type getResponse struct { - Message string `json:"message"` + Message string `json:"message" validate:"required"` } func (a *getEndpoint) Execute(_ context.Context, req *getRequest) (*getResponse, error) { @@ -233,16 +245,25 @@ func (*postEndpoint) Meta() *EndpointMeta { } type postRequest struct { - MakeInternalError bool `json:"make_internal_error"` - MakePostgresError bool `json:"make_postgres_error"` - Message string `json:"message"` + MakeAPIError bool `json:"make_api_error" validate:"-"` + MakeInternalError bool `json:"make_internal_error" validate:"-"` + MakePostgresError bool `json:"make_postgres_error" validate:"-"` + Message string `json:"message" validate:"required"` } type postResponse struct { Message string `json:"message"` } -func (a *postEndpoint) Execute(_ context.Context, req *postRequest) (*postResponse, error) { +func (a *postEndpoint) Execute(ctx context.Context, req *postRequest) (*postResponse, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + if req.MakeAPIError { + return nil, apierror.NewBadRequest("Bad request.") + } + if req.MakeInternalError { return nil, errors.New("an internal error occurred") } @@ -252,9 +273,5 @@ func (a *postEndpoint) Execute(_ context.Context, req *postRequest) (*postRespon return nil, fmt.Errorf("error runnning Postgres query: %w", &pgconn.PgError{Code: pgerrcode.InsufficientPrivilege}) } - if req.Message == "" { - return nil, apierror.NewBadRequest("Missing message value.") - } - return &postResponse{Message: req.Message}, nil } diff --git a/internal/validate/validate.go b/internal/validate/validate.go new file mode 100644 index 00000000..fb5c86ab --- /dev/null +++ b/internal/validate/validate.go @@ -0,0 +1,114 @@ +// Package validate internalizes Go Playground's Validator framework, setting +// some common options that we use everywhere, providing some useful helpers, +// and exporting a simplified API. +package validate + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/go-playground/validator/v10" +) + +// WithRequiredStructEnabled can be removed once validator/v11 is released. +var validate = validator.New(validator.WithRequiredStructEnabled()) //nolint:gochecknoglobals + +func init() { //nolint:gochecknoinits + validate.RegisterTagNameFunc(preferPublicName) +} + +// PublicFacingMessage builds a complete error message from a validator error +// that's suitable for public-facing consumption. +// +// I only added a few possible validations to start. We'll probably need to add +// more as we go and expand our usage. +func PublicFacingMessage(validatorErr error) string { + var message string + + //nolint:errorlint + if validationErrs, ok := validatorErr.(validator.ValidationErrors); ok { + for _, fieldErr := range validationErrs { + switch fieldErr.Tag() { + case "lte": + fallthrough // lte and max are synonyms + case "max": + kind := fieldErr.Kind() + if kind == reflect.Ptr { + kind = fieldErr.Type().Elem().Kind() + } + + switch kind { + case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int32, reflect.Int64: + message += fmt.Sprintf(" Field `%s` must be less than or equal to %s.", + fieldErr.Field(), fieldErr.Param()) + + case reflect.Slice, reflect.Map: + message += fmt.Sprintf(" Field `%s` must contain at most %s element(s).", + fieldErr.Field(), fieldErr.Param()) + + case reflect.String: + message += fmt.Sprintf(" Field `%s` must be at most %s character(s) long.", + fieldErr.Field(), fieldErr.Param()) + + default: + message += fieldErr.Error() + } + + case "gte": + fallthrough // gte and min are synonyms + case "min": + kind := fieldErr.Kind() + if kind == reflect.Ptr { + kind = fieldErr.Type().Elem().Kind() + } + + switch kind { + case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int32, reflect.Int64: + message += fmt.Sprintf(" Field `%s` must be greater or equal to %s.", + fieldErr.Field(), fieldErr.Param()) + + case reflect.Slice, reflect.Map: + message += fmt.Sprintf(" Field `%s` must contain at least %s element(s).", + fieldErr.Field(), fieldErr.Param()) + + case reflect.String: + message += fmt.Sprintf(" Field `%s` must be at least %s character(s) long.", + fieldErr.Field(), fieldErr.Param()) + + default: + message += fieldErr.Error() + } + + case "required": + message += fmt.Sprintf(" Field `%s` is required.", fieldErr.Field()) + + default: + message += fmt.Sprintf(" Validation on field `%s` failed on the `%s` tag.", fieldErr.Field(), fieldErr.Tag()) + } + } + } + + return strings.TrimSpace(message) +} + +// StructCtx validates a structs exposed fields, and automatically validates +// nested structs, unless otherwise specified and also allows passing of +// context.Context for contextual validation information. +func StructCtx(ctx context.Context, s any) error { + return validate.StructCtx(ctx, s) +} + +// preferPublicName is a validator tag naming function that uses public names +// like a field's JSON tag instead of actual field names in structs. +// This is important because we sent these back as user-facing errors (and the +// users submitted them as JSON/path parameters). +func preferPublicName(fld reflect.StructField) string { + name, _, _ := strings.Cut(fld.Tag.Get("json"), ",") + if name != "" && name != "-" { + return name + } + + return fld.Name +} diff --git a/internal/validate/validate_test.go b/internal/validate/validate_test.go new file mode 100644 index 00000000..47433b53 --- /dev/null +++ b/internal/validate/validate_test.go @@ -0,0 +1,125 @@ +package validate + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFromValidator(t *testing.T) { + t.Parallel() + + // Fields have JSON tags so we can verify those are used over the + // property name. + type TestStruct struct { + MinInt int `json:"min_int" validate:"min=1"` + MinSlice []string `json:"min_slice" validate:"min=1"` + MinString string `json:"min_string" validate:"min=1"` + MaxInt int `json:"max_int" validate:"max=0"` + MaxSlice []string `json:"max_slice" validate:"max=0"` + MaxString string `json:"max_string" validate:"max=0"` + Required string `json:"required" validate:"required"` + Unsupported string `json:"unsupported" validate:"e164"` + } + + validTestStruct := func() *TestStruct { + return &TestStruct{ + MinInt: 1, + MinSlice: []string{"1"}, + MinString: "value", + MaxInt: 0, + MaxSlice: []string{}, + MaxString: "", + Required: "value", + Unsupported: "+1123456789", + } + } + + t.Run("MaxInt", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MaxInt = 1 + require.Equal(t, "Field `max_int` must be less than or equal to 0.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("MaxSlice", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MaxSlice = []string{"1"} + require.Equal(t, "Field `max_slice` must contain at most 0 element(s).", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("MaxString", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MaxString = "value" + require.Equal(t, "Field `max_string` must be at most 0 character(s) long.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("MinInt", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MinInt = 0 + require.Equal(t, "Field `min_int` must be greater or equal to 1.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("MinSlice", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MinSlice = nil + require.Equal(t, "Field `min_slice` must contain at least 1 element(s).", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("MinString", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MinString = "" + require.Equal(t, "Field `min_string` must be at least 1 character(s) long.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("Required", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.Required = "" + require.Equal(t, "Field `required` is required.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("Unsupported", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.Unsupported = "abc" + require.Equal(t, "Validation on field `unsupported` failed on the `e164` tag.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("MultipleErrors", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MinInt = 0 + testStruct.Required = "" + require.Equal(t, "Field `min_int` must be greater or equal to 1. Field `required` is required.", PublicFacingMessage(validate.Struct(testStruct))) + }) +} + +func TestPreferPublicNames(t *testing.T) { + t.Parallel() + + type testStruct struct { + JSONNameField string `json:"json_name"` + StructNameField string `apiquery:"-"` + } + + require.Equal(t, "json_name", + preferPublicName(reflect.TypeOf(testStruct{}).Field(0))) + require.Equal(t, "StructNameField", + preferPublicName(reflect.TypeOf(testStruct{}).Field(1))) +}