diff --git a/manager/logbroker/broker.go b/manager/logbroker/broker.go index f5ec2b30bd..5eededfc05 100644 --- a/manager/logbroker/broker.go +++ b/manager/logbroker/broker.go @@ -57,12 +57,12 @@ func New(store *store.MemoryStore) *LogBroker { } } -// Run the log broker -func (lb *LogBroker) Run(ctx context.Context) error { +// Start starts the log broker +func (lb *LogBroker) Start(ctx context.Context) error { lb.mu.Lock() + defer lb.mu.Unlock() if lb.cancelAll != nil { - lb.mu.Unlock() return errAlreadyRunning } @@ -71,12 +71,7 @@ func (lb *LogBroker) Run(ctx context.Context) error { lb.subscriptionQueue = watch.NewQueue() lb.registeredSubscriptions = make(map[string]*subscription) lb.subscriptionsByNode = make(map[string]map[*subscription]struct{}) - lb.mu.Unlock() - - select { - case <-lb.pctx.Done(): - return lb.pctx.Err() - } + return nil } // Stop stops the log broker @@ -234,8 +229,15 @@ func (lb *LogBroker) SubscribeLogs(request *api.SubscribeLogsRequest, stream api return err } + lb.mu.Lock() + pctx := lb.pctx + lb.mu.Unlock() + if pctx == nil { + return errNotRunning + } + subscription := lb.newSubscription(request.Selector, request.Options) - subscription.Run(lb.pctx) + subscription.Run(pctx) defer subscription.Stop() log := log.G(ctx).WithFields( @@ -257,8 +259,8 @@ func (lb *LogBroker) SubscribeLogs(request *api.SubscribeLogsRequest, stream api select { case <-ctx.Done(): return ctx.Err() - case <-lb.pctx.Done(): - return lb.pctx.Err() + case <-pctx.Done(): + return pctx.Err() case event := <-publishCh: publish := event.(*logMessage) if publish.completed { @@ -308,6 +310,13 @@ func (lb *LogBroker) ListenSubscriptions(request *api.ListenSubscriptionsRequest return err } + lb.mu.Lock() + pctx := lb.pctx + lb.mu.Unlock() + if pctx == nil { + return errNotRunning + } + lb.nodeConnected(remote.NodeID) defer lb.nodeDisconnected(remote.NodeID) @@ -329,7 +338,7 @@ func (lb *LogBroker) ListenSubscriptions(request *api.ListenSubscriptionsRequest select { case <-stream.Context().Done(): return stream.Context().Err() - case <-lb.pctx.Done(): + case <-pctx.Done(): return nil default: } @@ -362,7 +371,7 @@ func (lb *LogBroker) ListenSubscriptions(request *api.ListenSubscriptionsRequest } case <-stream.Context().Done(): return stream.Context().Err() - case <-lb.pctx.Done(): + case <-pctx.Done(): return nil } } diff --git a/manager/logbroker/broker_test.go b/manager/logbroker/broker_test.go index 53147d8590..ca825525f7 100644 --- a/manager/logbroker/broker_test.go +++ b/manager/logbroker/broker_test.go @@ -126,8 +126,8 @@ func TestLogBrokerLogs(t *testing.T) { wg.Wait() - // Make sure double Run throws an error - require.EqualError(t, broker.Run(ctx), errAlreadyRunning.Error()) + // Make sure double Start throws an error + require.EqualError(t, broker.Start(ctx), errAlreadyRunning.Error()) // Stop should work require.NoError(t, broker.Stop()) // Double stopping should fail @@ -780,7 +780,7 @@ func testLogBrokerEnv(t *testing.T) (context.Context, *testutils.TestCA, *LogBro } }() - go broker.Run(ctx) + require.NoError(t, broker.Start(ctx)) return ctx, tca, broker, logListener.Addr().String(), brokerListener.Addr().String(), func() { broker.Stop() diff --git a/manager/manager.go b/manager/manager.go index d17e8ec231..705576b0ac 100644 --- a/manager/manager.go +++ b/manager/manager.go @@ -129,6 +129,7 @@ type Manager struct { caserver *ca.Server dispatcher *dispatcher.Dispatcher logbroker *logbroker.LogBroker + watchServer *watchapi.Server replicatedOrchestrator *replicated.Orchestrator globalOrchestrator *global.Orchestrator taskReaper *taskreaper.TaskReaper @@ -220,6 +221,7 @@ func New(config *Config) (*Manager, error) { caserver: ca.NewServer(raftNode.MemoryStore(), config.SecurityConfig, config.RootCAPaths), dispatcher: dispatcher.New(raftNode, dispatcher.DefaultConfig()), logbroker: logbroker.New(raftNode.MemoryStore()), + watchServer: watchapi.NewServer(raftNode.MemoryStore()), server: grpc.NewServer(opts...), localserver: grpc.NewServer(opts...), raftNode: raftNode, @@ -397,13 +399,12 @@ func (m *Manager) Run(parent context.Context) error { } baseControlAPI := controlapi.NewServer(m.raftNode.MemoryStore(), m.raftNode, m.config.SecurityConfig, m.caserver, m.config.PluginGetter) - baseWatchAPI := watchapi.NewServer(m.raftNode.MemoryStore()) baseResourceAPI := resourceapi.New(m.raftNode.MemoryStore()) healthServer := health.NewHealthServer() localHealthServer := health.NewHealthServer() authenticatedControlAPI := api.NewAuthenticatedWrapperControlServer(baseControlAPI, authorize) - authenticatedWatchAPI := api.NewAuthenticatedWrapperWatchServer(baseWatchAPI, authorize) + authenticatedWatchAPI := api.NewAuthenticatedWrapperWatchServer(m.watchServer, authorize) authenticatedResourceAPI := api.NewAuthenticatedWrapperResourceAllocatorServer(baseResourceAPI, authorize) authenticatedLogsServerAPI := api.NewAuthenticatedWrapperLogsServer(m.logbroker, authorize) authenticatedLogBrokerAPI := api.NewAuthenticatedWrapperLogBrokerServer(m.logbroker, authorize) @@ -476,7 +477,7 @@ func (m *Manager) Run(parent context.Context) error { grpc_prometheus.Register(m.server) api.RegisterControlServer(m.localserver, localProxyControlAPI) - api.RegisterWatchServer(m.localserver, baseWatchAPI) + api.RegisterWatchServer(m.localserver, m.watchServer) api.RegisterLogsServer(m.localserver, localProxyLogsAPI) api.RegisterHealthServer(m.localserver, localHealthServer) api.RegisterDispatcherServer(m.localserver, localProxyDispatcherAPI) @@ -1000,11 +1001,13 @@ func (m *Manager) becomeLeader(ctx context.Context) { } }(m.dispatcher) - go func(lb *logbroker.LogBroker) { - if err := lb.Run(ctx); err != nil { - log.G(ctx).WithError(err).Error("LogBroker exited with an error") - } - }(m.logbroker) + if err := m.logbroker.Start(ctx); err != nil { + log.G(ctx).WithError(err).Error("LogBroker failed to start") + } + + if err := m.watchServer.Start(ctx); err != nil { + log.G(ctx).WithError(err).Error("watch server failed to start") + } go func(server *ca.Server) { if err := server.Run(ctx); err != nil { @@ -1058,6 +1061,7 @@ func (m *Manager) becomeLeader(ctx context.Context) { func (m *Manager) becomeFollower() { m.dispatcher.Stop() m.logbroker.Stop() + m.watchServer.Stop() m.caserver.Stop() if m.allocator != nil { diff --git a/manager/watchapi/server.go b/manager/watchapi/server.go index 07cdedbb36..6d49dca715 100644 --- a/manager/watchapi/server.go +++ b/manager/watchapi/server.go @@ -1,12 +1,24 @@ package watchapi import ( + "errors" + "sync" + "github.com/docker/swarmkit/manager/state/store" + "golang.org/x/net/context" +) + +var ( + errAlreadyRunning = errors.New("broker is already running") + errNotRunning = errors.New("broker is not running") ) // Server is the store API gRPC server. type Server struct { - store *store.MemoryStore + store *store.MemoryStore + mu sync.Mutex + pctx context.Context + cancelAll func() } // NewServer creates a store API server. @@ -15,3 +27,30 @@ func NewServer(store *store.MemoryStore) *Server { store: store, } } + +// Start starts the watch server. +func (s *Server) Start(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.cancelAll != nil { + return errAlreadyRunning + } + + s.pctx, s.cancelAll = context.WithCancel(ctx) + return nil +} + +// Stop stops the watch server. +func (s *Server) Stop() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.cancelAll == nil { + return errNotRunning + } + s.cancelAll() + s.cancelAll = nil + + return nil +} diff --git a/manager/watchapi/server_test.go b/manager/watchapi/server_test.go index cd2e43d896..8e03b7efc7 100644 --- a/manager/watchapi/server_test.go +++ b/manager/watchapi/server_test.go @@ -14,6 +14,8 @@ import ( "github.com/docker/swarmkit/manager/state/store" stateutils "github.com/docker/swarmkit/manager/state/testutils" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/grpclog" ) @@ -30,6 +32,7 @@ type testServer struct { } func (ts *testServer) Stop() { + ts.Server.Stop() ts.clientConn.Close() ts.grpcServer.Stop() ts.Store.Close() @@ -48,6 +51,8 @@ func newTestServer(t *testing.T) *testServer { ts.Server = NewServer(ts.Store) assert.NotNil(t, ts.Server) + require.NoError(t, ts.Server.Start(context.Background())) + temp, err := ioutil.TempFile("", "test-socket") assert.NoError(t, err) assert.NoError(t, temp.Close()) diff --git a/manager/watchapi/watch.go b/manager/watchapi/watch.go index 555b899743..53bed49f1c 100644 --- a/manager/watchapi/watch.go +++ b/manager/watchapi/watch.go @@ -17,6 +17,13 @@ import ( func (s *Server) Watch(request *api.WatchRequest, stream api.Watch_WatchServer) error { ctx := stream.Context() + s.mu.Lock() + pctx := s.pctx + s.mu.Unlock() + if pctx == nil { + return errNotRunning + } + watchArgs, err := api.ConvertWatchArgs(request.Entries) if err != nil { return grpc.Errorf(codes.InvalidArgument, "%s", err.Error()) @@ -39,6 +46,8 @@ func (s *Server) Watch(request *api.WatchRequest, stream api.Watch_WatchServer) select { case <-ctx.Done(): return ctx.Err() + case <-pctx.Done(): + return pctx.Err() case event := <-watch: if commitEvent, ok := event.(state.EventCommit); ok && len(events) > 0 { if err := stream.Send(&api.WatchMessage{Events: events, Version: commitEvent.Version}); err != nil {