diff --git a/manager/orchestrator/global/global.go b/manager/orchestrator/global/global.go index bced3dd189..d314678144 100644 --- a/manager/orchestrator/global/global.go +++ b/manager/orchestrator/global/global.go @@ -187,7 +187,7 @@ func (g *Orchestrator) FixTask(ctx context.Context, batch *store.Batch, t *api.T node = g.nodes[t.NodeID] } // if the node no longer valid, remove the task - if t.NodeID == "" || orchestrator.InvalidNode(node) { + if node == nil || node.Spec.Availability == api.NodeAvailabilityDrain { g.shutdownTask(ctx, batch, t) return } @@ -295,9 +295,14 @@ func (g *Orchestrator) reconcileServices(ctx context.Context, serviceIDs []strin continue } - if node.Spec.Availability == api.NodeAvailabilityPause { - // the node is paused, so we won't add or update - // any tasks + if node.Spec.Availability == api.NodeAvailabilityPause || + node.Status.State == api.NodeStatus_DOWN { + // The node is paused or down, so we + // won't add or update any tasks. When + // the node is unpaused or comes back + // up, it will trigger node + // reconciliation, correcting anything + // we might have skiped here. continue } @@ -334,7 +339,7 @@ func (g *Orchestrator) reconcileServices(ctx context.Context, serviceIDs []strin // updateNode updates g.nodes based on the current node value func (g *Orchestrator) updateNode(node *api.Node) { - if node.Spec.Availability == api.NodeAvailabilityDrain || node.Status.State == api.NodeStatus_DOWN { + if node.Spec.Availability == api.NodeAvailabilityDrain { delete(g.nodes, node.ID) } else { g.nodes[node.ID] = node @@ -363,14 +368,12 @@ func (g *Orchestrator) reconcileOneNode(ctx context.Context, node *api.Node) { return } - if node.Status.State == api.NodeStatus_DOWN { - log.G(ctx).Debugf("global orchestrator: node %s is down, shutting down its tasks", node.ID) - g.foreachTaskFromNode(ctx, node, g.shutdownTask) - return - } - - if node.Spec.Availability == api.NodeAvailabilityPause { - // the node is paused, so we won't add or update tasks + if node.Spec.Availability == api.NodeAvailabilityPause || + node.Status.State == api.NodeStatus_DOWN { + // The node is paused or down, so we won't add or update any + // tasks. When the node is unpaused or comes back up, it will + // trigger node reconciliation, correcting anything we might + // have skiped here. return } @@ -490,13 +493,12 @@ func (g *Orchestrator) tickTasks(ctx context.Context) { return nil } - if node.Spec.Availability == api.NodeAvailabilityPause || - !constraint.NodeMatches(serviceEntry.constraints, node) { + if !constraint.NodeMatches(serviceEntry.constraints, node) { t.DesiredState = api.TaskStateShutdown return store.UpdateTask(tx, t) } - return g.restarts.Restart(ctx, tx, g.cluster, service, *t) + return g.restarts.Restart(ctx, tx, g.cluster, service, *t, false) }) if err != nil { log.G(ctx).WithError(err).Errorf("orchestrator restartTask transaction failed") diff --git a/manager/orchestrator/global/global_test.go b/manager/orchestrator/global/global_test.go index 51f4eda376..9933b514be 100644 --- a/manager/orchestrator/global/global_test.go +++ b/manager/orchestrator/global/global_test.go @@ -266,16 +266,20 @@ func TestNodeState(t *testing.T) { defer orchestrator.Stop() testutils.WatchTaskCreate(t, watch) + testutils.Expect(t, watch, state.EventCommit{}) // set node1 to down updateNodeState(t, store, node1, api.NodeStatus_DOWN) - - // task should be set to dead - observedTask1 := testutils.WatchShutdownTask(t, watch) - assert.Equal(t, observedTask1.ServiceAnnotations.Name, "name1") - assert.Equal(t, observedTask1.NodeID, "nodeid1") + testutils.Expect(t, watch, api.EventUpdateNode{}) testutils.Expect(t, watch, state.EventCommit{}) + // nothing should happen + select { + case event := <-watch: + t.Fatalf("got unexpected event %T: %+v", event, event) + case <-time.After(100 * time.Millisecond): + } + // updating the service shouldn't restart the task updateService(t, store, service1) testutils.Expect(t, watch, api.EventUpdateService{}) @@ -288,7 +292,7 @@ func TestNodeState(t *testing.T) { // set node1 to ready updateNodeState(t, store, node1, api.NodeStatus_READY) - // task should be added back + // task should be updated now observedTask2 := testutils.WatchTaskCreate(t, watch) assert.Equal(t, observedTask2.Status.State, api.TaskStateNew) assert.Equal(t, observedTask2.ServiceAnnotations.Name, "name1") @@ -414,9 +418,6 @@ func TestTaskFailure(t *testing.T) { failTask(t, store, observedTask3) testutils.Expect(t, watch, api.EventUpdateTask{}) testutils.Expect(t, watch, state.EventCommit{}) - observedTask4 := testutils.WatchTaskUpdate(t, watch) - assert.Equal(t, observedTask4.DesiredState, api.TaskStateShutdown) - testutils.Expect(t, watch, state.EventCommit{}) // the task should not be recreated select { @@ -430,10 +431,14 @@ func TestTaskFailure(t *testing.T) { testutils.Expect(t, watch, api.EventUpdateService{}) testutils.Expect(t, watch, state.EventCommit{}) - observedTask5 := testutils.WatchTaskCreate(t, watch) - assert.Equal(t, observedTask5.Status.State, api.TaskStateNew) - assert.Equal(t, observedTask5.ServiceAnnotations.Name, "norestart") - assert.Equal(t, observedTask5.NodeID, "nodeid1") + observedTask4 := testutils.WatchTaskCreate(t, watch) + assert.Equal(t, observedTask4.Status.State, api.TaskStateNew) + assert.Equal(t, observedTask4.ServiceAnnotations.Name, "norestart") + assert.Equal(t, observedTask4.NodeID, "nodeid1") + + // old task gets shut down as the new one is created + observedTask5 := testutils.WatchTaskUpdate(t, watch) + assert.Equal(t, observedTask5.DesiredState, api.TaskStateShutdown) testutils.Expect(t, watch, state.EventCommit{}) } diff --git a/manager/orchestrator/replicated/drain_test.go b/manager/orchestrator/replicated/drain_test.go index 58a76c4a7f..dcb62edcda 100644 --- a/manager/orchestrator/replicated/drain_test.go +++ b/manager/orchestrator/replicated/drain_test.go @@ -216,20 +216,17 @@ func TestDrain(t *testing.T) { assert.NoError(t, orchestrator.Run(ctx)) }() - // id2 and id5 should be killed immediately + // id5 should be killed immediately deletion1 := testutils.WatchShutdownTask(t, watch) - deletion2 := testutils.WatchShutdownTask(t, watch) - assert.Regexp(t, "id(2|5)", deletion1.ID) - assert.Regexp(t, "id(2|5)", deletion1.NodeID) - assert.Regexp(t, "id(2|5)", deletion2.ID) - assert.Regexp(t, "id(2|5)", deletion2.NodeID) + assert.Equal(t, "id5", deletion1.ID) + assert.Equal(t, "id5", deletion1.NodeID) - // Create a new task, assigned to node id2 + // Create a new task, assigned to node id5 err = s.Update(func(tx store.Tx) error { task := initialTaskSet[2].Copy() task.ID = "newtask" - task.NodeID = "id2" + task.NodeID = "id5" assert.NoError(t, store.CreateTask(tx, task)) return nil }) @@ -237,7 +234,7 @@ func TestDrain(t *testing.T) { deletion3 := testutils.WatchShutdownTask(t, watch) assert.Equal(t, "newtask", deletion3.ID) - assert.Equal(t, "id2", deletion3.NodeID) + assert.Equal(t, "id5", deletion3.NodeID) // Set node id4 to the DRAINED state err = s.Update(func(tx store.Tx) error { diff --git a/manager/orchestrator/replicated/replicated.go b/manager/orchestrator/replicated/replicated.go index 18b8e24aba..9adb05a1ce 100644 --- a/manager/orchestrator/replicated/replicated.go +++ b/manager/orchestrator/replicated/replicated.go @@ -15,7 +15,12 @@ type Orchestrator struct { store *store.MemoryStore reconcileServices map[string]*api.Service - restartTasks map[string]struct{} + + // restartTasks' keys are tasks that need to have Restart called. + // The value is whether to force the desired state to shutdown + // even when the restart policy does not call for this, for example + // when a node is drained. + restartTasks map[string]bool // stopChan signals to the state machine to stop running. stopChan chan struct{} @@ -37,7 +42,7 @@ func NewReplicatedOrchestrator(store *store.MemoryStore) *Orchestrator { stopChan: make(chan struct{}), doneChan: make(chan struct{}), reconcileServices: make(map[string]*api.Service), - restartTasks: make(map[string]struct{}), + restartTasks: make(map[string]bool), updater: updater, restarts: restartSupervisor, } diff --git a/manager/orchestrator/replicated/restart_test.go b/manager/orchestrator/replicated/restart_test.go index 2d621c36a8..345fbb3096 100644 --- a/manager/orchestrator/replicated/restart_test.go +++ b/manager/orchestrator/replicated/restart_test.go @@ -208,8 +208,6 @@ func TestOrchestratorRestartOnFailure(t *testing.T) { testutils.Expect(t, watch, state.EventCommit{}) testutils.Expect(t, watch, api.EventUpdateTask{}) testutils.Expect(t, watch, state.EventCommit{}) - testutils.Expect(t, watch, api.EventUpdateTask{}) - testutils.Expect(t, watch, state.EventCommit{}) select { case <-watch: @@ -284,8 +282,6 @@ func TestOrchestratorRestartOnNone(t *testing.T) { testutils.Expect(t, watch, state.EventCommit{}) testutils.Expect(t, watch, api.EventUpdateTask{}) testutils.Expect(t, watch, state.EventCommit{}) - testutils.Expect(t, watch, api.EventUpdateTask{}) - testutils.Expect(t, watch, state.EventCommit{}) select { case <-watch: @@ -303,8 +299,6 @@ func TestOrchestratorRestartOnNone(t *testing.T) { assert.NoError(t, err) testutils.Expect(t, watch, api.EventUpdateTask{}) testutils.Expect(t, watch, state.EventCommit{}) - testutils.Expect(t, watch, api.EventUpdateTask{}) - testutils.Expect(t, watch, state.EventCommit{}) select { case <-watch: @@ -414,7 +408,7 @@ func TestOrchestratorRestartMaxAttempts(t *testing.T) { orchestrator := NewReplicatedOrchestrator(s) defer orchestrator.Stop() - watch, cancel := state.Watch(s.WatchQueue() /*api.EventCreateTask{}, api.EventUpdateTask{}*/) + watch, cancel := state.Watch(s.WatchQueue(), api.EventCreateTask{}, api.EventUpdateTask{}) defer cancel() // Create a service with two instances specified before the orchestrator is @@ -457,16 +451,44 @@ func TestOrchestratorRestartMaxAttempts(t *testing.T) { assert.NoError(t, orchestrator.Run(ctx)) }() - testRestart := func() { + testRestart := func(serviceUpdated bool) { observedTask1 := testutils.WatchTaskCreate(t, watch) assert.Equal(t, observedTask1.Status.State, api.TaskStateNew) assert.Equal(t, observedTask1.ServiceAnnotations.Name, "name1") + if serviceUpdated { + shutdownTask := testutils.WatchTaskUpdate(t, watch) + runnableTask := testutils.WatchTaskUpdate(t, watch) + + assert.Equal(t, api.TaskStateShutdown, shutdownTask.DesiredState) + err = s.Update(func(tx store.Tx) error { + task := shutdownTask.Copy() + task.Status.State = api.TaskStateShutdown + assert.NoError(t, store.UpdateTask(tx, task)) + return nil + }) + assert.NoError(t, err) + + testutils.Expect(t, watch, api.EventUpdateTask{}) + + assert.Equal(t, api.TaskStateRunning, runnableTask.DesiredState) + err = s.Update(func(tx store.Tx) error { + task := runnableTask.Copy() + task.Status.State = api.TaskStateRunning + assert.NoError(t, store.UpdateTask(tx, task)) + return nil + }) + assert.NoError(t, err) + + testutils.Expect(t, watch, api.EventUpdateTask{}) + } observedTask2 := testutils.WatchTaskCreate(t, watch) assert.Equal(t, observedTask2.Status.State, api.TaskStateNew) assert.Equal(t, observedTask2.ServiceAnnotations.Name, "name1") - - testutils.Expect(t, watch, state.EventCommit{}) + if serviceUpdated { + testutils.Expect(t, watch, api.EventUpdateTask{}) + testutils.Expect(t, watch, api.EventUpdateTask{}) + } // Fail the first task. Confirm that it gets restarted. updatedTask1 := observedTask1.Copy() @@ -478,17 +500,14 @@ func TestOrchestratorRestartMaxAttempts(t *testing.T) { }) assert.NoError(t, err) testutils.Expect(t, watch, api.EventUpdateTask{}) - testutils.Expect(t, watch, state.EventCommit{}) testutils.Expect(t, watch, api.EventUpdateTask{}) observedTask3 := testutils.WatchTaskCreate(t, watch) - testutils.Expect(t, watch, state.EventCommit{}) assert.Equal(t, observedTask3.Status.State, api.TaskStateNew) assert.Equal(t, observedTask3.DesiredState, api.TaskStateReady) assert.Equal(t, observedTask3.ServiceAnnotations.Name, "name1") observedTask4 := testutils.WatchTaskUpdate(t, watch) - testutils.Expect(t, watch, state.EventCommit{}) after := time.Now() // At least 100 ms should have elapsed. Only check the lower bound, @@ -510,16 +529,13 @@ func TestOrchestratorRestartMaxAttempts(t *testing.T) { }) assert.NoError(t, err) testutils.Expect(t, watch, api.EventUpdateTask{}) - testutils.Expect(t, watch, state.EventCommit{}) testutils.Expect(t, watch, api.EventUpdateTask{}) observedTask5 := testutils.WatchTaskCreate(t, watch) - testutils.Expect(t, watch, state.EventCommit{}) assert.Equal(t, observedTask5.Status.State, api.TaskStateNew) assert.Equal(t, observedTask5.DesiredState, api.TaskStateReady) observedTask6 := testutils.WatchTaskUpdate(t, watch) // task gets started after a delay - testutils.Expect(t, watch, state.EventCommit{}) assert.Equal(t, observedTask6.Status.State, api.TaskStateNew) assert.Equal(t, observedTask6.DesiredState, api.TaskStateRunning) assert.Equal(t, observedTask6.ServiceAnnotations.Name, "name1") @@ -533,9 +549,6 @@ func TestOrchestratorRestartMaxAttempts(t *testing.T) { }) assert.NoError(t, err) testutils.Expect(t, watch, api.EventUpdateTask{}) - testutils.Expect(t, watch, state.EventCommit{}) - testutils.Expect(t, watch, api.EventUpdateTask{}) - testutils.Expect(t, watch, state.EventCommit{}) select { case <-watch: @@ -552,9 +565,6 @@ func TestOrchestratorRestartMaxAttempts(t *testing.T) { }) assert.NoError(t, err) testutils.Expect(t, watch, api.EventUpdateTask{}) - testutils.Expect(t, watch, state.EventCommit{}) - testutils.Expect(t, watch, api.EventUpdateTask{}) - testutils.Expect(t, watch, state.EventCommit{}) select { case <-watch: @@ -563,7 +573,7 @@ func TestOrchestratorRestartMaxAttempts(t *testing.T) { } } - testRestart() + testRestart(false) // Update the service spec err = s.Update(func(tx store.Tx) error { @@ -576,7 +586,7 @@ func TestOrchestratorRestartMaxAttempts(t *testing.T) { }) assert.NoError(t, err) - testRestart() + testRestart(true) } func TestOrchestratorRestartWindow(t *testing.T) { @@ -704,8 +714,6 @@ func TestOrchestratorRestartWindow(t *testing.T) { assert.NoError(t, err) testutils.Expect(t, watch, api.EventUpdateTask{}) testutils.Expect(t, watch, state.EventCommit{}) - testutils.Expect(t, watch, api.EventUpdateTask{}) - testutils.Expect(t, watch, state.EventCommit{}) select { case <-watch: diff --git a/manager/orchestrator/replicated/tasks.go b/manager/orchestrator/replicated/tasks.go index 66000e5d86..e61abba04c 100644 --- a/manager/orchestrator/replicated/tasks.go +++ b/manager/orchestrator/replicated/tasks.go @@ -22,7 +22,7 @@ func (r *Orchestrator) initTasks(ctx context.Context, readTx store.ReadTx) error func (r *Orchestrator) handleTaskEvent(ctx context.Context, event events.Event) { switch v := event.(type) { case api.EventDeleteNode: - r.restartTasksByNodeID(ctx, v.Node.ID) + r.restartTasksByNodeID(ctx, v.Node.ID, true) case api.EventCreateNode: r.handleNodeChange(ctx, v.Node) case api.EventUpdateNode: @@ -46,7 +46,7 @@ func (r *Orchestrator) handleTaskEvent(ctx context.Context, event events.Event) func (r *Orchestrator) tickTasks(ctx context.Context) { if len(r.restartTasks) > 0 { err := r.store.Batch(func(batch *store.Batch) error { - for taskID := range r.restartTasks { + for taskID, forceShutdownState := range r.restartTasks { err := batch.Update(func(tx store.Tx) error { // TODO(aaronl): optimistic update? t := store.GetTask(tx, taskID) @@ -61,7 +61,7 @@ func (r *Orchestrator) tickTasks(ctx context.Context) { } // Restart task if applicable - if err := r.restarts.Restart(ctx, tx, r.cluster, service, *t); err != nil { + if err := r.restarts.Restart(ctx, tx, r.cluster, service, *t, forceShutdownState); err != nil { return err } } @@ -78,11 +78,11 @@ func (r *Orchestrator) tickTasks(ctx context.Context) { log.G(ctx).WithError(err).Errorf("orchestrator task removal batch failed") } - r.restartTasks = make(map[string]struct{}) + r.restartTasks = make(map[string]bool) } } -func (r *Orchestrator) restartTasksByNodeID(ctx context.Context, nodeID string) { +func (r *Orchestrator) restartTasksByNodeID(ctx context.Context, nodeID string, forceShutdownState bool) { var err error r.store.View(func(tx store.ReadTx) { var tasks []*api.Task @@ -97,7 +97,7 @@ func (r *Orchestrator) restartTasksByNodeID(ctx context.Context, nodeID string) } service := store.GetService(tx, t.ServiceID) if orchestrator.IsReplicatedService(service) { - r.restartTasks[t.ID] = struct{}{} + r.restartTasks[t.ID] = forceShutdownState } } }) @@ -107,11 +107,11 @@ func (r *Orchestrator) restartTasksByNodeID(ctx context.Context, nodeID string) } func (r *Orchestrator) handleNodeChange(ctx context.Context, n *api.Node) { - if !orchestrator.InvalidNode(n) { - return + if n.Spec.Availability == api.NodeAvailabilityDrain { + r.restartTasksByNodeID(ctx, n.ID, true) + } else if n.Status.State == api.NodeStatus_DOWN { + r.restartTasksByNodeID(ctx, n.ID, false) } - - r.restartTasksByNodeID(ctx, n.ID) } // handleTaskChange defines what orchestrator does when a task is updated by agent. @@ -123,25 +123,37 @@ func (r *Orchestrator) handleTaskChange(ctx context.Context, t *api.Task) { } var ( - n *api.Node + node *api.Node service *api.Service ) r.store.View(func(tx store.ReadTx) { if t.NodeID != "" { - n = store.GetNode(tx, t.NodeID) + node = store.GetNode(tx, t.NodeID) } if t.ServiceID != "" { service = store.GetService(tx, t.ServiceID) } }) + r.maybeRestartTask(t, service, node) +} + +func (r *Orchestrator) maybeRestartTask(t *api.Task, service *api.Service, node *api.Node) { if !orchestrator.IsReplicatedService(service) { return } - if t.Status.State > api.TaskStateRunning || - (t.NodeID != "" && orchestrator.InvalidNode(n)) { - r.restartTasks[t.ID] = struct{}{} + if t.Status.State > api.TaskStateRunning { + r.restartTasks[t.ID] = false + } + if t.NodeID != "" { + if node == nil { + r.restartTasks[t.ID] = false + } else if node.Spec.Availability == api.NodeAvailabilityDrain { + r.restartTasks[t.ID] = true + } else if node.Status.State == api.NodeStatus_DOWN { + r.restartTasks[t.ID] = false + } } } @@ -155,12 +167,12 @@ func (r *Orchestrator) FixTask(ctx context.Context, batch *store.Batch, t *api.T } var ( - n *api.Node + node *api.Node service *api.Service ) batch.Update(func(tx store.Tx) error { if t.NodeID != "" { - n = store.GetNode(tx, t.NodeID) + node = store.GetNode(tx, t.NodeID) } if t.ServiceID != "" { service = store.GetService(tx, t.ServiceID) @@ -168,13 +180,5 @@ func (r *Orchestrator) FixTask(ctx context.Context, batch *store.Batch, t *api.T return nil }) - if !orchestrator.IsReplicatedService(service) { - return - } - - if t.Status.State > api.TaskStateRunning || - (t.NodeID != "" && orchestrator.InvalidNode(n)) { - r.restartTasks[t.ID] = struct{}{} - return - } + r.maybeRestartTask(t, service, node) } diff --git a/manager/orchestrator/restart/restart.go b/manager/orchestrator/restart/restart.go index eed28f8202..1dec04e574 100644 --- a/manager/orchestrator/restart/restart.go +++ b/manager/orchestrator/restart/restart.go @@ -96,7 +96,7 @@ func (r *Supervisor) waitRestart(ctx context.Context, oldDelay *delayedStart, cl if service == nil { return nil } - return r.Restart(ctx, tx, cluster, service, *t) + return r.Restart(ctx, tx, cluster, service, *t, false) }) if err != nil { @@ -106,7 +106,7 @@ func (r *Supervisor) waitRestart(ctx context.Context, oldDelay *delayedStart, cl // Restart initiates a new task to replace t if appropriate under the service's // restart policy. -func (r *Supervisor) Restart(ctx context.Context, tx store.Tx, cluster *api.Cluster, service *api.Service, t api.Task) error { +func (r *Supervisor) Restart(ctx context.Context, tx store.Tx, cluster *api.Cluster, service *api.Service, t api.Task, forceShutdownState bool) error { // TODO(aluzzardi): This function should not depend on `service`. // Is the old task still in the process of restarting? If so, wait for @@ -131,6 +131,12 @@ func (r *Supervisor) Restart(ctx context.Context, tx store.Tx, cluster *api.Clus return errors.New("Restart called on task that was already shut down") } + shouldRestart := r.shouldRestart(ctx, &t, service) + + if !forceShutdownState && !shouldRestart { + return nil + } + t.DesiredState = api.TaskStateShutdown err := store.UpdateTask(tx, &t) if err != nil { @@ -138,23 +144,27 @@ func (r *Supervisor) Restart(ctx context.Context, tx store.Tx, cluster *api.Clus return err } - if !r.shouldRestart(ctx, &t, service) { + if !shouldRestart { return nil } + n := store.GetNode(tx, t.NodeID) + var restartTask *api.Task if orchestrator.IsReplicatedService(service) { restartTask = orchestrator.NewTask(cluster, service, t.Slot, "") } else if orchestrator.IsGlobalService(service) { + if n != nil && (n.Status.State == api.NodeStatus_DOWN || n.Spec.Availability == api.NodeAvailabilityPause) { + // We don't restart global service tasks on a node that's down or paused + return nil + } restartTask = orchestrator.NewTask(cluster, service, 0, t.NodeID) } else { log.G(ctx).Error("service not supported by restart supervisor") return nil } - n := store.GetNode(tx, t.NodeID) - restartTask.DesiredState = api.TaskStateReady var restartDelay time.Duration diff --git a/manager/orchestrator/task.go b/manager/orchestrator/task.go index 32a22d5f5a..c75755827d 100644 --- a/manager/orchestrator/task.go +++ b/manager/orchestrator/task.go @@ -70,10 +70,3 @@ func IsTaskDirty(s *api.Service, t *api.Task) bool { return !reflect.DeepEqual(s.Spec.Task, t.Spec) || (t.Endpoint != nil && !reflect.DeepEqual(s.Spec.Endpoint, t.Endpoint.Spec)) } - -// InvalidNode is true if the node is nil, down, or drained -func InvalidNode(n *api.Node) bool { - return n == nil || - n.Status.State == api.NodeStatus_DOWN || - n.Spec.Availability == api.NodeAvailabilityDrain -} diff --git a/manager/orchestrator/update/updater.go b/manager/orchestrator/update/updater.go index 349bccabb9..cff5202d21 100644 --- a/manager/orchestrator/update/updater.go +++ b/manager/orchestrator/update/updater.go @@ -584,28 +584,70 @@ func (u *Updater) rollbackUpdate(ctx context.Context, serviceID, message string) log.G(ctx).Debugf("starting rollback of service %s", serviceID) var service *api.Service - err := u.store.Update(func(tx store.Tx) error { - service = store.GetService(tx, serviceID) - if service == nil { - return nil + err := u.store.Batch(func(batch *store.Batch) error { + var serviceTasks []*api.Task + + err := batch.Update(func(tx store.Tx) error { + var err error + serviceTasks, err = store.FindTasks(tx, store.ByServiceID(serviceID)) + return err + }) + if err != nil { + return err } - if service.UpdateStatus == nil { - // The service was updated since we started this update - return nil + + // Shut down any failed tasks. This ensures that the rollback + // will bring the service back to a converged state, instead of + // skipping tasks failed before the update and subsequent + // rollback. + for _, task := range serviceTasks { + if task.DesiredState <= api.TaskStateRunning && task.Status.State > api.TaskStateRunning { + err = batch.Update(func(tx store.Tx) error { + task.DesiredState = api.TaskStateShutdown + return store.UpdateTask(tx, task) + }) + if err != nil { + return err + } + } } - service.UpdateStatus.State = api.UpdateStatus_ROLLBACK_STARTED - service.UpdateStatus.Message = message + err = batch.Update(func(tx store.Tx) error { + service = store.GetService(tx, serviceID) + if service == nil { + return nil + } + if service.UpdateStatus == nil { + // The service was updated since we started this update + return nil + } + + service.UpdateStatus.State = api.UpdateStatus_ROLLBACK_STARTED + service.UpdateStatus.Message = message + + if service.PreviousSpec == nil { + return errors.New("cannot roll back service because no previous spec is available") + } + + var err error + serviceTasks, err = store.FindTasks(tx, store.ByServiceID(serviceID)) + if err != nil { + return err + } + + service.Spec = *service.PreviousSpec + service.SpecVersion = service.PreviousSpecVersion.Copy() + service.PreviousSpec = nil + service.PreviousSpecVersion = nil - if service.PreviousSpec == nil { - return errors.New("cannot roll back service because no previous spec is available") + return store.UpdateService(tx, service) + + }) + if err != nil { + return err } - service.Spec = *service.PreviousSpec - service.SpecVersion = service.PreviousSpecVersion.Copy() - service.PreviousSpec = nil - service.PreviousSpecVersion = nil - return store.UpdateService(tx, service) + return nil }) if err != nil {