From 6a9946b3caf85d0c6a41c307eabdb84bd6ff8617 Mon Sep 17 00:00:00 2001 From: Aaron Lehmann Date: Fri, 14 Jul 2017 13:59:32 -0700 Subject: [PATCH] watchapi: Terminate RPCs when the manager shuts down Because we use GracefulStop, RPCs need to finish on their own when the manager shuts down. The watchapi server formerly was completely stateless, and didn't have a notion of starting or stopping. This adds Start and Stop functions that control a context, adapted from the logbroker code. When the context is cancelled, outstanding RPCs will finish. There are also some related cleanups and fixes in the logbroker code. Signed-off-by: Aaron Lehmann --- manager/logbroker/broker.go | 37 +++++++++++++++++----------- manager/logbroker/broker_test.go | 6 ++--- manager/manager.go | 20 +++++++++------- manager/watchapi/server.go | 41 +++++++++++++++++++++++++++++++- manager/watchapi/server_test.go | 5 ++++ manager/watchapi/watch.go | 9 +++++++ 6 files changed, 92 insertions(+), 26 deletions(-) diff --git a/manager/logbroker/broker.go b/manager/logbroker/broker.go index f5ec2b30bd..5eededfc05 100644 --- a/manager/logbroker/broker.go +++ b/manager/logbroker/broker.go @@ -57,12 +57,12 @@ func New(store *store.MemoryStore) *LogBroker { } } -// Run the log broker -func (lb *LogBroker) Run(ctx context.Context) error { +// Start starts the log broker +func (lb *LogBroker) Start(ctx context.Context) error { lb.mu.Lock() + defer lb.mu.Unlock() if lb.cancelAll != nil { - lb.mu.Unlock() return errAlreadyRunning } @@ -71,12 +71,7 @@ func (lb *LogBroker) Run(ctx context.Context) error { lb.subscriptionQueue = watch.NewQueue() lb.registeredSubscriptions = make(map[string]*subscription) lb.subscriptionsByNode = make(map[string]map[*subscription]struct{}) - lb.mu.Unlock() - - select { - case <-lb.pctx.Done(): - return lb.pctx.Err() - } + return nil } // Stop stops the log broker @@ -234,8 +229,15 @@ func (lb *LogBroker) SubscribeLogs(request *api.SubscribeLogsRequest, stream api return err } + lb.mu.Lock() + pctx := lb.pctx + lb.mu.Unlock() + if pctx == nil { + return errNotRunning + } + subscription := lb.newSubscription(request.Selector, request.Options) - subscription.Run(lb.pctx) + subscription.Run(pctx) defer subscription.Stop() log := log.G(ctx).WithFields( @@ -257,8 +259,8 @@ func (lb *LogBroker) SubscribeLogs(request *api.SubscribeLogsRequest, stream api select { case <-ctx.Done(): return ctx.Err() - case <-lb.pctx.Done(): - return lb.pctx.Err() + case <-pctx.Done(): + return pctx.Err() case event := <-publishCh: publish := event.(*logMessage) if publish.completed { @@ -308,6 +310,13 @@ func (lb *LogBroker) ListenSubscriptions(request *api.ListenSubscriptionsRequest return err } + lb.mu.Lock() + pctx := lb.pctx + lb.mu.Unlock() + if pctx == nil { + return errNotRunning + } + lb.nodeConnected(remote.NodeID) defer lb.nodeDisconnected(remote.NodeID) @@ -329,7 +338,7 @@ func (lb *LogBroker) ListenSubscriptions(request *api.ListenSubscriptionsRequest select { case <-stream.Context().Done(): return stream.Context().Err() - case <-lb.pctx.Done(): + case <-pctx.Done(): return nil default: } @@ -362,7 +371,7 @@ func (lb *LogBroker) ListenSubscriptions(request *api.ListenSubscriptionsRequest } case <-stream.Context().Done(): return stream.Context().Err() - case <-lb.pctx.Done(): + case <-pctx.Done(): return nil } } diff --git a/manager/logbroker/broker_test.go b/manager/logbroker/broker_test.go index 53147d8590..ca825525f7 100644 --- a/manager/logbroker/broker_test.go +++ b/manager/logbroker/broker_test.go @@ -126,8 +126,8 @@ func TestLogBrokerLogs(t *testing.T) { wg.Wait() - // Make sure double Run throws an error - require.EqualError(t, broker.Run(ctx), errAlreadyRunning.Error()) + // Make sure double Start throws an error + require.EqualError(t, broker.Start(ctx), errAlreadyRunning.Error()) // Stop should work require.NoError(t, broker.Stop()) // Double stopping should fail @@ -780,7 +780,7 @@ func testLogBrokerEnv(t *testing.T) (context.Context, *testutils.TestCA, *LogBro } }() - go broker.Run(ctx) + require.NoError(t, broker.Start(ctx)) return ctx, tca, broker, logListener.Addr().String(), brokerListener.Addr().String(), func() { broker.Stop() diff --git a/manager/manager.go b/manager/manager.go index d17e8ec231..705576b0ac 100644 --- a/manager/manager.go +++ b/manager/manager.go @@ -129,6 +129,7 @@ type Manager struct { caserver *ca.Server dispatcher *dispatcher.Dispatcher logbroker *logbroker.LogBroker + watchServer *watchapi.Server replicatedOrchestrator *replicated.Orchestrator globalOrchestrator *global.Orchestrator taskReaper *taskreaper.TaskReaper @@ -220,6 +221,7 @@ func New(config *Config) (*Manager, error) { caserver: ca.NewServer(raftNode.MemoryStore(), config.SecurityConfig, config.RootCAPaths), dispatcher: dispatcher.New(raftNode, dispatcher.DefaultConfig()), logbroker: logbroker.New(raftNode.MemoryStore()), + watchServer: watchapi.NewServer(raftNode.MemoryStore()), server: grpc.NewServer(opts...), localserver: grpc.NewServer(opts...), raftNode: raftNode, @@ -397,13 +399,12 @@ func (m *Manager) Run(parent context.Context) error { } baseControlAPI := controlapi.NewServer(m.raftNode.MemoryStore(), m.raftNode, m.config.SecurityConfig, m.caserver, m.config.PluginGetter) - baseWatchAPI := watchapi.NewServer(m.raftNode.MemoryStore()) baseResourceAPI := resourceapi.New(m.raftNode.MemoryStore()) healthServer := health.NewHealthServer() localHealthServer := health.NewHealthServer() authenticatedControlAPI := api.NewAuthenticatedWrapperControlServer(baseControlAPI, authorize) - authenticatedWatchAPI := api.NewAuthenticatedWrapperWatchServer(baseWatchAPI, authorize) + authenticatedWatchAPI := api.NewAuthenticatedWrapperWatchServer(m.watchServer, authorize) authenticatedResourceAPI := api.NewAuthenticatedWrapperResourceAllocatorServer(baseResourceAPI, authorize) authenticatedLogsServerAPI := api.NewAuthenticatedWrapperLogsServer(m.logbroker, authorize) authenticatedLogBrokerAPI := api.NewAuthenticatedWrapperLogBrokerServer(m.logbroker, authorize) @@ -476,7 +477,7 @@ func (m *Manager) Run(parent context.Context) error { grpc_prometheus.Register(m.server) api.RegisterControlServer(m.localserver, localProxyControlAPI) - api.RegisterWatchServer(m.localserver, baseWatchAPI) + api.RegisterWatchServer(m.localserver, m.watchServer) api.RegisterLogsServer(m.localserver, localProxyLogsAPI) api.RegisterHealthServer(m.localserver, localHealthServer) api.RegisterDispatcherServer(m.localserver, localProxyDispatcherAPI) @@ -1000,11 +1001,13 @@ func (m *Manager) becomeLeader(ctx context.Context) { } }(m.dispatcher) - go func(lb *logbroker.LogBroker) { - if err := lb.Run(ctx); err != nil { - log.G(ctx).WithError(err).Error("LogBroker exited with an error") - } - }(m.logbroker) + if err := m.logbroker.Start(ctx); err != nil { + log.G(ctx).WithError(err).Error("LogBroker failed to start") + } + + if err := m.watchServer.Start(ctx); err != nil { + log.G(ctx).WithError(err).Error("watch server failed to start") + } go func(server *ca.Server) { if err := server.Run(ctx); err != nil { @@ -1058,6 +1061,7 @@ func (m *Manager) becomeLeader(ctx context.Context) { func (m *Manager) becomeFollower() { m.dispatcher.Stop() m.logbroker.Stop() + m.watchServer.Stop() m.caserver.Stop() if m.allocator != nil { diff --git a/manager/watchapi/server.go b/manager/watchapi/server.go index 07cdedbb36..6d49dca715 100644 --- a/manager/watchapi/server.go +++ b/manager/watchapi/server.go @@ -1,12 +1,24 @@ package watchapi import ( + "errors" + "sync" + "github.com/docker/swarmkit/manager/state/store" + "golang.org/x/net/context" +) + +var ( + errAlreadyRunning = errors.New("broker is already running") + errNotRunning = errors.New("broker is not running") ) // Server is the store API gRPC server. type Server struct { - store *store.MemoryStore + store *store.MemoryStore + mu sync.Mutex + pctx context.Context + cancelAll func() } // NewServer creates a store API server. @@ -15,3 +27,30 @@ func NewServer(store *store.MemoryStore) *Server { store: store, } } + +// Start starts the watch server. +func (s *Server) Start(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.cancelAll != nil { + return errAlreadyRunning + } + + s.pctx, s.cancelAll = context.WithCancel(ctx) + return nil +} + +// Stop stops the watch server. +func (s *Server) Stop() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.cancelAll == nil { + return errNotRunning + } + s.cancelAll() + s.cancelAll = nil + + return nil +} diff --git a/manager/watchapi/server_test.go b/manager/watchapi/server_test.go index cd2e43d896..8e03b7efc7 100644 --- a/manager/watchapi/server_test.go +++ b/manager/watchapi/server_test.go @@ -14,6 +14,8 @@ import ( "github.com/docker/swarmkit/manager/state/store" stateutils "github.com/docker/swarmkit/manager/state/testutils" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/grpclog" ) @@ -30,6 +32,7 @@ type testServer struct { } func (ts *testServer) Stop() { + ts.Server.Stop() ts.clientConn.Close() ts.grpcServer.Stop() ts.Store.Close() @@ -48,6 +51,8 @@ func newTestServer(t *testing.T) *testServer { ts.Server = NewServer(ts.Store) assert.NotNil(t, ts.Server) + require.NoError(t, ts.Server.Start(context.Background())) + temp, err := ioutil.TempFile("", "test-socket") assert.NoError(t, err) assert.NoError(t, temp.Close()) diff --git a/manager/watchapi/watch.go b/manager/watchapi/watch.go index 555b899743..53bed49f1c 100644 --- a/manager/watchapi/watch.go +++ b/manager/watchapi/watch.go @@ -17,6 +17,13 @@ import ( func (s *Server) Watch(request *api.WatchRequest, stream api.Watch_WatchServer) error { ctx := stream.Context() + s.mu.Lock() + pctx := s.pctx + s.mu.Unlock() + if pctx == nil { + return errNotRunning + } + watchArgs, err := api.ConvertWatchArgs(request.Entries) if err != nil { return grpc.Errorf(codes.InvalidArgument, "%s", err.Error()) @@ -39,6 +46,8 @@ func (s *Server) Watch(request *api.WatchRequest, stream api.Watch_WatchServer) select { case <-ctx.Done(): return ctx.Err() + case <-pctx.Done(): + return pctx.Err() case event := <-watch: if commitEvent, ok := event.(state.EventCommit); ok && len(events) > 0 { if err := stream.Send(&api.WatchMessage{Events: events, Version: commitEvent.Version}); err != nil {