diff --git a/CHANGELOG.md b/CHANGELOG.md index b305eace..69342d69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ⚠️ Version 0.19.0 has minor breaking changes for the `Worker.Middleware`, introduced fairly recently in 0.17.0. We tried not to make this change, but found the existing middleware interface insufficient to provide the necessary range of functionality we wanted, and this is a secondary middleware facility that won't be in use for many users, so it seemed worthwhile. +### Added + +- Added a new "hooks" API for tying into River functionality at various points like job inserts or working. Differs from middleware in that it doesn't go on the stack and can't modify context, but in some cases is able to run at a more granular level (e.g. for each job insert rather than each _batch_ of inserts). [PR #789](https://github.com/riverqueue/river/pull/789). + ### Changed - The `river.RecordOutput` function now returns an error if the output is too large. The output is limited to 32MB in size. [PR #782](https://github.com/riverqueue/river/pull/782). diff --git a/client.go b/client.go index 0568a972..2e669c95 100644 --- a/client.go +++ b/client.go @@ -14,6 +14,7 @@ import ( "github.com/riverqueue/river/internal/dblist" "github.com/riverqueue/river/internal/dbunique" + "github.com/riverqueue/river/internal/hooklookup" "github.com/riverqueue/river/internal/jobcompleter" "github.com/riverqueue/river/internal/leadership" "github.com/riverqueue/river/internal/maintenance" @@ -166,6 +167,11 @@ type Config struct { // Defaults to 1 minute. JobTimeout time.Duration + // Hooks are functions that may activate at certain points during a job's + // lifecycle (see rivertype.Hook), installed globally. Jobs may have their + // own specific hooks by implementing the JobArgsWithHooks interface. + Hooks []rivertype.Hook + // Logger is the structured logger to use for logging purposes. If none is // specified, logs will be emitted to STDOUT with messages at warn level // or higher. @@ -314,6 +320,7 @@ func (c *Config) WithDefaults() *Config { FetchCooldown: valutil.ValOrDefault(c.FetchCooldown, FetchCooldownDefault), FetchPollInterval: valutil.ValOrDefault(c.FetchPollInterval, FetchPollIntervalDefault), ID: valutil.ValOrDefaultFunc(c.ID, func() string { return defaultClientID(time.Now().UTC()) }), + Hooks: c.Hooks, JobInsertMiddleware: c.JobInsertMiddleware, JobTimeout: valutil.ValOrDefault(c.JobTimeout, JobTimeoutDefault), Logger: logger, @@ -427,6 +434,8 @@ type Client[TTx any] struct { config *Config driver riverdriver.Driver[TTx] elector *leadership.Elector + hookLookupByJob *hooklookup.JobHookLookup + hookLookupGlobal hooklookup.HookLookupInterface insertNotifyLimiter *notifylimiter.Limiter notifier *notifier.Notifier // may be nil in poll-only mode periodicJobs *PeriodicJobBundle @@ -543,6 +552,8 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client client := &Client[TTx]{ config: config, driver: driver, + hookLookupByJob: hooklookup.NewJobHookLookup(), + hookLookupGlobal: hooklookup.NewHookLookup(config.Hooks), producersByQueueName: make(map[string]*producer), testSignals: clientTestSignals{}, workCancel: func(cause error) {}, // replaced on start, but here in case StopAndCancel is called before start up @@ -1510,6 +1521,24 @@ func (c *Client[TTx]) insertManyShared( execute func(context.Context, []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error), ) ([]*rivertype.JobInsertResult, error) { doInner := func(ctx context.Context) ([]*rivertype.JobInsertResult, error) { + for _, params := range insertParams { + // TODO(brandur): This range clause and the one below it are + // identical, and it'd be nice to merge them together, but in such a + // way that doesn't require array allocation. I think we can do this + // using iterators after we drop support for Go 1.22. + for _, hook := range c.hookLookupGlobal.ByHookKind(hooklookup.HookKindInsertBegin) { + if err := hook.(rivertype.HookInsertBegin).InsertBegin(ctx, params); err != nil { //nolint:forcetypeassert + return nil, err + } + } + + for _, hook := range c.hookLookupByJob.ByJobArgs(params.Args).ByHookKind(hooklookup.HookKindInsertBegin) { + if err := hook.(rivertype.HookInsertBegin).InsertBegin(ctx, params); err != nil { //nolint:forcetypeassert + return nil, err + } + } + } + finalInsertParams := sliceutil.Map(insertParams, func(params *rivertype.JobInsertParams) *riverdriver.JobInsertFastParams { return (*riverdriver.JobInsertFastParams)(params) }) @@ -1754,7 +1783,8 @@ func (c *Client[TTx]) addProducer(queueName string, queueConfig QueueConfig) *pr ErrorHandler: c.config.ErrorHandler, FetchCooldown: c.config.FetchCooldown, FetchPollInterval: c.config.FetchPollInterval, - GlobalMiddleware: c.config.WorkerMiddleware, + HookLookupByJob: c.hookLookupByJob, + HookLookupGlobal: c.hookLookupGlobal, JobTimeout: c.config.JobTimeout, MaxWorkers: queueConfig.MaxWorkers, Notifier: c.notifier, @@ -1763,6 +1793,7 @@ func (c *Client[TTx]) addProducer(queueName string, queueConfig QueueConfig) *pr RetryPolicy: c.config.RetryPolicy, SchedulerInterval: c.config.schedulerInterval, Workers: c.config.Workers, + WorkerMiddleware: c.config.WorkerMiddleware, }) c.producersByQueueName[queueName] = producer return producer diff --git a/client_test.go b/client_test.go index 31864be5..256c0d12 100644 --- a/client_test.go +++ b/client_test.go @@ -24,6 +24,7 @@ import ( "github.com/tidwall/sjson" "github.com/riverqueue/river/internal/dbunique" + "github.com/riverqueue/river/internal/jobexecutor" "github.com/riverqueue/river/internal/maintenance" "github.com/riverqueue/river/internal/notifier" "github.com/riverqueue/river/internal/rivercommon" @@ -610,7 +611,117 @@ func Test_Client(t *testing.T) { require.Equal(t, `relation "river_job" does not exist`, pgErr.Message) }) - t.Run("WithWorkerMiddleware", func(t *testing.T) { + t.Run("WithGlobalInsertBeginHook", func(t *testing.T) { + t.Parallel() + + _, bundle := setup(t) + + insertBeginHookCalled := false + + bundle.config.Hooks = []rivertype.Hook{ + HookInsertBeginFunc(func(ctx context.Context, params *rivertype.JobInsertParams) error { + insertBeginHookCalled = true + return nil + }), + } + + AddWorker(bundle.config.Workers, WorkFunc(func(ctx context.Context, job *Job[callbackArgs]) error { + return nil + })) + + client, err := NewClient(riverpgxv5.New(bundle.dbPool), bundle.config) + require.NoError(t, err) + + _, err = client.Insert(ctx, callbackArgs{}, nil) + require.NoError(t, err) + + require.True(t, insertBeginHookCalled) + }) + + t.Run("WithGlobalWorkBeginHook", func(t *testing.T) { + t.Parallel() + + _, bundle := setup(t) + + workBeginHookCalled := false + + bundle.config.Hooks = []rivertype.Hook{ + HookWorkBeginFunc(func(ctx context.Context, job *rivertype.JobRow) error { + workBeginHookCalled = true + return nil + }), + } + + AddWorker(bundle.config.Workers, WorkFunc(func(ctx context.Context, job *Job[callbackArgs]) error { + return nil + })) + + client, err := NewClient(riverpgxv5.New(bundle.dbPool), bundle.config) + require.NoError(t, err) + + subscribeChan := subscribe(t, client) + startClient(ctx, t, client) + + insertRes, err := client.Insert(ctx, callbackArgs{}, nil) + require.NoError(t, err) + + event := riversharedtest.WaitOrTimeout(t, subscribeChan) + require.Equal(t, EventKindJobCompleted, event.Kind) + require.Equal(t, insertRes.Job.ID, event.Job.ID) + + require.True(t, workBeginHookCalled) + }) + + t.Run("WithInsertBeginHookOnJobArgs", func(t *testing.T) { + t.Parallel() + + _, bundle := setup(t) + + AddWorker(bundle.config.Workers, WorkFunc(func(ctx context.Context, job *Job[jobArgsWithCustomHook]) error { + return nil + })) + + client, err := NewClient(riverpgxv5.New(bundle.dbPool), bundle.config) + require.NoError(t, err) + + insertRes, err := client.Insert(ctx, jobArgsWithCustomHook{}, nil) + require.NoError(t, err) + + var metadataMap map[string]any + err = json.Unmarshal(insertRes.Job.Metadata, &metadataMap) + require.NoError(t, err) + require.Equal(t, "called", metadataMap["insert_begin_hook"]) + }) + + t.Run("WithWorkBeginHookOnJobArgs", func(t *testing.T) { + t.Parallel() + + _, bundle := setup(t) + + AddWorker(bundle.config.Workers, WorkFunc(func(ctx context.Context, job *Job[jobArgsWithCustomHook]) error { + return nil + })) + + client, err := NewClient(riverpgxv5.New(bundle.dbPool), bundle.config) + require.NoError(t, err) + + subscribeChan := subscribe(t, client) + startClient(ctx, t, client) + + insertRes, err := client.Insert(ctx, jobArgsWithCustomHook{}, nil) + require.NoError(t, err) + + event := riversharedtest.WaitOrTimeout(t, subscribeChan) + require.Equal(t, EventKindJobCompleted, event.Kind) + require.Equal(t, insertRes.Job.ID, event.Job.ID) + + var metadataMap map[string]any + err = json.Unmarshal(event.Job.Metadata, &metadataMap) + require.NoError(t, err) + require.Equal(t, "called", metadataMap["work_begin_hook"]) + }) + + t.Run("WithGlobalWorkerMiddleware", func(t *testing.T) { t.Parallel() _, bundle := setup(t) @@ -983,6 +1094,51 @@ func Test_Client(t *testing.T) { }) } +type jobArgsWithCustomHook struct{} + +func (jobArgsWithCustomHook) Kind() string { return "with_custom_hook" } + +func (jobArgsWithCustomHook) Hooks() []rivertype.Hook { + return []rivertype.Hook{ + &testHookInsertAndWorkBegin{}, + } +} + +var ( + _ rivertype.HookInsertBegin = &testHookInsertAndWorkBegin{} + _ rivertype.HookWorkBegin = &testHookInsertAndWorkBegin{} +) + +type testHookInsertAndWorkBegin struct{ HookDefaults } + +func (t *testHookInsertAndWorkBegin) InsertBegin(ctx context.Context, params *rivertype.JobInsertParams) error { + var metadataMap map[string]any + if err := json.Unmarshal(params.Metadata, &metadataMap); err != nil { + return err + } + + metadataMap["insert_begin_hook"] = "called" + + var err error + params.Metadata, err = json.Marshal(metadataMap) + if err != nil { + return err + } + + return nil +} + +func (t *testHookInsertAndWorkBegin) WorkBegin(ctx context.Context, job *rivertype.JobRow) error { + metadataUpdates, hasMetadataUpdates := jobexecutor.MetadataUpdatesFromWorkContext(ctx) + if !hasMetadataUpdates { + panic("expected to be called from within job executor") + } + + metadataUpdates["work_begin_hook"] = "called" + + return nil +} + type workerWithMiddleware[T JobArgs] struct { WorkerDefaults[T] workFunc func(context.Context, *Job[T]) error @@ -5081,6 +5237,7 @@ func Test_NewClient_Defaults(t *testing.T) { require.Equal(t, FetchCooldownDefault, client.config.FetchCooldown) require.Equal(t, FetchPollIntervalDefault, client.config.FetchPollInterval) require.Equal(t, JobTimeoutDefault, client.config.JobTimeout) + require.Nil(t, client.config.Hooks) require.NotZero(t, client.baseService.Logger) require.Equal(t, MaxAttemptsDefault, client.config.MaxAttempts) require.IsType(t, &DefaultClientRetryPolicy{}, client.config.RetryPolicy) @@ -5103,6 +5260,10 @@ func Test_NewClient_Overrides(t *testing.T) { retryPolicy := &DefaultClientRetryPolicy{} + type noOpHook struct { + HookDefaults + } + type noOpInsertMiddleware struct { JobInsertMiddlewareDefaults } @@ -5119,6 +5280,7 @@ func Test_NewClient_Overrides(t *testing.T) { ErrorHandler: errorHandler, FetchCooldown: 123 * time.Millisecond, FetchPollInterval: 124 * time.Millisecond, + Hooks: []rivertype.Hook{&noOpHook{}}, JobInsertMiddleware: []rivertype.JobInsertMiddleware{&noOpInsertMiddleware{}}, JobTimeout: 125 * time.Millisecond, Logger: logger, @@ -5154,6 +5316,7 @@ func Test_NewClient_Overrides(t *testing.T) { require.Equal(t, 124*time.Millisecond, client.config.FetchPollInterval) require.Len(t, client.config.JobInsertMiddleware, 1) require.Equal(t, 125*time.Millisecond, client.config.JobTimeout) + require.Equal(t, []rivertype.Hook{&noOpHook{}}, client.config.Hooks) require.Equal(t, logger, client.baseService.Logger) require.Equal(t, 5, client.config.MaxAttempts) require.Equal(t, retryPolicy, client.config.RetryPolicy) diff --git a/hook_defaults_funcs.go b/hook_defaults_funcs.go new file mode 100644 index 00000000..c685076f --- /dev/null +++ b/hook_defaults_funcs.go @@ -0,0 +1,34 @@ +package river + +import ( + "context" + + "github.com/riverqueue/river/rivertype" +) + +// HookDefaults should be embedded on any hook implementation. It helps +// guarantee forward compatibility in case additions are necessary to the Hook +// interface. +type HookDefaults struct{} + +func (d *HookDefaults) IsHook() bool { return true } + +// HookInsertBeginFunc is a convenience helper for implementing HookInsertBegin +// using a simple function instead of a struct. +type HookInsertBeginFunc func(ctx context.Context, params *rivertype.JobInsertParams) error + +func (f HookInsertBeginFunc) InsertBegin(ctx context.Context, params *rivertype.JobInsertParams) error { + return f(ctx, params) +} + +func (f HookInsertBeginFunc) IsHook() bool { return true } + +// HookWorkBeginFunc is a convenience helper for implementing HookworkBegin +// using a simple function instead of a struct. +type HookWorkBeginFunc func(ctx context.Context, job *rivertype.JobRow) error + +func (f HookWorkBeginFunc) WorkBegin(ctx context.Context, job *rivertype.JobRow) error { + return f(ctx, job) +} + +func (f HookWorkBeginFunc) IsHook() bool { return true } diff --git a/hook_defaults_funcs_test.go b/hook_defaults_funcs_test.go new file mode 100644 index 00000000..2c64bde2 --- /dev/null +++ b/hook_defaults_funcs_test.go @@ -0,0 +1,16 @@ +package river + +import ( + "context" + + "github.com/riverqueue/river/rivertype" +) + +// Verify interface compliance. +var ( + _ rivertype.Hook = HookInsertBeginFunc(func(ctx context.Context, params *rivertype.JobInsertParams) error { return nil }) + _ rivertype.HookInsertBegin = HookInsertBeginFunc(func(ctx context.Context, params *rivertype.JobInsertParams) error { return nil }) + + _ rivertype.Hook = HookWorkBeginFunc(func(ctx context.Context, job *rivertype.JobRow) error { return nil }) + _ rivertype.HookWorkBegin = HookWorkBeginFunc(func(ctx context.Context, job *rivertype.JobRow) error { return nil }) +) diff --git a/internal/hooklookup/hook_lookup.go b/internal/hooklookup/hook_lookup.go new file mode 100644 index 00000000..007ef497 --- /dev/null +++ b/internal/hooklookup/hook_lookup.go @@ -0,0 +1,152 @@ +package hooklookup + +import ( + "sync" + + "github.com/riverqueue/river/rivertype" +) + +// +// HookKind +// + +type HookKind string + +const ( + HookKindInsertBegin HookKind = "insert_begin" + HookKindWorkBegin HookKind = "work_begin" +) + +// +// HookLookupInterface +// + +// HookLookupInterface is an interface to look up hooks by hook kind. It's +// commonly implemented by HookLookup, but may also be EmptyHookLookup as a +// memory allocation optimization for bundles where no hooks are present. +type HookLookupInterface interface { + ByHookKind(kind HookKind) []rivertype.Hook +} + +// NewHookLookup returns a new hook lookup interface based on the given hooks +// that satisfies HookLookupInterface. This is often hookLookup, but may be +// emptyHookLookup as an optimization for the common case of an empty hook +// bundle. +func NewHookLookup(hooks []rivertype.Hook) HookLookupInterface { + if len(hooks) < 1 { + return &emptyHookLookup{} + } + + return &hookLookup{ + hooks: hooks, + hooksByKind: make(map[HookKind][]rivertype.Hook), + mu: &sync.RWMutex{}, + } +} + +// +// hookLookup +// + +// hookLookup looks up and caches hooks based on a HookKind, saving work when +// looking up hooks for specific operations, a common operation that gets +// repeated over and over again. This struct may be used as a lookup for +// globally installed hooks or hooks for specific job kinds through the use of +// JobHookLookup. +type hookLookup struct { + hooks []rivertype.Hook + hooksByKind map[HookKind][]rivertype.Hook + mu *sync.RWMutex +} + +func (c *hookLookup) ByHookKind(kind HookKind) []rivertype.Hook { + c.mu.RLock() + cache, ok := c.hooksByKind[kind] + c.mu.RUnlock() + if ok { + return cache + } + + c.mu.Lock() + defer c.mu.Unlock() + + // Even if this ends up being empty, make sure there's an entry for the next + // time the cache gets invoked for this kind. + c.hooksByKind[kind] = nil + + // Rely on exhaustlint to find any missing hook kinds here. + switch kind { + case HookKindInsertBegin: + for _, hook := range c.hooks { + if typedHook, ok := hook.(rivertype.HookInsertBegin); ok { + c.hooksByKind[kind] = append(c.hooksByKind[kind], typedHook) + } + } + case HookKindWorkBegin: + for _, hook := range c.hooks { + if typedHook, ok := hook.(rivertype.HookWorkBegin); ok { + c.hooksByKind[kind] = append(c.hooksByKind[kind], typedHook) + } + } + } + + return c.hooksByKind[kind] +} + +// +// emptyHookLookup +// + +// emptyHookLookup is an empty version of HookLookup that's zero allocation. For +// most applications, most job args won't have hooks, so this prevents us from +// allocating dozens/hundreds of small HookLookup objects that go unused. +type emptyHookLookup struct{} + +func (c *emptyHookLookup) ByHookKind(kind HookKind) []rivertype.Hook { return nil } + +// +// JobHookLookup +// + +type JobHookLookup struct { + hookLookupByKind map[string]HookLookupInterface + mu sync.RWMutex +} + +func NewJobHookLookup() *JobHookLookup { + return &JobHookLookup{ + hookLookupByKind: make(map[string]HookLookupInterface), + } +} + +// ByJobArgs returns a HookLookupInterface by job args, which is a HookLookup if +// the job args had specific hooks (i.e. implements JobArgsWithHooks and returns +// a non-empty set of hooks), or an EmptyHashLookup otherwise. +func (c *JobHookLookup) ByJobArgs(args rivertype.JobArgs) HookLookupInterface { + kind := args.Kind() + + c.mu.RLock() + lookup, ok := c.hookLookupByKind[kind] + c.mu.RUnlock() + if ok { + return lookup + } + + c.mu.Lock() + defer c.mu.Unlock() + + var hooks []rivertype.Hook + if argsWithHooks, ok := args.(jobArgsWithHooks); ok { + hooks = argsWithHooks.Hooks() + } + + lookup = NewHookLookup(hooks) + c.hookLookupByKind[kind] = lookup + return lookup +} + +// Same as river.JobArgsWithHooks, but duplicated here so that can still live in +// the top level package. +type jobArgsWithHooks interface { + Hooks() []rivertype.Hook +} diff --git a/internal/hooklookup/hook_lookup_test.go b/internal/hooklookup/hook_lookup_test.go new file mode 100644 index 00000000..90da7e35 --- /dev/null +++ b/internal/hooklookup/hook_lookup_test.go @@ -0,0 +1,244 @@ +package hooklookup + +import ( + "context" + "sync" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/rivertype" +) + +func TestHookLookup(t *testing.T) { + t.Parallel() + + type testBundle struct{} + + setup := func(t *testing.T) (*hookLookup, *testBundle) { //nolint:unparam + t.Helper() + + return NewHookLookup([]rivertype.Hook{ //nolint:forcetypeassert + &testHookInsertAndWorkBegin{}, + &testHookInsertBegin{}, + &testHookWorkBegin{}, + }).(*hookLookup), &testBundle{} + } + + t.Run("LooksUpHooks", func(t *testing.T) { + t.Parallel() + + hookLookup, _ := setup(t) + + require.Equal(t, []rivertype.Hook{ + &testHookInsertAndWorkBegin{}, + &testHookInsertBegin{}, + }, hookLookup.ByHookKind(HookKindInsertBegin)) + require.Equal(t, []rivertype.Hook{ + &testHookInsertAndWorkBegin{}, + &testHookWorkBegin{}, + }, hookLookup.ByHookKind(HookKindWorkBegin)) + + require.Len(t, hookLookup.hooksByKind, 2) + + // Repeat lookups to make sure we get the same result. + require.Equal(t, []rivertype.Hook{ + &testHookInsertAndWorkBegin{}, + &testHookInsertBegin{}, + }, hookLookup.ByHookKind(HookKindInsertBegin)) + require.Equal(t, []rivertype.Hook{ + &testHookInsertAndWorkBegin{}, + &testHookWorkBegin{}, + }, hookLookup.ByHookKind(HookKindWorkBegin)) + }) + + t.Run("Stress", func(t *testing.T) { + t.Parallel() + + hookLookup, _ := setup(t) + + var wg sync.WaitGroup + + parallelLookupLoop := func(kind HookKind) { + wg.Add(1) + go func() { + defer wg.Done() + + for range 50 { + hookLookup.ByHookKind(kind) + } + }() + } + + parallelLookupLoop(HookKindInsertBegin) + parallelLookupLoop(HookKindWorkBegin) + parallelLookupLoop(HookKindInsertBegin) + parallelLookupLoop(HookKindWorkBegin) + + wg.Wait() + }) +} + +func TestEmptyHookLookup(t *testing.T) { + t.Parallel() + + type testBundle struct{} + + setup := func(t *testing.T) (*emptyHookLookup, *testBundle) { + t.Helper() + + return NewHookLookup(nil).(*emptyHookLookup), &testBundle{} //nolint:forcetypeassert + } + + t.Run("AlwaysReturnsNil", func(t *testing.T) { + t.Parallel() + + hookLookup, _ := setup(t) + + require.Nil(t, hookLookup.ByHookKind(HookKindInsertBegin)) + require.Nil(t, hookLookup.ByHookKind(HookKindWorkBegin)) + }) +} + +func TestJobHookLookup(t *testing.T) { + t.Parallel() + + type testBundle struct{} + + setup := func(t *testing.T) (*JobHookLookup, *testBundle) { //nolint:unparam + t.Helper() + + return NewJobHookLookup(), &testBundle{} + } + + t.Run("LooksUpHooks", func(t *testing.T) { + t.Parallel() + + jobHookLookup, _ := setup(t) + + require.Nil(t, jobHookLookup.ByJobArgs(&jobArgsNoHooks{}).ByHookKind(HookKindInsertBegin)) + require.Nil(t, jobHookLookup.ByJobArgs(&jobArgsNoHooks{}).ByHookKind(HookKindWorkBegin)) + require.Equal(t, []rivertype.Hook{ + &testHookInsertAndWorkBegin{}, + &testHookInsertBegin{}, + }, jobHookLookup.ByJobArgs(&jobArgsWithCustomHooks{}).ByHookKind(HookKindInsertBegin)) + require.Equal(t, []rivertype.Hook{ + &testHookInsertAndWorkBegin{}, + &testHookWorkBegin{}, + }, jobHookLookup.ByJobArgs(&jobArgsWithCustomHooks{}).ByHookKind(HookKindWorkBegin)) + + require.Len(t, jobHookLookup.hookLookupByKind, 2) + + // Repeat lookups to make sure we get the same result. + require.Nil(t, jobHookLookup.ByJobArgs(&jobArgsNoHooks{}).ByHookKind(HookKindInsertBegin)) + require.Nil(t, jobHookLookup.ByJobArgs(&jobArgsNoHooks{}).ByHookKind(HookKindWorkBegin)) + require.Equal(t, []rivertype.Hook{ + &testHookInsertAndWorkBegin{}, + &testHookInsertBegin{}, + }, jobHookLookup.ByJobArgs(&jobArgsWithCustomHooks{}).ByHookKind(HookKindInsertBegin)) + require.Equal(t, []rivertype.Hook{ + &testHookInsertAndWorkBegin{}, + &testHookWorkBegin{}, + }, jobHookLookup.ByJobArgs(&jobArgsWithCustomHooks{}).ByHookKind(HookKindWorkBegin)) + }) + + t.Run("Stress", func(t *testing.T) { + t.Parallel() + + jobHookLookup, _ := setup(t) + + var wg sync.WaitGroup + + parallelLookupLoop := func(args rivertype.JobArgs) { + wg.Add(1) + go func() { + defer wg.Done() + + for range 50 { + jobHookLookup.ByJobArgs(args) + } + }() + } + + parallelLookupLoop(&jobArgsNoHooks{}) + parallelLookupLoop(&jobArgsWithCustomHooks{}) + parallelLookupLoop(&jobArgsNoHooks{}) + parallelLookupLoop(&jobArgsWithCustomHooks{}) + + wg.Wait() + }) +} + +// +// jobArgsNoHooks +// + +var _ rivertype.JobArgs = &jobArgsNoHooks{} + +type jobArgsNoHooks struct{} + +func (jobArgsNoHooks) Kind() string { return "no_hooks" } + +// +// jobArgsWithHooks +// + +var ( + _ rivertype.JobArgs = &jobArgsWithCustomHooks{} + _ jobArgsWithHooks = &jobArgsWithCustomHooks{} +) + +type jobArgsWithCustomHooks struct{} + +func (jobArgsWithCustomHooks) Hooks() []rivertype.Hook { + return []rivertype.Hook{ + &testHookInsertAndWorkBegin{}, + &testHookInsertBegin{}, + &testHookWorkBegin{}, + } +} + +func (jobArgsWithCustomHooks) Kind() string { return "with_custom_hooks" } + +// +// testHookInsertAndWorkBegin +// + +var ( + _ rivertype.HookInsertBegin = &testHookInsertAndWorkBegin{} + _ rivertype.HookWorkBegin = &testHookInsertAndWorkBegin{} +) + +type testHookInsertAndWorkBegin struct{ rivertype.Hook } + +func (t *testHookInsertAndWorkBegin) InsertBegin(ctx context.Context, params *rivertype.JobInsertParams) error { + return nil +} + +func (t *testHookInsertAndWorkBegin) WorkBegin(ctx context.Context, job *rivertype.JobRow) error { + return nil +} + +// +// testHookInsertBegin +// + +var _ rivertype.HookInsertBegin = &testHookInsertBegin{} + +type testHookInsertBegin struct{ rivertype.Hook } + +func (t *testHookInsertBegin) InsertBegin(ctx context.Context, params *rivertype.JobInsertParams) error { + return nil +} + +// +// testHookWorkBegin +// + +var _ rivertype.HookWorkBegin = &testHookWorkBegin{} + +type testHookWorkBegin struct{ rivertype.Hook } + +func (t *testHookWorkBegin) WorkBegin(ctx context.Context, job *rivertype.JobRow) error { + return nil +} diff --git a/internal/jobexecutor/job_executor.go b/internal/jobexecutor/job_executor.go index 980fd065..a5c3b62c 100644 --- a/internal/jobexecutor/job_executor.go +++ b/internal/jobexecutor/job_executor.go @@ -13,6 +13,7 @@ import ( "github.com/tidwall/sjson" "github.com/riverqueue/river/internal/execution" + "github.com/riverqueue/river/internal/hooklookup" "github.com/riverqueue/river/internal/jobcompleter" "github.com/riverqueue/river/internal/jobstats" "github.com/riverqueue/river/internal/workunit" @@ -108,9 +109,11 @@ type JobExecutor struct { ClientRetryPolicy ClientRetryPolicy DefaultClientRetryPolicy ClientRetryPolicy ErrorHandler ErrorHandler + HookLookupByJob *hooklookup.JobHookLookup + HookLookupGlobal hooklookup.HookLookupInterface InformProducerDoneFunc func(jobRow *rivertype.JobRow) JobRow *rivertype.JobRow - GlobalMiddleware []rivertype.WorkerMiddleware + WorkerMiddleware []rivertype.WorkerMiddleware SchedulerInterval time.Duration WorkUnit workunit.WorkUnit @@ -184,6 +187,24 @@ func (e *JobExecutor) execute(ctx context.Context) (res *jobExecutorResult) { } doInner := execution.Func(func(ctx context.Context) error { + { + // TODO(brandur): This range clause and the one below it are + // identical, and it'd be nice to merge them together, but in such a + // way that doesn't require array allocation. I think we can do this + // using iterators after we drop support for Go 1.22. + for _, hook := range e.HookLookupGlobal.ByHookKind(hooklookup.HookKindWorkBegin) { + if err := hook.(rivertype.HookWorkBegin).WorkBegin(ctx, e.JobRow); err != nil { //nolint:forcetypeassert + return err + } + } + + for _, hook := range e.WorkUnit.HookLookup(e.HookLookupByJob).ByHookKind(hooklookup.HookKindWorkBegin) { + if err := hook.(rivertype.HookWorkBegin).WorkBegin(ctx, e.JobRow); err != nil { //nolint:forcetypeassert + return err + } + } + } + if err := e.WorkUnit.UnmarshalJob(); err != nil { return err } @@ -195,7 +216,7 @@ func (e *JobExecutor) execute(ctx context.Context) (res *jobExecutorResult) { return e.WorkUnit.Work(ctx) }) - executeFunc := execution.MiddlewareChain(e.GlobalMiddleware, e.WorkUnit.Middleware(), doInner, e.JobRow) + executeFunc := execution.MiddlewareChain(e.WorkerMiddleware, e.WorkUnit.Middleware(), doInner, e.JobRow) return &jobExecutorResult{Err: executeFunc(ctx), MetadataUpdates: metadataUpdates} } diff --git a/internal/jobexecutor/job_executor_test.go b/internal/jobexecutor/job_executor_test.go index 6f08b151..788e5364 100644 --- a/internal/jobexecutor/job_executor_test.go +++ b/internal/jobexecutor/job_executor_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/riverqueue/river/internal/hooklookup" "github.com/riverqueue/river/internal/jobcompleter" "github.com/riverqueue/river/internal/rivercommon" "github.com/riverqueue/river/internal/riverinternaltest" @@ -33,6 +34,10 @@ type customizableWorkUnit struct { work func() error } +func (w *customizableWorkUnit) HookLookup(lookup *hooklookup.JobHookLookup) hooklookup.HookLookupInterface { + return hooklookup.NewHookLookup(nil) +} + func (w *customizableWorkUnit) Middleware() []rivertype.WorkerMiddleware { return w.middleware } @@ -176,6 +181,8 @@ func TestJobExecutor_Execute(t *testing.T) { Completer: bundle.completer, DefaultClientRetryPolicy: &retrypolicytest.RetryPolicyNoJitter{}, ErrorHandler: bundle.errorHandler, + HookLookupByJob: hooklookup.NewJobHookLookup(), + HookLookupGlobal: hooklookup.NewHookLookup(nil), InformProducerDoneFunc: func(job *rivertype.JobRow) {}, JobRow: bundle.jobRow, SchedulerInterval: riverinternaltest.SchedulerShortInterval, @@ -575,7 +582,7 @@ func TestJobExecutor_Execute(t *testing.T) { executor, bundle := setup(t) // Add a middleware so we can verify it's in the trace too: - executor.GlobalMiddleware = []rivertype.WorkerMiddleware{ + executor.WorkerMiddleware = []rivertype.WorkerMiddleware{ &testMiddleware{ work: func(ctx context.Context, job *rivertype.JobRow, next func(context.Context) error) error { return next(ctx) diff --git a/internal/maintenance/job_rescuer_test.go b/internal/maintenance/job_rescuer_test.go index 9de1bda5..0390276d 100644 --- a/internal/maintenance/job_rescuer_test.go +++ b/internal/maintenance/job_rescuer_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/riverqueue/river/internal/hooklookup" "github.com/riverqueue/river/internal/riverinternaltest" "github.com/riverqueue/river/internal/workunit" "github.com/riverqueue/river/riverdriver" @@ -38,6 +39,9 @@ type callbackWorkUnit struct { timeout time.Duration // defaults to 0, which signals default timeout } +func (w *callbackWorkUnit) HookLookup(cache *hooklookup.JobHookLookup) hooklookup.HookLookupInterface { + return nil +} func (w *callbackWorkUnit) Middleware() []rivertype.WorkerMiddleware { return nil } func (w *callbackWorkUnit) NextRetry() time.Time { return time.Now().Add(30 * time.Second) } func (w *callbackWorkUnit) Timeout() time.Duration { return w.timeout } diff --git a/internal/workunit/work_unit.go b/internal/workunit/work_unit.go index fc1f541d..dbc05345 100644 --- a/internal/workunit/work_unit.go +++ b/internal/workunit/work_unit.go @@ -4,6 +4,7 @@ import ( "context" "time" + "github.com/riverqueue/river/internal/hooklookup" "github.com/riverqueue/river/rivertype" ) @@ -15,6 +16,11 @@ import ( // // Implemented by river.wrapperWorkUnit. type WorkUnit interface { + // HookLookup procures the a hook lookup bundle for the wrapped job using + // the given job hook lookup bundle. Hooks are looked up by job args and + // otherwise not available to jobexecutor. + HookLookup(lookup *hooklookup.JobHookLookup) hooklookup.HookLookupInterface + Middleware() []rivertype.WorkerMiddleware NextRetry() time.Time Timeout() time.Duration diff --git a/job.go b/job.go index 0954fec8..0868abef 100644 --- a/job.go +++ b/job.go @@ -24,6 +24,19 @@ type JobArgs interface { Kind() string } +type JobArgsWithHooks interface { + // Hooks returns specific hooks to run for this job type. These will run + // after the global hooks configured on the client. + // + // Warning: Hooks returned should be based on the job type only and be + // invariant of the specific contents of a job. Hooks are extracted by + // instantiating a generic instance of the job even when a specific instance + // is available, so any conditional logic within will be ignored. This is + // done because although specific job information may be available in some + // hook contexts like on InsertBegin, it won't be in others like WorkBegin. + Hooks() []rivertype.Hook +} + // JobArgsWithInsertOpts is an extra interface that a job may implement on top // of JobArgs to provide insertion-time options for all jobs of this type. type JobArgsWithInsertOpts interface { diff --git a/producer.go b/producer.go index a0d98332..8700fb24 100644 --- a/producer.go +++ b/producer.go @@ -9,6 +9,7 @@ import ( "sync/atomic" "time" + "github.com/riverqueue/river/internal/hooklookup" "github.com/riverqueue/river/internal/jobcompleter" "github.com/riverqueue/river/internal/jobexecutor" "github.com/riverqueue/river/internal/notifier" @@ -64,7 +65,8 @@ type producerConfig struct { // LISTEN/NOTIFY, but this provides a fallback. FetchPollInterval time.Duration - GlobalMiddleware []rivertype.WorkerMiddleware + HookLookupByJob *hooklookup.JobHookLookup + HookLookupGlobal hooklookup.HookLookupInterface JobTimeout time.Duration MaxWorkers int @@ -87,6 +89,7 @@ type producerConfig struct { RetryPolicy ClientRetryPolicy SchedulerInterval time.Duration Workers *Workers + WorkerMiddleware []rivertype.WorkerMiddleware } func (c *producerConfig) mustValidate() *producerConfig { @@ -616,10 +619,12 @@ func (p *producer) startNewExecutors(workCtx context.Context, jobs []*rivertype. Completer: p.completer, DefaultClientRetryPolicy: &DefaultClientRetryPolicy{}, ErrorHandler: p.errorHandler, + HookLookupByJob: p.config.HookLookupByJob, + HookLookupGlobal: p.config.HookLookupGlobal, InformProducerDoneFunc: p.handleWorkerDone, - GlobalMiddleware: p.config.GlobalMiddleware, JobRow: job, SchedulerInterval: p.config.SchedulerInterval, + WorkerMiddleware: p.config.WorkerMiddleware, WorkUnit: workUnit, }) p.addActiveJob(job.ID, executor) diff --git a/producer_test.go b/producer_test.go index 3be9b8c8..bd7df759 100644 --- a/producer_test.go +++ b/producer_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/riverqueue/river/internal/hooklookup" "github.com/riverqueue/river/internal/jobcompleter" "github.com/riverqueue/river/internal/maintenance" "github.com/riverqueue/river/internal/notifier" @@ -90,6 +91,8 @@ func Test_Producer_CanSafelyCompleteJobsWhileFetchingNewOnes(t *testing.T) { // Fetch constantly to more aggressively trigger the potential data race: FetchCooldown: time.Millisecond, FetchPollInterval: time.Millisecond, + HookLookupByJob: hooklookup.NewJobHookLookup(), + HookLookupGlobal: hooklookup.NewHookLookup(nil), JobTimeout: JobTimeoutDefault, MaxWorkers: 1000, Notifier: notifier, @@ -176,6 +179,8 @@ func TestProducer_PollOnly(t *testing.T) { ErrorHandler: newTestErrorHandler(), FetchCooldown: FetchCooldownDefault, FetchPollInterval: 50 * time.Millisecond, // more aggressive than normal because we have no notifier + HookLookupByJob: hooklookup.NewJobHookLookup(), + HookLookupGlobal: hooklookup.NewHookLookup(nil), JobTimeout: JobTimeoutDefault, MaxWorkers: 1_000, Notifier: nil, // no notifier @@ -222,6 +227,8 @@ func TestProducer_WithNotifier(t *testing.T) { ErrorHandler: newTestErrorHandler(), FetchCooldown: FetchCooldownDefault, FetchPollInterval: 50 * time.Millisecond, // more aggressive than normal so in case we miss the event, tests still pass quickly + HookLookupByJob: hooklookup.NewJobHookLookup(), + HookLookupGlobal: hooklookup.NewHookLookup(nil), JobTimeout: JobTimeoutDefault, MaxWorkers: 1_000, Notifier: notifier, diff --git a/rivertest/worker.go b/rivertest/worker.go index 1b18169f..d1254fe1 100644 --- a/rivertest/worker.go +++ b/rivertest/worker.go @@ -10,6 +10,7 @@ import ( "github.com/riverqueue/river" "github.com/riverqueue/river/internal/execution" + "github.com/riverqueue/river/internal/hooklookup" "github.com/riverqueue/river/internal/jobcompleter" "github.com/riverqueue/river/internal/jobexecutor" "github.com/riverqueue/river/internal/maintenance" @@ -191,7 +192,9 @@ func (w *Worker[T, TTx]) workJob(ctx context.Context, tb testing.TB, tx TTx, job }, }, InformProducerDoneFunc: func(job *rivertype.JobRow) { close(executionDone) }, - GlobalMiddleware: w.config.WorkerMiddleware, + HookLookupGlobal: hooklookup.NewHookLookup(w.config.Hooks), + HookLookupByJob: hooklookup.NewJobHookLookup(), + WorkerMiddleware: w.config.WorkerMiddleware, JobRow: job, SchedulerInterval: maintenance.JobSchedulerIntervalDefault, WorkUnit: workUnit, @@ -295,13 +298,17 @@ type wrapperWorkUnit[T river.JobArgs] struct { worker river.Worker[T] } -func (w *wrapperWorkUnit[T]) NextRetry() time.Time { return w.worker.NextRetry(w.job) } -func (w *wrapperWorkUnit[T]) Timeout() time.Duration { return w.worker.Timeout(w.job) } -func (w *wrapperWorkUnit[T]) Work(ctx context.Context) error { return w.worker.Work(ctx, w.job) } +func (w *wrapperWorkUnit[T]) HookLookup(lookup *hooklookup.JobHookLookup) hooklookup.HookLookupInterface { + var job T + return lookup.ByJobArgs(job) +} func (w *wrapperWorkUnit[T]) Middleware() []rivertype.WorkerMiddleware { return w.worker.Middleware(w.jobRow) } +func (w *wrapperWorkUnit[T]) NextRetry() time.Time { return w.worker.NextRetry(w.job) } +func (w *wrapperWorkUnit[T]) Timeout() time.Duration { return w.worker.Timeout(w.job) } +func (w *wrapperWorkUnit[T]) Work(ctx context.Context) error { return w.worker.Work(ctx, w.job) } func (w *wrapperWorkUnit[T]) UnmarshalJob() error { w.job = &river.Job[T]{ diff --git a/rivertype/river_type.go b/rivertype/river_type.go index bce3088f..a5863ce6 100644 --- a/rivertype/river_type.go +++ b/rivertype/river_type.go @@ -270,6 +270,55 @@ type JobInsertParams struct { UniqueStates byte } +// Hook is an arbitrary interface for a plugin "hook" which will execute some +// arbitrary code at a predefined step in the job lifecycle. +// +// This interface is left purposely non-specific. Hook structs should embed +// river.HookDefaults to inherit an IsHook implementation, then implement one +// of the more specific hook interfaces like HookInsertBegin or HookWorkBegin. A +// hook struct may also implement multiple specific hook interfaces which are +// logically related and benefit from being grouped together. +// +// Hooks differ from middleware in that they're invoked at a specific lifecycle +// phase, but finish immediately instead of wrapping an inner call like a +// middleware does. One of the main ramifications of this different is that a +// hook cannot modify context in any useful way to pass down into the stack. +// Like a normal function, any changes it makes to its context are discarded on +// return. +// +// All else equal, hooks should generally be preferred over middleware because +// they don't add anything to the call stack. Call stacks that get overly deep +// can become a bit of an operational nightmare because they get hard to read. +// +// In a language with more specific type capabilities, this should be a union +// type. In Go we implement it somewhat awkwardly so that we can get future +// extensibility, but also some typing guarantees to prevent misuse (i.e. if +// Hook was an empty interface, then any object could be passed as a hook, but +// having a single function to implement forces the caller to make some token +// motions in the direction of implementing hooks). +type Hook interface { + // IsHook is a sentinel function to check that a type is implementing Hook + // on purpose and not by accident (Hook would otherwise be an empty + // interface). Hooks should embed river.HookDefaults to pick up an + // implementation for this function automatically. + IsHook() bool +} + +// HookInsertBegin is an interface to a hook that runs before job insertion. +type HookInsertBegin interface { + Hook + + InsertBegin(ctx context.Context, params *JobInsertParams) error +} + +// HookWorkBegin is an interface to a hook that runs after a job has been locked +// for work and before it's worked. +type HookWorkBegin interface { + Hook + + WorkBegin(ctx context.Context, job *JobRow) error +} + // JobInsertMiddleware provides an interface for middleware that integrations can // use to encapsulate common logic around job insertion. // diff --git a/work_unit_wrapper.go b/work_unit_wrapper.go index 2c8e3915..d0fcabc2 100644 --- a/work_unit_wrapper.go +++ b/work_unit_wrapper.go @@ -5,6 +5,7 @@ import ( "encoding/json" "time" + "github.com/riverqueue/river/internal/hooklookup" "github.com/riverqueue/river/internal/workunit" "github.com/riverqueue/river/rivertype" ) @@ -25,13 +26,17 @@ type wrapperWorkUnit[T JobArgs] struct { worker Worker[T] } -func (w *wrapperWorkUnit[T]) NextRetry() time.Time { return w.worker.NextRetry(w.job) } -func (w *wrapperWorkUnit[T]) Timeout() time.Duration { return w.worker.Timeout(w.job) } -func (w *wrapperWorkUnit[T]) Work(ctx context.Context) error { return w.worker.Work(ctx, w.job) } +func (w *wrapperWorkUnit[T]) HookLookup(lookup *hooklookup.JobHookLookup) hooklookup.HookLookupInterface { + var job T + return lookup.ByJobArgs(job) +} func (w *wrapperWorkUnit[T]) Middleware() []rivertype.WorkerMiddleware { return w.worker.Middleware(w.jobRow) } +func (w *wrapperWorkUnit[T]) NextRetry() time.Time { return w.worker.NextRetry(w.job) } +func (w *wrapperWorkUnit[T]) Timeout() time.Duration { return w.worker.Timeout(w.job) } +func (w *wrapperWorkUnit[T]) Work(ctx context.Context) error { return w.worker.Work(ctx, w.job) } func (w *wrapperWorkUnit[T]) UnmarshalJob() error { w.job = &Job[T]{ diff --git a/worker.go b/worker.go index 54625d50..4e7131cc 100644 --- a/worker.go +++ b/worker.go @@ -74,6 +74,8 @@ type Worker[T JobArgs] interface { // struct to make it fulfill the Worker interface with default values. type WorkerDefaults[T JobArgs] struct{} +func (w WorkerDefaults[T]) Hooks(*rivertype.JobRow) []rivertype.Hook { return nil } + func (w WorkerDefaults[T]) Middleware(*rivertype.JobRow) []rivertype.WorkerMiddleware { return nil } // NextRetry returns an empty time.Time{} to avoid setting any job or