diff --git a/experimental/aitools/README.md b/experimental/aitools/README.md index f645e4de51..ec12ed10f7 100644 --- a/experimental/aitools/README.md +++ b/experimental/aitools/README.md @@ -10,6 +10,10 @@ Current commands: - `databricks experimental aitools tools query` - `databricks experimental aitools tools discover-schema` - `databricks experimental aitools tools get-default-warehouse` +- `databricks experimental aitools tools statement submit` +- `databricks experimental aitools tools statement get` +- `databricks experimental aitools tools statement status` +- `databricks experimental aitools tools statement cancel` Current behavior: @@ -29,6 +33,19 @@ Current behavior: "SELECT vendor_id, count(*) FROM samples.nyctaxi.trips GROUP BY 1" ``` +- `tools statement` is a low-level lifecycle for asynchronous statements. + `submit` returns a `statement_id` immediately, `get` polls until terminal + and emits rows, `status` peeks without blocking, and `cancel` requests + termination. Ctrl+C on `get` stops polling but does NOT cancel the + server-side statement; use `cancel` for that. + + ```bash + SID=$(databricks experimental aitools tools statement submit \ + --warehouse "SELECT pg_sleep(5)" | jq -r '.statement_id') + databricks experimental aitools tools statement status "$SID" + databricks experimental aitools tools statement get "$SID" + ``` + Removed behavior: - there is no MCP server under `experimental aitools` diff --git a/experimental/aitools/cmd/statement.go b/experimental/aitools/cmd/statement.go new file mode 100644 index 0000000000..e1c48a7ddb --- /dev/null +++ b/experimental/aitools/cmd/statement.go @@ -0,0 +1,77 @@ +package aitools + +import ( + "encoding/json" + "fmt" + "io" + + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/spf13/cobra" +) + +// statementInfo is the JSON shape emitted by every `tools statement` +// subcommand. Fields are populated as the subcommand has them. omitempty keeps +// the output tight: `submit` doesn't emit columns/rows, `cancel` doesn't emit a +// warehouse_id, etc. +type statementInfo struct { + StatementID string `json:"statement_id"` + State sql.StatementState `json:"state,omitempty"` + WarehouseID string `json:"warehouse_id,omitempty"` + Columns []string `json:"columns,omitempty"` + Rows [][]string `json:"rows,omitempty"` + Error *batchResultError `json:"error,omitempty"` +} + +func renderStatementInfo(w io.Writer, info statementInfo) error { + data, err := json.MarshalIndent(info, "", " ") + if err != nil { + return fmt.Errorf("marshal statement info: %w", err) + } + fmt.Fprintf(w, "%s\n", data) + return nil +} + +// statementErrorFromStatus builds a batchResultError for any terminal non-success +// state (FAILED, CANCELED, CLOSED), populating it from the server's ServiceError +// when available and synthesizing a message when it isn't. Returns nil for +// SUCCEEDED, non-terminal states, and nil status. The synthesized fallback +// matters because the Statements API can hand back a non-success terminal state +// with `Error == nil`, and skill consumers should be able to branch on +// `error == null` alone instead of inspecting `state`. +func statementErrorFromStatus(status *sql.StatementStatus) *batchResultError { + if status == nil || !isTerminalState(status) || status.State == sql.StatementStateSucceeded { + return nil + } + out := &batchResultError{} + if status.Error != nil { + out.Message = status.Error.Message + out.ErrorCode = string(status.Error.ErrorCode) + } else { + out.Message = fmt.Sprintf("statement reached terminal state %s", status.State) + } + return out +} + +func newStatementCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "statement", + Short: "Manage SQL statement lifecycle (submit, get, status, cancel)", + Long: `Low-level command tree for asynchronous SQL execution. + +Use 'submit' to fire a statement and get its statement_id back, then +'get' to block on results, 'status' to peek without blocking, and +'cancel' to terminate. For "I want results now," use 'tools query' +instead. + +All subcommands emit a JSON object with the statement_id and state. +'get' adds columns and rows on success; any subcommand may emit an +error object when the server reports a non-success terminal state.`, + } + + cmd.AddCommand(newStatementSubmitCmd()) + cmd.AddCommand(newStatementGetCmd()) + cmd.AddCommand(newStatementStatusCmd()) + cmd.AddCommand(newStatementCancelCmd()) + + return cmd +} diff --git a/experimental/aitools/cmd/statement_cancel.go b/experimental/aitools/cmd/statement_cancel.go new file mode 100644 index 0000000000..1774b7abe6 --- /dev/null +++ b/experimental/aitools/cmd/statement_cancel.go @@ -0,0 +1,53 @@ +package aitools + +import ( + "context" + "fmt" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/spf13/cobra" +) + +func newStatementCancelCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "cancel STATEMENT_ID", + Short: "Request cancellation of a running statement", + Long: `Send a cancellation request for the given statement_id. The Statements +API returns no body on cancel; this command optimistically reports +state=CANCELED on success. Use 'statement status' afterwards to confirm +the server-side state if you need certainty.`, + Example: ` databricks experimental aitools tools statement cancel 01ef...`, + Args: cobra.ExactArgs(1), + PreRunE: root.MustWorkspaceClient, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + w := cmdctx.WorkspaceClient(ctx) + statementID := args[0] + + info, err := cancelStatementExecution(ctx, w.StatementExecution, statementID) + if err != nil { + return err + } + return renderStatementInfo(cmd.OutOrStdout(), info) + }, + } + + return cmd +} + +// cancelStatementExecution issues CancelExecution and reports state=CANCELED on success. +// CancelExecution returns no body; the actual server-side state is verified +// asynchronously. Use 'statement status' to confirm if certainty is required. +func cancelStatementExecution(ctx context.Context, api sql.StatementExecutionInterface, statementID string) (statementInfo, error) { + if err := api.CancelExecution(ctx, sql.CancelExecutionRequest{ + StatementId: statementID, + }); err != nil { + return statementInfo{}, fmt.Errorf("cancel statement: %w", err) + } + return statementInfo{ + StatementID: statementID, + State: sql.StatementStateCanceled, + }, nil +} diff --git a/experimental/aitools/cmd/statement_get.go b/experimental/aitools/cmd/statement_get.go new file mode 100644 index 0000000000..617b5c274d --- /dev/null +++ b/experimental/aitools/cmd/statement_get.go @@ -0,0 +1,96 @@ +package aitools + +import ( + "context" + "fmt" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/spf13/cobra" +) + +func newStatementGetCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "get STATEMENT_ID", + Short: "Block until a previously submitted statement is terminal and emit its result", + Long: `Poll a statement_id until it reaches a terminal state, then emit +columns and rows on success or an error object on failure. + +Ctrl+C stops polling but does NOT cancel the server-side statement. +Use 'statement cancel ' to terminate explicitly. (This differs from +'tools query', which cancels server-side on Ctrl+C because the user +invoked the synchronous path.)`, + Example: ` databricks experimental aitools tools statement get 01ef...`, + Args: cobra.ExactArgs(1), + PreRunE: root.MustWorkspaceClient, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + w := cmdctx.WorkspaceClient(ctx) + statementID := args[0] + + info, err := getStatementResult(ctx, w.StatementExecution, statementID) + if err != nil { + return err + } + + if err := renderStatementInfo(cmd.OutOrStdout(), info); err != nil { + return err + } + + // Non-zero exit when the statement reached a non-success terminal + // state OR a chunk-fetch failure prevented assembling the rows. + // In both cases the failure detail is already in the JSON output. + if info.State != sql.StatementStateSucceeded || info.Error != nil { + return root.ErrAlreadyPrinted + } + return nil + }, + } + + return cmd +} + +// getStatementResult polls a statement until terminal, then assembles a +// statementInfo with rows on success or an error object on failure. +// +// Context cancellation propagates from pollStatement WITHOUT cancelling the +// server-side statement (intentional: 'get' is a poll-only operation; use +// 'cancel' to terminate explicitly). +func getStatementResult(ctx context.Context, api sql.StatementExecutionInterface, statementID string) (statementInfo, error) { + // Fetch the current state first so pollStatement can short-circuit if + // the statement is already terminal. + resp, err := api.GetStatementByStatementId(ctx, statementID) + if err != nil { + return statementInfo{}, fmt.Errorf("get statement: %w", err) + } + + pollResp, err := pollStatement(ctx, api, resp) + if err != nil { + return statementInfo{}, err + } + + info := statementInfo{StatementID: pollResp.StatementId} + if pollResp.Status != nil { + info.State = pollResp.Status.State + } + info.Error = statementErrorFromStatus(pollResp.Status) + + if info.State == sql.StatementStateSucceeded { + info.Columns = extractColumns(pollResp.Manifest) + rows, err := fetchAllRows(ctx, api, pollResp) + if err != nil { + // The query succeeded server-side but a later chunk fetch failed + // (network blip, throttling, transient 5xx). Surface this as a + // structured error on the same statementInfo so the caller still + // gets a parseable JSON response with the statement_id; RunE then + // signals exit-non-zero based on info.Error. + info.Error = &batchResultError{ + Message: fmt.Sprintf("fetch result rows: %v", err), + } + return info, nil + } + info.Rows = rows + } + return info, nil +} diff --git a/experimental/aitools/cmd/statement_status.go b/experimental/aitools/cmd/statement_status.go new file mode 100644 index 0000000000..9981f49aa6 --- /dev/null +++ b/experimental/aitools/cmd/statement_status.go @@ -0,0 +1,52 @@ +package aitools + +import ( + "context" + "fmt" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/spf13/cobra" +) + +func newStatementStatusCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "status STATEMENT_ID", + Short: "Return the current state of a statement without polling", + Long: `Single GET against the Statements API. Use this to peek at progress +without blocking. For a blocking poll-until-terminal call, use +'statement get'.`, + Example: ` databricks experimental aitools tools statement status 01ef...`, + Args: cobra.ExactArgs(1), + PreRunE: root.MustWorkspaceClient, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + w := cmdctx.WorkspaceClient(ctx) + statementID := args[0] + + info, err := getStatementStatus(ctx, w.StatementExecution, statementID) + if err != nil { + return err + } + return renderStatementInfo(cmd.OutOrStdout(), info) + }, + } + + return cmd +} + +// getStatementStatus performs a single GET against the Statements API, no polling. +func getStatementStatus(ctx context.Context, api sql.StatementExecutionInterface, statementID string) (statementInfo, error) { + resp, err := api.GetStatementByStatementId(ctx, statementID) + if err != nil { + return statementInfo{}, fmt.Errorf("get statement: %w", err) + } + + info := statementInfo{StatementID: resp.StatementId} + if resp.Status != nil { + info.State = resp.Status.State + } + info.Error = statementErrorFromStatus(resp.Status) + return info, nil +} diff --git a/experimental/aitools/cmd/statement_submit.go b/experimental/aitools/cmd/statement_submit.go new file mode 100644 index 0000000000..ac8bf424e5 --- /dev/null +++ b/experimental/aitools/cmd/statement_submit.go @@ -0,0 +1,94 @@ +package aitools + +import ( + "context" + "errors" + "fmt" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/spf13/cobra" +) + +func newStatementSubmitCmd() *cobra.Command { + var warehouseID string + var filePath string + // resolved by PreRunE so input validation runs before any auth/profile + // work and the documented "validates input before WorkspaceClient" claim + // in the PR description is actually true. + var sqlStatement string + + cmd := &cobra.Command{ + Use: "submit [SQL | file.sql]", + Short: "Submit a SQL statement asynchronously and return its statement_id", + Long: `Submit a SQL statement to a Databricks SQL warehouse and return its +statement_id immediately, without waiting for results. + +The statement keeps running server-side. Harvest results with +'statement get ', inspect with 'statement status ', or stop +with 'statement cancel '.`, + Example: ` databricks experimental aitools tools statement submit "SELECT pg_sleep(60)" --warehouse + databricks experimental aitools tools statement submit --file query.sql`, + Args: cobra.MaximumNArgs(1), + PreRunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + var fps []string + if filePath != "" { + fps = []string{filePath} + } + sqls, err := resolveSQLs(ctx, cmd, args, fps) + if err != nil { + return err + } + if len(sqls) != 1 { + return errors.New("submit accepts exactly one SQL statement; pass multiple to 'query' for batch") + } + sqlStatement = sqls[0] + + return root.MustWorkspaceClient(cmd, args) + }, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + w := cmdctx.WorkspaceClient(ctx) + wID, err := resolveWarehouseID(ctx, w, warehouseID) + if err != nil { + return err + } + + info, err := submitStatement(ctx, w.StatementExecution, sqlStatement, wID) + if err != nil { + return err + } + return renderStatementInfo(cmd.OutOrStdout(), info) + }, + } + + cmd.Flags().StringVarP(&warehouseID, "warehouse", "w", "", "SQL warehouse ID to use for execution") + cmd.Flags().StringVarP(&filePath, "file", "f", "", "Path to a SQL file to execute") + + return cmd +} + +// submitStatement issues an asynchronous ExecuteStatement and returns the handle. +func submitStatement(ctx context.Context, api sql.StatementExecutionInterface, statement, warehouseID string) (statementInfo, error) { + resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ + WarehouseId: warehouseID, + Statement: statement, + WaitTimeout: "0s", + OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, + }) + if err != nil { + return statementInfo{}, fmt.Errorf("execute statement: %w", err) + } + + info := statementInfo{ + StatementID: resp.StatementId, + WarehouseID: warehouseID, + } + if resp.Status != nil { + info.State = resp.Status.State + } + return info, nil +} diff --git a/experimental/aitools/cmd/statement_test.go b/experimental/aitools/cmd/statement_test.go new file mode 100644 index 0000000000..9c2264daf2 --- /dev/null +++ b/experimental/aitools/cmd/statement_test.go @@ -0,0 +1,352 @@ +package aitools + +import ( + "context" + "errors" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/databricks/cli/libs/cmdio" + mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestSubmitStatementReturnsHandle(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.WarehouseId == "wh-1" && req.Statement == "SELECT 1" && + req.WaitTimeout == "0s" && + req.OnWaitTimeout == sql.ExecuteStatementRequestOnWaitTimeoutContinue + })).Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + }, nil).Once() + + info, err := submitStatement(ctx, mockAPI, "SELECT 1", "wh-1") + require.NoError(t, err) + assert.Equal(t, "stmt-1", info.StatementID) + assert.Equal(t, sql.StatementStatePending, info.State) + assert.Equal(t, "wh-1", info.WarehouseID) +} + +func TestSubmitStatementWrapsTransportError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.Anything). + Return(nil, errors.New("network unreachable")).Once() + + _, err := submitStatement(ctx, mockAPI, "SELECT 1", "wh-1") + require.Error(t, err) + assert.Contains(t, err.Error(), "execute statement") + assert.Contains(t, err.Error(), "network unreachable") +} + +func TestGetStatementResultPolls(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateRunning}, + }, nil).Once() + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "n"}}}, TotalChunkCount: 1}, + Result: &sql.ResultData{DataArray: [][]string{{"42"}}}, + }, nil).Once() + + info, err := getStatementResult(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, "stmt-1", info.StatementID) + assert.Equal(t, sql.StatementStateSucceeded, info.State) + assert.Equal(t, []string{"n"}, info.Columns) + assert.Equal(t, [][]string{{"42"}}, info.Rows) + assert.Nil(t, info.Error) +} + +func TestGetStatementResultFailedStateReportsError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ + ErrorCode: "SYNTAX_ERROR", + Message: "near 'bad': syntax error", + }, + }, + }, nil).Once() + + info, err := getStatementResult(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, sql.StatementStateFailed, info.State) + assert.Nil(t, info.Rows) + require.NotNil(t, info.Error) + assert.Equal(t, "SYNTAX_ERROR", info.Error.ErrorCode) + assert.Contains(t, info.Error.Message, "syntax error") +} + +func TestGetStatementResultDoesNotCancelServerSideOnContextCancel(t *testing.T) { + // 'statement get' is a poll-only operation: ctx cancellation must NOT + // trigger CancelExecution. The mock asserts (via t.Cleanup) that no + // unexpected calls happen. + ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context())) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + }, nil).Once() + + cancel() + + _, err := getStatementResult(ctx, mockAPI, "stmt-1") + require.ErrorIs(t, err, context.Canceled) +} + +func TestGetStatementResultChunkFetchFailureRendersPartialInfo(t *testing.T) { + // SUCCEEDED state but a later chunk fetch fails (network blip, throttle, + // 5xx). getStatementResult should surface this as a structured error on + // the same statementInfo so the caller still gets parseable JSON with the + // statement_id, instead of returning a raw Go error that RunE would + // discard along with the populated info. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{ + Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "n"}}}, + TotalChunkCount: 2, + }, + Result: &sql.ResultData{DataArray: [][]string{{"1"}}}, + }, nil).Once() + + mockAPI.EXPECT().GetStatementResultChunkNByStatementIdAndChunkIndex(mock.Anything, "stmt-1", 1). + Return(nil, errors.New("network blip")).Once() + + info, err := getStatementResult(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, sql.StatementStateSucceeded, info.State) + assert.Equal(t, []string{"n"}, info.Columns, "columns from the initial response are still surfaced") + require.NotNil(t, info.Error) + assert.Contains(t, info.Error.Message, "fetch result rows") + assert.Contains(t, info.Error.Message, "network blip") +} + +func TestGetStatementStatusSinglePoll(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateRunning}, + }, nil).Once() + + info, err := getStatementStatus(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, "stmt-1", info.StatementID) + assert.Equal(t, sql.StatementStateRunning, info.State) + assert.Nil(t, info.Error) +} + +func TestGetStatementStatusReportsError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ + ErrorCode: "TIMEOUT", + Message: "warehouse timed out", + }, + }, + }, nil).Once() + + info, err := getStatementStatus(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, sql.StatementStateFailed, info.State) + require.NotNil(t, info.Error) + assert.Equal(t, "TIMEOUT", info.Error.ErrorCode) +} + +func TestCancelStatementExecutionCallsAPI(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().CancelExecution(mock.Anything, sql.CancelExecutionRequest{ + StatementId: "stmt-1", + }).Return(nil).Once() + + info, err := cancelStatementExecution(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, "stmt-1", info.StatementID) + assert.Equal(t, sql.StatementStateCanceled, info.State) +} + +func TestCancelStatementExecutionWrapsAPIError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().CancelExecution(mock.Anything, mock.Anything). + Return(errors.New("not found")).Once() + + _, err := cancelStatementExecution(ctx, mockAPI, "stmt-1") + require.Error(t, err) + assert.Contains(t, err.Error(), "cancel statement") + assert.Contains(t, err.Error(), "not found") +} + +func TestRenderStatementInfo(t *testing.T) { + tests := []struct { + name string + info statementInfo + mustHave []string + mustNotHave []string + }{ + { + name: "full payload renders every populated field", + info: statementInfo{ + StatementID: "stmt-1", + State: sql.StatementStateSucceeded, + WarehouseID: "wh-1", + Columns: []string{"n"}, + Rows: [][]string{{"42"}}, + }, + mustHave: []string{ + `"statement_id": "stmt-1"`, + `"state": "SUCCEEDED"`, + `"warehouse_id": "wh-1"`, + `"columns": [`, + `"rows": [`, + }, + }, + { + name: "cancel-style payload omits unset fields", + info: statementInfo{ + StatementID: "stmt-1", + State: sql.StatementStateCanceled, + }, + mustHave: []string{ + `"statement_id": "stmt-1"`, + `"state": "CANCELED"`, + }, + mustNotHave: []string{`"warehouse_id"`, `"columns"`, `"rows"`, `"error"`}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var buf strings.Builder + require.NoError(t, renderStatementInfo(&buf, tc.info)) + out := buf.String() + for _, want := range tc.mustHave { + assert.Contains(t, out, want) + } + for _, missing := range tc.mustNotHave { + assert.NotContains(t, out, missing) + } + assert.True(t, strings.HasSuffix(out, "\n")) + }) + } +} + +func TestStatementSubmitRejectsMultipleSQLsBeforeWorkspaceClient(t *testing.T) { + // The "exactly one SQL" check runs in PreRunE BEFORE MustWorkspaceClient, + // so a malformed invocation is rejected without any auth/profile work. + // The test relies on this ordering: it does not stub out PreRunE, so if + // validation moved back after MustWorkspaceClient the test would panic + // on a missing workspace client instead of returning the validation error. + dir := t.TempDir() + path := filepath.Join(dir, "test.sql") + require.NoError(t, os.WriteFile(path, []byte("SELECT 1"), 0o644)) + + cmd := newStatementSubmitCmd() + cmd.SetArgs([]string{"--file", path, "SELECT 2"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "exactly one") +} + +func TestStatementErrorFromStatus(t *testing.T) { + tests := []struct { + name string + status *sql.StatementStatus + wantNil bool + wantMsg string + wantCode string + }{ + { + name: "nil status", + status: nil, + wantNil: true, + }, + { + name: "succeeded never produces an error", + status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + wantNil: true, + }, + { + name: "running is not terminal", + status: &sql.StatementStatus{State: sql.StatementStateRunning}, + wantNil: true, + }, + { + name: "pending is not terminal", + status: &sql.StatementStatus{State: sql.StatementStatePending}, + wantNil: true, + }, + { + name: "failed with backend error preserves both fields", + status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ErrorCode: "SYNTAX_ERROR", Message: "near 'bad'"}, + }, + wantMsg: "near 'bad'", + wantCode: "SYNTAX_ERROR", + }, + { + name: "failed without backend error synthesizes message", + status: &sql.StatementStatus{State: sql.StatementStateFailed}, + wantMsg: "statement reached terminal state FAILED", + }, + { + name: "canceled without backend error synthesizes message", + status: &sql.StatementStatus{State: sql.StatementStateCanceled}, + wantMsg: "statement reached terminal state CANCELED", + }, + { + name: "closed without backend error synthesizes message", + status: &sql.StatementStatus{State: sql.StatementStateClosed}, + wantMsg: "statement reached terminal state CLOSED", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := statementErrorFromStatus(tc.status) + if tc.wantNil { + assert.Nil(t, got) + return + } + require.NotNil(t, got) + assert.Equal(t, tc.wantMsg, got.Message) + assert.Equal(t, tc.wantCode, got.ErrorCode) + }) + } +} diff --git a/experimental/aitools/cmd/tools.go b/experimental/aitools/cmd/tools.go index b5dd306d21..22781f987f 100644 --- a/experimental/aitools/cmd/tools.go +++ b/experimental/aitools/cmd/tools.go @@ -15,6 +15,7 @@ func newToolsCmd() *cobra.Command { cmd.AddCommand(newQueryCmd()) cmd.AddCommand(newDiscoverSchemaCmd()) cmd.AddCommand(newGetDefaultWarehouseCmd()) + cmd.AddCommand(newStatementCmd()) return cmd }