diff --git a/manager/dispatcher/dispatcher.go b/manager/dispatcher/dispatcher.go index 7032a1a889..04e39a7ded 100644 --- a/manager/dispatcher/dispatcher.go +++ b/manager/dispatcher/dispatcher.go @@ -113,12 +113,8 @@ type clusterUpdate struct { // Dispatcher is responsible for dispatching tasks and tracking agent health. type Dispatcher struct { - // mu is a lock to provide mutually exclusive access to dispatcher fields - // e.g. lastSeenManagers, networkBootstrapKeys, lastSeenRootCert etc. - mu sync.Mutex - // shutdownWait is used by stop() to wait for existing operations to finish. - shutdownWait sync.WaitGroup - + mu sync.Mutex + wg sync.WaitGroup nodes *nodeStore store *store.MemoryStore lastSeenManagers []*api.WeightedPeer @@ -235,11 +231,8 @@ func (d *Dispatcher) Run(ctx context.Context) error { defer cancel() d.ctx, d.cancel = context.WithCancel(ctx) ctx = d.ctx - - // If Stop() is called, it should wait - // for Run() to complete. - d.shutdownWait.Add(1) - defer d.shutdownWait.Done() + d.wg.Add(1) + defer d.wg.Done() d.mu.Unlock() publishManagers := func(peers []*api.Peer) { @@ -302,14 +295,11 @@ func (d *Dispatcher) Stop() error { return errors.New("dispatcher is already stopped") } - // Cancel dispatcher context. - // This should also close the the streams in Tasks(), Assignments(). + log := log.G(d.ctx).WithField("method", "(*Dispatcher).Stop") + log.Info("dispatcher stopping") d.cancel() d.mu.Unlock() - // Wait for the RPCs that are in-progress to finish. - d.shutdownWait.Wait() - d.nodes.Clean() d.processUpdatesLock.Lock() @@ -320,6 +310,9 @@ func (d *Dispatcher) Stop() error { d.processUpdatesLock.Unlock() d.clusterUpdateQueue.Close() + + d.wg.Wait() + return nil } @@ -465,6 +458,7 @@ func nodeIPFromContext(ctx context.Context) (string, error) { // register is used for registration of node with particular dispatcher. func (d *Dispatcher) register(ctx context.Context, nodeID string, description *api.NodeDescription) (string, error) { + // prevent register until we're ready to accept it dctx, err := d.isRunningLocked() if err != nil { return "", err @@ -516,21 +510,6 @@ func (d *Dispatcher) register(ctx context.Context, nodeID string, description *a // UpdateTaskStatus updates status of task. Node should send such updates // on every status change of its tasks. func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStatusRequest) (*api.UpdateTaskStatusResponse, error) { - // shutdownWait.Add() followed by isRunning() to ensures that - // if this rpc sees the dispatcher running, - // it will already have called Add() on the shutdownWait wait, - // which ensures that Stop() will wait for this rpc to complete. - // Note that Stop() first does Dispatcher.ctx.cancel() followed by - // shutdownWait.Wait() to make sure new rpc's don't start before waiting - // for existing ones to finish. - d.shutdownWait.Add(1) - defer d.shutdownWait.Done() - - dctx, err := d.isRunningLocked() - if err != nil { - return nil, err - } - nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err @@ -546,6 +525,11 @@ func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStat } log := log.G(ctx).WithFields(fields) + dctx, err := d.isRunningLocked() + if err != nil { + return nil, err + } + if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil { return nil, err } @@ -704,26 +688,16 @@ func (d *Dispatcher) processUpdates(ctx context.Context) { // of tasks which should be run on node, if task is not present in that list, // it should be terminated. func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServer) error { - // shutdownWait.Add() followed by isRunning() to ensures that - // if this rpc sees the dispatcher running, - // it will already have called Add() on the shutdownWait wait, - // which ensures that Stop() will wait for this rpc to complete. - // Note that Stop() first does Dispatcher.ctx.cancel() followed by - // shutdownWait.Wait() to make sure new rpc's don't start before waiting - // for existing ones to finish. - d.shutdownWait.Add(1) - defer d.shutdownWait.Done() - - dctx, err := d.isRunningLocked() + nodeInfo, err := ca.RemoteNode(stream.Context()) if err != nil { return err } + nodeID := nodeInfo.NodeID - nodeInfo, err := ca.RemoteNode(stream.Context()) + dctx, err := d.isRunningLocked() if err != nil { return err } - nodeID := nodeInfo.NodeID fields := logrus.Fields{ "node.id": nodeID, @@ -837,26 +811,16 @@ func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServe // Assignments is a stream of assignments for a node. Each message contains // either full list of tasks and secrets for the node, or an incremental update. func (d *Dispatcher) Assignments(r *api.AssignmentsRequest, stream api.Dispatcher_AssignmentsServer) error { - // shutdownWait.Add() followed by isRunning() to ensures that - // if this rpc sees the dispatcher running, - // it will already have called Add() on the shutdownWait wait, - // which ensures that Stop() will wait for this rpc to complete. - // Note that Stop() first does Dispatcher.ctx.cancel() followed by - // shutdownWait.Wait() to make sure new rpc's don't start before waiting - // for existing ones to finish. - d.shutdownWait.Add(1) - defer d.shutdownWait.Done() - - dctx, err := d.isRunningLocked() + nodeInfo, err := ca.RemoteNode(stream.Context()) if err != nil { return err } + nodeID := nodeInfo.NodeID - nodeInfo, err := ca.RemoteNode(stream.Context()) + dctx, err := d.isRunningLocked() if err != nil { return err } - nodeID := nodeInfo.NodeID fields := logrus.Fields{ "node.id": nodeID, @@ -1099,24 +1063,6 @@ func (d *Dispatcher) markNodeNotReady(id string, state api.NodeStatus_State, mes // Node should send new heartbeat earlier than now + TTL, otherwise it will // be deregistered from dispatcher and its status will be updated to NodeStatus_DOWN func (d *Dispatcher) Heartbeat(ctx context.Context, r *api.HeartbeatRequest) (*api.HeartbeatResponse, error) { - // shutdownWait.Add() followed by isRunning() to ensures that - // if this rpc sees the dispatcher running, - // it will already have called Add() on the shutdownWait wait, - // which ensures that Stop() will wait for this rpc to complete. - // Note that Stop() first does Dispatcher.ctx.cancel() followed by - // shutdownWait.Wait() to make sure new rpc's don't start before waiting - // for existing ones to finish. - d.shutdownWait.Add(1) - defer d.shutdownWait.Done() - - // isRunningLocked() is not needed since its OK if - // the dispatcher context is cancelled while this call is in progress - // since Stop() which cancels the dispatcher context will wait for - // Heartbeat() to complete. - if !d.isRunning() { - return nil, grpc.Errorf(codes.Aborted, "dispatcher is stopped") - } - nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err @@ -1149,27 +1095,17 @@ func (d *Dispatcher) getRootCACert() []byte { // a special boolean field Disconnect which if true indicates that node should // reconnect to another Manager immediately. func (d *Dispatcher) Session(r *api.SessionRequest, stream api.Dispatcher_SessionServer) error { - // shutdownWait.Add() followed by isRunning() to ensures that - // if this rpc sees the dispatcher running, - // it will already have called Add() on the shutdownWait wait, - // which ensures that Stop() will wait for this rpc to complete. - // Note that Stop() first does Dispatcher.ctx.cancel() followed by - // shutdownWait.Wait() to make sure new rpc's don't start before waiting - // for existing ones to finish. - d.shutdownWait.Add(1) - defer d.shutdownWait.Done() - - dctx, err := d.isRunningLocked() + ctx := stream.Context() + nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return err } + nodeID := nodeInfo.NodeID - ctx := stream.Context() - nodeInfo, err := ca.RemoteNode(ctx) + dctx, err := d.isRunningLocked() if err != nil { return err } - nodeID := nodeInfo.NodeID var sessionID string if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {