From 38ae6c1f4e3eebf2cfcfceffd69ee3f212adc3e0 Mon Sep 17 00:00:00 2001 From: Blake Gentry Date: Thu, 30 May 2024 20:11:46 -0500 Subject: [PATCH 1/2] extract Client subscriptions into service This change extracts the Client subscriptions logic into a separate `startstop.Service` which can be started and stopped along with the other services. The important change that enables this is switching from a _callback_ for job events to a _channel_ for job events. The channel is passed to the completer during init, and the completer then owns it as the sole sender. When the completer is stopped, it must close the channel to indicate that there are no more job completion events to be processed. This moves us closer to having all the key client services be able to be managed as a single pool of services, and they can all have their shutdown initiated in parallel. Importantly, this paves the way for additional services to be added (even by external libraries) without needing to deal with more complex startup & shutdown ordering scenarios. In order to make this work with a client that can be started and stopped repeatedly, a new `ResetSubscribeChan` method was added to the `JobCompleter` interface to be called at the beginning of each `Client.Start()` call. --- client.go | 184 ++------------- client_test.go | 2 +- internal/jobcompleter/job_completer.go | 145 +++++++----- internal/jobcompleter/job_completer_test.go | 218 +++++++++--------- job_executor_test.go | 87 ++++--- producer_test.go | 59 +++-- riverdriver/river_driver_interface.go | 2 +- subscription_manager.go | 242 ++++++++++++++++++++ subscription_manager_test.go | 126 ++++++++++ 9 files changed, 665 insertions(+), 400 deletions(-) create mode 100644 subscription_manager.go create mode 100644 subscription_manager_test.go diff --git a/client.go b/client.go index 4b95cd58..700298b8 100644 --- a/client.go +++ b/client.go @@ -9,7 +9,6 @@ import ( "os" "regexp" "strings" - "sync" "time" "github.com/riverqueue/river/internal/baseservice" @@ -17,7 +16,6 @@ import ( "github.com/riverqueue/river/internal/dblist" "github.com/riverqueue/river/internal/dbunique" "github.com/riverqueue/river/internal/jobcompleter" - "github.com/riverqueue/river/internal/jobstats" "github.com/riverqueue/river/internal/leadership" "github.com/riverqueue/river/internal/maintenance" "github.com/riverqueue/river/internal/maintenance/startstop" @@ -304,6 +302,7 @@ type Client[TTx any] struct { baseStartStop startstop.BaseStartStop completer jobcompleter.JobCompleter + completerSubscribeCh chan []jobcompleter.CompleterJobUpdated config *Config driver riverdriver.Driver[TTx] elector *leadership.Elector @@ -314,12 +313,7 @@ type Client[TTx any] struct { producersByQueueName map[string]*producer queueMaintainer *maintenance.QueueMaintainer services []startstop.Service - subscriptions map[int]*eventSubscription - subscriptionsMu sync.Mutex - subscriptionsSeq int // used for generating simple IDs - statsAggregate jobstats.JobStatistics - statsMu sync.Mutex - statsNumJobs int + subscriptionManager *subscriptionManager stopped chan struct{} testSignals clientTestSignals uniqueInserter *dbunique.UniqueInserter @@ -471,7 +465,6 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client driver: driver, monitor: newClientMonitor(), producersByQueueName: make(map[string]*producer), - subscriptions: make(map[int]*eventSubscription), testSignals: clientTestSignals{}, uniqueInserter: baseservice.Init(archetype, &dbunique.UniqueInserter{ AdvisoryLockPrefix: config.AdvisoryLockPrefix, @@ -490,8 +483,9 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client return nil, errMissingDatabasePoolWithQueues } - client.completer = jobcompleter.NewBatchCompleter(archetype, driver.GetExecutor()) - client.services = append(client.services, client.completer) + client.completer = jobcompleter.NewBatchCompleter(archetype, driver.GetExecutor(), nil) + client.subscriptionManager = newSubscriptionManager(archetype, nil) + client.services = append(client.services, client.completer, client.subscriptionManager) // In poll only mode, we don't try to initialize a notifier that uses // listen/notify. Instead, each service polls for changes it's @@ -517,7 +511,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client MaxWorkers: queueConfig.MaxWorkers, Notifier: client.notifier, Queue: queue, - QueueEventCallback: client.distributeQueueEvent, + QueueEventCallback: client.subscriptionManager.distributeQueueEvent, RetryPolicy: config.RetryPolicy, SchedulerInterval: config.schedulerInterval, StatusFunc: client.monitor.SetProducerStatus, @@ -666,6 +660,13 @@ func (c *Client[TTx]) Start(ctx context.Context) error { return fmt.Errorf("error making initial connection to database: %w", err) } + // Each time we start, we need a fresh completer subscribe channel to + // send job completion events on, because the completer will close it + // each time it shuts down. + c.completerSubscribeCh = make(chan []jobcompleter.CompleterJobUpdated, 10) + c.completer.ResetSubscribeChan(c.completerSubscribeCh) + c.subscriptionManager.ResetSubscribeChan(c.completerSubscribeCh) + // In case of error, stop any services that might have started. This // is safe because even services that were never started will still // tolerate being stopped. @@ -695,10 +696,6 @@ func (c *Client[TTx]) Start(ctx context.Context) error { return err } - // Receives job complete notifications from the completer and - // distributes them to any subscriptions. - c.completer.Subscribe(c.distributeJobCompleterCallback) - // We use separate contexts for fetching and working to allow for a graceful // stop. Both inherit from the provided context, so if it's cancelled, a // more aggressive stop will be initiated. @@ -758,17 +755,6 @@ func (c *Client[TTx]) Start(ctx context.Context) error { c.queueMaintainer, )) - // Remove all subscriptions and close corresponding channels. - func() { - c.subscriptionsMu.Lock() - defer c.subscriptionsMu.Unlock() - - for subID, sub := range c.subscriptions { - close(sub.Chan) - delete(c.subscriptions, subID) - } - }() - // Shut down the monitor last so it can broadcast final status updates: c.monitor.Stop() }() @@ -881,125 +867,7 @@ type SubscribeConfig struct { // Special internal variant that lets us inject an overridden size. func (c *Client[TTx]) SubscribeConfig(config *SubscribeConfig) (<-chan *Event, func()) { - if config.ChanSize < 0 { - panic("SubscribeConfig.ChanSize must be greater or equal to 1") - } - if config.ChanSize == 0 { - config.ChanSize = subscribeChanSizeDefault - } - - for _, kind := range config.Kinds { - if _, ok := allKinds[kind]; !ok { - panic(fmt.Errorf("unknown event kind: %s", kind)) - } - } - - c.subscriptionsMu.Lock() - defer c.subscriptionsMu.Unlock() - - subChan := make(chan *Event, config.ChanSize) - - // Just gives us an easy way of removing the subscription again later. - subID := c.subscriptionsSeq - c.subscriptionsSeq++ - - c.subscriptions[subID] = &eventSubscription{ - Chan: subChan, - Kinds: sliceutil.KeyBy(config.Kinds, func(k EventKind) (EventKind, struct{}) { return k, struct{}{} }), - } - - cancel := func() { - c.subscriptionsMu.Lock() - defer c.subscriptionsMu.Unlock() - - // May no longer be present in case this was called after a stop. - sub, ok := c.subscriptions[subID] - if !ok { - return - } - - close(sub.Chan) - - delete(c.subscriptions, subID) - } - - return subChan, cancel -} - -// Distribute a single event into any listening subscriber channels. -// -// Job events should specify the job and stats, while queue events should only specify -// the queue. -func (c *Client[TTx]) distributeJobEvent(job *rivertype.JobRow, stats *JobStatistics) { - c.subscriptionsMu.Lock() - defer c.subscriptionsMu.Unlock() - - // Quick path so we don't need to allocate anything if no one is listening. - if len(c.subscriptions) < 1 { - return - } - - var event *Event - switch job.State { - case rivertype.JobStateCancelled: - event = &Event{Kind: EventKindJobCancelled, Job: job, JobStats: stats} - case rivertype.JobStateCompleted: - event = &Event{Kind: EventKindJobCompleted, Job: job, JobStats: stats} - case rivertype.JobStateScheduled: - event = &Event{Kind: EventKindJobSnoozed, Job: job, JobStats: stats} - case rivertype.JobStateAvailable, rivertype.JobStateDiscarded, rivertype.JobStateRetryable, rivertype.JobStateRunning: - event = &Event{Kind: EventKindJobFailed, Job: job, JobStats: stats} - case rivertype.JobStatePending: - panic("completion subscriber unexpectedly received job in pending state, river bug") - default: - // linter exhaustive rule prevents this from being reached - panic("unreachable state to distribute, river bug") - } - - // All subscription channels are non-blocking so this is always fast and - // there's no risk of falling behind what producers are sending. - for _, sub := range c.subscriptions { - if sub.ListensFor(event.Kind) { - select { - case sub.Chan <- event: - default: - } - } - } -} - -func (c *Client[TTx]) distributeQueueEvent(event *Event) { - c.subscriptionsMu.Lock() - defer c.subscriptionsMu.Unlock() - - // All subscription channels are non-blocking so this is always fast and - // there's no risk of falling behind what producers are sending. - for _, sub := range c.subscriptions { - if sub.ListensFor(event.Kind) { - select { - case sub.Chan <- event: - default: - } - } - } -} - -// Callback invoked by the completer and which prompts the client to update -// statistics and distribute jobs into any listening subscriber channels. -// (Subscriber channels are non-blocking so this should be quite fast.) -func (c *Client[TTx]) distributeJobCompleterCallback(update jobcompleter.CompleterJobUpdated) { - func() { - c.statsMu.Lock() - defer c.statsMu.Unlock() - - stats := update.JobStats - c.statsAggregate.CompleteDuration += stats.CompleteDuration - c.statsAggregate.QueueWaitDuration += stats.QueueWaitDuration - c.statsAggregate.RunDuration += stats.RunDuration - c.statsNumJobs++ - }() - - c.distributeJobEvent(update.Job, jobStatisticsFromInternal(update.JobStats)) + return c.subscriptionManager.SubscribeConfig(config) } // Dump aggregate stats from job completions to logs periodically. These @@ -1007,28 +875,6 @@ func (c *Client[TTx]) distributeJobCompleterCallback(update jobcompleter.Complet // proportions of each compared to each other, and may help flag outlying values // indicative of a problem. func (c *Client[TTx]) logStatsLoop(ctx context.Context, shouldStart bool, stopped chan struct{}) error { - // Handles a potential divide by zero. - safeDurationAverage := func(d time.Duration, n int) time.Duration { - if n == 0 { - return 0 - } - return d / time.Duration(n) - } - - logStats := func() { - c.statsMu.Lock() - defer c.statsMu.Unlock() - - c.baseService.Logger.InfoContext(ctx, c.baseService.Name+": Job stats (since last stats line)", - "num_jobs_run", c.statsNumJobs, - "average_complete_duration", safeDurationAverage(c.statsAggregate.CompleteDuration, c.statsNumJobs), - "average_queue_wait_duration", safeDurationAverage(c.statsAggregate.QueueWaitDuration, c.statsNumJobs), - "average_run_duration", safeDurationAverage(c.statsAggregate.RunDuration, c.statsNumJobs)) - - c.statsAggregate = jobstats.JobStatistics{} - c.statsNumJobs = 0 - } - if !shouldStart { return nil } @@ -1047,7 +893,7 @@ func (c *Client[TTx]) logStatsLoop(ctx context.Context, shouldStart bool, stoppe return case <-ticker.C: - logStats() + c.subscriptionManager.logStats(ctx, c.baseService.Name) } } }() diff --git a/client_test.go b/client_test.go index 9b8cfc04..754e17bb 100644 --- a/client_test.go +++ b/client_test.go @@ -2957,7 +2957,7 @@ func Test_Client_Subscribe(t *testing.T) { // Drops through immediately because the channel is closed. riverinternaltest.WaitOrTimeout(t, subscribeChan) - require.Empty(t, client.subscriptions) + require.Empty(t, client.subscriptionManager.subscriptions) }) } diff --git a/internal/jobcompleter/job_completer.go b/internal/jobcompleter/job_completer.go index a1c8c6a2..d8b72195 100644 --- a/internal/jobcompleter/job_completer.go +++ b/internal/jobcompleter/job_completer.go @@ -11,6 +11,7 @@ import ( "github.com/riverqueue/river/internal/baseservice" "github.com/riverqueue/river/internal/jobstats" "github.com/riverqueue/river/internal/maintenance/startstop" + "github.com/riverqueue/river/internal/util/sliceutil" "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/rivertype" ) @@ -27,11 +28,16 @@ type JobCompleter interface { // still running (i.e. its state has not changed to something else already). JobSetStateIfRunning(ctx context.Context, stats *jobstats.JobStatistics, params *riverdriver.JobSetStateIfRunningParams) error - // Subscribe injects a callback which will be invoked whenever a job is - // updated. - Subscribe(subscribeFunc func(update CompleterJobUpdated)) + // ResetSubscribeChan resets the subscription channel for the completer. It + // must only be called when the completer is stopped. + ResetSubscribeChan(subscribeCh SubscribeChan) } +type SubscribeChan chan<- []CompleterJobUpdated + +// SubscribeFunc will be invoked whenever a job is updated. +type SubscribeFunc func(update CompleterJobUpdated) + type CompleterJobUpdated struct { Job *rivertype.JobRow JobStats *jobstats.JobStatistics @@ -47,10 +53,11 @@ type PartialExecutor interface { type InlineCompleter struct { baseservice.BaseService - withSubscribe + startstop.BaseStartStop disableSleep bool // disable sleep in testing exec PartialExecutor + subscribeCh SubscribeChan // A waitgroup is not actually needed for the inline completer because as // long as the caller is waiting on each function call, completion is @@ -60,9 +67,10 @@ type InlineCompleter struct { wg sync.WaitGroup } -func NewInlineCompleter(archetype *baseservice.Archetype, exec PartialExecutor) *InlineCompleter { +func NewInlineCompleter(archetype *baseservice.Archetype, exec PartialExecutor, subscribeCh SubscribeChan) *InlineCompleter { return baseservice.Init(archetype, &InlineCompleter{ - exec: exec, + exec: exec, + subscribeCh: subscribeCh, }) } @@ -80,15 +88,35 @@ func (c *InlineCompleter) JobSetStateIfRunning(ctx context.Context, stats *jobst } stats.CompleteDuration = c.TimeNowUTC().Sub(start) - c.sendJobToSubscription(job, stats) + c.subscribeCh <- []CompleterJobUpdated{{Job: job, JobStats: stats}} return nil } -func (c *InlineCompleter) Start(ctx context.Context) error { return nil } +func (c *InlineCompleter) ResetSubscribeChan(subscribeCh SubscribeChan) { + c.subscribeCh = subscribeCh +} + +func (c *InlineCompleter) Start(ctx context.Context) error { + ctx, shouldStart, stopped := c.StartInit(ctx) + if !shouldStart { + return nil + } + + if c.subscribeCh == nil { + panic("subscribeCh must be non-nil") + } + + go func() { + defer close(stopped) + defer close(c.subscribeCh) -func (c *InlineCompleter) Stop() { - c.wg.Wait() + <-ctx.Done() + + c.wg.Wait() + }() + + return nil } // A default concurrency of 100 seems to perform better a much smaller number @@ -100,19 +128,20 @@ const asyncCompleterDefaultConcurrency = 100 type AsyncCompleter struct { baseservice.BaseService - withSubscribe + startstop.BaseStartStop concurrency int disableSleep bool // disable sleep in testing errGroup *errgroup.Group exec PartialExecutor + subscribeCh SubscribeChan } -func NewAsyncCompleter(archetype *baseservice.Archetype, exec PartialExecutor) *AsyncCompleter { - return newAsyncCompleterWithConcurrency(archetype, exec, asyncCompleterDefaultConcurrency) +func NewAsyncCompleter(archetype *baseservice.Archetype, exec PartialExecutor, subscribeCh SubscribeChan) *AsyncCompleter { + return newAsyncCompleterWithConcurrency(archetype, exec, asyncCompleterDefaultConcurrency, subscribeCh) } -func newAsyncCompleterWithConcurrency(archetype *baseservice.Archetype, exec PartialExecutor, concurrency int) *AsyncCompleter { +func newAsyncCompleterWithConcurrency(archetype *baseservice.Archetype, exec PartialExecutor, concurrency int, subscribeCh SubscribeChan) *AsyncCompleter { errGroup := &errgroup.Group{} errGroup.SetLimit(concurrency) @@ -120,6 +149,7 @@ func newAsyncCompleterWithConcurrency(archetype *baseservice.Archetype, exec Par exec: exec, concurrency: concurrency, errGroup: errGroup, + subscribeCh: subscribeCh, }) } @@ -137,19 +167,39 @@ func (c *AsyncCompleter) JobSetStateIfRunning(ctx context.Context, stats *jobsta } stats.CompleteDuration = c.TimeNowUTC().Sub(start) - c.sendJobToSubscription(job, stats) + c.subscribeCh <- []CompleterJobUpdated{{Job: job, JobStats: stats}} return nil }) return nil } -func (c *AsyncCompleter) Start(ctx context.Context) error { return nil } +func (c *AsyncCompleter) ResetSubscribeChan(subscribeCh SubscribeChan) { + c.subscribeCh = subscribeCh +} -func (c *AsyncCompleter) Stop() { - if err := c.errGroup.Wait(); err != nil { - c.Logger.Error("Error waiting on async completer: %s", err) +func (c *AsyncCompleter) Start(ctx context.Context) error { + ctx, shouldStart, stopped := c.StartInit(ctx) + if !shouldStart { + return nil + } + + if c.subscribeCh == nil { + panic("subscribeCh must be non-nil") } + + go func() { + defer close(stopped) + defer close(c.subscribeCh) + + <-ctx.Done() + + if err := c.errGroup.Wait(); err != nil { + c.Logger.Error("Error waiting on async completer: %s", err) + } + }() + + return nil } type batchCompleterSetState struct { @@ -166,7 +216,6 @@ type batchCompleterSetState struct { type BatchCompleter struct { baseservice.BaseService startstop.BaseStartStop - withSubscribe asyncCompleter *AsyncCompleter // used for non-complete completions completionMaxSize int // configurable for testing purposes; max jobs to complete in single database operation @@ -176,31 +225,46 @@ type BatchCompleter struct { setStateParams map[int64]*batchCompleterSetState setStateParamsMu sync.RWMutex started chan struct{} + subscribeCh SubscribeChan waitOnBacklogChan chan struct{} waitOnBacklogWaiting bool } -func NewBatchCompleter(archetype *baseservice.Archetype, exec PartialExecutor) *BatchCompleter { +func NewBatchCompleter(archetype *baseservice.Archetype, exec PartialExecutor, subscribeCh SubscribeChan) *BatchCompleter { const ( completionMaxSize = 5_000 maxBacklog = 20_000 ) return baseservice.Init(archetype, &BatchCompleter{ - asyncCompleter: NewAsyncCompleter(archetype, exec), + asyncCompleter: NewAsyncCompleter(archetype, exec, subscribeCh), completionMaxSize: completionMaxSize, exec: exec, maxBacklog: maxBacklog, setStateParams: make(map[int64]*batchCompleterSetState), + subscribeCh: subscribeCh, }) } +func (c *BatchCompleter) ResetSubscribeChan(subscribeCh SubscribeChan) { + c.subscribeCh = subscribeCh + c.asyncCompleter.subscribeCh = subscribeCh +} + func (c *BatchCompleter) Start(ctx context.Context) error { stopCtx, shouldStart, stopped := c.StartInit(ctx) if !shouldStart { return nil } + if c.subscribeCh == nil { + panic("subscribeCh must be non-nil") + } + + if err := c.asyncCompleter.Start(ctx); err != nil { + return err + } + c.started = make(chan struct{}) go func() { @@ -347,11 +411,13 @@ func (c *BatchCompleter) handleBatch(ctx context.Context) error { } } - for _, jobRow := range jobRows { + events := sliceutil.Map(jobRows, func(jobRow *rivertype.JobRow) CompleterJobUpdated { setState := setStateBatch[jobRow.ID] setState.Stats.CompleteDuration = c.TimeNowUTC().Sub(*setState.Params.FinalizedAt) - c.sendJobToSubscription(jobRow, setState.Stats) - } + return CompleterJobUpdated{Job: jobRow, JobStats: setState.Stats} + }) + + c.subscribeCh <- events func() { c.setStateParamsMu.Lock() @@ -391,11 +457,7 @@ func (c *BatchCompleter) JobSetStateIfRunning(ctx context.Context, stats *jobsta func (c *BatchCompleter) Stop() { c.BaseStartStop.Stop() c.asyncCompleter.Stop() -} - -func (c *BatchCompleter) Subscribe(subscribeFunc func(update CompleterJobUpdated)) { - c.withSubscribe.Subscribe(subscribeFunc) - c.asyncCompleter.Subscribe(subscribeFunc) + // subscribeCh already closed by asyncCompleter.Stop ^ } func (c *BatchCompleter) WaitStarted() <-chan struct{} { @@ -496,29 +558,6 @@ func withRetries[T any](logCtx context.Context, baseService *baseservice.BaseSer return defaultVal, lastErr } -// Utility struct embedded in completers to give them an easy way to provide a -// Subscribe function and to handle locking around its use. -type withSubscribe struct { - subscribeFunc func(update CompleterJobUpdated) - subscribeFuncMu sync.RWMutex -} - -func (c *withSubscribe) Subscribe(subscribeFunc func(update CompleterJobUpdated)) { - c.subscribeFuncMu.Lock() - defer c.subscribeFuncMu.Unlock() - - c.subscribeFunc = subscribeFunc -} - -func (c *withSubscribe) sendJobToSubscription(job *rivertype.JobRow, stats *jobstats.JobStatistics) { - c.subscribeFuncMu.RLock() - defer c.subscribeFuncMu.RUnlock() - - if c.subscribeFunc != nil { - c.subscribeFunc(CompleterJobUpdated{Job: job, JobStats: stats}) - } -} - // withWaitStarted is an additional completer interface that can wait on the // completer to full start, and which is used by benchmarks. // diff --git a/internal/jobcompleter/job_completer_test.go b/internal/jobcompleter/job_completer_test.go index 9a82c6a7..789e7799 100644 --- a/internal/jobcompleter/job_completer_test.go +++ b/internal/jobcompleter/job_completer_test.go @@ -17,7 +17,6 @@ import ( "github.com/riverqueue/river/internal/riverinternaltest" "github.com/riverqueue/river/internal/riverinternaltest/testfactory" "github.com/riverqueue/river/internal/util/ptrutil" - "github.com/riverqueue/river/internal/util/randutil" "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/riverdriver/riverpgxv5" "github.com/riverqueue/river/rivertype" @@ -71,7 +70,10 @@ func TestInlineJobCompleter_Complete(t *testing.T) { }, } - completer := NewInlineCompleter(riverinternaltest.BaseServiceArchetype(t), execMock) + subscribeCh := make(chan []CompleterJobUpdated, 10) + t.Cleanup(riverinternaltest.DiscardContinuously(subscribeCh)) + + completer := NewInlineCompleter(riverinternaltest.BaseServiceArchetype(t), execMock, subscribeCh) t.Cleanup(completer.Stop) completer.disableSleep = true @@ -87,16 +89,16 @@ func TestInlineJobCompleter_Complete(t *testing.T) { func TestInlineJobCompleter_Subscribe(t *testing.T) { t.Parallel() - testCompleterSubscribe(t, func(exec PartialExecutor) JobCompleter { - return NewInlineCompleter(riverinternaltest.BaseServiceArchetype(t), exec) + testCompleterSubscribe(t, func(exec PartialExecutor, subscribeCh SubscribeChan) JobCompleter { + return NewInlineCompleter(riverinternaltest.BaseServiceArchetype(t), exec, subscribeCh) }) } func TestInlineJobCompleter_Wait(t *testing.T) { t.Parallel() - testCompleterWait(t, func(exec PartialExecutor) JobCompleter { - return NewInlineCompleter(riverinternaltest.BaseServiceArchetype(t), exec) + testCompleterWait(t, func(exec PartialExecutor, subscribeChan SubscribeChan) JobCompleter { + return NewInlineCompleter(riverinternaltest.BaseServiceArchetype(t), exec, subscribeChan) }) } @@ -128,9 +130,11 @@ func TestAsyncJobCompleter_Complete(t *testing.T) { return nil, err }, } - completer := newAsyncCompleterWithConcurrency(riverinternaltest.BaseServiceArchetype(t), adapter, 2) - t.Cleanup(completer.Stop) + subscribeChan := make(chan []CompleterJobUpdated, 10) + completer := newAsyncCompleterWithConcurrency(riverinternaltest.BaseServiceArchetype(t), adapter, 2, subscribeChan) completer.disableSleep = true + require.NoError(t, completer.Start(ctx)) + t.Cleanup(completer.Stop) // launch 4 completions, only 2 can be inline due to the concurrency limit: for i := int64(0); i < 2; i++ { @@ -191,20 +195,20 @@ func TestAsyncJobCompleter_Complete(t *testing.T) { func TestAsyncJobCompleter_Subscribe(t *testing.T) { t.Parallel() - testCompleterSubscribe(t, func(exec PartialExecutor) JobCompleter { - return newAsyncCompleterWithConcurrency(riverinternaltest.BaseServiceArchetype(t), exec, 4) + testCompleterSubscribe(t, func(exec PartialExecutor, subscribeCh SubscribeChan) JobCompleter { + return newAsyncCompleterWithConcurrency(riverinternaltest.BaseServiceArchetype(t), exec, 4, subscribeCh) }) } func TestAsyncJobCompleter_Wait(t *testing.T) { t.Parallel() - testCompleterWait(t, func(exec PartialExecutor) JobCompleter { - return newAsyncCompleterWithConcurrency(riverinternaltest.BaseServiceArchetype(t), exec, 4) + testCompleterWait(t, func(exec PartialExecutor, subscribeCh SubscribeChan) JobCompleter { + return newAsyncCompleterWithConcurrency(riverinternaltest.BaseServiceArchetype(t), exec, 4, subscribeCh) }) } -func testCompleterSubscribe(t *testing.T, constructor func(PartialExecutor) JobCompleter) { +func testCompleterSubscribe(t *testing.T, constructor func(PartialExecutor, SubscribeChan) JobCompleter) { t.Helper() ctx := context.Background() @@ -217,26 +221,38 @@ func testCompleterSubscribe(t *testing.T, constructor func(PartialExecutor) JobC }, } - completer := constructor(exec) + subscribeChan := make(chan []CompleterJobUpdated, 10) + completer := constructor(exec, subscribeChan) + require.NoError(t, completer.Start(ctx)) + // Flatten the slice results from subscribeChan into jobUpdateChan: jobUpdateChan := make(chan CompleterJobUpdated, 10) - completer.Subscribe(func(update CompleterJobUpdated) { - jobUpdateChan <- update - }) + go func() { + defer close(jobUpdateChan) + for update := range subscribeChan { + for _, u := range update { + jobUpdateChan <- u + } + } + }() for i := 0; i < 4; i++ { require.NoError(t, completer.JobSetStateIfRunning(ctx, &jobstats.JobStatistics{}, riverdriver.JobSetStateCompleted(int64(i), time.Now()))) } - completer.Stop() + completer.Stop() // closes subscribeChan updates := riverinternaltest.WaitOrTimeoutN(t, jobUpdateChan, 4) for i := 0; i < 4; i++ { require.Equal(t, rivertype.JobStateCompleted, updates[0].Job.State) } + go completer.Stop() + // drain all remaining jobs + for range jobUpdateChan { + } } -func testCompleterWait(t *testing.T, constructor func(PartialExecutor) JobCompleter) { +func testCompleterWait(t *testing.T, constructor func(PartialExecutor, SubscribeChan) JobCompleter) { t.Helper() ctx := context.Background() @@ -250,8 +266,10 @@ func testCompleterWait(t *testing.T, constructor func(PartialExecutor) JobComple return nil, err }, } + subscribeCh := make(chan []CompleterJobUpdated, 100) - completer := constructor(exec) + completer := constructor(exec, subscribeCh) + require.NoError(t, completer.Start(ctx)) // launch 4 completions: for i := 0; i < 4; i++ { @@ -300,9 +318,9 @@ func testCompleterWait(t *testing.T, constructor func(PartialExecutor) JobComple func TestAsyncCompleter(t *testing.T) { t.Parallel() - testCompleter(t, func(t *testing.T, exec riverdriver.Executor) *AsyncCompleter { + testCompleter(t, func(t *testing.T, exec riverdriver.Executor, subscribeCh chan<- []CompleterJobUpdated) *AsyncCompleter { t.Helper() - return NewAsyncCompleter(riverinternaltest.BaseServiceArchetype(t), exec) + return NewAsyncCompleter(riverinternaltest.BaseServiceArchetype(t), exec, subscribeCh) }, func(completer *AsyncCompleter) { completer.disableSleep = true }, func(completer *AsyncCompleter, exec PartialExecutor) { completer.exec = exec }) @@ -311,9 +329,9 @@ func TestAsyncCompleter(t *testing.T) { func TestBatchCompleter(t *testing.T) { t.Parallel() - testCompleter(t, func(t *testing.T, exec riverdriver.Executor) *BatchCompleter { + testCompleter(t, func(t *testing.T, exec riverdriver.Executor, subscribeCh chan<- []CompleterJobUpdated) *BatchCompleter { t.Helper() - return NewBatchCompleter(riverinternaltest.BaseServiceArchetype(t), exec) + return NewBatchCompleter(riverinternaltest.BaseServiceArchetype(t), exec, subscribeCh) }, func(completer *BatchCompleter) { completer.disableSleep = true }, func(completer *BatchCompleter, exec PartialExecutor) { completer.exec = exec }) @@ -321,16 +339,18 @@ func TestBatchCompleter(t *testing.T) { ctx := context.Background() type testBundle struct { - exec riverdriver.Executor + exec riverdriver.Executor + subscribeCh <-chan []CompleterJobUpdated } setup := func(t *testing.T) (*BatchCompleter, *testBundle) { t.Helper() var ( - driver = riverpgxv5.New(riverinternaltest.TestDB(ctx, t)) - exec = driver.GetExecutor() - completer = NewBatchCompleter(riverinternaltest.BaseServiceArchetype(t), exec) + driver = riverpgxv5.New(riverinternaltest.TestDB(ctx, t)) + exec = driver.GetExecutor() + subscribeCh = make(chan []CompleterJobUpdated, 10) + completer = NewBatchCompleter(riverinternaltest.BaseServiceArchetype(t), exec, subscribeCh) ) require.NoError(t, completer.Start(ctx)) @@ -339,7 +359,8 @@ func TestBatchCompleter(t *testing.T) { riverinternaltest.WaitOrTimeout(t, completer.WaitStarted()) return completer, &testBundle{ - exec: exec, + exec: exec, + subscribeCh: subscribeCh, } } @@ -350,12 +371,14 @@ func TestBatchCompleter(t *testing.T) { completer.completionMaxSize = 10 // set to something artificially low jobUpdateChan := make(chan CompleterJobUpdated, 100) - completer.Subscribe(func(update CompleterJobUpdated) { - select { - case jobUpdateChan <- update: - default: + go func() { + defer close(jobUpdateChan) + for update := range bundle.subscribeCh { + for _, u := range update { + jobUpdateChan <- u + } } - }) + }() stopInsertion := doContinuousInsertion(ctx, t, completer, bundle.exec) @@ -365,6 +388,10 @@ func TestBatchCompleter(t *testing.T) { riverinternaltest.WaitOrTimeoutN(t, jobUpdateChan, 100) stopInsertion() + go completer.Stop() + // drain all remaining jobs + for range jobUpdateChan { + } }) t.Run("BacklogWaitAndContinue", func(t *testing.T) { @@ -374,12 +401,14 @@ func TestBatchCompleter(t *testing.T) { completer.maxBacklog = 10 // set to something artificially low jobUpdateChan := make(chan CompleterJobUpdated, 100) - completer.Subscribe(func(update CompleterJobUpdated) { - select { - case jobUpdateChan <- update: - default: + go func() { + defer close(jobUpdateChan) + for update := range bundle.subscribeCh { + for _, u := range update { + jobUpdateChan <- u + } } - }) + }() stopInsertion := doContinuousInsertion(ctx, t, completer, bundle.exec) @@ -389,15 +418,19 @@ func TestBatchCompleter(t *testing.T) { riverinternaltest.WaitOrTimeoutN(t, jobUpdateChan, 100) stopInsertion() + go completer.Stop() + // drain all remaining jobs + for range jobUpdateChan { + } }) } func TestInlineCompleter(t *testing.T) { t.Parallel() - testCompleter(t, func(t *testing.T, exec riverdriver.Executor) *InlineCompleter { + testCompleter(t, func(t *testing.T, exec riverdriver.Executor, subscribeCh chan<- []CompleterJobUpdated) *InlineCompleter { t.Helper() - return NewInlineCompleter(riverinternaltest.BaseServiceArchetype(t), exec) + return NewInlineCompleter(riverinternaltest.BaseServiceArchetype(t), exec, subscribeCh) }, func(completer *InlineCompleter) { completer.disableSleep = true }, func(completer *InlineCompleter, exec PartialExecutor) { completer.exec = exec }) @@ -405,7 +438,7 @@ func TestInlineCompleter(t *testing.T) { func testCompleter[TCompleter JobCompleter]( t *testing.T, - newCompleter func(t *testing.T, exec riverdriver.Executor) TCompleter, + newCompleter func(t *testing.T, exec riverdriver.Executor, subscribeCh chan<- []CompleterJobUpdated) TCompleter, // These functions are here to help us inject test behavior that's not part // of the JobCompleter interface. We could alternatively define a second @@ -419,23 +452,26 @@ func testCompleter[TCompleter JobCompleter]( ctx := context.Background() type testBundle struct { - exec riverdriver.Executor + exec riverdriver.Executor + subscribeCh <-chan []CompleterJobUpdated } setup := func(t *testing.T) (TCompleter, *testBundle) { t.Helper() var ( - driver = riverpgxv5.New(riverinternaltest.TestDB(ctx, t)) - exec = driver.GetExecutor() - completer = newCompleter(t, exec) + driver = riverpgxv5.New(riverinternaltest.TestDB(ctx, t)) + exec = driver.GetExecutor() + subscribeCh = make(chan []CompleterJobUpdated, 10) + completer = newCompleter(t, exec, subscribeCh) ) require.NoError(t, completer.Start(ctx)) t.Cleanup(completer.Stop) return completer, &testBundle{ - exec: exec, + exec: exec, + subscribeCh: subscribeCh, } } @@ -519,6 +555,8 @@ func testCompleter[TCompleter JobCompleter]( jobs, err := bundle.exec.JobGetByKindMany(ctx, []string{kind}) require.NoError(t, err) + t.Cleanup(riverinternaltest.DiscardContinuously(bundle.subscribeCh)) + for i := range jobs { require.NoError(t, completer.JobSetStateIfRunning(ctx, &stats[i], riverdriver.JobSetStateCompleted(jobs[i].ID, time.Now()))) } @@ -542,6 +580,7 @@ func testCompleter[TCompleter JobCompleter]( completer, bundle := setup(t) + t.Cleanup(riverinternaltest.DiscardContinuously(bundle.subscribeCh)) stopInsertion := doContinuousInsertion(ctx, t, completer, bundle.exec) // Give some time for some jobs to be inserted, and a guaranteed pass by @@ -623,19 +662,15 @@ func testCompleter[TCompleter JobCompleter]( completer, bundle := setup(t) - var jobUpdate CompleterJobUpdated - completer.Subscribe(func(update CompleterJobUpdated) { - jobUpdate = update - }) - job := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateRunning)}) require.NoError(t, completer.JobSetStateIfRunning(ctx, &jobstats.JobStatistics{}, riverdriver.JobSetStateCompleted(job.ID, time.Now()))) completer.Stop() - require.NotZero(t, jobUpdate) - require.Equal(t, rivertype.JobStateCompleted, jobUpdate.Job.State) + jobUpdate := riverinternaltest.WaitOrTimeout(t, bundle.subscribeCh) + require.Len(t, jobUpdate, 1) + require.Equal(t, rivertype.JobStateCompleted, jobUpdate[0].Job.State) }) t.Run("MultipleCycles", func(t *testing.T) { @@ -653,6 +688,9 @@ func testCompleter[TCompleter JobCompleter]( requireState(t, bundle.exec, job.ID, rivertype.JobStateCompleted) } + // Completer closes the subscribe channel on stop, so we need to reset it between runs. + completer.ResetSubscribeChan(make(SubscribeChan, 10)) + { require.NoError(t, completer.Start(ctx)) @@ -793,51 +831,6 @@ func testCompleter[TCompleter JobCompleter]( requireState(t, bundle.exec, job.ID, rivertype.JobStateRunning) }) - t.Run("SubscribeStress", func(t *testing.T) { - t.Parallel() - - completer, bundle := setup(t) - - stopInsertion := doContinuousInsertion(ctx, t, completer, bundle.exec) - - const numGoroutines = 5 - - var ( - rand = randutil.NewCryptoSeededConcurrentSafeRand() - stopSubscribing = make(chan struct{}) - wg sync.WaitGroup - ) - - wg.Add(numGoroutines) - for i := 0; i < numGoroutines; i++ { - go func() { - defer wg.Done() - for { - select { - case <-stopSubscribing: - return - case <-time.After(time.Duration(randutil.IntBetween(rand, int(2*time.Millisecond), int(20*time.Millisecond)))): - completer.Subscribe(func(update CompleterJobUpdated) {}) - } - } - }() - } - - // Give some time for some jobs to be inserted and the subscriber - // goroutines to churn. - time.Sleep(100 * time.Millisecond) - - close(stopSubscribing) - wg.Wait() - - // Signal to stop insertion and wait for the goroutine to return. - numInserted := stopInsertion() - - require.Greater(t, numInserted, 0) - - completer.Stop() - }) - // The batch completer supports an interface that lets caller wait for it to // start. Make sure this works as expected. t.Run("WithStartedWaitsForStarted", func(t *testing.T) { @@ -853,36 +846,36 @@ func testCompleter[TCompleter JobCompleter]( } func BenchmarkAsyncCompleter_Concurrency10(b *testing.B) { - benchmarkCompleter(b, func(b *testing.B, exec riverdriver.Executor) JobCompleter { + benchmarkCompleter(b, func(b *testing.B, exec riverdriver.Executor, subscribeCh chan<- []CompleterJobUpdated) JobCompleter { b.Helper() - return newAsyncCompleterWithConcurrency(riverinternaltest.BaseServiceArchetype(b), exec, 10) + return newAsyncCompleterWithConcurrency(riverinternaltest.BaseServiceArchetype(b), exec, 10, subscribeCh) }) } func BenchmarkAsyncCompleter_Concurrency100(b *testing.B) { - benchmarkCompleter(b, func(b *testing.B, exec riverdriver.Executor) JobCompleter { + benchmarkCompleter(b, func(b *testing.B, exec riverdriver.Executor, subscribeCh chan<- []CompleterJobUpdated) JobCompleter { b.Helper() - return newAsyncCompleterWithConcurrency(riverinternaltest.BaseServiceArchetype(b), exec, 100) + return newAsyncCompleterWithConcurrency(riverinternaltest.BaseServiceArchetype(b), exec, 100, subscribeCh) }) } func BenchmarkBatchCompleter(b *testing.B) { - benchmarkCompleter(b, func(b *testing.B, exec riverdriver.Executor) JobCompleter { + benchmarkCompleter(b, func(b *testing.B, exec riverdriver.Executor, subscribeCh chan<- []CompleterJobUpdated) JobCompleter { b.Helper() - return NewBatchCompleter(riverinternaltest.BaseServiceArchetype(b), exec) + return NewBatchCompleter(riverinternaltest.BaseServiceArchetype(b), exec, subscribeCh) }) } func BenchmarkInlineCompleter(b *testing.B) { - benchmarkCompleter(b, func(b *testing.B, exec riverdriver.Executor) JobCompleter { + benchmarkCompleter(b, func(b *testing.B, exec riverdriver.Executor, subscribeCh chan<- []CompleterJobUpdated) JobCompleter { b.Helper() - return NewInlineCompleter(riverinternaltest.BaseServiceArchetype(b), exec) + return NewInlineCompleter(riverinternaltest.BaseServiceArchetype(b), exec, subscribeCh) }) } func benchmarkCompleter( b *testing.B, - newCompleter func(b *testing.B, exec riverdriver.Executor) JobCompleter, + newCompleter func(b *testing.B, exec riverdriver.Executor, subscribeCh chan<- []CompleterJobUpdated) JobCompleter, ) { b.Helper() @@ -898,11 +891,14 @@ func benchmarkCompleter( b.Helper() var ( - driver = riverpgxv5.New(riverinternaltest.TestDB(ctx, b)) - exec = driver.GetExecutor() - completer = newCompleter(b, exec) + driver = riverpgxv5.New(riverinternaltest.TestDB(ctx, b)) + exec = driver.GetExecutor() + subscribeCh = make(chan []CompleterJobUpdated, 100) + completer = newCompleter(b, exec, subscribeCh) ) + b.Cleanup(riverinternaltest.DiscardContinuously(subscribeCh)) + require.NoError(b, completer.Start(ctx)) b.Cleanup(completer.Stop) diff --git a/job_executor_test.go b/job_executor_test.go index 79f2179e..4233072b 100644 --- a/job_executor_test.go +++ b/job_executor_test.go @@ -116,11 +116,11 @@ func TestJobExecutor_Execute(t *testing.T) { ctx := context.Background() type testBundle struct { - completer *jobcompleter.InlineCompleter - exec riverdriver.Executor - errorHandler *testErrorHandler - getUpdatesAndStop func() []jobcompleter.CompleterJobUpdated - jobRow *rivertype.JobRow + completer *jobcompleter.InlineCompleter + exec riverdriver.Executor + errorHandler *testErrorHandler + jobRow *rivertype.JobRow + updateCh <-chan []jobcompleter.CompleterJobUpdated } setup := func(t *testing.T) (*jobExecutor, *testBundle) { @@ -130,19 +130,11 @@ func TestJobExecutor_Execute(t *testing.T) { tx = riverinternaltest.TestTx(ctx, t) archetype = riverinternaltest.BaseServiceArchetype(t) exec = riverpgxv5.New(nil).UnwrapExecutor(tx) - completer = jobcompleter.NewInlineCompleter(archetype, exec) + updateCh = make(chan []jobcompleter.CompleterJobUpdated, 10) + completer = jobcompleter.NewInlineCompleter(archetype, exec, updateCh) ) - var updates []jobcompleter.CompleterJobUpdated - completer.Subscribe(func(update jobcompleter.CompleterJobUpdated) { - updates = append(updates, update) - }) - - getJobUpdates := func() []jobcompleter.CompleterJobUpdated { - completer.Stop() - return updates - } - t.Cleanup(func() { _ = getJobUpdates() }) + t.Cleanup(completer.Stop) workUnitFactory := newWorkUnitFactoryWithCustomRetry(func() error { return nil }, nil) @@ -167,11 +159,11 @@ func TestJobExecutor_Execute(t *testing.T) { job = jobs[0] bundle := &testBundle{ - completer: completer, - exec: exec, - errorHandler: newTestErrorHandler(), - getUpdatesAndStop: getJobUpdates, - jobRow: job, + completer: completer, + exec: exec, + errorHandler: newTestErrorHandler(), + jobRow: job, + updateCh: updateCh, } // allocate this context just so we can set the CancelFunc: @@ -206,19 +198,24 @@ func TestJobExecutor_Execute(t *testing.T) { }, nil).MakeUnit(bundle.jobRow) executor.Execute(ctx) - executor.Completer.Stop() + jobUpdates := riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) require.Equal(t, rivertype.JobStateCompleted, job.State) - jobUpdates := bundle.getUpdatesAndStop() require.Len(t, jobUpdates, 1) jobUpdate := jobUpdates[0] t.Logf("Job statistics: %+v", jobUpdate.JobStats) require.NotZero(t, jobUpdate.JobStats.CompleteDuration) require.NotZero(t, jobUpdate.JobStats.QueueWaitDuration) require.NotZero(t, jobUpdate.JobStats.RunDuration) + + select { + case <-bundle.updateCh: + t.Fatalf("unexpected job update: %+v", jobUpdate) + default: + } }) t.Run("FirstError", func(t *testing.T) { @@ -232,7 +229,7 @@ func TestJobExecutor_Execute(t *testing.T) { executor.WorkUnit = newWorkUnitFactoryWithCustomRetry(func() error { return workerErr }, nil).MakeUnit(bundle.jobRow) executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -256,7 +253,7 @@ func TestJobExecutor_Execute(t *testing.T) { executor.WorkUnit = newWorkUnitFactoryWithCustomRetry(func() error { return workerErr }, nil).MakeUnit(bundle.jobRow) executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -276,7 +273,7 @@ func TestJobExecutor_Execute(t *testing.T) { { executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -295,7 +292,7 @@ func TestJobExecutor_Execute(t *testing.T) { { executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -315,7 +312,7 @@ func TestJobExecutor_Execute(t *testing.T) { executor.WorkUnit = newWorkUnitFactoryWithCustomRetry(func() error { return workerErr }, nil).MakeUnit(bundle.jobRow) executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -335,7 +332,7 @@ func TestJobExecutor_Execute(t *testing.T) { executor.WorkUnit = newWorkUnitFactoryWithCustomRetry(func() error { return cancelErr }, nil).MakeUnit(bundle.jobRow) executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -358,7 +355,7 @@ func TestJobExecutor_Execute(t *testing.T) { executor.WorkUnit = newWorkUnitFactoryWithCustomRetry(func() error { return cancelErr }, nil).MakeUnit(bundle.jobRow) executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -378,7 +375,7 @@ func TestJobExecutor_Execute(t *testing.T) { executor.WorkUnit = newWorkUnitFactoryWithCustomRetry(func() error { return cancelErr }, nil).MakeUnit(bundle.jobRow) executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -398,7 +395,7 @@ func TestJobExecutor_Execute(t *testing.T) { executor.WorkUnit = newWorkUnitFactoryWithCustomRetry(func() error { return workerErr }, nil).MakeUnit(bundle.jobRow) executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -418,7 +415,7 @@ func TestJobExecutor_Execute(t *testing.T) { }).MakeUnit(bundle.jobRow) executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -436,7 +433,7 @@ func TestJobExecutor_Execute(t *testing.T) { executor.WorkUnit = newWorkUnitFactoryWithCustomRetry(func() error { return workerErr }, nil).MakeUnit(bundle.jobRow) executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -457,7 +454,7 @@ func TestJobExecutor_Execute(t *testing.T) { } executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -478,7 +475,7 @@ func TestJobExecutor_Execute(t *testing.T) { } executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -499,7 +496,7 @@ func TestJobExecutor_Execute(t *testing.T) { } executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -515,7 +512,7 @@ func TestJobExecutor_Execute(t *testing.T) { executor.WorkUnit = newWorkUnitFactoryWithCustomRetry(func() error { panic("panic val") }, nil).MakeUnit(bundle.jobRow) executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -536,7 +533,7 @@ func TestJobExecutor_Execute(t *testing.T) { executor.WorkUnit = newWorkUnitFactoryWithCustomRetry(func() error { panic("panic val") }, nil).MakeUnit(bundle.jobRow) executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -554,7 +551,7 @@ func TestJobExecutor_Execute(t *testing.T) { executor.WorkUnit = newWorkUnitFactoryWithCustomRetry(func() error { panic("panic val") }, nil).MakeUnit(bundle.jobRow) executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -574,7 +571,7 @@ func TestJobExecutor_Execute(t *testing.T) { } executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -594,7 +591,7 @@ func TestJobExecutor_Execute(t *testing.T) { } executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -614,7 +611,7 @@ func TestJobExecutor_Execute(t *testing.T) { } executor.Execute(ctx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) job, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) @@ -634,7 +631,7 @@ func TestJobExecutor_Execute(t *testing.T) { executor.CancelFunc = cancelFunc executor.Execute(workCtx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) require.ErrorIs(t, context.Cause(workCtx), errExecutorDefaultCancel) }) @@ -664,7 +661,7 @@ func TestJobExecutor_Execute(t *testing.T) { t.Cleanup(func() { cancelFunc(nil) }) executor.Execute(workCtx) - executor.Completer.Stop() + riverinternaltest.WaitOrTimeout(t, bundle.updateCh) jobRow, err := bundle.exec.JobGetByID(ctx, bundle.jobRow.ID) require.NoError(t, err) diff --git a/producer_test.go b/producer_test.go index 2e82653c..4b157d16 100644 --- a/producer_test.go +++ b/producer_test.go @@ -54,7 +54,10 @@ func Test_Producer_CanSafelyCompleteJobsWhileFetchingNewOnes(t *testing.T) { exec := dbDriver.GetExecutor() listener := dbDriver.GetListener() - completer := jobcompleter.NewInlineCompleter(archetype, exec) + subscribeCh := make(chan []jobcompleter.CompleterJobUpdated, 100) + t.Cleanup(riverinternaltest.DiscardContinuously(subscribeCh)) + + completer := jobcompleter.NewInlineCompleter(archetype, exec, subscribeCh) t.Cleanup(completer.Stop) type WithJobNumArgs struct { @@ -145,7 +148,7 @@ func Test_Producer_CanSafelyCompleteJobsWhileFetchingNewOnes(t *testing.T) { func TestProducer_PollOnly(t *testing.T) { t.Parallel() - testProducer(t, func(ctx context.Context, t *testing.T) *producer { + testProducer(t, func(ctx context.Context, t *testing.T) (*producer, chan []jobcompleter.CompleterJobUpdated) { t.Helper() var ( @@ -159,10 +162,16 @@ func TestProducer_PollOnly(t *testing.T) { tx = sharedtx.NewSharedTx(tx) var ( - exec = driver.UnwrapExecutor(tx) - completer = jobcompleter.NewInlineCompleter(archetype, exec) + exec = driver.UnwrapExecutor(tx) + jobUpdates = make(chan []jobcompleter.CompleterJobUpdated, 10) ) + completer := jobcompleter.NewInlineCompleter(archetype, exec, jobUpdates) + { + require.NoError(t, completer.Start(ctx)) + t.Cleanup(completer.Stop) + } + return newProducer(archetype, exec, &producerConfig{ ClientID: testClientID, Completer: completer, @@ -179,25 +188,31 @@ func TestProducer_PollOnly(t *testing.T) { SchedulerInterval: riverinternaltest.SchedulerShortInterval, StatusFunc: func(queue string, status componentstatus.Status) {}, Workers: NewWorkers(), - }) + }), jobUpdates }) } func TestProducer_WithNotifier(t *testing.T) { t.Parallel() - testProducer(t, func(ctx context.Context, t *testing.T) *producer { + testProducer(t, func(ctx context.Context, t *testing.T) (*producer, chan []jobcompleter.CompleterJobUpdated) { t.Helper() var ( - archetype = riverinternaltest.BaseServiceArchetype(t) - dbPool = riverinternaltest.TestDB(ctx, t) - driver = riverpgxv5.New(dbPool) - exec = driver.GetExecutor() - listener = driver.GetListener() - completer = jobcompleter.NewInlineCompleter(archetype, exec) + archetype = riverinternaltest.BaseServiceArchetype(t) + dbPool = riverinternaltest.TestDB(ctx, t) + driver = riverpgxv5.New(dbPool) + exec = driver.GetExecutor() + jobUpdates = make(chan []jobcompleter.CompleterJobUpdated, 10) + listener = driver.GetListener() ) + completer := jobcompleter.NewInlineCompleter(archetype, exec, jobUpdates) + { + require.NoError(t, completer.Start(ctx)) + t.Cleanup(completer.Stop) + } + notifier := notifier.New(archetype, listener, func(componentstatus.Status) {}) { require.NoError(t, notifier.Start(ctx)) @@ -220,11 +235,11 @@ func TestProducer_WithNotifier(t *testing.T) { SchedulerInterval: riverinternaltest.SchedulerShortInterval, StatusFunc: func(queue string, status componentstatus.Status) {}, Workers: NewWorkers(), - }) + }), jobUpdates }) } -func testProducer(t *testing.T, makeProducer func(ctx context.Context, t *testing.T) *producer) { +func testProducer(t *testing.T, makeProducer func(ctx context.Context, t *testing.T) (*producer, chan []jobcompleter.CompleterJobUpdated)) { t.Helper() ctx := context.Background() @@ -240,20 +255,24 @@ func testProducer(t *testing.T, makeProducer func(ctx context.Context, t *testin setup := func(t *testing.T) (*producer, *testBundle) { t.Helper() - producer := makeProducer(ctx, t) + producer, jobUpdates := makeProducer(ctx, t) producer.testSignals.Init() config := newTestConfig(t, nil) - jobUpdates := make(chan jobcompleter.CompleterJobUpdated, 10) - producer.completer.Subscribe(func(update jobcompleter.CompleterJobUpdated) { - jobUpdates <- update - }) + jobUpdatesFlattened := make(chan jobcompleter.CompleterJobUpdated, 10) + go func() { + for updates := range jobUpdates { + for _, update := range updates { + jobUpdatesFlattened <- update + } + } + }() return producer, &testBundle{ completer: producer.completer, config: config, exec: producer.exec, - jobUpdates: jobUpdates, + jobUpdates: jobUpdatesFlattened, workers: producer.workers, } } diff --git a/riverdriver/river_driver_interface.go b/riverdriver/river_driver_interface.go index 25855242..61b70f00 100644 --- a/riverdriver/river_driver_interface.go +++ b/riverdriver/river_driver_interface.go @@ -56,7 +56,7 @@ type Driver[TTx any] interface { // API is not stable. DO NOT USE. HasPool() bool - // UnwrapExecutor gets unwraps executor from a driver transaction. + // UnwrapExecutor gets an executor from a driver transaction. // // API is not stable. DO NOT USE. UnwrapExecutor(tx TTx) ExecutorTx diff --git a/subscription_manager.go b/subscription_manager.go new file mode 100644 index 00000000..1d74a6a6 --- /dev/null +++ b/subscription_manager.go @@ -0,0 +1,242 @@ +package river + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/riverqueue/river/internal/baseservice" + "github.com/riverqueue/river/internal/jobcompleter" + "github.com/riverqueue/river/internal/jobstats" + "github.com/riverqueue/river/internal/maintenance/startstop" + "github.com/riverqueue/river/internal/util/sliceutil" + "github.com/riverqueue/river/rivertype" +) + +type subscriptionManager struct { + baseservice.BaseService + startstop.BaseStartStop + + subscribeCh <-chan []jobcompleter.CompleterJobUpdated + + statsMu sync.Mutex // protects stats fields + statsAggregate jobstats.JobStatistics + statsNumJobs int + + mu sync.Mutex // protects subscription fields + subscriptions map[int]*eventSubscription + subscriptionsSeq int // used for generating simple IDs +} + +func newSubscriptionManager(archetype *baseservice.Archetype, subscribeCh <-chan []jobcompleter.CompleterJobUpdated) *subscriptionManager { + return baseservice.Init(archetype, &subscriptionManager{ + subscribeCh: subscribeCh, + subscriptions: make(map[int]*eventSubscription), + }) +} + +// ResetSubscribeChan is used to change the channel that the subscription +// manager listens on. It must only be called when the subscription manager is +// stopped. +func (sm *subscriptionManager) ResetSubscribeChan(subscribeCh <-chan []jobcompleter.CompleterJobUpdated) { + sm.subscribeCh = subscribeCh +} + +func (sm *subscriptionManager) Start(ctx context.Context) error { + _, shouldStart, stopped := sm.StartInit(ctx) + if !shouldStart { + return nil + } + + go func() { + // This defer should come first so that it's last out, thereby avoiding + // races. + defer close(stopped) + + for updates := range sm.subscribeCh { + sm.distributeJobUpdates(updates) + } + }() + + return nil +} + +func (sm *subscriptionManager) Stop() { + shouldStop, stopped, finalizeStop := sm.StopInit() + if !shouldStop { + return + } + + <-stopped + + // Remove all subscriptions and close corresponding channels. + func() { + sm.mu.Lock() + defer sm.mu.Unlock() + + for subID, sub := range sm.subscriptions { + close(sub.Chan) + delete(sm.subscriptions, subID) + } + }() + + finalizeStop(true) +} + +func (sm *subscriptionManager) logStats(ctx context.Context, svcName string) { + sm.statsMu.Lock() + defer sm.statsMu.Unlock() + + sm.Logger.InfoContext(ctx, svcName+": Job stats (since last stats line)", + "num_jobs_run", sm.statsNumJobs, + "average_complete_duration", sm.safeDurationAverage(sm.statsAggregate.CompleteDuration, sm.statsNumJobs), + "average_queue_wait_duration", sm.safeDurationAverage(sm.statsAggregate.QueueWaitDuration, sm.statsNumJobs), + "average_run_duration", sm.safeDurationAverage(sm.statsAggregate.RunDuration, sm.statsNumJobs)) + + sm.statsAggregate = jobstats.JobStatistics{} + sm.statsNumJobs = 0 +} + +// Handles a potential divide by zero. +func (sm *subscriptionManager) safeDurationAverage(d time.Duration, n int) time.Duration { + if n == 0 { + return 0 + } + return d / time.Duration(n) +} + +// Receives updates from the completer and prompts the client to update +// statistics and distribute jobs into any listening subscriber channels. +// (Subscriber channels are non-blocking so this should be quite fast.) +func (sm *subscriptionManager) distributeJobUpdates(updates []jobcompleter.CompleterJobUpdated) { + func() { + sm.statsMu.Lock() + defer sm.statsMu.Unlock() + + for _, update := range updates { + stats := update.JobStats + sm.statsAggregate.CompleteDuration += stats.CompleteDuration + sm.statsAggregate.QueueWaitDuration += stats.QueueWaitDuration + sm.statsAggregate.RunDuration += stats.RunDuration + sm.statsNumJobs++ + } + }() + + sm.mu.Lock() + defer sm.mu.Unlock() + + // Quick path so we don't need to allocate anything if no one is listening. + if len(sm.subscriptions) < 1 { + return + } + + for _, update := range updates { + sm.distributeJobEvent(update.Job, jobStatisticsFromInternal(update.JobStats)) + } +} + +// Distribute a single event into any listening subscriber channels. +// +// Job events should specify the job and stats, while queue events should only specify +// the queue. +// +// MUST be called with sm.mu already held. +func (sm *subscriptionManager) distributeJobEvent(job *rivertype.JobRow, stats *JobStatistics) { + var event *Event + switch job.State { + case rivertype.JobStateCancelled: + event = &Event{Kind: EventKindJobCancelled, Job: job, JobStats: stats} + case rivertype.JobStateCompleted: + event = &Event{Kind: EventKindJobCompleted, Job: job, JobStats: stats} + case rivertype.JobStateScheduled: + event = &Event{Kind: EventKindJobSnoozed, Job: job, JobStats: stats} + case rivertype.JobStateAvailable, rivertype.JobStateDiscarded, rivertype.JobStateRetryable, rivertype.JobStateRunning: + event = &Event{Kind: EventKindJobFailed, Job: job, JobStats: stats} + case rivertype.JobStatePending: + panic("completion subscriber unexpectedly received job in pending state, river bug") + default: + // linter exhaustive rule prevents this from being reached + panic("unreachable state to distribute, river bug") + } + + // All subscription channels are non-blocking so this is always fast and + // there's no risk of falling behind what producers are sending. + for _, sub := range sm.subscriptions { + if sub.ListensFor(event.Kind) { + // TODO: THIS IS UNSAFE AND WILL LEAD TO DROPPED EVENTS. + // + // We are allocating subscriber channels with a fixed size of 1000, but + // potentially processing job events in batches of 5000 (batch completer + // max batch size). It's probably not possible for the subscriber to keep + // up with these bursts. + select { + case sub.Chan <- event: + default: + } + } + } +} + +func (sm *subscriptionManager) distributeQueueEvent(event *Event) { + sm.mu.Lock() + defer sm.mu.Unlock() + + // All subscription channels are non-blocking so this is always fast and + // there's no risk of falling behind what producers are sending. + for _, sub := range sm.subscriptions { + if sub.ListensFor(event.Kind) { + select { + case sub.Chan <- event: + default: + } + } + } +} + +// Special internal variant that lets us inject an overridden size. +func (sm *subscriptionManager) SubscribeConfig(config *SubscribeConfig) (<-chan *Event, func()) { + if config.ChanSize < 0 { + panic("SubscribeConfig.ChanSize must be greater or equal to 1") + } + if config.ChanSize == 0 { + config.ChanSize = subscribeChanSizeDefault + } + + for _, kind := range config.Kinds { + if _, ok := allKinds[kind]; !ok { + panic(fmt.Errorf("unknown event kind: %s", kind)) + } + } + + subChan := make(chan *Event, config.ChanSize) + + sm.mu.Lock() + defer sm.mu.Unlock() + + // Just gives us an easy way of removing the subscription again later. + subID := sm.subscriptionsSeq + sm.subscriptionsSeq++ + + sm.subscriptions[subID] = &eventSubscription{ + Chan: subChan, + Kinds: sliceutil.KeyBy(config.Kinds, func(k EventKind) (EventKind, struct{}) { return k, struct{}{} }), + } + + cancel := func() { + sm.mu.Lock() + defer sm.mu.Unlock() + + // May no longer be present in case this was called after a stop. + sub, ok := sm.subscriptions[subID] + if !ok { + return + } + + close(sub.Chan) + + delete(sm.subscriptions, subID) + } + + return subChan, cancel +} diff --git a/subscription_manager_test.go b/subscription_manager_test.go new file mode 100644 index 00000000..5406d4a8 --- /dev/null +++ b/subscription_manager_test.go @@ -0,0 +1,126 @@ +package river + +import ( + "context" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/riverqueue/river/internal/jobcompleter" + "github.com/riverqueue/river/internal/jobstats" + "github.com/riverqueue/river/internal/riverinternaltest" + "github.com/riverqueue/river/internal/riverinternaltest/testfactory" + "github.com/riverqueue/river/internal/util/ptrutil" + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivertype" + "github.com/stretchr/testify/require" +) + +func Test_SubscriptionManager(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type testBundle struct { + exec riverdriver.Executor + subscribeCh chan []jobcompleter.CompleterJobUpdated + tx pgx.Tx + } + + setup := func(t *testing.T) (*subscriptionManager, *testBundle) { + t.Helper() + + tx := riverinternaltest.TestTx(ctx, t) + exec := riverpgxv5.New(nil).UnwrapExecutor(tx) + + subscribeCh := make(chan []jobcompleter.CompleterJobUpdated, 1) + manager := newSubscriptionManager(riverinternaltest.BaseServiceArchetype(t), subscribeCh) + + require.NoError(t, manager.Start(ctx)) + t.Cleanup(manager.Stop) + + return manager, &testBundle{ + exec: exec, + subscribeCh: subscribeCh, + tx: tx, + } + } + + t.Run("DistributesRequestedEventsToSubscribers", func(t *testing.T) { + t.Parallel() + + manager, bundle := setup(t) + t.Cleanup(func() { close(bundle.subscribeCh) }) + + sub, cancelSub := manager.SubscribeConfig(&SubscribeConfig{ChanSize: 10, Kinds: []EventKind{EventKindJobCompleted, EventKindJobSnoozed}}) + t.Cleanup(cancelSub) + + // Send some events + job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateCompleted), FinalizedAt: ptrutil.Ptr(time.Now())}) + job2 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateCancelled), FinalizedAt: ptrutil.Ptr(time.Now())}) + job3 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateRetryable)}) + job4 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateScheduled)}) + + makeStats := func(complete, wait, run time.Duration) *jobstats.JobStatistics { + return &jobstats.JobStatistics{ + CompleteDuration: complete, + QueueWaitDuration: wait, + RunDuration: run, + } + } + + bundle.subscribeCh <- []jobcompleter.CompleterJobUpdated{ + {Job: job1, JobStats: makeStats(101, 102, 103)}, // completed, should be sent + {Job: job2, JobStats: makeStats(201, 202, 203)}, // cancelled, should be skipped + } + bundle.subscribeCh <- []jobcompleter.CompleterJobUpdated{ + {Job: job3, JobStats: makeStats(301, 302, 303)}, // retryable, should be skipped + {Job: job4, JobStats: makeStats(401, 402, 403)}, // snoozed/scheduled, should be sent + } + + received := riverinternaltest.WaitOrTimeoutN(t, sub, 2) + require.Equal(t, job1.ID, received[0].Job.ID) + require.Equal(t, rivertype.JobStateCompleted, received[0].Job.State) + require.Equal(t, time.Duration(101), received[0].JobStats.CompleteDuration) + require.Equal(t, time.Duration(102), received[0].JobStats.QueueWaitDuration) + require.Equal(t, time.Duration(103), received[0].JobStats.RunDuration) + require.Equal(t, job4.ID, received[1].Job.ID) + require.Equal(t, rivertype.JobStateScheduled, received[1].Job.State) + require.Equal(t, time.Duration(401), received[1].JobStats.CompleteDuration) + require.Equal(t, time.Duration(402), received[1].JobStats.QueueWaitDuration) + require.Equal(t, time.Duration(403), received[1].JobStats.RunDuration) + + cancelSub() + select { + case value, stillOpen := <-sub: + require.False(t, stillOpen, "subscription channel should be closed") + require.Nil(t, value, "subscription channel should be closed") + default: + require.Fail(t, "subscription channel should have been closed") + } + }) + + t.Run("StartStopRepeatedly", func(t *testing.T) { + // This service does not use the typical `startstoptest.Stress()` test + // because there are some additional steps required after a `Stop` for the + // subsequent `Start` to succeed. It's also not friendly for multiple + // concurrent calls to `Start` and `Stop`, but this is fine because the only + // usage within `Client` is already protected by a mutex. + t.Parallel() + + manager, bundle := setup(t) + + subscribeCh := bundle.subscribeCh + for i := 0; i < 100; i++ { + go func() { close(subscribeCh) }() + manager.Stop() + + subscribeCh = make(chan []jobcompleter.CompleterJobUpdated, 1) + manager.ResetSubscribeChan(subscribeCh) + + require.NoError(t, manager.Start(ctx)) + } + close(subscribeCh) + }) +} From 01eb6e1f723fceae335780c25a921862f3651610 Mon Sep 17 00:00:00 2001 From: Blake Gentry Date: Sun, 9 Jun 2024 22:19:35 -0500 Subject: [PATCH 2/2] normalize subscription manager service Respect a stop, but in the case of one, makes sure to clear the subscription channel before leaving, which means that it still correctly clears all events on a client shutdown. This gives us a way to use the stress test because all we need to do is close the channel in advance before calling startstoptest.Stress (the service is still a little weird compared to other because it requires that channel close, but a little less so). Also normalizes things a bit by removing the custom Stop implementation, which most services shouldn't need. Co-Authored-By: Brandur Leach --- subscription_manager.go | 56 +++++++++++++++++++++--------------- subscription_manager_test.go | 15 +++++++++- 2 files changed, 47 insertions(+), 24 deletions(-) diff --git a/subscription_manager.go b/subscription_manager.go index 1d74a6a6..e1df69a9 100644 --- a/subscription_manager.go +++ b/subscription_manager.go @@ -44,7 +44,7 @@ func (sm *subscriptionManager) ResetSubscribeChan(subscribeCh <-chan []jobcomple } func (sm *subscriptionManager) Start(ctx context.Context) error { - _, shouldStart, stopped := sm.StartInit(ctx) + ctx, shouldStart, stopped := sm.StartInit(ctx) if !shouldStart { return nil } @@ -54,34 +54,44 @@ func (sm *subscriptionManager) Start(ctx context.Context) error { // races. defer close(stopped) - for updates := range sm.subscribeCh { - sm.distributeJobUpdates(updates) - } - }() - - return nil -} - -func (sm *subscriptionManager) Stop() { - shouldStop, stopped, finalizeStop := sm.StopInit() - if !shouldStop { - return - } + sm.Logger.DebugContext(ctx, sm.Name+": Run loop started") + defer sm.Logger.DebugContext(ctx, sm.Name+": Run loop stopped") - <-stopped + // On shutdown, close and remove all active subscriptions. + defer func() { + sm.mu.Lock() + defer sm.mu.Unlock() - // Remove all subscriptions and close corresponding channels. - func() { - sm.mu.Lock() - defer sm.mu.Unlock() + for subID, sub := range sm.subscriptions { + close(sub.Chan) + delete(sm.subscriptions, subID) + } + }() - for subID, sub := range sm.subscriptions { - close(sub.Chan) - delete(sm.subscriptions, subID) + for { + select { + case <-ctx.Done(): + // Distribute remaining subscriptions until the channel is + // closed. This does make the subscription manager a little + // problematic in that it requires the subscription channel to + // be closed before it will fully stop. This always happens in + // the case of a real client by virtue of the completer always + // stopping at the same time as the subscription manager, but + // one has to be careful in tests. + sm.Logger.DebugContext(ctx, sm.Name+": Stopping; distributing subscriptions until channel is closed") + for updates := range sm.subscribeCh { + sm.distributeJobUpdates(updates) + } + + return + + case updates := <-sm.subscribeCh: + sm.distributeJobUpdates(updates) + } } }() - finalizeStop(true) + return nil } func (sm *subscriptionManager) logStats(ctx context.Context, svcName string) { diff --git a/subscription_manager_test.go b/subscription_manager_test.go index 5406d4a8..be005008 100644 --- a/subscription_manager_test.go +++ b/subscription_manager_test.go @@ -9,6 +9,7 @@ import ( "github.com/riverqueue/river/internal/jobcompleter" "github.com/riverqueue/river/internal/jobstats" "github.com/riverqueue/river/internal/riverinternaltest" + "github.com/riverqueue/river/internal/riverinternaltest/startstoptest" "github.com/riverqueue/river/internal/riverinternaltest/testfactory" "github.com/riverqueue/river/internal/util/ptrutil" "github.com/riverqueue/river/riverdriver" @@ -113,7 +114,7 @@ func Test_SubscriptionManager(t *testing.T) { subscribeCh := bundle.subscribeCh for i := 0; i < 100; i++ { - go func() { close(subscribeCh) }() + close(subscribeCh) manager.Stop() subscribeCh = make(chan []jobcompleter.CompleterJobUpdated, 1) @@ -123,4 +124,16 @@ func Test_SubscriptionManager(t *testing.T) { } close(subscribeCh) }) + + t.Run("StartStopStress", func(t *testing.T) { + t.Parallel() + + svc, bundle := setup(t) + + // Close the subscription channel in advance so that stops can leave + // successfully. + close(bundle.subscribeCh) + + startstoptest.Stress(ctx, t, svc) + }) }