From d261a51dce15ab095457bbed925597653250f06f Mon Sep 17 00:00:00 2001 From: Brandur Date: Sat, 8 Mar 2025 00:46:41 -0800 Subject: [PATCH] Safer template usage with more aggressive errors on likely problems Follows up the addition of `sqlctemplate` in #794. I noticed while adding functionality in with templates that it was quite easy to (1) add a `/* TEMPLATE` tag to `river_job.sql`, (2) put in context parameters to a driver, but then (3) forget to run `make generate`. The context parameters are injected, but `sqlctemplate` no ops with a fast short circuit because there's no `/* TEMPLATE` tag present in the generated Go code that the driver is executing. This leads to confusion. Here, add a few more error conditions: * If a context container is present without any `/* TEMPLATE` tags, error. * If any `/* TEMPLATE` tags are present without a context container, error. This makes dumb bugs easier to catch because we get an explicit error instead of them failing silently. Tests are updated to check for them. --- rivershared/sqlctemplate/sqlc_template.go | 20 +++++++++------ .../sqlctemplate/sqlc_template_test.go | 25 ++++++++++++------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/rivershared/sqlctemplate/sqlc_template.go b/rivershared/sqlctemplate/sqlc_template.go index 36986b24..c8417377 100644 --- a/rivershared/sqlctemplate/sqlc_template.go +++ b/rivershared/sqlctemplate/sqlc_template.go @@ -110,10 +110,20 @@ func (r *Replacer) Run(ctx context.Context, sql string, args []any) (string, []a // RunSafely is the same as Run, but returns an error in case of missing or // extra templates. func (r *Replacer) RunSafely(ctx context.Context, sql string, args []any) (string, []any, error) { - // If nothing present in context, short circuit quickly. - container, containerOK := ctx.Value(contextKey{}).(*contextContainer) - if !containerOK { + var ( + container, containerOK = ctx.Value(contextKey{}).(*contextContainer) + sqlContainsTemplate = strings.Contains(sql, "/* TEMPLATE") + ) + switch { + case !containerOK && !sqlContainsTemplate: + // Neither context container or template in SQL; short circuit fast because there's no work to do. return sql, args, nil + + case containerOK && !sqlContainsTemplate: + return "", nil, errors.New("sqlctemplate found context container but SQL contains no templates; bug?") + + case !containerOK && sqlContainsTemplate: + return "", nil, errors.New("sqlctemplate found template(s) in SQL, but no context container; bug?") } r.cacheMu.RLock() @@ -134,10 +144,6 @@ func (r *Replacer) RunSafely(ctx context.Context, sql string, args []any) (strin return cachedSQL, args, nil } - if !strings.Contains(sql, "/* TEMPLATE") { - return sql, args, nil - } - var ( templatesExpected = maputil.Keys(container.Templates) templatesMissing []string // not preallocated because we don't expect any missing parameters in the common case diff --git a/rivershared/sqlctemplate/sqlc_template_test.go b/rivershared/sqlctemplate/sqlc_template_test.go index f10fa30f..bbe47251 100644 --- a/rivershared/sqlctemplate/sqlc_template_test.go +++ b/rivershared/sqlctemplate/sqlc_template_test.go @@ -21,36 +21,43 @@ func TestReplacer(t *testing.T) { return &Replacer{}, &testBundle{} } - t.Run("NoContainer", func(t *testing.T) { + t.Run("NoContainerError", func(t *testing.T) { t.Parallel() replacer, _ := setup(t) - updatedSQL, args, err := replacer.RunSafely(ctx, ` + _, _, err := replacer.RunSafely(ctx, ` SELECT /* TEMPLATE: schema */river_job; `, nil) - require.NoError(t, err) - require.Nil(t, args) - require.Equal(t, ` - SELECT /* TEMPLATE: schema */river_job; - `, updatedSQL) + require.EqualError(t, err, "sqlctemplate found template(s) in SQL, but no context container; bug?") }) - t.Run("NoTemplate", func(t *testing.T) { + t.Run("NoTemplateError", func(t *testing.T) { t.Parallel() replacer, _ := setup(t) ctx := WithReplacements(ctx, map[string]Replacement{}, nil) + _, _, err := replacer.RunSafely(ctx, ` + SELECT 1; + `, nil) + require.EqualError(t, err, "sqlctemplate found context container but SQL contains no templates; bug?") + }) + + t.Run("NoContainerOrTemplate", func(t *testing.T) { + t.Parallel() + + replacer, _ := setup(t) + updatedSQL, args, err := replacer.RunSafely(ctx, ` SELECT 1; `, nil) require.NoError(t, err) - require.Nil(t, args) require.Equal(t, ` SELECT 1; `, updatedSQL) + require.Nil(t, args) }) t.Run("BasicTemplate", func(t *testing.T) {