diff --git a/internal/dinosql/parser.go b/internal/dinosql/parser.go index a5a12a7c8d..3d60cedde8 100644 --- a/internal/dinosql/parser.go +++ b/internal/dinosql/parser.go @@ -574,6 +574,31 @@ func outputColumns(c core.Catalog, node nodes.Node) ([]core.Column, error) { cols = append(cols, core.Column{Name: name, DataType: "any", NotNull: false}) } + case nodes.CaseExpr: + name := "" + if res.Name != nil { + name = *res.Name + } + // TODO: The TypeCase code has been copied from below. Instead, we need a recurse function to get the type of a node. + if tc, ok := n.Defresult.(nodes.TypeCast); ok { + if tc.TypeName == nil { + return nil, errors.New("no type name type cast") + } + name := "" + if ref, ok := tc.Arg.(nodes.ColumnRef); ok { + name = join(ref.Fields, "_") + } + if res.Name != nil { + name = *res.Name + } + // TODO Validate column names + col := catalog.ToColumn(tc.TypeName) + col.Name = name + cols = append(cols, col) + } else { + cols = append(cols, core.Column{Name: name, DataType: "any", NotNull: false}) + } + case nodes.CoalesceExpr: for _, arg := range n.Args.Items { if ref, ok := arg.(nodes.ColumnRef); ok { @@ -652,6 +677,14 @@ func outputColumns(c core.Catalog, node nodes.Node) ([]core.Column, error) { col := catalog.ToColumn(n.TypeName) col.Name = name cols = append(cols, col) + + default: + name := "" + if res.Name != nil { + name = *res.Name + } + cols = append(cols, core.Column{Name: name, DataType: "any", NotNull: false}) + } } return cols, nil diff --git a/internal/dinosql/query_test.go b/internal/dinosql/query_test.go index 06bdd63534..d90c60d40f 100644 --- a/internal/dinosql/query_test.go +++ b/internal/dinosql/query_test.go @@ -703,6 +703,25 @@ func TestQueries(t *testing.T) { }, }, }, + { + "case-stmt-bool", + ` + CREATE TABLE foo (id text not null); + SELECT CASE + WHEN id = $1 THEN true + ELSE false + END is_one + FROM foo; + `, + Query{ + Params: []Parameter{ + {1, core.Column{Table: public("foo"), Name: "id", DataType: "text", NotNull: true}}, + }, + Columns: []core.Column{ + {Name: "is_one", DataType: "pg_catalog.bool", NotNull: true}, + }, + }, + }, } { test := tc t.Run(test.name, func(t *testing.T) {