diff --git a/internal/engine/ydb/catalog_tests/delete_test.go b/internal/engine/ydb/catalog_tests/delete_test.go index ab7b709be9..1885deb9ce 100644 --- a/internal/engine/ydb/catalog_tests/delete_test.go +++ b/internal/engine/ydb/catalog_tests/delete_test.go @@ -101,6 +101,7 @@ func TestDelete(t *testing.T) { }, }, OnSelectStmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, ValuesLists: &ast.List{ Items: []ast.Node{ &ast.List{ @@ -110,8 +111,12 @@ func TestDelete(t *testing.T) { }, }, }, - FromClause: &ast.List{}, - TargetList: &ast.List{}, + FromClause: &ast.List{}, + TargetList: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, ReturningList: &ast.List{ Items: []ast.Node{ @@ -145,6 +150,7 @@ func TestDelete(t *testing.T) { }, }, OnSelectStmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -153,7 +159,12 @@ func TestDelete(t *testing.T) { }, }, }, - FromClause: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, ReturningList: &ast.List{ Items: []ast.Node{ diff --git a/internal/engine/ydb/catalog_tests/insert_test.go b/internal/engine/ydb/catalog_tests/insert_test.go index 4dea2ceccb..c60d0920da 100644 --- a/internal/engine/ydb/catalog_tests/insert_test.go +++ b/internal/engine/ydb/catalog_tests/insert_test.go @@ -28,6 +28,7 @@ func TestInsert(t *testing.T) { }, }, SelectStmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, ValuesLists: &ast.List{ Items: []ast.Node{ &ast.List{ @@ -40,6 +41,10 @@ func TestInsert(t *testing.T) { }, TargetList: &ast.List{}, FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, OnConflictClause: &ast.OnConflictClause{}, ReturningList: &ast.List{ @@ -68,6 +73,7 @@ func TestInsert(t *testing.T) { }, }, SelectStmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, ValuesLists: &ast.List{ Items: []ast.Node{ &ast.List{ @@ -79,6 +85,10 @@ func TestInsert(t *testing.T) { }, TargetList: &ast.List{}, FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, OnConflictClause: &ast.OnConflictClause{ Action: ast.OnConflictAction_INSERT_OR_IGNORE, @@ -106,7 +116,16 @@ func TestInsert(t *testing.T) { Stmt: &ast.InsertStmt{ Relation: &ast.RangeVar{Relname: strPtr("users")}, Cols: &ast.List{Items: []ast.Node{&ast.ResTarget{Name: strPtr("id")}}}, - SelectStmt: &ast.SelectStmt{ValuesLists: &ast.List{Items: []ast.Node{&ast.List{Items: []ast.Node{&ast.A_Const{Val: &ast.Integer{Ival: 4}}}}}}, TargetList: &ast.List{}, FromClause: &ast.List{}}, + SelectStmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, + ValuesLists: &ast.List{Items: []ast.Node{&ast.List{Items: []ast.Node{&ast.A_Const{Val: &ast.Integer{Ival: 4}}}}}}, + TargetList: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, + }, OnConflictClause: &ast.OnConflictClause{Action: ast.OnConflictAction_UPSERT}, ReturningList: &ast.List{Items: []ast.Node{&ast.ResTarget{Val: &ast.ColumnRef{Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "id"}}}}, Indirection: &ast.List{}}}}, }, diff --git a/internal/engine/ydb/catalog_tests/select_test.go b/internal/engine/ydb/catalog_tests/select_test.go index fa7b22677c..f01171f12a 100644 --- a/internal/engine/ydb/catalog_tests/select_test.go +++ b/internal/engine/ydb/catalog_tests/select_test.go @@ -25,6 +25,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -34,7 +35,12 @@ func TestSelect(t *testing.T) { }, }, }, - FromClause: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -44,6 +50,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -53,7 +60,12 @@ func TestSelect(t *testing.T) { }, }, }, - FromClause: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -63,6 +75,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -72,7 +85,12 @@ func TestSelect(t *testing.T) { }, }, }, - FromClause: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -82,6 +100,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -91,7 +110,12 @@ func TestSelect(t *testing.T) { }, }, }, - FromClause: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -101,6 +125,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -108,7 +133,12 @@ func TestSelect(t *testing.T) { }, }, }, - FromClause: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -118,6 +148,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -127,7 +158,12 @@ func TestSelect(t *testing.T) { }, }, }, - FromClause: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -137,6 +173,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -166,7 +203,12 @@ func TestSelect(t *testing.T) { }, }, }, - FromClause: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -178,6 +220,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -198,6 +241,11 @@ func TestSelect(t *testing.T) { }, }, }, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -207,6 +255,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -228,6 +277,11 @@ func TestSelect(t *testing.T) { }, }, }, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -237,6 +291,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -259,6 +314,11 @@ func TestSelect(t *testing.T) { }, }, }, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -268,6 +328,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -307,6 +368,11 @@ func TestSelect(t *testing.T) { }, }, }, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -316,6 +382,7 @@ func TestSelect(t *testing.T) { expected: &ast.Statement{ Raw: &ast.RawStmt{ Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, TargetList: &ast.List{ Items: []ast.Node{ &ast.ResTarget{ @@ -362,6 +429,287 @@ func TestSelect(t *testing.T) { Val: &ast.Integer{Ival: 30}, }, }, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, + }, + }, + }, + }, + { + stmt: `(SELECT 1) UNION ALL (SELECT 2)`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, + TargetList: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, + Op: ast.Union, + All: true, + Larg: &ast.SelectStmt{ + DistinctClause: &ast.List{}, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.A_Const{ + Val: &ast.Integer{Ival: 1}, + }, + }, + }, + }, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, + }, + Rarg: &ast.SelectStmt{ + DistinctClause: &ast.List{}, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.A_Const{ + Val: &ast.Integer{Ival: 2}, + }, + }, + }, + }, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, + }, + }, + }, + }, + }, + { + stmt: `SELECT id FROM users ORDER BY id DESC`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{ + Relname: strPtr("users"), + }, + }, + }, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + LockingClause: &ast.List{}, + SortClause: &ast.List{ + Items: []ast.Node{ + &ast.SortBy{ + Node: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + SortbyDir: ast.SortByDirDesc, + UseOp: &ast.List{}, + }, + }, + }, + }, + }, + }, + }, + { + stmt: `SELECT id FROM users LIMIT 10 OFFSET 5`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{ + Relname: strPtr("users"), + }, + }, + }, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, + LimitCount: &ast.A_Const{ + Val: &ast.Integer{Ival: 10}, + }, + LimitOffset: &ast.A_Const{ + Val: &ast.Integer{Ival: 5}, + }, + }, + }, + }, + }, + { + stmt: `SELECT id FROM users WHERE id > 10 GROUP BY id HAVING id > 10`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{ + Relname: strPtr("users"), + }, + }, + }, + GroupClause: &ast.List{ + Items: []ast.Node{ + &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + }, + }, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, + WhereClause: &ast.A_Expr{ + Name: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: ">"}, + }, + }, + Lexpr: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + Rexpr: &ast.A_Const{ + Val: &ast.Integer{Ival: 10}, + }, + }, + HavingClause: &ast.A_Expr{ + Name: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: ">"}, + }, + }, + Lexpr: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + Rexpr: &ast.A_Const{ + Val: &ast.Integer{Ival: 10}, + }, + }, + }, + }, + }, + }, + { + stmt: `SELECT id FROM users GROUP BY ROLLUP (id)`, + expected: &ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{ + Relname: strPtr("users"), + }, + }, + }, + GroupClause: &ast.List{ + Items: []ast.Node{ + &ast.GroupingSet{ + Kind: 1, // T_GroupingSet: ROLLUP + Content: &ast.List{ + Items: []ast.Node{ + &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &ast.String{Str: "id"}, + }, + }, + }, + }, + }, + }, + }, + }, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, }, }, @@ -382,12 +730,12 @@ func TestSelect(t *testing.T) { diff := cmp.Diff(tc.expected, &stmts[0], cmpopts.IgnoreFields(ast.RawStmt{}, "StmtLocation", "StmtLen"), - // cmpopts.IgnoreFields(ast.SelectStmt{}, "Location"), cmpopts.IgnoreFields(ast.A_Const{}, "Location"), cmpopts.IgnoreFields(ast.ResTarget{}, "Location"), cmpopts.IgnoreFields(ast.ColumnRef{}, "Location"), cmpopts.IgnoreFields(ast.A_Expr{}, "Location"), cmpopts.IgnoreFields(ast.RangeVar{}, "Location"), + cmpopts.IgnoreFields(ast.SortBy{}, "Location"), ) if diff != "" { t.Errorf("AST mismatch for %q (-expected +got):\n%s", tc.stmt, diff) diff --git a/internal/engine/ydb/catalog_tests/update_test.go b/internal/engine/ydb/catalog_tests/update_test.go index f4f00a92bc..b7ebeb3d6a 100644 --- a/internal/engine/ydb/catalog_tests/update_test.go +++ b/internal/engine/ydb/catalog_tests/update_test.go @@ -121,6 +121,7 @@ func TestUpdate(t *testing.T) { }, }, OnSelectStmt: &ast.SelectStmt{ + DistinctClause: &ast.List{}, ValuesLists: &ast.List{ Items: []ast.Node{ &ast.List{ @@ -132,6 +133,10 @@ func TestUpdate(t *testing.T) { }, FromClause: &ast.List{}, TargetList: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, }, ReturningList: &ast.List{ Items: []ast.Node{ diff --git a/internal/engine/ydb/convert.go b/internal/engine/ydb/convert.go index b4d9490d0b..250dc467e0 100755 --- a/internal/engine/ydb/convert.go +++ b/internal/engine/ydb/convert.go @@ -54,7 +54,7 @@ func NewIdentifier(t string) *ast.String { return &ast.String{Str: identifier(t)} } -func (c *cc) convertDrop_role_stmtCOntext(n *parser.Drop_role_stmtContext) ast.Node { +func (c *cc) convertDrop_role_stmtContext(n *parser.Drop_role_stmtContext) ast.Node { if n.DROP() == nil || (n.USER() == nil && n.GROUP() == nil) || len(n.AllRole_name()) == 0 { return todo("Drop_role_stmtContext", n) } @@ -467,6 +467,145 @@ func (c *cc) convertRollback_stmtContext(n *parser.Rollback_stmtContext) ast.Nod return todo("convertRollback_stmtContext", n) } +func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) ast.Node { + if n.ALTER() == nil || n.TABLE() == nil || n.Simple_table_ref() == nil || len(n.AllAlter_table_action()) == 0 { + return todo("convertAlter_table_stmtContext", n) + } + + stmt := &ast.AlterTableStmt{ + Table: parseTableName(n.Simple_table_ref().Simple_table_ref_core()), + Cmds: &ast.List{}, + } + + for _, action := range n.AllAlter_table_action() { + if action == nil { + continue + } + + switch { + case action.Alter_table_add_column() != nil: + ac := action.Alter_table_add_column() + if ac.ADD() != nil && ac.Column_schema() != nil { + columnDef := c.convertColumnSchema(ac.Column_schema().(*parser.Column_schemaContext)) + stmt.Cmds.Items = append(stmt.Cmds.Items, &ast.AlterTableCmd{ + Name: &columnDef.Colname, + Subtype: ast.AT_AddColumn, + Def: columnDef, + }) + } + case action.Alter_table_drop_column() != nil: + ac := action.Alter_table_drop_column() + if ac.DROP() != nil && ac.An_id() != nil { + name := parseAnId(ac.An_id()) + stmt.Cmds.Items = append(stmt.Cmds.Items, &ast.AlterTableCmd{ + Name: &name, + Subtype: ast.AT_DropColumn, + }) + } + case action.Alter_table_alter_column_drop_not_null() != nil: + ac := action.Alter_table_alter_column_drop_not_null() + if ac.DROP() != nil && ac.NOT() != nil && ac.NULL() != nil && ac.An_id() != nil { + name := parseAnId(ac.An_id()) + stmt.Cmds.Items = append(stmt.Cmds.Items, &ast.AlterTableCmd{ + Name: &name, + Subtype: ast.AT_DropNotNull, + }) + } + case action.Alter_table_rename_to() != nil: + ac := action.Alter_table_rename_to() + if ac.RENAME() != nil && ac.TO() != nil && ac.An_id_table() != nil { + // TODO: Returning here may be incorrect if there are multiple specs + newName := parseAnIdTable(ac.An_id_table()) + return &ast.RenameTableStmt{ + Table: stmt.Table, + NewName: &newName, + } + } + case action.Alter_table_add_index() != nil, + action.Alter_table_drop_index() != nil, + action.Alter_table_add_column_family() != nil, + action.Alter_table_alter_column_family() != nil, + action.Alter_table_set_table_setting_uncompat() != nil, + action.Alter_table_set_table_setting_compat() != nil, + action.Alter_table_reset_table_setting() != nil, + action.Alter_table_add_changefeed() != nil, + action.Alter_table_alter_changefeed() != nil, + action.Alter_table_drop_changefeed() != nil, + action.Alter_table_rename_index_to() != nil, + action.Alter_table_alter_index() != nil: + // All these actions do not change column schema relevant to sqlc; no-op. + // Intentionally ignored. + } + } + + return stmt +} + +func (c *cc) convertDo_stmtContext(n *parser.Do_stmtContext) ast.Node { + if n.DO() == nil || (n.Call_action() == nil && n.Inline_action() == nil) { + return todo("convertDo_stmtContext", n) + } + + switch { + case n.Call_action() != nil: + return c.convert(n.Call_action()) + + case n.Inline_action() != nil: + return c.convert(n.Inline_action()) + } + + return todo("convertDo_stmtContext", n) +} + +func (c *cc) convertCall_actionContext(n *parser.Call_actionContext) ast.Node { + if n == nil { + return nil + } + if n.LPAREN() != nil && n.RPAREN() != nil { + funcCall := &ast.FuncCall{ + Funcname: &ast.List{}, + Args: &ast.List{}, + AggOrder: &ast.List{}, + } + + if n.Bind_parameter() != nil { + funcCall.Funcname.Items = append(funcCall.Funcname.Items, c.convert(n.Bind_parameter())) + } else if n.EMPTY_ACTION() != nil { + funcCall.Funcname.Items = append(funcCall.Funcname.Items, &ast.String{Str: "EMPTY_ACTION"}) + } + + if n.Expr_list() != nil { + for _, expr := range n.Expr_list().AllExpr() { + funcCall.Args.Items = append(funcCall.Args.Items, c.convert(expr)) + } + } + + return &ast.DoStmt{ + Args: &ast.List{Items: []ast.Node{funcCall}}, + } + } + return todo("convertCall_actionContext", n) +} + +func (c *cc) convertInline_actionContext(n *parser.Inline_actionContext) ast.Node { + if n == nil { + return nil + } + if n.BEGIN() != nil && n.END() != nil && n.DO() != nil { + args := &ast.List{} + if defineBody := n.Define_action_or_subquery_body(); defineBody != nil { + cores := defineBody.AllSql_stmt_core() + for _, stmtCore := range cores { + if converted := c.convert(stmtCore); converted != nil { + args.Items = append(args.Items, converted) + } + } + } + return &ast.DoStmt{Args: args} + } + return todo("convertInline_actionContext", n) +} + func (c *cc) convertDrop_table_stmtContext(n *parser.Drop_table_stmtContext) ast.Node { if n.DROP() != nil && (n.TABLESTORE() != nil || (n.EXTERNAL() != nil && n.TABLE() != nil) || n.TABLE() != nil) { name := parseTableName(n.Simple_table_ref().Simple_table_ref_core()) @@ -509,12 +648,9 @@ func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { if valSource != nil { switch { case valSource.Values_stmt() != nil: - source = &ast.SelectStmt{ - ValuesLists: c.convert(valSource.Values_stmt()).(*ast.List), - FromClause: &ast.List{}, - TargetList: &ast.List{}, - } - + stmt := emptySelectStmt() + stmt.ValuesLists = c.convert(valSource.Values_stmt()).(*ast.List) + source = stmt case valSource.Select_stmt() != nil: source = c.convert(valSource.Select_stmt()) } @@ -704,12 +840,9 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { if valSource != nil { switch { case valSource.Values_stmt() != nil: - source = &ast.SelectStmt{ - ValuesLists: c.convert(valSource.Values_stmt()).(*ast.List), - FromClause: &ast.List{}, - TargetList: &ast.List{}, - } - + stmt := emptySelectStmt() + stmt.ValuesLists = c.convert(valSource.Values_stmt()).(*ast.List) + source = stmt case valSource.Select_stmt() != nil: source = c.convert(valSource.Select_stmt()) } @@ -777,12 +910,9 @@ func (c *cc) convertInto_table_stmtContext(n *parser.Into_table_stmtContext) ast if valSource != nil { switch { case valSource.Values_stmt() != nil: - source = &ast.SelectStmt{ - ValuesLists: c.convert(valSource.Values_stmt()).(*ast.List), - FromClause: &ast.List{}, - TargetList: &ast.List{}, - } - + stmt := emptySelectStmt() + stmt.ValuesLists = c.convert(valSource.Values_stmt()).(*ast.List) + source = stmt case valSource.Select_stmt() != nil: source = c.convert(valSource.Select_stmt()) } @@ -859,60 +989,106 @@ func (c *cc) convertReturning_columns_listContext(n *parser.Returning_columns_li } func (c *cc) convertSelectStmtContext(n *parser.Select_stmtContext) ast.Node { + if len(n.AllSelect_kind_parenthesis()) == 0 { + return todo("convertSelectStmtContext", n) + } + skp := n.Select_kind_parenthesis(0) if skp == nil { - return nil + return todo("convertSelectStmtContext", skp) } - partial := skp.Select_kind_partial() - if partial == nil { - return nil + + stmt := c.convertSelectKindParenthesis(skp) + left, ok := stmt.(*ast.SelectStmt) + if left == nil || !ok { + return todo("convertSelectKindParenthesis", skp) } + + kinds := n.AllSelect_kind_parenthesis() + ops := n.AllSelect_op() + + for i := 1; i < len(kinds); i++ { + stmt := c.convertSelectKindParenthesis(kinds[i]) + right, ok := stmt.(*ast.SelectStmt) + if right == nil || !ok { + return todo("convertSelectKindParenthesis", kinds[i]) + } + + var op ast.SetOperation + var all bool + if i-1 < len(ops) && ops[i-1] != nil { + so := ops[i-1] + switch { + case so.UNION() != nil: + op = ast.Union + case so.INTERSECT() != nil: + log.Fatalf("YDB: INTERSECT is not implemented yet") + case so.EXCEPT() != nil: + log.Fatalf("YDB: EXCEPT is not implemented yet") + default: + op = ast.None + } + all = so.ALL() != nil + } + larg := left + left = emptySelectStmt() + left.Op = op + left.All = all + left.Larg = larg + left.Rarg = right + } + + return left +} + +func (c *cc) convertSelectKindParenthesis(n parser.ISelect_kind_parenthesisContext) ast.Node { + if n == nil || n.Select_kind_partial() == nil { + return todo("convertSelectKindParenthesis", n) + } + partial := n.Select_kind_partial() + sk := partial.Select_kind() if sk == nil { - return nil + return todo("convertSelectKind", sk) } - selectStmt := &ast.SelectStmt{} + var base ast.Node switch { - case sk.Process_core() != nil: - cnode := c.convert(sk.Process_core()) - stmt, ok := cnode.(*ast.SelectStmt) - if !ok { - return nil - } - selectStmt = stmt case sk.Select_core() != nil: - cnode := c.convert(sk.Select_core()) - stmt, ok := cnode.(*ast.SelectStmt) - if !ok { - return nil - } - selectStmt = stmt + base = c.convertSelectCoreContext(sk.Select_core()) + case sk.Process_core() != nil: + log.Fatalf("PROCESS is not supported in YDB engine") case sk.Reduce_core() != nil: - cnode := c.convert(sk.Reduce_core()) - stmt, ok := cnode.(*ast.SelectStmt) - if !ok { - return nil - } - selectStmt = stmt + log.Fatalf("REDUCE is not supported in YDB engine") + } + stmt, ok := base.(*ast.SelectStmt) + if !ok || stmt == nil { + return todo("convertSelectKindParenthesis", sk.Select_core()) } - // todo: cover process and reduce core, - // todo: cover LIMIT and OFFSET + // TODO: handle INTO RESULT clause - return selectStmt + if partial.LIMIT() != nil { + exprs := partial.AllExpr() + if len(exprs) >= 1 { + stmt.LimitCount = c.convert(exprs[0]) + } + if partial.OFFSET() != nil { + if len(exprs) >= 2 { + stmt.LimitOffset = c.convert(exprs[1]) + } + } + } + + return stmt } -func (c *cc) convertSelectCoreContext(n *parser.Select_coreContext) ast.Node { - stmt := &ast.SelectStmt{ - TargetList: &ast.List{}, - FromClause: &ast.List{}, - } +func (c *cc) convertSelectCoreContext(n parser.ISelect_coreContext) ast.Node { + stmt := emptySelectStmt() if n.Opt_set_quantifier() != nil { oq := n.Opt_set_quantifier() if oq.DISTINCT() != nil { - // todo: add distinct support - stmt.DistinctClause = &ast.List{} + stmt.DistinctClause.Items = append(stmt.DistinctClause.Items, &ast.TODO{}) // trick to handle distinct } } resultCols := n.AllResult_column() @@ -932,8 +1108,14 @@ func (c *cc) convertSelectCoreContext(n *parser.Select_coreContext) ast.Node { Items: items, } } + + // TODO: handle WITHOUT clause + jsList := n.AllJoin_source() - if len(n.AllFROM()) > 0 && len(jsList) > 0 { + if len(n.AllFROM()) > 1 { + log.Fatalf("YDB: Only one FROM clause is allowed") + } + if len(jsList) > 0 { var fromItems []ast.Node for _, js := range jsList { jsCon, ok := js.(*parser.Join_sourceContext) @@ -950,15 +1132,137 @@ func (c *cc) convertSelectCoreContext(n *parser.Select_coreContext) ast.Node { Items: fromItems, } } + + exprIdx := 0 if n.WHERE() != nil { - whereCtx := n.Expr(0) - if whereCtx != nil { + if whereCtx := n.Expr(exprIdx); whereCtx != nil { stmt.WhereClause = c.convert(whereCtx) } + exprIdx++ + } + if n.HAVING() != nil { + if havingCtx := n.Expr(exprIdx); havingCtx != nil { + stmt.HavingClause = c.convert(havingCtx) + } + exprIdx++ + } + + if gbc := n.Group_by_clause(); gbc != nil { + if gel := gbc.Grouping_element_list(); gel != nil { + var groups []ast.Node + for _, ne := range gel.AllGrouping_element() { + groups = append(groups, c.convert(ne)) + } + if len(groups) > 0 { + stmt.GroupClause = &ast.List{Items: groups} + } + } + } + + if ext := n.Ext_order_by_clause(); ext != nil { + if ob := ext.Order_by_clause(); ob != nil && ob.ORDER() != nil && ob.BY() != nil { + // TODO: ASSUME ORDER BY + if sl := ob.Sort_specification_list(); sl != nil { + var orderItems []ast.Node + for _, sp := range sl.AllSort_specification() { + ss, ok := sp.(*parser.Sort_specificationContext) + if !ok || ss == nil { + continue + } + expr := c.convert(ss.Expr()) + dir := ast.SortByDirDefault + if ss.ASC() != nil { + dir = ast.SortByDirAsc + } else if ss.DESC() != nil { + dir = ast.SortByDirDesc + } + orderItems = append(orderItems, &ast.SortBy{ + Node: expr, + SortbyDir: dir, + SortbyNulls: ast.SortByNullsUndefined, + UseOp: &ast.List{}, + Location: c.pos(ss.GetStart()), + }) + } + if len(orderItems) > 0 { + stmt.SortClause = &ast.List{Items: orderItems} + } + } + } } return stmt } +func (c *cc) convertGrouping_elementContext(n parser.IGrouping_elementContext) ast.Node { + if n == nil { + return todo("convertGrouping_elementContext", n) + } + if ogs := n.Ordinary_grouping_set(); ogs != nil { + return c.convert(ogs) + } + if rl := n.Rollup_list(); rl != nil { + return c.convert(rl) + } + if cl := n.Cube_list(); cl != nil { + return c.convert(cl) + } + if gss := n.Grouping_sets_specification(); gss != nil { + return c.convert(gss) + } + return todo("convertGrouping_elementContext", n) +} + +func (c *cc) convertOrdinary_grouping_setContext(n *parser.Ordinary_grouping_setContext) ast.Node { + if n == nil || n.Named_expr() == nil { + return todo("convertOrdinary_grouping_setContext", n) + } + + return c.convert(n.Named_expr()) +} + +func (c *cc) convertRollup_listContext(n *parser.Rollup_listContext) ast.Node { + if n == nil || n.ROLLUP() == nil || n.LPAREN() == nil || n.RPAREN() == nil { + return todo("convertRollup_listContext", n) + } + + var items []ast.Node + if list := n.Ordinary_grouping_set_list(); list != nil { + for _, ogs := range list.AllOrdinary_grouping_set() { + items = append(items, c.convert(ogs)) + } + } + return &ast.GroupingSet{Kind: 1, Content: &ast.List{Items: items}} +} + +func (c *cc) convertCube_listContext(n *parser.Cube_listContext) ast.Node { + if n == nil || n.CUBE() == nil || n.LPAREN() == nil || n.RPAREN() == nil { + return todo("convertCube_listContext", n) + } + + var items []ast.Node + if list := n.Ordinary_grouping_set_list(); list != nil { + for _, ogs := range list.AllOrdinary_grouping_set() { + items = append(items, c.convert(ogs)) + } + } + + return &ast.GroupingSet{Kind: 2, Content: &ast.List{Items: items}} +} + +func (c *cc) convertGrouping_sets_specificationContext(n *parser.Grouping_sets_specificationContext) ast.Node { + if n == nil || n.GROUPING() == nil || n.SETS() == nil || n.LPAREN() == nil || n.RPAREN() == nil { + return todo("convertGrouping_sets_specificationContext", n) + } + + var items []ast.Node + if gel := n.Grouping_element_list(); gel != nil { + for _, ge := range gel.AllGrouping_element() { + items = append(items, c.convert(ge)) + } + } + return &ast.GroupingSet{Kind: 3, Content: &ast.List{Items: items}} +} + func (c *cc) convertResultColumn(n *parser.Result_columnContext) ast.Node { // todo: support opt_id_prefix target := &ast.ResTarget{ @@ -1683,6 +1987,22 @@ func (c *cc) convertSqlStmtCore(n parser.ISql_stmt_coreContext) ast.Node { return nil } +func (c *cc) convertNamed_exprContext(n *parser.Named_exprContext) ast.Node { + if n == nil || n.Expr() == nil { + return todo("convertNamed_exprContext", n) + } + expr := c.convert(n.Expr()) + if n.AS() != nil && n.An_id_or_type() != nil { + name := parseAnIdOrType(n.An_id_or_type()) + return &ast.ResTarget{ + Name: &name, + Val: expr, + Location: c.pos(n.Expr().GetStart()), + } + } + return expr +} + func (c *cc) convertExpr(n *parser.ExprContext) ast.Node { if n == nil { return nil @@ -2457,9 +2777,6 @@ func (c *cc) convert(node node) ast.Node { case *parser.Select_stmtContext: return c.convertSelectStmtContext(n) - case *parser.Select_coreContext: - return c.convertSelectCoreContext(n) - case *parser.Result_columnContext: return c.convertResultColumn(n) @@ -2553,6 +2870,12 @@ func (c *cc) convert(node node) ast.Node { case *parser.Update_stmtContext: return c.convertUpdate_stmtContext(n) + case *parser.Alter_table_stmtContext: + return c.convertAlter_table_stmtContext(n) + + case *parser.Do_stmtContext: + return c.convertDo_stmtContext(n) + case *parser.Drop_table_stmtContext: return c.convertDrop_table_stmtContext(n) @@ -2593,7 +2916,31 @@ func (c *cc) convert(node node) ast.Node { return c.convertAlter_group_stmtContext(n) case *parser.Drop_role_stmtContext: - return c.convertDrop_role_stmtCOntext(n) + return c.convertDrop_role_stmtContext(n) + + case *parser.Grouping_elementContext: + return c.convertGrouping_elementContext(n) + + case *parser.Ordinary_grouping_setContext: + return c.convertOrdinary_grouping_setContext(n) + + case *parser.Rollup_listContext: + return c.convertRollup_listContext(n) + + case *parser.Cube_listContext: + return c.convertCube_listContext(n) + + case *parser.Grouping_sets_specificationContext: + return c.convertGrouping_sets_specificationContext(n) + + case *parser.Named_exprContext: + return c.convertNamed_exprContext(n) + + case *parser.Call_actionContext: + return c.convertCall_actionContext(n) + + case *parser.Inline_actionContext: + return c.convertInline_actionContext(n) default: return todo("convert(case=default)", n) diff --git a/internal/engine/ydb/utils.go b/internal/engine/ydb/utils.go index 3847ee5055..f2023e8ba9 100755 --- a/internal/engine/ydb/utils.go +++ b/internal/engine/ydb/utils.go @@ -85,7 +85,7 @@ func parseIdOrType(ctx parser.IId_or_typeContext) string { } Id := ctx.(*parser.Id_or_typeContext) if Id.Id() != nil { - return identifier(parseIdTable(Id.Id())) + return identifier(parseId(Id.Id())) } return "" @@ -112,13 +112,25 @@ func parseAnIdSchema(ctx parser.IAn_id_schemaContext) string { return "" } -func parseIdTable(ctx parser.IIdContext) string { +func parseId(ctx parser.IIdContext) string { if ctx == nil { return "" } return ctx.GetText() } +func parseAnIdTable(ctx parser.IAn_id_tableContext) string { + if ctx == nil { + return "" + } + if id := ctx.Id_table(); id != nil { + return id.GetText() + } else if str := ctx.STRING_VALUE(); str != nil { + return str.GetText() + } + return "" +} + func parseIntegerValue(text string) (int64, error) { text = strings.ToLower(text) base := 10 @@ -194,3 +206,16 @@ func byteOffsetFromRuneIndex(s string, runeIndex int) int { } return bytePos } + +func emptySelectStmt() *ast.SelectStmt { + return &ast.SelectStmt{ + DistinctClause: &ast.List{}, + TargetList: &ast.List{}, + FromClause: &ast.List{}, + GroupClause: &ast.List{}, + WindowClause: &ast.List{}, + ValuesLists: &ast.List{}, + SortClause: &ast.List{}, + LockingClause: &ast.List{}, + } +}