diff --git a/internal/codegen/golang/mysql_type.go b/internal/codegen/golang/mysql_type.go index 15c5838e3a..4a56d64f23 100644 --- a/internal/codegen/golang/mysql_type.go +++ b/internal/codegen/golang/mysql_type.go @@ -21,7 +21,20 @@ func mysqlType(r *compiler.Result, col *compiler.Column, settings config.Combine } return "sql.NullString" - case "int", "integer", "tinyint", "smallint", "mediumint", "year": + case "tinyint": + if col.Length != nil && *col.Length == 1 { + if notNull { + return "bool" + } + return "sql.NullBool" + } else { + if notNull { + return "int32" + } + return "sql.NullInt32" + } + + case "int", "integer", "smallint", "mediumint", "year": if notNull { return "int32" } diff --git a/internal/compiler/query.go b/internal/compiler/query.go index 768c67c7b0..3c6ebfc266 100644 --- a/internal/compiler/query.go +++ b/internal/compiler/query.go @@ -1,6 +1,8 @@ package compiler -import "github.com/kyleconroy/sqlc/internal/sql/ast" +import ( + "github.com/kyleconroy/sqlc/internal/sql/ast" +) type Table struct { Rel *ast.TableName @@ -13,6 +15,7 @@ type Column struct { NotNull bool IsArray bool Comment string + Length *int // XXX: Figure out what PostgreSQL calls `foo.id` Scope string diff --git a/internal/compiler/query_catalog.go b/internal/compiler/query_catalog.go index 5071f17662..ea477208c2 100644 --- a/internal/compiler/query_catalog.go +++ b/internal/compiler/query_catalog.go @@ -52,6 +52,7 @@ func ConvertColumn(rel *ast.TableName, c *catalog.Column) *Column { NotNull: c.IsNotNull, IsArray: c.IsArray, Type: &c.Type, + Length: c.Length, } } diff --git a/internal/endtoend/testdata/data_type_boolean/mysql/db/db.go b/internal/endtoend/testdata/data_type_boolean/mysql/db/db.go new file mode 100644 index 0000000000..c3c034ae37 --- /dev/null +++ b/internal/endtoend/testdata/data_type_boolean/mysql/db/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package db + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/data_type_boolean/mysql/db/models.go b/internal/endtoend/testdata/data_type_boolean/mysql/db/models.go new file mode 100644 index 0000000000..84236c7d59 --- /dev/null +++ b/internal/endtoend/testdata/data_type_boolean/mysql/db/models.go @@ -0,0 +1,19 @@ +// Code generated by sqlc. DO NOT EDIT. + +package db + +import ( + "database/sql" +) + +type Bar struct { + ColA sql.NullBool + ColB sql.NullBool + ColC sql.NullBool +} + +type Foo struct { + ColA bool + ColB bool + ColC bool +} diff --git a/internal/endtoend/testdata/data_type_boolean/mysql/db/query.sql.go b/internal/endtoend/testdata/data_type_boolean/mysql/db/query.sql.go new file mode 100644 index 0000000000..e1a4be881c --- /dev/null +++ b/internal/endtoend/testdata/data_type_boolean/mysql/db/query.sql.go @@ -0,0 +1,75 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package db + +import ( + "context" + "database/sql" +) + +const listBar = `-- name: ListBar :many +SELECT col_a, col_b, col_c FROM bar +` + +type ListBarRow struct { + ColA sql.NullInt32 + ColB sql.NullInt32 + ColC sql.NullInt32 +} + +func (q *Queries) ListBar(ctx context.Context) ([]ListBarRow, error) { + rows, err := q.db.QueryContext(ctx, listBar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListBarRow + for rows.Next() { + var i ListBarRow + if err := rows.Scan(&i.ColA, &i.ColB, &i.ColC); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listFoo = `-- name: ListFoo :many +SELECT col_a, col_b, col_c FROM foo +` + +type ListFooRow struct { + ColA int32 + ColB int32 + ColC int32 +} + +func (q *Queries) ListFoo(ctx context.Context) ([]ListFooRow, error) { + rows, err := q.db.QueryContext(ctx, listFoo) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListFooRow + for rows.Next() { + var i ListFooRow + if err := rows.Scan(&i.ColA, &i.ColB, &i.ColC); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/data_type_boolean/mysql/query.sql b/internal/endtoend/testdata/data_type_boolean/mysql/query.sql new file mode 100644 index 0000000000..210488ceac --- /dev/null +++ b/internal/endtoend/testdata/data_type_boolean/mysql/query.sql @@ -0,0 +1,19 @@ +CREATE TABLE foo +( + col_a BOOL NOT NULL, + col_b BOOLEAN NOT NULL, + col_c TINYINT(1) NOT NULL +); + +-- name: ListFoo :many +SELECT * FROM foo; + +CREATE TABLE bar +( + col_a BOOL, + col_b BOOLEAN, + col_c TINYINT(1) +); + +-- name: ListBar :many +SELECT * FROM bar; diff --git a/internal/endtoend/testdata/data_type_boolean/mysql/sqlc.json b/internal/endtoend/testdata/data_type_boolean/mysql/sqlc.json new file mode 100644 index 0000000000..a279d21f99 --- /dev/null +++ b/internal/endtoend/testdata/data_type_boolean/mysql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "db", + "engine": "mysql", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} + diff --git a/internal/endtoend/testdata/data_type_boolean/postgresql/db/db.go b/internal/endtoend/testdata/data_type_boolean/postgresql/db/db.go new file mode 100644 index 0000000000..c3c034ae37 --- /dev/null +++ b/internal/endtoend/testdata/data_type_boolean/postgresql/db/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package db + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/data_type_boolean/postgresql/db/models.go b/internal/endtoend/testdata/data_type_boolean/postgresql/db/models.go new file mode 100644 index 0000000000..49607fd39f --- /dev/null +++ b/internal/endtoend/testdata/data_type_boolean/postgresql/db/models.go @@ -0,0 +1,17 @@ +// Code generated by sqlc. DO NOT EDIT. + +package db + +import ( + "database/sql" +) + +type Bar struct { + ColA sql.NullBool + ColB sql.NullBool +} + +type Foo struct { + ColA bool + ColB bool +} diff --git a/internal/endtoend/testdata/data_type_boolean/postgresql/db/query.sql.go b/internal/endtoend/testdata/data_type_boolean/postgresql/db/query.sql.go new file mode 100644 index 0000000000..61aa8aab72 --- /dev/null +++ b/internal/endtoend/testdata/data_type_boolean/postgresql/db/query.sql.go @@ -0,0 +1,62 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package db + +import ( + "context" +) + +const listBar = `-- name: ListBar :many +SELECT col_a, col_b FROM bar +` + +func (q *Queries) ListBar(ctx context.Context) ([]Bar, error) { + rows, err := q.db.QueryContext(ctx, listBar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Bar + for rows.Next() { + var i Bar + if err := rows.Scan(&i.ColA, &i.ColB); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listFoo = `-- name: ListFoo :many +SELECT col_a, col_b FROM foo +` + +func (q *Queries) ListFoo(ctx context.Context) ([]Foo, error) { + rows, err := q.db.QueryContext(ctx, listFoo) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Foo + for rows.Next() { + var i Foo + if err := rows.Scan(&i.ColA, &i.ColB); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/data_type_boolean/postgresql/query.sql b/internal/endtoend/testdata/data_type_boolean/postgresql/query.sql new file mode 100644 index 0000000000..53ad078373 --- /dev/null +++ b/internal/endtoend/testdata/data_type_boolean/postgresql/query.sql @@ -0,0 +1,17 @@ +CREATE TABLE foo +( + col_a BOOL NOT NULL, + col_b BOOLEAN NOT NULL +); + +-- name: ListFoo :many +SELECT * FROM foo; + +CREATE TABLE bar +( + col_a BOOL, + col_b BOOLEAN +); + +-- name: ListBar :many +SELECT * FROM bar; diff --git a/internal/endtoend/testdata/data_type_boolean/postgresql/sqlc.json b/internal/endtoend/testdata/data_type_boolean/postgresql/sqlc.json new file mode 100644 index 0000000000..5fb2f8b6de --- /dev/null +++ b/internal/endtoend/testdata/data_type_boolean/postgresql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "db", + "engine": "postgresql", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} + diff --git a/internal/endtoend/testdata/insert_select/mysql/go/models.go b/internal/endtoend/testdata/insert_select/mysql/go/models.go index d216fceb7e..b9f6dc7049 100644 --- a/internal/endtoend/testdata/insert_select/mysql/go/models.go +++ b/internal/endtoend/testdata/insert_select/mysql/go/models.go @@ -6,7 +6,7 @@ import () type Bar struct { Name string - Ready int32 + Ready bool } type Foo struct { diff --git a/internal/engine/dolphin/convert.go b/internal/engine/dolphin/convert.go index df857c4691..1859a495fe 100644 --- a/internal/engine/dolphin/convert.go +++ b/internal/engine/dolphin/convert.go @@ -35,14 +35,19 @@ func (c *cc) convertAlterTableStmt(n *pcast.AlterTableStmt) ast.Node { case pcast.AlterTableAddColumns: for _, def := range spec.NewColumns { name := def.Name.String() + columnDef := ast.ColumnDef{ + Colname: def.Name.String(), + TypeName: &ast.TypeName{Name: types.TypeStr(def.Tp.Tp)}, + IsNotNull: isNotNull(def), + } + if def.Tp.Flen >= 0 { + length := def.Tp.Flen + columnDef.Length = &length + } alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{ Name: &name, Subtype: ast.AT_AddColumn, - Def: &ast.ColumnDef{ - Colname: def.Name.String(), - TypeName: &ast.TypeName{Name: types.TypeStr(def.Tp.Tp)}, - IsNotNull: isNotNull(def), - }, + Def: &columnDef, }) } @@ -60,6 +65,15 @@ func (c *cc) convertAlterTableStmt(n *pcast.AlterTableStmt) ast.Node { case pcast.AlterTableModifyColumn: for _, def := range spec.NewColumns { name := def.Name.String() + columnDef := ast.ColumnDef{ + Colname: def.Name.String(), + TypeName: &ast.TypeName{Name: types.TypeStr(def.Tp.Tp)}, + IsNotNull: isNotNull(def), + } + if def.Tp.Flen >= 0 { + length := def.Tp.Flen + columnDef.Length = &length + } alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{ Name: &name, Subtype: ast.AT_DropColumn, @@ -67,11 +81,7 @@ func (c *cc) convertAlterTableStmt(n *pcast.AlterTableStmt) ast.Node { alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{ Name: &name, Subtype: ast.AT_AddColumn, - Def: &ast.ColumnDef{ - Colname: def.Name.String(), - TypeName: &ast.TypeName{Name: types.TypeStr(def.Tp.Tp)}, - IsNotNull: isNotNull(def), - }, + Def: &columnDef, }) } @@ -224,13 +234,18 @@ func (c *cc) convertCreateTableStmt(n *pcast.CreateTableStmt) ast.Node { } } } - create.Cols = append(create.Cols, &ast.ColumnDef{ + columnDef := ast.ColumnDef{ Colname: def.Name.String(), TypeName: &ast.TypeName{Name: types.TypeStr(def.Tp.Tp)}, IsNotNull: isNotNull(def), Comment: comment, Vals: vals, - }) + } + if def.Tp.Flen >= 0 { + length := def.Tp.Flen + columnDef.Length = &length + } + create.Cols = append(create.Cols, &columnDef) } for _, opt := range n.Options { switch opt.Tp { diff --git a/internal/sql/ast/column_def.go b/internal/sql/ast/column_def.go index d52d8002f3..279f97bf48 100644 --- a/internal/sql/ast/column_def.go +++ b/internal/sql/ast/column_def.go @@ -6,6 +6,7 @@ type ColumnDef struct { IsNotNull bool IsArray bool Vals *List + Length *int // From pg.ColumnDef Inhcount int diff --git a/internal/sql/catalog/catalog.go b/internal/sql/catalog/catalog.go index 94312c4877..2263ec5b6e 100644 --- a/internal/sql/catalog/catalog.go +++ b/internal/sql/catalog/catalog.go @@ -193,6 +193,7 @@ type Column struct { IsNotNull bool IsArray bool Comment string + Length *int } type Type interface { diff --git a/internal/sql/catalog/table.go b/internal/sql/catalog/table.go index a2fc0b5da8..a91404c8d8 100644 --- a/internal/sql/catalog/table.go +++ b/internal/sql/catalog/table.go @@ -74,6 +74,7 @@ func (c *Catalog) alterTable(stmt *ast.AlterTableStmt) error { Type: *cmd.Def.TypeName, IsNotNull: cmd.Def.IsNotNull, IsArray: cmd.Def.IsArray, + Length: cmd.Def.Length, }) case ast.AT_AlterColumnType: @@ -160,6 +161,7 @@ func (c *Catalog) createTable(stmt *ast.CreateTableStmt) error { IsNotNull: col.IsNotNull, IsArray: col.IsArray, Comment: col.Comment, + Length: col.Length, } if col.Vals != nil { typeName := ast.TypeName{