diff --git a/manager/dispatcher/dispatcher.go b/manager/dispatcher/dispatcher.go index c6530fad76..5e2754ccc5 100644 --- a/manager/dispatcher/dispatcher.go +++ b/manager/dispatcher/dispatcher.go @@ -133,7 +133,8 @@ type Dispatcher struct { // has finished initializing the dispatcher. wg sync.WaitGroup // This RWMutex synchronizes RPC handlers and the dispatcher stop(). - // The RPC handlers use the read lock while stop() uses the write lock + // Used to serialize read-write access to the dispatcher context. + // Also, the RPC handlers use the read lock while stop() uses the write lock // and acts as a barrier to shutdown. rpcRW sync.RWMutex nodes *nodeStore @@ -224,11 +225,12 @@ func (d *Dispatcher) Run(ctx context.Context) error { d.nodeUpdates = make(map[string]nodeUpdate) d.nodeUpdatesLock.Unlock() - d.mu.Lock() - if d.isRunning() { + if _, err := d.context(); err == nil { d.mu.Unlock() return errors.New("dispatcher is already running") } + + d.mu.Lock() if err := d.markNodesUnknown(ctx); err != nil { log.G(ctx).Errorf(`failed to move all nodes to "unknown" state: %v`, err) } @@ -265,12 +267,15 @@ func (d *Dispatcher) Run(ctx context.Context) error { d.lastSeenManagers = getWeightedPeers(d.cluster) defer cancel() - d.ctx, d.cancel = context.WithCancel(ctx) - ctx = d.ctx d.wg.Add(1) defer d.wg.Done() d.mu.Unlock() + d.rpcRW.Lock() + d.ctx, d.cancel = context.WithCancel(ctx) + ctx = d.ctx + d.rpcRW.Unlock() + publishManagers := func(peers []*api.Peer) { var mgrs []*api.WeightedPeer for _, p := range peers { @@ -325,7 +330,9 @@ func (d *Dispatcher) Run(ctx context.Context) error { // Stop stops dispatcher and closes all grpc streams. func (d *Dispatcher) Stop() error { - d.mu.Lock() + // RPCs that start after rpcRW.Unlock() should find the context + // cancelled and should fail organically. + d.rpcRW.Lock() if !d.isRunning() { d.mu.Unlock() return errors.New("dispatcher is already stopped") @@ -333,14 +340,10 @@ func (d *Dispatcher) Stop() error { log := log.G(d.ctx).WithField("method", "(*Dispatcher).Stop") log.Info("dispatcher stopping") - d.cancel() - d.mu.Unlock() // The active nodes list can be cleaned out only when all // existing RPCs have finished. - // RPCs that start after rpcRW.Unlock() should find the context - // cancelled and should fail organically. - d.rpcRW.Lock() + d.cancel() d.nodes.Clean() d.downNodes.Clean() d.rpcRW.Unlock() @@ -364,14 +367,14 @@ func (d *Dispatcher) Stop() error { return nil } -func (d *Dispatcher) isRunningLocked() (context.Context, error) { - d.mu.Lock() +// context returns the dispatcher context. +func (d *Dispatcher) context() (context.Context, error) { + d.rpcRW.RLock() + defer d.rpcRW.RUnlock() if !d.isRunning() { - d.mu.Unlock() return nil, status.Errorf(codes.Aborted, "dispatcher is stopped") } ctx := d.ctx - d.mu.Unlock() return ctx, nil } @@ -510,7 +513,7 @@ func nodeIPFromContext(ctx context.Context) (string, error) { func (d *Dispatcher) register(ctx context.Context, nodeID string, description *api.NodeDescription) (string, error) { logLocal := log.G(ctx).WithField("method", "(*Dispatcher).register") // prevent register until we're ready to accept it - dctx, err := d.isRunningLocked() + dctx, err := d.context() if err != nil { return "", err } @@ -565,7 +568,7 @@ func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStat d.rpcRW.RLock() defer d.rpcRW.RUnlock() - dctx, err := d.isRunningLocked() + dctx, err := d.context() if err != nil { return nil, err } @@ -759,7 +762,7 @@ func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServe d.rpcRW.RLock() defer d.rpcRW.RUnlock() - dctx, err := d.isRunningLocked() + dctx, err := d.context() if err != nil { return err } @@ -885,7 +888,7 @@ func (d *Dispatcher) Assignments(r *api.AssignmentsRequest, stream api.Dispatche d.rpcRW.RLock() defer d.rpcRW.RUnlock() - dctx, err := d.isRunningLocked() + dctx, err := d.context() if err != nil { return err } @@ -1080,7 +1083,7 @@ func (d *Dispatcher) moveTasksToOrphaned(nodeID string) error { func (d *Dispatcher) markNodeNotReady(id string, state api.NodeStatus_State, message string) error { logLocal := log.G(d.ctx).WithField("method", "(*Dispatcher).markNodeNotReady") - dctx, err := d.isRunningLocked() + dctx, err := d.context() if err != nil { return err } @@ -1190,7 +1193,7 @@ func (d *Dispatcher) Session(r *api.SessionRequest, stream api.Dispatcher_Sessio d.rpcRW.RLock() defer d.rpcRW.RUnlock() - dctx, err := d.isRunningLocked() + dctx, err := d.context() if err != nil { return err }