diff --git a/go.mod b/go.mod index c72f29b6b1..4b755a4baa 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( github.com/wasilibs/go-pgquery v0.0.0-20250409022910-10ac41983c07 github.com/xeipuuv/gojsonschema v1.2.0 github.com/ydb-platform/ydb-go-sdk/v3 v3.115.3 - github.com/ydb-platform/yql-parsers v0.0.0-20250309001738-7d693911f333 + github.com/ydb-platform/yql-parsers v0.0.0-20250911122629-e8a65d734cbd golang.org/x/sync v0.16.0 google.golang.org/grpc v1.75.0 google.golang.org/protobuf v1.36.8 diff --git a/go.sum b/go.sum index 53cb8e8aec..eb68917f04 100644 --- a/go.sum +++ b/go.sum @@ -144,8 +144,6 @@ github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jonboulle/clockwork v0.3.0 h1:9BSCMi8C+0qdApAp4auwX0RkLGUjs956h0EkuQymUhg= -github.com/jonboulle/clockwork v0.3.0/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbdFz6I= github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= @@ -242,12 +240,10 @@ github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17 github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/ydb-platform/ydb-go-genproto v0.0.0-20241112172322-ea1f63298f77 h1:LY6cI8cP4B9rrpTleZk95+08kl2gF4rixG7+V/dwL6Q= github.com/ydb-platform/ydb-go-genproto v0.0.0-20241112172322-ea1f63298f77/go.mod h1:Er+FePu1dNUieD+XTMDduGpQuCPssK5Q4BjF+IIXJ3I= -github.com/ydb-platform/ydb-go-sdk/v3 v3.108.0 h1:TwWSp3gRMcja/hRpOofncLvgxAXCmzpz5cGtmdaoITw= -github.com/ydb-platform/ydb-go-sdk/v3 v3.108.0/go.mod h1:l5sSv153E18VvYcsmr51hok9Sjc16tEC8AXGbwrk+ho= github.com/ydb-platform/ydb-go-sdk/v3 v3.115.3 h1:SFeSK2c+PmiToyNIhr143u+YDzLhl/kboXwKLYDk0O4= github.com/ydb-platform/ydb-go-sdk/v3 v3.115.3/go.mod h1:Pp1w2xxUoLQ3NCNAwV7pvDq0TVQOdtAqs+ZiC+i8r14= -github.com/ydb-platform/yql-parsers v0.0.0-20250309001738-7d693911f333 h1:KFtJwlPdOxWjCKXX0jFJ8k1FlbqbRbUW3k/kYSZX7SA= -github.com/ydb-platform/yql-parsers v0.0.0-20250309001738-7d693911f333/go.mod h1:vrPJPS8cdPSV568YcXhB4bUwhyV8bmWKqmQ5c5Xi99o= +github.com/ydb-platform/yql-parsers v0.0.0-20250911122629-e8a65d734cbd h1:ZfUkkZ1m5JCAw7jHQavecv+gKJWA6SNxuKLqHQ5/988= +github.com/ydb-platform/yql-parsers v0.0.0-20250911122629-e8a65d734cbd/go.mod h1:vrPJPS8cdPSV568YcXhB4bUwhyV8bmWKqmQ5c5Xi99o= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= diff --git a/internal/engine/ydb/convert.go b/internal/engine/ydb/convert.go index e173cf287e..8b67191ce6 100755 --- a/internal/engine/ydb/convert.go +++ b/internal/engine/ydb/convert.go @@ -12,8 +12,8 @@ import ( ) type cc struct { - paramCount int - content string + parser.BaseYQLVisitor + content string } func (c *cc) pos(token antlr.Token) int { @@ -54,9 +54,9 @@ func NewIdentifier(t string) *ast.String { return &ast.String{Str: identifier(t)} } -func (c *cc) convertDrop_role_stmtContext(n *parser.Drop_role_stmtContext) ast.Node { +func (c *cc) VisitDrop_role_stmt(n *parser.Drop_role_stmtContext) interface{} { if n.DROP() == nil || (n.USER() == nil && n.GROUP() == nil) || len(n.AllRole_name()) == 0 { - return todo("Drop_role_stmtContext", n) + return todo("VisitDrop_role_stmt", n) } stmt := &ast.DropRoleStmt{ @@ -67,7 +67,7 @@ func (c *cc) convertDrop_role_stmtContext(n *parser.Drop_role_stmtContext) ast.N for _, role := range n.AllRole_name() { member, isParam, _ := c.extractRoleSpec(role, ast.RoleSpecType(1)) if member == nil { - return todo("Drop_role_stmtContext", n) + return todo("VisitDrop_role_stmt", role) } if debug.Active && isParam { @@ -80,13 +80,13 @@ func (c *cc) convertDrop_role_stmtContext(n *parser.Drop_role_stmtContext) ast.N return stmt } -func (c *cc) convertAlter_group_stmtContext(n *parser.Alter_group_stmtContext) ast.Node { +func (c *cc) VisitAlter_group_stmt(n *parser.Alter_group_stmtContext) interface{} { if n.ALTER() == nil || n.GROUP() == nil || len(n.AllRole_name()) == 0 { - return todo("convertAlter_group_stmtContext", n) + return todo("VisitAlter_group_stmt", n) } role, paramFlag, _ := c.extractRoleSpec(n.Role_name(0), ast.RoleSpecType(1)) if role == nil { - return todo("convertAlter_group_stmtContext", n) + return todo("VisitAlter_group_stmt", n) } if debug.Active && paramFlag { @@ -101,7 +101,10 @@ func (c *cc) convertAlter_group_stmtContext(n *parser.Alter_group_stmtContext) a switch { case n.RENAME() != nil && n.TO() != nil && len(n.AllRole_name()) > 1: - newName := c.convert(n.Role_name(1)) + newName, ok := n.Role_name(1).Accept(c).(ast.Node) + if !ok { + return todo("VisitAlter_group_stmt", n.Role_name(1)) + } action := "rename" defElem := &ast.DefElem{ @@ -120,12 +123,12 @@ func (c *cc) convertAlter_group_stmtContext(n *parser.Alter_group_stmtContext) a case *ast.Boolean: defElem.Arg = val default: - return todo("convertAlter_group_stmtContext", n) + return todo("VisitAlter_group_stmt", n.Role_name(1)) } case *ast.ParamRef, *ast.A_Expr: defElem.Arg = newName default: - return todo("convertAlter_group_stmtContext", n) + return todo("VisitAlter_group_stmt", n.Role_name(1)) } if debug.Active && !paramFlag && bindFlag { @@ -140,7 +143,7 @@ func (c *cc) convertAlter_group_stmtContext(n *parser.Alter_group_stmtContext) a for _, role := range n.AllRole_name()[1:] { member, isParam, _ := c.extractRoleSpec(role, ast.RoleSpecType(1)) if member == nil { - return todo("convertAlter_group_stmtContext", n) + return todo("VisitAlter_group_stmt", role) } if debug.Active && isParam && !paramFlag { @@ -161,21 +164,21 @@ func (c *cc) convertAlter_group_stmtContext(n *parser.Alter_group_stmtContext) a Defname: &defname, Arg: optionList, Defaction: action, - Location: c.pos(n.GetStart()), + Location: c.pos(n.Role_name(1).GetStart()), }) } return stmt } -func (c *cc) convertAlter_user_stmtContext(n *parser.Alter_user_stmtContext) ast.Node { +func (c *cc) VisitAlter_user_stmt(n *parser.Alter_user_stmtContext) interface{} { if n.ALTER() == nil || n.USER() == nil || len(n.AllRole_name()) == 0 { - return todo("Alter_user_stmtContext", n) + return todo("VisitAlter_user_stmt", n) } role, paramFlag, _ := c.extractRoleSpec(n.Role_name(0), ast.RoleSpecType(1)) if role == nil { - return todo("convertAlter_group_stmtContext", n) + return todo("VisitAlter_group_stmt", n) } if debug.Active && paramFlag { @@ -190,7 +193,10 @@ func (c *cc) convertAlter_user_stmtContext(n *parser.Alter_user_stmtContext) ast switch { case n.RENAME() != nil && n.TO() != nil && len(n.AllRole_name()) > 1: - newName := c.convert(n.Role_name(1)) + newName, ok := n.Role_name(1).Accept(c).(ast.Node) + if !ok { + return todo("VisitAlter_user_stmt", n.Role_name(1)) + } action := "rename" defElem := &ast.DefElem{ @@ -209,12 +215,12 @@ func (c *cc) convertAlter_user_stmtContext(n *parser.Alter_user_stmtContext) ast case *ast.Boolean: defElem.Arg = val default: - return todo("Alter_user_stmtContext", n) + return todo("VisitAlter_user_stmt", n.Role_name(1)) } case *ast.ParamRef, *ast.A_Expr: defElem.Arg = newName default: - return todo("Alter_user_stmtContext", n) + return todo("VisitAlter_user_stmt", n.Role_name(1)) } if debug.Active && !paramFlag && bindFlag { @@ -225,7 +231,11 @@ func (c *cc) convertAlter_user_stmtContext(n *parser.Alter_user_stmtContext) ast case len(n.AllUser_option()) > 0: for _, opt := range n.AllUser_option() { - if node := c.convert(opt); node != nil { + if temp := opt.Accept(c); temp != nil { + var node, ok = temp.(ast.Node) + if !ok { + return todo("VisitAlter_user_stmt", opt) + } stmt.Options.Items = append(stmt.Options.Items, node) } } @@ -234,11 +244,14 @@ func (c *cc) convertAlter_user_stmtContext(n *parser.Alter_user_stmtContext) ast return stmt } -func (c *cc) convertCreate_group_stmtContext(n *parser.Create_group_stmtContext) ast.Node { +func (c *cc) VisitCreate_group_stmt(n *parser.Create_group_stmtContext) interface{} { if n.CREATE() == nil || n.GROUP() == nil || len(n.AllRole_name()) == 0 { - return todo("Create_group_stmtContext", n) + return todo("VisitCreate_group_stmt", n) + } + groupName, ok := n.Role_name(0).Accept(c).(ast.Node) + if !ok { + return todo("VisitCreate_group_stmt", n.Role_name(0)) } - groupName := c.convert(n.Role_name(0)) stmt := &ast.CreateRoleStmt{ StmtType: ast.RoleStmtType(3), @@ -255,12 +268,12 @@ func (c *cc) convertCreate_group_stmtContext(n *parser.Create_group_stmtContext) case *ast.Boolean: stmt.BindRole = groupName default: - return todo("convertCreate_group_stmtContext", n) + return todo("VisitCreate_group_stmt", n.Role_name(0)) } case *ast.ParamRef, *ast.A_Expr: stmt.BindRole = groupName default: - return todo("convertCreate_group_stmtContext", n) + return todo("VisitCreate_group_stmt", n.Role_name(0)) } if debug.Active && paramFlag { @@ -273,7 +286,7 @@ func (c *cc) convertCreate_group_stmtContext(n *parser.Create_group_stmtContext) for _, role := range n.AllRole_name()[1:] { member, isParam, _ := c.extractRoleSpec(role, ast.RoleSpecType(1)) if member == nil { - return todo("convertCreate_group_stmtContext", n) + return todo("VisitCreate_group_stmt", role) } if debug.Active && isParam && !paramFlag { @@ -286,26 +299,29 @@ func (c *cc) convertCreate_group_stmtContext(n *parser.Create_group_stmtContext) stmt.Options.Items = append(stmt.Options.Items, &ast.DefElem{ Defname: &defname, Arg: optionList, - Location: c.pos(n.GetStart()), + Location: c.pos(n.Role_name(1).GetStart()), }) } return stmt } -func (c *cc) convertUse_stmtContext(n *parser.Use_stmtContext) ast.Node { +func (c *cc) VisitUse_stmt(n *parser.Use_stmtContext) interface{} { if n.USE() != nil && n.Cluster_expr() != nil { - clusterExpr := c.convert(n.Cluster_expr()) + clusterExpr, ok := n.Cluster_expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitUse_stmt", n.Cluster_expr()) + } stmt := &ast.UseStmt{ Xpr: clusterExpr, - Location: c.pos(n.GetStart()), + Location: c.pos(n.Cluster_expr().GetStart()), } return stmt } - return todo("convertUse_stmtContext", n) + return todo("VisitUse_stmt", n) } -func (c *cc) convertCluster_exprContext(n *parser.Cluster_exprContext) ast.Node { +func (c *cc) VisitCluster_expr(n *parser.Cluster_exprContext) interface{} { var node ast.Node switch { @@ -318,12 +334,16 @@ func (c *cc) convertCluster_exprContext(n *parser.Cluster_exprContext) ast.Node Location: c.pos(anID.GetStart()), } } else if bp := pureCtx.Bind_parameter(); bp != nil { - node = c.convert(bp) + temp, ok := bp.Accept(c).(ast.Node) + if !ok { + return todo("VisitCluster_expr", bp) + } + node = temp } case n.ASTERISK() != nil: node = &ast.A_Star{} default: - return todo("convertCluster_exprContext", n) + return todo("VisitCluster_expr", n) } if n.An_id() != nil && n.COLON() != nil { @@ -339,11 +359,14 @@ func (c *cc) convertCluster_exprContext(n *parser.Cluster_exprContext) ast.Node return node } -func (c *cc) convertCreate_user_stmtContext(n *parser.Create_user_stmtContext) ast.Node { +func (c *cc) VisitCreate_user_stmt(n *parser.Create_user_stmtContext) interface{} { if n.CREATE() == nil || n.USER() == nil || n.Role_name() == nil { - return todo("convertCreate_user_stmtContext", n) + return todo("VisitCreate_user_stmt", n) + } + roleNode, ok := n.Role_name().Accept(c).(ast.Node) + if !ok { + return todo("VisitCreate_user_stmt", n.Role_name()) } - roleNode := c.convert(n.Role_name()) stmt := &ast.CreateRoleStmt{ StmtType: ast.RoleStmtType(2), @@ -360,12 +383,12 @@ func (c *cc) convertCreate_user_stmtContext(n *parser.Create_user_stmtContext) a case *ast.Boolean: stmt.BindRole = roleNode default: - return todo("convertCreate_user_stmtContext", n) + return todo("VisitCreate_user_stmt", n.Role_name()) } case *ast.ParamRef, *ast.A_Expr: stmt.BindRole = roleNode default: - return todo("convertCreate_user_stmtContext", n) + return todo("VisitCreate_user_stmt", n.Role_name()) } if debug.Active && paramFlag { @@ -375,7 +398,11 @@ func (c *cc) convertCreate_user_stmtContext(n *parser.Create_user_stmtContext) a if len(n.AllUser_option()) > 0 { options := []ast.Node{} for _, opt := range n.AllUser_option() { - if node := c.convert(opt); node != nil { + if temp := opt.Accept(c); temp != nil { + node, ok := temp.(ast.Node) + if !ok { + return todo("VisitCreate_user_stmt", opt) + } options = append(options, node) } } @@ -386,7 +413,7 @@ func (c *cc) convertCreate_user_stmtContext(n *parser.Create_user_stmtContext) a return stmt } -func (c *cc) convertUser_optionContext(n *parser.User_optionContext) ast.Node { +func (c *cc) VisitUser_option(n *parser.User_optionContext) interface{} { switch { case n.Authentication_option() != nil: aOpt := n.Authentication_option() @@ -436,40 +463,43 @@ func (c *cc) convertUser_optionContext(n *parser.User_optionContext) ast.Node { Location: c.pos(lOpt.GetStart()), } default: - return todo("convertUser_optionContext", n) + return todo("VisitUser_option", n) } - return nil + return todo("VisitUser_option", n) } -func (c *cc) convertRole_nameContext(n *parser.Role_nameContext) ast.Node { +func (c *cc) VisitRole_name(n *parser.Role_nameContext) interface{} { switch { case n.An_id_or_type() != nil: name := parseAnIdOrType(n.An_id_or_type()) - return &ast.A_Const{Val: NewIdentifier(name), Location: c.pos(n.GetStart())} + return &ast.A_Const{Val: NewIdentifier(name), Location: c.pos(n.An_id_or_type().GetStart())} case n.Bind_parameter() != nil: - bindPar := c.convert(n.Bind_parameter()) + bindPar, ok := n.Bind_parameter().Accept(c).(ast.Node) + if !ok { + return todo("VisitRole_name", n.Bind_parameter()) + } return bindPar } - return todo("convertRole_nameContext", n) + return todo("VisitRole_name", n) } -func (c *cc) convertCommit_stmtContext(n *parser.Commit_stmtContext) ast.Node { +func (c *cc) VisitCommit_stmt(n *parser.Commit_stmtContext) interface{} { if n.COMMIT() != nil { return &ast.TransactionStmt{Kind: ast.TransactionStmtKind(3)} } - return todo("convertCommit_stmtContext", n) + return todo("VisitCommit_stmt", n) } -func (c *cc) convertRollback_stmtContext(n *parser.Rollback_stmtContext) ast.Node { +func (c *cc) VisitRollback_stmt(n *parser.Rollback_stmtContext) interface{} { if n.ROLLBACK() != nil { return &ast.TransactionStmt{Kind: ast.TransactionStmtKind(4)} } - return todo("convertRollback_stmtContext", n) + return todo("VisitRollback_stmt", n) } -func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) ast.Node { +func (c *cc) VisitAlter_table_stmt(n *parser.Alter_table_stmtContext) interface{} { if n.ALTER() == nil || n.TABLE() == nil || n.Simple_table_ref() == nil || len(n.AllAlter_table_action()) == 0 { - return todo("convertAlter_table_stmtContext", n) + return todo("VisitAlter_table_stmt", n) } stmt := &ast.AlterTableStmt{ @@ -486,7 +516,14 @@ func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) a case action.Alter_table_add_column() != nil: ac := action.Alter_table_add_column() if ac.ADD() != nil && ac.Column_schema() != nil { - columnDef := c.convertColumnSchema(ac.Column_schema().(*parser.Column_schemaContext)) + temp, ok := ac.Column_schema().Accept(c).(ast.Node) + if !ok { + return todo("VisitAlter_table_stmt", ac.Column_schema()) + } + columnDef, ok := temp.(*ast.ColumnDef) + if !ok { + return todo("VisitAlter_table_stmt", ac.Column_schema()) + } stmt.Cmds.Items = append(stmt.Cmds.Items, &ast.AlterTableCmd{ Name: &columnDef.Colname, Subtype: ast.AT_AddColumn, @@ -514,7 +551,7 @@ func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) a case action.Alter_table_rename_to() != nil: ac := action.Alter_table_rename_to() if ac.RENAME() != nil && ac.TO() != nil && ac.An_id_table() != nil { - // TODO: Returning here may be incorrect if there are multiple specs + // FIXME: Returning here may be incorrect if there are multiple specs newName := parseAnIdTable(ac.An_id_table()) return &ast.RenameTableStmt{ Table: stmt.Table, @@ -541,25 +578,33 @@ func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) a return stmt } -func (c *cc) convertDo_stmtContext(n *parser.Do_stmtContext) ast.Node { +func (c *cc) VisitDo_stmt(n *parser.Do_stmtContext) interface{} { if n.DO() == nil || (n.Call_action() == nil && n.Inline_action() == nil) { - return todo("convertDo_stmtContext", n) + return todo("VisitDo_stmt", n) } switch { case n.Call_action() != nil: - return c.convert(n.Call_action()) + result, ok := n.Call_action().Accept(c).(ast.Node) + if !ok { + return todo("VisitDo_stmt", n.Call_action()) + } + return result case n.Inline_action() != nil: - return c.convert(n.Inline_action()) + result, ok := n.Inline_action().Accept(c).(ast.Node) + if !ok { + return todo("VisitDo_stmt", n.Inline_action()) + } + return result } - return todo("convertDo_stmtContext", n) + return todo("VisitDo_stmt", n) } -func (c *cc) convertCall_actionContext(n *parser.Call_actionContext) ast.Node { +func (c *cc) VisitCall_action(n *parser.Call_actionContext) interface{} { if n == nil { - return nil + return todo("VisitCall_action", n) } if n.LPAREN() != nil && n.RPAREN() != nil { funcCall := &ast.FuncCall{ @@ -569,14 +614,22 @@ func (c *cc) convertCall_actionContext(n *parser.Call_actionContext) ast.Node { } if n.Bind_parameter() != nil { - funcCall.Funcname.Items = append(funcCall.Funcname.Items, c.convert(n.Bind_parameter())) + bindPar, ok := n.Bind_parameter().Accept(c).(ast.Node) + if !ok { + return todo("VisitCall_action", n.Bind_parameter()) + } + funcCall.Funcname.Items = append(funcCall.Funcname.Items, bindPar) } else if n.EMPTY_ACTION() != nil { funcCall.Funcname.Items = append(funcCall.Funcname.Items, &ast.String{Str: "EMPTY_ACTION"}) } if n.Expr_list() != nil { for _, expr := range n.Expr_list().AllExpr() { - funcCall.Args.Items = append(funcCall.Args.Items, c.convert(expr)) + exprNode, ok := expr.Accept(c).(ast.Node) + if !ok { + return todo("VisitCall_action", expr) + } + funcCall.Args.Items = append(funcCall.Args.Items, exprNode) } } @@ -584,29 +637,33 @@ func (c *cc) convertCall_actionContext(n *parser.Call_actionContext) ast.Node { Args: &ast.List{Items: []ast.Node{funcCall}}, } } - return todo("convertCall_actionContext", n) + return todo("VisitCall_action", n) } -func (c *cc) convertInline_actionContext(n *parser.Inline_actionContext) ast.Node { +func (c *cc) VisitInline_action(n *parser.Inline_actionContext) interface{} { if n == nil { - return nil + return todo("VisitInline_action", n) } if n.BEGIN() != nil && n.END() != nil && n.DO() != nil { args := &ast.List{} if defineBody := n.Define_action_or_subquery_body(); defineBody != nil { cores := defineBody.AllSql_stmt_core() for _, stmtCore := range cores { - if converted := c.convert(stmtCore); converted != nil { - args.Items = append(args.Items, converted) + if converted := stmtCore.Accept(c); converted != nil { + var convertedNode, ok = converted.(ast.Node) + if !ok { + return todo("VisitInline_action", stmtCore) + } + args.Items = append(args.Items, convertedNode) } } } return &ast.DoStmt{Args: args} } - return todo("convertInline_actionContext", n) + return todo("VisitInline_action", n) } -func (c *cc) convertDrop_table_stmtContext(n *parser.Drop_table_stmtContext) ast.Node { +func (c *cc) VisitDrop_table_stmt(n *parser.Drop_table_stmtContext) interface{} { if n.DROP() != nil && (n.TABLESTORE() != nil || (n.EXTERNAL() != nil && n.TABLE() != nil) || n.TABLE() != nil) { name := parseTableName(n.Simple_table_ref().Simple_table_ref_core()) stmt := &ast.DropTableStmt{ @@ -615,10 +672,10 @@ func (c *cc) convertDrop_table_stmtContext(n *parser.Drop_table_stmtContext) ast } return stmt } - return todo("convertDrop_Table_stmtContxt", n) + return todo("VisitDrop_table_stmt", n) } -func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { +func (c *cc) VisitDelete_stmt(n *parser.Delete_stmtContext) interface{} { batch := n.BATCH() != nil tableName := identifier(n.Simple_table_ref().Simple_table_ref_core().GetText()) @@ -626,7 +683,11 @@ func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { var where ast.Node if n.WHERE() != nil && n.Expr() != nil { - where = c.convert(n.Expr()) + whereNode, ok := n.Expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitDelete_stmt", n.Expr()) + } + where = whereNode } var cols *ast.List var source ast.Node @@ -649,17 +710,37 @@ func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { switch { case valSource.Values_stmt() != nil: stmt := emptySelectStmt() - stmt.ValuesLists = c.convert(valSource.Values_stmt()).(*ast.List) + temp, ok := valSource.Values_stmt().Accept(c).(ast.Node) + if !ok { + return todo("VisitDelete_stmt", valSource.Values_stmt()) + } + list, ok := temp.(*ast.List) + if !ok { + return todo("VisitDelete_stmt", valSource.Values_stmt()) + } + stmt.ValuesLists = list source = stmt case valSource.Select_stmt() != nil: - source = c.convert(valSource.Select_stmt()) + temp, ok := valSource.Select_stmt().Accept(c).(ast.Node) + if !ok { + return todo("VisitDelete_stmt", valSource.Select_stmt()) + } + source = temp } } } returning := &ast.List{} if ret := n.Returning_columns_list(); ret != nil { - returning = c.convert(ret).(*ast.List) + temp, ok := ret.Accept(c).(ast.Node) + if !ok { + return todo("VisitDelete_stmt", n.Returning_columns_list()) + } + returningNode, ok := temp.(*ast.List) + if !ok { + return todo("VisitDelete_stmt", n.Returning_columns_list()) + } + returning = returningNode } stmts := &ast.DeleteStmt{ @@ -674,7 +755,7 @@ func (c *cc) convertDelete_stmtContext(n *parser.Delete_stmtContext) ast.Node { return stmts } -func (c *cc) convertPragma_stmtContext(n *parser.Pragma_stmtContext) ast.Node { +func (c *cc) VisitPragma_stmt(n *parser.Pragma_stmtContext) interface{} { if n.PRAGMA() != nil && n.An_id() != nil { prefix := "" if p := n.Opt_id_prefix_or_type(); p != nil { @@ -696,22 +777,30 @@ func (c *cc) convertPragma_stmtContext(n *parser.Pragma_stmtContext) ast.Node { if n.EQUALS() != nil { stmt.Equals = true if val := n.Pragma_value(0); val != nil { - stmt.Values = &ast.List{Items: []ast.Node{c.convert(val)}} + valNode, ok := val.Accept(c).(ast.Node) + if !ok { + return todo("VisitPragma_stmt", n.Pragma_value(0)) + } + stmt.Values = &ast.List{Items: []ast.Node{valNode}} } } else if lp := n.LPAREN(); lp != nil { values := []ast.Node{} for _, v := range n.AllPragma_value() { - values = append(values, c.convert(v)) + valNode, ok := v.Accept(c).(ast.Node) + if !ok { + return todo("VisitPragma_stmt", v) + } + values = append(values, valNode) } stmt.Values = &ast.List{Items: values} } return stmt } - return todo("convertPragma_stmtContext", n) + return todo("VisitPragma_stmt", n) } -func (c *cc) convertPragma_valueContext(n *parser.Pragma_valueContext) ast.Node { +func (c *cc) VisitPragma_value(n *parser.Pragma_valueContext) interface{} { switch { case n.Signed_number() != nil: if n.Signed_number().Integer() != nil { @@ -742,16 +831,20 @@ func (c *cc) convertPragma_valueContext(n *parser.Pragma_valueContext) ast.Node } return &ast.A_Const{Val: &ast.Boolean{Boolval: i}, Location: c.pos(n.GetStart())} case n.Bind_parameter() != nil: - bindPar := c.convert(n.Bind_parameter()) - return bindPar + bindPar := n.Bind_parameter().Accept(c) + var bindParNode, ok = bindPar.(ast.Node) + if !ok { + return todo("VisitPragma_value", n.Bind_parameter()) + } + return bindParNode } - return todo("convertPragma_valueContext", n) + return todo("VisitPragma_value", n) } -func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { - if n.UPDATE() == nil { - return nil +func (c *cc) VisitUpdate_stmt(n *parser.Update_stmtContext) interface{} { + if n == nil || n.UPDATE() == nil { + return todo("VisitUpdate_stmt", n) } batch := n.BATCH() != nil @@ -772,7 +865,10 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { for _, clause := range nSet.Set_clause_list().AllSet_clause() { targetCtx := clause.Set_target() columnName := identifier(targetCtx.Column_name().GetText()) - expr := c.convert(clause.Expr()) + expr, ok := clause.Expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitUpdate_stmt", clause.Expr()) + } resTarget := &ast.ResTarget{ Name: &columnName, Val: expr, @@ -798,7 +894,11 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { Args: &ast.List{}, } for _, expr := range exprList.AllExpr() { - rowExpr.Args.Items = append(rowExpr.Args.Items, c.convert(expr)) + exprNode, ok := expr.Accept(c).(ast.Node) + if !ok { + return todo("VisitUpdate_stmt", expr) + } + rowExpr.Args.Items = append(rowExpr.Args.Items, exprNode) } } @@ -817,7 +917,11 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { } if n.WHERE() != nil && n.Expr() != nil { - where = c.convert(n.Expr()) + whereNode, ok := n.Expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitUpdate_stmt", n.Expr()) + } + where = whereNode } } else if n.ON() != nil && n.Into_values_source() != nil { @@ -841,17 +945,37 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { switch { case valSource.Values_stmt() != nil: stmt := emptySelectStmt() - stmt.ValuesLists = c.convert(valSource.Values_stmt()).(*ast.List) + temp, ok := valSource.Values_stmt().Accept(c).(ast.Node) + if !ok { + return todo("VisitUpdate_stmt", valSource.Values_stmt()) + } + list, ok := temp.(*ast.List) + if !ok { + return todo("VisitUpdate_stmt", valSource.Values_stmt()) + } + stmt.ValuesLists = list source = stmt case valSource.Select_stmt() != nil: - source = c.convert(valSource.Select_stmt()) + temp, ok := valSource.Select_stmt().Accept(c).(ast.Node) + if !ok { + return todo("VisitUpdate_stmt", valSource.Select_stmt()) + } + source = temp } } } returning := &ast.List{} if ret := n.Returning_columns_list(); ret != nil { - returning = c.convert(ret).(*ast.List) + temp, ok := ret.Accept(c).(ast.Node) + if !ok { + return todo("VisitDelete_stmt", n.Returning_columns_list()) + } + returningNode, ok := temp.(*ast.List) + if !ok { + return todo("VisitDelete_stmt", n.Returning_columns_list()) + } + returning = returningNode } stmts := &ast.UpdateStmt{ @@ -869,7 +993,7 @@ func (c *cc) convertUpdate_stmtContext(n *parser.Update_stmtContext) ast.Node { return stmts } -func (c *cc) convertInto_table_stmtContext(n *parser.Into_table_stmtContext) ast.Node { +func (c *cc) VisitInto_table_stmt(n *parser.Into_table_stmtContext) interface{} { tableName := identifier(n.Into_simple_table_ref().Simple_table_ref().Simple_table_ref_core().GetText()) rel := &ast.RangeVar{ Relname: &tableName, @@ -911,17 +1035,37 @@ func (c *cc) convertInto_table_stmtContext(n *parser.Into_table_stmtContext) ast switch { case valSource.Values_stmt() != nil: stmt := emptySelectStmt() - stmt.ValuesLists = c.convert(valSource.Values_stmt()).(*ast.List) + temp, ok := valSource.Values_stmt().Accept(c).(ast.Node) + if !ok { + return todo("VisitInto_table_stmt", valSource.Values_stmt()) + } + stmtNode, ok := temp.(*ast.List) + if !ok { + return todo("VisitInto_table_stmt", valSource.Values_stmt()) + } + stmt.ValuesLists = stmtNode source = stmt case valSource.Select_stmt() != nil: - source = c.convert(valSource.Select_stmt()) + sourceNode, ok := valSource.Select_stmt().Accept(c).(ast.Node) + if !ok { + return todo("VisitInto_table_stmt", valSource.Select_stmt()) + } + source = sourceNode } } } returning := &ast.List{} if ret := n.Returning_columns_list(); ret != nil { - returning = c.convert(ret).(*ast.List) + temp, ok := ret.Accept(c).(ast.Node) + if !ok { + return todo("VisitInto_table_stmt", n.Returning_columns_list()) + } + returningNode, ok := temp.(*ast.List) + if !ok { + return todo("VisitInto_table_stmt", n.Returning_columns_list()) + } + returning = returningNode } stmts := &ast.InsertStmt{ @@ -935,7 +1079,7 @@ func (c *cc) convertInto_table_stmtContext(n *parser.Into_table_stmtContext) ast return stmts } -func (c *cc) convertValues_stmtContext(n *parser.Values_stmtContext) ast.Node { +func (c *cc) VisitValues_stmt(n *parser.Values_stmtContext) interface{} { mainList := &ast.List{} for _, rowCtx := range n.Values_source_row_list().AllValues_source_row() { @@ -943,8 +1087,12 @@ func (c *cc) convertValues_stmtContext(n *parser.Values_stmtContext) ast.Node { exprListCtx := rowCtx.Expr_list().(*parser.Expr_listContext) for _, exprCtx := range exprListCtx.AllExpr() { - if converted := c.convert(exprCtx); converted != nil { - rowList.Items = append(rowList.Items, converted) + if converted := exprCtx.Accept(c); converted != nil { + var convertedNode, ok = converted.(ast.Node) + if !ok { + return todo("VisitValues_stmt", exprCtx) + } + rowList.Items = append(rowList.Items, convertedNode) } } @@ -955,7 +1103,7 @@ func (c *cc) convertValues_stmtContext(n *parser.Values_stmtContext) ast.Node { return mainList } -func (c *cc) convertReturning_columns_listContext(n *parser.Returning_columns_listContext) ast.Node { +func (c *cc) VisitReturning_columns_list(n *parser.Returning_columns_listContext) interface{} { list := &ast.List{Items: []ast.Node{}} if n.ASTERISK() != nil { @@ -988,30 +1136,36 @@ func (c *cc) convertReturning_columns_listContext(n *parser.Returning_columns_li return list } -func (c *cc) convertSelectStmtContext(n *parser.Select_stmtContext) ast.Node { +func (c *cc) VisitSelect_stmt(n *parser.Select_stmtContext) interface{} { if len(n.AllSelect_kind_parenthesis()) == 0 { - return todo("convertSelectStmtContext", n) + return todo("VisitSelect_stmt", n) } skp := n.Select_kind_parenthesis(0) if skp == nil { - return todo("convertSelectStmtContext", skp) + return todo("VisitSelect_stmt", skp) } - stmt := c.convertSelectKindParenthesis(skp) - left, ok := stmt.(*ast.SelectStmt) + temp, ok := skp.Accept(c).(ast.Node) + if !ok { + return todo("VisitSelect_kind_parenthesis", skp) + } + left, ok := temp.(*ast.SelectStmt) if left == nil || !ok { - return todo("convertSelectKindParenthesis", skp) + return todo("VisitSelect_kind_parenthesis", skp) } kinds := n.AllSelect_kind_parenthesis() ops := n.AllSelect_op() for i := 1; i < len(kinds); i++ { - stmt := c.convertSelectKindParenthesis(kinds[i]) - right, ok := stmt.(*ast.SelectStmt) + temp, ok := kinds[i].Accept(c).(ast.Node) + if !ok { + return todo("VisitSelect_kind_parenthesis", kinds[i]) + } + right, ok := temp.(*ast.SelectStmt) if right == nil || !ok { - return todo("convertSelectKindParenthesis", kinds[i]) + return todo("VisitSelect_kind_parenthesis", kinds[i]) } var op ast.SetOperation @@ -1041,21 +1195,25 @@ func (c *cc) convertSelectStmtContext(n *parser.Select_stmtContext) ast.Node { return left } -func (c *cc) convertSelectKindParenthesis(n parser.ISelect_kind_parenthesisContext) ast.Node { +func (c *cc) VisitSelect_kind_parenthesis(n *parser.Select_kind_parenthesisContext) interface{} { if n == nil || n.Select_kind_partial() == nil { - return todo("convertSelectKindParenthesis", n) + return todo("VisitSelect_kind_parenthesis", n) } partial := n.Select_kind_partial() sk := partial.Select_kind() if sk == nil { - return todo("convertSelectKind", sk) + return todo("VisitSelect_kind_parenthesis", sk) } var base ast.Node switch { case sk.Select_core() != nil: - base = c.convertSelectCoreContext(sk.Select_core()) + baseNode, ok := sk.Select_core().Accept(c).(ast.Node) + if !ok { + return todo("VisitSelect_kind_parenthesis", sk.Select_core()) + } + base = baseNode case sk.Process_core() != nil: log.Fatalf("PROCESS is not supported in YDB engine") case sk.Reduce_core() != nil: @@ -1063,7 +1221,7 @@ func (c *cc) convertSelectKindParenthesis(n parser.ISelect_kind_parenthesisConte } stmt, ok := base.(*ast.SelectStmt) if !ok || stmt == nil { - return todo("convertSelectKindParenthesis", sk.Select_core()) + return todo("VisitSelect_kind_parenthesis", sk.Select_core()) } // TODO: handle INTO RESULT clause @@ -1071,11 +1229,19 @@ func (c *cc) convertSelectKindParenthesis(n parser.ISelect_kind_parenthesisConte if partial.LIMIT() != nil { exprs := partial.AllExpr() if len(exprs) >= 1 { - stmt.LimitCount = c.convert(exprs[0]) + temp, ok := exprs[0].Accept(c).(ast.Node) + if !ok { + return todo("VisitSelect_kind_parenthesis", exprs[0]) + } + stmt.LimitCount = temp } if partial.OFFSET() != nil { if len(exprs) >= 2 { - stmt.LimitOffset = c.convert(exprs[1]) + temp, ok := exprs[1].Accept(c).(ast.Node) + if !ok { + return todo("VisitSelect_kind_parenthesis", exprs[1]) + } + stmt.LimitOffset = temp } } } @@ -1083,7 +1249,7 @@ func (c *cc) convertSelectKindParenthesis(n parser.ISelect_kind_parenthesisConte return stmt } -func (c *cc) convertSelectCoreContext(n parser.ISelect_coreContext) ast.Node { +func (c *cc) VisitSelect_core(n *parser.Select_coreContext) interface{} { stmt := emptySelectStmt() if n.Opt_set_quantifier() != nil { oq := n.Opt_set_quantifier() @@ -1095,14 +1261,11 @@ func (c *cc) convertSelectCoreContext(n parser.ISelect_coreContext) ast.Node { if len(resultCols) > 0 { var items []ast.Node for _, rc := range resultCols { - resCol, ok := rc.(*parser.Result_columnContext) + convNode, ok := rc.Accept(c).(ast.Node) if !ok { - continue - } - convNode := c.convertResultColumn(resCol) - if convNode != nil { - items = append(items, convNode) + return todo("VisitSelect_core", rc) } + items = append(items, convNode) } stmt.TargetList = &ast.List{ Items: items, @@ -1118,15 +1281,11 @@ func (c *cc) convertSelectCoreContext(n parser.ISelect_coreContext) ast.Node { if len(jsList) > 0 { var fromItems []ast.Node for _, js := range jsList { - jsCon, ok := js.(*parser.Join_sourceContext) + joinNode, ok := js.Accept(c).(ast.Node) if !ok { - continue - } - - joinNode := c.convertJoinSource(jsCon) - if joinNode != nil { - fromItems = append(fromItems, joinNode) + return todo("VisitSelect_core", js) } + fromItems = append(fromItems, joinNode) } stmt.FromClause = &ast.List{ Items: fromItems, @@ -1136,13 +1295,21 @@ func (c *cc) convertSelectCoreContext(n parser.ISelect_coreContext) ast.Node { exprIdx := 0 if n.WHERE() != nil { if whereCtx := n.Expr(exprIdx); whereCtx != nil { - stmt.WhereClause = c.convert(whereCtx) + where, ok := whereCtx.Accept(c).(ast.Node) + if !ok { + return todo("VisitSelect_core", whereCtx) + } + stmt.WhereClause = where } exprIdx++ } if n.HAVING() != nil { if havingCtx := n.Expr(exprIdx); havingCtx != nil { - stmt.HavingClause = c.convert(havingCtx) + having, ok := havingCtx.Accept(c).(ast.Node) + if !ok || having == nil { + return todo("VisitSelect_core", havingCtx) + } + stmt.HavingClause = having } exprIdx++ } @@ -1151,7 +1318,11 @@ func (c *cc) convertSelectCoreContext(n parser.ISelect_coreContext) ast.Node { if gel := gbc.Grouping_element_list(); gel != nil { var groups []ast.Node for _, ne := range gel.AllGrouping_element() { - groups = append(groups, c.convert(ne)) + groupBy, ok := ne.Accept(c).(ast.Node) + if !ok { + return todo("VisitSelect_core", ne) + } + groups = append(groups, groupBy) } if len(groups) > 0 { stmt.GroupClause = &ast.List{Items: groups} @@ -1165,15 +1336,14 @@ func (c *cc) convertSelectCoreContext(n parser.ISelect_coreContext) ast.Node { if sl := ob.Sort_specification_list(); sl != nil { var orderItems []ast.Node for _, sp := range sl.AllSort_specification() { - ss, ok := sp.(*parser.Sort_specificationContext) - if !ok || ss == nil { - continue + expr, ok := sp.Expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitSelect_core", sp.Expr()) } - expr := c.convert(ss.Expr()) dir := ast.SortByDirDefault - if ss.ASC() != nil { + if sp.ASC() != nil { dir = ast.SortByDirAsc - } else if ss.DESC() != nil { + } else if sp.DESC() != nil { dir = ast.SortByDirDesc } orderItems = append(orderItems, &ast.SortBy{ @@ -1181,7 +1351,7 @@ func (c *cc) convertSelectCoreContext(n parser.ISelect_coreContext) ast.Node { SortbyDir: dir, SortbyNulls: ast.SortByNullsUndefined, UseOp: &ast.List{}, - Location: c.pos(ss.GetStart()), + Location: c.pos(sp.GetStart()), }) } if len(orderItems) > 0 { @@ -1193,77 +1363,109 @@ func (c *cc) convertSelectCoreContext(n parser.ISelect_coreContext) ast.Node { return stmt } -func (c *cc) convertGrouping_elementContext(n parser.IGrouping_elementContext) ast.Node { +func (c *cc) VisitGrouping_element(n *parser.Grouping_elementContext) interface{} { if n == nil { - return todo("convertGrouping_elementContext", n) + return todo("VisitGrouping_element", n) } if ogs := n.Ordinary_grouping_set(); ogs != nil { - return c.convert(ogs) + groupingSet, ok := ogs.Accept(c).(ast.Node) + if !ok { + return todo("VisitGrouping_element", ogs) + } + return groupingSet } if rl := n.Rollup_list(); rl != nil { - return c.convert(rl) + rollupList, ok := rl.Accept(c).(ast.Node) + if !ok { + return todo("VisitGrouping_element", rl) + } + return rollupList } if cl := n.Cube_list(); cl != nil { - return c.convert(cl) + cubeList, ok := cl.Accept(c).(ast.Node) + if !ok { + return todo("VisitGrouping_element", cl) + } + return cubeList } if gss := n.Grouping_sets_specification(); gss != nil { - return c.convert(gss) + groupingSets, ok := gss.Accept(c).(ast.Node) + if !ok { + return todo("VisitGrouping_element", gss) + } + return groupingSets } - return todo("convertGrouping_elementContext", n) + return todo("VisitGrouping_element", n) } -func (c *cc) convertOrdinary_grouping_setContext(n *parser.Ordinary_grouping_setContext) ast.Node { +func (c *cc) VisitOrdinary_grouping_set(n *parser.Ordinary_grouping_setContext) interface{} { if n == nil || n.Named_expr() == nil { - return todo("convertOrdinary_grouping_setContext", n) + return todo("VisitOrdinary_grouping_set", n) } - return c.convert(n.Named_expr()) + namedExpr, ok := n.Named_expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitOrdinary_grouping_set", n.Named_expr()) + } + return namedExpr } -func (c *cc) convertRollup_listContext(n *parser.Rollup_listContext) ast.Node { +func (c *cc) VisitRollup_list(n *parser.Rollup_listContext) interface{} { if n == nil || n.ROLLUP() == nil || n.LPAREN() == nil || n.RPAREN() == nil { - return todo("convertRollup_listContext", n) + return todo("VisitRollup_list", n) } var items []ast.Node if list := n.Ordinary_grouping_set_list(); list != nil { for _, ogs := range list.AllOrdinary_grouping_set() { - items = append(items, c.convert(ogs)) + og, ok := ogs.Accept(c).(ast.Node) + if !ok { + return todo("VisitRollup_list", ogs) + } + items = append(items, og) } } return &ast.GroupingSet{Kind: 1, Content: &ast.List{Items: items}} } -func (c *cc) convertCube_listContext(n *parser.Cube_listContext) ast.Node { +func (c *cc) VisitCube_list(n *parser.Cube_listContext) interface{} { if n == nil || n.CUBE() == nil || n.LPAREN() == nil || n.RPAREN() == nil { - return todo("convertCube_listContext", n) + return todo("VisitCube_list", n) } var items []ast.Node if list := n.Ordinary_grouping_set_list(); list != nil { for _, ogs := range list.AllOrdinary_grouping_set() { - items = append(items, c.convert(ogs)) + og, ok := ogs.Accept(c).(ast.Node) + if !ok { + return todo("VisitCube_list", ogs) + } + items = append(items, og) } } return &ast.GroupingSet{Kind: 2, Content: &ast.List{Items: items}} } -func (c *cc) convertGrouping_sets_specificationContext(n *parser.Grouping_sets_specificationContext) ast.Node { +func (c *cc) VisitGrouping_sets_specification(n *parser.Grouping_sets_specificationContext) interface{} { if n == nil || n.GROUPING() == nil || n.SETS() == nil || n.LPAREN() == nil || n.RPAREN() == nil { - return todo("convertGrouping_sets_specificationContext", n) + return todo("VisitGrouping_sets_specification", n) } var items []ast.Node if gel := n.Grouping_element_list(); gel != nil { for _, ge := range gel.AllGrouping_element() { - items = append(items, c.convert(ge)) + g, ok := ge.Accept(c).(ast.Node) + if !ok { + return todo("VisitGrouping_sets_specification", ge) + } + items = append(items, g) } } return &ast.GroupingSet{Kind: 3, Content: &ast.List{Items: items}} } -func (c *cc) convertResultColumn(n *parser.Result_columnContext) ast.Node { +func (c *cc) VisitResult_column(n *parser.Result_columnContext) interface{} { // todo: support opt_id_prefix target := &ast.ResTarget{ Location: c.pos(n.GetStart()), @@ -1274,11 +1476,15 @@ func (c *cc) convertResultColumn(n *parser.Result_columnContext) ast.Node { case n.ASTERISK() != nil: val = c.convertWildCardField(n) case iexpr != nil: - val = c.convert(iexpr) + temp, ok := iexpr.Accept(c).(ast.Node) + if !ok { + return todo("VisitResult_column", iexpr) + } + val = temp } if val == nil { - return nil + return todo("VisitResult_column", n) } switch { case n.AS() != nil && n.An_id_or_type() != nil: @@ -1291,30 +1497,27 @@ func (c *cc) convertResultColumn(n *parser.Result_columnContext) ast.Node { return target } -func (c *cc) convertJoinSource(n *parser.Join_sourceContext) ast.Node { - if n == nil { - return nil +func (c *cc) VisitJoin_source(n *parser.Join_sourceContext) interface{} { + if n == nil || len(n.AllFlatten_source()) == 0 { + return todo("VisitJoin_source", n) } fsList := n.AllFlatten_source() - if len(fsList) == 0 { - return nil - } joinOps := n.AllJoin_op() joinConstraints := n.AllJoin_constraint() // todo: add ANY support - leftNode := c.convertFlattenSource(fsList[0]) - if leftNode == nil { - return nil + leftNode, ok := fsList[0].Accept(c).(ast.Node) + if !ok { + return todo("VisitJoin_source", fsList[0]) } for i, jopCtx := range joinOps { if i+1 >= len(fsList) { break } - rightNode := c.convertFlattenSource(fsList[i+1]) - if rightNode == nil { - return leftNode + rightNode, ok := fsList[i+1].Accept(c).(ast.Node) + if !ok { + return todo("VisitJoin_source", fsList[i+1]) } jexpr := &ast.JoinExpr{ Larg: leftNode, @@ -1343,7 +1546,11 @@ func (c *cc) convertJoinSource(n *parser.Join_sourceContext) ast.Node { switch { case jc.ON() != nil: if exprCtx := jc.Expr(); exprCtx != nil { - jexpr.Quals = c.convert(exprCtx) + expr, ok := exprCtx.Accept(c).(ast.Node) + if !ok { + return todo("VisitJoin_source", exprCtx) + } + jexpr.Quals = expr } case jc.USING() != nil: if pureListCtx := jc.Pure_column_or_named_list(); pureListCtx != nil { @@ -1353,12 +1560,17 @@ func (c *cc) convertJoinSource(n *parser.Join_sourceContext) ast.Node { if anID := pureCtx.An_id(); anID != nil { using.Items = append(using.Items, NewIdentifier(parseAnId(anID))) } else if bp := pureCtx.Bind_parameter(); bp != nil { - bindPar := c.convert(bp) + bindPar, ok := bp.Accept(c).(ast.Node) + if !ok { + return todo("VisitJoin_source", bp) + } using.Items = append(using.Items, bindPar) } } jexpr.UsingClause = &using } + default: + return todo("VisitJoin_source", jc) } } } @@ -1367,31 +1579,25 @@ func (c *cc) convertJoinSource(n *parser.Join_sourceContext) ast.Node { return leftNode } -func (c *cc) convertFlattenSource(n parser.IFlatten_sourceContext) ast.Node { - if n == nil { - return nil - } - nss := n.Named_single_source() - if nss == nil { - return nil +func (c *cc) VisitFlatten_source(n *parser.Flatten_sourceContext) interface{} { + if n == nil || n.Named_single_source() == nil { + return todo("VisitFlatten_source", n) } - namedSingleSource, ok := nss.(*parser.Named_single_sourceContext) + namedSingleSource, ok := n.Named_single_source().Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitFlatten_source", n.Named_single_source()) } - return c.convertNamedSingleSource(namedSingleSource) + return namedSingleSource } -func (c *cc) convertNamedSingleSource(n *parser.Named_single_sourceContext) ast.Node { - ss := n.Single_source() - if ss == nil { - return nil +func (c *cc) VisitNamed_single_source(n *parser.Named_single_sourceContext) interface{} { + if n == nil || n.Single_source() == nil { + return todo("VisitNamed_single_source", n) } - SingleSource, ok := ss.(*parser.Single_sourceContext) + base, ok := n.Single_source().Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitNamed_single_source", n.Single_source()) } - base := c.convertSingleSource(SingleSource) if n.AS() != nil && n.An_id() != nil { aliasText := parseAnId(n.An_id()) @@ -1407,7 +1613,11 @@ func (c *cc) convertNamedSingleSource(n *parser.Named_single_sourceContext) ast. return base } -func (c *cc) convertSingleSource(n *parser.Single_sourceContext) ast.Node { +func (c *cc) VisitSingle_source(n *parser.Single_sourceContext) interface{} { + if n == nil { + return todo("VisitSingle_source", n) + } + if n.Table_ref() != nil { tableName := n.Table_ref().GetText() // !! debug !! return &ast.RangeVar{ @@ -1417,7 +1627,10 @@ func (c *cc) convertSingleSource(n *parser.Single_sourceContext) ast.Node { } if n.Select_stmt() != nil { - subquery := c.convert(n.Select_stmt()) + subquery, ok := n.Select_stmt().Accept(c).(ast.Node) + if !ok { + return todo("VisitSingle_source", n.Select_stmt()) + } return &ast.RangeSubselect{ Subquery: subquery, } @@ -1425,35 +1638,30 @@ func (c *cc) convertSingleSource(n *parser.Single_sourceContext) ast.Node { } // todo: Values stmt - return nil + return todo("VisitSingle_source", n) } -func (c *cc) convertBindParameter(n *parser.Bind_parameterContext) ast.Node { - // !!debug later!! - if n.DOLLAR() != nil { - if n.TRUE() != nil { - return &ast.A_Const{Val: &ast.Boolean{Boolval: true}, Location: c.pos(n.GetStart())} - } - if n.FALSE() != nil { - return &ast.A_Const{Val: &ast.Boolean{Boolval: false}, Location: c.pos(n.GetStart())} - } +func (c *cc) VisitBind_parameter(n *parser.Bind_parameterContext) interface{} { + if n == nil || n.DOLLAR() == nil { + return todo("VisitBind_parameter", n) + } - if an := n.An_id_or_type(); an != nil { - idText := parseAnIdOrType(an) - return &ast.A_Expr{ - Name: &ast.List{Items: []ast.Node{&ast.String{Str: "@"}}}, - Rexpr: &ast.String{Str: idText}, - Location: c.pos(n.GetStart()), - } - } - c.paramCount++ - return &ast.ParamRef{ - Number: c.paramCount, + if n.TRUE() != nil { + return &ast.A_Const{Val: &ast.Boolean{Boolval: true}, Location: c.pos(n.GetStart())} + } + if n.FALSE() != nil { + return &ast.A_Const{Val: &ast.Boolean{Boolval: false}, Location: c.pos(n.GetStart())} + } + + if an := n.An_id_or_type(); an != nil { + idText := parseAnIdOrType(an) + return &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "@"}}}, + Rexpr: &ast.String{Str: idText}, Location: c.pos(n.GetStart()), - Dollar: true, } } - return &ast.TODO{} + return todo("VisitBind_parameter", n) } func (c *cc) convertWildCardField(n *parser.Result_columnContext) *ast.ColumnRef { @@ -1472,129 +1680,154 @@ func (c *cc) convertWildCardField(n *parser.Result_columnContext) *ast.ColumnRef } } -func (c *cc) convertOptIdPrefix(ctx parser.IOpt_id_prefixContext) string { - if ctx == nil { +func (c *cc) convertOptIdPrefix(n parser.IOpt_id_prefixContext) string { + if n == nil { return "" } - if ctx.An_id() != nil { - return ctx.An_id().GetText() + if n.An_id() != nil { + return n.An_id().GetText() } return "" } -func (c *cc) convertCreate_table_stmtContext(n *parser.Create_table_stmtContext) ast.Node { +func (c *cc) VisitCreate_table_stmt(n *parser.Create_table_stmtContext) interface{} { stmt := &ast.CreateTableStmt{ Name: parseTableName(n.Simple_table_ref().Simple_table_ref_core()), IfNotExists: n.EXISTS() != nil, } - for _, idef := range n.AllCreate_table_entry() { - if def, ok := idef.(*parser.Create_table_entryContext); ok { + for _, def := range n.AllCreate_table_entry() { + switch { + case def.Column_schema() != nil: + temp, ok := def.Column_schema().Accept(c).(ast.Node) + if !ok { + return todo("VisitCreate_table_stmt", def.Column_schema()) + } + colCtx, ok := temp.(*ast.ColumnDef) + if !ok { + return todo("VisitCreate_table_stmt", def.Column_schema()) + } + stmt.Cols = append(stmt.Cols, colCtx) + case def.Table_constraint() != nil: + conCtx := def.Table_constraint() switch { - case def.Column_schema() != nil: - if colCtx, ok := def.Column_schema().(*parser.Column_schemaContext); ok { - colDef := c.convertColumnSchema(colCtx) - if colDef != nil { - stmt.Cols = append(stmt.Cols, colDef) - } - } - case def.Table_constraint() != nil: - if conCtx, ok := def.Table_constraint().(*parser.Table_constraintContext); ok { - switch { - case conCtx.PRIMARY() != nil && conCtx.KEY() != nil: - for _, cname := range conCtx.AllAn_id() { - for _, col := range stmt.Cols { - if col.Colname == parseAnId(cname) { - col.IsNotNull = true - } - } + case conCtx.PRIMARY() != nil && conCtx.KEY() != nil: + for _, cname := range conCtx.AllAn_id() { + for _, col := range stmt.Cols { + if col.Colname == parseAnId(cname) { + col.IsNotNull = true } - case conCtx.PARTITION() != nil && conCtx.BY() != nil: - _ = conCtx - // todo: partition by constraint - case conCtx.ORDER() != nil && conCtx.BY() != nil: - _ = conCtx - // todo: order by constraint } } - - case def.Table_index() != nil: - if indCtx, ok := def.Table_index().(*parser.Table_indexContext); ok { - _ = indCtx - // todo - } - case def.Family_entry() != nil: - if famCtx, ok := def.Family_entry().(*parser.Family_entryContext); ok { - _ = famCtx - // todo - } - case def.Changefeed() != nil: // таблица ориентированная - if cgfCtx, ok := def.Changefeed().(*parser.ChangefeedContext); ok { - _ = cgfCtx - // todo - } + case conCtx.PARTITION() != nil && conCtx.BY() != nil: + return todo("VisitCreate_table_stmt", conCtx) + case conCtx.ORDER() != nil && conCtx.BY() != nil: + return todo("VisitCreate_table_stmt", conCtx) } + + case def.Table_index() != nil: + return todo("VisitCreate_table_stmt", def.Table_index()) + case def.Family_entry() != nil: + return todo("VisitCreate_table_stmt", def.Family_entry()) + case def.Changefeed() != nil: // table-oriented + return todo("VisitCreate_table_stmt", def.Changefeed()) } } return stmt } -func (c *cc) convertColumnSchema(n *parser.Column_schemaContext) *ast.ColumnDef { - +func (c *cc) VisitColumn_schema(n *parser.Column_schemaContext) interface{} { + if n == nil { + return todo("VisitColumn_schema", n) + } col := &ast.ColumnDef{} if anId := n.An_id_schema(); anId != nil { col.Colname = identifier(parseAnIdSchema(anId)) } if tnb := n.Type_name_or_bind(); tnb != nil { - col.TypeName = c.convertTypeNameOrBind(tnb) + temp, ok := tnb.Accept(c).(ast.Node) + if !ok { + return todo("VisitColumn_schema", tnb) + } + typeName, ok := temp.(*ast.TypeName) + if !ok { + return todo("VisitColumn_schema", tnb) + } + col.TypeName = typeName } if colCons := n.Opt_column_constraints(); colCons != nil { col.IsNotNull = colCons.NOT() != nil && colCons.NULL() != nil - //todo: cover exprs if needed + + if colCons.DEFAULT() != nil && colCons.Expr() != nil { + defaultExpr, ok := colCons.Expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitColumn_schema", colCons.Expr()) + } + col.RawDefault = defaultExpr + } } // todo: family return col } -func (c *cc) convertTypeNameOrBind(n parser.IType_name_or_bindContext) *ast.TypeName { +func (c *cc) VisitType_name_or_bind(n *parser.Type_name_or_bindContext) interface{} { + if n == nil { + return todo("VisitType_name_or_bind", n) + } + if t := n.Type_name(); t != nil { - return c.convertTypeName(t) + temp, ok := t.Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_or_bind", t) + } + typeName, ok := temp.(*ast.TypeName) + if !ok { + return todo("VisitType_name_or_bind", t) + } + return typeName } else if b := n.Bind_parameter(); b != nil { return &ast.TypeName{Name: "BIND:" + identifier(parseAnIdOrType(b.An_id_or_type()))} } - return nil + return todo("VisitType_name_or_bind", n) } -func (c *cc) convertTypeName(n parser.IType_nameContext) *ast.TypeName { +func (c *cc) VisitType_name(n *parser.Type_nameContext) interface{} { if n == nil { - return nil + return todo("VisitType_name", n) } if composite := n.Type_name_composite(); composite != nil { - if node := c.convertTypeNameComposite(composite); node != nil { - if typeName, ok := node.(*ast.TypeName); ok { - return typeName - } + typeName, ok := composite.Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_or_bind", composite) } + return typeName } if decimal := n.Type_name_decimal(); decimal != nil { if integerOrBinds := decimal.AllInteger_or_bind(); len(integerOrBinds) >= 2 { + first, ok := integerOrBinds[0].Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name", decimal.Integer_or_bind(0)) + } + second, ok := integerOrBinds[1].Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name", decimal.Integer_or_bind(1)) + } return &ast.TypeName{ Name: "decimal", TypeOid: 0, Names: &ast.List{ Items: []ast.Node{ - c.convertIntegerOrBind(integerOrBinds[0]), - c.convertIntegerOrBind(integerOrBinds[1]), + first, + second, }, }, } } } - // Handle simple types if simple := n.Type_name_simple(); simple != nil { return &ast.TypeName{ Name: simple.GetText(), @@ -1602,41 +1835,49 @@ func (c *cc) convertTypeName(n parser.IType_nameContext) *ast.TypeName { } } - return nil + return todo("VisitType_name", n) } -func (c *cc) convertIntegerOrBind(n parser.IInteger_or_bindContext) ast.Node { +func (c *cc) VisitInteger_or_bind(n *parser.Integer_or_bindContext) interface{} { if n == nil { - return nil + return todo("VisitInteger_or_bind", n) } if integer := n.Integer(); integer != nil { val, err := parseIntegerValue(integer.GetText()) if err != nil { - return &ast.TODO{} + return todo("VisitInteger_or_bind", n.Integer()) } return &ast.Integer{Ival: val} } if bind := n.Bind_parameter(); bind != nil { - return c.convertBindParameter(bind.(*parser.Bind_parameterContext)) + temp, ok := bind.Accept(c).(ast.Node) + if !ok { + return todo("VisitInteger_or_bind", bind) + } + return temp } - return nil + return todo("VisitInteger_or_bind", n) } -func (c *cc) convertTypeNameComposite(n parser.IType_name_compositeContext) ast.Node { +func (c *cc) VisitType_name_composite(n *parser.Type_name_compositeContext) interface{} { if n == nil { - return nil + return todo("VisitType_name_composite", n) } if opt := n.Type_name_optional(); opt != nil { if typeName := opt.Type_name_or_bind(); typeName != nil { + tn, ok := typeName.Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_composite", typeName) + } return &ast.TypeName{ Name: "Optional", TypeOid: 0, Names: &ast.List{ - Items: []ast.Node{c.convertTypeNameOrBind(typeName)}, + Items: []ast.Node{tn}, }, } } @@ -1646,7 +1887,11 @@ func (c *cc) convertTypeNameComposite(n parser.IType_name_compositeContext) ast. if typeNames := tuple.AllType_name_or_bind(); len(typeNames) > 0 { var items []ast.Node for _, tn := range typeNames { - items = append(items, c.convertTypeNameOrBind(tn)) + tnNode, ok := tn.Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_composite", tn) + } + items = append(items, tnNode) } return &ast.TypeName{ Name: "Tuple", @@ -1688,11 +1933,15 @@ func (c *cc) convertTypeNameComposite(n parser.IType_name_compositeContext) ast. if list := n.Type_name_list(); list != nil { if typeName := list.Type_name_or_bind(); typeName != nil { + tn, ok := typeName.Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_composite", typeName) + } return &ast.TypeName{ Name: "List", TypeOid: 0, Names: &ast.List{ - Items: []ast.Node{c.convertTypeNameOrBind(typeName)}, + Items: []ast.Node{tn}, }, } } @@ -1700,37 +1949,41 @@ func (c *cc) convertTypeNameComposite(n parser.IType_name_compositeContext) ast. if stream := n.Type_name_stream(); stream != nil { if typeName := stream.Type_name_or_bind(); typeName != nil { + tn, ok := typeName.Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_composite", typeName) + } return &ast.TypeName{ Name: "Stream", TypeOid: 0, Names: &ast.List{ - Items: []ast.Node{c.convertTypeNameOrBind(typeName)}, + Items: []ast.Node{tn}, }, } } } if flow := n.Type_name_flow(); flow != nil { - if typeName := flow.Type_name_or_bind(); typeName != nil { - return &ast.TypeName{ - Name: "Flow", - TypeOid: 0, - Names: &ast.List{ - Items: []ast.Node{c.convertTypeNameOrBind(typeName)}, - }, - } - } + return todo("VisitType_name_composite", flow) } if dict := n.Type_name_dict(); dict != nil { if typeNames := dict.AllType_name_or_bind(); len(typeNames) >= 2 { + first, ok := typeNames[0].Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_composite", typeNames[0]) + } + second, ok := typeNames[1].Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_composite", typeNames[1]) + } return &ast.TypeName{ Name: "Dict", TypeOid: 0, Names: &ast.List{ Items: []ast.Node{ - c.convertTypeNameOrBind(typeNames[0]), - c.convertTypeNameOrBind(typeNames[1]), + first, + second, }, }, } @@ -1739,259 +1992,234 @@ func (c *cc) convertTypeNameComposite(n parser.IType_name_compositeContext) ast. if set := n.Type_name_set(); set != nil { if typeName := set.Type_name_or_bind(); typeName != nil { + tn, ok := typeName.Accept(c).(ast.Node) + if !ok { + return todo("VisitType_name_composite", typeName) + } return &ast.TypeName{ Name: "Set", TypeOid: 0, Names: &ast.List{ - Items: []ast.Node{c.convertTypeNameOrBind(typeName)}, + Items: []ast.Node{tn}, }, } } } - if enum := n.Type_name_enum(); enum != nil { - if typeTags := enum.AllType_name_tag(); len(typeTags) > 0 { - var items []ast.Node - for range typeTags { // todo: Handle enum tags - items = append(items, &ast.TODO{}) - } - return &ast.TypeName{ - Name: "Enum", - TypeOid: 0, - Names: &ast.List{Items: items}, - } - } + if enum := n.Type_name_enum(); enum != nil { // todo: handle enum + todo("VisitType_name_composite", enum) } - if resource := n.Type_name_resource(); resource != nil { - if typeTag := resource.Type_name_tag(); typeTag != nil { - // TODO: Handle resource tag - return &ast.TypeName{ - Name: "Resource", - TypeOid: 0, - Names: &ast.List{ - Items: []ast.Node{&ast.TODO{}}, - }, - } - } + if resource := n.Type_name_resource(); resource != nil { // todo: handle resource + todo("VisitType_name_composite", resource) } - if tagged := n.Type_name_tagged(); tagged != nil { - if typeName := tagged.Type_name_or_bind(); typeName != nil { - if typeTag := tagged.Type_name_tag(); typeTag != nil { - // TODO: Handle tagged type and tag - return &ast.TypeName{ - Name: "Tagged", - TypeOid: 0, - Names: &ast.List{ - Items: []ast.Node{ - c.convertTypeNameOrBind(typeName), - &ast.TODO{}, - }, - }, - } - } - } + if tagged := n.Type_name_tagged(); tagged != nil { // todo: handle tagged + todo("VisitType_name_composite", tagged) } - if callable := n.Type_name_callable(); callable != nil { - // TODO: Handle callable argument list and return type - return &ast.TypeName{ - Name: "Callable", - TypeOid: 0, - Names: &ast.List{ - Items: []ast.Node{&ast.TODO{}}, - }, - } + if callable := n.Type_name_callable(); callable != nil { // todo: handle callable + todo("VisitType_name_composite", callable) } - return nil + return todo("VisitType_name_composite", n) } -func (c *cc) convertSqlStmtCore(n parser.ISql_stmt_coreContext) ast.Node { +func (c *cc) VisitSql_stmt_core(n *parser.Sql_stmt_coreContext) interface{} { if n == nil { - return nil + return todo("VisitSql_stmt_core", n) } if stmt := n.Pragma_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Select_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Named_nodes_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_table_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) + } + if stmt := n.Named_nodes_stmt(); stmt != nil { + return stmt.Accept(c) + } + if stmt := n.Create_table_stmt(); stmt != nil { + return stmt.Accept(c) } if stmt := n.Drop_table_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Use_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Into_table_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Commit_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Update_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Delete_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Rollback_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Declare_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Import_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Export_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_table_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_external_table_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Do_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Define_action_or_subquery_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.If_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.For_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Values_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_user_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_user_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_group_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_group_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Drop_role_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_object_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_object_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Drop_object_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_external_data_source_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_external_data_source_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Drop_external_data_source_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_replication_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Drop_replication_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_topic_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_topic_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Drop_topic_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Grant_permissions_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Revoke_permissions_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_table_store_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Upsert_object_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_view_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Drop_view_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_replication_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_resource_pool_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_resource_pool_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Drop_resource_pool_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_backup_collection_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_backup_collection_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Drop_backup_collection_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Analyze_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Create_resource_pool_classifier_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_resource_pool_classifier_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Drop_resource_pool_classifier_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Backup_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Restore_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } if stmt := n.Alter_sequence_stmt(); stmt != nil { - return c.convert(stmt) + return stmt.Accept(c) } - return nil + return todo("VisitSql_stmt_core", n) } -func (c *cc) convertNamed_exprContext(n *parser.Named_exprContext) ast.Node { +func (c *cc) VisitNamed_expr(n *parser.Named_exprContext) interface{} { if n == nil || n.Expr() == nil { - return todo("convertNamed_exprContext", n) + return todo("VisitNamed_expr", n) } - expr := c.convert(n.Expr()) + + expr, ok := n.Expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitNamed_expr", n) + } + if n.AS() != nil && n.An_id_or_type() != nil { name := parseAnIdOrType(n.An_id_or_type()) return &ast.ResTarget{ @@ -2003,32 +2231,32 @@ func (c *cc) convertNamed_exprContext(n *parser.Named_exprContext) ast.Node { return expr } -func (c *cc) convertExpr(n *parser.ExprContext) ast.Node { +func (c *cc) VisitExpr(n *parser.ExprContext) interface{} { if n == nil { - return nil + return todo("VisitExpr", n) } if tn := n.Type_name_composite(); tn != nil { - return c.convertTypeNameComposite(tn) + return tn.Accept(c) } orSubs := n.AllOr_subexpr() if len(orSubs) == 0 { - return nil + return todo("VisitExpr", n) } - orSub, ok := orSubs[0].(*parser.Or_subexprContext) + left, ok := n.Or_subexpr(0).Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitExpr", n) } - left := c.convertOrSubExpr(orSub) for i := 1; i < len(orSubs); i++ { - orSub, ok = orSubs[i].(*parser.Or_subexprContext) + + right, ok := orSubs[i].Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitExpr", n) } - right := c.convertOrSubExpr(orSub) + left = &ast.BoolExpr{ Boolop: ast.BoolExprTypeOr, Args: &ast.List{Items: []ast.Node{left, right}}, @@ -2038,26 +2266,23 @@ func (c *cc) convertExpr(n *parser.ExprContext) ast.Node { return left } -func (c *cc) convertOrSubExpr(n *parser.Or_subexprContext) ast.Node { - if n == nil { - return nil +func (c *cc) VisitOr_subexpr(n *parser.Or_subexprContext) interface{} { + if n == nil || len(n.AllAnd_subexpr()) == 0 { + return todo("VisitOr_subexpr", n) } - andSubs := n.AllAnd_subexpr() - if len(andSubs) == 0 { - return nil - } - andSub, ok := andSubs[0].(*parser.And_subexprContext) + + left, ok := n.And_subexpr(0).Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitOr_subexpr", n) } - left := c.convertAndSubexpr(andSub) - for i := 1; i < len(andSubs); i++ { - andSub, ok = andSubs[i].(*parser.And_subexprContext) + for i := 1; i < len(n.AllAnd_subexpr()); i++ { + + right, ok := n.And_subexpr(i).Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitOr_subexpr", n) } - right := c.convertAndSubexpr(andSub) + left = &ast.BoolExpr{ Boolop: ast.BoolExprTypeAnd, Args: &ast.List{Items: []ast.Node{left, right}}, @@ -2067,28 +2292,23 @@ func (c *cc) convertOrSubExpr(n *parser.Or_subexprContext) ast.Node { return left } -func (c *cc) convertAndSubexpr(n *parser.And_subexprContext) ast.Node { - if n == nil { - return nil - } - - xors := n.AllXor_subexpr() - if len(xors) == 0 { - return nil +func (c *cc) VisitAnd_subexpr(n *parser.And_subexprContext) interface{} { + if n == nil || len(n.AllXor_subexpr()) == 0 { + return todo("VisitAnd_subexpr", n) } - xor, ok := xors[0].(*parser.Xor_subexprContext) + left, ok := n.Xor_subexpr(0).Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitAnd_subexpr", n) } - left := c.convertXorSubexpr(xor) - for i := 1; i < len(xors); i++ { - xor, ok = xors[i].(*parser.Xor_subexprContext) + for i := 1; i < len(n.AllXor_subexpr()); i++ { + + right, ok := n.Xor_subexpr(i).Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitAnd_subexpr", n) } - right := c.convertXorSubexpr(xor) + left = &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: "XOR"}}}, Lexpr: left, @@ -2099,40 +2319,53 @@ func (c *cc) convertAndSubexpr(n *parser.And_subexprContext) ast.Node { return left } -func (c *cc) convertXorSubexpr(n *parser.Xor_subexprContext) ast.Node { - if n == nil { - return nil - } - es := n.Eq_subexpr() - if es == nil { - return nil +func (c *cc) VisitXor_subexpr(n *parser.Xor_subexprContext) interface{} { + if n == nil || n.Eq_subexpr() == nil { + return todo("VisitXor_subexpr", n) } - subExpr, ok := es.(*parser.Eq_subexprContext) + + base, ok := n.Eq_subexpr().Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitXor_subexpr", n) } - base := c.convertEqSubexpr(subExpr) - if cond := n.Cond_expr(); cond != nil { - condCtx, ok := cond.(*parser.Cond_exprContext) - if !ok { - return base - } + + if condCtx := n.Cond_expr(); condCtx != nil { switch { case condCtx.IN() != nil: if inExpr := condCtx.In_expr(); inExpr != nil { - return &ast.A_Expr{ - Name: &ast.List{Items: []ast.Node{&ast.String{Str: "IN"}}}, - Lexpr: base, - Rexpr: c.convert(inExpr), + temp, ok := inExpr.Accept(c).(ast.Node) + if !ok { + return todo("VisitXor_subexpr", inExpr) + } + list, ok := temp.(*ast.List) + if !ok { + return todo("VisitXor_subexpr", inExpr) + } + return &ast.In{ + Expr: base, + List: list.Items, + Not: condCtx.NOT() != nil, + Location: c.pos(n.GetStart()), } } case condCtx.BETWEEN() != nil: if eqSubs := condCtx.AllEq_subexpr(); len(eqSubs) >= 2 { + + first, ok := eqSubs[0].Accept(c).(ast.Node) + if !ok { + return todo("VisitXor_subexpr", n) + } + + second, ok := eqSubs[1].Accept(c).(ast.Node) + if !ok { + return todo("VisitXor_subexpr", n) + } + return &ast.BetweenExpr{ Expr: base, - Left: c.convert(eqSubs[0]), - Right: c.convert(eqSubs[1]), + Left: first, + Right: second, Not: condCtx.NOT() != nil, Location: c.pos(n.GetStart()), } @@ -2155,7 +2388,7 @@ func (c *cc) convertXorSubexpr(n *parser.Xor_subexprContext) ast.Node { Nulltesttype: 1, // IS NULL Location: c.pos(n.GetStart()), } - case condCtx.IS() != nil && condCtx.NOT() != nil && condCtx.NULL() != nil: + case condCtx.NOT() != nil && condCtx.NULL() != nil: return &ast.NullTest{ Arg: base, Nulltesttype: 2, // IS NOT NULL @@ -2165,10 +2398,16 @@ func (c *cc) convertXorSubexpr(n *parser.Xor_subexprContext) ast.Node { // debug!!! matchOp := condCtx.Match_op().GetText() if eqSubs := condCtx.AllEq_subexpr(); len(eqSubs) >= 1 { + + xpr, ok := eqSubs[0].Accept(c).(ast.Node) + if !ok { + return todo("VisitXor_subexpr", n) + } + expr := &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: matchOp}}}, Lexpr: base, - Rexpr: c.convert(eqSubs[0]), + Rexpr: xpr, } if condCtx.ESCAPE() != nil && len(eqSubs) >= 2 { //nolint // todo: Add ESCAPE support @@ -2177,25 +2416,43 @@ func (c *cc) convertXorSubexpr(n *parser.Xor_subexprContext) ast.Node { } case len(condCtx.AllEQUALS()) > 0 || len(condCtx.AllEQUALS2()) > 0 || len(condCtx.AllNOT_EQUALS()) > 0 || len(condCtx.AllNOT_EQUALS2()) > 0: - // debug!!! - var op string - switch { - case len(condCtx.AllEQUALS()) > 0: - op = "=" - case len(condCtx.AllEQUALS2()) > 0: - op = "==" - case len(condCtx.AllNOT_EQUALS()) > 0: - op = "!=" - case len(condCtx.AllNOT_EQUALS2()) > 0: - op = "<>" - } - if eqSubs := condCtx.AllEq_subexpr(); len(eqSubs) >= 1 { - return &ast.A_Expr{ - Name: &ast.List{Items: []ast.Node{&ast.String{Str: op}}}, - Lexpr: base, - Rexpr: c.convert(eqSubs[0]), + eqSubs := condCtx.AllEq_subexpr() + if len(eqSubs) >= 1 { + left := base + + ops := c.collectEqualityOps(condCtx) + + for i, eqSub := range eqSubs { + right, ok := eqSub.Accept(c).(ast.Node) + if !ok { + return todo("VisitXor_subexpr", condCtx) + } + + var op string + if i < len(ops) { + op = ops[i].GetText() + } else { + if len(condCtx.AllEQUALS()) > 0 { + op = "=" + } else if len(condCtx.AllEQUALS2()) > 0 { + op = "==" + } else if len(condCtx.AllNOT_EQUALS()) > 0 { + op = "!=" + } else if len(condCtx.AllNOT_EQUALS2()) > 0 { + op = "<>" + } + } + + left = &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: op}}}, + Lexpr: left, + Rexpr: right, + Location: c.pos(condCtx.GetStart()), + } } + return left } + return todo("VisitXor_subexpr", condCtx) case len(condCtx.AllDistinct_from_op()) > 0: // debug!!! distinctOps := condCtx.AllDistinct_from_op() @@ -2206,10 +2463,16 @@ func (c *cc) convertXorSubexpr(n *parser.Xor_subexprContext) ast.Node { if not { op = "IS NOT DISTINCT FROM" } + + xpr, ok := eqSubs[0].Accept(c).(ast.Node) + if !ok { + return todo("VisitXor_subexpr", n) + } + return &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: op}}}, Lexpr: base, - Rexpr: c.convert(eqSubs[0]), + Rexpr: xpr, } } } @@ -2218,26 +2481,24 @@ func (c *cc) convertXorSubexpr(n *parser.Xor_subexprContext) ast.Node { return base } -func (c *cc) convertEqSubexpr(n *parser.Eq_subexprContext) ast.Node { - if n == nil { - return nil - } - neqList := n.AllNeq_subexpr() - if len(neqList) == 0 { - return nil +func (c *cc) VisitEq_subexpr(n *parser.Eq_subexprContext) interface{} { + if n == nil || len(n.AllNeq_subexpr()) == 0 { + return todo("VisitEq_subexpr", n) } - neq, ok := neqList[0].(*parser.Neq_subexprContext) + + left, ok := n.Neq_subexpr(0).Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitEq_subexpr", n) } - left := c.convertNeqSubexpr(neq) + ops := c.collectComparisonOps(n) - for i := 1; i < len(neqList); i++ { - neq, ok = neqList[i].(*parser.Neq_subexprContext) + for i := 1; i < len(n.AllNeq_subexpr()); i++ { + + right, ok := n.Neq_subexpr(i).Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitEq_subexpr", n) } - right := c.convertNeqSubexpr(neq) + opText := ops[i-1].GetText() left = &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: opText}}}, @@ -2249,40 +2510,22 @@ func (c *cc) convertEqSubexpr(n *parser.Eq_subexprContext) ast.Node { return left } -func (c *cc) collectComparisonOps(n parser.IEq_subexprContext) []antlr.TerminalNode { - var ops []antlr.TerminalNode - for _, child := range n.GetChildren() { - if tn, ok := child.(antlr.TerminalNode); ok { - switch tn.GetText() { - case "<", "<=", ">", ">=": - ops = append(ops, tn) - } - } +func (c *cc) VisitNeq_subexpr(n *parser.Neq_subexprContext) interface{} { + if n == nil || len(n.AllBit_subexpr()) == 0 { + return todo("VisitNeq_subexpr", n) } - return ops -} -func (c *cc) convertNeqSubexpr(n *parser.Neq_subexprContext) ast.Node { - if n == nil { - return nil - } - bitList := n.AllBit_subexpr() - if len(bitList) == 0 { - return nil - } - - bl, ok := bitList[0].(*parser.Bit_subexprContext) + left, ok := n.Bit_subexpr(0).Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitNeq_subexpr", n) } - left := c.convertBitSubexpr(bl) + ops := c.collectBitwiseOps(n) - for i := 1; i < len(bitList); i++ { - bl, ok = bitList[i].(*parser.Bit_subexprContext) + for i := 1; i < len(n.AllBit_subexpr()); i++ { + right, ok := n.Bit_subexpr(i).Accept(c).(ast.Node) if !ok { - return nil + return todo("VisitNeq_subexpr", n) } - right := c.convertBitSubexpr(bl) opText := ops[i-1].GetText() left = &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: opText}}}, @@ -2293,13 +2536,12 @@ func (c *cc) convertNeqSubexpr(n *parser.Neq_subexprContext) ast.Node { } if n.Double_question() != nil { - nextCtx := n.Neq_subexpr() - if nextCtx != nil { - neq, ok2 := nextCtx.(*parser.Neq_subexprContext) + if nextCtx := n.Neq_subexpr(); nextCtx != nil { + right, ok2 := nextCtx.Accept(c).(ast.Node) if !ok2 { - return nil + return todo("VisitNeq_subexpr", n) } - right := c.convertNeqSubexpr(neq) + left = &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: "??"}}}, Lexpr: left, @@ -2326,28 +2568,24 @@ func (c *cc) convertNeqSubexpr(n *parser.Neq_subexprContext) ast.Node { return left } -func (c *cc) collectBitwiseOps(ctx parser.INeq_subexprContext) []antlr.TerminalNode { - var ops []antlr.TerminalNode - children := ctx.GetChildren() - for _, child := range children { - if tn, ok := child.(antlr.TerminalNode); ok { - txt := tn.GetText() - switch txt { - case "<<", ">>", "<<|", ">>|", "&", "|", "^": - ops = append(ops, tn) - } - } +func (c *cc) VisitBit_subexpr(n *parser.Bit_subexprContext) interface{} { + if n == nil || len(n.AllAdd_subexpr()) == 0 { + return todo("VisitBit_subexpr", n) } - return ops -} -func (c *cc) convertBitSubexpr(n *parser.Bit_subexprContext) ast.Node { - addList := n.AllAdd_subexpr() - left := c.convertAddSubexpr(addList[0].(*parser.Add_subexprContext)) + left, ok := n.Add_subexpr(0).Accept(c).(ast.Node) + if !ok { + return todo("VisitBit_subexpr", n) + } ops := c.collectBitOps(n) - for i := 1; i < len(addList); i++ { - right := c.convertAddSubexpr(addList[i].(*parser.Add_subexprContext)) + for i := 1; i < len(n.AllAdd_subexpr()); i++ { + + right, ok := n.Add_subexpr(i).Accept(c).(ast.Node) + if !ok { + return todo("VisitBit_subexpr", n) + } + opText := ops[i-1].GetText() left = &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: opText}}}, @@ -2359,28 +2597,24 @@ func (c *cc) convertBitSubexpr(n *parser.Bit_subexprContext) ast.Node { return left } -func (c *cc) collectBitOps(ctx parser.IBit_subexprContext) []antlr.TerminalNode { - var ops []antlr.TerminalNode - children := ctx.GetChildren() - for _, child := range children { - if tn, ok := child.(antlr.TerminalNode); ok { - txt := tn.GetText() - switch txt { - case "+", "-": - ops = append(ops, tn) - } - } +func (c *cc) VisitAdd_subexpr(n *parser.Add_subexprContext) interface{} { + if n == nil || len(n.AllMul_subexpr()) == 0 { + return todo("VisitAdd_subexpr", n) } - return ops -} -func (c *cc) convertAddSubexpr(n *parser.Add_subexprContext) ast.Node { - mulList := n.AllMul_subexpr() - left := c.convertMulSubexpr(mulList[0].(*parser.Mul_subexprContext)) + left, ok := n.Mul_subexpr(0).Accept(c).(ast.Node) + if !ok { + return todo("VisitAdd_subexpr", n) + } ops := c.collectAddOps(n) - for i := 1; i < len(mulList); i++ { - right := c.convertMulSubexpr(mulList[i].(*parser.Mul_subexprContext)) + for i := 1; i < len(n.AllMul_subexpr()); i++ { + + right, ok := n.Mul_subexpr(i).Accept(c).(ast.Node) + if !ok { + return todo("VisitAdd_subexpr", n) + } + opText := ops[i-1].GetText() left = &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: opText}}}, @@ -2392,25 +2626,23 @@ func (c *cc) convertAddSubexpr(n *parser.Add_subexprContext) ast.Node { return left } -func (c *cc) collectAddOps(ctx parser.IAdd_subexprContext) []antlr.TerminalNode { - var ops []antlr.TerminalNode - for _, child := range ctx.GetChildren() { - if tn, ok := child.(antlr.TerminalNode); ok { - switch tn.GetText() { - case "*", "/", "%": - ops = append(ops, tn) - } - } +func (c *cc) VisitMul_subexpr(n *parser.Mul_subexprContext) interface{} { + if n == nil || len(n.AllCon_subexpr()) == 0 { + return todo("VisitMul_subexpr", n) } - return ops -} -func (c *cc) convertMulSubexpr(n *parser.Mul_subexprContext) ast.Node { - conList := n.AllCon_subexpr() - left := c.convertConSubexpr(conList[0].(*parser.Con_subexprContext)) + left, ok := n.Con_subexpr(0).Accept(c).(ast.Node) + if !ok { + return todo("VisitMul_subexpr", n) + } + + for i := 1; i < len(n.AllCon_subexpr()); i++ { + + right, ok := n.Con_subexpr(i).Accept(c).(ast.Node) + if !ok { + return todo("VisitMul_subexpr", n) + } - for i := 1; i < len(conList); i++ { - right := c.convertConSubexpr(conList[i].(*parser.Con_subexprContext)) left = &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: "||"}}}, Lexpr: left, @@ -2421,42 +2653,76 @@ func (c *cc) convertMulSubexpr(n *parser.Mul_subexprContext) ast.Node { return left } -func (c *cc) convertConSubexpr(n *parser.Con_subexprContext) ast.Node { +func (c *cc) VisitCon_subexpr(n *parser.Con_subexprContext) interface{} { + if n == nil || (n.Unary_op() == nil && n.Unary_subexpr() == nil) { + return todo("VisitCon_subexpr", n) + } + if opCtx := n.Unary_op(); opCtx != nil { op := opCtx.GetText() - operand := c.convertUnarySubexpr(n.Unary_subexpr().(*parser.Unary_subexprContext)) + operand, ok := n.Unary_subexpr().Accept(c).(ast.Node) + if !ok { + return todo("VisitCon_subexpr", opCtx) + } return &ast.A_Expr{ Name: &ast.List{Items: []ast.Node{&ast.String{Str: op}}}, Rexpr: operand, Location: c.pos(n.GetStart()), } } - return c.convertUnarySubexpr(n.Unary_subexpr().(*parser.Unary_subexprContext)) + + operand, ok := n.Unary_subexpr().Accept(c).(ast.Node) + if !ok { + return todo("VisitCon_subexpr", n.Unary_subexpr()) + } + return operand + } -func (c *cc) convertUnarySubexpr(n *parser.Unary_subexprContext) ast.Node { +func (c *cc) VisitUnary_subexpr(n *parser.Unary_subexprContext) interface{} { + if n == nil || (n.Unary_casual_subexpr() == nil && n.Json_api_expr() == nil) { + return todo("VisitUnary_subexpr", n) + } + if casual := n.Unary_casual_subexpr(); casual != nil { - return c.convertUnaryCasualSubexpr(casual.(*parser.Unary_casual_subexprContext)) + expr, ok := casual.Accept(c).(ast.Node) + if !ok { + return todo("VisitUnary_subexpr", casual) + } + return expr } if jsonExpr := n.Json_api_expr(); jsonExpr != nil { - return c.convertJsonApiExpr(jsonExpr.(*parser.Json_api_exprContext)) + expr, ok := jsonExpr.Accept(c).(ast.Node) + if !ok { + return todo("VisitUnary_subexpr", jsonExpr) + } + return expr } - return nil + + return todo("VisitUnary_subexpr", n) } -func (c *cc) convertJsonApiExpr(n *parser.Json_api_exprContext) ast.Node { - return todo("Json_api_exprContext", n) +func (c *cc) VisitJson_api_expr(n *parser.Json_api_exprContext) interface{} { + return todo("VisitJson_api_expr", n) } -func (c *cc) convertUnaryCasualSubexpr(n *parser.Unary_casual_subexprContext) ast.Node { +func (c *cc) VisitUnary_casual_subexpr(n *parser.Unary_casual_subexprContext) interface{} { var current ast.Node switch { case n.Id_expr() != nil: - current = c.convertIdExpr(n.Id_expr().(*parser.Id_exprContext)) + expr, ok := n.Id_expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitUnary_casual_subexpr", n.Id_expr()) + } + current = expr case n.Atom_expr() != nil: - current = c.convertAtomExpr(n.Atom_expr().(*parser.Atom_exprContext)) + expr, ok := n.Atom_expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitUnary_casual_subexpr", n.Atom_expr()) + } + current = expr default: - return todo("Unary_casual_subexprContext", n) + return todo("VisitUnary_casual_subexpr", n) } if suffix := n.Unary_subexpr_suffix(); suffix != nil { @@ -2478,17 +2744,24 @@ func (c *cc) processSuffixChain(base ast.Node, suffix *parser.Unary_subexpr_suff case antlr.TerminalNode: if elem.GetText() == "." { current = c.handleDotSuffix(current, suffix, &i) + } else { + return todo("Unary_subexpr_suffixContext", suffix) } + default: + return todo("Unary_subexpr_suffixContext", suffix) } } return current } func (c *cc) handleKeySuffix(base ast.Node, keyCtx *parser.Key_exprContext) ast.Node { - keyNode := c.convertKey_exprContext(keyCtx) + keyNode, ok := keyCtx.Accept(c).(ast.Node) + if !ok { + return todo("VisitKey_expr", keyCtx) + } ind, ok := keyNode.(*ast.A_Indirection) if !ok { - return todo("Key_exprContext", keyCtx) + return todo("VisitKey_expr", keyCtx) } if indirection, ok := base.(*ast.A_Indirection); ok { @@ -2505,9 +2778,13 @@ func (c *cc) handleKeySuffix(base ast.Node, keyCtx *parser.Key_exprContext) ast. } func (c *cc) handleInvokeSuffix(base ast.Node, invokeCtx *parser.Invoke_exprContext, idx int) ast.Node { - funcCall, ok := c.convertInvoke_exprContext(invokeCtx).(*ast.FuncCall) + temp, ok := invokeCtx.Accept(c).(ast.Node) if !ok { - return todo("Invoke_exprContext", invokeCtx) + return todo("VisitInvoke_expr", invokeCtx) + } + funcCall, ok := temp.(*ast.FuncCall) + if !ok { + return todo("VisitInvoke_expr", invokeCtx) } if idx == 0 { @@ -2535,7 +2812,7 @@ func (c *cc) handleInvokeSuffix(base ast.Node, invokeCtx *parser.Invoke_exprCont return funcCall } default: - return todo("Invoke_exprContext", invokeCtx) + return todo("VisitInvoke_expr", invokeCtx) } } @@ -2562,14 +2839,18 @@ func (c *cc) handleDotSuffix(base ast.Node, suffix *parser.Unary_subexpr_suffixC var field ast.Node switch v := next.(type) { case *parser.Bind_parameterContext: - field = c.convertBindParameter(v) + temp, ok := v.Accept(c).(ast.Node) + if !ok { + return todo("VisitBind_parameter", v) + } + field = temp case *parser.An_id_or_typeContext: field = &ast.String{Str: parseAnIdOrType(v)} case antlr.TerminalNode: if val, err := parseIntegerValue(v.GetText()); err == nil { field = &ast.A_Const{Val: &ast.Integer{Ival: val}} } else { - return &ast.TODO{} + return todo("Unary_subexpr_suffixContext", suffix) } } @@ -2586,16 +2867,19 @@ func (c *cc) handleDotSuffix(base ast.Node, suffix *parser.Unary_subexpr_suffixC } } -func (c *cc) convertKey_exprContext(n *parser.Key_exprContext) ast.Node { +func (c *cc) VisitKey_expr(n *parser.Key_exprContext) interface{} { if n.LBRACE_SQUARE() == nil || n.RBRACE_SQUARE() == nil || n.Expr() == nil { - return todo("Key_exprContext", n) + return todo("VisitKey_expr", n) } stmt := &ast.A_Indirection{ Indirection: &ast.List{}, } - expr := c.convert(n.Expr()) + expr, ok := n.Expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitKey_expr", n.Expr()) + } stmt.Indirection.Items = append(stmt.Indirection.Items, &ast.A_Indices{ Uidx: expr, @@ -2604,9 +2888,9 @@ func (c *cc) convertKey_exprContext(n *parser.Key_exprContext) ast.Node { return stmt } -func (c *cc) convertInvoke_exprContext(n *parser.Invoke_exprContext) ast.Node { +func (c *cc) VisitInvoke_expr(n *parser.Invoke_exprContext) interface{} { if n.LPAREN() == nil || n.RPAREN() == nil { - return todo("Invoke_exprContext", n) + return todo("VisitInvoke_expr", n) } distinct := false @@ -2625,7 +2909,10 @@ func (c *cc) convertInvoke_exprContext(n *parser.Invoke_exprContext) ast.Node { if nList := n.Named_expr_list(); nList != nil { for _, namedExpr := range nList.AllNamed_expr() { name := parseAnIdOrType(namedExpr.An_id_or_type()) - expr := c.convert(namedExpr.Expr()) + expr, ok := namedExpr.Expr().Accept(c).(ast.Node) + if !ok { + return todo("VisitInvoke_expr", namedExpr.Expr()) + } var res ast.Node if rt, ok := expr.(*ast.ResTarget); ok { @@ -2652,7 +2939,10 @@ func (c *cc) convertInvoke_exprContext(n *parser.Invoke_exprContext) ast.Node { return stmt } -func (c *cc) convertIdExpr(n *parser.Id_exprContext) ast.Node { +func (c *cc) VisitId_expr(n *parser.Id_exprContext) interface{} { + if n == nil { + return todo("VisitId_expr", n) + } if id := n.Identifier(); id != nil { return &ast.ColumnRef{ Fields: &ast.List{ @@ -2663,25 +2953,43 @@ func (c *cc) convertIdExpr(n *parser.Id_exprContext) ast.Node { Location: c.pos(id.GetStart()), } } - return &ast.TODO{} + return todo("VisitId_expr", n) } -func (c *cc) convertAtomExpr(n *parser.Atom_exprContext) ast.Node { +func (c *cc) VisitAtom_expr(n *parser.Atom_exprContext) interface{} { + if n == nil { + return todo("VisitAtom_expr", n) + } + switch { - case n.An_id_or_type() != nil && n.NAMESPACE() != nil: - return NewIdentifier(parseAnIdOrType(n.An_id_or_type()) + "::" + parseIdOrType(n.Id_or_type())) case n.An_id_or_type() != nil: + if n.NAMESPACE() != nil { + return NewIdentifier(parseAnIdOrType(n.An_id_or_type()) + "::" + parseIdOrType(n.Id_or_type())) + } return NewIdentifier(parseAnIdOrType(n.An_id_or_type())) case n.Literal_value() != nil: - return c.convertLiteralValue(n.Literal_value().(*parser.Literal_valueContext)) + expr, ok := n.Literal_value().Accept(c).(ast.Node) + if !ok { + return todo("VisitAtom_expr", n.Literal_value()) + } + return expr case n.Bind_parameter() != nil: - return c.convertBindParameter(n.Bind_parameter().(*parser.Bind_parameterContext)) + expr, ok := n.Bind_parameter().Accept(c).(ast.Node) + if !ok { + return todo("VisitAtom_expr", n.Bind_parameter()) + } + return expr + // TODO: check other cases default: - return &ast.TODO{} + return todo("VisitAtom_expr", n) } } -func (c *cc) convertLiteralValue(n *parser.Literal_valueContext) ast.Node { +func (c *cc) VisitLiteral_value(n *parser.Literal_valueContext) interface{} { + if n == nil { + return todo("VisitLiteral_value", n) + } + switch { case n.Integer() != nil: text := n.Integer().GetText() @@ -2690,7 +2998,7 @@ func (c *cc) convertLiteralValue(n *parser.Literal_valueContext) ast.Node { if debug.Active { log.Printf("Failed to parse integer value '%s': %v", text, err) } - return &ast.TODO{} + return todo("VisitLiteral_value", n.Integer()) } return &ast.A_Const{Val: &ast.Integer{Ival: val}, Location: c.pos(n.GetStart())} @@ -2716,22 +3024,16 @@ func (c *cc) convertLiteralValue(n *parser.Literal_valueContext) ast.Node { return &ast.Null{} case n.CURRENT_TIME() != nil: - if debug.Active { - log.Printf("TODO: Implement CURRENT_TIME") - } - return &ast.TODO{} + log.Fatalf("CURRENT_TIME is not supported yet") + return todo("VisitLiteral_value", n) case n.CURRENT_DATE() != nil: - if debug.Active { - log.Printf("TODO: Implement CURRENT_DATE") - } - return &ast.TODO{} + log.Fatalf("CURRENT_DATE is not supported yet") + return todo("VisitLiteral_value", n) case n.CURRENT_TIMESTAMP() != nil: - if debug.Active { - log.Printf("TODO: Implement CURRENT_TIMESTAMP") - } - return &ast.TODO{} + log.Fatalf("CURRENT_TIMESTAMP is not supported yet") + return todo("VisitLiteral_value", n) case n.BLOB() != nil: blobText := n.BLOB().GetText() @@ -2744,205 +3046,36 @@ func (c *cc) convertLiteralValue(n *parser.Literal_valueContext) ast.Node { return &ast.TODO{} default: - if debug.Active { - log.Printf("Unknown literal value type: %T", n) - } - return &ast.TODO{} + return todo("VisitLiteral_value", n) } } -func (c *cc) convertSqlStmt(n *parser.Sql_stmtContext) ast.Node { - if n == nil { - return nil - } - // todo: handle explain - if core := n.Sql_stmt_core(); core != nil { - return c.convert(core) +func (c *cc) VisitSql_stmt(n *parser.Sql_stmtContext) interface{} { + if n == nil || n.Sql_stmt_core() == nil { + return todo("VisitSql_stmt", n) } - return nil -} - -func (c *cc) convert(node node) ast.Node { - switch n := node.(type) { - case *parser.Sql_stmtContext: - return c.convertSqlStmt(n) - - case *parser.Sql_stmt_coreContext: - return c.convertSqlStmtCore(n) - - case *parser.Create_table_stmtContext: - return c.convertCreate_table_stmtContext(n) - - case *parser.Select_stmtContext: - return c.convertSelectStmtContext(n) - - case *parser.Result_columnContext: - return c.convertResultColumn(n) - - case *parser.Join_sourceContext: - return c.convertJoinSource(n) - - case *parser.Flatten_sourceContext: - return c.convertFlattenSource(n) - - case *parser.Named_single_sourceContext: - return c.convertNamedSingleSource(n) - - case *parser.Single_sourceContext: - return c.convertSingleSource(n) - - case *parser.Bind_parameterContext: - return c.convertBindParameter(n) - - case *parser.ExprContext: - return c.convertExpr(n) - - case *parser.Or_subexprContext: - return c.convertOrSubExpr(n) - - case *parser.And_subexprContext: - return c.convertAndSubexpr(n) - - case *parser.Xor_subexprContext: - return c.convertXorSubexpr(n) - - case *parser.Eq_subexprContext: - return c.convertEqSubexpr(n) - - case *parser.Neq_subexprContext: - return c.convertNeqSubexpr(n) - - case *parser.Bit_subexprContext: - return c.convertBitSubexpr(n) - - case *parser.Add_subexprContext: - return c.convertAddSubexpr(n) - - case *parser.Mul_subexprContext: - return c.convertMulSubexpr(n) - - case *parser.Con_subexprContext: - return c.convertConSubexpr(n) - - case *parser.Unary_subexprContext: - return c.convertUnarySubexpr(n) - - case *parser.Unary_casual_subexprContext: - return c.convertUnaryCasualSubexpr(n) - - case *parser.Id_exprContext: - return c.convertIdExpr(n) - - case *parser.Atom_exprContext: - return c.convertAtomExpr(n) - - case *parser.Literal_valueContext: - return c.convertLiteralValue(n) - - case *parser.Json_api_exprContext: - return c.convertJsonApiExpr(n) - - case *parser.Type_name_compositeContext: - return c.convertTypeNameComposite(n) - - case *parser.Type_nameContext: - return c.convertTypeName(n) - - case *parser.Integer_or_bindContext: - return c.convertIntegerOrBind(n) - - case *parser.Type_name_or_bindContext: - return c.convertTypeNameOrBind(n) - - case *parser.Into_table_stmtContext: - return c.convertInto_table_stmtContext(n) - - case *parser.Values_stmtContext: - return c.convertValues_stmtContext(n) - - case *parser.Returning_columns_listContext: - return c.convertReturning_columns_listContext(n) - - case *parser.Delete_stmtContext: - return c.convertDelete_stmtContext(n) - - case *parser.Update_stmtContext: - return c.convertUpdate_stmtContext(n) - - case *parser.Alter_table_stmtContext: - return c.convertAlter_table_stmtContext(n) - - case *parser.Do_stmtContext: - return c.convertDo_stmtContext(n) - - case *parser.Drop_table_stmtContext: - return c.convertDrop_table_stmtContext(n) - - case *parser.Commit_stmtContext: - return c.convertCommit_stmtContext(n) - - case *parser.Rollback_stmtContext: - return c.convertRollback_stmtContext(n) - - case *parser.Pragma_valueContext: - return c.convertPragma_valueContext(n) - - case *parser.Pragma_stmtContext: - return c.convertPragma_stmtContext(n) - - case *parser.Use_stmtContext: - return c.convertUse_stmtContext(n) - - case *parser.Cluster_exprContext: - return c.convertCluster_exprContext(n) - - case *parser.Create_user_stmtContext: - return c.convertCreate_user_stmtContext(n) - - case *parser.Role_nameContext: - return c.convertRole_nameContext(n) - - case *parser.User_optionContext: - return c.convertUser_optionContext(n) - - case *parser.Create_group_stmtContext: - return c.convertCreate_group_stmtContext(n) - - case *parser.Alter_user_stmtContext: - return c.convertAlter_user_stmtContext(n) - - case *parser.Alter_group_stmtContext: - return c.convertAlter_group_stmtContext(n) - - case *parser.Drop_role_stmtContext: - return c.convertDrop_role_stmtContext(n) - - case *parser.Grouping_elementContext: - return c.convertGrouping_elementContext(n) - - case *parser.Ordinary_grouping_setContext: - return c.convertOrdinary_grouping_setContext(n) - - case *parser.Rollup_listContext: - return c.convertRollup_listContext(n) - - case *parser.Cube_listContext: - return c.convertCube_listContext(n) - - case *parser.Grouping_sets_specificationContext: - return c.convertGrouping_sets_specificationContext(n) - - case *parser.Named_exprContext: - return c.convertNamed_exprContext(n) + expr, ok := n.Sql_stmt_core().Accept(c).(ast.Node) + if !ok { + return todo("VisitSql_stmt", n.Sql_stmt_core()) + } - case *parser.Call_actionContext: - return c.convertCall_actionContext(n) + if n.EXPLAIN() != nil { + options := &ast.List{Items: []ast.Node{}} - case *parser.Inline_actionContext: - return c.convertInline_actionContext(n) + if n.QUERY() != nil && n.PLAN() != nil { + queryPlan := "QUERY PLAN" + options.Items = append(options.Items, &ast.DefElem{ + Defname: &queryPlan, + Arg: &ast.TODO{}, + }) + } - default: - return todo("convert(case=default)", n) + return &ast.ExplainStmt{ + Query: expr, + Options: options, + } } + + return expr } diff --git a/internal/engine/ydb/parse.go b/internal/engine/ydb/parse.go index 1c263924a5..8fbdd81ebb 100755 --- a/internal/engine/ydb/parse.go +++ b/internal/engine/ydb/parse.go @@ -64,7 +64,10 @@ func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) { loc := 0 for _, stmt := range stmtListCtx.AllSql_stmt() { converter := &cc{content: string(blob)} - out := converter.convert(stmt) + out, ok := stmt.Accept(converter).(ast.Node) + if !ok { + return nil, fmt.Errorf("expected ast.Node; got %T", out) + } if _, ok := out.(*ast.TODO); ok { loc = byteOffset(content, stmt.GetStop().GetStop() + 2) continue diff --git a/internal/engine/ydb/utils.go b/internal/engine/ydb/utils.go index f2023e8ba9..8f118df09b 100755 --- a/internal/engine/ydb/utils.go +++ b/internal/engine/ydb/utils.go @@ -156,7 +156,13 @@ func parseIntegerValue(text string) (int64, error) { } func (c *cc) extractRoleSpec(n parser.IRole_nameContext, roletype ast.RoleSpecType) (*ast.RoleSpec, bool, ast.Node) { - roleNode := c.convert(n) + if n == nil { + return nil, false, nil + } + roleNode, ok := n.Accept(c).(ast.Node) + if !ok { + return nil, false, nil + } roleSpec := &ast.RoleSpec{ Roletype: roletype, @@ -219,3 +225,73 @@ func emptySelectStmt() *ast.SelectStmt { LockingClause: &ast.List{}, } } + +func (c *cc) collectComparisonOps(n parser.IEq_subexprContext) []antlr.TerminalNode { + var ops []antlr.TerminalNode + for _, child := range n.GetChildren() { + if tn, ok := child.(antlr.TerminalNode); ok { + switch tn.GetText() { + case "<", "<=", ">", ">=": + ops = append(ops, tn) + } + } + } + return ops +} + +func (c *cc) collectBitwiseOps(ctx parser.INeq_subexprContext) []antlr.TerminalNode { + var ops []antlr.TerminalNode + children := ctx.GetChildren() + for _, child := range children { + if tn, ok := child.(antlr.TerminalNode); ok { + txt := tn.GetText() + switch txt { + case "<<", ">>", "<<|", ">>|", "&", "|", "^": + ops = append(ops, tn) + } + } + } + return ops +} + +func (c *cc) collectBitOps(ctx parser.IBit_subexprContext) []antlr.TerminalNode { + var ops []antlr.TerminalNode + children := ctx.GetChildren() + for _, child := range children { + if tn, ok := child.(antlr.TerminalNode); ok { + txt := tn.GetText() + switch txt { + case "+", "-": + ops = append(ops, tn) + } + } + } + return ops +} + +func (c *cc) collectAddOps(ctx parser.IAdd_subexprContext) []antlr.TerminalNode { + var ops []antlr.TerminalNode + for _, child := range ctx.GetChildren() { + if tn, ok := child.(antlr.TerminalNode); ok { + switch tn.GetText() { + case "*", "/", "%": + ops = append(ops, tn) + } + } + } + return ops +} + +func (c *cc) collectEqualityOps(ctx parser.ICond_exprContext) []antlr.TerminalNode { + var ops []antlr.TerminalNode + children := ctx.GetChildren() + for _, child := range children { + if tn, ok := child.(antlr.TerminalNode); ok { + switch tn.GetText() { + case "=", "==", "!=", "<>": + ops = append(ops, tn) + } + } + } + return ops +}