-
Notifications
You must be signed in to change notification settings - Fork 1k
parser: optionally rewrite numbered params to positional params #306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -229,7 +229,8 @@ func ParseQueries(c core.Catalog, pkg PackageSettings) (*Result, error) { | |
| continue | ||
| } | ||
| for _, stmt := range tree.Statements { | ||
| query, err := parseQuery(c, stmt, source) | ||
| rewriteParameters := pkg.rewriteParams | ||
| query, err := parseQuery(c, stmt, source, rewriteParameters) | ||
| if err == errUnsupportedStatementType { | ||
| continue | ||
| } | ||
|
|
@@ -407,7 +408,7 @@ func validateCmd(n nodes.Node, name, cmd string) error { | |
|
|
||
| var errUnsupportedStatementType = errors.New("parseQuery: unsupported statement type") | ||
|
|
||
| func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) { | ||
| func parseQuery(c core.Catalog, stmt nodes.Node, source string, rewriteParameters bool) (*Query, error) { | ||
| if err := validateParamRef(stmt); err != nil { | ||
| return nil, err | ||
| } | ||
|
|
@@ -443,6 +444,16 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) | |
| } | ||
| rvs := rangeVars(raw.Stmt) | ||
| refs := findParameters(raw.Stmt) | ||
| var edits []edit | ||
| if rewriteParameters { | ||
| edits, err = rewriteNumberedParameters(refs, raw, rawSQL) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
| } else { | ||
| refs = uniqueParamRefs(refs) | ||
| sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number }) | ||
| } | ||
| params, err := resolveCatalogRefs(c, rvs, refs) | ||
| if err != nil { | ||
| return nil, err | ||
|
|
@@ -452,7 +463,13 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) | |
| if err != nil { | ||
| return nil, err | ||
| } | ||
| expanded, err := expand(c, raw, rawSQL) | ||
| expandEdits, err := expand(c, raw, rawSQL) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
| edits = append(edits, expandEdits...) | ||
|
|
||
| expanded, err := editQuery(rawSQL, edits) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
@@ -472,6 +489,18 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) | |
| }, nil | ||
| } | ||
|
|
||
| func rewriteNumberedParameters(refs []paramRef, raw nodes.RawStmt, sql string) ([]edit, error) { | ||
| edits := make([]edit, len(refs)) | ||
| for i, ref := range refs { | ||
| edits[i] = edit{ | ||
| Location: ref.ref.Location - raw.StmtLocation, | ||
| Old: fmt.Sprintf("$%d", ref.ref.Number), | ||
| New: "?", | ||
| } | ||
| } | ||
| return edits, nil | ||
| } | ||
|
|
||
| func stripComments(sql string) (string, []string, error) { | ||
| s := bufio.NewScanner(strings.NewReader(sql)) | ||
| var lines, comments []string | ||
|
|
@@ -494,7 +523,7 @@ type edit struct { | |
| New string | ||
| } | ||
|
|
||
| func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) { | ||
| func expand(c core.Catalog, raw nodes.RawStmt, sql string) ([]edit, error) { | ||
| list := search(raw, func(node nodes.Node) bool { | ||
| switch node.(type) { | ||
| case nodes.DeleteStmt: | ||
|
|
@@ -507,17 +536,17 @@ func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) { | |
| return true | ||
| }) | ||
| if len(list.Items) == 0 { | ||
| return sql, nil | ||
| return nil, nil | ||
| } | ||
| var edits []edit | ||
| for _, item := range list.Items { | ||
| edit, err := expandStmt(c, raw, item) | ||
| if err != nil { | ||
| return "", err | ||
| return nil, err | ||
| } | ||
| edits = append(edits, edit...) | ||
| } | ||
| return editQuery(sql, edits) | ||
| return edits, nil | ||
| } | ||
|
|
||
| func expandStmt(c core.Catalog, raw nodes.RawStmt, node nodes.Node) ([]edit, error) { | ||
|
|
@@ -958,7 +987,8 @@ type paramRef struct { | |
| type paramSearch struct { | ||
| parent nodes.Node | ||
| rangeVar *nodes.RangeVar | ||
| refs map[int]paramRef | ||
| refs *[]paramRef | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I tried to make |
||
| seen map[int]struct{} | ||
|
|
||
| // XXX: Gross state hack for limit | ||
| limitCount nodes.Node | ||
|
|
@@ -1005,7 +1035,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { | |
| continue | ||
| } | ||
| // TODO: Out-of-bounds panic | ||
| p.refs[ref.Number] = paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar} | ||
| *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar}) | ||
| p.seen[ref.Location] = struct{}{} | ||
| } | ||
| for _, vl := range s.ValuesLists { | ||
| for i, v := range vl { | ||
|
|
@@ -1014,7 +1045,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { | |
| continue | ||
| } | ||
| // TODO: Out-of-bounds panic | ||
| p.refs[ref.Number] = paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar} | ||
| *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar}) | ||
| p.seen[ref.Location] = struct{}{} | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -1050,7 +1082,7 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { | |
| parent = limitOffset{} | ||
| } | ||
| } | ||
| if _, found := p.refs[n.Number]; found { | ||
| if _, found := p.seen[n.Location]; found { | ||
| break | ||
| } | ||
|
|
||
|
|
@@ -1072,21 +1104,18 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { | |
| } | ||
|
|
||
| if set { | ||
| p.refs[n.Number] = paramRef{parent: parent, ref: n, rv: p.rangeVar} | ||
| *p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar}) | ||
| p.seen[n.Location] = struct{}{} | ||
| } | ||
| return nil | ||
| } | ||
| return p | ||
| } | ||
|
|
||
| func findParameters(root nodes.Node) []paramRef { | ||
| v := paramSearch{refs: map[int]paramRef{}} | ||
| Walk(v, root) | ||
| refs := make([]paramRef, 0) | ||
| for _, r := range v.refs { | ||
| refs = append(refs, r) | ||
| } | ||
| sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number }) | ||
| v := paramSearch{seen: make(map[int]struct{}), refs: &refs} | ||
| Walk(v, root) | ||
| return refs | ||
| } | ||
|
|
||
|
|
@@ -1348,3 +1377,15 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( | |
| } | ||
| return a, nil | ||
| } | ||
|
|
||
| func uniqueParamRefs(in []paramRef) []paramRef { | ||
| m := make(map[int]struct{}, len(in)) | ||
| o := make([]paramRef, 0, len(in)) | ||
| for _, v := range in { | ||
| if _, ok := m[v.ref.Number]; !ok { | ||
| m[v.ref.Number] = struct{}{} | ||
| o = append(o, v) | ||
| } | ||
| } | ||
| return o | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I promise to remove this when Kolint support is merged :)