diff --git a/experimental/aitools/README.md b/experimental/aitools/README.md index 571136538c..f645e4de51 100644 --- a/experimental/aitools/README.md +++ b/experimental/aitools/README.md @@ -16,6 +16,18 @@ Current behavior: - `skills install` installs Databricks skills for detected coding agents. - `install` is a compatibility alias for `skills install`. - `tools` exposes a small set of AI-oriented workspace helpers. +- `tools query` accepts a single SQL or multiple SQLs in one invocation. Pass + several positional arguments and/or repeat `--file` to run them in parallel + against the warehouse. Multi-query output is always JSON; control parallelism + with `--concurrency` (default 8). + + ```bash + databricks experimental aitools tools query \ + --warehouse --output json \ + "SELECT count(*) FROM samples.nyctaxi.trips" \ + "SELECT min(tpep_pickup_datetime), max(tpep_pickup_datetime) FROM samples.nyctaxi.trips" \ + "SELECT vendor_id, count(*) FROM samples.nyctaxi.trips GROUP BY 1" + ``` Removed behavior: diff --git a/experimental/aitools/cmd/batch.go b/experimental/aitools/cmd/batch.go new file mode 100644 index 0000000000..38ecea531e --- /dev/null +++ b/experimental/aitools/cmd/batch.go @@ -0,0 +1,215 @@ +package aitools + +import ( + "context" + "errors" + "fmt" + "os" + "os/signal" + "sync/atomic" + "syscall" + "time" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go/service/sql" + "golang.org/x/sync/errgroup" +) + +// defaultBatchConcurrency caps in-flight statements when --concurrency is unset. +// Matches the default used by cmd/fs/cp.go for similar fan-out work. +const defaultBatchConcurrency = 8 + +// errInvalidBatchConcurrency is returned when --concurrency is set to a value +// that errgroup.SetLimit can't honor (0 deadlocks, negative removes the cap). +var errInvalidBatchConcurrency = errors.New("--concurrency must be at least 1") + +// batchResult is the per-statement payload emitted in batch mode JSON output. +// State is the server-reported terminal state. Error is set whenever the +// statement did not produce usable rows, regardless of state, so consumers +// can branch on `error == null` alone. +type batchResult struct { + SQL string `json:"sql"` + StatementID string `json:"statement_id,omitempty"` + State sql.StatementState `json:"state,omitempty"` + ElapsedMs int64 `json:"elapsed_ms"` + Columns []string `json:"columns,omitempty"` + Rows [][]string `json:"rows,omitempty"` + Error *batchResultError `json:"error,omitempty"` +} + +// batchResultError captures user-visible error info for a failed statement. +type batchResultError struct { + Message string `json:"message"` + ErrorCode string `json:"error_code,omitempty"` +} + +// executeBatch submits sqls against the warehouse in parallel, polls each to +// completion, and returns one batchResult per input in input order. +// +// Individual statement failures do not abort siblings; failures are encoded in +// the per-result Error field so callers can render partial results. +// +// On context cancellation (Ctrl+C or parent context), still-running statements +// are cancelled server-side via CancelExecution. Statements that finished +// before cancellation are left as-is. +func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, warehouseID string, sqls []string, concurrency int) []batchResult { + pollCtx, pollCancel := context.WithCancel(ctx) + defer pollCancel() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(sigCh) + + go func() { + select { + case <-sigCh: + log.Infof(ctx, "Received interrupt, cancelling %d in-flight queries", len(sqls)) + pollCancel() + case <-pollCtx.Done(): + } + }() + + sp := cmdio.NewSpinner(pollCtx) + defer sp.Close() + sp.Update(fmt.Sprintf("Executing %d queries...", len(sqls))) + + var completed atomic.Int64 + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + go func() { + for { + select { + case <-pollCtx.Done(): + return + case <-ticker.C: + sp.Update(fmt.Sprintf("Executing %d queries... (%d/%d done)", len(sqls), completed.Load(), len(sqls))) + } + } + }() + + results := make([]batchResult, len(sqls)) + // Each goroutine writes to a distinct slot, safe without a mutex. + // We read after g.Wait(), establishing happens-before for all writes. + statementIDs := make([]string, len(sqls)) + + g := new(errgroup.Group) + g.SetLimit(concurrency) + for i, sqlStr := range sqls { + g.Go(func() error { + results[i] = runOneBatchQuery(pollCtx, api, warehouseID, sqlStr, statementIDs, i) + completed.Add(1) + return nil + }) + } + _ = g.Wait() + + // pollStatement is a pure helper that returns ctx.Err() on cancellation + // without touching the server. Sweep any not-yet-terminal statements here. + if pollCtx.Err() != nil { + cancelInFlight(ctx, api, statementIDs, results) + } + + return results +} + +// runOneBatchQuery submits one SQL, polls to completion, and returns its +// batchResult. All errors are encoded into the result; never returns an error. +func runOneBatchQuery(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, sqlStr string, statementIDs []string, idx int) batchResult { + start := time.Now() + result := batchResult{SQL: sqlStr} + + resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ + WarehouseId: warehouseID, + Statement: sqlStr, + WaitTimeout: "0s", + OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, + }) + if err != nil { + if ctx.Err() != nil { + result.State = sql.StatementStateCanceled + result.Error = &batchResultError{Message: "submission cancelled"} + } else { + result.State = sql.StatementStateFailed + result.Error = &batchResultError{Message: fmt.Sprintf("execute statement: %v", err)} + } + result.ElapsedMs = time.Since(start).Milliseconds() + return result + } + + statementIDs[idx] = resp.StatementId + result.StatementID = resp.StatementId + + pollResp, err := pollStatement(ctx, api, resp) + if err != nil { + if ctx.Err() != nil { + result.State = sql.StatementStateCanceled + result.Error = &batchResultError{Message: "cancelled"} + } else { + result.State = sql.StatementStateFailed + result.Error = &batchResultError{Message: err.Error()} + } + result.ElapsedMs = time.Since(start).Milliseconds() + return result + } + + if pollResp.Status != nil { + result.State = pollResp.Status.State + } + + if result.State != sql.StatementStateSucceeded { + result.Error = &batchResultError{} + if pollResp.Status != nil && pollResp.Status.Error != nil { + result.Error.Message = pollResp.Status.Error.Message + result.Error.ErrorCode = string(pollResp.Status.Error.ErrorCode) + } else { + result.Error.Message = fmt.Sprintf("query reached terminal state %s", result.State) + } + result.ElapsedMs = time.Since(start).Milliseconds() + return result + } + + result.Columns = extractColumns(pollResp.Manifest) + rows, err := fetchAllRows(ctx, api, pollResp) + if err != nil { + result.Error = &batchResultError{Message: fmt.Sprintf("fetch rows: %v", err)} + result.ElapsedMs = time.Since(start).Milliseconds() + return result + } + result.Rows = rows + result.ElapsedMs = time.Since(start).Milliseconds() + return result +} + +// cancelInFlight sends CancelExecution for every statement that didn't reach +// a terminal state server-side before context cancellation. Best effort: errors +// are logged at warn but don't fail the batch. +func cancelInFlight(ctx context.Context, api sql.StatementExecutionInterface, statementIDs []string, results []batchResult) { + var cancelled int + for i, sid := range statementIDs { + if sid == "" { + continue + } + switch results[i].State { + case sql.StatementStateSucceeded, sql.StatementStateFailed, sql.StatementStateClosed: + continue + case sql.StatementStateCanceled, sql.StatementStatePending, sql.StatementStateRunning: + // Either still running server-side, or our internal "canceled" + // marker meaning the goroutine bailed without telling the server. + // Either way, send CancelExecution. + } + // Detach from the inbound ctx (which is typically already cancelled by + // the time we reach this sweep): WithoutCancel keeps the caller's + // values but drops the cancellation signal so the cancel RPC actually + // reaches the warehouse instead of short-circuiting on ctx.Err(). + cancelCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), cancelTimeout) + if err := api.CancelExecution(cancelCtx, sql.CancelExecutionRequest{StatementId: sid}); err != nil { + log.Warnf(ctx, "Failed to cancel statement %s: %v", sid, err) + } + cancel() + cancelled++ + } + if cancelled > 0 { + cmdio.LogString(ctx, fmt.Sprintf("Cancelled %d in-flight queries.", cancelled)) + } +} diff --git a/experimental/aitools/cmd/batch_test.go b/experimental/aitools/cmd/batch_test.go new file mode 100644 index 0000000000..f6f468768f --- /dev/null +++ b/experimental/aitools/cmd/batch_test.go @@ -0,0 +1,243 @@ +package aitools + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + "testing" + + "github.com/databricks/cli/libs/cmdio" + mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestRenderBatchJSON(t *testing.T) { + results := []batchResult{ + { + SQL: "SELECT 1", + StatementID: "stmt-1", + State: sql.StatementStateSucceeded, + ElapsedMs: 42, + Columns: []string{"n"}, + Rows: [][]string{{"1"}}, + }, + { + SQL: "SELECT bad_syntax", + StatementID: "stmt-2", + State: sql.StatementStateFailed, + ElapsedMs: 12, + Error: &batchResultError{ + Message: "near 'bad_syntax': syntax error", + ErrorCode: "SYNTAX_ERROR", + }, + }, + } + + var buf strings.Builder + err := renderBatchJSON(&buf, results) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, `"sql": "SELECT 1"`) + assert.Contains(t, output, `"statement_id": "stmt-1"`) + assert.Contains(t, output, `"state": "SUCCEEDED"`) + assert.Contains(t, output, `"elapsed_ms": 42`) + assert.Contains(t, output, `"columns": [`) + assert.Contains(t, output, `"rows": [`) + assert.Contains(t, output, `"sql": "SELECT bad_syntax"`) + assert.Contains(t, output, `"error": {`) + assert.Contains(t, output, `"error_code": "SYNTAX_ERROR"`) + // Trailing newline. + assert.True(t, strings.HasSuffix(output, "\n")) +} + +func TestExecuteBatchAllSucceed(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + sqls := []string{"SELECT 1", "SELECT 2", "SELECT 3"} + for i, sqlStr := range sqls { + sid := fmt.Sprintf("stmt-%d", i+1) + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == sqlStr + })).Return(&sql.StatementResponse{ + StatementId: sid, + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "n"}}}}, + Result: &sql.ResultData{DataArray: [][]string{{strconv.Itoa(i + 1)}}}, + }, nil).Once() + } + + results := executeBatch(ctx, mockAPI, "wh-123", sqls, 8) + + require.Len(t, results, 3) + for i, r := range results { + assert.Equal(t, sqls[i], r.SQL, "result %d sql", i) + assert.Equal(t, sql.StatementStateSucceeded, r.State, "result %d state", i) + assert.Nil(t, r.Error, "result %d error", i) + assert.Equal(t, []string{"n"}, r.Columns, "result %d columns", i) + assert.Equal(t, [][]string{{strconv.Itoa(i + 1)}}, r.Rows, "result %d rows", i) + assert.NotEmpty(t, r.StatementID, "result %d statement_id", i) + } +} + +func TestExecuteBatchPartialFailure(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT 1" + })).Return(&sql.StatementResponse{ + StatementId: "stmt-good", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "n"}}}}, + Result: &sql.ResultData{DataArray: [][]string{{"1"}}}, + }, nil).Once() + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT bad" + })).Return(&sql.StatementResponse{ + StatementId: "stmt-bad", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ + ErrorCode: "SYNTAX_ERROR", + Message: "near 'bad': syntax error", + }, + }, + }, nil).Once() + + results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT 1", "SELECT bad"}, 8) + + require.Len(t, results, 2) + assert.Nil(t, results[0].Error) + assert.Equal(t, sql.StatementStateSucceeded, results[0].State) + + require.NotNil(t, results[1].Error) + assert.Equal(t, sql.StatementStateFailed, results[1].State) + assert.Equal(t, "SYNTAX_ERROR", results[1].Error.ErrorCode) + assert.Contains(t, results[1].Error.Message, "syntax error") +} + +func TestExecuteBatchSubmissionFailure(t *testing.T) { + // ExecuteStatement transport error is encoded into the per-result error, + // not propagated up to abort siblings. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT good" + })).Return(&sql.StatementResponse{ + StatementId: "stmt-good", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + }, nil).Once() + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT broken" + })).Return(nil, errors.New("network unreachable")).Once() + + results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT good", "SELECT broken"}, 8) + + require.Len(t, results, 2) + assert.Nil(t, results[0].Error) + require.NotNil(t, results[1].Error) + assert.Contains(t, results[1].Error.Message, "execute statement") + assert.Contains(t, results[1].Error.Message, "network unreachable") + assert.Empty(t, results[1].StatementID) +} + +func TestExecuteBatchSetsOnWaitTimeoutContinue(t *testing.T) { + // Guards against a silent SDK default flip from CONTINUE to CANCEL. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.WaitTimeout == "0s" && req.OnWaitTimeout == sql.ExecuteStatementRequestOnWaitTimeoutContinue + })).Return(&sql.StatementResponse{ + StatementId: "stmt-x", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + }, nil).Times(2) + + results := executeBatch(ctx, mockAPI, "wh-123", []string{"q1", "q2"}, 8) + require.Len(t, results, 2) +} + +func TestExecuteBatchPreservesInputOrder(t *testing.T) { + // Index 0 is slow (PENDING then SUCCEEDED on first poll); 1 and 2 are + // immediate. Despite the staggered completion, results stay in input order. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT 'slow'" + })).Return(&sql.StatementResponse{ + StatementId: "stmt-slow", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + }, nil).Once() + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-slow").Return(&sql.StatementResponse{ + StatementId: "stmt-slow", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + }, nil).Once() + + for i, sqlStr := range []string{"SELECT 'fast1'", "SELECT 'fast2'"} { + sid := fmt.Sprintf("stmt-fast-%d", i+1) + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == sqlStr + })).Return(&sql.StatementResponse{ + StatementId: sid, + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + }, nil).Once() + } + + sqls := []string{"SELECT 'slow'", "SELECT 'fast1'", "SELECT 'fast2'"} + results := executeBatch(ctx, mockAPI, "wh-1", sqls, 8) + + require.Len(t, results, 3) + for i, r := range results { + assert.Equal(t, sqls[i], r.SQL, "result %d", i) + assert.Equal(t, sql.StatementStateSucceeded, r.State, "result %d", i) + } +} + +func TestExecuteBatchContextCancellationCancelsInFlight(t *testing.T) { + // All statements are PENDING when the context is cancelled. cancelInFlight + // sweeps the in-flight set with CancelExecution. Each cancel RPC must + // carry a NON-cancelled context, otherwise the SDK short-circuits on + // ctx.Err() and never reaches the warehouse. + ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context())) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + aliveCtx := mock.MatchedBy(func(c context.Context) bool { + return c.Err() == nil + }) + + for i, sqlStr := range []string{"q1", "q2", "q3"} { + sid := fmt.Sprintf("stmt-%d", i+1) + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == sqlStr + })).Return(&sql.StatementResponse{ + StatementId: sid, + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + }, nil).Once() + + mockAPI.EXPECT().CancelExecution(aliveCtx, sql.CancelExecutionRequest{ + StatementId: sid, + }).Return(nil).Once() + } + + cancel() + + results := executeBatch(ctx, mockAPI, "wh", []string{"q1", "q2", "q3"}, 8) + + require.Len(t, results, 3) + for i, r := range results { + assert.Equal(t, sql.StatementStateCanceled, r.State, "result %d state", i) + require.NotNil(t, r.Error, "result %d error", i) + } +} diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index 7b95fdd4e2..7e9ae1d030 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -75,32 +75,47 @@ func selectQueryOutputMode(outputType flags.Output, stdoutInteractive, promptSup func newQueryCmd() *cobra.Command { var warehouseID string - var filePath string + var filePaths []string var outputFormat string + var concurrency int cmd := &cobra.Command{ - Use: "query [SQL | file.sql]", + Use: "query [SQL | file.sql]...", Short: "Execute SQL against a Databricks warehouse", - Long: `Execute a SQL statement against a Databricks SQL warehouse and return results. + Long: `Execute one or more SQL statements against a Databricks SQL warehouse +and return results. -SQL can be provided as a positional argument, read from a file with --file, -or piped via stdin. If the positional argument ends in .sql and the file -exists, it is read as a SQL file automatically. +A single SQL can be provided as a positional argument, read from a file with +--file, or piped via stdin. If a positional argument ends in .sql and the +file exists, it is read as a SQL file automatically. + +Pass multiple positional arguments and/or repeat --file to run several +queries in parallel against the warehouse. Multi-query output is always +JSON: an array of {sql, statement_id, state, elapsed_ms, columns, rows, +error} objects. Result order is: --file inputs first (in flag order), +then positional SQLs (in arg order). The exit code is non-zero if any +query failed. The command auto-detects an available warehouse unless --warehouse is set or the DATABRICKS_WAREHOUSE_ID environment variable is configured. -Output is JSON in non-interactive contexts. In interactive terminals it renders -tables, and large results open an interactive table browser. Use --output csv -to export results as CSV.`, +For a single query, output is JSON in non-interactive contexts. In +interactive terminals it renders tables, and large results open an +interactive table browser. Use --output csv to export results as CSV.`, Example: ` databricks experimental aitools tools query "SELECT * FROM samples.nyctaxi.trips LIMIT 5" databricks experimental aitools tools query --warehouse abc123 "SELECT 1" databricks experimental aitools tools query --file report.sql databricks experimental aitools tools query report.sql databricks experimental aitools tools query --output csv "SELECT * FROM samples.nyctaxi.trips LIMIT 5" + databricks experimental aitools tools query --output json "SELECT 1" "SELECT 2" "SELECT 3" echo "SELECT 1" | databricks experimental aitools tools query`, - Args: cobra.MaximumNArgs(1), - PreRunE: root.MustWorkspaceClient, + Args: cobra.ArbitraryArgs, + PreRunE: func(cmd *cobra.Command, args []string) error { + if concurrency <= 0 { + return errInvalidBatchConcurrency + } + return root.MustWorkspaceClient(cmd, args) + }, RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() @@ -124,19 +139,29 @@ to export results as CSV.`, return fmt.Errorf("unsupported output format %q, accepted values: text, json, csv", outputFormat) } - w := cmdctx.WorkspaceClient(ctx) - - sqlStatement, err := resolveSQL(ctx, cmd, args, filePath) + sqls, err := resolveSQLs(ctx, cmd, args, filePaths) if err != nil { return err } + // Reject incompatible flag combinations before any API call so the + // user sees the real error instead of an auth/warehouse failure. + if len(sqls) > 1 && flags.Output(outputFormat) != flags.OutputJSON { + return fmt.Errorf("multiple queries require --output json (got %q); pass --output json to receive a JSON array of per-statement results", outputFormat) + } + + w := cmdctx.WorkspaceClient(ctx) + wID, err := resolveWarehouseID(ctx, w, warehouseID) if err != nil { return err } - resp, err := executeAndPoll(ctx, w.StatementExecution, wID, sqlStatement) + if len(sqls) > 1 { + return runBatch(ctx, cmd, w.StatementExecution, wID, sqls, concurrency) + } + + resp, err := executeAndPoll(ctx, w.StatementExecution, wID, sqls[0]) if err != nil { return err } @@ -177,7 +202,8 @@ to export results as CSV.`, } cmd.Flags().StringVarP(&warehouseID, "warehouse", "w", "", "SQL warehouse ID to use for execution") - cmd.Flags().StringVarP(&filePath, "file", "f", "", "Path to a SQL file to execute") + cmd.Flags().StringSliceVarP(&filePaths, "file", "f", nil, "Path to a SQL file to execute (repeatable; pair with positional SQLs to run a batch)") + cmd.Flags().IntVar(&concurrency, "concurrency", defaultBatchConcurrency, "Maximum in-flight statements when running a batch of queries") // Local --output flag shadows the root command's persistent --output flag, // adding csv support for this command only. cmd.Flags().StringVarP(&outputFormat, "output", "o", string(flags.OutputText), "Output format: text, json, or csv") @@ -188,59 +214,85 @@ to export results as CSV.`, return cmd } -// resolveSQL determines the SQL statement to execute from the available input sources. -// Priority: --file flag > positional arg > stdin. -func resolveSQL(ctx context.Context, cmd *cobra.Command, args []string, filePath string) (string, error) { - var raw string +// resolveSQLs collects SQL statements from --file paths, positional args, and +// stdin. The returned slice preserves source order: --file paths first (in flag +// order), then positional args (in arg order), then stdin (only if no other +// source produced anything). Each SQL is run through cleanSQL. +func resolveSQLs(ctx context.Context, cmd *cobra.Command, args, filePaths []string) ([]string, error) { + var raws []string - switch { - case filePath != "": - if len(args) > 0 { - return "", errors.New("cannot use both --file and a positional SQL argument") - } - data, err := os.ReadFile(filePath) + for _, path := range filePaths { + data, err := os.ReadFile(path) if err != nil { - return "", fmt.Errorf("read SQL file: %w", err) + return nil, fmt.Errorf("read SQL file %s: %w", path, err) } - raw = string(data) + raws = append(raws, string(data)) + } - case len(args) > 0: + for _, arg := range args { // If the argument looks like a .sql file, try to read it. // Only fall through to literal SQL if the file doesn't exist. // Surface other errors (permission denied, etc.) directly. - if strings.HasSuffix(args[0], sqlFileExtension) { - data, err := os.ReadFile(args[0]) + if strings.HasSuffix(arg, sqlFileExtension) { + data, err := os.ReadFile(arg) if err != nil && !errors.Is(err, os.ErrNotExist) { - return "", fmt.Errorf("read SQL file: %w", err) + return nil, fmt.Errorf("read SQL file: %w", err) } if err == nil { - raw = string(data) - break + raws = append(raws, string(data)) + continue } } - raw = args[0] + raws = append(raws, arg) + } - default: - // No args: try reading from stdin if it's piped. + if len(raws) == 0 { + // No --file and no positional args: try reading from stdin if it's piped. // If stdin was overridden (e.g. cmd.SetIn in tests), always read from it. // Otherwise, only read if stdin is not a TTY (i.e. piped input). in := cmd.InOrStdin() _, isOsFile := in.(*os.File) if isOsFile && cmdio.IsPromptSupported(ctx) { - return "", errors.New("no SQL provided; pass a SQL string, use --file, or pipe via stdin") + return nil, errors.New("no SQL provided; pass a SQL string, use --file, or pipe via stdin") } data, err := io.ReadAll(in) if err != nil { - return "", fmt.Errorf("read stdin: %w", err) + return nil, fmt.Errorf("read stdin: %w", err) + } + raws = append(raws, string(data)) + } + + cleaned := make([]string, 0, len(raws)) + for i, raw := range raws { + c := cleanSQL(raw) + if c == "" { + if len(raws) == 1 { + return nil, errors.New("SQL statement is empty after removing comments and blank lines") + } + return nil, fmt.Errorf("SQL statement #%d is empty after removing comments and blank lines", i+1) } - raw = string(data) + cleaned = append(cleaned, c) + } + return cleaned, nil +} + +// runBatch executes multiple SQL statements in parallel and renders the result +// as a JSON array. Returns root.ErrAlreadyPrinted (so the exit code is non-zero +// without an extra error message) when any statement failed; the failure detail +// is already encoded in the printed JSON. The caller is responsible for +// rejecting incompatible output formats before invoking this. +func runBatch(ctx context.Context, cmd *cobra.Command, api sql.StatementExecutionInterface, warehouseID string, sqls []string, concurrency int) error { + results := executeBatch(ctx, api, warehouseID, sqls, concurrency) + if err := renderBatchJSON(cmd.OutOrStdout(), results); err != nil { + return err } - result := cleanSQL(raw) - if result == "" { - return "", errors.New("SQL statement is empty after removing comments and blank lines") + for _, r := range results { + if r.Error != nil { + return root.ErrAlreadyPrinted + } } - return result, nil + return nil } // resolveWarehouseID returns the warehouse ID to use for query execution. @@ -262,9 +314,10 @@ func resolveWarehouseID(ctx context.Context, w any, flagValue string) (string, e func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, statement string) (*sql.StatementResponse, error) { // Submit asynchronously to get the statement ID immediately for cancellation. resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ - WarehouseId: warehouseID, - Statement: statement, - WaitTimeout: "0s", + WarehouseId: warehouseID, + Statement: statement, + WaitTimeout: "0s", + OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, }) if err != nil { return nil, fmt.Errorf("execute statement: %w", err) @@ -272,11 +325,6 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa statementID := resp.StatementId - // Check if it completed immediately. - if isTerminalState(resp.Status) { - return resp, checkFailedState(resp.Status) - } - // Set up Ctrl+C: signal cancels the poll context, cleanup is unified below. pollCtx, pollCancel := context.WithCancel(ctx) defer pollCancel() @@ -297,8 +345,11 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa // cancelStatement performs best-effort server-side cancellation. // Called on any poll exit due to context cancellation (signal or parent). cancelStatement := func() { - // Use the parent context (ctx), not the cancelled pollCtx. - cancelCtx, cancel := context.WithTimeout(ctx, cancelTimeout) + // Detach from any cancellation on the inbound ctx (the caller might + // have cancelled the parent before invoking this path): WithoutCancel + // preserves values but drops cancellation so the cancel RPC actually + // reaches the warehouse. + cancelCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), cancelTimeout) defer cancel() if err := api.CancelExecution(cancelCtx, sql.CancelExecutionRequest{ StatementId: statementID, @@ -327,34 +378,59 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa } }() + pollResp, err := pollStatement(pollCtx, api, resp) + if err != nil { + if pollCtx.Err() != nil { + cancelStatement() + cmdio.LogString(ctx, "Query cancelled.") + return nil, root.ErrAlreadyPrinted + } + return nil, err + } + + sp.Close() + if err := checkFailedState(pollResp.Status); err != nil { + return nil, err + } + return pollResp, nil +} + +// pollStatement polls until the statement reaches a terminal state. +// +// On context cancellation it returns the context error WITHOUT cancelling the +// server-side statement. Callers that want server-side cancellation should +// invoke CancelExecution explicitly. +// +// If the input response is already in a terminal state, it is returned without +// further polling. +func pollStatement(ctx context.Context, api sql.StatementExecutionInterface, resp *sql.StatementResponse) (*sql.StatementResponse, error) { + if isTerminalState(resp.Status) { + return resp, nil + } + + statementID := resp.StatementId + start := time.Now() + // Poll with additive backoff: 1s, 2s, 3s, 4s, 5s (capped). interval := pollIntervalInitial for { select { - case <-pollCtx.Done(): - cancelStatement() - cmdio.LogString(ctx, "Query cancelled.") - return nil, root.ErrAlreadyPrinted + case <-ctx.Done(): + return nil, ctx.Err() case <-time.After(interval): } log.Debugf(ctx, "Polling statement %s: %s elapsed", statementID, time.Since(start).Truncate(time.Second)) - pollResp, err := api.GetStatementByStatementId(pollCtx, statementID) + pollResp, err := api.GetStatementByStatementId(ctx, statementID) if err != nil { - if pollCtx.Err() != nil { - cancelStatement() - cmdio.LogString(ctx, "Query cancelled.") - return nil, root.ErrAlreadyPrinted + if ctx.Err() != nil { + return nil, ctx.Err() } return nil, fmt.Errorf("poll statement status: %w", err) } if isTerminalState(pollResp.Status) { - sp.Close() - if err := checkFailedState(pollResp.Status); err != nil { - return nil, err - } return &sql.StatementResponse{ StatementId: pollResp.StatementId, Status: pollResp.Status, diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index aa33921c83..59de11d578 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -2,6 +2,7 @@ package aitools import ( "context" + "errors" "os" "path/filepath" "strings" @@ -48,7 +49,9 @@ func TestExecuteAndPollImmediateSuccess(t *testing.T) { mockAPI := mocksql.NewMockStatementExecutionInterface(t) mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { - return req.WarehouseId == "wh-123" && req.Statement == "SELECT 1" && req.WaitTimeout == "0s" + return req.WarehouseId == "wh-123" && req.Statement == "SELECT 1" && + req.WaitTimeout == "0s" && + req.OnWaitTimeout == sql.ExecuteStatementRequestOnWaitTimeoutContinue })).Return(&sql.StatementResponse{ StatementId: "stmt-1", Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, @@ -143,8 +146,12 @@ func TestExecuteAndPollCancelledContextCallsCancelExecution(t *testing.T) { Status: &sql.StatementStatus{State: sql.StatementStatePending}, }, nil) - // CancelExecution must be called when context is cancelled (not just on signal). - mockAPI.EXPECT().CancelExecution(mock.Anything, sql.CancelExecutionRequest{ + // CancelExecution must be called when context is cancelled (not just on + // signal). Assert the RPC's own ctx is NOT cancelled, otherwise the SDK + // would short-circuit on ctx.Err() and never reach the warehouse. + mockAPI.EXPECT().CancelExecution(mock.MatchedBy(func(c context.Context) bool { + return c.Err() == nil + }), sql.CancelExecutionRequest{ StatementId: "stmt-1", }).Return(nil).Once() @@ -154,6 +161,106 @@ func TestExecuteAndPollCancelledContextCallsCancelExecution(t *testing.T) { require.ErrorIs(t, err, root.ErrAlreadyPrinted) } +func TestPollStatementImmediateTerminal(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + resp := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "1"}}}}, + Result: &sql.ResultData{DataArray: [][]string{{"1"}}}, + } + + pollResp, err := pollStatement(ctx, mockAPI, resp) + require.NoError(t, err) + assert.Equal(t, sql.StatementStateSucceeded, pollResp.Status.State) + assert.Equal(t, "stmt-1", pollResp.StatementId) +} + +func TestPollStatementTerminalFailureNotErrored(t *testing.T) { + // pollStatement returns the response without erroring on failed terminal + // states; callers (e.g. executeAndPoll) decide what to do via checkFailedState. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + resp := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ErrorCode: "ERR", Message: "boom"}, + }, + } + + pollResp, err := pollStatement(ctx, mockAPI, resp) + require.NoError(t, err) + assert.Equal(t, sql.StatementStateFailed, pollResp.Status.State) +} + +func TestPollStatementEventualSuccess(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + initial := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + } + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateRunning}, + }, nil).Once() + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Result: &sql.ResultData{DataArray: [][]string{{"42"}}}, + }, nil).Once() + + pollResp, err := pollStatement(ctx, mockAPI, initial) + require.NoError(t, err) + assert.Equal(t, sql.StatementStateSucceeded, pollResp.Status.State) + assert.Equal(t, [][]string{{"42"}}, pollResp.Result.DataArray) +} + +func TestPollStatementContextCancellationDoesNotCancelServerSide(t *testing.T) { + // The mock asserts (via t.Cleanup) that no unexpected calls are made. + // Specifically, pollStatement must NOT call CancelExecution on context + // cancellation; that is the caller's responsibility. + ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context())) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + initial := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + } + + cancel() + + pollResp, err := pollStatement(ctx, mockAPI, initial) + require.ErrorIs(t, err, context.Canceled) + assert.Nil(t, pollResp) +} + +func TestPollStatementGetErrorPropagated(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + initial := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + } + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1"). + Return(nil, errors.New("network unreachable")).Once() + + pollResp, err := pollStatement(ctx, mockAPI, initial) + require.Error(t, err) + assert.Contains(t, err.Error(), "poll statement status") + assert.Contains(t, err.Error(), "network unreachable") + assert.Nil(t, pollResp) +} + func TestResolveWarehouseIDWithFlag(t *testing.T) { ctx := t.Context() id, err := resolveWarehouseID(ctx, nil, "explicit-id") @@ -330,69 +437,95 @@ func TestPollingConstants(t *testing.T) { assert.Equal(t, 10*time.Second, cancelTimeout) } -// newTestCmd creates a minimal cobra.Command for testing resolveSQL. +// newTestCmd creates a minimal cobra.Command for testing resolveSQLs. func newTestCmd() *cobra.Command { return &cobra.Command{Use: "test"} } -func TestResolveSQLFromFileFlag(t *testing.T) { +func TestResolveSQLsFromFileFlag(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "query.sql") err := os.WriteFile(path, []byte("SELECT 1"), 0o644) require.NoError(t, err) cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, path) + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{path}) require.NoError(t, err) - assert.Equal(t, "SELECT 1", result) + assert.Equal(t, []string{"SELECT 1"}, result) } -func TestResolveSQLFromFileFlagWithComments(t *testing.T) { +func TestResolveSQLsFromFileFlagWithComments(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "query.sql") err := os.WriteFile(path, []byte("-- header comment\nSELECT 1\n-- trailing"), 0o644) require.NoError(t, err) cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, path) + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{path}) + require.NoError(t, err) + assert.Equal(t, []string{"SELECT 1"}, result) +} + +func TestResolveSQLsMixedFileAndPositional(t *testing.T) { + // --file paths are emitted before positional args, in flag order. + dir := t.TempDir() + path := filepath.Join(dir, "from-file.sql") + err := os.WriteFile(path, []byte("SELECT 'from file'"), 0o644) + require.NoError(t, err) + + cmd := newTestCmd() + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 'from arg'"}, []string{path}) require.NoError(t, err) - assert.Equal(t, "SELECT 1", result) + assert.Equal(t, []string{"SELECT 'from file'", "SELECT 'from arg'"}, result) } -func TestResolveSQLFileFlagConflictsWithArg(t *testing.T) { +func TestResolveSQLsMultiplePositional(t *testing.T) { cmd := newTestCmd() - _, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 1"}, "/some/file.sql") - require.Error(t, err) - assert.Contains(t, err.Error(), "cannot use both --file and a positional SQL argument") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 1", "SELECT 2", "SELECT 3"}, nil) + require.NoError(t, err) + assert.Equal(t, []string{"SELECT 1", "SELECT 2", "SELECT 3"}, result) +} + +func TestResolveSQLsMultipleFiles(t *testing.T) { + dir := t.TempDir() + pathA := filepath.Join(dir, "a.sql") + pathB := filepath.Join(dir, "b.sql") + require.NoError(t, os.WriteFile(pathA, []byte("SELECT 'a'"), 0o644)) + require.NoError(t, os.WriteFile(pathB, []byte("SELECT 'b'"), 0o644)) + + cmd := newTestCmd() + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{pathA, pathB}) + require.NoError(t, err) + assert.Equal(t, []string{"SELECT 'a'", "SELECT 'b'"}, result) } -func TestResolveSQLFromPositionalArg(t *testing.T) { +func TestResolveSQLsFromPositionalArg(t *testing.T) { cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 42"}, "") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 42"}, nil) require.NoError(t, err) - assert.Equal(t, "SELECT 42", result) + assert.Equal(t, []string{"SELECT 42"}, result) } -func TestResolveSQLAutoDetectsSQLFile(t *testing.T) { +func TestResolveSQLsAutoDetectsSQLFile(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "report.sql") err := os.WriteFile(path, []byte("SELECT * FROM sales"), 0o644) require.NoError(t, err) cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{path}, "") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{path}, nil) require.NoError(t, err) - assert.Equal(t, "SELECT * FROM sales", result) + assert.Equal(t, []string{"SELECT * FROM sales"}, result) } -func TestResolveSQLNonexistentSQLFileTreatedAsString(t *testing.T) { +func TestResolveSQLsNonexistentSQLFileTreatedAsString(t *testing.T) { cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{"nonexistent.sql"}, "") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"nonexistent.sql"}, nil) require.NoError(t, err) - assert.Equal(t, "nonexistent.sql", result) + assert.Equal(t, []string{"nonexistent.sql"}, result) } -func TestResolveSQLUnreadableSQLFileReturnsError(t *testing.T) { +func TestResolveSQLsUnreadableSQLFileReturnsError(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "locked.sql") err := os.WriteFile(path, []byte("SELECT 1"), 0o644) @@ -404,47 +537,54 @@ func TestResolveSQLUnreadableSQLFileReturnsError(t *testing.T) { t.Cleanup(func() { _ = os.Chmod(path, 0o644) }) cmd := newTestCmd() - _, err = resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{path}, "") + _, err = resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{path}, nil) require.Error(t, err) assert.Contains(t, err.Error(), "read SQL file") } -func TestResolveSQLFromStdin(t *testing.T) { +func TestResolveSQLsFromStdin(t *testing.T) { cmd := newTestCmd() cmd.SetIn(strings.NewReader("SELECT 1 FROM stdin_test")) - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, "") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, nil) require.NoError(t, err) - assert.Equal(t, "SELECT 1 FROM stdin_test", result) + assert.Equal(t, []string{"SELECT 1 FROM stdin_test"}, result) } -func TestResolveSQLEmptyFileReturnsError(t *testing.T) { +func TestResolveSQLsEmptyFileReturnsError(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "empty.sql") err := os.WriteFile(path, []byte(""), 0o644) require.NoError(t, err) cmd := newTestCmd() - _, err = resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, path) + _, err = resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{path}) require.Error(t, err) assert.Contains(t, err.Error(), "empty") } -func TestResolveSQLCommentsOnlyFileReturnsError(t *testing.T) { +func TestResolveSQLsCommentsOnlyFileReturnsError(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "comments.sql") err := os.WriteFile(path, []byte("-- just a comment\n-- another"), 0o644) require.NoError(t, err) cmd := newTestCmd() - _, err = resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, path) + _, err = resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{path}) require.Error(t, err) assert.Contains(t, err.Error(), "empty") } -func TestResolveSQLMissingFileReturnsError(t *testing.T) { +func TestResolveSQLsBatchEmptyAtIndexReturnsIndexedError(t *testing.T) { cmd := newTestCmd() - _, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, "/nonexistent/path/query.sql") + _, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 1", "-- comment only", "SELECT 3"}, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "SQL statement #2 is empty") +} + +func TestResolveSQLsMissingFileReturnsError(t *testing.T) { + cmd := newTestCmd() + _, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{"/nonexistent/path/query.sql"}) require.Error(t, err) assert.Contains(t, err.Error(), "read SQL file") } @@ -458,6 +598,34 @@ func TestQueryCommandUnsupportedOutputReturnsError(t *testing.T) { assert.Contains(t, err.Error(), "unsupported output format") } +func TestQueryCommandBatchOutputRejection(t *testing.T) { + // Multi-query mode is JSON-only. text and csv are rejected with an + // actionable error before any API call. + for _, format := range []string{"text", "csv"} { + t.Run(format, func(t *testing.T) { + cmd := newQueryCmd() + cmd.PreRunE = nil + cmd.SetArgs([]string{"--output", format, "SELECT 1", "SELECT 2"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "multiple queries require --output json") + }) + } +} + +func TestQueryCommandConcurrencyRejection(t *testing.T) { + // errgroup.SetLimit(0) deadlocks; negative removes the cap entirely. + // Both surprise users, so PreRunE rejects anything <= 0. + for _, value := range []string{"0", "-1"} { + t.Run(value, func(t *testing.T) { + cmd := newQueryCmd() + cmd.SetArgs([]string{"--concurrency", value, "--output", "json", "SELECT 1", "SELECT 2"}) + err := cmd.Execute() + require.ErrorIs(t, err, errInvalidBatchConcurrency) + }) + } +} + func TestQueryCommandOutputFlagIsCaseInsensitive(t *testing.T) { cmd := newQueryCmd() cmd.PreRunE = nil diff --git a/experimental/aitools/cmd/render.go b/experimental/aitools/cmd/render.go index 7727c37106..d0b62926c2 100644 --- a/experimental/aitools/cmd/render.go +++ b/experimental/aitools/cmd/render.go @@ -29,6 +29,17 @@ func extractColumns(manifest *sql.ResultManifest) []string { return columns } +// renderBatchJSON writes batch results as a JSON array. The array preserves +// input order and includes one object per submitted statement. +func renderBatchJSON(w io.Writer, results []batchResult) error { + output, err := json.MarshalIndent(results, "", " ") + if err != nil { + return fmt.Errorf("marshal batch results: %w", err) + } + fmt.Fprintf(w, "%s\n", output) + return nil +} + // renderJSON writes query results as a parseable JSON array to stdout. // Row count is written to stderr so stdout remains valid JSON for piping. func renderJSON(w io.Writer, columns []string, rows [][]string) error {