Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions internal/dinosql/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ type PackageSettings struct {
EmitJSONTags bool `json:"emit_json_tags"`
EmitPreparedQueries bool `json:"emit_prepared_queries"`
Overrides []Override `json:"overrides"`
// HACK: this is only set in tests, only here till Kotlin support can be merged.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I promise to remove this when Kolint support is merged :)

rewriteParams bool
}

type Override struct {
Expand Down
77 changes: 59 additions & 18 deletions internal/dinosql/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ func ParseQueries(c core.Catalog, pkg PackageSettings) (*Result, error) {
continue
}
for _, stmt := range tree.Statements {
query, err := parseQuery(c, stmt, source)
rewriteParameters := pkg.rewriteParams
query, err := parseQuery(c, stmt, source, rewriteParameters)
if err == errUnsupportedStatementType {
continue
}
Expand Down Expand Up @@ -407,7 +408,7 @@ func validateCmd(n nodes.Node, name, cmd string) error {

var errUnsupportedStatementType = errors.New("parseQuery: unsupported statement type")

func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) {
func parseQuery(c core.Catalog, stmt nodes.Node, source string, rewriteParameters bool) (*Query, error) {
if err := validateParamRef(stmt); err != nil {
return nil, err
}
Expand Down Expand Up @@ -443,6 +444,16 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
}
rvs := rangeVars(raw.Stmt)
refs := findParameters(raw.Stmt)
var edits []edit
if rewriteParameters {
edits, err = rewriteNumberedParameters(refs, raw, rawSQL)
if err != nil {
return nil, err
}
} else {
refs = uniqueParamRefs(refs)
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
}
params, err := resolveCatalogRefs(c, rvs, refs)
if err != nil {
return nil, err
Expand All @@ -452,7 +463,13 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
if err != nil {
return nil, err
}
expanded, err := expand(c, raw, rawSQL)
expandEdits, err := expand(c, raw, rawSQL)
if err != nil {
return nil, err
}
edits = append(edits, expandEdits...)

expanded, err := editQuery(rawSQL, edits)
if err != nil {
return nil, err
}
Expand All @@ -472,6 +489,18 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
}, nil
}

func rewriteNumberedParameters(refs []paramRef, raw nodes.RawStmt, sql string) ([]edit, error) {
edits := make([]edit, len(refs))
for i, ref := range refs {
edits[i] = edit{
Location: ref.ref.Location - raw.StmtLocation,
Old: fmt.Sprintf("$%d", ref.ref.Number),
New: "?",
}
}
return edits, nil
}

func stripComments(sql string) (string, []string, error) {
s := bufio.NewScanner(strings.NewReader(sql))
var lines, comments []string
Expand All @@ -494,7 +523,7 @@ type edit struct {
New string
}

func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) {
func expand(c core.Catalog, raw nodes.RawStmt, sql string) ([]edit, error) {
list := search(raw, func(node nodes.Node) bool {
switch node.(type) {
case nodes.DeleteStmt:
Expand All @@ -507,17 +536,17 @@ func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) {
return true
})
if len(list.Items) == 0 {
return sql, nil
return nil, nil
}
var edits []edit
for _, item := range list.Items {
edit, err := expandStmt(c, raw, item)
if err != nil {
return "", err
return nil, err
}
edits = append(edits, edit...)
}
return editQuery(sql, edits)
return edits, nil
}

func expandStmt(c core.Catalog, raw nodes.RawStmt, node nodes.Node) ([]edit, error) {
Expand Down Expand Up @@ -958,7 +987,8 @@ type paramRef struct {
type paramSearch struct {
parent nodes.Node
rangeVar *nodes.RangeVar
refs map[int]paramRef
refs *[]paramRef
Copy link
Contributor Author

@mightyguava mightyguava Jan 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I tried to make Visit take a pointer receiver: (p *paramSearch), rather than doing a pointer to a slice here. This broke the DeleteVenue query parsing somehow. The parent for the second parameter became a FunctionCall instead of an Expr. I'm a compilers noob so I decided to not dive down that rabbit hole.

seen map[int]struct{}

// XXX: Gross state hack for limit
limitCount nodes.Node
Expand Down Expand Up @@ -1005,7 +1035,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor {
continue
}
// TODO: Out-of-bounds panic
p.refs[ref.Number] = paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar}
*p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar})
p.seen[ref.Location] = struct{}{}
}
for _, vl := range s.ValuesLists {
for i, v := range vl {
Expand All @@ -1014,7 +1045,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor {
continue
}
// TODO: Out-of-bounds panic
p.refs[ref.Number] = paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar}
*p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar})
p.seen[ref.Location] = struct{}{}
}
}
}
Expand Down Expand Up @@ -1050,7 +1082,7 @@ func (p paramSearch) Visit(node nodes.Node) Visitor {
parent = limitOffset{}
}
}
if _, found := p.refs[n.Number]; found {
if _, found := p.seen[n.Location]; found {
break
}

Expand All @@ -1072,21 +1104,18 @@ func (p paramSearch) Visit(node nodes.Node) Visitor {
}

if set {
p.refs[n.Number] = paramRef{parent: parent, ref: n, rv: p.rangeVar}
*p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar})
p.seen[n.Location] = struct{}{}
}
return nil
}
return p
}

