diff --git a/examples/authors/ydb/query.sql.go b/examples/authors/ydb/query.sql.go index 3233b705d3..e9b6b332a4 100644 --- a/examples/authors/ydb/query.sql.go +++ b/examples/authors/ydb/query.sql.go @@ -25,11 +25,15 @@ type CreateOrUpdateAuthorParams struct { func (q *Queries) CreateOrUpdateAuthor(ctx context.Context, arg CreateOrUpdateAuthorParams, opts ...query.ExecuteOption) error { err := q.db.Exec(ctx, createOrUpdateAuthor, - append(opts, query.WithParameters(ydb.ParamsFromMap(map[string]any{ - "$p0": arg.P0, - "$p1": arg.P1, - "$p2": arg.P2, - })))..., + append(opts, + query.WithParameters( + ydb.ParamsBuilder(). + Param("$p0").Uint64(arg.P0). + Param("$p1").Text(arg.P1). + Param("$p2").BeginOptional().Text(arg.P2).EndOptional(). + Build(), + ), + )..., ) if err != nil { return xerrors.WithStackTrace(err) @@ -43,9 +47,13 @@ DELETE FROM authors WHERE id = $p0 func (q *Queries) DeleteAuthor(ctx context.Context, p0 uint64, opts ...query.ExecuteOption) error { err := q.db.Exec(ctx, deleteAuthor, - append(opts, query.WithParameters(ydb.ParamsFromMap(map[string]any{ - "$p0": p0, - })))..., + append(opts, + query.WithParameters( + ydb.ParamsBuilder(). + Param("$p0").Uint64(p0). + Build(), + ), + )..., ) if err != nil { return xerrors.WithStackTrace(err) @@ -72,9 +80,13 @@ WHERE id = $p0 LIMIT 1 func (q *Queries) GetAuthor(ctx context.Context, p0 uint64, opts ...query.ExecuteOption) (Author, error) { row, err := q.db.QueryRow(ctx, getAuthor, - append(opts, query.WithParameters(ydb.ParamsFromMap(map[string]any{ - "$p0": p0, - })))..., + append(opts, + query.WithParameters( + ydb.ParamsBuilder(). + Param("$p0").Uint64(p0). + Build(), + ), + )..., ) var i Author if err != nil { @@ -92,25 +104,20 @@ SELECT id, name, bio FROM authors ORDER BY name ` func (q *Queries) ListAuthors(ctx context.Context, opts ...query.ExecuteOption) ([]Author, error) { - result, err := q.db.Query(ctx, listAuthors, opts...) + result, err := q.db.QueryResultSet(ctx, listAuthors, opts...) if err != nil { return nil, xerrors.WithStackTrace(err) } var items []Author - for set, err := range result.ResultSets(ctx) { + for row, err := range result.Rows(ctx) { if err != nil { return nil, xerrors.WithStackTrace(err) } - for row, err := range set.Rows(ctx) { - if err != nil { - return nil, xerrors.WithStackTrace(err) - } - var i Author - if err := row.Scan(&i.ID, &i.Name, &i.Bio); err != nil { - return nil, xerrors.WithStackTrace(err) - } - items = append(items, i) + var i Author + if err := row.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, xerrors.WithStackTrace(err) } + items = append(items, i) } if err := result.Close(ctx); err != nil { return nil, xerrors.WithStackTrace(err) diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 02a09c3870..7cda1b7c2b 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/sqlc-dev/sqlc/internal/codegen/golang/opts" + "github.com/sqlc-dev/sqlc/internal/codegen/sdk" "github.com/sqlc-dev/sqlc/internal/metadata" "github.com/sqlc-dev/sqlc/internal/plugin" ) @@ -294,6 +295,106 @@ func (v QueryValue) YDBParamMapEntries() string { return "\n" + strings.Join(parts, ",\n") } +// ydbBuilderMethodForColumnType maps a YDB column data type to a ParamsBuilder method name. +func ydbBuilderMethodForColumnType(dbType string) string { + switch strings.ToLower(dbType) { + case "bool": + return "Bool" + case "uint64": + return "Uint64" + case "int64": + return "Int64" + case "uint32": + return "Uint32" + case "int32": + return "Int32" + case "uint16": + return "Uint16" + case "int16": + return "Int16" + case "uint8": + return "Uint8" + case "int8": + return "Int8" + case "float": + return "Float" + case "double": + return "Double" + case "json": + return "JSON" + case "jsondocument": + return "JSONDocument" + case "utf8", "text", "string": + return "Text" + case "date": + return "Date" + case "date32": + return "Date32" + case "datetime": + return "Datetime" + case "timestamp": + return "Timestamp" + case "tzdate": + return "TzDate" + case "tzdatetime": + return "TzDatetime" + case "tztimestamp": + return "TzTimestamp" + + //TODO: support other types + default: + return "" + } +} + +// YDBParamsBuilder emits Go code that constructs YDB params using ParamsBuilder. +func (v QueryValue) YDBParamsBuilder() string { + if v.isEmpty() { + return "" + } + + var lines []string + + for _, field := range v.getParameterFields() { + if field.Column != nil && field.Column.IsNamedParam { + name := field.Column.GetName() + if name == "" { + continue + } + paramName := fmt.Sprintf("%q", addDollarPrefix(name)) + variable := escape(v.VariableForField(field)) + + var method string + if field.Column != nil && field.Column.Type != nil { + method = ydbBuilderMethodForColumnType(sdk.DataType(field.Column.Type)) + } + + goType := field.Type + isPtr := strings.HasPrefix(goType, "*") + if isPtr { + goType = strings.TrimPrefix(goType, "*") + } + + if method == "" { + panic(fmt.Sprintf("unknown YDB column type for param %s (goType=%s)", name, goType)) + } + + if isPtr { + lines = append(lines, fmt.Sprintf("\t\t\tParam(%s).BeginOptional().%s(%s).EndOptional().", paramName, method, variable)) + } else { + lines = append(lines, fmt.Sprintf("\t\t\tParam(%s).%s(%s).", paramName, method, variable)) + } + } + } + + if len(lines) == 0 { + return "" + } + + params := strings.Join(lines, "\n") + return fmt.Sprintf("\nquery.WithParameters(\n\t\tydb.ParamsBuilder().\n%s\n\t\t\tBuild(),\n\t\t),\n", params) +} + func (v QueryValue) getParameterFields() []Field { if v.Struct == nil { return []Field{ diff --git a/internal/codegen/golang/templates/ydb-go-sdk/queryCode.tmpl b/internal/codegen/golang/templates/ydb-go-sdk/queryCode.tmpl index ecd78b1344..c56fc953f8 100644 --- a/internal/codegen/golang/templates/ydb-go-sdk/queryCode.tmpl +++ b/internal/codegen/golang/templates/ydb-go-sdk/queryCode.tmpl @@ -27,8 +27,8 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBA {{- if .Arg.IsEmpty }} row, err := {{$dbArg}}.QueryRow(ctx, {{.ConstantName}}, opts...) {{- else }} - row, err := {{$dbArg}}.QueryRow(ctx, {{.ConstantName}}, - append(opts, query.WithParameters(ydb.ParamsFromMap(map[string]any{ {{.Arg.YDBParamMapEntries}} })))..., + row, err := {{$dbArg}}.QueryRow(ctx, {{.ConstantName}}, + append(opts, {{.Arg.YDBParamsBuilder}})..., ) {{- end }} {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} @@ -61,10 +61,10 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBA func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBArgument}}db DBTX, {{end}}{{if not .Arg.IsEmpty}}{{.Arg.Pair}}, {{end}}opts ...query.ExecuteOption) ([]{{.Ret.DefineType}}, error) { {{- $dbArg := "q.db" }}{{- if $.EmitMethodsWithDBArgument }}{{- $dbArg = "db" }}{{- end -}} {{- if .Arg.IsEmpty }} - result, err := {{$dbArg}}.Query(ctx, {{.ConstantName}}, opts...) + result, err := {{$dbArg}}.QueryResultSet(ctx, {{.ConstantName}}, opts...) {{- else }} - result, err := {{$dbArg}}.Query(ctx, {{.ConstantName}}, - append(opts, query.WithParameters(ydb.ParamsFromMap(map[string]any{ {{.Arg.YDBParamMapEntries}} })))..., + result, err := {{$dbArg}}.QueryResultSet(ctx, {{.ConstantName}}, + append(opts, {{.Arg.YDBParamsBuilder}})..., ) {{- end }} if err != nil { @@ -79,7 +79,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBA {{else}} var items []{{.Ret.DefineType}} {{end -}} - for set, err := range result.ResultSets(ctx) { + for row, err := range result.Rows(ctx) { if err != nil { {{- if $.WrapErrors}} return nil, xerrors.WithStackTrace(fmt.Errorf("query {{.MethodName}}: %w", err)) @@ -87,24 +87,15 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBA return nil, xerrors.WithStackTrace(err) {{- end }} } - for row, err := range set.Rows(ctx) { - if err != nil { - {{- if $.WrapErrors}} - return nil, xerrors.WithStackTrace(fmt.Errorf("query {{.MethodName}}: %w", err)) - {{- else }} - return nil, xerrors.WithStackTrace(err) - {{- end }} - } - var {{.Ret.Name}} {{.Ret.Type}} - if err := row.Scan({{.Ret.Scan}}); err != nil { - {{- if $.WrapErrors}} - return nil, xerrors.WithStackTrace(fmt.Errorf("query {{.MethodName}}: %w", err)) - {{- else }} - return nil, xerrors.WithStackTrace(err) - {{- end }} - } - items = append(items, {{.Ret.ReturnName}}) - } + var {{.Ret.Name}} {{.Ret.Type}} + if err := row.Scan({{.Ret.Scan}}); err != nil { + {{- if $.WrapErrors}} + return nil, xerrors.WithStackTrace(fmt.Errorf("query {{.MethodName}}: %w", err)) + {{- else }} + return nil, xerrors.WithStackTrace(err) + {{- end }} + } + items = append(items, {{.Ret.ReturnName}}) } if err := result.Close(ctx); err != nil { {{- if $.WrapErrors}} @@ -125,8 +116,8 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{if $.EmitMethodsWithDBA {{- if .Arg.IsEmpty }} err := {{$dbArg}}.Exec(ctx, {{.ConstantName}}, opts...) {{- else }} - err := {{$dbArg}}.Exec(ctx, {{.ConstantName}}, - append(opts, query.WithParameters(ydb.ParamsFromMap(map[string]any{ {{.Arg.YDBParamMapEntries}} })))..., + err := {{$dbArg}}.Exec(ctx, {{.ConstantName}}, + append(opts, {{.Arg.YDBParamsBuilder}})..., ) {{- end }} if err != nil { diff --git a/internal/codegen/golang/ydb_type.go b/internal/codegen/golang/ydb_type.go index e9e5c46344..0ef665aee1 100644 --- a/internal/codegen/golang/ydb_type.go +++ b/internal/codegen/golang/ydb_type.go @@ -151,6 +151,24 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col // return "sql.NullInt64" return "*int64" + case "json", "jsondocument": + if notNull { + return "string" + } + if emitPointersForNull { + return "*string" + } + return "*string" + + case "date", "date32", "datetime", "timestamp", "tzdate", "tztimestamp", "tzdatetime": + if notNull { + return "time.Time" + } + if emitPointersForNull { + return "*time.Time" + } + return "*time.Time" + case "null": // return "sql.Null" return "interface{}" diff --git a/internal/engine/ydb/catalog_tests/create_table_test.go b/internal/engine/ydb/catalog_tests/create_table_test.go index e98288d75a..7761118927 100644 --- a/internal/engine/ydb/catalog_tests/create_table_test.go +++ b/internal/engine/ydb/catalog_tests/create_table_test.go @@ -106,7 +106,7 @@ func TestCreateTable(t *testing.T) { { Name: "amount", Type: ast.TypeName{ - Name: "Decimal", + Name: "decimal", Names: &ast.List{ Items: []ast.Node{ &ast.Integer{Ival: 22}, diff --git a/internal/engine/ydb/convert.go b/internal/engine/ydb/convert.go index b4d9490d0b..905c972375 100755 --- a/internal/engine/ydb/convert.go +++ b/internal/engine/ydb/convert.go @@ -1278,7 +1278,7 @@ func (c *cc) convertTypeName(n parser.IType_nameContext) *ast.TypeName { if decimal := n.Type_name_decimal(); decimal != nil { if integerOrBinds := decimal.AllInteger_or_bind(); len(integerOrBinds) >= 2 { return &ast.TypeName{ - Name: "Decimal", + Name: "decimal", TypeOid: 0, Names: &ast.List{ Items: []ast.Node{