diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fbd3147..9fc306af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `Cancel` and `CancelTx` to the `Client` to enable cancellation of jobs. [PR #141](https://github.com/riverqueue/river/pull/141). - Added `ClientFromContext` and `ClientWithContextSafely` helpers to extract the `Client` from the worker's context where it is now available to workers. This simplifies making the River client available within your workers for i.e. enqueueing additional jobs. [PR #145](https://github.com/riverqueue/river/pull/145). ## [0.0.16] - 2024-01-06 diff --git a/client.go b/client.go index ecae76e2..6e35d5fb 100644 --- a/client.go +++ b/client.go @@ -312,6 +312,11 @@ func (ts *clientTestSignals) Init() { } var ( + // ErrNotFound is returned when a query by ID does not match any existing + // rows. For example, attempting to cancel a job that doesn't exist will + // return this error. + ErrNotFound = errors.New("not found") + errMissingConfig = errors.New("missing config") errMissingDatabasePoolWithQueues = errors.New("must have a non-nil database pool to execute jobs (either use a driver with database pool or don't configure Queues)") errMissingDriver = errors.New("missing database driver (try wrapping a Pgx pool with river/riverdriver/riverpgxv5.New)") @@ -935,6 +940,106 @@ func (c *Client[TTx]) runProducers(fetchNewWorkCtx, workCtx context.Context) { } } +// Cancel cancels the job with the given ID. If possible, the job is cancelled +// immediately and will not be retried. The provided context is used for the +// underlying Postgres update and can be used to cancel the operation or apply a +// timeout. +// +// If the job is still in the queue (available, scheduled, or retryable), it is +// immediately marked as cancelled and will not be retried. +// +// If the job is already finalized (cancelled, completed, or discarded), no +// changes are made. +// +// If the job is currently running, it is not immediately cancelled, but is +// instead marked for cancellation. The client running the job will also be +// notified (via LISTEN/NOTIFY) to cancel the running job's context. Although +// the job's context will be cancelled, since Go does not provide a mechanism to +// interrupt a running goroutine the job will continue running until it returns. +// As always, it is important for workers to respect context cancellation and +// return promptly when the job context is done. +// +// Once the cancellation signal is received by the client running the job, any +// error returned by that job will result in it being cancelled permanently and +// not retried. However if the job returns no error, it will be completed as +// usual. +// +// In the event the running job finishes executing _before_ the cancellation +// signal is received but _after_ this update was made, the behavior depends on +// which state the job is being transitioned into (based on its return error): +// +// - If the job completed successfully, was cancelled from within, or was +// discarded due to exceeding its max attempts, the job will be updated as +// usual. +// - If the job was snoozed to run again later or encountered a retryable error, +// the job will be marked as cancelled and will not be attempted again. +// +// Returns the up-to-date JobRow for the specified jobID if it exists. Returns +// ErrNotFound if the job doesn't exist. +func (c *Client[TTx]) Cancel(ctx context.Context, jobID int64) (*rivertype.JobRow, error) { + job, err := c.adapter.JobCancel(ctx, jobID) + if err != nil { + if errors.Is(err, riverdriver.ErrNoRows) { + return nil, ErrNotFound + } + return nil, err + } + + return dbsqlc.JobRowFromInternal(job), nil +} + +// CancelTx cancels the job with the given ID within the specified transaction. +// This variant lets a caller cancel a job atomically alongside other database +// changes. An cancelled job doesn't take effect until the transaction commits, +// and if the transaction rolls back, so too is the cancelled job. +// +// If possible, the job is cancelled immediately and will not be retried. The +// provided context is used for the underlying Postgres update and can be used +// to cancel the operation or apply a timeout. +// +// If the job is still in the queue (available, scheduled, or retryable), it is +// immediately marked as cancelled and will not be retried. +// +// If the job is already finalized (cancelled, completed, or discarded), no +// changes are made. +// +// If the job is currently running, it is not immediately cancelled, but is +// instead marked for cancellation. The client running the job will also be +// notified (via LISTEN/NOTIFY) to cancel the running job's context. Although +// the job's context will be cancelled, since Go does not provide a mechanism to +// interrupt a running goroutine the job will continue running until it returns. +// As always, it is important for workers to respect context cancellation and +// return promptly when the job context is done. +// +// Once the cancellation signal is received by the client running the job, any +// error returned by that job will result in it being cancelled permanently and +// not retried. However if the job returns no error, it will be completed as +// usual. +// +// In the event the running job finishes executing _before_ the cancellation +// signal is received but _after_ this update was made, the behavior depends on +// which state the job is being transitioned into (based on its return error): +// +// - If the job completed successfully, was cancelled from within, or was +// discarded due to exceeding its max attempts, the job will be updated as +// usual. +// - If the job was snoozed to run again later or encountered a retryable error, +// the job will be marked as cancelled and will not be attempted again. +// +// Returns the up-to-date JobRow for the specified jobID if it exists. Returns +// ErrNotFound if the job doesn't exist. +func (c *Client[TTx]) CancelTx(ctx context.Context, tx TTx, jobID int64) (*rivertype.JobRow, error) { + job, err := c.adapter.JobCancelTx(ctx, c.driver.UnwrapTx(tx), jobID) + if errors.Is(err, riverdriver.ErrNoRows) { + return nil, ErrNotFound + } + if err != nil { + return nil, err + } + + return dbsqlc.JobRowFromInternal(job), nil +} + func insertParamsFromArgsAndOptions(args JobArgs, insertOpts *InsertOpts) (*dbadapter.JobInsertParams, error) { encodedArgs, err := json.Marshal(args) if err != nil { diff --git a/client_test.go b/client_test.go index 98b31fd9..d54ff1c2 100644 --- a/client_test.go +++ b/client_test.go @@ -216,7 +216,7 @@ func Test_Client(t *testing.T) { riverinternaltest.WaitOrTimeout(t, workedChan) }) - t.Run("JobCancel", func(t *testing.T) { + t.Run("JobCancelErrorReturned", func(t *testing.T) { t.Parallel() client, bundle := setup(t) @@ -245,7 +245,7 @@ func Test_Client(t *testing.T) { require.WithinDuration(t, time.Now(), *updatedJob.FinalizedAt, 2*time.Second) }) - t.Run("JobSnooze", func(t *testing.T) { + t.Run("JobSnoozeErrorReturned", func(t *testing.T) { t.Parallel() client, bundle := setup(t) @@ -274,6 +274,130 @@ func Test_Client(t *testing.T) { require.WithinDuration(t, time.Now().Add(15*time.Minute), updatedJob.ScheduledAt, 2*time.Second) }) + // This helper is used to test cancelling a job both _in_ a transaction and + // _outside of_ a transaction. The exact same test logic applies to each case, + // the only difference is a different cancelFunc provided by the specific + // subtest. + cancelRunningJobTestHelper := func(t *testing.T, cancelFunc func(ctx context.Context, client *Client[pgx.Tx], jobID int64) (*rivertype.JobRow, error)) { //nolint:thelper + client, bundle := setup(t) + + jobStartedChan := make(chan int64) + + type JobArgs struct { + JobArgsReflectKind[JobArgs] + } + + AddWorker(client.config.Workers, WorkFunc(func(ctx context.Context, job *Job[JobArgs]) error { + jobStartedChan <- job.ID + <-ctx.Done() + return ctx.Err() + })) + + statusUpdateCh := client.monitor.RegisterUpdates() + startClient(ctx, t, client) + waitForClientHealthy(ctx, t, statusUpdateCh) + + insertedJob, err := client.Insert(ctx, &JobArgs{}, nil) + require.NoError(t, err) + + startedJobID := riverinternaltest.WaitOrTimeout(t, jobStartedChan) + require.Equal(t, insertedJob.ID, startedJobID) + + // Cancel the job: + updatedJob, err := cancelFunc(ctx, client, insertedJob.ID) + require.NoError(t, err) + require.NotNil(t, updatedJob) + // Job is still actively running at this point because the query wouldn't + // modify that column for a running job: + require.Equal(t, rivertype.JobStateRunning, updatedJob.State) + + event := riverinternaltest.WaitOrTimeout(t, bundle.subscribeChan) + require.Equal(t, EventKindJobCancelled, event.Kind) + require.Equal(t, JobStateCancelled, event.Job.State) + require.WithinDuration(t, time.Now(), *event.Job.FinalizedAt, 2*time.Second) + + jobAfterCancel, err := bundle.queries.JobGetByID(ctx, client.driver.GetDBPool(), insertedJob.ID) + require.NoError(t, err) + require.Equal(t, dbsqlc.JobStateCancelled, jobAfterCancel.State) + require.WithinDuration(t, time.Now(), *jobAfterCancel.FinalizedAt, 2*time.Second) + } + + t.Run("CancelRunningJob", func(t *testing.T) { + t.Parallel() + + cancelRunningJobTestHelper(t, func(ctx context.Context, client *Client[pgx.Tx], jobID int64) (*rivertype.JobRow, error) { + return client.Cancel(ctx, jobID) + }) + }) + + t.Run("CancelRunningJobInTx", func(t *testing.T) { + t.Parallel() + + cancelRunningJobTestHelper(t, func(ctx context.Context, client *Client[pgx.Tx], jobID int64) (*rivertype.JobRow, error) { + var ( + job *rivertype.JobRow + err error + ) + txErr := pgx.BeginFunc(ctx, client.driver.GetDBPool(), func(tx pgx.Tx) error { + job, err = client.CancelTx(ctx, tx, jobID) + return err + }) + require.NoError(t, txErr) + return job, err + }) + }) + + t.Run("CancelScheduledJob", func(t *testing.T) { + t.Parallel() + + client, _ := setup(t) + + jobStartedChan := make(chan int64) + + type JobArgs struct { + JobArgsReflectKind[JobArgs] + } + + AddWorker(client.config.Workers, WorkFunc(func(ctx context.Context, job *Job[JobArgs]) error { + jobStartedChan <- job.ID + <-ctx.Done() + return ctx.Err() + })) + + startClient(ctx, t, client) + + insertedJob, err := client.Insert(ctx, &JobArgs{}, &InsertOpts{ScheduledAt: time.Now().Add(5 * time.Minute)}) + require.NoError(t, err) + + // Cancel the job: + updatedJob, err := client.Cancel(ctx, insertedJob.ID) + require.NoError(t, err) + require.NotNil(t, updatedJob) + require.Equal(t, rivertype.JobStateCancelled, updatedJob.State) + require.WithinDuration(t, time.Now(), *updatedJob.FinalizedAt, 2*time.Second) + }) + + t.Run("CancelNonExistentJob", func(t *testing.T) { + t.Parallel() + + client, _ := setup(t) + startClient(ctx, t, client) + + // Cancel an unknown job ID: + jobAfter, err := client.Cancel(ctx, 0) + require.ErrorIs(t, err, ErrNotFound) + require.Nil(t, jobAfter) + + // Cancel an unknown job ID, within a transaction: + err = pgx.BeginFunc(ctx, client.driver.GetDBPool(), func(tx pgx.Tx) error { + jobAfter, err := client.CancelTx(ctx, tx, 0) + require.ErrorIs(t, err, ErrNotFound) + require.Nil(t, jobAfter) + return nil + }) + require.NoError(t, err) + }) + t.Run("AlternateSchema", func(t *testing.T) { t.Parallel() diff --git a/example_cancel_from_client_test.go b/example_cancel_from_client_test.go new file mode 100644 index 00000000..7dfa2658 --- /dev/null +++ b/example_cancel_from_client_test.go @@ -0,0 +1,102 @@ +package river_test + +import ( + "context" + "errors" + "log/slog" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/riverqueue/river" + "github.com/riverqueue/river/internal/riverinternaltest" + "github.com/riverqueue/river/internal/util/slogutil" + "github.com/riverqueue/river/riverdriver/riverpgxv5" +) + +type SleepingArgs struct{} + +func (args SleepingArgs) Kind() string { return "SleepingWorker" } + +type SleepingWorker struct { + river.WorkerDefaults[CancellingArgs] + jobChan chan int64 +} + +func (w *SleepingWorker) Work(ctx context.Context, job *river.Job[CancellingArgs]) error { + w.jobChan <- job.ID + select { + case <-ctx.Done(): + case <-time.After(5 * time.Second): + return errors.New("sleeping worker timed out") + } + return ctx.Err() +} + +// Example_cancelJobFromClient demonstrates how to permanently cancel a job from +// any Client using Cancel. +func Example_cancelJobFromClient() { + ctx := context.Background() + + dbPool, err := pgxpool.NewWithConfig(ctx, riverinternaltest.DatabaseConfig("river_testdb_example")) + if err != nil { + panic(err) + } + defer dbPool.Close() + + // Required for the purpose of this test, but not necessary in real usage. + if err := riverinternaltest.TruncateRiverTables(ctx, dbPool); err != nil { + panic(err) + } + + jobChan := make(chan int64) + + workers := river.NewWorkers() + river.AddWorker(workers, &SleepingWorker{jobChan: jobChan}) + + riverClient, err := river.NewClient(riverpgxv5.New(dbPool), &river.Config{ + Logger: slog.New(&slogutil.SlogMessageOnlyHandler{Level: slog.LevelWarn}), + Queues: map[string]river.QueueConfig{ + river.QueueDefault: {MaxWorkers: 10}, + }, + Workers: workers, + }) + if err != nil { + panic(err) + } + + // Not strictly needed, but used to help this test wait until job is worked. + subscribeChan, subscribeCancel := riverClient.Subscribe(river.EventKindJobCancelled) + defer subscribeCancel() + + if err := riverClient.Start(ctx); err != nil { + panic(err) + } + job, err := riverClient.Insert(ctx, CancellingArgs{ShouldCancel: true}, nil) + if err != nil { + panic(err) + } + select { + case <-jobChan: + case <-time.After(2 * time.Second): + panic("no jobChan signal received") + } + + // There is presently no way to wait for the client to be 100% ready, so we + // sleep for a bit to give it time to start up. This is only needed in this + // example because we need the notifier to be ready for it to receive the + // cancellation signal. + time.Sleep(500 * time.Millisecond) + + if _, err = riverClient.Cancel(ctx, job.ID); err != nil { + panic(err) + } + waitForNJobs(subscribeChan, 1) + + if err := riverClient.Stop(ctx); err != nil { + panic(err) + } + + // Output: + // jobExecutor: job cancelled remotely +} diff --git a/internal/dbadapter/db_adapter.go b/internal/dbadapter/db_adapter.go index dd299eab..78d39a88 100644 --- a/internal/dbadapter/db_adapter.go +++ b/internal/dbadapter/db_adapter.go @@ -19,6 +19,7 @@ import ( "github.com/riverqueue/river/internal/util/ptrutil" "github.com/riverqueue/river/internal/util/sliceutil" "github.com/riverqueue/river/internal/util/valutil" + "github.com/riverqueue/river/riverdriver" ) // When a job has specified unique options, but has not set the ByState @@ -82,6 +83,9 @@ type JobInsertResult struct { // expedience, but this should be converted to a more stable API if Adapter // would be exported. type Adapter interface { + JobCancel(ctx context.Context, id int64) (*dbsqlc.RiverJob, error) + JobCancelTx(ctx context.Context, tx pgx.Tx, id int64) (*dbsqlc.RiverJob, error) + JobInsert(ctx context.Context, params *JobInsertParams) (*JobInsertResult, error) JobInsertTx(ctx context.Context, tx pgx.Tx, params *JobInsertParams) (*JobInsertResult, error) @@ -154,6 +158,36 @@ func NewStandardAdapter(archetype *baseservice.Archetype, config *StandardAdapte }) } +func (a *StandardAdapter) JobCancel(ctx context.Context, id int64) (*dbsqlc.RiverJob, error) { + return dbutil.WithTxV(ctx, a.executor, func(ctx context.Context, tx pgx.Tx) (*dbsqlc.RiverJob, error) { + return a.JobCancelTx(ctx, tx, id) + }) +} + +func (a *StandardAdapter) JobCancelTx(ctx context.Context, tx pgx.Tx, id int64) (*dbsqlc.RiverJob, error) { + ctx, cancel := context.WithTimeout(ctx, a.deadlineTimeout) + defer cancel() + + cancelledAt, err := a.TimeNowUTC().MarshalJSON() + if err != nil { + return nil, err + } + + job, err := a.queries.JobCancel(ctx, a.executor, dbsqlc.JobCancelParams{ + CancelAttemptedAt: cancelledAt, + ID: id, + JobControlTopic: string(notifier.NotificationTopicJobControl), + }) + if errors.Is(err, pgx.ErrNoRows) { + return nil, riverdriver.ErrNoRows + } + if err != nil { + return nil, err + } + + return job, nil +} + func (a *StandardAdapter) JobInsert(ctx context.Context, params *JobInsertParams) (*JobInsertResult, error) { return dbutil.WithTxV(ctx, a.executor, func(ctx context.Context, tx pgx.Tx) (*JobInsertResult, error) { return a.JobInsertTx(ctx, tx, params) diff --git a/internal/dbadapter/db_adapter_test.go b/internal/dbadapter/db_adapter_test.go index 1bf58ad7..8a35e5bd 100644 --- a/internal/dbadapter/db_adapter_test.go +++ b/internal/dbadapter/db_adapter_test.go @@ -18,8 +18,130 @@ import ( "github.com/riverqueue/river/internal/riverinternaltest" "github.com/riverqueue/river/internal/util/dbutil" "github.com/riverqueue/river/internal/util/ptrutil" + "github.com/riverqueue/river/riverdriver" ) +func Test_StandardAdapter_JobCancel(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type testBundle struct { + baselineTime time.Time // baseline time frozen at now when setup is called + ex dbutil.Executor + } + + setup := func(t *testing.T, ex dbutil.Executor) (*StandardAdapter, *testBundle) { + t.Helper() + + bundle := &testBundle{ + baselineTime: time.Now().UTC(), + ex: ex, + } + + adapter := NewStandardAdapter(riverinternaltest.BaseServiceArchetype(t), testAdapterConfig(bundle.ex)) + adapter.TimeNowUTC = func() time.Time { return bundle.baselineTime } + + return adapter, bundle + } + + setupTx := func(t *testing.T) (*StandardAdapter, *testBundle) { + t.Helper() + return setup(t, riverinternaltest.TestTx(ctx, t)) + } + + for _, startingState := range []dbsqlc.JobState{ + dbsqlc.JobStateAvailable, + dbsqlc.JobStateRetryable, + dbsqlc.JobStateScheduled, + } { + startingState := startingState + + t.Run(fmt.Sprintf("CancelsJobIn%sState", startingState), func(t *testing.T) { + t.Parallel() + + adapter, bundle := setupTx(t) + timeNowString := bundle.baselineTime.Format(time.RFC3339Nano) + + params := makeFakeJobInsertParams(0, nil) + params.State = startingState + insertResult, err := adapter.JobInsert(ctx, params) + require.NoError(t, err) + require.Equal(t, startingState, insertResult.Job.State) + + jobAfter, err := adapter.JobCancel(ctx, insertResult.Job.ID) + require.NoError(t, err) + require.NotNil(t, jobAfter) + + require.Equal(t, dbsqlc.JobStateCancelled, jobAfter.State) + require.WithinDuration(t, time.Now(), *jobAfter.FinalizedAt, 2*time.Second) + require.JSONEq(t, fmt.Sprintf(`{"cancel_attempted_at":%q}`, timeNowString), string(jobAfter.Metadata)) + }) + } + + t.Run("RunningJobIsNotImmediatelyCancelled", func(t *testing.T) { + t.Parallel() + + adapter, bundle := setupTx(t) + timeNowString := bundle.baselineTime.Format(time.RFC3339Nano) + + params := makeFakeJobInsertParams(0, nil) + params.State = dbsqlc.JobStateRunning + insertResult, err := adapter.JobInsert(ctx, params) + require.NoError(t, err) + require.Equal(t, dbsqlc.JobStateRunning, insertResult.Job.State) + + jobAfter, err := adapter.JobCancel(ctx, insertResult.Job.ID) + require.NoError(t, err) + require.NotNil(t, jobAfter) + require.Equal(t, dbsqlc.JobStateRunning, jobAfter.State) + require.Nil(t, jobAfter.FinalizedAt) + require.JSONEq(t, fmt.Sprintf(`{"cancel_attempted_at":%q}`, timeNowString), string(jobAfter.Metadata)) + }) + + for _, startingState := range []dbsqlc.JobState{ + dbsqlc.JobStateCancelled, + dbsqlc.JobStateCompleted, + dbsqlc.JobStateDiscarded, + } { + startingState := startingState + + t.Run(fmt.Sprintf("DoesNotAlterFinalizedJobIn%sState", startingState), func(t *testing.T) { + t.Parallel() + adapter, bundle := setupTx(t) + + params := makeFakeJobInsertParams(0, nil) + initialRes, err := adapter.JobInsert(ctx, params) + require.NoError(t, err) + + res, err := adapter.queries.JobUpdate(ctx, bundle.ex, dbsqlc.JobUpdateParams{ + ID: initialRes.Job.ID, + FinalizedAtDoUpdate: true, + FinalizedAt: ptrutil.Ptr(time.Now()), + StateDoUpdate: true, + State: startingState, + }) + require.NoError(t, err) + + jobAfter, err := adapter.JobCancel(ctx, res.ID) + require.NoError(t, err) + require.Equal(t, startingState, jobAfter.State) + require.WithinDuration(t, *res.FinalizedAt, *jobAfter.FinalizedAt, time.Microsecond) + require.JSONEq(t, `{}`, string(jobAfter.Metadata)) + }) + } + + t.Run("ReturnsErrNoRowsIfJobDoesNotExist", func(t *testing.T) { + t.Parallel() + + adapter, _ := setupTx(t) + + jobAfter, err := adapter.JobCancel(ctx, 1234567890) + require.ErrorIs(t, err, riverdriver.ErrNoRows) + require.Nil(t, jobAfter) + }) +} + func Test_StandardAdapter_JobGetAvailable(t *testing.T) { t.Parallel() @@ -765,6 +887,44 @@ func Test_StandardAdapter_JobSetStateErrored(t *testing.T) { require.Equal(t, dbsqlc.JobStateRetryable, j.State) require.WithinDuration(t, params.ScheduledAt, jAfter.ScheduledAt, time.Microsecond) }) + + t.Run("SetsAJobWithCancelAttemptedAtToCancelled", func(t *testing.T) { + // If a job has cancel_attempted_at in its metadata, it means that the user + // tried to cancel the job with the Cancel API but that the job + // finished/errored before the producer received the cancel notification. + // + // In this case, we want to move the job to cancelled instead of retryable + // so that the job is not retried. + t.Parallel() + + adapter, bundle := setupTx(t) + + params := makeFakeJobInsertParams(0, &makeFakeJobInsertParamsOpts{ + ScheduledAt: ptrutil.Ptr(bundle.baselineTime.Add(-10 * time.Second)), + }) + params.State = dbsqlc.JobStateRunning + params.Metadata = []byte(fmt.Sprintf(`{"cancel_attempted_at":"%s"}`, time.Now().UTC().Format(time.RFC3339))) + res, err := adapter.JobInsert(ctx, params) + require.NoError(t, err) + + jAfter, err := adapter.JobSetStateIfRunning(ctx, JobSetStateErrorRetryable(res.Job.ID, bundle.baselineTime, bundle.errPayload)) + require.NoError(t, err) + require.Equal(t, dbsqlc.JobStateCancelled, jAfter.State) + require.NotNil(t, jAfter.FinalizedAt) + // Loose assertion against FinalizedAt just to make sure it was set (it uses + // the database's now() instead of a passed-in time): + require.WithinDuration(t, time.Now().UTC(), *jAfter.FinalizedAt, 2*time.Second) + // ScheduledAt should not be touched: + require.WithinDuration(t, params.ScheduledAt, jAfter.ScheduledAt, time.Microsecond) + // Errors should still be appended to: + require.Len(t, jAfter.Errors, 1) + require.Contains(t, jAfter.Errors[0].Error, "fake error") + + j, err := adapter.queries.JobGetByID(ctx, bundle.ex, res.Job.ID) + require.NoError(t, err) + require.Equal(t, dbsqlc.JobStateCancelled, j.State) + require.WithinDuration(t, params.ScheduledAt, jAfter.ScheduledAt, time.Microsecond) + }) } func Test_StandardAdapter_LeadershipAttemptElect_CannotElectTwiceInARow(t *testing.T) { diff --git a/internal/dbadaptertest/test_adapter.go b/internal/dbadaptertest/test_adapter.go index 1a7e40e4..6b24c0ba 100644 --- a/internal/dbadaptertest/test_adapter.go +++ b/internal/dbadaptertest/test_adapter.go @@ -18,6 +18,8 @@ type TestAdapter struct { fallthroughAdapter dbadapter.Adapter mu sync.Mutex + JobCancelCalled bool + JobCancelTxCalled bool JobInsertCalled bool JobInsertTxCalled bool JobInsertManyCalled bool @@ -28,6 +30,8 @@ type TestAdapter struct { LeadershipAttemptElectCalled bool LeadershipResignedCalled bool + JobCancelFunc func(ctx context.Context, id int64) (*dbsqlc.RiverJob, error) + JobCancelTxFunc func(ctx context.Context, tx pgx.Tx, id int64) (*dbsqlc.RiverJob, error) JobInsertFunc func(ctx context.Context, params *dbadapter.JobInsertParams) (*dbadapter.JobInsertResult, error) JobInsertTxFunc func(ctx context.Context, tx pgx.Tx, params *dbadapter.JobInsertParams) (*dbadapter.JobInsertResult, error) JobInsertManyFunc func(ctx context.Context, params []*dbadapter.JobInsertParams) (int64, error) @@ -39,6 +43,26 @@ type TestAdapter struct { LeadershipResignFunc func(ctx context.Context, name string, leaderID string) error } +func (ta *TestAdapter) JobCancel(ctx context.Context, id int64) (*dbsqlc.RiverJob, error) { + ta.atomicSetBoolTrue(&ta.JobCancelCalled) + + if ta.JobCancelFunc != nil { + return ta.JobCancelFunc(ctx, id) + } + + return ta.fallthroughAdapter.JobCancel(ctx, id) +} + +func (ta *TestAdapter) JobCancelTx(ctx context.Context, tx pgx.Tx, id int64) (*dbsqlc.RiverJob, error) { + ta.atomicSetBoolTrue(&ta.JobCancelTxCalled) + + if ta.JobCancelTxFunc != nil { + return ta.JobCancelTxFunc(ctx, tx, id) + } + + return ta.fallthroughAdapter.JobCancel(ctx, id) +} + func (ta *TestAdapter) JobInsert(ctx context.Context, params *dbadapter.JobInsertParams) (*dbadapter.JobInsertResult, error) { ta.atomicSetBoolTrue(&ta.JobInsertCalled) diff --git a/internal/dbsqlc/river_job.sql b/internal/dbsqlc/river_job.sql index 92f82fc7..dae9d241 100644 --- a/internal/dbsqlc/river_job.sql +++ b/internal/dbsqlc/river_job.sql @@ -31,6 +31,51 @@ CREATE TABLE river_job( CONSTRAINT kind_length CHECK (char_length(kind) > 0 AND char_length(kind) < 128) ); +-- name: JobCancel :one +WITH locked_job AS ( + SELECT + id, queue, state, finalized_at + FROM river_job + WHERE + river_job.id = @id + FOR UPDATE +), + +notification AS ( + SELECT + id, + pg_notify(@job_control_topic, json_build_object('action', 'cancel', 'job_id', id, 'queue', queue)::text) + FROM + locked_job + WHERE + state NOT IN ('cancelled', 'completed', 'discarded') + AND finalized_at IS NULL +), + +updated_job AS ( + UPDATE river_job + SET + -- If the job is actively running, we want to let its current client and + -- producer handle the cancellation. Otherwise, immediately cancel it. + state = CASE WHEN state = 'running'::river_job_state THEN state ELSE 'cancelled'::river_job_state END, + finalized_at = CASE WHEN state = 'running'::river_job_state THEN finalized_at ELSE now() END, + -- Mark the job as cancelled by query so that the rescuer knows not to + -- rescue it, even if it gets stuck in the running state: + metadata = jsonb_set(metadata, '{cancel_attempted_at}'::text[], @cancel_attempted_at::jsonb, true) + FROM notification + WHERE + river_job.id = notification.id + RETURNING river_job.* +) + +SELECT * +FROM river_job +WHERE id = @id::bigint + AND id NOT IN (SELECT id FROM updated_job) +UNION +SELECT * +FROM updated_job; + -- name: JobCountRunning :one SELECT count(*) @@ -240,18 +285,27 @@ RETURNING *; -- name: JobSetStateIfRunning :one WITH job_to_update AS ( - SELECT id + SELECT + id, + @state::river_job_state IN ('retryable'::river_job_state, 'scheduled'::river_job_state) AND metadata ? 'cancel_attempted_at' AS should_cancel FROM river_job WHERE id = @id::bigint FOR UPDATE ), updated_job AS ( UPDATE river_job - SET errors = CASE WHEN @error_do_update::boolean THEN array_append(errors, @error::jsonb) ELSE errors END, - finalized_at = CASE WHEN @finalized_at_do_update::boolean THEN @finalized_at ELSE finalized_at END, - max_attempts = CASE WHEN @max_attempts_update::boolean THEN @max_attempts ELSE max_attempts END, - scheduled_at = CASE WHEN @scheduled_at_do_update::boolean THEN @scheduled_at ELSE scheduled_at END, - state = @state + SET + state = CASE WHEN should_cancel THEN 'cancelled'::river_job_state + ELSE @state::river_job_state END, + finalized_at = CASE WHEN should_cancel THEN now() + WHEN @finalized_at_do_update::boolean THEN @finalized_at + ELSE finalized_at END, + errors = CASE WHEN @error_do_update::boolean THEN array_append(errors, @error::jsonb) + ELSE errors END, + max_attempts = CASE WHEN NOT should_cancel AND @max_attempts_update::boolean THEN @max_attempts + ELSE max_attempts END, + scheduled_at = CASE WHEN NOT should_cancel AND @scheduled_at_do_update::boolean THEN @scheduled_at + ELSE scheduled_at END FROM job_to_update WHERE river_job.id = job_to_update.id AND river_job.state = 'running'::river_job_state @@ -290,6 +344,7 @@ UPDATE river_job SET attempt = CASE WHEN @attempt_do_update::boolean THEN @attempt ELSE attempt END, attempted_at = CASE WHEN @attempted_at_do_update::boolean THEN @attempted_at ELSE attempted_at END, + finalized_at = CASE WHEN @finalized_at_do_update::boolean THEN @finalized_at ELSE finalized_at END, state = CASE WHEN @state_do_update::boolean THEN @state ELSE state END WHERE id = @id RETURNING *; diff --git a/internal/dbsqlc/river_job.sql.go b/internal/dbsqlc/river_job.sql.go index 9a8f12fd..f8064d45 100644 --- a/internal/dbsqlc/river_job.sql.go +++ b/internal/dbsqlc/river_job.sql.go @@ -10,6 +10,82 @@ import ( "time" ) +const jobCancel = `-- name: JobCancel :one +WITH locked_job AS ( + SELECT + id, queue, state, finalized_at + FROM river_job + WHERE + river_job.id = $1 + FOR UPDATE +), + +notification AS ( + SELECT + id, + pg_notify($2, json_build_object('action', 'cancel', 'job_id', id, 'queue', queue)::text) + FROM + locked_job + WHERE + state NOT IN ('cancelled', 'completed', 'discarded') + AND finalized_at IS NULL +), + +updated_job AS ( + UPDATE river_job + SET + -- If the job is actively running, we want to let its current client and + -- producer handle the cancellation. Otherwise, immediately cancel it. + state = CASE WHEN state = 'running'::river_job_state THEN state ELSE 'cancelled'::river_job_state END, + finalized_at = CASE WHEN state = 'running'::river_job_state THEN finalized_at ELSE now() END, + -- Mark the job as cancelled by query so that the rescuer knows not to + -- rescue it, even if it gets stuck in the running state: + metadata = jsonb_set(metadata, '{cancel_attempted_at}'::text[], $3::jsonb, true) + FROM notification + WHERE + river_job.id = notification.id + RETURNING river_job.id, river_job.args, river_job.attempt, river_job.attempted_at, river_job.attempted_by, river_job.created_at, river_job.errors, river_job.finalized_at, river_job.kind, river_job.max_attempts, river_job.metadata, river_job.priority, river_job.queue, river_job.state, river_job.scheduled_at, river_job.tags +) + +SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags +FROM river_job +WHERE id = $1::bigint + AND id NOT IN (SELECT id FROM updated_job) +UNION +SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags +FROM updated_job +` + +type JobCancelParams struct { + ID int64 + JobControlTopic string + CancelAttemptedAt []byte +} + +func (q *Queries) JobCancel(ctx context.Context, db DBTX, arg JobCancelParams) (*RiverJob, error) { + row := db.QueryRow(ctx, jobCancel, arg.ID, arg.JobControlTopic, arg.CancelAttemptedAt) + var i RiverJob + err := row.Scan( + &i.ID, + &i.Args, + &i.Attempt, + &i.AttemptedAt, + &i.AttemptedBy, + &i.CreatedAt, + &i.Errors, + &i.FinalizedAt, + &i.Kind, + &i.MaxAttempts, + &i.Metadata, + &i.Priority, + &i.Queue, + &i.State, + &i.ScheduledAt, + &i.Tags, + ) + return &i, err +} + const jobCountRunning = `-- name: JobCountRunning :one SELECT count(*) @@ -668,18 +744,27 @@ func (q *Queries) JobSetState(ctx context.Context, db DBTX, arg JobSetStateParam const jobSetStateIfRunning = `-- name: JobSetStateIfRunning :one WITH job_to_update AS ( - SELECT id + SELECT + id, + $1::river_job_state IN ('retryable'::river_job_state, 'scheduled'::river_job_state) AND metadata ? 'cancel_attempted_at' AS should_cancel FROM river_job - WHERE id = $1::bigint + WHERE id = $2::bigint FOR UPDATE ), updated_job AS ( UPDATE river_job - SET errors = CASE WHEN $2::boolean THEN array_append(errors, $3::jsonb) ELSE errors END, - finalized_at = CASE WHEN $4::boolean THEN $5 ELSE finalized_at END, - max_attempts = CASE WHEN $6::boolean THEN $7 ELSE max_attempts END, - scheduled_at = CASE WHEN $8::boolean THEN $9 ELSE scheduled_at END, - state = $10 + SET + state = CASE WHEN should_cancel THEN 'cancelled'::river_job_state + ELSE $1::river_job_state END, + finalized_at = CASE WHEN should_cancel THEN now() + WHEN $3::boolean THEN $4 + ELSE finalized_at END, + errors = CASE WHEN $5::boolean THEN array_append(errors, $6::jsonb) + ELSE errors END, + max_attempts = CASE WHEN NOT should_cancel AND $7::boolean THEN $8 + ELSE max_attempts END, + scheduled_at = CASE WHEN NOT should_cancel AND $9::boolean THEN $10 + ELSE scheduled_at END FROM job_to_update WHERE river_job.id = job_to_update.id AND river_job.state = 'running'::river_job_state @@ -687,7 +772,7 @@ updated_job AS ( ) SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags FROM river_job -WHERE id = $1::bigint +WHERE id = $2::bigint AND id NOT IN (SELECT id FROM updated_job) UNION SELECT id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags @@ -695,30 +780,30 @@ FROM updated_job ` type JobSetStateIfRunningParams struct { + State JobState ID int64 - ErrorDoUpdate bool - Error []byte FinalizedAtDoUpdate bool FinalizedAt *time.Time + ErrorDoUpdate bool + Error []byte MaxAttemptsUpdate bool MaxAttempts int16 ScheduledAtDoUpdate bool ScheduledAt time.Time - State JobState } func (q *Queries) JobSetStateIfRunning(ctx context.Context, db DBTX, arg JobSetStateIfRunningParams) (*RiverJob, error) { row := db.QueryRow(ctx, jobSetStateIfRunning, + arg.State, arg.ID, - arg.ErrorDoUpdate, - arg.Error, arg.FinalizedAtDoUpdate, arg.FinalizedAt, + arg.ErrorDoUpdate, + arg.Error, arg.MaxAttemptsUpdate, arg.MaxAttempts, arg.ScheduledAtDoUpdate, arg.ScheduledAt, - arg.State, ) var i RiverJob err := row.Scan( @@ -747,8 +832,9 @@ UPDATE river_job SET attempt = CASE WHEN $1::boolean THEN $2 ELSE attempt END, attempted_at = CASE WHEN $3::boolean THEN $4 ELSE attempted_at END, - state = CASE WHEN $5::boolean THEN $6 ELSE state END -WHERE id = $7 + finalized_at = CASE WHEN $5::boolean THEN $6 ELSE finalized_at END, + state = CASE WHEN $7::boolean THEN $8 ELSE state END +WHERE id = $9 RETURNING id, args, attempt, attempted_at, attempted_by, created_at, errors, finalized_at, kind, max_attempts, metadata, priority, queue, state, scheduled_at, tags ` @@ -757,6 +843,8 @@ type JobUpdateParams struct { Attempt int16 AttemptedAtDoUpdate bool AttemptedAt *time.Time + FinalizedAtDoUpdate bool + FinalizedAt *time.Time StateDoUpdate bool State JobState ID int64 @@ -770,6 +858,8 @@ func (q *Queries) JobUpdate(ctx context.Context, db DBTX, arg JobUpdateParams) ( arg.Attempt, arg.AttemptedAtDoUpdate, arg.AttemptedAt, + arg.FinalizedAtDoUpdate, + arg.FinalizedAt, arg.StateDoUpdate, arg.State, arg.ID, diff --git a/internal/maintenance/rescuer.go b/internal/maintenance/rescuer.go index aaef7d24..d9a5af7d 100644 --- a/internal/maintenance/rescuer.go +++ b/internal/maintenance/rescuer.go @@ -146,10 +146,15 @@ func (s *Rescuer) Start(ctx context.Context) error { } type rescuerRunOnceResult struct { + NumJobsCancelled int64 NumJobsDiscarded int64 NumJobsRetried int64 } +type metadataWithCancelAttemptedAt struct { + CancelAttemptedAt time.Time `json:"cancel_attempted_at"` +} + func (s *Rescuer) runOnce(ctx context.Context) (*rescuerRunOnceResult, error) { res := &rescuerRunOnceResult{} @@ -174,6 +179,11 @@ func (s *Rescuer) runOnce(ctx context.Context) (*rescuerRunOnceResult, error) { for i, job := range stuckJobs { rescueManyParams.ID[i] = job.ID + var metadata metadataWithCancelAttemptedAt + if err := json.Unmarshal(job.Metadata, &metadata); err != nil { + return nil, fmt.Errorf("error unmarshaling job metadata: %w", err) + } + rescueManyParams.Error[i], err = json.Marshal(rivertype.AttemptError{ At: now, Attempt: max(int(job.Attempt), 0), @@ -184,6 +194,13 @@ func (s *Rescuer) runOnce(ctx context.Context) (*rescuerRunOnceResult, error) { return nil, fmt.Errorf("error marshaling error JSON: %w", err) } + if !metadata.CancelAttemptedAt.IsZero() { + res.NumJobsCancelled++ + rescueManyParams.FinalizedAt[i] = now + rescueManyParams.ScheduledAt[i] = job.ScheduledAt // reuse previous value + rescueManyParams.State[i] = string(dbsqlc.JobStateCancelled) + continue + } shouldRetry, retryAt := s.makeRetryDecision(ctx, job) if shouldRetry { res.NumJobsRetried++ diff --git a/internal/maintenance/rescuer_test.go b/internal/maintenance/rescuer_test.go index 2c7947b1..e3028074 100644 --- a/internal/maintenance/rescuer_test.go +++ b/internal/maintenance/rescuer_test.go @@ -2,6 +2,7 @@ package maintenance import ( "context" + "fmt" "math" "testing" "time" @@ -65,6 +66,7 @@ func TestRescuer(t *testing.T) { Attempt int16 AttemptedAt *time.Time MaxAttempts int16 + Metadata []byte State dbsqlc.JobState } @@ -75,6 +77,7 @@ func TestRescuer(t *testing.T) { Args: []byte("{}"), Kind: rescuerJobKind, MaxAttempts: 5, + Metadata: params.Metadata, Priority: int16(rivercommon.PriorityDefault), Queue: rivercommon.QueueDefault, State: params.State, @@ -151,9 +154,15 @@ func TestRescuer(t *testing.T) { stuckToDiscardJob1 := insertJob(ctx, bundle.tx, insertJobParams{State: dbsqlc.JobStateRunning, Attempt: 5, AttemptedAt: ptrutil.Ptr(bundle.rescueHorizon.Add(-1 * time.Hour))}) stuckToDiscardJob2 := insertJob(ctx, bundle.tx, insertJobParams{State: dbsqlc.JobStateRunning, Attempt: 5, AttemptedAt: ptrutil.Ptr(bundle.rescueHorizon.Add(1 * time.Minute))}) // won't be rescued + // Marked as cancelled by query: + cancelTime := time.Now().UTC().Format(time.RFC3339Nano) + stuckToCancelJob1 := insertJob(ctx, bundle.tx, insertJobParams{State: dbsqlc.JobStateRunning, AttemptedAt: ptrutil.Ptr(bundle.rescueHorizon.Add(-1 * time.Hour)), Metadata: []byte(fmt.Sprintf(`{"cancel_attempted_at": %q}`, cancelTime))}) + stuckToCancelJob2 := insertJob(ctx, bundle.tx, insertJobParams{State: dbsqlc.JobStateRunning, AttemptedAt: ptrutil.Ptr(bundle.rescueHorizon.Add(1 * time.Minute)), Metadata: []byte(fmt.Sprintf(`{"cancel_attempted_at": %q}`, cancelTime))}) // won't be rescued + // these aren't touched: notRunningJob1 := insertJob(ctx, bundle.tx, insertJobParams{State: dbsqlc.JobStateCompleted, AttemptedAt: ptrutil.Ptr(bundle.rescueHorizon.Add(-1 * time.Hour))}) notRunningJob2 := insertJob(ctx, bundle.tx, insertJobParams{State: dbsqlc.JobStateDiscarded, AttemptedAt: ptrutil.Ptr(bundle.rescueHorizon.Add(-1 * time.Hour))}) + notRunningJob3 := insertJob(ctx, bundle.tx, insertJobParams{State: dbsqlc.JobStateCancelled, AttemptedAt: ptrutil.Ptr(bundle.rescueHorizon.Add(-1 * time.Hour))}) require.NoError(cleaner.Start(ctx)) @@ -184,12 +193,26 @@ func TestRescuer(t *testing.T) { require.Equal(dbsqlc.JobStateRunning, discard2After.State) require.Nil(discard2After.FinalizedAt) + cancel1After, err := queries.JobGetByID(ctx, bundle.tx, stuckToCancelJob1.ID) + require.NoError(err) + require.Equal(dbsqlc.JobStateCancelled, cancel1After.State) + require.WithinDuration(time.Now(), *cancel1After.FinalizedAt, 5*time.Second) + require.Len(cancel1After.Errors, 1) + + cancel2After, err := queries.JobGetByID(ctx, bundle.tx, stuckToCancelJob2.ID) + require.NoError(err) + require.Equal(dbsqlc.JobStateRunning, cancel2After.State) + require.Nil(cancel2After.FinalizedAt) + notRunning1After, err := queries.JobGetByID(ctx, bundle.tx, notRunningJob1.ID) require.NoError(err) require.Equal(notRunning1After.State, notRunningJob1.State) notRunning2After, err := queries.JobGetByID(ctx, bundle.tx, notRunningJob2.ID) require.NoError(err) require.Equal(notRunning2After.State, notRunningJob2.State) + notRunning3After, err := queries.JobGetByID(ctx, bundle.tx, notRunningJob3.ID) + require.NoError(err) + require.Equal(notRunning3After.State, notRunningJob3.State) }) t.Run("RescuesInBatches", func(t *testing.T) { diff --git a/internal/notifier/notifier.go b/internal/notifier/notifier.go index e95603b5..8b842737 100644 --- a/internal/notifier/notifier.go +++ b/internal/notifier/notifier.go @@ -23,6 +23,7 @@ type NotificationTopic string const ( NotificationTopicInsert NotificationTopic = "river_insert" NotificationTopicLeadership NotificationTopic = "river_leadership" + NotificationTopicJobControl NotificationTopic = "river_job_control" ) type NotifyFunc func(topic NotificationTopic, payload string) diff --git a/job_executor.go b/job_executor.go index d873d4a8..3f2bbd23 100644 --- a/job_executor.go +++ b/job_executor.go @@ -91,6 +91,8 @@ func (e *jobSnoozeError) Is(target error) bool { return ok } +var ErrJobCancelledRemotely = JobCancel(errors.New("job cancelled remotely")) + type jobExecutorResult struct { Err error NextRetry time.Time @@ -116,6 +118,7 @@ type jobExecutor struct { baseservice.BaseService Adapter dbadapter.Adapter + CancelFunc context.CancelCauseFunc ClientJobTimeout time.Duration Completer jobcompleter.JobCompleter ClientRetryPolicy ClientRetryPolicy @@ -130,6 +133,11 @@ type jobExecutor struct { stats *jobstats.JobStatistics // initialized by the executor, and handed off to completer } +func (e *jobExecutor) Cancel() { + e.Logger.Warn(e.Name+": job cancelled remotely", slog.Int64("job_id", e.JobRow.ID)) + e.CancelFunc(ErrJobCancelledRemotely) +} + func (e *jobExecutor) Execute(ctx context.Context) { e.start = e.TimeNowUTC() e.stats = &jobstats.JobStatistics{ @@ -137,6 +145,9 @@ func (e *jobExecutor) Execute(ctx context.Context) { } res := e.execute(ctx) + if res.Err != nil && errors.Is(context.Cause(ctx), ErrJobCancelledRemotely) { + res.Err = context.Cause(ctx) + } e.reportResult(ctx, res) diff --git a/job_executor_test.go b/job_executor_test.go index 69b00b5c..b01ca605 100644 --- a/job_executor_test.go +++ b/job_executor_test.go @@ -606,6 +606,62 @@ func TestJobExecutor_Execute(t *testing.T) { require.True(t, bundle.errorHandler.HandlePanicCalled) }) + + runCancelTest := func(t *testing.T, returnErr error) *dbsqlc.RiverJob { //nolint:thelper + executor, bundle := setup(t) + + // ensure we still have remaining attempts: + require.Greater(t, bundle.jobRow.MaxAttempts, bundle.jobRow.Attempt) + + jobStarted := make(chan struct{}) + haveCancelled := make(chan struct{}) + executor.WorkUnit = newWorkUnitFactoryWithCustomRetry(func() error { + close(jobStarted) + <-haveCancelled + return returnErr + }, nil).MakeUnit(bundle.jobRow) + + go func() { + <-jobStarted + executor.Cancel() + close(haveCancelled) + }() + + workCtx, cancelFunc := context.WithCancelCause(ctx) + executor.CancelFunc = cancelFunc + + executor.Execute(workCtx) + executor.Completer.Wait() + + job, err := queries.JobGetByID(ctx, bundle.tx, bundle.jobRow.ID) + require.NoError(t, err) + return job + } + + t.Run("RemoteCancellationViaCancel", func(t *testing.T) { + t.Parallel() + + job := runCancelTest(t, errors.New("a non-nil error")) + + require.WithinDuration(t, time.Now(), *job.FinalizedAt, 2*time.Second) + require.Equal(t, dbsqlc.JobStateCancelled, job.State) + require.Len(t, job.Errors, 1) + require.WithinDuration(t, time.Now(), job.Errors[0].At, 2*time.Second) + require.Equal(t, uint16(1), job.Errors[0].Attempt) + require.Equal(t, "jobCancelError: job cancelled remotely", job.Errors[0].Error) + require.Equal(t, ErrJobCancelledRemotely.Error(), job.Errors[0].Error) + require.Equal(t, "", job.Errors[0].Trace) + }) + + t.Run("RemoteCancellationJobNotCancelledIfNoErrorReturned", func(t *testing.T) { + t.Parallel() + + job := runCancelTest(t, nil) + + require.WithinDuration(t, time.Now(), *job.FinalizedAt, 2*time.Second) + require.Equal(t, dbsqlc.JobStateCompleted, job.State) + require.Empty(t, job.Errors) + }) } func TestUnknownJobKindError_As(t *testing.T) { diff --git a/producer.go b/producer.go index 41ac946c..c79b51dc 100644 --- a/producer.go +++ b/producer.go @@ -63,6 +63,10 @@ type producer struct { errorHandler ErrorHandler workers *Workers + // Receives job IDs to cancel. Written by notifier goroutine, only read from + // main goroutine. + cancelCh chan int64 + // Receives completed jobs from workers. Written by completed workers, only // read from main goroutine. jobResultCh chan *rivertype.JobRow @@ -119,6 +123,7 @@ func newProducer(archetype *baseservice.Archetype, adapter dbadapter.Adapter, co return baseservice.Init(archetype, &producer{ activeJobs: make(map[int64]*jobExecutor), adapter: adapter, + cancelCh: make(chan int64, 1000), completer: completer, config: config, errorHandler: config.ErrorHandler, @@ -150,12 +155,47 @@ func (p *producer) Run(fetchCtx, workCtx context.Context, statusFunc producerSta // TODO: fetcher should have some jitter in it to avoid stampeding issues. fetchLimiter := chanutil.NewDebouncedChan(fetchCtx, p.config.FetchCooldown) + handleJobControlNotification := func(topic notifier.NotificationTopic, payload string) { + var decoded jobControlPayload + if err := json.Unmarshal([]byte(payload), &decoded); err != nil { + p.Logger.ErrorContext(workCtx, p.Name+": Failed to unmarshal job control notification payload", slog.String("err", err.Error())) + return + } + if string(decoded.Action) == string(jobControlActionCancel) && decoded.Queue == p.config.QueueName && decoded.JobID > 0 { + select { + case p.cancelCh <- decoded.JobID: + default: + p.Logger.WarnContext(workCtx, p.Name+": Job cancel notification dropped due to full buffer", slog.Int64("job_id", decoded.JobID)) + } + return + } + p.Logger.DebugContext(workCtx, p.Name+": Received job control notification with unknown action or other queue", + slog.String("action", string(decoded.Action)), + slog.Int64("job_id", decoded.JobID), + slog.String("queue", decoded.Queue), + ) + } + sub := p.config.Notifier.Listen(notifier.NotificationTopicJobControl, handleJobControlNotification) + defer sub.Unlisten() + p.fetchAndRunLoop(fetchCtx, workCtx, fetchLimiter, statusFunc) statusFunc(p.config.QueueName, componentstatus.ShuttingDown) p.executorShutdownLoop() statusFunc(p.config.QueueName, componentstatus.Stopped) } +type jobControlAction string + +const ( + jobControlActionCancel jobControlAction = "cancel" +) + +type jobControlPayload struct { + Action jobControlAction `json:"action"` + JobID int64 `json:"job_id"` + Queue string `json:"queue"` +} + type insertPayload struct { Queue string `json:"queue"` } @@ -237,6 +277,8 @@ func (p *producer) innerFetchLoop(workCtx context.Context, fetchResultCh chan pr return case result := <-p.jobResultCh: p.removeActiveJob(result.ID) + case jobID := <-p.cancelCh: + p.maybeCancelJob(jobID) } } } @@ -264,6 +306,14 @@ func (p *producer) removeActiveJob(id int64) { p.numJobsRan.Add(1) } +func (p *producer) maybeCancelJob(id int64) { + executor, ok := p.activeJobs[id] + if !ok { + return + } + executor.Cancel() +} + func (p *producer) dispatchWork(count int32, jobsFetchedCh chan<- producerFetchResult) { // This intentionally uses a background context because we don't want it to // get cancelled if the producer is asked to shut down. In that situation, we @@ -308,8 +358,11 @@ func (p *producer) startNewExecutors(workCtx context.Context, jobs []*rivertype. workUnit = workInfo.workUnitFactory.MakeUnit(job) } + jobCtx, jobCancel := context.WithCancelCause(workCtx) + executor := baseservice.Init(&p.Archetype, &jobExecutor{ Adapter: p.adapter, + CancelFunc: jobCancel, ClientJobTimeout: p.jobTimeout, ClientRetryPolicy: p.retryPolicy, Completer: p.completer, @@ -321,7 +374,7 @@ func (p *producer) startNewExecutors(workCtx context.Context, jobs []*rivertype. }) p.addActiveJob(job.ID, executor) - go executor.Execute(workCtx) + go executor.Execute(jobCtx) // TODO: // Errors can be recorded synchronously before the Executor slot is considered // available.