diff --git a/client_test.go b/client_test.go index 9d2af82b..31864be5 100644 --- a/client_test.go +++ b/client_test.go @@ -710,7 +710,7 @@ func Test_Client(t *testing.T) { &overridableJobMiddleware{ workFunc: func(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error { middlewareCalled = true - require.Equal(t, `{"name": "inserted name"}`, string(job.EncodedArgs)) + require.JSONEq(t, `{"name": "inserted name"}`, string(job.EncodedArgs)) job.EncodedArgs = []byte(`{"name": "middleware name"}`) return doInner(ctx) }, diff --git a/internal/dblist/db_list.go b/internal/dblist/db_list.go index a3529ca6..64e76b1e 100644 --- a/internal/dblist/db_list.go +++ b/internal/dblist/db_list.go @@ -3,7 +3,6 @@ package dblist import ( "context" "errors" - "fmt" "strings" "github.com/riverqueue/river/riverdriver" @@ -11,17 +10,6 @@ import ( "github.com/riverqueue/river/rivertype" ) -const jobList = `-- name: JobList :many -SELECT - %s -FROM - river_job -%s -ORDER BY - %s -LIMIT @count::integer -` - type SortOrder int const ( @@ -47,7 +35,7 @@ type JobListParams struct { } func JobList(ctx context.Context, exec riverdriver.Executor, params *JobListParams) ([]*rivertype.JobRow, error) { - var conditionsBuilder strings.Builder + var whereBuilder strings.Builder orderBy := make([]JobListOrderBy, len(params.OrderBy)) for i, o := range params.OrderBy { @@ -62,41 +50,44 @@ func JobList(ctx context.Context, exec riverdriver.Executor, params *JobListPara namedArgs = make(map[string]any) } - writeWhereOrAnd := func() { - if conditionsBuilder.Len() == 0 { - conditionsBuilder.WriteString("WHERE\n ") - } else { - conditionsBuilder.WriteString("\n AND ") + writeAndAfterFirst := func() { + if whereBuilder.Len() != 0 { + whereBuilder.WriteString("\n AND ") } } if len(params.Kinds) > 0 { - writeWhereOrAnd() - conditionsBuilder.WriteString("kind = any(@kinds::text[])") + writeAndAfterFirst() + whereBuilder.WriteString("kind = any(@kinds::text[])") namedArgs["kinds"] = params.Kinds } if len(params.Queues) > 0 { - writeWhereOrAnd() - conditionsBuilder.WriteString("queue = any(@queues::text[])") + writeAndAfterFirst() + whereBuilder.WriteString("queue = any(@queues::text[])") namedArgs["queues"] = params.Queues } if len(params.States) > 0 { - writeWhereOrAnd() - conditionsBuilder.WriteString("state = any(@states::river_job_state[])") + writeAndAfterFirst() + whereBuilder.WriteString("state = any(@states::river_job_state[])") namedArgs["states"] = sliceutil.Map(params.States, func(s rivertype.JobState) string { return string(s) }) } if params.Conditions != "" { - writeWhereOrAnd() - conditionsBuilder.WriteString(params.Conditions) + writeAndAfterFirst() + whereBuilder.WriteString(params.Conditions) + } + + // A condition of some kind is needed, so given no others write one that'll + // always return true. + if whereBuilder.Len() < 1 { + whereBuilder.WriteString("1") } if params.LimitCount < 1 { return nil, errors.New("required parameter 'Count' in JobList must be greater than zero") } - namedArgs["count"] = params.LimitCount if len(params.OrderBy) == 0 { return nil, errors.New("sort order is required") @@ -116,7 +107,10 @@ func JobList(ctx context.Context, exec riverdriver.Executor, params *JobListPara } } - sql := fmt.Sprintf(jobList, exec.JobListFields(), conditionsBuilder.String(), orderByBuilder.String()) - - return exec.JobList(ctx, sql, namedArgs) + return exec.JobList(ctx, &riverdriver.JobListParams{ + Max: params.LimitCount, + NamedArgs: namedArgs, + OrderByClause: orderByBuilder.String(), + WhereClause: whereBuilder.String(), + }) } diff --git a/internal/riverinternaltest/riverdrivertest/riverdrivertest.go b/internal/riverinternaltest/riverdrivertest/riverdrivertest.go index fbb77183..86fbfd7d 100644 --- a/internal/riverinternaltest/riverdrivertest/riverdrivertest.go +++ b/internal/riverinternaltest/riverdrivertest/riverdrivertest.go @@ -1281,11 +1281,15 @@ func Exercise[TTx any](ctx context.Context, t *testing.T, UniqueStates: 0xFF, }) - fetchedJobs, err := exec.JobList( - ctx, - fmt.Sprintf("SELECT %s FROM river_job WHERE id = @job_id_123", exec.JobListFields()), - map[string]any{"job_id_123": job.ID}, - ) + // Does not match predicate (makes sure where clause is working). + _ = testfactory.Job(ctx, t, exec, &testfactory.JobOpts{}) + + fetchedJobs, err := exec.JobList(ctx, &riverdriver.JobListParams{ + Max: 100, + NamedArgs: map[string]any{"job_id_123": job.ID}, + OrderByClause: "id", + WhereClause: "id = @job_id_123", + }) require.NoError(t, err) require.Len(t, fetchedJobs, 1) @@ -1316,36 +1320,29 @@ func Exercise[TTx any](ctx context.Context, t *testing.T, job2 := testfactory.Job(ctx, t, exec, &testfactory.JobOpts{Kind: ptrutil.Ptr("test_kind2")}) { - fetchedJobs, err := exec.JobList( - ctx, - fmt.Sprintf("SELECT %s FROM river_job WHERE kind = @kind", exec.JobListFields()), - map[string]any{"kind": job1.Kind}, - ) + fetchedJobs, err := exec.JobList(ctx, &riverdriver.JobListParams{ + Max: 100, + NamedArgs: map[string]any{"kind": job1.Kind}, + OrderByClause: "id", + WhereClause: "kind = @kind", + }) require.NoError(t, err) require.Len(t, fetchedJobs, 1) } { - fetchedJobs, err := exec.JobList( - ctx, - fmt.Sprintf("SELECT %s FROM river_job WHERE kind = any(@kind::text[])", exec.JobListFields()), - map[string]any{"kind": []string{job1.Kind, job2.Kind}}, - ) + fetchedJobs, err := exec.JobList(ctx, &riverdriver.JobListParams{ + Max: 100, + NamedArgs: map[string]any{"kind": []string{job1.Kind, job2.Kind}}, + OrderByClause: "id", + WhereClause: "kind = any(@kind::text[])", + }) require.NoError(t, err) require.Len(t, fetchedJobs, 2) } }) }) - t.Run("JobListFields", func(t *testing.T) { - t.Parallel() - - exec, _ := setup(ctx, t) - - require.Equal(t, "id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states", - exec.JobListFields()) - }) - t.Run("JobRescueMany", func(t *testing.T) { t.Parallel() diff --git a/riverdriver/river_driver_interface.go b/riverdriver/river_driver_interface.go index f6d8ff5a..50884cae 100644 --- a/riverdriver/river_driver_interface.go +++ b/riverdriver/river_driver_interface.go @@ -118,8 +118,7 @@ type Executor interface { JobInsertFastMany(ctx context.Context, params []*JobInsertFastParams) ([]*JobInsertFastResult, error) JobInsertFastManyNoReturning(ctx context.Context, params []*JobInsertFastParams) (int, error) JobInsertFull(ctx context.Context, params *JobInsertFullParams) (*rivertype.JobRow, error) - JobList(ctx context.Context, query string, namedArgs map[string]any) ([]*rivertype.JobRow, error) - JobListFields() string + JobList(ctx context.Context, params *JobListParams) ([]*rivertype.JobRow, error) JobRescueMany(ctx context.Context, params *JobRescueManyParams) (*struct{}, error) JobRetry(ctx context.Context, id int64) (*rivertype.JobRow, error) JobSchedule(ctx context.Context, params *JobScheduleParams) ([]*JobScheduleResult, error) @@ -290,6 +289,13 @@ type JobInsertFullParams struct { UniqueStates byte } +type JobListParams struct { + Max int32 + NamedArgs map[string]any + OrderByClause string + WhereClause string +} + type JobRescueManyParams struct { ID []int64 Error [][]byte diff --git a/riverdriver/riverdatabasesql/internal/dbsqlc/river_job.sql.go b/riverdriver/riverdatabasesql/internal/dbsqlc/river_job.sql.go index 429ff2f8..16446ad3 100644 --- a/riverdriver/riverdatabasesql/internal/dbsqlc/river_job.sql.go +++ b/riverdriver/riverdatabasesql/internal/dbsqlc/river_job.sql.go @@ -91,7 +91,7 @@ func (q *Queries) JobCancel(ctx context.Context, db DBTX, arg *JobCancelParams) const jobCountByState = `-- name: JobCountByState :one SELECT count(*) -FROM river_job +FROM /* TEMPLATE: schema */river_job WHERE state = $1 ` @@ -825,6 +825,56 @@ func (q *Queries) JobInsertFull(ctx context.Context, db DBTX, arg *JobInsertFull return &i, err } +const jobList = `-- name: JobList :many +SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states +FROM river_job +WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */ +ORDER BY /* TEMPLATE_BEGIN: order_by_clause */ id /* TEMPLATE_END */ +LIMIT $1::int +` + +func (q *Queries) JobList(ctx context.Context, db DBTX, max int32) ([]*RiverJob, error) { + rows, err := db.QueryContext(ctx, jobList, max) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverJob + for rows.Next() { + var i RiverJob + if err := rows.Scan( + &i.ID, + &i.Args, + &i.Attempt, + &i.AttemptedAt, + pq.Array(&i.AttemptedBy), + &i.CreatedAt, + pq.Array(&i.Errors), + &i.FinalizedAt, + &i.Kind, + &i.MaxAttempts, + &i.Metadata, + &i.Priority, + &i.Queue, + &i.State, + &i.ScheduledAt, + pq.Array(&i.Tags), + &i.UniqueKey, + &i.UniqueStates, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const jobRescueMany = `-- name: JobRescueMany :exec UPDATE river_job SET diff --git a/riverdriver/riverdatabasesql/river_database_sql_driver.go b/riverdriver/riverdatabasesql/river_database_sql_driver.go index 34918e19..66c8df14 100644 --- a/riverdriver/riverdatabasesql/river_database_sql_driver.go +++ b/riverdriver/riverdatabasesql/river_database_sql_driver.go @@ -19,12 +19,12 @@ import ( "time" "github.com/jackc/pgx/v5/pgtype" - "github.com/lib/pq" "github.com/riverqueue/river/internal/dbunique" "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/riverdriver/riverdatabasesql/internal/dbsqlc" "github.com/riverqueue/river/riverdriver/riverdatabasesql/internal/pgtypealias" + "github.com/riverqueue/river/rivershared/sqlctemplate" "github.com/riverqueue/river/rivershared/util/sliceutil" "github.com/riverqueue/river/rivershared/util/valutil" "github.com/riverqueue/river/rivertype" @@ -35,7 +35,8 @@ var migrationFS embed.FS // Driver is an implementation of riverdriver.Driver for database/sql. type Driver struct { - dbPool *sql.DB + dbPool *sql.DB + replacer *sqlctemplate.Replacer } // New returns a new database/sql River driver for use with River. @@ -44,11 +45,14 @@ type Driver struct { // configured to use the schema specified in the client's Schema field. The pool // must not be closed while associated River objects are running. func New(dbPool *sql.DB) *Driver { - return &Driver{dbPool: dbPool} + return &Driver{ + dbPool: dbPool, + replacer: sqlctemplate.NewReplacer(), + } } func (d *Driver) GetExecutor() riverdriver.Executor { - return &Executor{d.dbPool, d.dbPool} + return &Executor{d.dbPool, templateReplaceWrapper{d.dbPool, d.replacer}, d} } func (d *Driver) GetListener() riverdriver.Listener { panic(riverdriver.ErrNotImplemented) } @@ -63,12 +67,13 @@ func (d *Driver) HasPool() bool { return d.dbPool != nil } func (d *Driver) SupportsListener() bool { return false } func (d *Driver) UnwrapExecutor(tx *sql.Tx) riverdriver.ExecutorTx { - return &ExecutorTx{Executor: Executor{nil, tx}, tx: tx} + return &ExecutorTx{Executor: Executor{nil, templateReplaceWrapper{tx, d.replacer}, d}, tx: tx} } type Executor struct { dbPool *sql.DB - dbtx dbsqlc.DBTX + dbtx templateReplaceWrapper + driver *Driver } func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { @@ -76,7 +81,7 @@ func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { if err != nil { return nil, err } - return &ExecutorTx{Executor: Executor{nil, tx}, tx: tx}, nil + return &ExecutorTx{Executor: Executor{nil, templateReplaceWrapper{tx, e.driver.replacer}, e.driver}, tx: tx}, nil } func (e *Executor) ColumnExists(ctx context.Context, tableName, columnName string) (bool, error) { @@ -351,50 +356,22 @@ func (e *Executor) JobInsertFull(ctx context.Context, params *riverdriver.JobIns return jobRowFromInternal(job) } -func (e *Executor) JobList(ctx context.Context, query string, namedArgs map[string]any) ([]*rivertype.JobRow, error) { - query, err := replaceNamed(query, namedArgs) +func (e *Executor) JobList(ctx context.Context, params *riverdriver.JobListParams) ([]*rivertype.JobRow, error) { + whereClause, err := replaceNamed(params.WhereClause, params.NamedArgs) if err != nil { return nil, err } - rows, err := e.dbtx.QueryContext(ctx, query) + ctx = sqlctemplate.WithTemplates(ctx, map[string]sqlctemplate.Replacement{ + "order_by_clause": {Value: params.OrderByClause}, + "where_clause": {Value: whereClause}, + }, nil) // named params not passed because they've already been replaced above + + jobs, err := dbsqlc.New().JobList(ctx, e.dbtx, params.Max) if err != nil { - return nil, err - } - defer rows.Close() - - var items []*dbsqlc.RiverJob - for rows.Next() { - var i dbsqlc.RiverJob - if err := rows.Scan( - &i.ID, - &i.Args, - &i.Attempt, - &i.AttemptedAt, - pq.Array(&i.AttemptedBy), - &i.CreatedAt, - pq.Array(&i.Errors), - &i.FinalizedAt, - &i.Kind, - &i.MaxAttempts, - &i.Metadata, - &i.Priority, - &i.Queue, - &i.State, - &i.ScheduledAt, - pq.Array(&i.Tags), - &i.UniqueKey, - &i.UniqueStates, - ); err != nil { - return nil, err - } - items = append(items, &i) - } - if err := rows.Err(); err != nil { return nil, interpretError(err) } - - return mapSliceError(items, jobRowFromInternal) + return mapSliceError(jobs, jobRowFromInternal) } func escapeSinglePostgresValue(value any) string { @@ -489,10 +466,6 @@ func replaceNamed(query string, namedArgs map[string]any) (string, error) { return query, nil } -func (e *Executor) JobListFields() string { - return "id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states" -} - func (e *Executor) JobRescueMany(ctx context.Context, params *riverdriver.JobRescueManyParams) (*struct{}, error) { err := dbsqlc.New().JobRescueMany(ctx, e.dbtx, &dbsqlc.JobRescueManyParams{ ID: params.ID, @@ -846,7 +819,7 @@ type ExecutorTx struct { } func (t *ExecutorTx) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { - return (&ExecutorSubTx{Executor: Executor{nil, t.tx}, savepointNum: 0, single: &singleTransaction{}, tx: t.tx}).Begin(ctx) + return (&ExecutorSubTx{Executor: Executor{nil, templateReplaceWrapper{t.tx, t.driver.replacer}, t.driver}, savepointNum: 0, single: &singleTransaction{}, tx: t.tx}).Begin(ctx) } func (t *ExecutorTx) Commit(ctx context.Context) error { @@ -878,7 +851,7 @@ func (t *ExecutorSubTx) Begin(ctx context.Context) (riverdriver.ExecutorTx, erro if err != nil { return nil, err } - return &ExecutorSubTx{Executor: Executor{nil, t.tx}, savepointNum: nextSavepointNum, single: &singleTransaction{parent: t.single}, tx: t.tx}, nil + return &ExecutorSubTx{Executor: Executor{nil, templateReplaceWrapper{t.tx, t.driver.replacer}, t.driver}, savepointNum: nextSavepointNum, single: &singleTransaction{parent: t.single}, tx: t.tx}, nil } func (t *ExecutorSubTx) Commit(ctx context.Context) error { @@ -944,6 +917,31 @@ func (t *singleTransaction) setDone() { } } +type templateReplaceWrapper struct { + dbtx dbsqlc.DBTX + replacer *sqlctemplate.Replacer +} + +func (w templateReplaceWrapper) ExecContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) { + sql, args = w.replacer.Run(ctx, sql, args) + return w.dbtx.ExecContext(ctx, sql, args...) +} + +func (w templateReplaceWrapper) PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error) { + sql, _ = w.replacer.Run(ctx, sql, nil) + return w.dbtx.PrepareContext(ctx, sql) +} + +func (w templateReplaceWrapper) QueryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) { + sql, args = w.replacer.Run(ctx, sql, args) + return w.dbtx.QueryContext(ctx, sql, args...) +} + +func (w templateReplaceWrapper) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *sql.Row { + sql, args = w.replacer.Run(ctx, sql, args) + return w.dbtx.QueryRowContext(ctx, sql, args...) +} + func jobRowFromInternal(internal *dbsqlc.RiverJob) (*rivertype.JobRow, error) { var attemptedAt *time.Time if internal.AttemptedAt != nil { diff --git a/riverdriver/riverpgxv5/internal/dbsqlc/river_job.sql b/riverdriver/riverpgxv5/internal/dbsqlc/river_job.sql index a6be7556..aeb057fe 100644 --- a/riverdriver/riverpgxv5/internal/dbsqlc/river_job.sql +++ b/riverdriver/riverpgxv5/internal/dbsqlc/river_job.sql @@ -82,7 +82,7 @@ FROM updated_job; -- name: JobCountByState :one SELECT count(*) -FROM river_job +FROM /* TEMPLATE: schema */river_job WHERE state = @state; -- name: JobDelete :one @@ -315,6 +315,12 @@ INSERT INTO river_job( @unique_states ) RETURNING *; +-- name: JobList :many +SELECT * +FROM river_job +WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */ +ORDER BY /* TEMPLATE_BEGIN: order_by_clause */ id /* TEMPLATE_END */ +LIMIT @max::int; -- Run by the rescuer to queue for retry or discard depending on job state. -- name: JobRescueMany :exec diff --git a/riverdriver/riverpgxv5/internal/dbsqlc/river_job.sql.go b/riverdriver/riverpgxv5/internal/dbsqlc/river_job.sql.go index 4cd0810c..68233263 100644 --- a/riverdriver/riverpgxv5/internal/dbsqlc/river_job.sql.go +++ b/riverdriver/riverpgxv5/internal/dbsqlc/river_job.sql.go @@ -90,7 +90,7 @@ func (q *Queries) JobCancel(ctx context.Context, db DBTX, arg *JobCancelParams) const jobCountByState = `-- name: JobCountByState :one SELECT count(*) -FROM river_job +FROM /* TEMPLATE: schema */river_job WHERE state = $1 ` @@ -809,6 +809,53 @@ func (q *Queries) JobInsertFull(ctx context.Context, db DBTX, arg *JobInsertFull return &i, err } +const jobList = `-- name: JobList :many +SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states +FROM river_job +WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */ +ORDER BY /* TEMPLATE_BEGIN: order_by_clause */ id /* TEMPLATE_END */ +LIMIT $1::int +` + +func (q *Queries) JobList(ctx context.Context, db DBTX, max int32) ([]*RiverJob, error) { + rows, err := db.Query(ctx, jobList, max) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverJob + for rows.Next() { + var i RiverJob + if err := rows.Scan( + &i.ID, + &i.Args, + &i.Attempt, + &i.AttemptedAt, + &i.AttemptedBy, + &i.CreatedAt, + &i.Errors, + &i.FinalizedAt, + &i.Kind, + &i.MaxAttempts, + &i.Metadata, + &i.Priority, + &i.Queue, + &i.State, + &i.ScheduledAt, + &i.Tags, + &i.UniqueKey, + &i.UniqueStates, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const jobRescueMany = `-- name: JobRescueMany :exec UPDATE river_job SET diff --git a/riverdriver/riverpgxv5/river_pgx_v5_driver.go b/riverdriver/riverpgxv5/river_pgx_v5_driver.go index c446c673..755e9655 100644 --- a/riverdriver/riverpgxv5/river_pgx_v5_driver.go +++ b/riverdriver/riverpgxv5/river_pgx_v5_driver.go @@ -17,6 +17,7 @@ import ( "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/puddle/v2" @@ -24,6 +25,7 @@ import ( "github.com/riverqueue/river/internal/dbunique" "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/riverdriver/riverpgxv5/internal/dbsqlc" + "github.com/riverqueue/river/rivershared/sqlctemplate" "github.com/riverqueue/river/rivershared/util/sliceutil" "github.com/riverqueue/river/rivertype" ) @@ -31,9 +33,15 @@ import ( //go:embed migration/*/*.sql var migrationFS embed.FS +type Config struct { + Schema string +} + // Driver is an implementation of riverdriver.Driver for Pgx v5. type Driver struct { - dbPool *pgxpool.Pool + config *Config + dbPool *pgxpool.Pool + replacer *sqlctemplate.Replacer } // New returns a new Pgx v5 River driver for use with River. @@ -49,10 +57,24 @@ type Driver struct { // in testing so that inserts can be performed and verified on a test // transaction that will be rolled back. func New(dbPool *pgxpool.Pool) *Driver { - return &Driver{dbPool: dbPool} + return &Driver{ + config: &Config{}, + dbPool: dbPool, + replacer: sqlctemplate.NewReplacer(), + } } -func (d *Driver) GetExecutor() riverdriver.Executor { return &Executor{d.dbPool} } +func NewConfig(dbPool *pgxpool.Pool, config *Config) *Driver { + return &Driver{ + config: config, + dbPool: dbPool, + replacer: sqlctemplate.NewReplacer(), + } +} + +func (d *Driver) GetExecutor() riverdriver.Executor { + return &Executor{templateReplaceWrapper{d.dbPool, d.replacer}, d} +} func (d *Driver) GetListener() riverdriver.Listener { return &Listener{dbPool: d.dbPool} } func (d *Driver) GetMigrationFS(line string) fs.FS { if line == riverdriver.MigrationLineMain { @@ -65,14 +87,12 @@ func (d *Driver) HasPool() bool { return d.dbPool != nil } func (d *Driver) SupportsListener() bool { return true } func (d *Driver) UnwrapExecutor(tx pgx.Tx) riverdriver.ExecutorTx { - return &ExecutorTx{Executor: Executor{tx}, tx: tx} + return &ExecutorTx{Executor: Executor{templateReplaceWrapper{tx, d.replacer}, d}, tx: tx} } type Executor struct { - dbtx interface { - dbsqlc.DBTX - Begin(ctx context.Context) (pgx.Tx, error) - } + dbtx templateReplaceWrapper + driver *Driver } func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { @@ -80,7 +100,7 @@ func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { if err != nil { return nil, err } - return &ExecutorTx{Executor: Executor{tx}, tx: tx}, nil + return &ExecutorTx{Executor: Executor{templateReplaceWrapper{tx, e.driver.replacer}, e.driver}, tx: tx}, nil } func (e *Executor) ColumnExists(ctx context.Context, tableName, columnName string) (bool, error) { @@ -113,8 +133,19 @@ func (e *Executor) JobCancel(ctx context.Context, params *riverdriver.JobCancelP return jobRowFromInternal(job) } +func (e *Executor) schemaTemplateParam(ctx context.Context) context.Context { + var schemaPrefix string + if e.driver.config.Schema != "" { + schemaPrefix = e.driver.config.Schema + "." + } + + return sqlctemplate.WithTemplates(ctx, map[string]sqlctemplate.Replacement{ + "schema": {Value: schemaPrefix}, + }, nil) +} + func (e *Executor) JobCountByState(ctx context.Context, state rivertype.JobState) (int, error) { - numJobs, err := dbsqlc.New().JobCountByState(ctx, e.dbtx, dbsqlc.RiverJobState(state)) + numJobs, err := dbsqlc.New().JobCountByState(e.schemaTemplateParam(ctx), e.dbtx, dbsqlc.RiverJobState(state)) if err != nil { return 0, err } @@ -337,49 +368,17 @@ func (e *Executor) JobInsertFull(ctx context.Context, params *riverdriver.JobIns return jobRowFromInternal(job) } -func (e *Executor) JobList(ctx context.Context, query string, namedArgs map[string]any) ([]*rivertype.JobRow, error) { - rows, err := e.dbtx.Query(ctx, query, pgx.NamedArgs(namedArgs)) +func (e *Executor) JobList(ctx context.Context, params *riverdriver.JobListParams) ([]*rivertype.JobRow, error) { + ctx = sqlctemplate.WithTemplates(ctx, map[string]sqlctemplate.Replacement{ + "order_by_clause": {Value: params.OrderByClause}, + "where_clause": {Value: params.WhereClause}, + }, params.NamedArgs) + + jobs, err := dbsqlc.New().JobList(ctx, e.dbtx, params.Max) if err != nil { - return nil, err - } - defer rows.Close() - - var items []*dbsqlc.RiverJob - for rows.Next() { - var i dbsqlc.RiverJob - if err := rows.Scan( - &i.ID, - &i.Args, - &i.Attempt, - &i.AttemptedAt, - &i.AttemptedBy, - &i.CreatedAt, - &i.Errors, - &i.FinalizedAt, - &i.Kind, - &i.MaxAttempts, - &i.Metadata, - &i.Priority, - &i.Queue, - &i.State, - &i.ScheduledAt, - &i.Tags, - &i.UniqueKey, - &i.UniqueStates, - ); err != nil { - return nil, err - } - items = append(items, &i) - } - if err := rows.Err(); err != nil { return nil, interpretError(err) } - - return mapSliceError(items, jobRowFromInternal) -} - -func (e *Executor) JobListFields() string { - return "id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags, unique_key, unique_states" + return mapSliceError(jobs, jobRowFromInternal) } func (e *Executor) JobRescueMany(ctx context.Context, params *riverdriver.JobRescueManyParams) (*struct{}, error) { @@ -814,6 +813,37 @@ func (l *Listener) WaitForNotification(ctx context.Context) (*riverdriver.Notifi }, nil } +type templateReplaceWrapper struct { + dbtx interface { + dbsqlc.DBTX + Begin(ctx context.Context) (pgx.Tx, error) + } + replacer *sqlctemplate.Replacer +} + +func (w templateReplaceWrapper) Begin(ctx context.Context) (pgx.Tx, error) { + return w.dbtx.Begin(ctx) +} + +func (w templateReplaceWrapper) Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) { + sql, args = w.replacer.Run(ctx, sql, args) + return w.dbtx.Exec(ctx, sql, args...) +} + +func (w templateReplaceWrapper) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { + sql, args = w.replacer.Run(ctx, sql, args) + return w.dbtx.Query(ctx, sql, args...) +} + +func (w templateReplaceWrapper) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { + sql, args = w.replacer.Run(ctx, sql, args) + return w.dbtx.QueryRow(ctx, sql, args...) +} + +func (w templateReplaceWrapper) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + return w.dbtx.CopyFrom(ctx, tableName, columnNames, rowSrc) +} + func interpretError(err error) error { if errors.Is(err, puddle.ErrClosedPool) { return riverdriver.ErrClosedPool diff --git a/rivershared/sqlctemplate/sqlc_template.go b/rivershared/sqlctemplate/sqlc_template.go new file mode 100644 index 00000000..046126a9 --- /dev/null +++ b/rivershared/sqlctemplate/sqlc_template.go @@ -0,0 +1,225 @@ +// Package sqlctemplate provides a way of making arbitrary text replacement in +// sqlc queries which normally only allow parameters which are in places valid +// in a prepared statement. For example, it can be used to insert a schema name +// as a prefix to tables referenced in sqlc, which is otherwise impossible. +// +// Replacement is carried out from within invocations of sqlc's generated DBTX +// interface, after sqlc generated code runs, but before queries are executed. +// This is accomplished by implementing DBTX, calling Replacer.Run from within +// them, and injecting parameters in with WithTemplates (which is unfortunately +// the only way of injecting them). +// +// Templates are modeled as SQL comments so that they're still parseable as +// valid SQL. An example use of the basic /* TEMPLATE ... */ syntax: +// +// -- name: JobCountByState :one +// SELECT count(*) +// FROM /* TEMPLATE: schema */river_job +// WHERE state = @state; +// +// An open/close syntax is also available for when SQL is required before +// processing for the query to be valid. For example, a WHERE or ORDER BY clause +// can't be empty, so the SQL includes a sentinel value that's parseable which +// is then replaced later with template values: +// +// -- name: JobList :many +// SELECT * +// FROM river_job +// WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */ +// ORDER BY /* TEMPLATE_BEGIN: order_by_clause */ id /* TEMPLATE_END */ +// LIMIT @max::int; +// +// Be careful not to place a template on a line by itself because sqlc will +// strip any lines that start with a comment. For example, this does NOT work: +// +// -- name: JobList :many +// SELECT * +// FROM river_job +// /* TEMPLATE_BEGIN: where_clause */ +// LIMIT @max::int; +package sqlctemplate + +import ( + "context" + "errors" + "fmt" + "regexp" + "slices" + "strconv" + "strings" + + "github.com/riverqueue/river/rivershared/util/maputil" +) + +type contextContainer struct { + NamedArgs map[string]any + Templates map[string]Replacement +} + +type contextKey struct{} + +// Replacement defines a replacement for a template value in some input SQL. +type Replacement struct { + // Stable is whether the replacement value is expected to be stable for any + // number of times Replacer.Run is called with the same given input SQL. If + // all replacements are stable, then the output of Replacer.Run is cached so + // that it doesn't have to be processed again. Replacements should be not be + // stable if they depend on input parameters. + Stable bool + + // Value is the value which the template should be replaced with. For a /* + // TEMPLATE ... */ tag, replaces template and the comment containing it. For + // a /* TEMPLATE_BEGIN ... */ ... /* TEMPLATE_END */ tag pair, replaces both + // templates, comments, and the value between them. + Value string +} + +// Replacer replaces templates with template values. As an optimization, it +// contains an internal cache to short circuit SQL that has entirely stable +// template replacements and whose output is invariant of input parameters. +type Replacer struct { + cache map[string]string +} + +// NewReplacer initializes a new template Replacer. +func NewReplacer() *Replacer { + return &Replacer{ + cache: make(map[string]string), + } +} + +var ( + templateBeginEndRE = regexp.MustCompile(`/\* TEMPLATE_BEGIN: (.*?) \*/ .*? /\* TEMPLATE_END \*/`) + templateRE = regexp.MustCompile(`/\* TEMPLATE: (.*?) \*/`) +) + +// Run replaces any tempates in input SQL with values from context added via +// WithTemplates. +// +// args aren't used for replacements in the input SQL, but are needed to +// determine which placeholder number (e.g. $1, $2, $3, ...) we should start +// with to replace any template named args. The returned args value should then +// be used as query input as named args from context may have been added to it. +func (r *Replacer) Run(ctx context.Context, sql string, args []any) (string, []any) { + sql, namedArgs, err := r.RunSafely(ctx, sql, args) + if err != nil { + panic(err) + } + return sql, namedArgs +} + +// RunSafely is the same as Run, but returns an error in case of missing or +// extra templates. +func (r *Replacer) RunSafely(ctx context.Context, sql string, args []any) (string, []any, error) { + // If nothing present in context, short circuit quickly. + container, ok := ctx.Value(contextKey{}).(*contextContainer) + if !ok { + return sql, args, nil + } + + // If all input templates were stable, the finished SQL will have been + if cachedSQL, ok := r.cache[sql]; ok { + if len(container.NamedArgs) > 0 { + args = append(args, maputil.Values(container.NamedArgs)...) + } + return cachedSQL, args, nil + } + + if !strings.Contains(sql, "/* TEMPLATE") { + return sql, args, nil + } + + var ( + templatesExpected = maputil.Keys(container.Templates) + templatesMissing []string // not preallocated because we don't expect any missing parameters in the common case + ) + + replaceTemplate := func(sql string, templateRE *regexp.Regexp) string { + return templateRE.ReplaceAllStringFunc(sql, func(templateStr string) string { + // Really dumb, but Go doesn't provide any way to get submatches in a + // function, so we have to match twice. + // https://github.com/golang/go/issues/5690 + matches := templateRE.FindStringSubmatch(templateStr) + + template := matches[1] + + if tmpl, ok := container.Templates[template]; ok { + templatesExpected = slices.DeleteFunc(templatesExpected, func(p string) bool { return p == template }) + return tmpl.Value + } else { + templatesMissing = append(templatesMissing, template) + } + + return templateStr + }) + } + + updatedSQL := sql + updatedSQL = replaceTemplate(updatedSQL, templateBeginEndRE) + updatedSQL = replaceTemplate(updatedSQL, templateRE) + + if len(templatesExpected) > 0 { + return "", nil, errors.New("sqlctemplate params present in context but missing in SQL: " + strings.Join(templatesExpected, ", ")) + } + + if len(templatesMissing) > 0 { + return "", nil, errors.New("sqlctemplate params present in SQL but missing in context: " + strings.Join(templatesMissing, ", ")) + } + + if len(container.NamedArgs) > 0 { + placeholderNum := len(args) + for arg, val := range container.NamedArgs { + placeholderNum++ + + var ( + symbol = "@" + arg + symbolIndex = strings.Index(updatedSQL, symbol) + ) + + if symbolIndex == -1 { + return "", nil, fmt.Errorf("sqltemplate expected to find named arg %q, but it wasn't present", symbol) + } + + // ReplaceAll because an input parameter may appear multiple times. + updatedSQL = strings.ReplaceAll(updatedSQL, symbol, "$"+strconv.Itoa(placeholderNum)) + args = append(args, val) + } + } + + for _, tmpl := range container.Templates { + if !tmpl.Stable { + return updatedSQL, args, nil + } + } + + r.cache[sql] = updatedSQL + + return updatedSQL, args, nil +} + +// WithTemplates adds sqlctemplate templates to the given context (they go in +// context because it's the only way to get them down into the innards of sqlc). +// namedArgs can also be passed in to replace arguments found in +// +// If sqlctemplate params are already present in context, the two sets are +// merged, with the new params taking precedent. +func WithTemplates(ctx context.Context, templates map[string]Replacement, namedArgs map[string]any) context.Context { + if container, ok := ctx.Value(contextKey{}).(*contextContainer); ok { + for arg, val := range namedArgs { + container.NamedArgs[arg] = val + } + for template, tmpl := range templates { + container.Templates[template] = tmpl + } + return ctx + } + + if namedArgs == nil { + namedArgs = make(map[string]any) + } + + return context.WithValue(ctx, contextKey{}, &contextContainer{ + NamedArgs: namedArgs, + Templates: templates, + }) +} diff --git a/rivershared/sqlctemplate/sqlc_template_test.go b/rivershared/sqlctemplate/sqlc_template_test.go new file mode 100644 index 00000000..d897419c --- /dev/null +++ b/rivershared/sqlctemplate/sqlc_template_test.go @@ -0,0 +1,245 @@ +package sqlctemplate + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestReplacer(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type testBundle struct{} + + setup := func(t *testing.T) (*Replacer, *testBundle) { //nolint:unparam + t.Helper() + + return NewReplacer(), &testBundle{} + } + + t.Run("BasicTemplate", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithTemplates(ctx, map[string]Replacement{ + "schema": {Value: "test_schema."}, + }, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + -- name: JobCountByState :one + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + WHERE state = @state; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + -- name: JobCountByState :one + SELECT count(*) + FROM test_schema.river_job + WHERE state = @state; + `, updatedSQL) + }) + + t.Run("BeginEndTemplate", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithTemplates(ctx, map[string]Replacement{ + "order_by_clause": {Value: "kind, id"}, + "where_clause": {Value: "kind = 'no_op'"}, + }, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + -- name: JobList :many + SELECT * + FROM river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */ + ORDER BY /* TEMPLATE_BEGIN: order_by_clause */ id /* TEMPLATE_END */ + LIMIT @max::int; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + -- name: JobList :many + SELECT * + FROM river_job + WHERE kind = 'no_op' + ORDER BY kind, id + LIMIT @max::int; + `, updatedSQL) + }) + + t.Run("RepeatedTemplate", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithTemplates(ctx, map[string]Replacement{ + "schema": {Value: "test_schema."}, + }, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job r1 + INNER JOIN /* TEMPLATE: schema */river_job r2 ON r1.id = r2.id; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job r1 + INNER JOIN test_schema.river_job r2 ON r1.id = r2.id; + `, updatedSQL) + }) + + t.Run("AllTemplatesStableCached", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithTemplates(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + }, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job; + `, updatedSQL) + + require.Len(t, replacer.cache, 1) + + // Invoke again to make sure we get back the same result. + updatedSQL, args, err = replacer.RunSafely(ctx, ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job; + `, updatedSQL) + }) + + t.Run("AnyTemplateNotStableNotCached", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithTemplates(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + "where_clause": {Value: "kind = 'no_op'"}, + }, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job + WHERE kind = 'no_op'; + `, updatedSQL) + + require.Empty(t, replacer.cache) + }) + + t.Run("NamedArgsNoInitialArgs", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithTemplates(ctx, map[string]Replacement{ + "where_clause": {Value: "kind = @kind"}, + }, map[string]any{ + "kind": "no_op", + }) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT count(*) + FROM river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */; + `, nil) + require.NoError(t, err) + require.Equal(t, []any{"no_op"}, args) + require.Equal(t, ` + SELECT count(*) + FROM river_job + WHERE kind = $1; + `, updatedSQL) + }) + + t.Run("NamedArgsWithInitialArgs", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithTemplates(ctx, map[string]Replacement{ + "where_clause": {Value: "kind = @kind"}, + }, map[string]any{ + "kind": "no_op", + }) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT count(*) + FROM river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */ + AND status = $1; + `, []any{"succeeded"}) + require.NoError(t, err) + require.Equal(t, []any{"succeeded", "no_op"}, args) + require.Equal(t, ` + SELECT count(*) + FROM river_job + WHERE kind = $2 + AND status = $1; + `, updatedSQL) + }) + + t.Run("MultipleWithTemplatesOverrides", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithTemplates(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + "where_clause": {Value: "kind = @kind"}, + }, map[string]any{ + "kind": "no_op", + }) + + ctx = WithTemplates(ctx, map[string]Replacement{ + "where_clause": {Value: "kind = @kind AND status = @status"}, + }, map[string]any{ + "status": "succeeded", + }) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */; + `, nil) + require.NoError(t, err) + require.Equal(t, []any{"no_op", "succeeded"}, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job + WHERE kind = $1 AND status = $2; + `, updatedSQL) + }) +}