From 433cd721e87c5eadc47d04318d284b99c161e4ef Mon Sep 17 00:00:00 2001 From: Brandur Date: Wed, 5 Mar 2025 18:02:36 -0800 Subject: [PATCH] Add sqlc templates for arbitrary text substitution (minimal) This one aims to give us a workable resolution to one of our most common problems with sqlc. Namely, that although it allows substitution for parameters that work with a prepared query, it can't replace arbitrary parts of a SQL query, leading to operations that aren't possible so that we either don't do them or end up degrading to raw SQL that's only checked at runtime. Here, we add a `sqlctemplate` package that's designed to be run from inside custom implementations of sqlc's `DBTX` interface so that it it runs after sqlc's generated code but before the query goes to Postgres. In sqlc code, templates look like this: -- name: JobCountByState :one SELECT count(*) FROM /* TEMPLATE: schema */river_job WHERE state = @state; The template replacement is modeled as a comment so that it doesn't interfere with with sqlc's parsing of SQL syntax. The above is valid SQL with or without the template, but with it, `sqlctemplate` can add an arbitrary schema name to the queried table. It also supports a form of syntax where a value is required for SQL to be valid. For example, `WHERE` and `ORDER BY` clauses both require a value for them to be valid. Here, a stand in value is provide between template tags. It's processed by sqlc's parser, but then replace by the template engine before the SQL is executed: -- name: JobList :many SELECT * FROM river_job WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */ ORDER BY /* TEMPLATE_BEGIN: order_by_clause */ id /* TEMPLATE_END */ LIMIT @max::int; Template values are injected via context (don't love this, but there's no other way in getting information down to a layer below `DBTX`): ctx = sqlctemplate.WithReplacements(ctx, map[string]sqlctemplate.Replacement{ "order_by_clause": {Value: params.OrderByClause}, "where_clause": {Value: params.WhereClause}, }, params.NamedArgs) jobs, err := dbsqlc.New().JobList(ctx, e.dbtx, params.Max) The template engine is written to be root out as many error as possible by noticing if a template replacement is passed that doesn't have an equivalent template in SQL, or if a template in SQL is present for which there's no replacement. Named args are support in templates similar to how sqlc supports them. This allows pgx's prepared statement cache to continue to operate as it did before, thereby keeping everything fast. Lastly, I should note that templates are meant as a utility of last resort. All effort should be made to resolve problems via mainstream sqlc, and only bring in templates when there's no other option. --- client_test.go | 2 +- .../river_database_sql_driver.go | 54 ++- riverdriver/riverpgxv5/river_pgx_v5_driver.go | 62 ++- rivershared/sqlctemplate/sqlc_template.go | 239 ++++++++++++ .../sqlctemplate/sqlc_template_test.go | 364 ++++++++++++++++++ 5 files changed, 703 insertions(+), 18 deletions(-) create mode 100644 rivershared/sqlctemplate/sqlc_template.go create mode 100644 rivershared/sqlctemplate/sqlc_template_test.go diff --git a/client_test.go b/client_test.go index 9d2af82b..31864be5 100644 --- a/client_test.go +++ b/client_test.go @@ -710,7 +710,7 @@ func Test_Client(t *testing.T) { &overridableJobMiddleware{ workFunc: func(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error { middlewareCalled = true - require.Equal(t, `{"name": "inserted name"}`, string(job.EncodedArgs)) + require.JSONEq(t, `{"name": "inserted name"}`, string(job.EncodedArgs)) job.EncodedArgs = []byte(`{"name": "middleware name"}`) return doInner(ctx) }, diff --git a/riverdriver/riverdatabasesql/river_database_sql_driver.go b/riverdriver/riverdatabasesql/river_database_sql_driver.go index 34918e19..195609ec 100644 --- a/riverdriver/riverdatabasesql/river_database_sql_driver.go +++ b/riverdriver/riverdatabasesql/river_database_sql_driver.go @@ -25,6 +25,7 @@ import ( "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/riverdriver/riverdatabasesql/internal/dbsqlc" "github.com/riverqueue/river/riverdriver/riverdatabasesql/internal/pgtypealias" + "github.com/riverqueue/river/rivershared/sqlctemplate" "github.com/riverqueue/river/rivershared/util/sliceutil" "github.com/riverqueue/river/rivershared/util/valutil" "github.com/riverqueue/river/rivertype" @@ -35,7 +36,8 @@ var migrationFS embed.FS // Driver is an implementation of riverdriver.Driver for database/sql. type Driver struct { - dbPool *sql.DB + dbPool *sql.DB + replacer sqlctemplate.Replacer } // New returns a new database/sql River driver for use with River. @@ -44,11 +46,13 @@ type Driver struct { // configured to use the schema specified in the client's Schema field. The pool // must not be closed while associated River objects are running. func New(dbPool *sql.DB) *Driver { - return &Driver{dbPool: dbPool} + return &Driver{ + dbPool: dbPool, + } } func (d *Driver) GetExecutor() riverdriver.Executor { - return &Executor{d.dbPool, d.dbPool} + return &Executor{d.dbPool, templateReplaceWrapper{d.dbPool, &d.replacer}, d} } func (d *Driver) GetListener() riverdriver.Listener { panic(riverdriver.ErrNotImplemented) } @@ -63,12 +67,21 @@ func (d *Driver) HasPool() bool { return d.dbPool != nil } func (d *Driver) SupportsListener() bool { return false } func (d *Driver) UnwrapExecutor(tx *sql.Tx) riverdriver.ExecutorTx { - return &ExecutorTx{Executor: Executor{nil, tx}, tx: tx} + // Allows UnwrapExecutor to be invoked even if driver is nil. + var replacer *sqlctemplate.Replacer + if d == nil { + replacer = &sqlctemplate.Replacer{} + } else { + replacer = &d.replacer + } + + return &ExecutorTx{Executor: Executor{nil, templateReplaceWrapper{tx, replacer}, d}, tx: tx} } type Executor struct { dbPool *sql.DB - dbtx dbsqlc.DBTX + dbtx templateReplaceWrapper + driver *Driver } func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { @@ -76,7 +89,7 @@ func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { if err != nil { return nil, err } - return &ExecutorTx{Executor: Executor{nil, tx}, tx: tx}, nil + return &ExecutorTx{Executor: Executor{nil, templateReplaceWrapper{tx, &e.driver.replacer}, e.driver}, tx: tx}, nil } func (e *Executor) ColumnExists(ctx context.Context, tableName, columnName string) (bool, error) { @@ -846,7 +859,7 @@ type ExecutorTx struct { } func (t *ExecutorTx) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { - return (&ExecutorSubTx{Executor: Executor{nil, t.tx}, savepointNum: 0, single: &singleTransaction{}, tx: t.tx}).Begin(ctx) + return (&ExecutorSubTx{Executor: Executor{nil, templateReplaceWrapper{t.tx, &t.driver.replacer}, t.driver}, savepointNum: 0, single: &singleTransaction{}, tx: t.tx}).Begin(ctx) } func (t *ExecutorTx) Commit(ctx context.Context) error { @@ -878,7 +891,7 @@ func (t *ExecutorSubTx) Begin(ctx context.Context) (riverdriver.ExecutorTx, erro if err != nil { return nil, err } - return &ExecutorSubTx{Executor: Executor{nil, t.tx}, savepointNum: nextSavepointNum, single: &singleTransaction{parent: t.single}, tx: t.tx}, nil + return &ExecutorSubTx{Executor: Executor{nil, templateReplaceWrapper{t.tx, &t.driver.replacer}, t.driver}, savepointNum: nextSavepointNum, single: &singleTransaction{parent: t.single}, tx: t.tx}, nil } func (t *ExecutorSubTx) Commit(ctx context.Context) error { @@ -944,6 +957,31 @@ func (t *singleTransaction) setDone() { } } +type templateReplaceWrapper struct { + dbtx dbsqlc.DBTX + replacer *sqlctemplate.Replacer +} + +func (w templateReplaceWrapper) ExecContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) { + sql, args = w.replacer.Run(ctx, sql, args) + return w.dbtx.ExecContext(ctx, sql, args...) +} + +func (w templateReplaceWrapper) PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error) { + sql, _ = w.replacer.Run(ctx, sql, nil) + return w.dbtx.PrepareContext(ctx, sql) +} + +func (w templateReplaceWrapper) QueryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) { + sql, args = w.replacer.Run(ctx, sql, args) + return w.dbtx.QueryContext(ctx, sql, args...) +} + +func (w templateReplaceWrapper) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *sql.Row { + sql, args = w.replacer.Run(ctx, sql, args) + return w.dbtx.QueryRowContext(ctx, sql, args...) +} + func jobRowFromInternal(internal *dbsqlc.RiverJob) (*rivertype.JobRow, error) { var attemptedAt *time.Time if internal.AttemptedAt != nil { diff --git a/riverdriver/riverpgxv5/river_pgx_v5_driver.go b/riverdriver/riverpgxv5/river_pgx_v5_driver.go index c446c673..3db9d3af 100644 --- a/riverdriver/riverpgxv5/river_pgx_v5_driver.go +++ b/riverdriver/riverpgxv5/river_pgx_v5_driver.go @@ -17,6 +17,7 @@ import ( "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/puddle/v2" @@ -24,6 +25,7 @@ import ( "github.com/riverqueue/river/internal/dbunique" "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/riverdriver/riverpgxv5/internal/dbsqlc" + "github.com/riverqueue/river/rivershared/sqlctemplate" "github.com/riverqueue/river/rivershared/util/sliceutil" "github.com/riverqueue/river/rivertype" ) @@ -33,7 +35,8 @@ var migrationFS embed.FS // Driver is an implementation of riverdriver.Driver for Pgx v5. type Driver struct { - dbPool *pgxpool.Pool + dbPool *pgxpool.Pool + replacer sqlctemplate.Replacer } // New returns a new Pgx v5 River driver for use with River. @@ -49,10 +52,14 @@ type Driver struct { // in testing so that inserts can be performed and verified on a test // transaction that will be rolled back. func New(dbPool *pgxpool.Pool) *Driver { - return &Driver{dbPool: dbPool} + return &Driver{ + dbPool: dbPool, + } } -func (d *Driver) GetExecutor() riverdriver.Executor { return &Executor{d.dbPool} } +func (d *Driver) GetExecutor() riverdriver.Executor { + return &Executor{templateReplaceWrapper{d.dbPool, &d.replacer}, d} +} func (d *Driver) GetListener() riverdriver.Listener { return &Listener{dbPool: d.dbPool} } func (d *Driver) GetMigrationFS(line string) fs.FS { if line == riverdriver.MigrationLineMain { @@ -65,14 +72,20 @@ func (d *Driver) HasPool() bool { return d.dbPool != nil } func (d *Driver) SupportsListener() bool { return true } func (d *Driver) UnwrapExecutor(tx pgx.Tx) riverdriver.ExecutorTx { - return &ExecutorTx{Executor: Executor{tx}, tx: tx} + // Allows UnwrapExecutor to be invoked even if driver is nil. + var replacer *sqlctemplate.Replacer + if d == nil { + replacer = &sqlctemplate.Replacer{} + } else { + replacer = &d.replacer + } + + return &ExecutorTx{Executor: Executor{templateReplaceWrapper{tx, replacer}, d}, tx: tx} } type Executor struct { - dbtx interface { - dbsqlc.DBTX - Begin(ctx context.Context) (pgx.Tx, error) - } + dbtx templateReplaceWrapper + driver *Driver } func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { @@ -80,7 +93,7 @@ func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { if err != nil { return nil, err } - return &ExecutorTx{Executor: Executor{tx}, tx: tx}, nil + return &ExecutorTx{Executor: Executor{templateReplaceWrapper{tx, &e.driver.replacer}, e.driver}, tx: tx}, nil } func (e *Executor) ColumnExists(ctx context.Context, tableName, columnName string) (bool, error) { @@ -814,6 +827,37 @@ func (l *Listener) WaitForNotification(ctx context.Context) (*riverdriver.Notifi }, nil } +type templateReplaceWrapper struct { + dbtx interface { + dbsqlc.DBTX + Begin(ctx context.Context) (pgx.Tx, error) + } + replacer *sqlctemplate.Replacer +} + +func (w templateReplaceWrapper) Begin(ctx context.Context) (pgx.Tx, error) { + return w.dbtx.Begin(ctx) +} + +func (w templateReplaceWrapper) Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) { + sql, args = w.replacer.Run(ctx, sql, args) + return w.dbtx.Exec(ctx, sql, args...) +} + +func (w templateReplaceWrapper) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { + sql, args = w.replacer.Run(ctx, sql, args) + return w.dbtx.Query(ctx, sql, args...) +} + +func (w templateReplaceWrapper) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { + sql, args = w.replacer.Run(ctx, sql, args) + return w.dbtx.QueryRow(ctx, sql, args...) +} + +func (w templateReplaceWrapper) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + return w.dbtx.CopyFrom(ctx, tableName, columnNames, rowSrc) +} + func interpretError(err error) error { if errors.Is(err, puddle.ErrClosedPool) { return riverdriver.ErrClosedPool diff --git a/rivershared/sqlctemplate/sqlc_template.go b/rivershared/sqlctemplate/sqlc_template.go new file mode 100644 index 00000000..36986b24 --- /dev/null +++ b/rivershared/sqlctemplate/sqlc_template.go @@ -0,0 +1,239 @@ +// Package sqlctemplate provides a way of making arbitrary text replacement in +// sqlc queries which normally only allow parameters which are in places valid +// in a prepared statement. For example, it can be used to insert a schema name +// as a prefix to tables referenced in sqlc, which is otherwise impossible. +// +// Replacement is carried out from within invocations of sqlc's generated DBTX +// interface, after sqlc generated code runs, but before queries are executed. +// This is accomplished by implementing DBTX, calling Replacer.Run from within +// them, and injecting parameters in with WithReplacements (which is unfortunately +// the only way of injecting them). +// +// Templates are modeled as SQL comments so that they're still parseable as +// valid SQL. An example use of the basic /* TEMPLATE ... */ syntax: +// +// -- name: JobCountByState :one +// SELECT count(*) +// FROM /* TEMPLATE: schema */river_job +// WHERE state = @state; +// +// An open/close syntax is also available for when SQL is required before +// processing for the query to be valid. For example, a WHERE or ORDER BY clause +// can't be empty, so the SQL includes a sentinel value that's parseable which +// is then replaced later with template values: +// +// -- name: JobList :many +// SELECT * +// FROM river_job +// WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */ +// ORDER BY /* TEMPLATE_BEGIN: order_by_clause */ id /* TEMPLATE_END */ +// LIMIT @max::int; +// +// Be careful not to place a template on a line by itself because sqlc will +// strip any lines that start with a comment. For example, this does NOT work: +// +// -- name: JobList :many +// SELECT * +// FROM river_job +// /* TEMPLATE_BEGIN: where_clause */ +// LIMIT @max::int; +package sqlctemplate + +import ( + "context" + "errors" + "fmt" + "regexp" + "slices" + "strconv" + "strings" + "sync" + + "github.com/riverqueue/river/rivershared/util/maputil" +) + +type contextContainer struct { + NamedArgs map[string]any + Templates map[string]Replacement +} + +type contextKey struct{} + +// Replacement defines a replacement for a template value in some input SQL. +type Replacement struct { + // Stable is whether the replacement value is expected to be stable for any + // number of times Replacer.Run is called with the same given input SQL. If + // all replacements are stable, then the output of Replacer.Run is cached so + // that it doesn't have to be processed again. Replacements should be not be + // stable if they depend on input parameters. + Stable bool + + // Value is the value which the template should be replaced with. For a /* + // TEMPLATE ... */ tag, replaces template and the comment containing it. For + // a /* TEMPLATE_BEGIN ... */ ... /* TEMPLATE_END */ tag pair, replaces both + // templates, comments, and the value between them. + Value string +} + +// Replacer replaces templates with template values. As an optimization, it +// contains an internal cache to short circuit SQL that has entirely stable +// template replacements and whose output is invariant of input parameters. +// +// The struct is written so that it's safe to use as a value and doesn't need to +// be initialized with a constructor. This lets it default to a usable instance +// on drivers that may themselves not be initialized. +type Replacer struct { + cache map[string]string + cacheMu sync.RWMutex +} + +var ( + templateBeginEndRE = regexp.MustCompile(`/\* TEMPLATE_BEGIN: (.*?) \*/ .*? /\* TEMPLATE_END \*/`) + templateRE = regexp.MustCompile(`/\* TEMPLATE: (.*?) \*/`) +) + +// Run replaces any tempates in input SQL with values from context added via +// WithReplacements. +// +// args aren't used for replacements in the input SQL, but are needed to +// determine which placeholder number (e.g. $1, $2, $3, ...) we should start +// with to replace any template named args. The returned args value should then +// be used as query input as named args from context may have been added to it. +func (r *Replacer) Run(ctx context.Context, sql string, args []any) (string, []any) { + sql, namedArgs, err := r.RunSafely(ctx, sql, args) + if err != nil { + panic(err) + } + return sql, namedArgs +} + +// RunSafely is the same as Run, but returns an error in case of missing or +// extra templates. +func (r *Replacer) RunSafely(ctx context.Context, sql string, args []any) (string, []any, error) { + // If nothing present in context, short circuit quickly. + container, containerOK := ctx.Value(contextKey{}).(*contextContainer) + if !containerOK { + return sql, args, nil + } + + r.cacheMu.RLock() + var ( + cachedSQL string + cachedSQLOK bool + ) + if r.cache != nil { // protect against map not initialized yet + cachedSQL, cachedSQLOK = r.cache[sql] + } + r.cacheMu.RUnlock() + + // If all input templates were stable, the finished SQL will have been + if cachedSQLOK { + if len(container.NamedArgs) > 0 { + args = append(args, maputil.Values(container.NamedArgs)...) + } + return cachedSQL, args, nil + } + + if !strings.Contains(sql, "/* TEMPLATE") { + return sql, args, nil + } + + var ( + templatesExpected = maputil.Keys(container.Templates) + templatesMissing []string // not preallocated because we don't expect any missing parameters in the common case + ) + + replaceTemplate := func(sql string, templateRE *regexp.Regexp) string { + return templateRE.ReplaceAllStringFunc(sql, func(templateStr string) string { + // Really dumb, but Go doesn't provide any way to get submatches in a + // function, so we have to match twice. + // https://github.com/golang/go/issues/5690 + matches := templateRE.FindStringSubmatch(templateStr) + + template := matches[1] + + if tmpl, ok := container.Templates[template]; ok { + templatesExpected = slices.DeleteFunc(templatesExpected, func(p string) bool { return p == template }) + return tmpl.Value + } else { + templatesMissing = append(templatesMissing, template) + } + + return templateStr + }) + } + + updatedSQL := sql + updatedSQL = replaceTemplate(updatedSQL, templateBeginEndRE) + updatedSQL = replaceTemplate(updatedSQL, templateRE) + + if len(templatesExpected) > 0 { + return "", nil, errors.New("sqlctemplate params present in context but missing in SQL: " + strings.Join(templatesExpected, ", ")) + } + + if len(templatesMissing) > 0 { + return "", nil, errors.New("sqlctemplate params present in SQL but missing in context: " + strings.Join(templatesMissing, ", ")) + } + + if len(container.NamedArgs) > 0 { + placeholderNum := len(args) + for arg, val := range container.NamedArgs { + placeholderNum++ + + var ( + symbol = "@" + arg + symbolIndex = strings.Index(updatedSQL, symbol) + ) + + if symbolIndex == -1 { + return "", nil, fmt.Errorf("sqltemplate expected to find named arg %q, but it wasn't present", symbol) + } + + // ReplaceAll because an input parameter may appear multiple times. + updatedSQL = strings.ReplaceAll(updatedSQL, symbol, "$"+strconv.Itoa(placeholderNum)) + args = append(args, val) + } + } + + for _, tmpl := range container.Templates { + if !tmpl.Stable { + return updatedSQL, args, nil + } + } + + r.cacheMu.Lock() + if r.cache == nil { + r.cache = make(map[string]string) + } + r.cache[sql] = updatedSQL + r.cacheMu.Unlock() + + return updatedSQL, args, nil +} + +// WithReplacements adds sqlctemplate templates to the given context (they go in +// context because it's the only way to get them down into the innards of sqlc). +// namedArgs can also be passed in to replace arguments found in +// +// If sqlctemplate params are already present in context, the two sets are +// merged, with the new params taking precedent. +func WithReplacements(ctx context.Context, templates map[string]Replacement, namedArgs map[string]any) context.Context { + if container, ok := ctx.Value(contextKey{}).(*contextContainer); ok { + for arg, val := range namedArgs { + container.NamedArgs[arg] = val + } + for template, tmpl := range templates { + container.Templates[template] = tmpl + } + return ctx + } + + if namedArgs == nil { + namedArgs = make(map[string]any) + } + + return context.WithValue(ctx, contextKey{}, &contextContainer{ + NamedArgs: namedArgs, + Templates: templates, + }) +} diff --git a/rivershared/sqlctemplate/sqlc_template_test.go b/rivershared/sqlctemplate/sqlc_template_test.go new file mode 100644 index 00000000..f10fa30f --- /dev/null +++ b/rivershared/sqlctemplate/sqlc_template_test.go @@ -0,0 +1,364 @@ +package sqlctemplate + +import ( + "context" + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestReplacer(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type testBundle struct{} + + setup := func(t *testing.T) (*Replacer, *testBundle) { //nolint:unparam + t.Helper() + + return &Replacer{}, &testBundle{} + } + + t.Run("NoContainer", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT /* TEMPLATE: schema */river_job; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + SELECT /* TEMPLATE: schema */river_job; + `, updatedSQL) + }) + + t.Run("NoTemplate", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithReplacements(ctx, map[string]Replacement{}, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT 1; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + SELECT 1; + `, updatedSQL) + }) + + t.Run("BasicTemplate", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Value: "test_schema."}, + }, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + -- name: JobCountByState :one + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + WHERE state = @state; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + -- name: JobCountByState :one + SELECT count(*) + FROM test_schema.river_job + WHERE state = @state; + `, updatedSQL) + }) + + t.Run("BeginEndTemplate", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithReplacements(ctx, map[string]Replacement{ + "order_by_clause": {Value: "kind, id"}, + "where_clause": {Value: "kind = 'no_op'"}, + }, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + -- name: JobList :many + SELECT * + FROM river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */ + ORDER BY /* TEMPLATE_BEGIN: order_by_clause */ id /* TEMPLATE_END */ + LIMIT @max::int; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + -- name: JobList :many + SELECT * + FROM river_job + WHERE kind = 'no_op' + ORDER BY kind, id + LIMIT @max::int; + `, updatedSQL) + }) + + t.Run("RepeatedTemplate", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Value: "test_schema."}, + }, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job r1 + INNER JOIN /* TEMPLATE: schema */river_job r2 ON r1.id = r2.id; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job r1 + INNER JOIN test_schema.river_job r2 ON r1.id = r2.id; + `, updatedSQL) + }) + + t.Run("AllTemplatesStableCached", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + }, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job; + `, updatedSQL) + + require.Len(t, replacer.cache, 1) + + // Invoke again to make sure we get back the same result. + updatedSQL, args, err = replacer.RunSafely(ctx, ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job; + `, updatedSQL) + }) + + t.Run("AnyTemplateNotStableNotCached", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + "where_clause": {Value: "kind = 'no_op'"}, + }, nil) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */; + `, nil) + require.NoError(t, err) + require.Nil(t, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job + WHERE kind = 'no_op'; + `, updatedSQL) + + require.Empty(t, replacer.cache) + }) + + t.Run("NamedArgsNoInitialArgs", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithReplacements(ctx, map[string]Replacement{ + "where_clause": {Value: "kind = @kind"}, + }, map[string]any{ + "kind": "no_op", + }) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT count(*) + FROM river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */; + `, nil) + require.NoError(t, err) + require.Equal(t, []any{"no_op"}, args) + require.Equal(t, ` + SELECT count(*) + FROM river_job + WHERE kind = $1; + `, updatedSQL) + }) + + t.Run("NamedArgsWithInitialArgs", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithReplacements(ctx, map[string]Replacement{ + "where_clause": {Value: "kind = @kind"}, + }, map[string]any{ + "kind": "no_op", + }) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT count(*) + FROM river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */ + AND status = $1; + `, []any{"succeeded"}) + require.NoError(t, err) + require.Equal(t, []any{"succeeded", "no_op"}, args) + require.Equal(t, ` + SELECT count(*) + FROM river_job + WHERE kind = $2 + AND status = $1; + `, updatedSQL) + }) + + t.Run("MultipleWithReplacementsOverrides", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + "where_clause": {Value: "kind = @kind"}, + }, map[string]any{ + "kind": "no_op", + }) + + ctx = WithReplacements(ctx, map[string]Replacement{ + "where_clause": {Value: "kind = @kind AND status = @status"}, + }, map[string]any{ + "status": "succeeded", + }) + + updatedSQL, args, err := replacer.RunSafely(ctx, ` + SELECT count(*) + FROM /* TEMPLATE: schema */river_job + WHERE /* TEMPLATE_BEGIN: where_clause */ 1 /* TEMPLATE_END */; + `, nil) + require.NoError(t, err) + require.Equal(t, []any{"no_op", "succeeded"}, args) + require.Equal(t, ` + SELECT count(*) + FROM test_schema.river_job + WHERE kind = $1 AND status = $2; + `, updatedSQL) + }) + + t.Run("Stress", func(t *testing.T) { + t.Parallel() + + const ( + clearCacheIterations = 10 + numIterations = 50 + ) + + replacer, _ := setup(t) + + periodicallyClearCache := func(i int, replacer *Replacer) { + if i+1%clearCacheIterations == 0 { // +1 so we don't clear cache on i == 0 + replacer.cacheMu.Lock() + replacer.cache = nil + replacer.cacheMu.Unlock() + } + } + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + + for i := range numIterations { + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Value: "test_schema."}, + }, nil) + + updatedSQL, _, err := replacer.RunSafely(ctx, ` + SELECT count(*) FROM /* TEMPLATE: schema */river_job; + `, nil) + require.NoError(t, err) + require.Equal(t, ` + SELECT count(*) FROM test_schema.river_job; + `, updatedSQL) + + periodicallyClearCache(i, replacer) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + + for i := range numIterations { + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + }, nil) + + updatedSQL, _, err := replacer.RunSafely(ctx, ` + SELECT distinct(kind) FROM /* TEMPLATE: schema */river_job; + `, nil) + require.NoError(t, err) + require.Equal(t, ` + SELECT distinct(kind) FROM test_schema.river_job; + `, updatedSQL) + + periodicallyClearCache(i, replacer) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + + for i := range numIterations { + ctx := WithReplacements(ctx, map[string]Replacement{ + "schema": {Stable: true, Value: "test_schema."}, + }, nil) + + updatedSQL, _, err := replacer.RunSafely(ctx, ` + SELECT count(*) FROM /* TEMPLATE: schema */river_job WHERE status = 'succeeded'; + `, nil) + require.NoError(t, err) + require.Equal(t, ` + SELECT count(*) FROM test_schema.river_job WHERE status = 'succeeded'; + `, updatedSQL) + + periodicallyClearCache(i, replacer) + } + }() + + wg.Wait() + }) +}