From 79fc08074e05ebeb45212ffc8ad1a1332a87b35e Mon Sep 17 00:00:00 2001 From: simon Date: Mon, 27 Apr 2026 15:04:52 +0200 Subject: [PATCH] aitools: extract pollStatement helper and pin OnWaitTimeout Refactor `executeAndPoll` in `experimental/aitools/cmd/query.go` to extract a pure `pollStatement(ctx, api, resp)` helper. The helper polls until the statement reaches a terminal state and returns the response without any signal handling, spinner, or server-side cancellation; those concerns stay in `executeAndPoll` where they belong. Also pin `OnWaitTimeout: CONTINUE` explicitly on the `ExecuteStatement` call. The SDK default happens to be CONTINUE today, but relying on it is a hidden coupling: a server-side default flip would silently break the poll loop by killing the statement before our first GET. Behavior is unchanged for the existing `query` command. Follow-up PRs (parallel batch queries, statement lifecycle command tree) will reuse the helper. Co-authored-by: Isaac --- experimental/aitools/cmd/query.go | 63 ++++++++++----- experimental/aitools/cmd/query_test.go | 105 ++++++++++++++++++++++++- 2 files changed, 146 insertions(+), 22 deletions(-) diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index 7b95fdd4e23..6c125bbcd6b 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -262,9 +262,10 @@ func resolveWarehouseID(ctx context.Context, w any, flagValue string) (string, e func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, statement string) (*sql.StatementResponse, error) { // Submit asynchronously to get the statement ID immediately for cancellation. resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ - WarehouseId: warehouseID, - Statement: statement, - WaitTimeout: "0s", + WarehouseId: warehouseID, + Statement: statement, + WaitTimeout: "0s", + OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, }) if err != nil { return nil, fmt.Errorf("execute statement: %w", err) @@ -272,11 +273,6 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa statementID := resp.StatementId - // Check if it completed immediately. - if isTerminalState(resp.Status) { - return resp, checkFailedState(resp.Status) - } - // Set up Ctrl+C: signal cancels the poll context, cleanup is unified below. pollCtx, pollCancel := context.WithCancel(ctx) defer pollCancel() @@ -327,34 +323,59 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa } }() + pollResp, err := pollStatement(pollCtx, api, resp) + if err != nil { + if pollCtx.Err() != nil { + cancelStatement() + cmdio.LogString(ctx, "Query cancelled.") + return nil, root.ErrAlreadyPrinted + } + return nil, err + } + + sp.Close() + if err := checkFailedState(pollResp.Status); err != nil { + return nil, err + } + return pollResp, nil +} + +// pollStatement polls until the statement reaches a terminal state. +// +// On context cancellation it returns the context error WITHOUT cancelling the +// server-side statement. Callers that want server-side cancellation should +// invoke CancelExecution explicitly. +// +// If the input response is already in a terminal state, it is returned without +// further polling. +func pollStatement(ctx context.Context, api sql.StatementExecutionInterface, resp *sql.StatementResponse) (*sql.StatementResponse, error) { + if isTerminalState(resp.Status) { + return resp, nil + } + + statementID := resp.StatementId + start := time.Now() + // Poll with additive backoff: 1s, 2s, 3s, 4s, 5s (capped). interval := pollIntervalInitial for { select { - case <-pollCtx.Done(): - cancelStatement() - cmdio.LogString(ctx, "Query cancelled.") - return nil, root.ErrAlreadyPrinted + case <-ctx.Done(): + return nil, ctx.Err() case <-time.After(interval): } log.Debugf(ctx, "Polling statement %s: %s elapsed", statementID, time.Since(start).Truncate(time.Second)) - pollResp, err := api.GetStatementByStatementId(pollCtx, statementID) + pollResp, err := api.GetStatementByStatementId(ctx, statementID) if err != nil { - if pollCtx.Err() != nil { - cancelStatement() - cmdio.LogString(ctx, "Query cancelled.") - return nil, root.ErrAlreadyPrinted + if ctx.Err() != nil { + return nil, ctx.Err() } return nil, fmt.Errorf("poll statement status: %w", err) } if isTerminalState(pollResp.Status) { - sp.Close() - if err := checkFailedState(pollResp.Status); err != nil { - return nil, err - } return &sql.StatementResponse{ StatementId: pollResp.StatementId, Status: pollResp.Status, diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index aa33921c83b..4bc06c1d63b 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -2,6 +2,7 @@ package aitools import ( "context" + "errors" "os" "path/filepath" "strings" @@ -48,7 +49,9 @@ func TestExecuteAndPollImmediateSuccess(t *testing.T) { mockAPI := mocksql.NewMockStatementExecutionInterface(t) mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { - return req.WarehouseId == "wh-123" && req.Statement == "SELECT 1" && req.WaitTimeout == "0s" + return req.WarehouseId == "wh-123" && req.Statement == "SELECT 1" && + req.WaitTimeout == "0s" && + req.OnWaitTimeout == sql.ExecuteStatementRequestOnWaitTimeoutContinue })).Return(&sql.StatementResponse{ StatementId: "stmt-1", Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, @@ -154,6 +157,106 @@ func TestExecuteAndPollCancelledContextCallsCancelExecution(t *testing.T) { require.ErrorIs(t, err, root.ErrAlreadyPrinted) } +func TestPollStatementImmediateTerminal(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + resp := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "1"}}}}, + Result: &sql.ResultData{DataArray: [][]string{{"1"}}}, + } + + pollResp, err := pollStatement(ctx, mockAPI, resp) + require.NoError(t, err) + assert.Equal(t, sql.StatementStateSucceeded, pollResp.Status.State) + assert.Equal(t, "stmt-1", pollResp.StatementId) +} + +func TestPollStatementTerminalFailureNotErrored(t *testing.T) { + // pollStatement returns the response without erroring on failed terminal + // states; callers (e.g. executeAndPoll) decide what to do via checkFailedState. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + resp := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ErrorCode: "ERR", Message: "boom"}, + }, + } + + pollResp, err := pollStatement(ctx, mockAPI, resp) + require.NoError(t, err) + assert.Equal(t, sql.StatementStateFailed, pollResp.Status.State) +} + +func TestPollStatementEventualSuccess(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + initial := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + } + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateRunning}, + }, nil).Once() + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Result: &sql.ResultData{DataArray: [][]string{{"42"}}}, + }, nil).Once() + + pollResp, err := pollStatement(ctx, mockAPI, initial) + require.NoError(t, err) + assert.Equal(t, sql.StatementStateSucceeded, pollResp.Status.State) + assert.Equal(t, [][]string{{"42"}}, pollResp.Result.DataArray) +} + +func TestPollStatementContextCancellationDoesNotCancelServerSide(t *testing.T) { + // The mock asserts (via t.Cleanup) that no unexpected calls are made. + // Specifically, pollStatement must NOT call CancelExecution on context + // cancellation; that is the caller's responsibility. + ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context())) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + initial := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + } + + cancel() + + pollResp, err := pollStatement(ctx, mockAPI, initial) + require.ErrorIs(t, err, context.Canceled) + assert.Nil(t, pollResp) +} + +func TestPollStatementGetErrorPropagated(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + initial := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + } + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1"). + Return(nil, errors.New("network unreachable")).Once() + + pollResp, err := pollStatement(ctx, mockAPI, initial) + require.Error(t, err) + assert.Contains(t, err.Error(), "poll statement status") + assert.Contains(t, err.Error(), "network unreachable") + assert.Nil(t, pollResp) +} + func TestResolveWarehouseIDWithFlag(t *testing.T) { ctx := t.Context() id, err := resolveWarehouseID(ctx, nil, "explicit-id")