diff --git a/internal/dinosql/config.go b/internal/dinosql/config.go index e14dc9a2f0..bda15aa2aa 100644 --- a/internal/dinosql/config.go +++ b/internal/dinosql/config.go @@ -51,6 +51,8 @@ type PackageSettings struct { EmitJSONTags bool `json:"emit_json_tags"` EmitPreparedQueries bool `json:"emit_prepared_queries"` Overrides []Override `json:"overrides"` + // HACK: this is only set in tests, only here till Kotlin support can be merged. + rewriteParams bool } type Override struct { diff --git a/internal/dinosql/parser.go b/internal/dinosql/parser.go index 8154f1a5c6..1106fb4e38 100644 --- a/internal/dinosql/parser.go +++ b/internal/dinosql/parser.go @@ -229,7 +229,8 @@ func ParseQueries(c core.Catalog, pkg PackageSettings) (*Result, error) { continue } for _, stmt := range tree.Statements { - query, err := parseQuery(c, stmt, source) + rewriteParameters := pkg.rewriteParams + query, err := parseQuery(c, stmt, source, rewriteParameters) if err == errUnsupportedStatementType { continue } @@ -407,7 +408,7 @@ func validateCmd(n nodes.Node, name, cmd string) error { var errUnsupportedStatementType = errors.New("parseQuery: unsupported statement type") -func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) { +func parseQuery(c core.Catalog, stmt nodes.Node, source string, rewriteParameters bool) (*Query, error) { if err := validateParamRef(stmt); err != nil { return nil, err } @@ -443,6 +444,16 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) } rvs := rangeVars(raw.Stmt) refs := findParameters(raw.Stmt) + var edits []edit + if rewriteParameters { + edits, err = rewriteNumberedParameters(refs, raw, rawSQL) + if err != nil { + return nil, err + } + } else { + refs = uniqueParamRefs(refs) + sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number }) + } params, err := resolveCatalogRefs(c, rvs, refs) if err != nil { return nil, err @@ -452,7 +463,13 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) if err != nil { return nil, err } - expanded, err := expand(c, raw, rawSQL) + expandEdits, err := expand(c, raw, rawSQL) + if err != nil { + return nil, err + } + edits = append(edits, expandEdits...) + + expanded, err := editQuery(rawSQL, edits) if err != nil { return nil, err } @@ -472,6 +489,18 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) }, nil } +func rewriteNumberedParameters(refs []paramRef, raw nodes.RawStmt, sql string) ([]edit, error) { + edits := make([]edit, len(refs)) + for i, ref := range refs { + edits[i] = edit{ + Location: ref.ref.Location - raw.StmtLocation, + Old: fmt.Sprintf("$%d", ref.ref.Number), + New: "?", + } + } + return edits, nil +} + func stripComments(sql string) (string, []string, error) { s := bufio.NewScanner(strings.NewReader(sql)) var lines, comments []string @@ -494,7 +523,7 @@ type edit struct { New string } -func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) { +func expand(c core.Catalog, raw nodes.RawStmt, sql string) ([]edit, error) { list := search(raw, func(node nodes.Node) bool { switch node.(type) { case nodes.DeleteStmt: @@ -507,17 +536,17 @@ func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) { return true }) if len(list.Items) == 0 { - return sql, nil + return nil, nil } var edits []edit for _, item := range list.Items { edit, err := expandStmt(c, raw, item) if err != nil { - return "", err + return nil, err } edits = append(edits, edit...) } - return editQuery(sql, edits) + return edits, nil } func expandStmt(c core.Catalog, raw nodes.RawStmt, node nodes.Node) ([]edit, error) { @@ -958,7 +987,8 @@ type paramRef struct { type paramSearch struct { parent nodes.Node rangeVar *nodes.RangeVar - refs map[int]paramRef + refs *[]paramRef + seen map[int]struct{} // XXX: Gross state hack for limit limitCount nodes.Node @@ -1005,7 +1035,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { continue } // TODO: Out-of-bounds panic - p.refs[ref.Number] = paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar} + *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar}) + p.seen[ref.Location] = struct{}{} } for _, vl := range s.ValuesLists { for i, v := range vl { @@ -1014,7 +1045,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { continue } // TODO: Out-of-bounds panic - p.refs[ref.Number] = paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar} + *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar}) + p.seen[ref.Location] = struct{}{} } } } @@ -1050,7 +1082,7 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { parent = limitOffset{} } } - if _, found := p.refs[n.Number]; found { + if _, found := p.seen[n.Location]; found { break } @@ -1072,7 +1104,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { } if set { - p.refs[n.Number] = paramRef{parent: parent, ref: n, rv: p.rangeVar} + *p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar}) + p.seen[n.Location] = struct{}{} } return nil } @@ -1080,13 +1113,9 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { } func findParameters(root nodes.Node) []paramRef { - v := paramSearch{refs: map[int]paramRef{}} - Walk(v, root) refs := make([]paramRef, 0) - for _, r := range v.refs { - refs = append(refs, r) - } - sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number }) + v := paramSearch{seen: make(map[int]struct{}), refs: &refs} + Walk(v, root) return refs } @@ -1348,3 +1377,15 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( } return a, nil } + +func uniqueParamRefs(in []paramRef) []paramRef { + m := make(map[int]struct{}, len(in)) + o := make([]paramRef, 0, len(in)) + for _, v := range in { + if _, ok := m[v.ref.Number]; !ok { + m[v.ref.Number] = struct{}{} + o = append(o, v) + } + } + return o +} diff --git a/internal/dinosql/parser_test.go b/internal/dinosql/parser_test.go index e53c59955c..d1741cc6be 100644 --- a/internal/dinosql/parser_test.go +++ b/internal/dinosql/parser_test.go @@ -3,6 +3,7 @@ package dinosql import ( "testing" + "github.com/google/go-cmp/cmp" pg "github.com/lfittl/pg_query_go" nodes "github.com/lfittl/pg_query_go/nodes" ) @@ -87,13 +88,14 @@ func TestLineColumn(t *testing.T) { func TestExtractArgs(t *testing.T) { queries := []struct { - query string - count int + query string + bindNumbers []int }{ - {"SELECT * FROM venue WHERE slug = $1 AND city = $2", 2}, - {"SELECT * FROM venue WHERE slug = $1", 1}, - {"SELECT * FROM venue LIMIT $1", 1}, - {"SELECT * FROM venue OFFSET $1", 1}, + {"SELECT * FROM venue WHERE slug = $1 AND city = $2", []int{1, 2}}, + {"SELECT * FROM venue WHERE slug = $1 AND region = $2 AND city = $3 AND country = $2", []int{1, 2, 3, 2}}, + {"SELECT * FROM venue WHERE slug = $1", []int{1}}, + {"SELECT * FROM venue LIMIT $1", []int{1}}, + {"SELECT * FROM venue OFFSET $1", []int{1}}, } for _, q := range queries { tree, err := pg.Parse(q.query) @@ -105,8 +107,46 @@ func TestExtractArgs(t *testing.T) { if err != nil { t.Error(err) } - if len(refs) != q.count { - t.Errorf("expected %d refs, got %d", q.count, len(refs)) + nums := make([]int, len(refs)) + for i, n := range refs { + nums[i] = n.ref.Number + } + if diff := cmp.Diff(q.bindNumbers, nums); diff != "" { + t.Errorf("expected bindings %v, got %v", q.bindNumbers, nums) + } + } + } +} + +func TestRewriteParameters(t *testing.T) { + queries := []struct { + orig string + new string + }{ + {"SELECT * FROM venue WHERE slug = $1 AND city = $3 AND bar = $2", "SELECT * FROM venue WHERE slug = ? AND city = ? AND bar = ?"}, + {"DELETE FROM venue WHERE slug = $1 AND slug = $1", "DELETE FROM venue WHERE slug = ? AND slug = ?"}, + {"SELECT * FROM venue LIMIT $1", "SELECT * FROM venue LIMIT ?"}, + } + for _, q := range queries { + tree, err := pg.Parse(q.orig) + if err != nil { + t.Fatal(err) + } + for _, stmt := range tree.Statements { + refs := findParameters(stmt) + if err != nil { + t.Error(err) + } + edits, err := rewriteNumberedParameters(refs, stmt.(nodes.RawStmt), q.orig) + if err != nil { + t.Error(err) + } + rewritten, err := editQuery(q.orig, edits) + if err != nil { + t.Error(err) + } + if rewritten != q.new { + t.Errorf("expected %q, got %q", q.new, rewritten) } } }