From 79fc08074e05ebeb45212ffc8ad1a1332a87b35e Mon Sep 17 00:00:00 2001 From: simon Date: Mon, 27 Apr 2026 15:04:52 +0200 Subject: [PATCH 1/5] aitools: extract pollStatement helper and pin OnWaitTimeout Refactor `executeAndPoll` in `experimental/aitools/cmd/query.go` to extract a pure `pollStatement(ctx, api, resp)` helper. The helper polls until the statement reaches a terminal state and returns the response without any signal handling, spinner, or server-side cancellation; those concerns stay in `executeAndPoll` where they belong. Also pin `OnWaitTimeout: CONTINUE` explicitly on the `ExecuteStatement` call. The SDK default happens to be CONTINUE today, but relying on it is a hidden coupling: a server-side default flip would silently break the poll loop by killing the statement before our first GET. Behavior is unchanged for the existing `query` command. Follow-up PRs (parallel batch queries, statement lifecycle command tree) will reuse the helper. Co-authored-by: Isaac --- experimental/aitools/cmd/query.go | 63 ++++++++++----- experimental/aitools/cmd/query_test.go | 105 ++++++++++++++++++++++++- 2 files changed, 146 insertions(+), 22 deletions(-) diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index 7b95fdd4e23..6c125bbcd6b 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -262,9 +262,10 @@ func resolveWarehouseID(ctx context.Context, w any, flagValue string) (string, e func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, statement string) (*sql.StatementResponse, error) { // Submit asynchronously to get the statement ID immediately for cancellation. resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ - WarehouseId: warehouseID, - Statement: statement, - WaitTimeout: "0s", + WarehouseId: warehouseID, + Statement: statement, + WaitTimeout: "0s", + OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, }) if err != nil { return nil, fmt.Errorf("execute statement: %w", err) @@ -272,11 +273,6 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa statementID := resp.StatementId - // Check if it completed immediately. - if isTerminalState(resp.Status) { - return resp, checkFailedState(resp.Status) - } - // Set up Ctrl+C: signal cancels the poll context, cleanup is unified below. pollCtx, pollCancel := context.WithCancel(ctx) defer pollCancel() @@ -327,34 +323,59 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa } }() + pollResp, err := pollStatement(pollCtx, api, resp) + if err != nil { + if pollCtx.Err() != nil { + cancelStatement() + cmdio.LogString(ctx, "Query cancelled.") + return nil, root.ErrAlreadyPrinted + } + return nil, err + } + + sp.Close() + if err := checkFailedState(pollResp.Status); err != nil { + return nil, err + } + return pollResp, nil +} + +// pollStatement polls until the statement reaches a terminal state. +// +// On context cancellation it returns the context error WITHOUT cancelling the +// server-side statement. Callers that want server-side cancellation should +// invoke CancelExecution explicitly. +// +// If the input response is already in a terminal state, it is returned without +// further polling. +func pollStatement(ctx context.Context, api sql.StatementExecutionInterface, resp *sql.StatementResponse) (*sql.StatementResponse, error) { + if isTerminalState(resp.Status) { + return resp, nil + } + + statementID := resp.StatementId + start := time.Now() + // Poll with additive backoff: 1s, 2s, 3s, 4s, 5s (capped). interval := pollIntervalInitial for { select { - case <-pollCtx.Done(): - cancelStatement() - cmdio.LogString(ctx, "Query cancelled.") - return nil, root.ErrAlreadyPrinted + case <-ctx.Done(): + return nil, ctx.Err() case <-time.After(interval): } log.Debugf(ctx, "Polling statement %s: %s elapsed", statementID, time.Since(start).Truncate(time.Second)) - pollResp, err := api.GetStatementByStatementId(pollCtx, statementID) + pollResp, err := api.GetStatementByStatementId(ctx, statementID) if err != nil { - if pollCtx.Err() != nil { - cancelStatement() - cmdio.LogString(ctx, "Query cancelled.") - return nil, root.ErrAlreadyPrinted + if ctx.Err() != nil { + return nil, ctx.Err() } return nil, fmt.Errorf("poll statement status: %w", err) } if isTerminalState(pollResp.Status) { - sp.Close() - if err := checkFailedState(pollResp.Status); err != nil { - return nil, err - } return &sql.StatementResponse{ StatementId: pollResp.StatementId, Status: pollResp.Status, diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index aa33921c83b..4bc06c1d63b 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -2,6 +2,7 @@ package aitools import ( "context" + "errors" "os" "path/filepath" "strings" @@ -48,7 +49,9 @@ func TestExecuteAndPollImmediateSuccess(t *testing.T) { mockAPI := mocksql.NewMockStatementExecutionInterface(t) mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { - return req.WarehouseId == "wh-123" && req.Statement == "SELECT 1" && req.WaitTimeout == "0s" + return req.WarehouseId == "wh-123" && req.Statement == "SELECT 1" && + req.WaitTimeout == "0s" && + req.OnWaitTimeout == sql.ExecuteStatementRequestOnWaitTimeoutContinue })).Return(&sql.StatementResponse{ StatementId: "stmt-1", Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, @@ -154,6 +157,106 @@ func TestExecuteAndPollCancelledContextCallsCancelExecution(t *testing.T) { require.ErrorIs(t, err, root.ErrAlreadyPrinted) } +func TestPollStatementImmediateTerminal(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + resp := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "1"}}}}, + Result: &sql.ResultData{DataArray: [][]string{{"1"}}}, + } + + pollResp, err := pollStatement(ctx, mockAPI, resp) + require.NoError(t, err) + assert.Equal(t, sql.StatementStateSucceeded, pollResp.Status.State) + assert.Equal(t, "stmt-1", pollResp.StatementId) +} + +func TestPollStatementTerminalFailureNotErrored(t *testing.T) { + // pollStatement returns the response without erroring on failed terminal + // states; callers (e.g. executeAndPoll) decide what to do via checkFailedState. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + resp := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ErrorCode: "ERR", Message: "boom"}, + }, + } + + pollResp, err := pollStatement(ctx, mockAPI, resp) + require.NoError(t, err) + assert.Equal(t, sql.StatementStateFailed, pollResp.Status.State) +} + +func TestPollStatementEventualSuccess(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + initial := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + } + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateRunning}, + }, nil).Once() + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Result: &sql.ResultData{DataArray: [][]string{{"42"}}}, + }, nil).Once() + + pollResp, err := pollStatement(ctx, mockAPI, initial) + require.NoError(t, err) + assert.Equal(t, sql.StatementStateSucceeded, pollResp.Status.State) + assert.Equal(t, [][]string{{"42"}}, pollResp.Result.DataArray) +} + +func TestPollStatementContextCancellationDoesNotCancelServerSide(t *testing.T) { + // The mock asserts (via t.Cleanup) that no unexpected calls are made. + // Specifically, pollStatement must NOT call CancelExecution on context + // cancellation; that is the caller's responsibility. + ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context())) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + initial := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + } + + cancel() + + pollResp, err := pollStatement(ctx, mockAPI, initial) + require.ErrorIs(t, err, context.Canceled) + assert.Nil(t, pollResp) +} + +func TestPollStatementGetErrorPropagated(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + initial := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + } + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1"). + Return(nil, errors.New("network unreachable")).Once() + + pollResp, err := pollStatement(ctx, mockAPI, initial) + require.Error(t, err) + assert.Contains(t, err.Error(), "poll statement status") + assert.Contains(t, err.Error(), "network unreachable") + assert.Nil(t, pollResp) +} + func TestResolveWarehouseIDWithFlag(t *testing.T) { ctx := t.Context() id, err := resolveWarehouseID(ctx, nil, "explicit-id") From 6b6128a7c10f802e5c67838ad1215ef6ec8fb985 Mon Sep 17 00:00:00 2001 From: simon Date: Mon, 27 Apr 2026 15:37:40 +0200 Subject: [PATCH 2/5] aitools: run multiple SQL queries in parallel from one query invocation Allow `databricks experimental aitools tools query` to accept several SQLs in a single invocation and run them in parallel against the warehouse. Pass multiple positional arguments and/or repeat `--file` to fan out: databricks experimental aitools tools query \ --warehouse --output json \ "SELECT count(*) FROM t" \ "SELECT min(ts), max(ts) FROM t" \ "SELECT col, count(*) FROM t GROUP BY 1" Multi-query output is always a JSON array of one object per input, preserving input order. The shape is `{sql, statement_id, state, elapsed_ms, columns, rows, error}`. Individual statement failures don't abort siblings; each is encoded in the per-result `error` field, and the exit code is non-zero when any statement failed. A new `--concurrency` flag (default 8) caps in-flight statements. On Ctrl+C the still-running statements are cancelled server-side via CancelExecution before exit. Single-query behavior is unchanged. The previous restriction that forbade mixing `--file` and a positional SQL is lifted, since both now contribute to the batch. Co-authored-by: Isaac --- experimental/aitools/README.md | 12 ++ experimental/aitools/cmd/batch.go | 206 +++++++++++++++++++++ experimental/aitools/cmd/batch_test.go | 237 +++++++++++++++++++++++++ experimental/aitools/cmd/query.go | 132 +++++++++----- experimental/aitools/cmd/query_test.go | 113 ++++++++---- experimental/aitools/cmd/render.go | 11 ++ 6 files changed, 637 insertions(+), 74 deletions(-) create mode 100644 experimental/aitools/cmd/batch.go create mode 100644 experimental/aitools/cmd/batch_test.go diff --git a/experimental/aitools/README.md b/experimental/aitools/README.md index 571136538c9..f645e4de51d 100644 --- a/experimental/aitools/README.md +++ b/experimental/aitools/README.md @@ -16,6 +16,18 @@ Current behavior: - `skills install` installs Databricks skills for detected coding agents. - `install` is a compatibility alias for `skills install`. - `tools` exposes a small set of AI-oriented workspace helpers. +- `tools query` accepts a single SQL or multiple SQLs in one invocation. Pass + several positional arguments and/or repeat `--file` to run them in parallel + against the warehouse. Multi-query output is always JSON; control parallelism + with `--concurrency` (default 8). + + ```bash + databricks experimental aitools tools query \ + --warehouse --output json \ + "SELECT count(*) FROM samples.nyctaxi.trips" \ + "SELECT min(tpep_pickup_datetime), max(tpep_pickup_datetime) FROM samples.nyctaxi.trips" \ + "SELECT vendor_id, count(*) FROM samples.nyctaxi.trips GROUP BY 1" + ``` Removed behavior: diff --git a/experimental/aitools/cmd/batch.go b/experimental/aitools/cmd/batch.go new file mode 100644 index 00000000000..8965923c17c --- /dev/null +++ b/experimental/aitools/cmd/batch.go @@ -0,0 +1,206 @@ +package aitools + +import ( + "context" + "fmt" + "os" + "os/signal" + "sync/atomic" + "syscall" + "time" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go/service/sql" + "golang.org/x/sync/errgroup" +) + +// defaultBatchConcurrency caps in-flight statements when --concurrency is unset. +// Matches the default used by cmd/fs/cp.go for similar fan-out work. +const defaultBatchConcurrency = 8 + +// batchResult is the per-statement payload emitted in batch mode JSON output. +// State is the server-reported terminal state. Error is set whenever the +// statement did not produce usable rows, regardless of state, so consumers +// can branch on `error == null` alone. +type batchResult struct { + SQL string `json:"sql"` + StatementID string `json:"statement_id,omitempty"` + State sql.StatementState `json:"state,omitempty"` + ElapsedMs int64 `json:"elapsed_ms"` + Columns []string `json:"columns,omitempty"` + Rows [][]string `json:"rows,omitempty"` + Error *batchResultError `json:"error,omitempty"` +} + +// batchResultError captures user-visible error info for a failed statement. +type batchResultError struct { + Message string `json:"message"` + ErrorCode string `json:"error_code,omitempty"` +} + +// executeBatch submits sqls against the warehouse in parallel, polls each to +// completion, and returns one batchResult per input in input order. +// +// Individual statement failures do not abort siblings; failures are encoded in +// the per-result Error field so callers can render partial results. +// +// On context cancellation (Ctrl+C or parent context), still-running statements +// are cancelled server-side via CancelExecution. Statements that finished +// before cancellation are left as-is. +func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, warehouseID string, sqls []string, concurrency int) []batchResult { + pollCtx, pollCancel := context.WithCancel(ctx) + defer pollCancel() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(sigCh) + + go func() { + select { + case <-sigCh: + log.Infof(ctx, "Received interrupt, cancelling %d in-flight queries", len(sqls)) + pollCancel() + case <-pollCtx.Done(): + } + }() + + sp := cmdio.NewSpinner(pollCtx) + defer sp.Close() + sp.Update(fmt.Sprintf("Executing %d queries...", len(sqls))) + + var completed atomic.Int64 + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + go func() { + for { + select { + case <-pollCtx.Done(): + return + case <-ticker.C: + sp.Update(fmt.Sprintf("Executing %d queries... (%d/%d done)", len(sqls), completed.Load(), len(sqls))) + } + } + }() + + results := make([]batchResult, len(sqls)) + // Each goroutine writes to a distinct slot, safe without a mutex. + // We read after g.Wait(), establishing happens-before for all writes. + statementIDs := make([]string, len(sqls)) + + g := new(errgroup.Group) + g.SetLimit(concurrency) + for i, sqlStr := range sqls { + g.Go(func() error { + results[i] = runOneBatchQuery(pollCtx, api, warehouseID, sqlStr, statementIDs, i) + completed.Add(1) + return nil + }) + } + _ = g.Wait() + + // pollStatement is a pure helper that returns ctx.Err() on cancellation + // without touching the server. Sweep any not-yet-terminal statements here. + if pollCtx.Err() != nil { + cancelInFlight(ctx, api, statementIDs, results) + } + + return results +} + +// runOneBatchQuery submits one SQL, polls to completion, and returns its +// batchResult. All errors are encoded into the result; never returns an error. +func runOneBatchQuery(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, sqlStr string, statementIDs []string, idx int) batchResult { + start := time.Now() + result := batchResult{SQL: sqlStr} + + resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ + WarehouseId: warehouseID, + Statement: sqlStr, + WaitTimeout: "0s", + OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, + }) + if err != nil { + if ctx.Err() != nil { + result.State = sql.StatementStateCanceled + result.Error = &batchResultError{Message: "submission cancelled"} + } else { + result.State = sql.StatementStateFailed + result.Error = &batchResultError{Message: fmt.Sprintf("execute statement: %v", err)} + } + result.ElapsedMs = time.Since(start).Milliseconds() + return result + } + + statementIDs[idx] = resp.StatementId + result.StatementID = resp.StatementId + + pollResp, err := pollStatement(ctx, api, resp) + if err != nil { + if ctx.Err() != nil { + result.State = sql.StatementStateCanceled + result.Error = &batchResultError{Message: "cancelled"} + } else { + result.State = sql.StatementStateFailed + result.Error = &batchResultError{Message: err.Error()} + } + result.ElapsedMs = time.Since(start).Milliseconds() + return result + } + + if pollResp.Status != nil { + result.State = pollResp.Status.State + } + + if result.State != sql.StatementStateSucceeded { + result.Error = &batchResultError{} + if pollResp.Status != nil && pollResp.Status.Error != nil { + result.Error.Message = pollResp.Status.Error.Message + result.Error.ErrorCode = string(pollResp.Status.Error.ErrorCode) + } else { + result.Error.Message = fmt.Sprintf("query reached terminal state %s", result.State) + } + result.ElapsedMs = time.Since(start).Milliseconds() + return result + } + + result.Columns = extractColumns(pollResp.Manifest) + rows, err := fetchAllRows(ctx, api, pollResp) + if err != nil { + result.Error = &batchResultError{Message: fmt.Sprintf("fetch rows: %v", err)} + result.ElapsedMs = time.Since(start).Milliseconds() + return result + } + result.Rows = rows + result.ElapsedMs = time.Since(start).Milliseconds() + return result +} + +// cancelInFlight sends CancelExecution for every statement that didn't reach +// a terminal state server-side before context cancellation. Best effort: errors +// are logged at warn but don't fail the batch. +func cancelInFlight(ctx context.Context, api sql.StatementExecutionInterface, statementIDs []string, results []batchResult) { + var cancelled int + for i, sid := range statementIDs { + if sid == "" { + continue + } + switch results[i].State { + case sql.StatementStateSucceeded, sql.StatementStateFailed, sql.StatementStateClosed: + continue + case sql.StatementStateCanceled, sql.StatementStatePending, sql.StatementStateRunning: + // Either still running server-side, or our internal "canceled" + // marker meaning the goroutine bailed without telling the server. + // Either way, send CancelExecution. + } + cancelCtx, cancel := context.WithTimeout(ctx, cancelTimeout) + if err := api.CancelExecution(cancelCtx, sql.CancelExecutionRequest{StatementId: sid}); err != nil { + log.Warnf(ctx, "Failed to cancel statement %s: %v", sid, err) + } + cancel() + cancelled++ + } + if cancelled > 0 { + cmdio.LogString(ctx, fmt.Sprintf("Cancelled %d in-flight queries.", cancelled)) + } +} diff --git a/experimental/aitools/cmd/batch_test.go b/experimental/aitools/cmd/batch_test.go new file mode 100644 index 00000000000..96235530f4d --- /dev/null +++ b/experimental/aitools/cmd/batch_test.go @@ -0,0 +1,237 @@ +package aitools + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + "testing" + + "github.com/databricks/cli/libs/cmdio" + mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestRenderBatchJSON(t *testing.T) { + results := []batchResult{ + { + SQL: "SELECT 1", + StatementID: "stmt-1", + State: sql.StatementStateSucceeded, + ElapsedMs: 42, + Columns: []string{"n"}, + Rows: [][]string{{"1"}}, + }, + { + SQL: "SELECT bad_syntax", + StatementID: "stmt-2", + State: sql.StatementStateFailed, + ElapsedMs: 12, + Error: &batchResultError{ + Message: "near 'bad_syntax': syntax error", + ErrorCode: "SYNTAX_ERROR", + }, + }, + } + + var buf strings.Builder + err := renderBatchJSON(&buf, results) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, `"sql": "SELECT 1"`) + assert.Contains(t, output, `"statement_id": "stmt-1"`) + assert.Contains(t, output, `"state": "SUCCEEDED"`) + assert.Contains(t, output, `"elapsed_ms": 42`) + assert.Contains(t, output, `"columns": [`) + assert.Contains(t, output, `"rows": [`) + assert.Contains(t, output, `"sql": "SELECT bad_syntax"`) + assert.Contains(t, output, `"error": {`) + assert.Contains(t, output, `"error_code": "SYNTAX_ERROR"`) + // Trailing newline. + assert.True(t, strings.HasSuffix(output, "\n")) +} + +func TestExecuteBatchAllSucceed(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + sqls := []string{"SELECT 1", "SELECT 2", "SELECT 3"} + for i, sqlStr := range sqls { + sid := fmt.Sprintf("stmt-%d", i+1) + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == sqlStr + })).Return(&sql.StatementResponse{ + StatementId: sid, + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "n"}}}}, + Result: &sql.ResultData{DataArray: [][]string{{strconv.Itoa(i + 1)}}}, + }, nil).Once() + } + + results := executeBatch(ctx, mockAPI, "wh-123", sqls, 8) + + require.Len(t, results, 3) + for i, r := range results { + assert.Equal(t, sqls[i], r.SQL, "result %d sql", i) + assert.Equal(t, sql.StatementStateSucceeded, r.State, "result %d state", i) + assert.Nil(t, r.Error, "result %d error", i) + assert.Equal(t, []string{"n"}, r.Columns, "result %d columns", i) + assert.Equal(t, [][]string{{strconv.Itoa(i + 1)}}, r.Rows, "result %d rows", i) + assert.NotEmpty(t, r.StatementID, "result %d statement_id", i) + } +} + +func TestExecuteBatchPartialFailure(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT 1" + })).Return(&sql.StatementResponse{ + StatementId: "stmt-good", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "n"}}}}, + Result: &sql.ResultData{DataArray: [][]string{{"1"}}}, + }, nil).Once() + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT bad" + })).Return(&sql.StatementResponse{ + StatementId: "stmt-bad", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ + ErrorCode: "SYNTAX_ERROR", + Message: "near 'bad': syntax error", + }, + }, + }, nil).Once() + + results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT 1", "SELECT bad"}, 8) + + require.Len(t, results, 2) + assert.Nil(t, results[0].Error) + assert.Equal(t, sql.StatementStateSucceeded, results[0].State) + + require.NotNil(t, results[1].Error) + assert.Equal(t, sql.StatementStateFailed, results[1].State) + assert.Equal(t, "SYNTAX_ERROR", results[1].Error.ErrorCode) + assert.Contains(t, results[1].Error.Message, "syntax error") +} + +func TestExecuteBatchSubmissionFailure(t *testing.T) { + // ExecuteStatement transport error is encoded into the per-result error, + // not propagated up to abort siblings. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT good" + })).Return(&sql.StatementResponse{ + StatementId: "stmt-good", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + }, nil).Once() + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT broken" + })).Return(nil, errors.New("network unreachable")).Once() + + results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT good", "SELECT broken"}, 8) + + require.Len(t, results, 2) + assert.Nil(t, results[0].Error) + require.NotNil(t, results[1].Error) + assert.Contains(t, results[1].Error.Message, "execute statement") + assert.Contains(t, results[1].Error.Message, "network unreachable") + assert.Empty(t, results[1].StatementID) +} + +func TestExecuteBatchSetsOnWaitTimeoutContinue(t *testing.T) { + // Guards against a silent SDK default flip from CONTINUE to CANCEL. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.WaitTimeout == "0s" && req.OnWaitTimeout == sql.ExecuteStatementRequestOnWaitTimeoutContinue + })).Return(&sql.StatementResponse{ + StatementId: "stmt-x", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + }, nil).Times(2) + + results := executeBatch(ctx, mockAPI, "wh-123", []string{"q1", "q2"}, 8) + require.Len(t, results, 2) +} + +func TestExecuteBatchPreservesInputOrder(t *testing.T) { + // Index 0 is slow (PENDING then SUCCEEDED on first poll); 1 and 2 are + // immediate. Despite the staggered completion, results stay in input order. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT 'slow'" + })).Return(&sql.StatementResponse{ + StatementId: "stmt-slow", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + }, nil).Once() + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-slow").Return(&sql.StatementResponse{ + StatementId: "stmt-slow", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + }, nil).Once() + + for i, sqlStr := range []string{"SELECT 'fast1'", "SELECT 'fast2'"} { + sid := fmt.Sprintf("stmt-fast-%d", i+1) + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == sqlStr + })).Return(&sql.StatementResponse{ + StatementId: sid, + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + }, nil).Once() + } + + sqls := []string{"SELECT 'slow'", "SELECT 'fast1'", "SELECT 'fast2'"} + results := executeBatch(ctx, mockAPI, "wh-1", sqls, 8) + + require.Len(t, results, 3) + for i, r := range results { + assert.Equal(t, sqls[i], r.SQL, "result %d", i) + assert.Equal(t, sql.StatementStateSucceeded, r.State, "result %d", i) + } +} + +func TestExecuteBatchContextCancellationCancelsInFlight(t *testing.T) { + // All statements are PENDING when the context is cancelled. cancelInFlight + // sweeps the in-flight set with CancelExecution. + ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context())) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + for i, sqlStr := range []string{"q1", "q2", "q3"} { + sid := fmt.Sprintf("stmt-%d", i+1) + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == sqlStr + })).Return(&sql.StatementResponse{ + StatementId: sid, + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + }, nil).Once() + + mockAPI.EXPECT().CancelExecution(mock.Anything, sql.CancelExecutionRequest{ + StatementId: sid, + }).Return(nil).Once() + } + + cancel() + + results := executeBatch(ctx, mockAPI, "wh", []string{"q1", "q2", "q3"}, 8) + + require.Len(t, results, 3) + for i, r := range results { + assert.Equal(t, sql.StatementStateCanceled, r.State, "result %d state", i) + require.NotNil(t, r.Error, "result %d error", i) + } +} diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index 6c125bbcd6b..b7e4d5ede34 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -75,31 +75,40 @@ func selectQueryOutputMode(outputType flags.Output, stdoutInteractive, promptSup func newQueryCmd() *cobra.Command { var warehouseID string - var filePath string + var filePaths []string var outputFormat string + var concurrency int cmd := &cobra.Command{ - Use: "query [SQL | file.sql]", + Use: "query [SQL | file.sql]...", Short: "Execute SQL against a Databricks warehouse", - Long: `Execute a SQL statement against a Databricks SQL warehouse and return results. + Long: `Execute one or more SQL statements against a Databricks SQL warehouse +and return results. -SQL can be provided as a positional argument, read from a file with --file, -or piped via stdin. If the positional argument ends in .sql and the file -exists, it is read as a SQL file automatically. +A single SQL can be provided as a positional argument, read from a file with +--file, or piped via stdin. If a positional argument ends in .sql and the +file exists, it is read as a SQL file automatically. + +Pass multiple positional arguments and/or repeat --file to run several +queries in parallel against the warehouse. Multi-query output is always +JSON: an array of {sql, statement_id, state, elapsed_ms, columns, rows, +error} objects in input order. The exit code is non-zero if any query +failed. The command auto-detects an available warehouse unless --warehouse is set or the DATABRICKS_WAREHOUSE_ID environment variable is configured. -Output is JSON in non-interactive contexts. In interactive terminals it renders -tables, and large results open an interactive table browser. Use --output csv -to export results as CSV.`, +For a single query, output is JSON in non-interactive contexts. In +interactive terminals it renders tables, and large results open an +interactive table browser. Use --output csv to export results as CSV.`, Example: ` databricks experimental aitools tools query "SELECT * FROM samples.nyctaxi.trips LIMIT 5" databricks experimental aitools tools query --warehouse abc123 "SELECT 1" databricks experimental aitools tools query --file report.sql databricks experimental aitools tools query report.sql databricks experimental aitools tools query --output csv "SELECT * FROM samples.nyctaxi.trips LIMIT 5" + databricks experimental aitools tools query --output json "SELECT 1" "SELECT 2" "SELECT 3" echo "SELECT 1" | databricks experimental aitools tools query`, - Args: cobra.MaximumNArgs(1), + Args: cobra.ArbitraryArgs, PreRunE: root.MustWorkspaceClient, RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() @@ -124,19 +133,29 @@ to export results as CSV.`, return fmt.Errorf("unsupported output format %q, accepted values: text, json, csv", outputFormat) } - w := cmdctx.WorkspaceClient(ctx) - - sqlStatement, err := resolveSQL(ctx, cmd, args, filePath) + sqls, err := resolveSQLs(ctx, cmd, args, filePaths) if err != nil { return err } + // Reject incompatible flag combinations before any API call so the + // user sees the real error instead of an auth/warehouse failure. + if len(sqls) > 1 && flags.Output(outputFormat) != flags.OutputJSON { + return fmt.Errorf("multiple queries require --output json (got %q); pass --output json to receive a JSON array of per-statement results", outputFormat) + } + + w := cmdctx.WorkspaceClient(ctx) + wID, err := resolveWarehouseID(ctx, w, warehouseID) if err != nil { return err } - resp, err := executeAndPoll(ctx, w.StatementExecution, wID, sqlStatement) + if len(sqls) > 1 { + return runBatch(ctx, cmd, w.StatementExecution, wID, sqls, concurrency) + } + + resp, err := executeAndPoll(ctx, w.StatementExecution, wID, sqls[0]) if err != nil { return err } @@ -177,7 +196,8 @@ to export results as CSV.`, } cmd.Flags().StringVarP(&warehouseID, "warehouse", "w", "", "SQL warehouse ID to use for execution") - cmd.Flags().StringVarP(&filePath, "file", "f", "", "Path to a SQL file to execute") + cmd.Flags().StringSliceVarP(&filePaths, "file", "f", nil, "Path to a SQL file to execute (repeatable; pair with positional SQLs to run a batch)") + cmd.Flags().IntVar(&concurrency, "concurrency", defaultBatchConcurrency, "Maximum in-flight statements when running a batch of queries") // Local --output flag shadows the root command's persistent --output flag, // adding csv support for this command only. cmd.Flags().StringVarP(&outputFormat, "output", "o", string(flags.OutputText), "Output format: text, json, or csv") @@ -188,59 +208,85 @@ to export results as CSV.`, return cmd } -// resolveSQL determines the SQL statement to execute from the available input sources. -// Priority: --file flag > positional arg > stdin. -func resolveSQL(ctx context.Context, cmd *cobra.Command, args []string, filePath string) (string, error) { - var raw string +// resolveSQLs collects SQL statements from --file paths, positional args, and +// stdin. The returned slice preserves source order: --file paths first (in flag +// order), then positional args (in arg order), then stdin (only if no other +// source produced anything). Each SQL is run through cleanSQL. +func resolveSQLs(ctx context.Context, cmd *cobra.Command, args, filePaths []string) ([]string, error) { + var raws []string - switch { - case filePath != "": - if len(args) > 0 { - return "", errors.New("cannot use both --file and a positional SQL argument") - } - data, err := os.ReadFile(filePath) + for _, path := range filePaths { + data, err := os.ReadFile(path) if err != nil { - return "", fmt.Errorf("read SQL file: %w", err) + return nil, fmt.Errorf("read SQL file %s: %w", path, err) } - raw = string(data) + raws = append(raws, string(data)) + } - case len(args) > 0: + for _, arg := range args { // If the argument looks like a .sql file, try to read it. // Only fall through to literal SQL if the file doesn't exist. // Surface other errors (permission denied, etc.) directly. - if strings.HasSuffix(args[0], sqlFileExtension) { - data, err := os.ReadFile(args[0]) + if strings.HasSuffix(arg, sqlFileExtension) { + data, err := os.ReadFile(arg) if err != nil && !errors.Is(err, os.ErrNotExist) { - return "", fmt.Errorf("read SQL file: %w", err) + return nil, fmt.Errorf("read SQL file: %w", err) } if err == nil { - raw = string(data) - break + raws = append(raws, string(data)) + continue } } - raw = args[0] + raws = append(raws, arg) + } - default: - // No args: try reading from stdin if it's piped. + if len(raws) == 0 { + // No --file and no positional args: try reading from stdin if it's piped. // If stdin was overridden (e.g. cmd.SetIn in tests), always read from it. // Otherwise, only read if stdin is not a TTY (i.e. piped input). in := cmd.InOrStdin() _, isOsFile := in.(*os.File) if isOsFile && cmdio.IsPromptSupported(ctx) { - return "", errors.New("no SQL provided; pass a SQL string, use --file, or pipe via stdin") + return nil, errors.New("no SQL provided; pass a SQL string, use --file, or pipe via stdin") } data, err := io.ReadAll(in) if err != nil { - return "", fmt.Errorf("read stdin: %w", err) + return nil, fmt.Errorf("read stdin: %w", err) + } + raws = append(raws, string(data)) + } + + cleaned := make([]string, 0, len(raws)) + for i, raw := range raws { + c := cleanSQL(raw) + if c == "" { + if len(raws) == 1 { + return nil, errors.New("SQL statement is empty after removing comments and blank lines") + } + return nil, fmt.Errorf("SQL statement #%d is empty after removing comments and blank lines", i+1) } - raw = string(data) + cleaned = append(cleaned, c) } + return cleaned, nil +} - result := cleanSQL(raw) - if result == "" { - return "", errors.New("SQL statement is empty after removing comments and blank lines") +// runBatch executes multiple SQL statements in parallel and renders the result +// as a JSON array. Returns root.ErrAlreadyPrinted (so the exit code is non-zero +// without an extra error message) when any statement failed; the failure detail +// is already encoded in the printed JSON. The caller is responsible for +// rejecting incompatible output formats before invoking this. +func runBatch(ctx context.Context, cmd *cobra.Command, api sql.StatementExecutionInterface, warehouseID string, sqls []string, concurrency int) error { + results := executeBatch(ctx, api, warehouseID, sqls, concurrency) + if err := renderBatchJSON(cmd.OutOrStdout(), results); err != nil { + return err } - return result, nil + + for _, r := range results { + if r.Error != nil { + return root.ErrAlreadyPrinted + } + } + return nil } // resolveWarehouseID returns the warehouse ID to use for query execution. diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index 4bc06c1d63b..a5d079acf8a 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -433,69 +433,95 @@ func TestPollingConstants(t *testing.T) { assert.Equal(t, 10*time.Second, cancelTimeout) } -// newTestCmd creates a minimal cobra.Command for testing resolveSQL. +// newTestCmd creates a minimal cobra.Command for testing resolveSQLs. func newTestCmd() *cobra.Command { return &cobra.Command{Use: "test"} } -func TestResolveSQLFromFileFlag(t *testing.T) { +func TestResolveSQLsFromFileFlag(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "query.sql") err := os.WriteFile(path, []byte("SELECT 1"), 0o644) require.NoError(t, err) cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, path) + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{path}) require.NoError(t, err) - assert.Equal(t, "SELECT 1", result) + assert.Equal(t, []string{"SELECT 1"}, result) } -func TestResolveSQLFromFileFlagWithComments(t *testing.T) { +func TestResolveSQLsFromFileFlagWithComments(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "query.sql") err := os.WriteFile(path, []byte("-- header comment\nSELECT 1\n-- trailing"), 0o644) require.NoError(t, err) cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, path) + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{path}) require.NoError(t, err) - assert.Equal(t, "SELECT 1", result) + assert.Equal(t, []string{"SELECT 1"}, result) } -func TestResolveSQLFileFlagConflictsWithArg(t *testing.T) { +func TestResolveSQLsMixedFileAndPositional(t *testing.T) { + // --file paths are emitted before positional args, in flag order. + dir := t.TempDir() + path := filepath.Join(dir, "from-file.sql") + err := os.WriteFile(path, []byte("SELECT 'from file'"), 0o644) + require.NoError(t, err) + cmd := newTestCmd() - _, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 1"}, "/some/file.sql") - require.Error(t, err) - assert.Contains(t, err.Error(), "cannot use both --file and a positional SQL argument") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 'from arg'"}, []string{path}) + require.NoError(t, err) + assert.Equal(t, []string{"SELECT 'from file'", "SELECT 'from arg'"}, result) +} + +func TestResolveSQLsMultiplePositional(t *testing.T) { + cmd := newTestCmd() + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 1", "SELECT 2", "SELECT 3"}, nil) + require.NoError(t, err) + assert.Equal(t, []string{"SELECT 1", "SELECT 2", "SELECT 3"}, result) +} + +func TestResolveSQLsMultipleFiles(t *testing.T) { + dir := t.TempDir() + pathA := filepath.Join(dir, "a.sql") + pathB := filepath.Join(dir, "b.sql") + require.NoError(t, os.WriteFile(pathA, []byte("SELECT 'a'"), 0o644)) + require.NoError(t, os.WriteFile(pathB, []byte("SELECT 'b'"), 0o644)) + + cmd := newTestCmd() + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{pathA, pathB}) + require.NoError(t, err) + assert.Equal(t, []string{"SELECT 'a'", "SELECT 'b'"}, result) } -func TestResolveSQLFromPositionalArg(t *testing.T) { +func TestResolveSQLsFromPositionalArg(t *testing.T) { cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 42"}, "") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 42"}, nil) require.NoError(t, err) - assert.Equal(t, "SELECT 42", result) + assert.Equal(t, []string{"SELECT 42"}, result) } -func TestResolveSQLAutoDetectsSQLFile(t *testing.T) { +func TestResolveSQLsAutoDetectsSQLFile(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "report.sql") err := os.WriteFile(path, []byte("SELECT * FROM sales"), 0o644) require.NoError(t, err) cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{path}, "") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{path}, nil) require.NoError(t, err) - assert.Equal(t, "SELECT * FROM sales", result) + assert.Equal(t, []string{"SELECT * FROM sales"}, result) } -func TestResolveSQLNonexistentSQLFileTreatedAsString(t *testing.T) { +func TestResolveSQLsNonexistentSQLFileTreatedAsString(t *testing.T) { cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{"nonexistent.sql"}, "") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"nonexistent.sql"}, nil) require.NoError(t, err) - assert.Equal(t, "nonexistent.sql", result) + assert.Equal(t, []string{"nonexistent.sql"}, result) } -func TestResolveSQLUnreadableSQLFileReturnsError(t *testing.T) { +func TestResolveSQLsUnreadableSQLFileReturnsError(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "locked.sql") err := os.WriteFile(path, []byte("SELECT 1"), 0o644) @@ -507,47 +533,54 @@ func TestResolveSQLUnreadableSQLFileReturnsError(t *testing.T) { t.Cleanup(func() { _ = os.Chmod(path, 0o644) }) cmd := newTestCmd() - _, err = resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{path}, "") + _, err = resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{path}, nil) require.Error(t, err) assert.Contains(t, err.Error(), "read SQL file") } -func TestResolveSQLFromStdin(t *testing.T) { +func TestResolveSQLsFromStdin(t *testing.T) { cmd := newTestCmd() cmd.SetIn(strings.NewReader("SELECT 1 FROM stdin_test")) - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, "") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, nil) require.NoError(t, err) - assert.Equal(t, "SELECT 1 FROM stdin_test", result) + assert.Equal(t, []string{"SELECT 1 FROM stdin_test"}, result) } -func TestResolveSQLEmptyFileReturnsError(t *testing.T) { +func TestResolveSQLsEmptyFileReturnsError(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "empty.sql") err := os.WriteFile(path, []byte(""), 0o644) require.NoError(t, err) cmd := newTestCmd() - _, err = resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, path) + _, err = resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{path}) require.Error(t, err) assert.Contains(t, err.Error(), "empty") } -func TestResolveSQLCommentsOnlyFileReturnsError(t *testing.T) { +func TestResolveSQLsCommentsOnlyFileReturnsError(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "comments.sql") err := os.WriteFile(path, []byte("-- just a comment\n-- another"), 0o644) require.NoError(t, err) cmd := newTestCmd() - _, err = resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, path) + _, err = resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{path}) require.Error(t, err) assert.Contains(t, err.Error(), "empty") } -func TestResolveSQLMissingFileReturnsError(t *testing.T) { +func TestResolveSQLsBatchEmptyAtIndexReturnsIndexedError(t *testing.T) { cmd := newTestCmd() - _, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, "/nonexistent/path/query.sql") + _, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 1", "-- comment only", "SELECT 3"}, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "SQL statement #2 is empty") +} + +func TestResolveSQLsMissingFileReturnsError(t *testing.T) { + cmd := newTestCmd() + _, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{"/nonexistent/path/query.sql"}) require.Error(t, err) assert.Contains(t, err.Error(), "read SQL file") } @@ -561,6 +594,24 @@ func TestQueryCommandUnsupportedOutputReturnsError(t *testing.T) { assert.Contains(t, err.Error(), "unsupported output format") } +func TestQueryCommandBatchTextOutputRejected(t *testing.T) { + cmd := newQueryCmd() + cmd.PreRunE = nil + cmd.SetArgs([]string{"--output", "text", "SELECT 1", "SELECT 2"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "multiple queries require --output json") +} + +func TestQueryCommandBatchCsvOutputRejected(t *testing.T) { + cmd := newQueryCmd() + cmd.PreRunE = nil + cmd.SetArgs([]string{"--output", "csv", "SELECT 1", "SELECT 2"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "multiple queries require --output json") +} + func TestQueryCommandOutputFlagIsCaseInsensitive(t *testing.T) { cmd := newQueryCmd() cmd.PreRunE = nil diff --git a/experimental/aitools/cmd/render.go b/experimental/aitools/cmd/render.go index 7727c37106c..d0b62926c20 100644 --- a/experimental/aitools/cmd/render.go +++ b/experimental/aitools/cmd/render.go @@ -29,6 +29,17 @@ func extractColumns(manifest *sql.ResultManifest) []string { return columns } +// renderBatchJSON writes batch results as a JSON array. The array preserves +// input order and includes one object per submitted statement. +func renderBatchJSON(w io.Writer, results []batchResult) error { + output, err := json.MarshalIndent(results, "", " ") + if err != nil { + return fmt.Errorf("marshal batch results: %w", err) + } + fmt.Fprintf(w, "%s\n", output) + return nil +} + // renderJSON writes query results as a parseable JSON array to stdout. // Row count is written to stderr so stdout remains valid JSON for piping. func renderJSON(w io.Writer, columns []string, rows [][]string) error { From bc06013af476d113a5eb92ecfe77702cb9dd0e3f Mon Sep 17 00:00:00 2001 From: simon Date: Mon, 27 Apr 2026 15:49:30 +0200 Subject: [PATCH 3/5] aitools: validate --concurrency and document batch result order Address two findings from a cursor PR review: 1. --concurrency was passed straight into errgroup.SetLimit. A value of 0 deadlocks (errgroup refuses to add goroutines), and a negative value silently removes the cap. Add a PreRunE check that rejects anything <= 0 with errInvalidBatchConcurrency, matching the shape used by cmd/fs/cp.go for the same flag. 2. The Long help previously said multi-query results come back "in input order", which was ambiguous when --file and positional SQLs are mixed. The actual behavior (already covered by TestResolveSQLsMixedFileAndPositional) is: --file inputs first in flag order, then positional SQLs in arg order. Tighten the help text to state that contract precisely. Adds two unit tests that verify --concurrency 0 and -1 are rejected before any API call. Co-authored-by: Isaac --- experimental/aitools/cmd/batch.go | 5 +++++ experimental/aitools/cmd/query.go | 14 ++++++++++---- experimental/aitools/cmd/query_test.go | 16 ++++++++++++++++ 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/experimental/aitools/cmd/batch.go b/experimental/aitools/cmd/batch.go index 8965923c17c..3f8fc3015bb 100644 --- a/experimental/aitools/cmd/batch.go +++ b/experimental/aitools/cmd/batch.go @@ -2,6 +2,7 @@ package aitools import ( "context" + "errors" "fmt" "os" "os/signal" @@ -19,6 +20,10 @@ import ( // Matches the default used by cmd/fs/cp.go for similar fan-out work. const defaultBatchConcurrency = 8 +// errInvalidBatchConcurrency is returned when --concurrency is set to a value +// that errgroup.SetLimit can't honor (0 deadlocks, negative removes the cap). +var errInvalidBatchConcurrency = errors.New("--concurrency must be at least 1") + // batchResult is the per-statement payload emitted in batch mode JSON output. // State is the server-reported terminal state. Error is set whenever the // statement did not produce usable rows, regardless of state, so consumers diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index b7e4d5ede34..afe544c0e26 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -92,8 +92,9 @@ file exists, it is read as a SQL file automatically. Pass multiple positional arguments and/or repeat --file to run several queries in parallel against the warehouse. Multi-query output is always JSON: an array of {sql, statement_id, state, elapsed_ms, columns, rows, -error} objects in input order. The exit code is non-zero if any query -failed. +error} objects. Result order is: --file inputs first (in flag order), +then positional SQLs (in arg order). The exit code is non-zero if any +query failed. The command auto-detects an available warehouse unless --warehouse is set or the DATABRICKS_WAREHOUSE_ID environment variable is configured. @@ -108,8 +109,13 @@ interactive table browser. Use --output csv to export results as CSV.`, databricks experimental aitools tools query --output csv "SELECT * FROM samples.nyctaxi.trips LIMIT 5" databricks experimental aitools tools query --output json "SELECT 1" "SELECT 2" "SELECT 3" echo "SELECT 1" | databricks experimental aitools tools query`, - Args: cobra.ArbitraryArgs, - PreRunE: root.MustWorkspaceClient, + Args: cobra.ArbitraryArgs, + PreRunE: func(cmd *cobra.Command, args []string) error { + if concurrency <= 0 { + return errInvalidBatchConcurrency + } + return root.MustWorkspaceClient(cmd, args) + }, RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index a5d079acf8a..abd6ffe8341 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -612,6 +612,22 @@ func TestQueryCommandBatchCsvOutputRejected(t *testing.T) { assert.Contains(t, err.Error(), "multiple queries require --output json") } +func TestQueryCommandConcurrencyZeroRejected(t *testing.T) { + // errgroup.SetLimit(0) deadlocks; we reject it in PreRunE. + cmd := newQueryCmd() + cmd.SetArgs([]string{"--concurrency", "0", "--output", "json", "SELECT 1", "SELECT 2"}) + err := cmd.Execute() + require.ErrorIs(t, err, errInvalidBatchConcurrency) +} + +func TestQueryCommandConcurrencyNegativeRejected(t *testing.T) { + // Negative removes the cap entirely in errgroup, which surprises users. + cmd := newQueryCmd() + cmd.SetArgs([]string{"--concurrency", "-1", "--output", "json", "SELECT 1", "SELECT 2"}) + err := cmd.Execute() + require.ErrorIs(t, err, errInvalidBatchConcurrency) +} + func TestQueryCommandOutputFlagIsCaseInsensitive(t *testing.T) { cmd := newQueryCmd() cmd.PreRunE = nil From 4d15c81687a0b7971bbf00624d228546f2c646dd Mon Sep 17 00:00:00 2001 From: simon Date: Mon, 27 Apr 2026 21:36:58 +0200 Subject: [PATCH 4/5] aitools: fold redundant cobra-level rejection tests into table-driven cases MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two pairs of cobra-level tests were each testing one rejection code path with two flag values. Fold them into table-driven subtests so the shared assertion lives in one place: - TestQueryCommandBatchTextOutputRejected + ...CsvOutputRejected → TestQueryCommandBatchOutputRejection (text, csv subtests) - TestQueryCommandConcurrencyZeroRejected + ...NegativeRejected → TestQueryCommandConcurrencyRejection (0, -1 subtests) Same coverage, half the test functions. Co-authored-by: Isaac --- experimental/aitools/cmd/query_test.go | 54 ++++++++++++-------------- 1 file changed, 24 insertions(+), 30 deletions(-) diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index abd6ffe8341..e6bf9362fff 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -594,38 +594,32 @@ func TestQueryCommandUnsupportedOutputReturnsError(t *testing.T) { assert.Contains(t, err.Error(), "unsupported output format") } -func TestQueryCommandBatchTextOutputRejected(t *testing.T) { - cmd := newQueryCmd() - cmd.PreRunE = nil - cmd.SetArgs([]string{"--output", "text", "SELECT 1", "SELECT 2"}) - err := cmd.Execute() - require.Error(t, err) - assert.Contains(t, err.Error(), "multiple queries require --output json") -} - -func TestQueryCommandBatchCsvOutputRejected(t *testing.T) { - cmd := newQueryCmd() - cmd.PreRunE = nil - cmd.SetArgs([]string{"--output", "csv", "SELECT 1", "SELECT 2"}) - err := cmd.Execute() - require.Error(t, err) - assert.Contains(t, err.Error(), "multiple queries require --output json") -} - -func TestQueryCommandConcurrencyZeroRejected(t *testing.T) { - // errgroup.SetLimit(0) deadlocks; we reject it in PreRunE. - cmd := newQueryCmd() - cmd.SetArgs([]string{"--concurrency", "0", "--output", "json", "SELECT 1", "SELECT 2"}) - err := cmd.Execute() - require.ErrorIs(t, err, errInvalidBatchConcurrency) +func TestQueryCommandBatchOutputRejection(t *testing.T) { + // Multi-query mode is JSON-only. text and csv are rejected with an + // actionable error before any API call. + for _, format := range []string{"text", "csv"} { + t.Run(format, func(t *testing.T) { + cmd := newQueryCmd() + cmd.PreRunE = nil + cmd.SetArgs([]string{"--output", format, "SELECT 1", "SELECT 2"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "multiple queries require --output json") + }) + } } -func TestQueryCommandConcurrencyNegativeRejected(t *testing.T) { - // Negative removes the cap entirely in errgroup, which surprises users. - cmd := newQueryCmd() - cmd.SetArgs([]string{"--concurrency", "-1", "--output", "json", "SELECT 1", "SELECT 2"}) - err := cmd.Execute() - require.ErrorIs(t, err, errInvalidBatchConcurrency) +func TestQueryCommandConcurrencyRejection(t *testing.T) { + // errgroup.SetLimit(0) deadlocks; negative removes the cap entirely. + // Both surprise users, so PreRunE rejects anything <= 0. + for _, value := range []string{"0", "-1"} { + t.Run(value, func(t *testing.T) { + cmd := newQueryCmd() + cmd.SetArgs([]string{"--concurrency", value, "--output", "json", "SELECT 1", "SELECT 2"}) + err := cmd.Execute() + require.ErrorIs(t, err, errInvalidBatchConcurrency) + }) + } } func TestQueryCommandOutputFlagIsCaseInsensitive(t *testing.T) { From a1c5ca637443c2a3f1f2153d3c0b5785a3849eb8 Mon Sep 17 00:00:00 2001 From: simon Date: Tue, 28 Apr 2026 10:13:02 +0200 Subject: [PATCH 5/5] aitools: detach cancel-RPC ctx from cancelled parent Address Arseni's P2 finding on the batch PR. cancelInFlight (batch.go) and cancelStatement (query.go) used to derive the cancel-RPC ctx via context.WithTimeout(ctx, cancelTimeout). On the actual hot path (Ctrl+C or parent ctx cancelled), the inbound ctx is already cancelled by the time we reach the cancel sweep. The SDK then short-circuits on ctx.Err() and the cancel RPC never reaches the warehouse, leaving in-flight statements running server-side. Wrap with context.WithoutCancel(ctx) (Go 1.21+) so the timeout context keeps the caller's values but drops the cancellation signal. The cancel RPC now actually fires. Also tighten the existing tests: - TestExecuteBatchContextCancellationCancelsInFlight - TestExecuteAndPollCancelledContextCallsCancelExecution Both previously matched mock.Anything for the ctx argument, so they passed regardless of whether the bug was present. They now use mock.MatchedBy(c.Err() == nil) to assert the cancel-RPC ctx is alive. This is a regression guard; reverting the production fix makes the tests fail with "unexpected call" because the matcher no longer matches. Co-authored-by: Isaac --- experimental/aitools/cmd/batch.go | 6 +++++- experimental/aitools/cmd/batch_test.go | 10 ++++++++-- experimental/aitools/cmd/query.go | 7 +++++-- experimental/aitools/cmd/query_test.go | 8 ++++++-- 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/experimental/aitools/cmd/batch.go b/experimental/aitools/cmd/batch.go index 3f8fc3015bb..38ecea531e6 100644 --- a/experimental/aitools/cmd/batch.go +++ b/experimental/aitools/cmd/batch.go @@ -198,7 +198,11 @@ func cancelInFlight(ctx context.Context, api sql.StatementExecutionInterface, st // marker meaning the goroutine bailed without telling the server. // Either way, send CancelExecution. } - cancelCtx, cancel := context.WithTimeout(ctx, cancelTimeout) + // Detach from the inbound ctx (which is typically already cancelled by + // the time we reach this sweep): WithoutCancel keeps the caller's + // values but drops the cancellation signal so the cancel RPC actually + // reaches the warehouse instead of short-circuiting on ctx.Err(). + cancelCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), cancelTimeout) if err := api.CancelExecution(cancelCtx, sql.CancelExecutionRequest{StatementId: sid}); err != nil { log.Warnf(ctx, "Failed to cancel statement %s: %v", sid, err) } diff --git a/experimental/aitools/cmd/batch_test.go b/experimental/aitools/cmd/batch_test.go index 96235530f4d..f6f468768f9 100644 --- a/experimental/aitools/cmd/batch_test.go +++ b/experimental/aitools/cmd/batch_test.go @@ -207,10 +207,16 @@ func TestExecuteBatchPreservesInputOrder(t *testing.T) { func TestExecuteBatchContextCancellationCancelsInFlight(t *testing.T) { // All statements are PENDING when the context is cancelled. cancelInFlight - // sweeps the in-flight set with CancelExecution. + // sweeps the in-flight set with CancelExecution. Each cancel RPC must + // carry a NON-cancelled context, otherwise the SDK short-circuits on + // ctx.Err() and never reaches the warehouse. ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context())) mockAPI := mocksql.NewMockStatementExecutionInterface(t) + aliveCtx := mock.MatchedBy(func(c context.Context) bool { + return c.Err() == nil + }) + for i, sqlStr := range []string{"q1", "q2", "q3"} { sid := fmt.Sprintf("stmt-%d", i+1) mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { @@ -220,7 +226,7 @@ func TestExecuteBatchContextCancellationCancelsInFlight(t *testing.T) { Status: &sql.StatementStatus{State: sql.StatementStatePending}, }, nil).Once() - mockAPI.EXPECT().CancelExecution(mock.Anything, sql.CancelExecutionRequest{ + mockAPI.EXPECT().CancelExecution(aliveCtx, sql.CancelExecutionRequest{ StatementId: sid, }).Return(nil).Once() } diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index afe544c0e26..7e9ae1d030d 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -345,8 +345,11 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa // cancelStatement performs best-effort server-side cancellation. // Called on any poll exit due to context cancellation (signal or parent). cancelStatement := func() { - // Use the parent context (ctx), not the cancelled pollCtx. - cancelCtx, cancel := context.WithTimeout(ctx, cancelTimeout) + // Detach from any cancellation on the inbound ctx (the caller might + // have cancelled the parent before invoking this path): WithoutCancel + // preserves values but drops cancellation so the cancel RPC actually + // reaches the warehouse. + cancelCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), cancelTimeout) defer cancel() if err := api.CancelExecution(cancelCtx, sql.CancelExecutionRequest{ StatementId: statementID, diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index e6bf9362fff..59de11d578a 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -146,8 +146,12 @@ func TestExecuteAndPollCancelledContextCallsCancelExecution(t *testing.T) { Status: &sql.StatementStatus{State: sql.StatementStatePending}, }, nil) - // CancelExecution must be called when context is cancelled (not just on signal). - mockAPI.EXPECT().CancelExecution(mock.Anything, sql.CancelExecutionRequest{ + // CancelExecution must be called when context is cancelled (not just on + // signal). Assert the RPC's own ctx is NOT cancelled, otherwise the SDK + // would short-circuit on ctx.Err() and never reach the warehouse. + mockAPI.EXPECT().CancelExecution(mock.MatchedBy(func(c context.Context) bool { + return c.Err() == nil + }), sql.CancelExecutionRequest{ StatementId: "stmt-1", }).Return(nil).Once()