func findParameters(root nodes.Node) []paramRef {
v := paramSearch{refs: map[int]paramRef{}}
Walk(v, root)
refs := make([]paramRef, 0)
for _, r := range v.refs {
refs = append(refs, r)
}
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
v := paramSearch{seen: make(map[int]struct{}), refs: &refs}
Walk(v, root)
return refs
}

Expand Down Expand Up @@ -1348,3 +1377,15 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) (
}
return a, nil
}

func uniqueParamRefs(in []paramRef) []paramRef {
m := make(map[int]struct{}, len(in))
o := make([]paramRef, 0, len(in))
for _, v := range in {
if _, ok := m[v.ref.Number]; !ok {
m[v.ref.Number] = struct{}{}
o = append(o, v)
}
}
return o
}
56 changes: 48 additions & 8 deletions internal/dinosql/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dinosql
import (
"testing"

"github.com/google/go-cmp/cmp"
pg "github.com/lfittl/pg_query_go"
nodes "github.com/lfittl/pg_query_go/nodes"
)
Expand Down Expand Up @@ -87,13 +88,14 @@ func TestLineColumn(t *testing.T) {

func TestExtractArgs(t *testing.T) {
queries := []struct {
query string
count int
query string
bindNumbers []int
}{
{"SELECT * FROM venue WHERE slug = $1 AND city = $2", 2},
{"SELECT * FROM venue WHERE slug = $1", 1},
{"SELECT * FROM venue LIMIT $1", 1},
{"SELECT * FROM venue OFFSET $1", 1},
{"SELECT * FROM venue WHERE slug = $1 AND city = $2", []int{1, 2}},
{"SELECT * FROM venue WHERE slug = $1 AND region = $2 AND city = $3 AND country = $2", []int{1, 2, 3, 2}},
{"SELECT * FROM venue WHERE slug = $1", []int{1}},
{"SELECT * FROM venue LIMIT $1", []int{1}},
{"SELECT * FROM venue OFFSET $1", []int{1}},
}
for _, q := range queries {
tree, err := pg.Parse(q.query)
Expand All @@ -105,8 +107,46 @@ func TestExtractArgs(t *testing.T) {
if err != nil {
t.Error(err)
}
if len(refs) != q.count {
t.Errorf("expected %d refs, got %d", q.count, len(refs))
nums := make([]int, len(refs))
for i, n := range refs {
nums[i] = n.ref.Number
}
if diff := cmp.Diff(q.bindNumbers, nums); diff != "" {
t.Errorf("expected bindings %v, got %v", q.bindNumbers, nums)
}
}
}
}

func TestRewriteParameters(t *testing.T) {
queries := []struct {
orig string
new string
}{
{"SELECT * FROM venue WHERE slug = $1 AND city = $3 AND bar = $2", "SELECT * FROM venue WHERE slug = ? AND city = ? AND bar = ?"},
{"DELETE FROM venue WHERE slug = $1 AND slug = $1", "DELETE FROM venue WHERE slug = ? AND slug = ?"},
{"SELECT * FROM venue LIMIT $1", "SELECT * FROM venue LIMIT ?"},
}
for _, q := range queries {
tree, err := pg.Parse(q.orig)
if err != nil {
t.Fatal(err)
}
for _, stmt := range tree.Statements {
refs := findParameters(stmt)
if err != nil {
t.Error(err)
}
edits, err := rewriteNumberedParameters(refs, stmt.(nodes.RawStmt), q.orig)
if err != nil {
t.Error(err)
}
rewritten, err := editQuery(q.orig, edits)
if err != nil {
t.Error(err)
}
if rewritten != q.new {
t.Errorf("expected %q, got %q", q.new, rewritten)
}
}
}
Expand Down