diff --git a/agent/agent_test.go b/agent/agent_test.go index bc74698ff7..9dc1a0c4cf 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -19,6 +19,7 @@ import ( "github.com/docker/swarmkit/ca" cautils "github.com/docker/swarmkit/ca/testutils" "github.com/docker/swarmkit/connectionbroker" + "github.com/docker/swarmkit/log" "github.com/docker/swarmkit/remotes" "github.com/docker/swarmkit/testutils" "github.com/docker/swarmkit/xnet" @@ -89,7 +90,7 @@ func TestAgentStartStop(t *testing.T) { require.NoError(t, err) assert.NotNil(t, agent) - ctx, _ := context.WithTimeout(context.Background(), 5000*time.Millisecond) + ctx, _ := context.WithTimeout(tc.Context, 5000*time.Millisecond) assert.Equal(t, errAgentNotStarted, agent.Stop(ctx)) assert.NoError(t, agent.Start(ctx)) @@ -403,12 +404,12 @@ func TestAgentExitsBasedOnSessionTracker(t *testing.T) { tracker := testSessionTracker{} tester.agent.config.SessionTracker = &tracker - go tester.agent.Start(context.Background()) - defer tester.agent.Stop(context.Background()) + go tester.agent.Start(tester.testCA.Context) + defer tester.agent.Stop(tester.testCA.Context) getErr := make(chan error) go func() { - getErr <- tester.agent.Err(context.Background()) + getErr <- tester.agent.Err(tester.testCA.Context) }() select { @@ -443,10 +444,19 @@ func TestAgentRegistersSessionsWithSessionTracker(t *testing.T) { defer tester.StartAgent(t)() - establishedSessions, errCounter, closeClounter := tracker.Stats() - require.Equal(t, establishedSessions, 1) + var establishedSessions, errCounter, closeCounter int + // poll because session tracker gets called after the ready channel is closed + // (so there may be edge cases where the stats are called before the session + // tracker is called) + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + establishedSessions, errCounter, closeCounter = tracker.Stats() + if establishedSessions != 1 { + return errors.New("sessiontracker hasn't been called yet") + } + return nil + }, 3*time.Millisecond)) require.Equal(t, errCounter, 0) - require.Equal(t, closeClounter, 0) + require.Equal(t, closeCounter, 0) currSession, closedSessions := tester.dispatcher.GetSessions() require.NotNil(t, currSession) require.Len(t, closedSessions, 0) @@ -462,11 +472,11 @@ type agentTester struct { } func (a *agentTester) StartAgent(t *testing.T) func() { - go a.agent.Start(context.Background()) + go a.agent.Start(a.testCA.Context) getErr := make(chan error) go func() { - getErr <- a.agent.Err(context.Background()) + getErr <- a.agent.Err(a.testCA.Context) }() select { case err := <-getErr: @@ -477,7 +487,7 @@ func (a *agentTester) StartAgent(t *testing.T) func() { } return func() { - a.agent.Stop(context.Background()) + a.agent.Stop(a.testCA.Context) } } @@ -485,6 +495,7 @@ func agentTestEnv(t *testing.T, nodeChangeCh chan *NodeChanges, tlsChangeCh chan var cleanup []func() tc := cautils.NewTestCA(t) cleanup = append(cleanup, tc.Stop) + tc.Context = log.WithLogger(tc.Context, log.G(tc.Context).WithField("localDispatcher", localDispatcher)) agentSecurityConfig, err := tc.NewNodeConfig(ca.WorkerRole) require.NoError(t, err)