diff --git a/manager/orchestrator/replicated/restart_test.go b/manager/orchestrator/replicated/restart_test.go index 7d2f0bd1cd..2d621c36a8 100644 --- a/manager/orchestrator/replicated/restart_test.go +++ b/manager/orchestrator/replicated/restart_test.go @@ -10,6 +10,7 @@ import ( "github.com/docker/swarmkit/manager/state/store" gogotypes "github.com/gogo/protobuf/types" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/net/context" ) @@ -442,6 +443,9 @@ func TestOrchestratorRestartMaxAttempts(t *testing.T) { }, }, }, + SpecVersion: &api.Version{ + Index: 1, + }, } assert.NoError(t, store.CreateService(tx, j1)) return nil @@ -453,89 +457,126 @@ func TestOrchestratorRestartMaxAttempts(t *testing.T) { assert.NoError(t, orchestrator.Run(ctx)) }() - observedTask1 := testutils.WatchTaskCreate(t, watch) - assert.Equal(t, observedTask1.Status.State, api.TaskStateNew) - assert.Equal(t, observedTask1.ServiceAnnotations.Name, "name1") - - observedTask2 := testutils.WatchTaskCreate(t, watch) - assert.Equal(t, observedTask2.Status.State, api.TaskStateNew) - assert.Equal(t, observedTask2.ServiceAnnotations.Name, "name1") - - // Fail the first task. Confirm that it gets restarted. - updatedTask1 := observedTask1.Copy() - updatedTask1.Status = api.TaskStatus{State: api.TaskStateFailed} - before := time.Now() - err = s.Update(func(tx store.Tx) error { - assert.NoError(t, store.UpdateTask(tx, updatedTask1)) - return nil - }) - assert.NoError(t, err) - testutils.Expect(t, watch, state.EventCommit{}) - testutils.Expect(t, watch, api.EventUpdateTask{}) - testutils.Expect(t, watch, state.EventCommit{}) - testutils.Expect(t, watch, api.EventUpdateTask{}) - - observedTask3 := testutils.WatchTaskCreate(t, watch) - testutils.Expect(t, watch, state.EventCommit{}) - assert.Equal(t, observedTask3.Status.State, api.TaskStateNew) - assert.Equal(t, observedTask3.DesiredState, api.TaskStateReady) - assert.Equal(t, observedTask3.ServiceAnnotations.Name, "name1") + testRestart := func() { + observedTask1 := testutils.WatchTaskCreate(t, watch) + assert.Equal(t, observedTask1.Status.State, api.TaskStateNew) + assert.Equal(t, observedTask1.ServiceAnnotations.Name, "name1") + + observedTask2 := testutils.WatchTaskCreate(t, watch) + assert.Equal(t, observedTask2.Status.State, api.TaskStateNew) + assert.Equal(t, observedTask2.ServiceAnnotations.Name, "name1") + + testutils.Expect(t, watch, state.EventCommit{}) + + // Fail the first task. Confirm that it gets restarted. + updatedTask1 := observedTask1.Copy() + updatedTask1.Status = api.TaskStatus{State: api.TaskStateFailed} + before := time.Now() + err = s.Update(func(tx store.Tx) error { + assert.NoError(t, store.UpdateTask(tx, updatedTask1)) + return nil + }) + assert.NoError(t, err) + testutils.Expect(t, watch, api.EventUpdateTask{}) + testutils.Expect(t, watch, state.EventCommit{}) + testutils.Expect(t, watch, api.EventUpdateTask{}) + + observedTask3 := testutils.WatchTaskCreate(t, watch) + testutils.Expect(t, watch, state.EventCommit{}) + assert.Equal(t, observedTask3.Status.State, api.TaskStateNew) + assert.Equal(t, observedTask3.DesiredState, api.TaskStateReady) + assert.Equal(t, observedTask3.ServiceAnnotations.Name, "name1") + + observedTask4 := testutils.WatchTaskUpdate(t, watch) + testutils.Expect(t, watch, state.EventCommit{}) + after := time.Now() + + // At least 100 ms should have elapsed. Only check the lower bound, + // because the system may be slow and it could have taken longer. + if after.Sub(before) < 100*time.Millisecond { + t.Fatal("restart delay should have elapsed") + } - observedTask4 := testutils.WatchTaskUpdate(t, watch) - after := time.Now() + assert.Equal(t, observedTask4.Status.State, api.TaskStateNew) + assert.Equal(t, observedTask4.DesiredState, api.TaskStateRunning) + assert.Equal(t, observedTask4.ServiceAnnotations.Name, "name1") + + // Fail the second task. Confirm that it gets restarted. + updatedTask2 := observedTask2.Copy() + updatedTask2.Status = api.TaskStatus{State: api.TaskStateFailed} + err = s.Update(func(tx store.Tx) error { + assert.NoError(t, store.UpdateTask(tx, updatedTask2)) + return nil + }) + assert.NoError(t, err) + testutils.Expect(t, watch, api.EventUpdateTask{}) + testutils.Expect(t, watch, state.EventCommit{}) + testutils.Expect(t, watch, api.EventUpdateTask{}) + + observedTask5 := testutils.WatchTaskCreate(t, watch) + testutils.Expect(t, watch, state.EventCommit{}) + assert.Equal(t, observedTask5.Status.State, api.TaskStateNew) + assert.Equal(t, observedTask5.DesiredState, api.TaskStateReady) + + observedTask6 := testutils.WatchTaskUpdate(t, watch) // task gets started after a delay + testutils.Expect(t, watch, state.EventCommit{}) + assert.Equal(t, observedTask6.Status.State, api.TaskStateNew) + assert.Equal(t, observedTask6.DesiredState, api.TaskStateRunning) + assert.Equal(t, observedTask6.ServiceAnnotations.Name, "name1") + + // Fail the first instance again. It should not be restarted. + updatedTask1 = observedTask3.Copy() + updatedTask1.Status = api.TaskStatus{State: api.TaskStateFailed} + err = s.Update(func(tx store.Tx) error { + assert.NoError(t, store.UpdateTask(tx, updatedTask1)) + return nil + }) + assert.NoError(t, err) + testutils.Expect(t, watch, api.EventUpdateTask{}) + testutils.Expect(t, watch, state.EventCommit{}) + testutils.Expect(t, watch, api.EventUpdateTask{}) + testutils.Expect(t, watch, state.EventCommit{}) + + select { + case <-watch: + t.Fatal("got unexpected event") + case <-time.After(200 * time.Millisecond): + } - // At least 100 ms should have elapsed. Only check the lower bound, - // because the system may be slow and it could have taken longer. - if after.Sub(before) < 100*time.Millisecond { - t.Fatal("restart delay should have elapsed") + // Fail the second instance again. It should not be restarted. + updatedTask2 = observedTask5.Copy() + updatedTask2.Status = api.TaskStatus{State: api.TaskStateFailed} + err = s.Update(func(tx store.Tx) error { + assert.NoError(t, store.UpdateTask(tx, updatedTask2)) + return nil + }) + assert.NoError(t, err) + testutils.Expect(t, watch, api.EventUpdateTask{}) + testutils.Expect(t, watch, state.EventCommit{}) + testutils.Expect(t, watch, api.EventUpdateTask{}) + testutils.Expect(t, watch, state.EventCommit{}) + + select { + case <-watch: + t.Fatal("got unexpected event") + case <-time.After(200 * time.Millisecond): + } } - assert.Equal(t, observedTask4.Status.State, api.TaskStateNew) - assert.Equal(t, observedTask4.DesiredState, api.TaskStateRunning) - assert.Equal(t, observedTask4.ServiceAnnotations.Name, "name1") - - // Fail the second task. Confirm that it gets restarted. - updatedTask2 := observedTask2.Copy() - updatedTask2.Status = api.TaskStatus{State: api.TaskStateFailed} - err = s.Update(func(tx store.Tx) error { - assert.NoError(t, store.UpdateTask(tx, updatedTask2)) - return nil - }) - assert.NoError(t, err) - testutils.Expect(t, watch, state.EventCommit{}) - testutils.Expect(t, watch, api.EventUpdateTask{}) - testutils.Expect(t, watch, state.EventCommit{}) - testutils.Expect(t, watch, api.EventUpdateTask{}) - - observedTask5 := testutils.WatchTaskCreate(t, watch) - testutils.Expect(t, watch, state.EventCommit{}) - assert.Equal(t, observedTask5.Status.State, api.TaskStateNew) - assert.Equal(t, observedTask5.DesiredState, api.TaskStateReady) + testRestart() - observedTask6 := testutils.WatchTaskUpdate(t, watch) // task gets started after a delay - testutils.Expect(t, watch, state.EventCommit{}) - assert.Equal(t, observedTask6.Status.State, api.TaskStateNew) - assert.Equal(t, observedTask6.DesiredState, api.TaskStateRunning) - assert.Equal(t, observedTask6.ServiceAnnotations.Name, "name1") - - // Fail the first instance again. It should not be restarted. - updatedTask1 = observedTask3.Copy() - updatedTask1.Status = api.TaskStatus{State: api.TaskStateFailed} + // Update the service spec err = s.Update(func(tx store.Tx) error { - assert.NoError(t, store.UpdateTask(tx, updatedTask1)) + s := store.GetService(tx, "id1") + require.NotNil(t, s) + s.Spec.Task.GetContainer().Image = "newimage" + s.SpecVersion.Index = 2 + assert.NoError(t, store.UpdateService(tx, s)) return nil }) assert.NoError(t, err) - testutils.Expect(t, watch, api.EventUpdateTask{}) - testutils.Expect(t, watch, state.EventCommit{}) - testutils.Expect(t, watch, api.EventUpdateTask{}) - testutils.Expect(t, watch, state.EventCommit{}) - select { - case <-watch: - t.Fatal("got unexpected event") - case <-time.After(200 * time.Millisecond): - } + testRestart() } func TestOrchestratorRestartWindow(t *testing.T) { diff --git a/manager/orchestrator/restart/restart.go b/manager/orchestrator/restart/restart.go index 6167552d4a..eed28f8202 100644 --- a/manager/orchestrator/restart/restart.go +++ b/manager/orchestrator/restart/restart.go @@ -30,6 +30,13 @@ type instanceRestartInfo struct { // Restart.MaxAttempts and Restart.Window are both // nonzero. restartedInstances *list.List + // Why is specVersion in this structure and not in the map key? While + // putting it in the key would be a very simple solution, it wouldn't + // be easy to clean up map entries corresponding to old specVersions. + // Making the key version-agnostic and clearing the value whenever the + // version changes avoids the issue of stale map entries for old + // versions. + specVersion api.Version } type delayedStart struct { @@ -54,8 +61,7 @@ type Supervisor struct { mu sync.Mutex store *store.MemoryStore delays map[string]*delayedStart - history map[instanceTuple]*instanceRestartInfo - historyByService map[string]map[instanceTuple]struct{} + historyByService map[string]map[instanceTuple]*instanceRestartInfo TaskTimeout time.Duration } @@ -64,8 +70,7 @@ func NewSupervisor(store *store.MemoryStore) *Supervisor { return &Supervisor{ store: store, delays: make(map[string]*delayedStart), - history: make(map[instanceTuple]*instanceRestartInfo), - historyByService: make(map[string]map[instanceTuple]struct{}), + historyByService: make(map[string]map[instanceTuple]*instanceRestartInfo), TaskTimeout: defaultOldTaskTimeout, } } @@ -214,8 +219,8 @@ func (r *Supervisor) shouldRestart(ctx context.Context, t *api.Task, service *ap r.mu.Lock() defer r.mu.Unlock() - restartInfo := r.history[instanceTuple] - if restartInfo == nil { + restartInfo := r.historyByService[t.ServiceID][instanceTuple] + if restartInfo == nil || (t.SpecVersion != nil && *t.SpecVersion != restartInfo.specVersion) { return true } @@ -268,17 +273,26 @@ func (r *Supervisor) recordRestartHistory(restartTask *api.Task) { r.mu.Lock() defer r.mu.Unlock() - if r.history[tuple] == nil { - r.history[tuple] = &instanceRestartInfo{} + if r.historyByService[restartTask.ServiceID] == nil { + r.historyByService[restartTask.ServiceID] = make(map[instanceTuple]*instanceRestartInfo) + } + if r.historyByService[restartTask.ServiceID][tuple] == nil { + r.historyByService[restartTask.ServiceID][tuple] = &instanceRestartInfo{} } - restartInfo := r.history[tuple] - restartInfo.totalRestarts++ + restartInfo := r.historyByService[restartTask.ServiceID][tuple] - if r.historyByService[restartTask.ServiceID] == nil { - r.historyByService[restartTask.ServiceID] = make(map[instanceTuple]struct{}) + if restartTask.SpecVersion != nil && *restartTask.SpecVersion != restartInfo.specVersion { + // This task has a different SpecVersion from the one we're + // tracking. Most likely, the service was updated. Past failures + // shouldn't count against the new service definition, so clear + // the history for this instance. + *restartInfo = instanceRestartInfo{ + specVersion: *restartTask.SpecVersion, + } } - r.historyByService[restartTask.ServiceID][tuple] = struct{}{} + + restartInfo.totalRestarts++ if restartTask.Spec.Restart.Window != nil && (restartTask.Spec.Restart.Window.Seconds != 0 || restartTask.Spec.Restart.Window.Nanos != 0) { if restartInfo.restartedInstances == nil { @@ -432,16 +446,6 @@ func (r *Supervisor) CancelAll() { // ClearServiceHistory forgets restart history related to a given service ID. func (r *Supervisor) ClearServiceHistory(serviceID string) { r.mu.Lock() - defer r.mu.Unlock() - - tuples := r.historyByService[serviceID] - if tuples == nil { - return - } - delete(r.historyByService, serviceID) - - for t := range tuples { - delete(r.history, t) - } + r.mu.Unlock() }