From 51309e85cefe5d5a9f305ff3ab2410c0ab617e09 Mon Sep 17 00:00:00 2001 From: cyli Date: Fri, 21 Apr 2017 18:31:34 -0700 Subject: [PATCH] =?UTF-8?q?Connections=20to=20a=20local=20dispatcher=20can?= =?UTF-8?q?=E2=80=99t=20really=20be=20closed,=20so=20a=20session=20can?= =?UTF-8?q?=E2=80=99t=20really=20be=20restarted=20because=20closing=20a=20?= =?UTF-8?q?session=20just=20closes=20the=20connection.=20=20When=20this=20?= =?UTF-8?q?happens,=20it=20just=20starts=20up=20another=20session=20withou?= =?UTF-8?q?t=20closing=20the=20previous=20one.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Since we need to restart a session to push new TLS data up to the dispatcher from the agent, change "closing" a session to mean first shutting down all the clients with a context.cancel before closing the connection. Signed-off-by: cyli --- agent/agent_test.go | 165 ++++++++++++++++++++++++++++++++++++--- agent/session.go | 18 ++++- agent/testutils/fakes.go | 52 ++++++++++-- node/node_test.go | 2 +- 4 files changed, 215 insertions(+), 22 deletions(-) diff --git a/agent/agent_test.go b/agent/agent_test.go index 7966ae5674..2f8ca0c4c8 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -1,11 +1,18 @@ package agent import ( + "crypto/tls" "errors" "fmt" + "net" + "os" + "sync" "testing" "time" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + events "github.com/docker/go-events" agentutils "github.com/docker/swarmkit/agent/testutils" "github.com/docker/swarmkit/api" @@ -14,11 +21,27 @@ import ( "github.com/docker/swarmkit/connectionbroker" "github.com/docker/swarmkit/remotes" "github.com/docker/swarmkit/testutils" + "github.com/docker/swarmkit/xnet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/context" ) +var localDispatcher = false + +// TestMain runs every test in this file twice - once with a local dispatcher, and +// once again with a remote dispatcher +func TestMain(m *testing.M) { + localDispatcher = false + dispatcherRPCTimeout = 500 * time.Millisecond + if status := m.Run(); status != 0 { + os.Exit(status) + } + + localDispatcher = true + os.Exit(m.Run()) +} + func TestAgent(t *testing.T) { // TODO(stevvooe): The current agent is fairly monolithic, making it hard // to test without implementing or mocking an entire master. We'd like to @@ -237,6 +260,13 @@ func TestSessionRestartedOnNodeDescriptionChange(t *testing.T) { require.Equal(t, "testAgent", gotSession.Description.Hostname) currSession = gotSession + // If nothing changes, the session is not re-established + tlsCh <- gotSession.Description.TLSInfo + time.Sleep(1 * time.Second) + gotSession, closedSessions = tester.dispatcher.GetSessions() + require.Equal(t, currSession, gotSession) + require.Len(t, closedSessions, 1) + newTLSInfo := &api.NodeTLSInfo{ TrustRoot: cautils.ECDSA256SHA256Cert, CertIssuerPublicKey: []byte("public key"), @@ -259,12 +289,71 @@ func TestSessionRestartedOnNodeDescriptionChange(t *testing.T) { require.Equal(t, newTLSInfo, gotSession.Description.TLSInfo) } +// If the dispatcher returns an error, if it times out, or if it's unreachable, no matter +// what the agent attempts to reconnect and rebuild a new session. +func TestSessionReconnectsIfDispatcherErrors(t *testing.T) { + tlsCh := make(chan events.Event, 1) + defer close(tlsCh) + + tester := agentTestEnv(t, nil, tlsCh) + defer tester.cleanup() + + // create a second dispatcher we can fall back on + anotherConfig, err := tester.testCA.NewNodeConfig(ca.ManagerRole) + require.NoError(t, err) + anotherDispatcher, stop := agentutils.NewMockDispatcher(t, anotherConfig, false) // this one is not local, because the other one may be + defer stop() + + var counter int + anotherDispatcher.SetSessionHandler(func(r *api.SessionRequest, stream api.Dispatcher_SessionServer) error { + if counter == 0 { + counter++ + return errors.New("terminate immediately") + } + // hang forever until the other side cancels, and then set the session to nil so we use the default one + defer anotherDispatcher.SetSessionHandler(nil) + <-stream.Context().Done() + return stream.Context().Err() + }) + + // ok, agent should have connect to the first dispatcher by now - if it has, kill the first dispatcher and ensure + // the agent connects to the second one + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + gotSession, closedSessions := tester.dispatcher.GetSessions() + if gotSession == nil { + return errors.New("no current session") + } + if len(closedSessions) != 0 { + return fmt.Errorf("expecting 0 closed sessions, got %d", len(closedSessions)) + } + return nil + }, 2*time.Second)) + tester.stopDispatcher() + tester.remotes.setPeer(api.Peer{Addr: anotherDispatcher.Addr}) + tester.agent.config.ConnBroker.SetLocalConn(nil) + + // It should have connected with the second dispatcher 3 times - first because the first dispatcher died, + // second because the dispatcher returned an error, third time because the session timed out. So there should + // be 2 closed sessions. + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + gotSession, closedSessions := anotherDispatcher.GetSessions() + if gotSession == nil { + return errors.New("no current session") + } + if len(closedSessions) != 2 { + return fmt.Errorf("expecting 2 closed sessions, got %d", len(closedSessions)) + } + return nil + }, 5*time.Second)) +} + type agentTester struct { - agent *Agent - dispatcher *agentutils.MockDispatcher - executor *agentutils.TestExecutor - cleanup func() - testCA *cautils.TestCA + agent *Agent + dispatcher *agentutils.MockDispatcher + executor *agentutils.TestExecutor + stopDispatcher, cleanup func() + testCA *cautils.TestCA + remotes *fakeRemotes } func agentTestEnv(t *testing.T, nodeChangeCh chan *NodeChanges, tlsChangeCh chan events.Event) *agentTester { @@ -277,10 +366,28 @@ func agentTestEnv(t *testing.T, nodeChangeCh chan *NodeChanges, tlsChangeCh chan managerSecurityConfig, err := tc.NewNodeConfig(ca.ManagerRole) require.NoError(t, err) - mockDispatcher, mockDispatcherStop := agentutils.NewMockDispatcher(t, managerSecurityConfig) + mockDispatcher, mockDispatcherStop := agentutils.NewMockDispatcher(t, managerSecurityConfig, localDispatcher) cleanup = append(cleanup, mockDispatcherStop) - remotes := remotes.NewRemotes(api.Peer{Addr: mockDispatcher.Addr}) + fr := &fakeRemotes{} + broker := connectionbroker.New(fr) + if localDispatcher { + insecureCreds := credentials.NewTLS(&tls.Config{InsecureSkipVerify: true}) + conn, err := grpc.Dial( + mockDispatcher.Addr, + grpc.WithTransportCredentials(insecureCreds), + grpc.WithDialer( + func(addr string, timeout time.Duration) (net.Conn, error) { + return xnet.DialTimeoutLocal(addr, timeout) + }), + ) + require.NoError(t, err) + cleanup = append(cleanup, func() { conn.Close() }) + + broker.SetLocalConn(conn) + } else { + fr.setPeer(api.Peer{Addr: mockDispatcher.Addr}) + } db, cleanupStorage := storageTestEnv(t) cleanup = append(cleanup, func() { cleanupStorage() }) @@ -289,7 +396,7 @@ func agentTestEnv(t *testing.T, nodeChangeCh chan *NodeChanges, tlsChangeCh chan agent, err := New(&Config{ Executor: executor, - ConnBroker: connectionbroker.New(remotes), + ConnBroker: broker, Credentials: agentSecurityConfig.ClientTLSCreds, DB: db, NotifyNodeChange: nodeChangeCh, @@ -321,15 +428,49 @@ func agentTestEnv(t *testing.T, nodeChangeCh chan *NodeChanges, tlsChangeCh chan } return &agentTester{ - agent: agent, - dispatcher: mockDispatcher, - executor: executor, - testCA: tc, + agent: agent, + dispatcher: mockDispatcher, + stopDispatcher: mockDispatcherStop, + executor: executor, + testCA: tc, cleanup: func() { // go in reverse order for i := len(cleanup) - 1; i >= 0; i-- { cleanup[i]() } }, + remotes: fr, } } + +// fakeRemotes is a Remotes interface that just always selects the current remote until +// it is switched out +type fakeRemotes struct { + mu sync.Mutex + peer api.Peer +} + +func (f *fakeRemotes) Weights() map[api.Peer]int { + f.mu.Lock() + defer f.mu.Unlock() + return map[api.Peer]int{f.peer: 1} +} + +func (f *fakeRemotes) Select(...string) (api.Peer, error) { + f.mu.Lock() + defer f.mu.Unlock() + return f.peer, nil +} + +// do nothing +func (f *fakeRemotes) Observe(peer api.Peer, weight int) {} +func (f *fakeRemotes) ObserveIfExists(peer api.Peer, weight int) {} +func (f *fakeRemotes) Remove(addrs ...api.Peer) {} + +func (f *fakeRemotes) setPeer(p api.Peer) { + f.mu.Lock() + f.peer = p + f.mu.Unlock() +} + +var _ remotes.Remotes = &fakeRemotes{} diff --git a/agent/session.go b/agent/session.go index e15714a705..36a3a375cb 100644 --- a/agent/session.go +++ b/agent/session.go @@ -14,9 +14,8 @@ import ( "google.golang.org/grpc/codes" ) -const dispatcherRPCTimeout = 5 * time.Second - var ( + dispatcherRPCTimeout = 5 * time.Second errSessionDisconnect = errors.New("agent: session disconnect") // instructed to disconnect errSessionClosed = errors.New("agent: session closed") ) @@ -39,12 +38,14 @@ type session struct { assignments chan *api.AssignmentsMessage subscriptions chan *api.SubscriptionMessage + cancel func() // this is assumed to be never nil, and set whenever a session is created registered chan struct{} // closed registration closed chan struct{} closeOnce sync.Once } func newSession(ctx context.Context, agent *Agent, delay time.Duration, sessionID string, description *api.NodeDescription) *session { + sessionCtx, sessionCancel := context.WithCancel(ctx) s := &session{ agent: agent, sessionID: sessionID, @@ -54,6 +55,7 @@ func newSession(ctx context.Context, agent *Agent, delay time.Duration, sessionI subscriptions: make(chan *api.SubscriptionMessage), registered: make(chan struct{}), closed: make(chan struct{}), + cancel: sessionCancel, } // TODO(stevvooe): Need to move connection management up a level or create @@ -69,7 +71,7 @@ func newSession(ctx context.Context, agent *Agent, delay time.Duration, sessionI } s.conn = cc - go s.run(ctx, delay, description) + go s.run(sessionCtx, delay, description) return s } @@ -114,6 +116,14 @@ func (s *session) start(ctx context.Context, description *api.NodeDescription) e // Note: we don't defer cancellation of this context, because the // streaming RPC is used after this function returned. We only cancel // it in the timeout case to make sure the goroutine completes. + + // We also fork this context again from the `run` context, because on + // `dispatcherRPCTimeout`, we want to cancel establishing a session and + // return an error. If we cancel the `run` context instead of forking, + // then in `run` it's possible that we just terminate the function because + // `ctx` is done and hence fail to propagate the timeout error to the agent. + // If the error is not propogated to the agent, the agent will not close + // the session or rebuild a new sesssion. sessionCtx, cancelSession := context.WithCancel(ctx) // Need to run Session in a goroutine since there's no way to set a @@ -402,10 +412,10 @@ func (s *session) sendError(err error) { // of event loop. func (s *session) close() error { s.closeOnce.Do(func() { + s.cancel() if s.conn != nil { s.conn.Close(false) } - close(s.closed) }) diff --git a/agent/testutils/fakes.go b/agent/testutils/fakes.go index 31625bbac8..d8de14cc84 100644 --- a/agent/testutils/fakes.go +++ b/agent/testutils/fakes.go @@ -1,7 +1,10 @@ package testutils import ( + "io/ioutil" "net" + "os" + "path/filepath" "sync" "testing" "time" @@ -114,12 +117,16 @@ func (t *TestController) Close() error { return nil } +// SessionHandler is an injectable function that can be used handle session requests +type SessionHandler func(*api.SessionRequest, api.Dispatcher_SessionServer) error + // MockDispatcher is a fake dispatcher that one agent at a time can connect to type MockDispatcher struct { mu sync.Mutex sessionCh chan *api.SessionMessage openSession *api.SessionRequest closedSessions []*api.SessionRequest + sessionHandler SessionHandler Addr string } @@ -153,6 +160,7 @@ func (m *MockDispatcher) Heartbeat(context.Context, *api.HeartbeatRequest) (*api // Session allows a session to be established, and sends the node info func (m *MockDispatcher) Session(r *api.SessionRequest, stream api.Dispatcher_SessionServer) error { m.mu.Lock() + handler := m.sessionHandler m.openSession = r m.mu.Unlock() defer func() { @@ -162,6 +170,10 @@ func (m *MockDispatcher) Session(r *api.SessionRequest, stream api.Dispatcher_Se m.openSession = nil }() + if handler != nil { + return handler(r, stream) + } + // send the initial message first if err := stream.Send(&api.SessionMessage{ SessionID: r.SessionID, @@ -200,11 +212,35 @@ func (m *MockDispatcher) SessionMessageChannel() chan<- *api.SessionMessage { return m.sessionCh } +// SetSessionHandler lets you inject a custom function to handle session requests +func (m *MockDispatcher) SetSessionHandler(s SessionHandler) { + m.mu.Lock() + defer m.mu.Unlock() + m.sessionHandler = s +} + // NewMockDispatcher starts and returns a mock dispatcher instance that can be connected to -func NewMockDispatcher(t *testing.T, secConfig *ca.SecurityConfig) (*MockDispatcher, func()) { - l, err := net.Listen("tcp", "127.0.0.1:0") - addr := l.Addr().String() - require.NoError(t, err) +func NewMockDispatcher(t *testing.T, secConfig *ca.SecurityConfig, local bool) (*MockDispatcher, func()) { + var ( + l net.Listener + err error + addr string + cleanup func() + ) + if local { + tempDir, err := ioutil.TempDir("", "local-dispatcher-socket") + require.NoError(t, err) + addr = filepath.Join(tempDir, "socket") + l, err = net.Listen("unix", addr) + require.NoError(t, err) + cleanup = func() { + os.RemoveAll(tempDir) + } + } else { + l, err = net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + addr = l.Addr().String() + } serverOpts := []grpc.ServerOption{grpc.Creds(secConfig.ServerTLSCreds)} s := grpc.NewServer(serverOpts...) @@ -215,5 +251,11 @@ func NewMockDispatcher(t *testing.T, secConfig *ca.SecurityConfig) (*MockDispatc } api.RegisterDispatcherServer(s, m) go s.Serve(l) - return m, s.Stop + return m, func() { + l.Close() + s.Stop() + if cleanup != nil { + cleanup() + } + } } diff --git a/node/node_test.go b/node/node_test.go index b4ae8d58bc..83653673c3 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -318,7 +318,7 @@ func TestAgentRespectsDispatcherRootCAUpdate(t *testing.T) { ca.WorkerRole, managerSecConfig.ServerTLSCreds.Organization()) require.NoError(t, err) - mockDispatcher, cleanup := agentutils.NewMockDispatcher(t, managerSecConfig) + mockDispatcher, cleanup := agentutils.NewMockDispatcher(t, managerSecConfig, false) defer cleanup() cfg := &Config{