diff --git a/.golangci.yaml b/.golangci.yaml index fa53e4e9..3f37f9d9 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -40,6 +40,9 @@ linters-settings: - desc: "Use `github.com/google/uuid` package for UUIDs instead." pkg: "github.com/xtgo/uuid" + exhaustive: + default-signifies-exhaustive: true + forbidigo: forbid: - msg: "Use `require` variants instead." diff --git a/api_handler.go b/api_handler.go index 291b1ee2..55e9aa80 100644 --- a/api_handler.go +++ b/api_handler.go @@ -108,7 +108,7 @@ func (*jobCancelEndpoint) Meta() *apiendpoint.EndpointMeta { } type jobCancelRequest struct { - JobIDs []int64String `json:"ids"` + JobIDs []int64String `json:"ids" validate:"required,min=1,max=1000"` } func (a *jobCancelEndpoint) Execute(ctx context.Context, req *jobCancelRequest) (*statusResponse, error) { @@ -252,7 +252,7 @@ func (*jobGetEndpoint) Meta() *apiendpoint.EndpointMeta { } type jobGetRequest struct { - JobID int64 `json:"-"` // from ExtractRaw + JobID int64 `json:"-" validate:"required"` // from ExtractRaw } func (req *jobGetRequest) ExtractRaw(r *http.Request) error { diff --git a/go.mod b/go.mod index e5e5b0d1..2180c7cb 100644 --- a/go.mod +++ b/go.mod @@ -16,16 +16,23 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.22.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/leodido/go-urn v1.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/riverqueue/river/riverdriver v0.7.0 // indirect go.opentelemetry.io/otel v1.19.0 // indirect go.opentelemetry.io/otel/trace v1.19.0 // indirect golang.org/x/crypto v0.22.0 // indirect + golang.org/x/net v0.21.0 // indirect golang.org/x/sync v0.7.0 // indirect + golang.org/x/sys v0.19.0 // indirect golang.org/x/text v0.16.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 9f7f7692..7379cb51 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,14 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.22.0 h1:k6HsTZ0sTnROkhS//R0O+55JgM8C4Bx7ia+JlgcnOao= +github.com/go-playground/validator/v10 v10.22.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -21,6 +29,8 @@ github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -56,8 +66,12 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/apiendpoint/api_endpoint.go b/internal/apiendpoint/api_endpoint.go index 5de0914c..1f433bf6 100644 --- a/internal/apiendpoint/api_endpoint.go +++ b/internal/apiendpoint/api_endpoint.go @@ -16,6 +16,7 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/riverqueue/riverui/internal/apierror" + "github.com/riverqueue/riverui/internal/validate" ) // Endpoint is a struct that should be embedded on an API endpoint, and which @@ -123,6 +124,10 @@ func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Requ } } + if err := validate.StructCtx(ctx, &req); err != nil { + return apierror.NewBadRequest(validate.PublicFacingMessage(err)) + } + resp, err := execute(ctx, &req) if err != nil { return err diff --git a/internal/apiendpoint/api_endpoint_test.go b/internal/apiendpoint/api_endpoint_test.go index 24ed15bf..49dfc175 100644 --- a/internal/apiendpoint/api_endpoint_test.go +++ b/internal/apiendpoint/api_endpoint_test.go @@ -91,7 +91,7 @@ func TestMountAndServe(t *testing.T) { requireStatusAndJSONResponse(t, http.StatusCreated, &postResponse{Message: "Hello."}, bundle.recorder) }) - t.Run("APIError", func(t *testing.T) { + t.Run("ValidationError", func(t *testing.T) { t.Parallel() mux, bundle := setup(t) @@ -99,7 +99,19 @@ func TestMountAndServe(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", nil) mux.ServeHTTP(bundle.recorder, req) - requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Missing message value."}, bundle.recorder) + requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Field `message` is required."}, bundle.recorder) + }) + + t.Run("APIError", func(t *testing.T) { + t.Parallel() + + mux, bundle := setup(t) + + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", + bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeAPIError: true, Message: "Hello."}))) + mux.ServeHTTP(bundle.recorder, req) + + requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Bad request."}, bundle.recorder) }) t.Run("InterpretedPostgresError", func(t *testing.T) { @@ -108,7 +120,7 @@ func TestMountAndServe(t *testing.T) { mux, bundle := setup(t) req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakePostgresError: true}))) + bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakePostgresError: true, Message: "Hello."}))) mux.ServeHTTP(bundle.recorder, req) requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Insufficient database privilege to perform this operation."}, bundle.recorder) @@ -123,11 +135,11 @@ func TestMountAndServe(t *testing.T) { t.Cleanup(cancel) req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeInternalError: true}))) + bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."}))) require.NoError(t, err) mux.ServeHTTP(bundle.recorder, req) - requireStatusAndJSONResponse(t, http.StatusInternalServerError, &apierror.APIError{Message: "Internal server error. Check logs for more information."}, bundle.recorder) + requireStatusAndJSONResponse(t, http.StatusServiceUnavailable, &apierror.APIError{Message: "Request timed out. Retrying the request might work."}, bundle.recorder) }) t.Run("InternalServerError", func(t *testing.T) { @@ -136,7 +148,7 @@ func TestMountAndServe(t *testing.T) { mux, bundle := setup(t) req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeInternalError: true}))) + bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeInternalError: true, Message: "Hello."}))) mux.ServeHTTP(bundle.recorder, req) requireStatusAndJSONResponse(t, http.StatusInternalServerError, &apierror.APIError{Message: "Internal server error. Check logs for more information."}, bundle.recorder) @@ -195,8 +207,8 @@ func (*getEndpoint) Meta() *EndpointMeta { } type getRequest struct { - IgnoredJSONMessage string `json:"ignored_json"` - Message string `json:"-"` + IgnoredJSONMessage string `json:"ignored_json" validate:"-"` + Message string `json:"-" validate:"required"` } func (req *getRequest) ExtractRaw(r *http.Request) error { @@ -205,7 +217,7 @@ func (req *getRequest) ExtractRaw(r *http.Request) error { } type getResponse struct { - Message string `json:"message"` + Message string `json:"message" validate:"required"` } func (a *getEndpoint) Execute(_ context.Context, req *getRequest) (*getResponse, error) { @@ -233,16 +245,25 @@ func (*postEndpoint) Meta() *EndpointMeta { } type postRequest struct { - MakeInternalError bool `json:"make_internal_error"` - MakePostgresError bool `json:"make_postgres_error"` - Message string `json:"message"` + MakeAPIError bool `json:"make_api_error" validate:"-"` + MakeInternalError bool `json:"make_internal_error" validate:"-"` + MakePostgresError bool `json:"make_postgres_error" validate:"-"` + Message string `json:"message" validate:"required"` } type postResponse struct { Message string `json:"message"` } -func (a *postEndpoint) Execute(_ context.Context, req *postRequest) (*postResponse, error) { +func (a *postEndpoint) Execute(ctx context.Context, req *postRequest) (*postResponse, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + if req.MakeAPIError { + return nil, apierror.NewBadRequest("Bad request.") + } + if req.MakeInternalError { return nil, errors.New("an internal error occurred") } @@ -252,9 +273,5 @@ func (a *postEndpoint) Execute(_ context.Context, req *postRequest) (*postRespon return nil, fmt.Errorf("error runnning Postgres query: %w", &pgconn.PgError{Code: pgerrcode.InsufficientPrivilege}) } - if req.Message == "" { - return nil, apierror.NewBadRequest("Missing message value.") - } - return &postResponse{Message: req.Message}, nil } diff --git a/internal/validate/validate.go b/internal/validate/validate.go new file mode 100644 index 00000000..fb5c86ab --- /dev/null +++ b/internal/validate/validate.go @@ -0,0 +1,114 @@ +// Package validate internalizes Go Playground's Validator framework, setting +// some common options that we use everywhere, providing some useful helpers, +// and exporting a simplified API. +package validate + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/go-playground/validator/v10" +) + +// WithRequiredStructEnabled can be removed once validator/v11 is released. +var validate = validator.New(validator.WithRequiredStructEnabled()) //nolint:gochecknoglobals + +func init() { //nolint:gochecknoinits + validate.RegisterTagNameFunc(preferPublicName) +} + +// PublicFacingMessage builds a complete error message from a validator error +// that's suitable for public-facing consumption. +// +// I only added a few possible validations to start. We'll probably need to add +// more as we go and expand our usage. +func PublicFacingMessage(validatorErr error) string { + var message string + + //nolint:errorlint + if validationErrs, ok := validatorErr.(validator.ValidationErrors); ok { + for _, fieldErr := range validationErrs { + switch fieldErr.Tag() { + case "lte": + fallthrough // lte and max are synonyms + case "max": + kind := fieldErr.Kind() + if kind == reflect.Ptr { + kind = fieldErr.Type().Elem().Kind() + } + + switch kind { + case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int32, reflect.Int64: + message += fmt.Sprintf(" Field `%s` must be less than or equal to %s.", + fieldErr.Field(), fieldErr.Param()) + + case reflect.Slice, reflect.Map: + message += fmt.Sprintf(" Field `%s` must contain at most %s element(s).", + fieldErr.Field(), fieldErr.Param()) + + case reflect.String: + message += fmt.Sprintf(" Field `%s` must be at most %s character(s) long.", + fieldErr.Field(), fieldErr.Param()) + + default: + message += fieldErr.Error() + } + + case "gte": + fallthrough // gte and min are synonyms + case "min": + kind := fieldErr.Kind() + if kind == reflect.Ptr { + kind = fieldErr.Type().Elem().Kind() + } + + switch kind { + case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int32, reflect.Int64: + message += fmt.Sprintf(" Field `%s` must be greater or equal to %s.", + fieldErr.Field(), fieldErr.Param()) + + case reflect.Slice, reflect.Map: + message += fmt.Sprintf(" Field `%s` must contain at least %s element(s).", + fieldErr.Field(), fieldErr.Param()) + + case reflect.String: + message += fmt.Sprintf(" Field `%s` must be at least %s character(s) long.", + fieldErr.Field(), fieldErr.Param()) + + default: + message += fieldErr.Error() + } + + case "required": + message += fmt.Sprintf(" Field `%s` is required.", fieldErr.Field()) + + default: + message += fmt.Sprintf(" Validation on field `%s` failed on the `%s` tag.", fieldErr.Field(), fieldErr.Tag()) + } + } + } + + return strings.TrimSpace(message) +} + +// StructCtx validates a structs exposed fields, and automatically validates +// nested structs, unless otherwise specified and also allows passing of +// context.Context for contextual validation information. +func StructCtx(ctx context.Context, s any) error { + return validate.StructCtx(ctx, s) +} + +// preferPublicName is a validator tag naming function that uses public names +// like a field's JSON tag instead of actual field names in structs. +// This is important because we sent these back as user-facing errors (and the +// users submitted them as JSON/path parameters). +func preferPublicName(fld reflect.StructField) string { + name, _, _ := strings.Cut(fld.Tag.Get("json"), ",") + if name != "" && name != "-" { + return name + } + + return fld.Name +} diff --git a/internal/validate/validate_test.go b/internal/validate/validate_test.go new file mode 100644 index 00000000..47433b53 --- /dev/null +++ b/internal/validate/validate_test.go @@ -0,0 +1,125 @@ +package validate + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFromValidator(t *testing.T) { + t.Parallel() + + // Fields have JSON tags so we can verify those are used over the + // property name. + type TestStruct struct { + MinInt int `json:"min_int" validate:"min=1"` + MinSlice []string `json:"min_slice" validate:"min=1"` + MinString string `json:"min_string" validate:"min=1"` + MaxInt int `json:"max_int" validate:"max=0"` + MaxSlice []string `json:"max_slice" validate:"max=0"` + MaxString string `json:"max_string" validate:"max=0"` + Required string `json:"required" validate:"required"` + Unsupported string `json:"unsupported" validate:"e164"` + } + + validTestStruct := func() *TestStruct { + return &TestStruct{ + MinInt: 1, + MinSlice: []string{"1"}, + MinString: "value", + MaxInt: 0, + MaxSlice: []string{}, + MaxString: "", + Required: "value", + Unsupported: "+1123456789", + } + } + + t.Run("MaxInt", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MaxInt = 1 + require.Equal(t, "Field `max_int` must be less than or equal to 0.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("MaxSlice", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MaxSlice = []string{"1"} + require.Equal(t, "Field `max_slice` must contain at most 0 element(s).", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("MaxString", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MaxString = "value" + require.Equal(t, "Field `max_string` must be at most 0 character(s) long.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("MinInt", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MinInt = 0 + require.Equal(t, "Field `min_int` must be greater or equal to 1.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("MinSlice", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MinSlice = nil + require.Equal(t, "Field `min_slice` must contain at least 1 element(s).", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("MinString", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MinString = "" + require.Equal(t, "Field `min_string` must be at least 1 character(s) long.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("Required", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.Required = "" + require.Equal(t, "Field `required` is required.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("Unsupported", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.Unsupported = "abc" + require.Equal(t, "Validation on field `unsupported` failed on the `e164` tag.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("MultipleErrors", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MinInt = 0 + testStruct.Required = "" + require.Equal(t, "Field `min_int` must be greater or equal to 1. Field `required` is required.", PublicFacingMessage(validate.Struct(testStruct))) + }) +} + +func TestPreferPublicNames(t *testing.T) { + t.Parallel() + + type testStruct struct { + JSONNameField string `json:"json_name"` + StructNameField string `apiquery:"-"` + } + + require.Equal(t, "json_name", + preferPublicName(reflect.TypeOf(testStruct{}).Field(0))) + require.Equal(t, "StructNameField", + preferPublicName(reflect.TypeOf(testStruct{}).Field(1))) +}