From 3fff4c233b80f47a432a5853db42c149b20c6b38 Mon Sep 17 00:00:00 2001 From: Brandur Date: Sat, 22 Jun 2024 16:32:30 -0700 Subject: [PATCH] Introduce lightweight API framework for Go code + test suite Put in a lightweight API framework for River UI's Go code to make writing endpoints more succinct and better enable testing. Endpoints are defined as a type that embeds a struct declaring their request and response types along with metadata that declares their path and success status code: type jobCancelEndpoint struct { apiBundle apiendpoint.Endpoint[jobCancelRequest, jobCancelResponse] } func (*jobCancelEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Path: "POST /api/jobs/cancel", StatusCode: http.StatusOK, } } The request/response types know to unmarshal and marshal themselves from to JSON, or encapsulate any path/query parameters they need to capture: type jobCancelRequest struct { JobIDs []int64String `json:"ids"` } type jobCancelResponse struct { Status string `json:"status"` } Endpoints define an `Execute` function that takes a request struct and returns a response struct along with a possible error: func (a *jobCancelEndpoint) Execute(ctx context.Context, req *jobCancelRequest) (*jobCancelResponse, error) { ... return &jobCancelResponse{Status: "ok"}, nil } This makes the endpoints a lot easier to write because serialization code gets removed, and errors can be returned succinctly according to normal Go practices instead of each one having to be handled in a custom way and be a liability in case of a forgotten `return` after writing it back in the response. The underlying API framework takes care of writing back errors that should be user-facing (anything in the newly added `apierror` package) or logging an internal error and return a generic message. Context deadline code also gets pushed down. The newly added test suite shows that the `Execute` shape also makes tests easier and more succinct to write because structs can be sent and read directly without having to go through a JSON/HTTP layer, and errors can be handled directly without having to worry about them being converted to a server error, which makes debugging broken tests a lot easier. resp, err := endpoint.Execute(ctx, &jobCancelRequest{JobIDs: []int64String{int64String(insertRes1.Job.ID), int64String(insertRes2.Job.ID)}}) require.NoError(t, err) require.Equal(t, &jobCancelResponse{Status: "ok"}, resp) We also add a suite of integration-level tests that test each endpoint through the entire HTTP/JSON stack to make sure that everything works at that level. This suite is written much more sparingly -- one test fer endpoint -- because the vast majority of endpoint tests should be written in the handler-level suite for the reasons mentioned above. --- .env.example | 2 +- .github/workflows/ci.yaml | 9 +- .gitignore | 1 + .golangci.yaml | 2 + api_handler.go | 148 +++++++----- api_handler_test.go | 112 +++++++++ common_test.go | 61 +++++ docs/development.md | 21 +- go.mod | 4 + go.sum | 8 + handler.go | 18 +- handler_test.go | 93 ++++++++ int64_string.go | 35 +++ int64_string_test.go | 38 +++ internal/apiendpoint/api_endpoint.go | 168 +++++++++++++ internal/apiendpoint/api_endpoint_test.go | 220 ++++++++++++++++++ internal/apierror/api_error.go | 122 ++++++++++ internal/apierror/api_error_test.go | 49 ++++ .../riverinternaltest/riverinternaltest.go | 117 ++++++++++ .../slogtest/slog_test_handler.go | 86 +++++++ .../slogtest/slog_test_handler_test.go | 58 +++++ internal/util/dbutil/db_util.go | 44 ++++ internal/util/dbutil/db_util_test.go | 42 ++++ 23 files changed, 1383 insertions(+), 75 deletions(-) create mode 100644 api_handler_test.go create mode 100644 common_test.go create mode 100644 handler_test.go create mode 100644 int64_string.go create mode 100644 int64_string_test.go create mode 100644 internal/apiendpoint/api_endpoint.go create mode 100644 internal/apiendpoint/api_endpoint_test.go create mode 100644 internal/apierror/api_error.go create mode 100644 internal/apierror/api_error_test.go create mode 100644 internal/riverinternaltest/riverinternaltest.go create mode 100644 internal/riverinternaltest/slogtest/slog_test_handler.go create mode 100644 internal/riverinternaltest/slogtest/slog_test_handler_test.go create mode 100644 internal/util/dbutil/db_util.go create mode 100644 internal/util/dbutil/db_util_test.go diff --git a/.env.example b/.env.example index 40669211..77f7741f 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,4 @@ CORS_ORIGINS=http://localhost:5173,https://example.com -DATABASE_URL=postgres://postgres:postgres@localhost:5432/river-development +DATABASE_URL=postgres://postgres:postgres@localhost:5432/river_dev OTEL_ENABLED=false PORT=8080 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index fbe3b715..55d4d092 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -2,7 +2,10 @@ name: CI env: # A suitable URL for the test database. - DATABASE_URL: postgres://postgres:postgres@127.0.0.1:5432/riverui_dev?sslmode=disable + DATABASE_URL: postgres://postgres:postgres@127.0.0.1:5432/river_dev?sslmode=disable + + # Test database. + TEST_DATABASE_URL: postgres://postgres:postgres@127.0.0.1:5432/river_test?sslmode=disable on: push: @@ -58,14 +61,14 @@ jobs: run: go install github.com/riverqueue/river/cmd/river@latest - name: Create test DB - run: createdb riverui_dev + run: createdb river_test env: PGHOST: 127.0.0.1 PGUSER: postgres PGPASSWORD: postgres - name: Migrate test DB - run: river migrate-up --database-url "$DATABASE_URL" + run: river migrate-up --database-url "$TEST_DATABASE_URL" # ensure that there is a file in `ui/dist` to prevent a lint error about # it during CI when there is nothing there: diff --git a/.gitignore b/.gitignore index 6ea9b69c..3321e19d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .env .env.* !.env.example +.tool-versions /riverui diff --git a/.golangci.yaml b/.golangci.yaml index b85be384..fa53e4e9 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -82,9 +82,11 @@ linters-settings: - id - j - mu + - r # common for http.Request - rw # common for http.ResponseWriter - sb # common convention for string builder - t - tt # common convention for table tests - tx + - w # common for http.ResponseWriter - wg diff --git a/api_handler.go b/api_handler.go index d667e460..be56f508 100644 --- a/api_handler.go +++ b/api_handler.go @@ -12,63 +12,78 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgxpool" "github.com/riverqueue/river" "github.com/riverqueue/river/rivertype" + "github.com/riverqueue/riverui/internal/apiendpoint" + "github.com/riverqueue/riverui/internal/apierror" "github.com/riverqueue/riverui/internal/db" + "github.com/riverqueue/riverui/internal/util/dbutil" ) -type jobCancelRequest struct { - JobIDStrings []string `json:"ids"` -} - -type apiHandler struct { +// A bundle of common utilities needed for many API endpoints. +type apiBundle struct { client *river.Client[pgx.Tx] - dbPool *pgxpool.Pool + dbPool DBTXWithBegin logger *slog.Logger queries *db.Queries } -func (a *apiHandler) JobCancel(rw http.ResponseWriter, req *http.Request) { - ctx, cancel := context.WithTimeout(req.Context(), 5*time.Second) - defer cancel() +// SetBundle sets all values to the same as the given bundle. +func (a *apiBundle) SetBundle(bundle *apiBundle) { + *a = *bundle +} - var payload jobCancelRequest - if err := json.NewDecoder(req.Body).Decode(&payload); err != nil { - a.logger.ErrorContext(ctx, "error decoding request", slog.String("error", err.Error())) - http.Error(rw, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) - return - } - jobIDs, err := stringIDsToInt64s(payload.JobIDStrings) - if err != nil { - a.logger.ErrorContext(ctx, "error decoding job IDs", slog.String("error", err.Error())) - http.Error(rw, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) - return +// withSetBundle is an interface that's automatically implemented by types that +// embed apiBundle. It lets places like tests generically set bundle values on +// any general endpoint type. +type withSetBundle interface { + // SetBundle sets all values to the same as the given bundle. + SetBundle(bundle *apiBundle) +} + +type jobCancelEndpoint struct { + apiBundle + apiendpoint.Endpoint[jobCancelRequest, jobCancelResponse] +} + +func (*jobCancelEndpoint) Meta() *apiendpoint.EndpointMeta { + return &apiendpoint.EndpointMeta{ + Pattern: "POST /api/jobs/cancel", + StatusCode: http.StatusOK, } +} + +type jobCancelRequest struct { + JobIDs []int64String `json:"ids"` +} - updatedJobs := make(map[int64]*rivertype.JobRow) +type jobCancelResponse struct { + Status string `json:"status"` +} - if err := pgx.BeginFunc(ctx, a.dbPool, func(tx pgx.Tx) error { - for _, jobID := range jobIDs { +func (a *jobCancelEndpoint) Execute(ctx context.Context, req *jobCancelRequest) (*jobCancelResponse, error) { + return dbutil.WithTxV(ctx, a.dbPool, func(ctx context.Context, tx pgx.Tx) (*jobCancelResponse, error) { + updatedJobs := make(map[int64]*rivertype.JobRow) + for _, jobID := range req.JobIDs { + jobID := int64(jobID) job, err := a.client.JobCancelTx(ctx, tx, jobID) if err != nil { if errors.Is(err, river.ErrNotFound) { - fmt.Printf("job %d not found\n", jobID) + return nil, apierror.NewNotFoundJob(jobID) } - return err + return nil, fmt.Errorf("error canceling jobs: %w", err) } updatedJobs[jobID] = job } - return nil - }); err != nil { - a.logger.ErrorContext(ctx, "error cancelling jobs", slog.String("error", err.Error())) - http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - return - } - // TODO: return jobs in response, use in frontend instead of invalidating - a.writeResponse(ctx, rw, []byte("{\"status\": \"ok\"}")) + // TODO: return jobs in response, use in frontend instead of invalidating + return &jobCancelResponse{Status: "ok"}, nil + }) +} + +type apiHandler struct { + apiBundle } func (a *apiHandler) writeResponse(ctx context.Context, rw http.ResponseWriter, data []byte) { @@ -175,38 +190,45 @@ func (a *apiHandler) JobRetry(rw http.ResponseWriter, req *http.Request) { a.writeResponse(ctx, rw, []byte("{\"status\": \"ok\"}")) } -func (a *apiHandler) JobGet(rw http.ResponseWriter, req *http.Request) { - ctx, cancel := context.WithTimeout(req.Context(), 5*time.Second) - defer cancel() +type jobGetEndpoint struct { + apiBundle + apiendpoint.Endpoint[jobGetRequest, RiverJob] +} - idString := req.PathValue("id") - if idString == "" { - http.Error(rw, "missing job id", http.StatusBadRequest) - return +func (*jobGetEndpoint) Meta() *apiendpoint.EndpointMeta { + return &apiendpoint.EndpointMeta{ + Pattern: "GET /api/jobs/{job_id}", + StatusCode: http.StatusOK, } +} + +type jobGetRequest struct { + JobID int64 `json:"-"` // from ExtractRaw +} + +func (req *jobGetRequest) ExtractRaw(r *http.Request) error { + idString := r.PathValue("job_id") jobID, err := strconv.ParseInt(idString, 10, 64) if err != nil { - http.Error(rw, fmt.Sprintf("invalid job id: %s", err), http.StatusBadRequest) - return + return apierror.NewBadRequest("Couldn't convert job ID to int64: %s.", err) } + req.JobID = jobID - job, err := a.client.JobGet(ctx, jobID) - if errors.Is(err, river.ErrNotFound) { - http.Error(rw, "{\"error\": {\"msg\": \"job not found\"}}", http.StatusNotFound) - return - } - if err != nil { - a.logger.ErrorContext(ctx, "error getting job", slog.String("error", err.Error())) - http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - return - } + return nil +} - if err := json.NewEncoder(rw).Encode(riverJobToSerializableJob(*job)); err != nil { - a.logger.ErrorContext(ctx, "error encoding job", slog.String("error", err.Error())) - http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - return - } +func (a *jobGetEndpoint) Execute(ctx context.Context, req *jobGetRequest) (*RiverJob, error) { + return dbutil.WithTxV(ctx, a.dbPool, func(ctx context.Context, tx pgx.Tx) (*RiverJob, error) { + job, err := a.client.JobGetTx(ctx, tx, req.JobID) + if err != nil { + if errors.Is(err, river.ErrNotFound) { + return nil, apierror.NewNotFoundJob(req.JobID) + } + return nil, fmt.Errorf("error getting job: %w", err) + } + return riverJobToSerializableJob(job), nil + }) } func (a *apiHandler) JobList(rw http.ResponseWriter, req *http.Request) { @@ -544,7 +566,7 @@ func internalJobsToSerializableJobs(internal []db.RiverJob) []RiverJob { return jobs } -func riverJobToSerializableJob(riverJob rivertype.JobRow) RiverJob { +func riverJobToSerializableJob(riverJob *rivertype.JobRow) *RiverJob { attemptedBy := riverJob.AttemptedBy if attemptedBy == nil { attemptedBy = []string{} @@ -554,7 +576,7 @@ func riverJobToSerializableJob(riverJob rivertype.JobRow) RiverJob { errs = []rivertype.AttemptError{} } - return RiverJob{ + return &RiverJob{ ID: riverJob.ID, Args: riverJob.EncodedArgs, Attempt: riverJob.Attempt, @@ -574,10 +596,10 @@ func riverJobToSerializableJob(riverJob rivertype.JobRow) RiverJob { } } -func riverJobsToSerializableJobs(result *river.JobListResult) []RiverJob { - jobs := make([]RiverJob, len(result.Jobs)) +func riverJobsToSerializableJobs(result *river.JobListResult) []*RiverJob { + jobs := make([]*RiverJob, len(result.Jobs)) for i, internalJob := range result.Jobs { - jobs[i] = riverJobToSerializableJob(*internalJob) + jobs[i] = riverJobToSerializableJob(internalJob) } return jobs } diff --git a/api_handler_test.go b/api_handler_test.go new file mode 100644 index 00000000..61043128 --- /dev/null +++ b/api_handler_test.go @@ -0,0 +1,112 @@ +package riverui + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river" + "github.com/riverqueue/river/rivertype" + "github.com/riverqueue/riverui/internal/apierror" + "github.com/riverqueue/riverui/internal/db" + "github.com/riverqueue/riverui/internal/riverinternaltest" +) + +type setupEndpointTestBundle struct { + client *river.Client[pgx.Tx] + tx pgx.Tx +} + +func setupEndpoint[TEndpoint any](ctx context.Context, t *testing.T) (*TEndpoint, *setupEndpointTestBundle) { + t.Helper() + + var ( + endpoint TEndpoint + logger = riverinternaltest.Logger(t) + client = insertOnlyClient(t, logger) + tx = riverinternaltest.TestTx(ctx, t) + ) + + if withSetBundle, ok := any(&endpoint).(withSetBundle); ok { + withSetBundle.SetBundle(&apiBundle{ + client: client, + dbPool: tx, + logger: logger, + queries: db.New(tx), + }) + } + + return &endpoint, &setupEndpointTestBundle{ + client: client, + tx: tx, + } +} + +func TestAPIHandlerJobCancel(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + t.Run("Success", func(t *testing.T) { + t.Parallel() + + endpoint, bundle := setupEndpoint[jobCancelEndpoint](ctx, t) + + insertRes1, err := bundle.client.InsertTx(ctx, bundle.tx, &noOpArgs{}, nil) + require.NoError(t, err) + + insertRes2, err := bundle.client.InsertTx(ctx, bundle.tx, &noOpArgs{}, nil) + require.NoError(t, err) + + resp, err := endpoint.Execute(ctx, &jobCancelRequest{JobIDs: []int64String{int64String(insertRes1.Job.ID), int64String(insertRes2.Job.ID)}}) + require.NoError(t, err) + require.Equal(t, &jobCancelResponse{Status: "ok"}, resp) + + updatedJob1, err := bundle.client.JobGetTx(ctx, bundle.tx, insertRes1.Job.ID) + require.NoError(t, err) + require.Equal(t, rivertype.JobStateCancelled, updatedJob1.State) + + updatedJob2, err := bundle.client.JobGetTx(ctx, bundle.tx, insertRes2.Job.ID) + require.NoError(t, err) + require.Equal(t, rivertype.JobStateCancelled, updatedJob2.State) + }) + + t.Run("NotFound", func(t *testing.T) { + t.Parallel() + + endpoint, _ := setupEndpoint[jobCancelEndpoint](ctx, t) + + _, err := endpoint.Execute(ctx, &jobCancelRequest{JobIDs: []int64String{123}}) + requireAPIError(t, apierror.NewNotFoundJob(123), err) + }) +} + +func TestAPIHandlerJobGet(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + t.Run("Success", func(t *testing.T) { + t.Parallel() + + endpoint, bundle := setupEndpoint[jobGetEndpoint](ctx, t) + + insertRes, err := bundle.client.InsertTx(ctx, bundle.tx, &noOpArgs{}, nil) + require.NoError(t, err) + + resp, err := endpoint.Execute(ctx, &jobGetRequest{JobID: insertRes.Job.ID}) + require.NoError(t, err) + require.Equal(t, insertRes.Job.ID, resp.ID) + }) + + t.Run("NotFound", func(t *testing.T) { + t.Parallel() + + endpoint, _ := setupEndpoint[jobGetEndpoint](ctx, t) + + _, err := endpoint.Execute(ctx, &jobGetRequest{JobID: 123}) + requireAPIError(t, apierror.NewNotFoundJob(123), err) + }) +} diff --git a/common_test.go b/common_test.go new file mode 100644 index 00000000..1cd43d80 --- /dev/null +++ b/common_test.go @@ -0,0 +1,61 @@ +package riverui + +import ( + "context" + "encoding/json" + "log/slog" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river" + "github.com/riverqueue/river/riverdriver/riverpgxv5" +) + +type noOpArgs struct { + Name string `json:"name"` +} + +func (noOpArgs) Kind() string { return "noOp" } + +type noOpWorker struct { + river.WorkerDefaults[noOpArgs] +} + +func (w *noOpWorker) Work(_ context.Context, _ *river.Job[noOpArgs]) error { return nil } + +func insertOnlyClient(t *testing.T, logger *slog.Logger) *river.Client[pgx.Tx] { + t.Helper() + + workers := river.NewWorkers() + river.AddWorker(workers, &noOpWorker{}) + + client, err := river.NewClient(riverpgxv5.New(nil), &river.Config{ + Logger: logger, + Workers: workers, + }) + require.NoError(t, err) + + return client +} + +func mustMarshalJSON(t *testing.T, v any) []byte { + t.Helper() + + data, err := json.Marshal(v) + require.NoError(t, err) + return data +} + +// Requires that err is an equivalent API error to expectedErr. +// +// TError is a pointer to an API error type like *apierror.NotFound. +func requireAPIError[TError error](t *testing.T, expectedErr TError, err error) { + t.Helper() + + require.Error(t, err) + var apiErr TError + require.ErrorAs(t, err, &apiErr) + require.Equal(t, expectedErr, apiErr) +} diff --git a/docs/development.md b/docs/development.md index dd8c6366..ff4d8171 100644 --- a/docs/development.md +++ b/docs/development.md @@ -2,7 +2,7 @@ River UI consists of two apps: a Go backend API, and a TypeScript UI frontend. -### Migrate database +## Migrate database ```sh cp .env.sample .env @@ -14,7 +14,7 @@ $ go install github.com/riverqueue/river/cmd/river $ river migrate-up --database-url postgres://localhost/river-development ``` -### Go API +## Go API ```sh go build ./cmd/riverui && ./riverui @@ -24,7 +24,22 @@ By default it starts at http://localhost:8080. The API will need a build TypeScript UI in `ui/dist`, or you'll have to serve it separately (see below). -### TypeScript UI +## Run tests + +Raise test database: + +```sh +$ createdb river-test +$ river migrate-up --database-url postgres://localhost/river-test +``` + +Run tests: + +```sh +$ go test ./... +``` + +## TypeScript UI The UI lives in the `ui/` subdirectory. Go to it and install dependencies: diff --git a/go.mod b/go.mod index 34078b80..ba8d8788 100644 --- a/go.mod +++ b/go.mod @@ -10,19 +10,23 @@ require ( github.com/riverqueue/river/rivertype v0.7.0 github.com/rs/cors v1.10.0 github.com/samber/slog-http v1.0.0 + github.com/stretchr/testify v1.9.0 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/riverqueue/river/riverdriver v0.7.0 // indirect go.opentelemetry.io/otel v1.19.0 // indirect go.opentelemetry.io/otel/trace v1.19.0 // indirect golang.org/x/crypto v0.22.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/text v0.16.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) // replace github.com/riverqueue/river => ../river diff --git a/go.sum b/go.sum index 5658edac..422b8837 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,10 @@ github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -33,6 +37,8 @@ github.com/riverqueue/river/rivertype v0.7.0 h1:sqnl40ymCfT5DfstHsLbg35hddVccPkP github.com/riverqueue/river/rivertype v0.7.0/go.mod h1:nDd50b/mIdxR/ezQzGS/JiAhBPERA7tUIne21GdfspQ= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/rs/cors v1.10.0 h1:62NOS1h+r8p1mW6FM0FSB0exioXLhd/sh15KpjWBZ+8= github.com/rs/cors v1.10.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/samber/slog-http v1.0.0 h1:KjxyJm2lOsuWBt904A04qvrp+0ZvOfwDnk6jI8h7/5c= @@ -55,6 +61,8 @@ golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/handler.go b/handler.go index 15fd7ca8..6d460525 100644 --- a/handler.go +++ b/handler.go @@ -1,6 +1,7 @@ package riverui import ( + "context" "errors" "fmt" "io/fs" @@ -10,19 +11,24 @@ import ( "strings" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgxpool" "github.com/riverqueue/river" + "github.com/riverqueue/riverui/internal/apiendpoint" "github.com/riverqueue/riverui/internal/db" "github.com/riverqueue/riverui/ui" ) +type DBTXWithBegin interface { + Begin(ctx context.Context) (pgx.Tx, error) + db.DBTX +} + // HandlerOpts are the options for creating a new Handler. type HandlerOpts struct { // Client is the River client to use for API requests. Client *river.Client[pgx.Tx] // DBPool is the database connection pool to use for API requests. - DBPool *pgxpool.Pool + DBPool DBTXWithBegin // Logger is the logger to use logging errors within the handler. Logger *slog.Logger // Prefix is the path prefix to use for the API and UI HTTP requests. @@ -71,20 +77,22 @@ func NewHandler(opts *HandlerOpts) (http.Handler, error) { fileServer := http.FileServer(httpFS) serveIndex := serveFileContents("index.html", httpFS) - handler := &apiHandler{ + apiBundle := apiBundle{ client: opts.Client, dbPool: opts.DBPool, logger: opts.Logger, queries: db.New(opts.DBPool), } + + handler := &apiHandler{apiBundle: apiBundle} prefix := opts.Prefix mux := http.NewServeMux() mux.HandleFunc("GET /api/jobs", handler.JobList) - mux.HandleFunc("POST /api/jobs/cancel", handler.JobCancel) + apiendpoint.Mount(mux, opts.Logger, &jobCancelEndpoint{apiBundle: apiBundle}) mux.HandleFunc("POST /api/jobs/delete", handler.JobDelete) mux.HandleFunc("POST /api/jobs/retry", handler.JobRetry) - mux.HandleFunc("GET /api/jobs/{id}", handler.JobGet) + apiendpoint.Mount(mux, opts.Logger, &jobGetEndpoint{apiBundle: apiBundle}) mux.HandleFunc("GET /api/queues", handler.QueueList) mux.HandleFunc("GET /api/queues/{name}", handler.QueueGet) mux.HandleFunc("PUT /api/queues/{name}/pause", handler.QueuePause) diff --git a/handler_test.go b/handler_test.go new file mode 100644 index 00000000..ac49b7b8 --- /dev/null +++ b/handler_test.go @@ -0,0 +1,93 @@ +package riverui + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/riverqueue/riverui/internal/riverinternaltest" +) + +func TestNewHandlerIntegration(t *testing.T) { + t.Parallel() + + var ( + ctx = context.Background() + logger = riverinternaltest.Logger(t) + client = insertOnlyClient(t, logger) + tx = riverinternaltest.TestTx(ctx, t) + ) + + // + // Helper functions + // + + makeAPICall := func(t *testing.T, testCaseName, method, path string, payload []byte) { + t.Helper() + + t.Run(testCaseName, func(t *testing.T) { + logger := riverinternaltest.Logger(t) + + // Start a new savepoint so that the state of our test data stays + // prestine between API calls. + tx, err := tx.Begin(ctx) + require.NoError(t, err) + t.Cleanup(func() { tx.Rollback(ctx) }) + + handler, err := NewHandler(&HandlerOpts{ + Client: client, + DBPool: tx, + Logger: logger, + }) + require.NoError(t, err) + + var body io.Reader + if len(payload) > 0 { + body = bytes.NewBuffer(payload) + } + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(method, path, body) + + t.Logf("--> %s %s", method, path) + + handler.ServeHTTP(recorder, req) + + status := recorder.Result().StatusCode //nolint:bodyclose + t.Logf("Response status: %d", status) + + if status >= 200 && status < 300 { + return + } + + // Only print the body in the event of a problem because it may be + // quite sizable. + t.Logf("Response body: %s", recorder.Body.String()) + + require.FailNow(t, "Got non-OK status code making request", "Expected status >= 200 && < 300; got: %d", status) + }) + } + + makeURL := fmt.Sprintf + + // + // Test data + // + + insertRes, err := client.InsertTx(ctx, tx, &noOpArgs{}, nil) + require.NoError(t, err) + job := insertRes.Job + + // + // API calls + // + + makeAPICall(t, "JobCancel", http.MethodPost, makeURL("/api/jobs/cancel"), mustMarshalJSON(t, &jobCancelRequest{JobIDs: []int64String{int64String(job.ID)}})) + makeAPICall(t, "JobGet", http.MethodGet, makeURL("/api/jobs/%d", job.ID), nil) +} diff --git a/int64_string.go b/int64_string.go new file mode 100644 index 00000000..98568e8a --- /dev/null +++ b/int64_string.go @@ -0,0 +1,35 @@ +package riverui + +import ( + "errors" + "fmt" + "strconv" +) + +// int64String is an int64 type that marshals itself as a string and can +// unmarshal from a string to make sure that the use of the entire possible +// range of int64 is safe. +type int64String int64 + +func (i int64String) MarshalJSON() ([]byte, error) { + return []byte(`"` + strconv.FormatInt(int64(i), 10) + `"`), nil +} + +func (i *int64String) UnmarshalJSON(data []byte) error { + if len(data) < 1 { + return errors.New("can't unmarshal empty int64 string value") + } + + str := string(data) + if str[0] == '"' && len(str) > 1 { + str = str[1 : len(str)-1] + } + + parsedInt, err := strconv.ParseInt(str, 10, 64) + if err != nil { + return fmt.Errorf("error parsing int64 string: %w", err) + } + + *i = int64String(parsedInt) + return nil +} diff --git a/int64_string_test.go b/int64_string_test.go new file mode 100644 index 00000000..2811a7bc --- /dev/null +++ b/int64_string_test.go @@ -0,0 +1,38 @@ +package riverui + +import ( + "encoding/json" + "math" + "strconv" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestInt64String(t *testing.T) { + t.Parallel() + + t.Run("MarshalJSON", func(t *testing.T) { + t.Parallel() + + require.Equal(t, `"123"`, string(mustMarshalJSON(t, int64String(123)))) + }) + + t.Run("UnmarshalJSON", func(t *testing.T) { + t.Parallel() + + var myLargeInt int64String + + // With quotes. + require.NoError(t, json.Unmarshal([]byte(`"123"`), &myLargeInt)) + require.Equal(t, int64String(123), myLargeInt) + + // Without quotes. + require.NoError(t, json.Unmarshal([]byte(`123`), &myLargeInt)) + require.Equal(t, int64String(123), myLargeInt) + + // Integer larger than JSON's maximum number size. + require.NoError(t, json.Unmarshal([]byte(`"`+strconv.FormatInt(math.MaxInt64, 10)+`"`), &myLargeInt)) + require.Equal(t, int64String(math.MaxInt64), myLargeInt) + }) +} diff --git a/internal/apiendpoint/api_endpoint.go b/internal/apiendpoint/api_endpoint.go new file mode 100644 index 00000000..52288e0c --- /dev/null +++ b/internal/apiendpoint/api_endpoint.go @@ -0,0 +1,168 @@ +// Package apiendpoint provides a lightweight API framework for use with River +// UI. It lets API endpoints be defined, then mounted into an http.ServeMux. +package apiendpoint + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "time" + + "github.com/riverqueue/riverui/internal/apierror" +) + +// Endpoint is a struct that should be embedded on an API endpoint, and which +// provides a partial implementation for EndpointInterface. +type Endpoint[TReq any, TResp any] struct { + // Logger used to log information about endpoint execution. + logger *slog.Logger + + // Metadata about the endpoint. This is not available until SetMeta is + // invoked on the endpoint, which is usually done in Mount. + meta *EndpointMeta +} + +func (e *Endpoint[TReq, TResp]) SetLogger(logger *slog.Logger) { e.logger = logger } +func (e *Endpoint[TReq, TResp]) SetMeta(meta *EndpointMeta) { e.meta = meta } + +// EndpointInterface is an interface to an API endpoint. Some of it is +// implemented by an embedded Endpoint struct, and some of it should be +// implemented by the endpoint itself. +type EndpointInterface[TReq any, TResp any] interface { + // Execute executes the API endpoint. + // + // This should be implemented by each specific API endpoint. + Execute(ctx context.Context, req *TReq) (*TResp, error) + + // Meta returns metadata about an API endpoint, like the path it should be + // mounted at, and the status code it returns on success. + // + // This should be implemented by each specific API endpoint. + Meta() *EndpointMeta + + // SetLogger sets a logger on the endpoint. + // + // Implementation inherited from an embedded Endpoint struct. + SetLogger(logger *slog.Logger) + + // SetMeta sets metadata on an Endpoint struct after its extracted from a + // call to an endpoint's Meta function. + // + // Implementation inherited from an embedded Endpoint struct. + SetMeta(meta *EndpointMeta) +} + +// EndpointMeta is metadata about an API endpoint. +type EndpointMeta struct { + // Pattern is the API endpoint's HTTP method and path where it should be + // mounted, which is passed to http.ServeMux by Mount. It should start with + // a verb like `GET` or `POST`, and may contain Go 1.22 path variables like + // `{name}`, whose values should be extracted by an endpoint request + // struct's custom ExtractRaw implementation. + Pattern string + + // StatusCode is the status code to be set on a successful response. + StatusCode int +} + +func (m *EndpointMeta) validate() { + if m.Pattern == "" { + panic("Endpoint.Path is required") + } + if m.StatusCode == 0 { + panic("Endpoint.StatusCode is required") + } +} + +// Mount mounts an endpoint to a Go http.ServeMux. The logger is used to log +// information about endpoint execution. +func Mount[TReq any, TResp any](mux *http.ServeMux, logger *slog.Logger, apiEndpoint EndpointInterface[TReq, TResp]) { + apiEndpoint.SetLogger(logger) + + meta := apiEndpoint.Meta() + meta.validate() // panic on problem + apiEndpoint.SetMeta(meta) + + mux.HandleFunc(meta.Pattern, func(w http.ResponseWriter, r *http.Request) { + executeAPIEndpoint(w, r, logger, meta, apiEndpoint.Execute) + }) +} + +func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Request, logger *slog.Logger, meta *EndpointMeta, execute func(ctx context.Context, req *TReq) (*TResp, error)) { + ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second) + defer cancel() + + // Run as much code as we can in a sub-function that can return an error. + // This is more convenient to write, but is also safer because unlike when + // writing errors to ResponseWriter, there's no danger of a missing return. + err := func() error { + var req TReq + if r.Method != http.MethodGet { + reqData, err := io.ReadAll(r.Body) + if err != nil { + return fmt.Errorf("error reading request body: %w", err) + } + + if len(reqData) > 0 { + if err := json.Unmarshal(reqData, &req); err != nil { + return apierror.NewBadRequest("Error unmarshaling request body: %s.", err) + } + } + } + + if rawExtractor, ok := any(&req).(RawExtractor); ok { + if err := rawExtractor.ExtractRaw(r); err != nil { + return err + } + } + + resp, err := execute(ctx, &req) + if err != nil { + return err + } + + respData, err := json.Marshal(resp) + if err != nil { + return fmt.Errorf("error marshaling response JSON: %w", err) + } + + w.WriteHeader(meta.StatusCode) + + if _, err := w.Write(respData); err != nil { + return fmt.Errorf("error writing response: %w", err) + } + + return nil + }() + if err != nil { + var apiErr apierror.Interface + if errors.As(err, &apiErr) { + // Logged at info level because API errors are normal. + logger.InfoContext(ctx, "API error response", slog.String("error", apiErr.Error())) + apiErr.Write(ctx, logger, w) + return + } + + if errors.Is(err, context.DeadlineExceeded) { + logger.ErrorContext(ctx, "request timeout", slog.String("error", err.Error())) + apierror.NewServiceUnavailable("Request timed out. Retrying the request might work.").Write(ctx, logger, w) + return + } + + // Internal server error. The error goes to logs but should not be + // included in the response in case there's something sensitive in + // the error string. + logger.ErrorContext(ctx, "error running API route", slog.String("error", err.Error())) + apierror.NewInternalServerError("Internal server error. Check logs for more information.").Write(ctx, logger, w) + } +} + +// RawExtractor is an interface that can be implemented by request structs that +// allows them to extract information from a raw request, like path values. +type RawExtractor interface { + ExtractRaw(r *http.Request) error +} diff --git a/internal/apiendpoint/api_endpoint_test.go b/internal/apiendpoint/api_endpoint_test.go new file mode 100644 index 00000000..168d4c20 --- /dev/null +++ b/internal/apiendpoint/api_endpoint_test.go @@ -0,0 +1,220 @@ +package apiendpoint + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/riverqueue/riverui/internal/apierror" + "github.com/riverqueue/riverui/internal/riverinternaltest" +) + +func TestMountAndServe(t *testing.T) { + t.Parallel() + + type testBundle struct { + recorder *httptest.ResponseRecorder + } + + setup := func(t *testing.T) (*http.ServeMux, *testBundle) { + t.Helper() + + var ( + logger = riverinternaltest.Logger(t) + mux = http.NewServeMux() + ) + + Mount(mux, logger, &getEndpoint{}) + Mount(mux, logger, &postEndpoint{}) + + return mux, &testBundle{ + recorder: httptest.NewRecorder(), + } + } + + t.Run("GetEndpointAndExtractRaw", func(t *testing.T) { + t.Parallel() + + mux, bundle := setup(t) + + req := httptest.NewRequest(http.MethodGet, "/api/get-endpoint/Hello.", nil) + mux.ServeHTTP(bundle.recorder, req) + + requireStatusAndJSONResponse(t, http.StatusOK, &postResponse{Message: "Hello."}, bundle.recorder) + }) + + t.Run("BodyIgnoredOnGet", func(t *testing.T) { + t.Parallel() + + mux, bundle := setup(t) + + req := httptest.NewRequest(http.MethodGet, "/api/get-endpoint/Hello.", + bytes.NewBuffer(mustMarshalJSON(t, &getRequest{IgnoredJSONMessage: "Ignored hello."}))) + mux.ServeHTTP(bundle.recorder, req) + + requireStatusAndJSONResponse(t, http.StatusOK, &postResponse{Message: "Hello."}, bundle.recorder) + }) + + t.Run("MethodNotAllowed", func(t *testing.T) { + t.Parallel() + + mux, bundle := setup(t) + + req := httptest.NewRequest(http.MethodPost, "/api/get-endpoint/Hello.", nil) + mux.ServeHTTP(bundle.recorder, req) + + // This error comes from net/http. + requireStatusAndResponse(t, http.StatusMethodNotAllowed, "Method Not Allowed\n", bundle.recorder) + }) + + t.Run("PostEndpoint", func(t *testing.T) { + t.Parallel() + + mux, bundle := setup(t) + + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", + bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."}))) + mux.ServeHTTP(bundle.recorder, req) + + requireStatusAndJSONResponse(t, http.StatusCreated, &postResponse{Message: "Hello."}, bundle.recorder) + }) + + t.Run("APIError", func(t *testing.T) { + t.Parallel() + + mux, bundle := setup(t) + + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", nil) + mux.ServeHTTP(bundle.recorder, req) + + requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Missing message value."}, bundle.recorder) + }) + + t.Run("InternalServerError", func(t *testing.T) { + t.Parallel() + + mux, bundle := setup(t) + + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", + bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeInternalError: true}))) + mux.ServeHTTP(bundle.recorder, req) + + requireStatusAndJSONResponse(t, http.StatusInternalServerError, &apierror.APIError{Message: "Internal server error. Check logs for more information."}, bundle.recorder) + }) +} + +func mustMarshalJSON(t *testing.T, v any) []byte { + t.Helper() + + data, err := json.Marshal(v) + require.NoError(t, err) + return data +} + +func mustUnmarshalJSON[T any](t *testing.T, data []byte) *T { + t.Helper() + + var val T + err := json.Unmarshal(data, &val) + require.NoError(t, err) + return &val +} + +// Shortcut for requiring an HTTP status code and a JSON-marshaled response +// equivalent to expectedResp. The important thing that is does is that in the +// event of a failure on status code, it prints the response body as additional +// context to help debug the problem. +func requireStatusAndJSONResponse[T any](t *testing.T, expectedStatusCode int, expectedResp *T, recorder *httptest.ResponseRecorder) { + t.Helper() + + require.Equal(t, expectedStatusCode, recorder.Result().StatusCode, "Unexpected status code; response body: %s", recorder.Body.String()) //nolint:bodyclose + require.Equal(t, expectedResp, mustUnmarshalJSON[T](t, recorder.Body.Bytes())) +} + +// Same as the above, but for a non-JSON response. +func requireStatusAndResponse(t *testing.T, expectedStatusCode int, expectedResp string, recorder *httptest.ResponseRecorder) { + t.Helper() + + require.Equal(t, expectedStatusCode, recorder.Result().StatusCode, "Unexpected status code; response body: %s", recorder.Body.String()) //nolint:bodyclose + require.Equal(t, expectedResp, recorder.Body.String()) +} + +// +// getEndpoint +// + +type getEndpoint struct { + Endpoint[getRequest, getResponse] +} + +func (*getEndpoint) Meta() *EndpointMeta { + return &EndpointMeta{ + Pattern: "GET /api/get-endpoint/{message}", + StatusCode: http.StatusOK, + } +} + +type getRequest struct { + IgnoredJSONMessage string `json:"ignored_json"` + Message string `json:"-"` +} + +func (req *getRequest) ExtractRaw(r *http.Request) error { + req.Message = r.PathValue("message") + return nil +} + +type getResponse struct { + Message string `json:"message"` +} + +func (a *getEndpoint) Execute(_ context.Context, req *getRequest) (*getResponse, error) { + // This branch never gets taken because request bodies are ignored on GET. + if req.IgnoredJSONMessage != "" { + return &getResponse{Message: req.IgnoredJSONMessage}, nil + } + + return &getResponse{Message: req.Message}, nil +} + +// +// postEndpoint +// + +type postEndpoint struct { + Endpoint[postRequest, postResponse] +} + +func (*postEndpoint) Meta() *EndpointMeta { + return &EndpointMeta{ + Pattern: "POST /api/post-endpoint", + StatusCode: http.StatusCreated, + } +} + +type postRequest struct { + MakeInternalError bool `json:"make_internal_error"` + Message string `json:"message"` +} + +type postResponse struct { + Message string `json:"message"` +} + +func (a *postEndpoint) Execute(_ context.Context, req *postRequest) (*postResponse, error) { + if req.MakeInternalError { + return nil, errors.New("an internal error occurred") + } + + if req.Message == "" { + return nil, apierror.NewBadRequest("Missing message value.") + } + + return &postResponse{Message: req.Message}, nil +} diff --git a/internal/apierror/api_error.go b/internal/apierror/api_error.go new file mode 100644 index 00000000..ac7bc257 --- /dev/null +++ b/internal/apierror/api_error.go @@ -0,0 +1,122 @@ +// Package apierror contains a variety of marshalable API errors that adhere to +// a unified error response convention. +package apierror + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" +) + +// APIError is a struct that's embedded on a more specific API error struct (as +// seen below), and which provides a JSON serialization and a wait to +// conveniently write itself to an HTTP response. +// +// APIErrorInterface should be used with errors.As instead of this struct. +type APIError struct { + // Message is a descriptive, human-friendly message indicating what went + // wrong. Try to make error messages as actionable as possible to help the + // caller easily fix what went wrong. + Message string `json:"message"` + + // StatusCode is the API error's HTTP status code. It's not marshaled to + // JSON, but determines how the error is written to a response. + StatusCode int `json:"-"` +} + +func (e *APIError) Error() string { return e.Message } + +// Write writes the API error to an HTTP response, writing to the given logger +// in case of a problem. +func (e *APIError) Write(ctx context.Context, logger *slog.Logger, w http.ResponseWriter) { + w.WriteHeader(e.StatusCode) + + respData, err := json.Marshal(e) + if err != nil { + logger.ErrorContext(ctx, "error marshaling API error", slog.String("error", err.Error())) + } + + if _, err := w.Write(respData); err != nil { + logger.ErrorContext(ctx, "error writing API error", slog.String("error", err.Error())) + } +} + +// Interface is an interface to an API error. This is needed for use with +// errors.As because APIError itself is emedded on another error struct, and +// won't be usable as an errors.As target. +type Interface interface { + Error() string + Write(ctx context.Context, logger *slog.Logger, w http.ResponseWriter) +} + +// +// BadRequest +// + +type BadRequest struct { + APIError +} + +func NewBadRequest(format string, a ...any) *BadRequest { + return &BadRequest{ + APIError: APIError{ + Message: fmt.Sprintf(format, a...), + StatusCode: http.StatusBadRequest, + }, + } +} + +// +// InternalServerError +// + +type InternalServerError struct { + APIError +} + +func NewInternalServerError(format string, a ...any) *InternalServerError { + return &InternalServerError{ + APIError: APIError{ + Message: fmt.Sprintf(format, a...), + StatusCode: http.StatusInternalServerError, + }, + } +} + +// +// NotFound +// + +type NotFound struct { + APIError +} + +func NewNotFound(format string, a ...any) *NotFound { + return &NotFound{ + APIError: APIError{ + Message: fmt.Sprintf(format, a...), + StatusCode: http.StatusNotFound, + }, + } +} + +func NewNotFoundJob(jobID int64) *NotFound { return NewNotFound("Job not found: %d.", jobID) } + +// +// ServiceUnavailable +// + +type ServiceUnavailable struct { + APIError +} + +func NewServiceUnavailable(format string, a ...any) *ServiceUnavailable { + return &ServiceUnavailable{ + APIError: APIError{ + Message: fmt.Sprintf(format, a...), + StatusCode: http.StatusServiceUnavailable, + }, + } +} diff --git a/internal/apierror/api_error_test.go b/internal/apierror/api_error_test.go new file mode 100644 index 00000000..060e8b8a --- /dev/null +++ b/internal/apierror/api_error_test.go @@ -0,0 +1,49 @@ +package apierror + +import ( + "context" + "encoding/json" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/riverqueue/riverui/internal/riverinternaltest" +) + +func TestAPIErrorJSON(t *testing.T) { + t.Parallel() + + require.Equal(t, + `{"message":"Bad request. Try sending JSON next time."}`, + string(mustMarshalJSON( + t, NewBadRequest("Bad request. Try sending JSON next time.")), + ), + ) +} + +func TestAPIErrorWrite(t *testing.T) { + t.Parallel() + + var ( + ctx = context.Background() + logger = riverinternaltest.Logger(t) + recorder = httptest.NewRecorder() + ) + + NewBadRequest("Bad request. Try sending JSON next time.").Write(ctx, logger, recorder) + + require.Equal(t, 400, recorder.Result().StatusCode) //nolint:bodyclose + require.Equal(t, + `{"message":"Bad request. Try sending JSON next time."}`, + recorder.Body.String(), + ) +} + +func mustMarshalJSON(t *testing.T, v any) []byte { + t.Helper() + + data, err := json.Marshal(v) + require.NoError(t, err) + return data +} diff --git a/internal/riverinternaltest/riverinternaltest.go b/internal/riverinternaltest/riverinternaltest.go new file mode 100644 index 00000000..a8673a41 --- /dev/null +++ b/internal/riverinternaltest/riverinternaltest.go @@ -0,0 +1,117 @@ +package riverinternaltest + +import ( + "context" + "errors" + "log/slog" + "os" + "sync" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/require" + + "github.com/riverqueue/riverui/internal/riverinternaltest/slogtest" +) + +// Logger returns a logger suitable for use in tests. +// +// Defaults to informational verbosity. If env is set with `RIVER_DEBUG=true`, +// debug level verbosity is activated. +func Logger(tb testing.TB) *slog.Logger { + tb.Helper() + + if os.Getenv("RIVER_DEBUG") == "1" || os.Getenv("RIVER_DEBUG") == "true" { + return slogtest.NewLogger(tb, &slog.HandlerOptions{Level: slog.LevelDebug}) + } + + return slogtest.NewLogger(tb, nil) +} + +// A pool and mutex to protect it, lazily initialized by TestTx. Once open, this +// pool is never explicitly closed, instead closing implicitly as the package +// tests finish. +var ( + dbPool *pgxpool.Pool //nolint:gochecknoglobals + dbPoolMu sync.RWMutex //nolint:gochecknoglobals +) + +// TestTx starts a test transaction that's rolled back automatically as the test +// case is cleaning itself up. This can be used as a lighter weight alternative +// to `testdb.Manager` in components where it's not necessary to have many +// connections open simultaneously. +func TestTx(ctx context.Context, tb testing.TB) pgx.Tx { + tb.Helper() + + tryPool := func() *pgxpool.Pool { + dbPoolMu.RLock() + defer dbPoolMu.RUnlock() + return dbPool + } + + getPool := func() *pgxpool.Pool { + if dbPool := tryPool(); dbPool != nil { + return dbPool + } + + dbPoolMu.Lock() + defer dbPoolMu.Unlock() + + // Multiple goroutines may have passed the initial `nil` check on start + // up, so check once more to make sure pool hasn't been set yet. + if dbPool != nil { + return dbPool + } + + testDatabaseURL := os.Getenv("TEST_DATABASE_URL") + if testDatabaseURL == "" { + testDatabaseURL = "postgres://localhost/river-test" + } + + var err error + dbPool, err = pgxpool.New(ctx, testDatabaseURL) + require.NoError(tb, err) + + return dbPool + } + + tx, err := getPool().Begin(ctx) + require.NoError(tb, err) + + tb.Cleanup(func() { + err := tx.Rollback(ctx) + + if err == nil { + return + } + + // Try to look for an error on rollback because it does occasionally + // reveal a real problem in the way a test is written. However, allow + // tests to roll back their transaction early if they like, so ignore + // `ErrTxClosed`. + if errors.Is(err, pgx.ErrTxClosed) { + return + } + + // In case of a cancelled context during a database operation, which + // happens in many tests, pgx seems to not only roll back the + // transaction, but closes the connection, and returns this error on + // rollback. Allow this error since it's hard to prevent it in our flows + // that use contexts heavily. + if err.Error() == "conn closed" { + return + } + + // Similar to the above, but a newly appeared error that wraps the + // above. As far as I can tell, no error variables are available to use + // with `errors.Is`. + if err.Error() == "failed to deallocate cached statement(s): conn closed" { + return + } + + require.NoError(tb, err) + }) + + return tx +} diff --git a/internal/riverinternaltest/slogtest/slog_test_handler.go b/internal/riverinternaltest/slogtest/slog_test_handler.go new file mode 100644 index 00000000..7c1cfed4 --- /dev/null +++ b/internal/riverinternaltest/slogtest/slog_test_handler.go @@ -0,0 +1,86 @@ +package slogtest + +import ( + "bytes" + "context" + "io" + "log/slog" + "sync" + "testing" +) + +// NewLogger returns a new slog text logger that outputs to `t.Log`. This helps +// keep test output better formatted, and allows it to be differentiated in case +// of a failure during a parallel test suite run. +func NewLogger(tb testing.TB, opts *slog.HandlerOptions) *slog.Logger { + tb.Helper() + + var buf bytes.Buffer + + textHandler := slog.NewTextHandler(&buf, opts) + + return slog.New(&slogTestHandler{ + buf: &buf, + inner: textHandler, + mu: &sync.Mutex{}, + tb: tb, + }) +} + +type slogTestHandler struct { + buf *bytes.Buffer + inner slog.Handler + mu *sync.Mutex + tb testing.TB +} + +func (b *slogTestHandler) Enabled(ctx context.Context, level slog.Level) bool { + return b.inner.Enabled(ctx, level) +} + +func (b *slogTestHandler) Handle(ctx context.Context, rec slog.Record) error { + b.mu.Lock() + defer b.mu.Unlock() + + err := b.inner.Handle(ctx, rec) + if err != nil { + return err + } + + output, err := io.ReadAll(b.buf) + if err != nil { + return err + } + + // t.Log adds its own newline, so trim the one from slog. + output = bytes.TrimSuffix(output, []byte("\n")) + + // Register as a helper, but unfortunately still not enough to fix the + // reported callsite of the log line and it'll still show `logger.go` from + // slog's internals. See explanation and discussion here: + // + // https://github.com/neilotoole/slogt#deficiency + b.tb.Helper() + + b.tb.Log(string(output)) + + return nil +} + +func (b *slogTestHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &slogTestHandler{ + buf: b.buf, + inner: b.inner.WithAttrs(attrs), + mu: b.mu, + tb: b.tb, + } +} + +func (b *slogTestHandler) WithGroup(name string) slog.Handler { + return &slogTestHandler{ + buf: b.buf, + inner: b.inner.WithGroup(name), + mu: b.mu, + tb: b.tb, + } +} diff --git a/internal/riverinternaltest/slogtest/slog_test_handler_test.go b/internal/riverinternaltest/slogtest/slog_test_handler_test.go new file mode 100644 index 00000000..94749659 --- /dev/null +++ b/internal/riverinternaltest/slogtest/slog_test_handler_test.go @@ -0,0 +1,58 @@ +package slogtest + +import ( + "log/slog" + "sync" + "testing" +) + +// This test doesn't assert anything due to the inherent difficulty of testing +// this test helper, but it can be run with `-test.v` to observe that it's +// working correctly. +func TestSlogTestHandler_levels(t *testing.T) { + t.Parallel() + + testCases := []struct { + desc string + level slog.Level + }{ + {desc: "Debug", level: slog.LevelDebug}, + {desc: "Info", level: slog.LevelInfo}, + {desc: "Warn", level: slog.LevelWarn}, + {desc: "Error", level: slog.LevelError}, + } + for _, tt := range testCases { + tt := tt + t.Run(tt.desc, func(t *testing.T) { + t.Parallel() + + logger := NewLogger(t, &slog.HandlerOptions{Level: tt.level}) + + logger.Debug("debug message") + logger.Info("info message") + logger.Warn("warn message") + logger.Error("error message") + }) + } +} + +func TestSlogTestHandler_stress(t *testing.T) { + t.Parallel() + + var ( + logger = NewLogger(t, nil) + wg sync.WaitGroup + ) + + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + for j := 0; j < 100; j++ { + logger.Info("message", "key", "value") + } + wg.Done() + }() + } + + wg.Wait() +} diff --git a/internal/util/dbutil/db_util.go b/internal/util/dbutil/db_util.go new file mode 100644 index 00000000..4e1f3226 --- /dev/null +++ b/internal/util/dbutil/db_util.go @@ -0,0 +1,44 @@ +package dbutil + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5" +) + +type TxBegin interface { + Begin(ctx context.Context) (pgx.Tx, error) +} + +// WithTx starts and commits a transaction on a driver executor around +// the given function, allowing the return of a generic value. +func WithTx(ctx context.Context, txBegin TxBegin, innerFunc func(ctx context.Context, tx pgx.Tx) error) error { + _, err := WithTxV(ctx, txBegin, func(ctx context.Context, tx pgx.Tx) (struct{}, error) { + return struct{}{}, innerFunc(ctx, tx) + }) + return err +} + +// WithTxV starts and commits a transaction on a driver executor around +// the given function, allowing the return of a generic value. +func WithTxV[T any](ctx context.Context, txBegin TxBegin, innerFunc func(ctx context.Context, exec pgx.Tx) (T, error)) (T, error) { + var defaultRes T + + tx, err := txBegin.Begin(ctx) + if err != nil { + return defaultRes, fmt.Errorf("error beginning transaction: %w", err) + } + defer tx.Rollback(ctx) + + res, err := innerFunc(ctx, tx) + if err != nil { + return defaultRes, err + } + + if err := tx.Commit(ctx); err != nil { + return defaultRes, fmt.Errorf("error committing transaction: %w", err) + } + + return res, nil +} diff --git a/internal/util/dbutil/db_util_test.go b/internal/util/dbutil/db_util_test.go new file mode 100644 index 00000000..2c95db07 --- /dev/null +++ b/internal/util/dbutil/db_util_test.go @@ -0,0 +1,42 @@ +package dbutil + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" + + "github.com/riverqueue/riverui/internal/riverinternaltest" +) + +func TestWithTx(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tx := riverinternaltest.TestTx(ctx, t) + + err := WithTx(ctx, tx, func(ctx context.Context, tx pgx.Tx) error { + _, err := tx.Exec(ctx, "SELECT 1") + require.NoError(t, err) + + return nil + }) + require.NoError(t, err) +} + +func TestWithTxV(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tx := riverinternaltest.TestTx(ctx, t) + + ret, err := WithTxV(ctx, tx, func(ctx context.Context, tx pgx.Tx) (int, error) { + _, err := tx.Exec(ctx, "SELECT 1") + require.NoError(t, err) + + return 7, nil + }) + require.NoError(t, err) + require.Equal(t, 7, ret) +}