diff --git a/ca/server.go b/ca/server.go index 4057edd31f..562fcdd0e4 100644 --- a/ca/server.go +++ b/ca/server.go @@ -392,14 +392,12 @@ func (s *Server) Run(ctx context.Context) error { if len(clusters) != 1 { return errors.New("could not find cluster object") } - s.updateCluster(ctx, clusters[0]) - + s.UpdateRootCA(ctx, clusters[0]) // call once to ensure that the join tokens are always set nodes, err = store.FindNodes(readTx, store.All) return err }, state.EventCreateNode{}, state.EventUpdateNode{}, - state.EventUpdateCluster{}, ) // Do this after updateCluster has been called, so isRunning never @@ -434,6 +432,12 @@ func (s *Server) Run(ctx context.Context) error { // Watch for new nodes being created, new nodes being updated, and changes // to the cluster for { + select { + case <-ctx.Done(): + return nil + default: + } + select { case event := <-updates: switch v := event.(type) { @@ -445,8 +449,6 @@ func (s *Server) Run(ctx context.Context) error { if !isFinalState(v.Node.Certificate.Status) { s.evaluateAndSignNodeCert(ctx, v.Node) } - case state.EventUpdateCluster: - s.updateCluster(ctx, v.Cluster) } case <-ticker.C: for _, node := range s.pending { @@ -512,9 +514,10 @@ func (s *Server) isRunning() bool { return true } -// updateCluster is called when there are cluster changes, and it ensures that the local RootCA is -// always aware of changes in clusterExpiry and the Root CA key material -func (s *Server) updateCluster(ctx context.Context, cluster *api.Cluster) { +// UpdateRootCA is called when there are cluster changes, and it ensures that the local RootCA is +// always aware of changes in clusterExpiry and the Root CA key material - this can be called by +// anything to update the root CA material +func (s *Server) UpdateRootCA(ctx context.Context, cluster *api.Cluster) { s.mu.Lock() s.joinTokens = cluster.RootCA.JoinTokens.Copy() s.mu.Unlock() diff --git a/ca/testutils/cautils.go b/ca/testutils/cautils.go index d0f155a157..881fce6e6a 100644 --- a/ca/testutils/cautils.go +++ b/ca/testutils/cautils.go @@ -21,6 +21,7 @@ import ( "github.com/docker/swarmkit/connectionbroker" "github.com/docker/swarmkit/identity" "github.com/docker/swarmkit/ioutils" + "github.com/docker/swarmkit/manager/state" "github.com/docker/swarmkit/manager/state/store" "github.com/docker/swarmkit/remotes" "github.com/opencontainers/go-digest" @@ -48,10 +49,12 @@ type TestCA struct { ManagerToken string ConnBroker *connectionbroker.Broker KeyReadWriter *ca.KeyReadWriter + watchCancel func() } // Stop cleans up after TestCA func (tc *TestCA) Stop() { + tc.watchCancel() os.RemoveAll(tc.TempDir) for _, conn := range tc.Conns { conn.Close() @@ -174,6 +177,31 @@ func NewTestCA(t *testing.T, krwGenerators ...func(ca.CertPaths) *ca.KeyReadWrit ctx := context.Background() + clusterWatch, clusterWatchCancel, err := store.ViewAndWatch( + s, func(tx store.ReadTx) error { + cluster := store.GetCluster(tx, organization) + caServer.UpdateRootCA(ctx, cluster) + return nil + }, + state.EventUpdateCluster{ + Cluster: &api.Cluster{ID: organization}, + Checks: []state.ClusterCheckFunc{state.ClusterCheckID}, + }, + ) + assert.NoError(t, err) + go func() { + for { + select { + case event := <-clusterWatch: + clusterEvent := event.(state.EventUpdateCluster) + caServer.UpdateRootCA(ctx, clusterEvent.Cluster) + case <-ctx.Done(): + clusterWatchCancel() + return + } + } + }() + go grpcServer.Serve(l) go caServer.Run(ctx) @@ -202,6 +230,7 @@ func NewTestCA(t *testing.T, krwGenerators ...func(ca.CertPaths) *ca.KeyReadWrit ManagerToken: managerToken, ConnBroker: connectionbroker.New(remotes), KeyReadWriter: krw, + watchCancel: clusterWatchCancel, } } diff --git a/manager/manager.go b/manager/manager.go index 27496df7d0..7d5fcd0983 100644 --- a/manager/manager.go +++ b/manager/manager.go @@ -514,7 +514,7 @@ func (m *Manager) Run(parent context.Context) error { } raftConfig := c.Spec.Raft - if err := m.watchForKEKChanges(ctx); err != nil { + if err := m.watchForClusterChanges(ctx); err != nil { return err } @@ -679,7 +679,7 @@ func (m *Manager) updateKEK(ctx context.Context, cluster *api.Cluster) error { return nil } -func (m *Manager) watchForKEKChanges(ctx context.Context) error { +func (m *Manager) watchForClusterChanges(ctx context.Context) error { clusterID := m.config.SecurityConfig.ClientTLSCreds.Organization() clusterWatch, clusterWatchCancel, err := store.ViewAndWatch(m.raftNode.MemoryStore(), func(tx store.ReadTx) error { @@ -687,6 +687,7 @@ func (m *Manager) watchForKEKChanges(ctx context.Context) error { if cluster == nil { return fmt.Errorf("unable to get current cluster") } + m.caserver.UpdateRootCA(ctx, cluster) return m.updateKEK(ctx, cluster) }, state.EventUpdateCluster{ @@ -702,6 +703,7 @@ func (m *Manager) watchForKEKChanges(ctx context.Context) error { select { case event := <-clusterWatch: clusterEvent := event.(state.EventUpdateCluster) + m.caserver.UpdateRootCA(ctx, clusterEvent.Cluster) m.updateKEK(ctx, clusterEvent.Cluster) case <-ctx.Done(): clusterWatchCancel() diff --git a/manager/manager_test.go b/manager/manager_test.go index 1d0bd43d9d..2406dfc484 100644 --- a/manager/manager_test.go +++ b/manager/manager_test.go @@ -400,3 +400,110 @@ func TestManagerLockUnlock(t *testing.T) { // error. <-done } + +// If the root CA material is updated in the memory store, a manager will update its own +// security configs even if it's "not the leader" (which we will fake by calling `becomeFollower`) +func TestManagerUpdatesSecurityConfig(t *testing.T) { + ctx := context.Background() + + temp, err := ioutil.TempFile("", "test-manager-update-security-config") + require.NoError(t, err) + require.NoError(t, temp.Close()) + require.NoError(t, os.Remove(temp.Name())) + + defer os.RemoveAll(temp.Name()) + + stateDir, err := ioutil.TempDir("", "test-raft") + require.NoError(t, err) + defer os.RemoveAll(stateDir) + + tc := testutils.NewTestCA(t) + defer tc.Stop() + + managerSecurityConfig, err := tc.NewNodeConfig(ca.ManagerRole) + require.NoError(t, err) + + _, _, err = managerSecurityConfig.KeyReader().Read() + require.NoError(t, err) + + m, err := New(&Config{ + RemoteAPI: &RemoteAddrs{ListenAddr: "127.0.0.1:0"}, + ControlAPI: temp.Name(), + StateDir: stateDir, + SecurityConfig: managerSecurityConfig, + }) + require.NoError(t, err) + require.NotNil(t, m) + + done := make(chan error) + defer close(done) + go func() { + done <- m.Run(ctx) + }() + + // wait until the CA server is running + opts := []grpc.DialOption{ + grpc.WithTimeout(10 * time.Second), + grpc.WithTransportCredentials(managerSecurityConfig.ClientTLSCreds), + } + + conn, err := grpc.Dial(m.Addr(), opts...) + require.NoError(t, err) + defer func() { + require.NoError(t, conn.Close()) + }() + + client := api.NewCAClient(conn) + + require.NoError(t, raftutils.PollFuncWithTimeout(nil, func() error { + ctx, _ := context.WithTimeout(context.Background(), 500*time.Millisecond) + _, err := client.GetRootCACertificate(ctx, &api.GetRootCACertificateRequest{}) + return err + }, time.Second)) + + // wait until the cluster is up + var clusters []*api.Cluster + + require.NoError(t, raftutils.PollFuncWithTimeout(nil, func() error { + var err error + m.raftNode.MemoryStore().View(func(tx store.ReadTx) { + clusters, err = store.FindClusters(tx, store.ByName(store.DefaultClusterName)) + }) + if err != nil { + return err + } + if len(clusters) == 0 { + return fmt.Errorf("cluster not ready yet") + } + return nil + }, 1*time.Second)) + + // stop running CA server and other leader functions + m.becomeFollower() + + newRootCert, _, err := testutils.CreateRootCertAndKey("rootOther") + require.NoError(t, err) + updatedCA := append(tc.RootCA.Cert, newRootCert...) + + // Update the RootCA to have a bundle + require.NoError(t, m.raftNode.MemoryStore().Update(func(tx store.Tx) error { + cluster := store.GetCluster(tx, clusters[0].ID) + cluster.RootCA.CACert = updatedCA + return store.UpdateCluster(tx, cluster) + })) + + // wait for the manager's security config to be updated + require.NoError(t, raftutils.PollFuncWithTimeout(nil, func() error { + if !bytes.Equal(managerSecurityConfig.RootCA().Cert, updatedCA) { + return fmt.Errorf("root CA not updated yet") + } + return nil + }, 1*time.Second)) + + m.Stop(ctx, false) + + // After stopping we should MAY receive an error from ListenAndServe if + // all this happened before WaitForLeader completed, so don't check the + // error. + <-done +}