From f1d0d24df390eb578907b8a8a673b0777038e1fb Mon Sep 17 00:00:00 2001 From: cyli Date: Mon, 3 Apr 2017 18:52:21 -0700 Subject: [PATCH 1/5] Add a reconciliation loop to ca.Server which, on root rotation, updates all the nodes to have an IssuanceStateRotate to trigger all the nodes to get new certificates. When all the nodes have rotated their certificates to be signed by the desired issuer, complete root rotation. Signed-off-by: cyli --- ca/server.go | 195 +++++++++++- ca/server_test.go | 673 +++++++++++++++++++++++++++++++++++++++- ca/testutils/cautils.go | 6 +- 3 files changed, 864 insertions(+), 10 deletions(-) diff --git a/ca/server.go b/ca/server.go index f3fa9cb6c8..85c18eb75d 100644 --- a/ca/server.go +++ b/ca/server.go @@ -4,10 +4,12 @@ import ( "bytes" "crypto/subtle" "crypto/x509" + "fmt" "sync" "time" "github.com/Sirupsen/logrus" + "github.com/cloudflare/cfssl/helpers" "github.com/docker/swarmkit/api" "github.com/docker/swarmkit/api/equality" "github.com/docker/swarmkit/identity" @@ -22,6 +24,7 @@ import ( const ( defaultReconciliationRetryInterval = 10 * time.Second + defaultRootReconciliationInterval = 3 * time.Second ) // APISecurityConfigUpdater knows how to update a SecurityConfig from an api.Cluster object @@ -63,6 +66,11 @@ type Server struct { // before we update the security config with the new root CA, we need to be able to save the root certs rootPaths CertPaths + + // lets us know if we're already doing a reconcile loop to tell nodes to rotate their + // TLS certificates and watching for root rotation completion + rootCAReconciliationInProgress bool + rootReconciliationRetryInterval time.Duration } // DefaultCAConfig returns the default CA Config, with a default expiration. @@ -75,12 +83,13 @@ func DefaultCAConfig() api.CAConfig { // NewServer creates a CA API server. func NewServer(store *store.MemoryStore, securityConfig *SecurityConfig, rootCAPaths CertPaths) *Server { return &Server{ - store: store, - securityConfig: securityConfig, - pending: make(map[string]*api.Node), - started: make(chan struct{}), - reconciliationRetryInterval: defaultReconciliationRetryInterval, - rootPaths: rootCAPaths, + store: store, + securityConfig: securityConfig, + pending: make(map[string]*api.Node), + started: make(chan struct{}), + reconciliationRetryInterval: defaultReconciliationRetryInterval, + rootReconciliationRetryInterval: defaultRootReconciliationInterval, + rootPaths: rootCAPaths, } } @@ -90,6 +99,12 @@ func (s *Server) SetReconciliationRetryInterval(reconciliationRetryInterval time s.reconciliationRetryInterval = reconciliationRetryInterval } +// SetRootReconciliationInterval changes the time interval between root rotation +// reconciliation attempts. This function must be called before Run. +func (s *Server) SetRootReconciliationInterval(interval time.Duration) { + s.rootReconciliationRetryInterval = interval +} + // GetUnlockKey is responsible for returning the current unlock key used for encrypting TLS private keys and // other at rest data. Access to this RPC call should only be allowed via mutual TLS from managers. func (s *Server) GetUnlockKey(ctx context.Context, request *api.GetUnlockKeyRequest) (*api.GetUnlockKeyResponse, error) { @@ -446,6 +461,7 @@ func (s *Server) Run(ctx context.Context) error { "method": "(*Server).Run", }).WithError(err).Errorf("error attempting to reconcile certificates") } + s.reconcileNodeRootsAndCerts() ticker := time.NewTicker(s.reconciliationRetryInterval) defer ticker.Stop() @@ -581,7 +597,6 @@ func (s *Server) UpdateRootCA(ctx context.Context, cluster *api.Cluster) error { if signingKey == nil { signingCert = nil } - updatedRootCA, err := NewRootCA(rCA.CACert, signingCert, signingKey, expiry, intermediates) if err != nil { return errors.Wrap(err, "invalid Root CA object in cluster") @@ -646,6 +661,9 @@ func (s *Server) UpdateRootCA(ctx context.Context, cluster *api.Cluster) error { s.securityConfig.externalCA.UpdateURLs(cfsslURLs...) s.lastSeenExternalCAs = cluster.Spec.CAConfig.Copy().ExternalCAs } + if rootCAChanged && s.lastSeenClusterRootCA.RootRotation != nil { + s.reconcileNodeRootsAndCerts() + } return nil } @@ -808,3 +826,166 @@ func isFinalState(status api.IssuanceStatus) bool { return false } + +// This function assumes that the expected root CA has root rotation. This is intended to be used by +// `reconcileNodeRootsAndCerts`, which uses the root CA from the `lastSeenClusterRootCA`, and checks +// that it has a root rotation before calling this function. +func (s *Server) finishRootRotation(tx store.Tx, expectedRootCA *api.RootCA) error { + clusterID := s.securityConfig.ClientTLSCreds.Organization() + cluster := store.GetCluster(tx, clusterID) + if cluster == nil { + return fmt.Errorf("unable to get cluster %s", clusterID) + } + + // If the RootCA object has changed (because another root rotation was started or because some other node + // had finished the root rotation), we cannot finish the root rotation that we were working on. + if !equality.RootCAEqualStable(expectedRootCA, &cluster.RootCA) { + return errors.New("target root rotation has changed") + } + + var signerCert []byte + if len(cluster.RootCA.RootRotation.CAKey) > 0 { + signerCert = cluster.RootCA.RootRotation.CACert + } + // we don't actually have to parse out the default node expiration from the cluster - we are just using + // the ca.RootCA object to generate new tokens and the digest + updatedRootCA, err := NewRootCA(cluster.RootCA.RootRotation.CACert, signerCert, cluster.RootCA.RootRotation.CAKey, + DefaultNodeCertExpiration, nil) + if err != nil { + return errors.Wrap(err, "invalid cluster root rotation object") + } + cluster.RootCA = api.RootCA{ + CACert: cluster.RootCA.RootRotation.CACert, + CAKey: cluster.RootCA.RootRotation.CAKey, + CACertHash: updatedRootCA.Digest.String(), + JoinTokens: api.JoinTokens{ + Worker: GenerateJoinToken(&updatedRootCA), + Manager: GenerateJoinToken(&updatedRootCA), + }, + LastForcedRotation: cluster.RootCA.LastForcedRotation, + } + return store.UpdateCluster(tx, cluster) +} + +func (s *Server) reconcileNodeRootsAndCerts() { + s.mu.Lock() + if !s.isRunning() || s.rootCAReconciliationInProgress { + s.mu.Unlock() + return + } + ctx := s.ctx + s.rootCAReconciliationInProgress = true + s.mu.Unlock() + + s.wg.Add(1) + logger := log.G(ctx).WithField("method", "(*Server).reconcileNodeRootsAndCerts") + + go func(retryInterval time.Duration) { + defer func() { + s.wg.Done() + s.mu.Lock() + s.rootCAReconciliationInProgress = false + s.mu.Unlock() + }() + + for { + s.secConfigMu.Lock() + rootCA := s.lastSeenClusterRootCA + s.secConfigMu.Unlock() + if rootCA == nil { + return + } + + wantedIssuer := rootCA.CACert + if rootCA.RootRotation != nil { + wantedIssuer = rootCA.RootRotation.CACert + } + issuerCert, err := helpers.ParseCertificatePEM(wantedIssuer) + if err != nil { + logger.WithError(err).Error("invalid certificate in cluster root CA object") + } + + // If the last seen root rotation is nil, then possibly during a leadership transfer a different manager's + // CA node started root rotation reconciliation and finished the root rotation, so we should also end the + // root reconciliation + if rootCA.RootRotation == nil { + return + } + + var allNodes []*api.Node + + s.store.View(func(tx store.ReadTx) { + allNodes, err = store.FindNodes(tx, store.ByMembership(api.NodeMembershipAccepted)) + }) + if err != nil { + logger.WithError(err).Error("could not find all nodes in the store") + } + + var rootRotationDone bool + _, err = s.store.Batch(func(batch *store.Batch) error { + converged := true // this value is true if all the node certs are issued by the correct issuer + for _, node := range allNodes { + var needUpdate bool + err := batch.Update(func(tx store.Tx) error { + n := store.GetNode(tx, node.ID) + if n == nil || (n.Description != nil && n.Description.TLSInfo != nil && + bytes.Equal(n.Description.TLSInfo.CertIssuerPublicKey, issuerCert.RawSubjectPublicKeyInfo) && + bytes.Equal(n.Description.TLSInfo.CertIssuerSubject, issuerCert.RawSubject)) { + return nil + } + converged = false + + // If we are already waiting for the node to rotate (or the node's about to get issued a new cert by us), + // so don't bother telling the node to rotate again. If a cert is in renew or pending, the cert that + // the node will get should be issued by the correct issuer. However, in all likelihood the node will + // not get a chance to report back its TLS before the reconciliation loop will mark it as rotate, and + // it will get a new certificate, which is fine, but results in an extra certificate renewal round. + issuanceState := n.Certificate.Status.State + if issuanceState == api.IssuanceStateRotate || issuanceState == api.IssuanceStateRenew || issuanceState == api.IssuanceStatePending { + return nil + } + n.Certificate.Status.State = api.IssuanceStateRotate + needUpdate = true + return store.UpdateNode(tx, n) + }) + if err != nil { + logger.WithError(err).Debugf("could not update node %s to IssuanceStateRotate", node.ID) + } else if needUpdate { + logger.Debugf("updated node %s to IssuanceStateRotate", node.ID) + } + } + // It's possible that between getting all nodes and finishing root rotation, new nodes have joined. by the time we + // fetch all nodes, the CA signer should have already been changed to be the desired issuer by the time we list all + // nodes, so any new nodes that appear after that will only get certificates signed by the correct issuer. + if converged { + // when the nodes are all converged on having TLS certificates issued by the correct entity, we can + // try to finish the root rotation + err := batch.Update(func(tx store.Tx) error { + return s.finishRootRotation(tx, rootCA) + }) + if err != nil { + logger.WithError(err).Errorf("could not complete root rotation") + } else { + rootRotationDone = true + } + } + + return nil + }) + if err != nil { + logger.WithError(err).Errorf("failed to check nodes for desired certificate issuer") + } + if rootRotationDone { + logger.Infof("completed root rotation on cluster %s", s.securityConfig.ServerTLSCreds.Organization()) + return + } + + select { + case <-ctx.Done(): + logger.Info("stopping incomplete root rotation reconciliation due to context being stopped") + return + case <-time.After(retryInterval): + } + } + }(s.rootReconciliationRetryInterval) +} diff --git a/ca/server_test.go b/ca/server_test.go index c80279665b..9bc1cdf1b0 100644 --- a/ca/server_test.go +++ b/ca/server_test.go @@ -6,19 +6,22 @@ import ( "fmt" "io/ioutil" "os" + "reflect" "testing" "time" - "golang.org/x/net/context" - "github.com/cloudflare/cfssl/helpers" "github.com/docker/swarmkit/api" + "github.com/docker/swarmkit/api/equality" "github.com/docker/swarmkit/ca" cautils "github.com/docker/swarmkit/ca/testutils" "github.com/docker/swarmkit/manager/state/store" "github.com/docker/swarmkit/testutils" + "github.com/opencontainers/go-digest" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" ) @@ -551,3 +554,669 @@ func TestCAServerUpdateRootCA(t *testing.T) { tc.CAServer.UpdateRootCA(context.Background(), fakeClusterSpec(cautils.ECDSA256SHA256Cert, cautils.ECDSA256Key, nil, nil)) require.Equal(t, tc.RootCA.Certs, tc.ServingSecurityConfig.RootCA().Certs) } + +type rootRotationTester struct { + tc *cautils.TestCA + t *testing.T +} + +// go through all the nodes and update/create the ones we want, and delete the ones +// we don't +func (r *rootRotationTester) convergeWantedNodes(wantNodes map[string]*api.Node, descr string) { + updatedOldNodes := make(map[string]struct{}) + createdNewNodes := make(map[string]struct{}) + require.NoError(r.t, testutils.PollFuncWithTimeout(nil, func() error { + return r.tc.MemoryStore.Update(func(tx store.Tx) error { + nodes, err := store.FindNodes(tx, store.All) + if err != nil { + return err + } + for _, node := range nodes { + wanted, inWanted := wantNodes[node.ID] + _, done := updatedOldNodes[node.ID] + + if inWanted && !done { + node.Description = wanted.Description + node.Certificate = wanted.Certificate + if err := store.UpdateNode(tx, node); err != nil { + return err + } + updatedOldNodes[node.ID] = struct{}{} + } else if !inWanted { + if err := store.DeleteNode(tx, node.ID); err != nil { + return err + } + updatedOldNodes[node.ID] = struct{}{} + } + } + for nodeID, wanted := range wantNodes { + _, createdAlready := createdNewNodes[nodeID] + if _, ok := updatedOldNodes[nodeID]; !ok && !createdAlready { + if err := store.CreateNode(tx, wanted); err != nil { + return err + } + createdNewNodes[nodeID] = struct{}{} + } + } + return nil + }) + }, 2*time.Second), descr) +} + +func (r *rootRotationTester) convergeRootCA(wantRootCA *api.RootCA, descr string) { + require.NoError(r.t, testutils.PollFuncWithTimeout(nil, func() error { + return r.tc.MemoryStore.Update(func(tx store.Tx) error { + clusters, err := store.FindClusters(tx, store.All) + if err != nil || len(clusters) != 1 { + return errors.Wrap(err, "unable to find cluster") + } + clusters[0].RootCA = *wantRootCA + return store.UpdateCluster(tx, clusters[0]) + }) + }, time.Second), descr) +} + +func (r *rootRotationTester) getFakeAPINode(id string, state api.IssuanceStatus_State, tlsInfo *api.NodeTLSInfo, member bool) *api.Node { + node := &api.Node{ + ID: id, + Certificate: api.Certificate{ + Status: api.IssuanceStatus{ + State: state, + }, + }, + Spec: api.NodeSpec{ + Membership: api.NodeMembershipAccepted, + }, + } + if !member { + node.Spec.Membership = api.NodeMembershipPending + } + // the CA server will immediately pick these up, so generate CSRs for the CA server to sign + if state == api.IssuanceStateRenew || state == api.IssuanceStatePending { + csr, _, err := ca.GenerateNewCSR() + require.NoError(r.t, err) + node.Certificate.CSR = csr + } + if tlsInfo != nil { + node.Description = &api.NodeDescription{TLSInfo: tlsInfo} + } + return node +} + +func startCAServer(caServer *ca.Server) { + alreadyRunning := make(chan struct{}) + go func() { + if err := caServer.Run(context.Background()); err != nil { + close(alreadyRunning) + } + }() + select { + case <-caServer.Ready(): + case <-alreadyRunning: + } +} + +func getRotationInfo(t *testing.T, rotationCert []byte, rootCA *ca.RootCA) ([]byte, *api.NodeTLSInfo) { + parsedNewRoot, err := helpers.ParseCertificatePEM(rotationCert) + require.NoError(t, err) + crossSigned, err := rootCA.CrossSignCACertificate(rotationCert) + require.NoError(t, err) + return crossSigned, &api.NodeTLSInfo{ + TrustRoot: rootCA.Certs, + CertIssuerPublicKey: parsedNewRoot.RawSubjectPublicKeyInfo, + CertIssuerSubject: parsedNewRoot.RawSubject, + } +} + +// These are the root rotation test cases where we expect there to be a change in the FindNodes +// or root CA values after converging. +func TestRootRotationReconciliationWithChanges(t *testing.T) { + t.Parallel() + if cautils.External { + // the external CA functionality is unrelated to testing the reconciliation loop + return + } + + tc := cautils.NewTestCA(t) + defer tc.Stop() + rt := rootRotationTester{ + tc: tc, + t: t, + } + + rotationCerts := [][]byte{cautils.ECDSA256SHA256Cert, cautils.ECDSACertChain[2]} + rotationKeys := [][]byte{cautils.ECDSA256Key, cautils.ECDSACertChainKeys[2]} + var ( + rotationCrossSigned [][]byte + rotationTLSInfo []*api.NodeTLSInfo + ) + for _, cert := range rotationCerts { + cross, info := getRotationInfo(t, cert, &tc.RootCA) + rotationCrossSigned = append(rotationCrossSigned, cross) + rotationTLSInfo = append(rotationTLSInfo, info) + } + + oldNodeTLSInfo := &api.NodeTLSInfo{ + TrustRoot: tc.RootCA.Certs, + CertIssuerPublicKey: tc.ServingSecurityConfig.IssuerInfo().PublicKey, + CertIssuerSubject: tc.ServingSecurityConfig.IssuerInfo().Subject, + } + + var startCluster *api.Cluster + tc.MemoryStore.View(func(tx store.ReadTx) { + startCluster = store.GetCluster(tx, tc.Organization) + }) + require.NotNil(t, startCluster) + + testcases := []struct { + nodes map[string]*api.Node // what nodes we should start with + rootCA *api.RootCA // what root CA we should start with + expectedNodes map[string]*api.Node // what nodes we expect in the end, if nil, then unchanged from the start + expectedRootCA *api.RootCA // what root CA we expect in the end, if nil, then unchanged from the start + caServerRestart bool // whether to stop the CA server before making the node and root changes and restart after + descr string + }{ + { + descr: ("If there is no TLS info, the reconciliation cycle tells the nodes to rotate if they're not already getting " + + "a new cert. Any renew/pending nodes will have certs issued, but because the TLS info is nil, they will " + + `go "rotate" state`), + nodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, nil, true), + "2": rt.getFakeAPINode("2", api.IssuanceStateRenew, nil, true), + "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, nil, true), + "4": rt.getFakeAPINode("4", api.IssuanceStatePending, nil, true), + "5": rt.getFakeAPINode("5", api.IssuanceStateFailed, nil, true), + }, + rootCA: &api.RootCA{ + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCerts[0], + CAKey: rotationKeys[0], + CrossSignedCACert: rotationCrossSigned[0], + }, + }, + expectedNodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateRotate, nil, true), + "2": rt.getFakeAPINode("2", api.IssuanceStateRotate, nil, true), + "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, nil, true), + "4": rt.getFakeAPINode("4", api.IssuanceStateRotate, nil, true), + "5": rt.getFakeAPINode("5", api.IssuanceStateRotate, nil, true), + }, + }, + { + descr: ("Assume all of the nodes have gotten certs, but some of them are the wrong cert " + + "(going by the TLS info), which shouldn't really happen. the rotation reconciliation " + + "will tell the wrong ones to rotate a second time"), + nodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "2": rt.getFakeAPINode("2", api.IssuanceStateIssued, oldNodeTLSInfo, true), + "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": rt.getFakeAPINode("5", api.IssuanceStateIssued, oldNodeTLSInfo, true), + }, + rootCA: &api.RootCA{ // no change in root CA from previous + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCerts[0], + CAKey: rotationKeys[0], + CrossSignedCACert: rotationCrossSigned[0], + }, + }, + expectedNodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "2": rt.getFakeAPINode("2", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": rt.getFakeAPINode("5", api.IssuanceStateRotate, oldNodeTLSInfo, true), + }, + }, + { + descr: ("New nodes that are added will also be picked up and told to rotate"), + nodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": rt.getFakeAPINode("5", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "6": rt.getFakeAPINode("6", api.IssuanceStateRenew, nil, true), + }, + rootCA: &api.RootCA{ // no change in root CA from previous + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCerts[0], + CAKey: rotationKeys[0], + CrossSignedCACert: rotationCrossSigned[0], + }, + }, + expectedNodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": rt.getFakeAPINode("5", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "6": rt.getFakeAPINode("6", api.IssuanceStateRotate, nil, true), + }, + }, + { + descr: ("Even if root rotation isn't finished, if the root changes again to a " + + "different cert, all the nodes with the old root rotation cert will be told " + + "to rotate again."), + nodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": rt.getFakeAPINode("5", api.IssuanceStateIssued, oldNodeTLSInfo, true), + "6": rt.getFakeAPINode("6", api.IssuanceStateIssued, rotationTLSInfo[0], true), + }, + rootCA: &api.RootCA{ // new root rotation + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCerts[1], + CAKey: rotationKeys[1], + CrossSignedCACert: rotationCrossSigned[1], + }, + }, + expectedNodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateRotate, rotationTLSInfo[0], true), + "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": rt.getFakeAPINode("4", api.IssuanceStateRotate, rotationTLSInfo[0], true), + "5": rt.getFakeAPINode("5", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "6": rt.getFakeAPINode("6", api.IssuanceStateRotate, rotationTLSInfo[0], true), + }, + }, + { + descr: ("Once all nodes have rotated to their desired TLS info (even if it's because " + + "a node with the wrong TLS info has been removed, the root rotation is completed."), + nodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "6": rt.getFakeAPINode("6", api.IssuanceStateIssued, rotationTLSInfo[1], true), + }, + rootCA: &api.RootCA{ + // no change in root CA from previous - even if root rotation gets completed after + // the nodes are first set, and we just add the root rotation again because of this + // test order, because the TLS info is correct for all nodes it will be completed again + // anyway) + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCerts[1], + CAKey: rotationKeys[1], + CrossSignedCACert: rotationCrossSigned[1], + }, + }, + expectedRootCA: &api.RootCA{ + CACert: rotationCerts[1], + CAKey: rotationKeys[1], + CACertHash: digest.FromBytes(rotationCerts[1]).String(), + // ignore the join tokens - we aren't comparing them + }, + }, + { + descr: ("If a root rotation happens when the CA server is down, so long as it saw the change " + + "it will start reconciling the nodes as soon as it's started up again"), + caServerRestart: true, + nodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "6": rt.getFakeAPINode("6", api.IssuanceStateIssued, rotationTLSInfo[1], true), + }, + rootCA: &api.RootCA{ + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCerts[0], + CAKey: rotationKeys[0], + CrossSignedCACert: rotationCrossSigned[0], + }, + }, + expectedNodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateRotate, rotationTLSInfo[1], true), + "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, rotationTLSInfo[1], true), + "4": rt.getFakeAPINode("4", api.IssuanceStateRotate, rotationTLSInfo[1], true), + "6": rt.getFakeAPINode("6", api.IssuanceStateRotate, rotationTLSInfo[1], true), + }, + }, + } + + for _, testcase := range testcases { + if testcase.caServerRestart { + rt.tc.CAServer.Stop() + } + + rt.convergeRootCA(testcase.rootCA, testcase.descr) + rt.convergeWantedNodes(testcase.nodes, testcase.descr) + + if testcase.expectedNodes == nil { + testcase.expectedNodes = testcase.nodes + } + if testcase.expectedRootCA == nil { + testcase.expectedRootCA = testcase.rootCA + } + + if testcase.caServerRestart { + startCAServer(rt.tc.CAServer) + } + + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + var ( + nodes []*api.Node + cluster *api.Cluster + err error + ) + tc.MemoryStore.View(func(tx store.ReadTx) { + nodes, err = store.FindNodes(tx, store.All) + cluster = store.GetCluster(tx, tc.Organization) + }) + if err != nil { + return err + } + if cluster == nil { + return errors.New("no cluster found") + } + + if !equality.RootCAEqualStable(&cluster.RootCA, testcase.expectedRootCA) { + return fmt.Errorf("root CAs not equal:\n\texpected: %v\n\tactual: %v", *testcase.expectedRootCA, cluster.RootCA) + } + if len(nodes) != len(testcase.expectedNodes) { + return fmt.Errorf("number of expected nodes (%d) does not equal number of actual nodes (%d)", + len(testcase.expectedNodes), len(nodes)) + } + for _, node := range nodes { + expected, ok := testcase.expectedNodes[node.ID] + if !ok { + return fmt.Errorf("node %s is present and was unexpected", node.ID) + } + if !reflect.DeepEqual(expected.Description, node.Description) { + return fmt.Errorf("the node description of node %s is not expected:\n\texpected: %v\n\tactual: %v", node.ID, + expected.Description, node.Description) + } + if !reflect.DeepEqual(expected.Certificate.Status, node.Certificate.Status) { + return fmt.Errorf("the certificate status of node %s is not expected:\n\texpected: %v\n\tactual: %v", node.ID, + expected.Certificate, node.Certificate) + } + } + return nil + }, 5*time.Second), testcase.descr) + } +} + +// These are the root rotation test cases where we expect there to be no changes made to either +// the nodes or the root CA object +func TestRootRotationReconciliationNoChanges(t *testing.T) { + t.Parallel() + if cautils.External { + // the external CA functionality is unrelated to testing the reconciliation loop + return + } + + tc := cautils.NewTestCA(t) + defer tc.Stop() + rt := rootRotationTester{ + tc: tc, + t: t, + } + + rotationCert := cautils.ECDSA256SHA256Cert + rotationKey := cautils.ECDSA256Key + rotationCrossSigned, rotationTLSInfo := getRotationInfo(t, rotationCert, &tc.RootCA) + + oldNodeTLSInfo := &api.NodeTLSInfo{ + TrustRoot: tc.RootCA.Certs, + CertIssuerPublicKey: tc.ServingSecurityConfig.IssuerInfo().PublicKey, + CertIssuerSubject: tc.ServingSecurityConfig.IssuerInfo().Subject, + } + + var startCluster *api.Cluster + tc.MemoryStore.View(func(tx store.ReadTx) { + startCluster = store.GetCluster(tx, tc.Organization) + }) + require.NotNil(t, startCluster) + + testcases := []struct { + nodes map[string]*api.Node // what nodes we should start with + rootCA *api.RootCA // what root CA we should start with + descr string + caServerStopped bool // if the server is running, only then will a reconciliation loop happen + }{ + { + descr: ("If the CA server is not running no reconciliation happens even if a root rotation " + + "is in progress"), + caServerStopped: true, + nodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, oldNodeTLSInfo, true), + "2": rt.getFakeAPINode("2", api.IssuanceStateRenew, nil, true), + "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, nil, true), + "4": rt.getFakeAPINode("4", api.IssuanceStatePending, nil, true), + "5": rt.getFakeAPINode("5", api.IssuanceStateFailed, nil, true), + }, + rootCA: &api.RootCA{ + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCert, + CAKey: rotationKey, + CrossSignedCACert: rotationCrossSigned, + }, + }, + }, + { + descr: ("If all nodes have the right TLS info or are already rotated (or are not members), the " + + "there will be no changes needed"), + nodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo, true), + "2": rt.getFakeAPINode("2", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, rotationTLSInfo, true), + }, + rootCA: &api.RootCA{ // no change in root CA from previous + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCert, + CAKey: rotationKey, + CrossSignedCACert: rotationCrossSigned, + }, + }, + }, + { + descr: ("Nodes already in rotate state, even if they currently have the correct TLS issuer, will be " + + "left in the rotate state even if root rotation is aborted because we don't know if they're already " + + "in the process of getting a new cert. Even if they're issued by a different issuer, they will be " + + "left alone because they'll have an interemdiate that chains up to the old issuer."), + nodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo, true), + "2": rt.getFakeAPINode("2", api.IssuanceStateRotate, oldNodeTLSInfo, true), + }, + rootCA: &api.RootCA{ // no change in root CA from previous + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + }, + }, + } + + for _, testcase := range testcases { + if testcase.caServerStopped { + rt.tc.CAServer.Stop() + } else { + startCAServer(rt.tc.CAServer) + } + + rt.convergeRootCA(testcase.rootCA, testcase.descr) + rt.convergeWantedNodes(testcase.nodes, testcase.descr) + + time.Sleep(500 * time.Millisecond) + + var ( + nodes []*api.Node + cluster *api.Cluster + err error + ) + + tc.MemoryStore.View(func(tx store.ReadTx) { + nodes, err = store.FindNodes(tx, store.All) + cluster = store.GetCluster(tx, tc.Organization) + }) + require.NoError(t, err) + require.NotNil(t, cluster) + require.Equal(t, cluster.RootCA, *testcase.rootCA, testcase.descr) + require.Len(t, nodes, len(testcase.nodes), testcase.descr) + for _, node := range nodes { + expected, ok := testcase.nodes[node.ID] + require.True(t, ok, "node %s: %s", node.ID, testcase.descr) + require.Equal(t, expected.Description, node.Description, "node %s: %s", node.ID, testcase.descr) + require.Equal(t, expected.Certificate.Status, node.Certificate.Status, "node %s: %s", node.ID, testcase.descr) + } + } +} + +// Tests if the root rotation changes while the reconciliation loop is going, eventually the root rotation will finish +// successfully (even if there's a competing reconciliation loop, for instance if there's a bug during leadership handoff). +func TestRootRotationReconciliationRace(t *testing.T) { + t.Parallel() + if cautils.External { + // the external CA functionality is unrelated to testing the reconciliation loop + return + } + + tc := cautils.NewTestCA(t) + defer tc.Stop() + rt := rootRotationTester{ + tc: tc, + t: t, + } + // start a competing CA server + tempDir, err := ioutil.TempDir("", "competing-ca-server") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + paths := ca.NewConfigPaths(tempDir) + competingSecConfig, err := tc.NewNodeConfig(ca.ManagerRole) + require.NoError(t, err) + + var otherServers []*ca.Server + for i := 0; i < 3; i++ { // to make sure we get some collision + otherServer := ca.NewServer(tc.MemoryStore, competingSecConfig, paths.RootCA) + // offset each server's reconciliation interval somewhat so that some will + // pre-empt others + otherServer.SetRootReconciliationInterval(time.Millisecond * time.Duration((i+1)*10)) + startCAServer(otherServer) + defer otherServer.Stop() + otherServers = append(otherServers, otherServer) + } + clusterWatch, clusterWatchCancel, err := store.ViewAndWatch( + tc.MemoryStore, func(tx store.ReadTx) error { + cluster := store.GetCluster(tx, tc.Organization) + for _, s := range otherServers { + s.UpdateRootCA(context.Background(), cluster) + } + return nil + }, + api.EventUpdateCluster{ + Cluster: &api.Cluster{ID: tc.Organization}, + Checks: []api.ClusterCheckFunc{api.ClusterCheckID}, + }, + ) + require.NoError(t, err) + defer clusterWatchCancel() + + done := make(chan struct{}) + defer close(done) + go func() { + for { + select { + case event := <-clusterWatch: + clusterEvent := event.(api.EventUpdateCluster) + for _, s := range otherServers { + s.UpdateRootCA(context.Background(), clusterEvent.Cluster) + } + case <-done: + return + } + } + }() + + oldNodeTLSInfo := &api.NodeTLSInfo{ + TrustRoot: tc.RootCA.Certs, + CertIssuerPublicKey: tc.ServingSecurityConfig.IssuerInfo().PublicKey, + CertIssuerSubject: tc.ServingSecurityConfig.IssuerInfo().Subject, + } + + nodes := make(map[string]*api.Node) + for i := 0; i < 5; i++ { + nodeID := fmt.Sprintf("%d", i) + nodes[nodeID] = rt.getFakeAPINode(nodeID, api.IssuanceStateIssued, oldNodeTLSInfo, true) + } + rt.convergeWantedNodes(nodes, "setting up nodes for root rotation race condition test") + + for i := 0; i < 10; i++ { + var ( + rotationCrossSigned []byte + rotationTLSInfo *api.NodeTLSInfo + ) + rotationCert, rotationKey, err := cautils.CreateRootCertAndKey(fmt.Sprintf("root cn %d", i)) + require.NoError(t, err) + require.NoError(t, tc.MemoryStore.Update(func(tx store.Tx) error { + cluster := store.GetCluster(tx, tc.Organization) + if cluster == nil { + return errors.New("cluster has disappeared") + } + rootCA := cluster.RootCA.Copy() + caRootCA, err := ca.NewRootCA(rootCA.CACert, rootCA.CACert, rootCA.CAKey, ca.DefaultNodeCertExpiration, nil) + if err != nil { + return err + } + rotationCrossSigned, rotationTLSInfo = getRotationInfo(t, rotationCert, &caRootCA) + rootCA.RootRotation = &api.RootRotation{ + CACert: rotationCert, + CAKey: rotationKey, + CrossSignedCACert: rotationCrossSigned, + } + cluster.RootCA = *rootCA + return store.UpdateCluster(tx, cluster) + })) + for _, node := range nodes { + node.Description.TLSInfo = rotationTLSInfo + } + rt.convergeWantedNodes(nodes, fmt.Sprintf("iteration %d", i)) + } + + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + var cluster *api.Cluster + tc.MemoryStore.View(func(tx store.ReadTx) { + cluster = store.GetCluster(tx, tc.Organization) + }) + if cluster == nil { + return errors.New("cluster has disappeared") + } + if cluster.RootCA.RootRotation == nil { + return fmt.Errorf("root rotation is still present") + } + return nil + }, time.Second)) +} diff --git a/ca/testutils/cautils.go b/ca/testutils/cautils.go index 61024320c2..532943181c 100644 --- a/ca/testutils/cautils.go +++ b/ca/testutils/cautils.go @@ -21,6 +21,7 @@ import ( "github.com/docker/swarmkit/connectionbroker" "github.com/docker/swarmkit/identity" "github.com/docker/swarmkit/ioutils" + "github.com/docker/swarmkit/log" "github.com/docker/swarmkit/manager/state/store" stateutils "github.com/docker/swarmkit/manager/state/testutils" "github.com/docker/swarmkit/remotes" @@ -178,6 +179,7 @@ func NewTestCAFromRootCA(t *testing.T, tempBaseDir string, rootCA ca.RootCA, krw caServer := ca.NewServer(s, managerConfig, paths.RootCA) caServer.SetReconciliationRetryInterval(50 * time.Millisecond) + caServer.SetRootReconciliationInterval(50 * time.Millisecond) api.RegisterCAServer(grpcServer, caServer) api.RegisterNodeCAServer(grpcServer, caServer) @@ -200,7 +202,9 @@ func NewTestCAFromRootCA(t *testing.T, tempBaseDir string, rootCA ca.RootCA, krw select { case event := <-clusterWatch: clusterEvent := event.(api.EventUpdateCluster) - caServer.UpdateRootCA(ctx, clusterEvent.Cluster) + if err := caServer.UpdateRootCA(ctx, clusterEvent.Cluster); err != nil { + log.G(ctx).WithError(err).Error("ca utils CA server could not update root CA") + } case <-ctx.Done(): clusterWatchCancel() return From 225bea52378ff0890c6dca06385ae2deffc962d0 Mon Sep 17 00:00:00 2001 From: cyli Date: Wed, 5 Apr 2017 15:53:01 -0700 Subject: [PATCH 2/5] Add root rotation integration tests Signed-off-by: cyli --- integration/api.go | 10 +- integration/cluster.go | 51 ++++++--- integration/integration_test.go | 181 ++++++++++++++++++++++++++++++++ integration/node.go | 5 + 4 files changed, 229 insertions(+), 18 deletions(-) diff --git a/integration/api.go b/integration/api.go index b7a838eeb9..0f44037887 100644 --- a/integration/api.go +++ b/integration/api.go @@ -131,6 +131,12 @@ func (a *dummyAPI) ListClusters(ctx context.Context, r *api.ListClustersRequest) return cli.ListClusters(ctx, r) } -func (a *dummyAPI) UpdateCluster(context.Context, *api.UpdateClusterRequest) (*api.UpdateClusterResponse, error) { - panic("not implemented") +func (a *dummyAPI) UpdateCluster(ctx context.Context, r *api.UpdateClusterRequest) (*api.UpdateClusterResponse, error) { + ctx, cancel := context.WithTimeout(ctx, opsTimeout) + defer cancel() + cli, err := a.c.RandomManager().ControlClient(ctx) + if err != nil { + return nil, err + } + return cli.UpdateCluster(ctx, r) } diff --git a/integration/cluster.go b/integration/cluster.go index f56351d7e5..6bc00925d6 100644 --- a/integration/cluster.go +++ b/integration/cluster.go @@ -98,14 +98,11 @@ func (c *testCluster) AddManager(lateBind bool, rootCA *ca.RootCA) error { if err != nil { return err } - clusterInfo, err := c.api.ListClusters(context.Background(), &api.ListClustersRequest{}) + clusterInfo, err := c.GetClusterInfo() if err != nil { return err } - if len(clusterInfo.Clusters) == 0 { - return fmt.Errorf("joining manager: there is no cluster created in storage") - } - node, err := newTestNode(joinAddr, clusterInfo.Clusters[0].RootCA.JoinTokens.Manager, false, nil) + node, err := newTestNode(joinAddr, clusterInfo.RootCA.JoinTokens.Manager, false, nil) if err != nil { return err } @@ -137,14 +134,9 @@ func (c *testCluster) AddManager(lateBind bool, rootCA *ca.RootCA) error { if lateBind { // Verify that the control API works - clusterInfo, err := c.api.ListClusters(context.Background(), &api.ListClustersRequest{}) - if err != nil { + if _, err := c.GetClusterInfo(); err != nil { return err } - if len(clusterInfo.Clusters) == 0 { - return fmt.Errorf("joining manager: there is no cluster created in storage") - } - return n.node.BindRemote(context.Background(), "127.0.0.1:0", "") } @@ -162,14 +154,11 @@ func (c *testCluster) AddAgent() error { if err != nil { return err } - clusterInfo, err := c.api.ListClusters(context.Background(), &api.ListClustersRequest{}) + clusterInfo, err := c.GetClusterInfo() if err != nil { return err } - if len(clusterInfo.Clusters) == 0 { - return fmt.Errorf("joining agent: there is no cluster created in storage") - } - node, err := newTestNode(joinAddr, clusterInfo.Clusters[0].RootCA.JoinTokens.Worker, false, nil) + node, err := newTestNode(joinAddr, clusterInfo.RootCA.JoinTokens.Worker, false, nil) if err != nil { return err } @@ -383,3 +372,33 @@ func (c *testCluster) StartNode(id string) error { } return nil } + +func (c *testCluster) GetClusterInfo() (*api.Cluster, error) { + clusterInfo, err := c.api.ListClusters(context.Background(), &api.ListClustersRequest{}) + if err != nil { + return nil, err + } + if len(clusterInfo.Clusters) != 1 { + return nil, fmt.Errorf("number of clusters in storage: %d; expected 1", len(clusterInfo.Clusters)) + } + return clusterInfo.Clusters[0], nil +} + +func (c *testCluster) RotateRootCA(cert, key []byte) error { + // poll in case something else changes the cluster before we can update it + return testutils.PollFuncWithTimeout(nil, func() error { + clusterInfo, err := c.GetClusterInfo() + if err != nil { + return err + } + newSpec := clusterInfo.Spec.Copy() + newSpec.CAConfig.SigningCACert = cert + newSpec.CAConfig.SigningCAKey = key + _, err = c.api.UpdateCluster(context.Background(), &api.UpdateClusterRequest{ + ClusterID: clusterInfo.ID, + Spec: newSpec, + ClusterVersion: &clusterInfo.Meta.Version, + }) + return err + }, opsTimeout) +} diff --git a/integration/integration_test.go b/integration/integration_test.go index dd13f5667c..d676987afe 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -1,6 +1,7 @@ package integration import ( + "bytes" "flag" "fmt" "io/ioutil" @@ -13,6 +14,8 @@ import ( "golang.org/x/net/context" + "reflect" + "github.com/Sirupsen/logrus" "github.com/cloudflare/cfssl/helpers" events "github.com/docker/go-events" @@ -99,6 +102,9 @@ func pollClusterReady(t *testing.T, c *testCluster, numWorker, numManager int) { return fmt.Errorf("worker node %s should not have manager status, returned %s", n.ID, n.ManagerStatus) } } + if n.Description.TLSInfo == nil { + return fmt.Errorf("node %s has not reported its TLS info yet", n.ID) + } } if !leaderFound { return fmt.Errorf("leader of cluster is not found") @@ -547,3 +553,178 @@ func TestForceNewCluster(t *testing.T) { require.NoError(t, ioutil.WriteFile(managerCertFile, expiredCertPEM, 0644)) require.Error(t, cl.StartNode(nodeID)) } + +func pollRootRotationDone(t *testing.T, cl *testCluster) { + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + clusterInfo, err := cl.GetClusterInfo() + if err != nil { + return err + } + if clusterInfo.RootCA.RootRotation != nil { + return errors.New("root rotation not done") + } + return nil + }, opsTimeout)) +} + +func TestSuccessfulRootRotation(t *testing.T) { + t.Parallel() + numWorker, numManager := 2, 3 + cl := newCluster(t, numWorker, numManager) + defer func() { + require.NoError(t, cl.Stop()) + }() + pollClusterReady(t, cl, numWorker, numManager) + + // Take down one of managers and both workers, so we can't actually ever finish root rotation. + resp, err := cl.api.ListNodes(context.Background(), &api.ListNodesRequest{}) + require.NoError(t, err) + var ( + downManagerID string + downWorkerIDs []string + oldTLSInfo *api.NodeTLSInfo + ) + for _, n := range resp.Nodes { + if oldTLSInfo != nil { + require.Equal(t, oldTLSInfo, n.Description.TLSInfo) + } else { + oldTLSInfo = n.Description.TLSInfo + } + if n.Role == api.NodeRoleManager { + if !n.ManagerStatus.Leader && downManagerID == "" { + downManagerID = n.ID + require.NoError(t, cl.nodes[n.ID].Pause(false)) + } + continue + } + downWorkerIDs = append(downWorkerIDs, n.ID) + require.NoError(t, cl.nodes[n.ID].Pause(false)) + } + + // perform a root rotation, and wait until all the nodes that are up have newly issued certs + newRootCert, newRootKey, err := cautils.CreateRootCertAndKey("newRootCN") + require.NoError(t, err) + require.NoError(t, cl.RotateRootCA(newRootCert, newRootKey)) + + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + resp, err := cl.api.ListNodes(context.Background(), &api.ListNodesRequest{}) + if err != nil { + return err + } + for _, n := range resp.Nodes { + isDown := n.ID == downManagerID || n.ID == downWorkerIDs[0] || n.ID == downWorkerIDs[1] + if reflect.DeepEqual(n.Description.TLSInfo, oldTLSInfo) != isDown { + return fmt.Errorf("expected TLS info to have changed: %v", !isDown) + } + } + + // root rotation isn't done + clusterInfo, err := cl.GetClusterInfo() + if err != nil { + return err + } + require.NotNil(t, clusterInfo.RootCA.RootRotation) // if root rotation is already done, fail and finish the test here + return nil + }, opsTimeout)) + + // Kill the leader, and bring the other manager back to show that a new + // leader can pick up the reconciliation loop and complete the root rotation. + // Bring one worker back, kill the other worker, and add a new worker - show + // that we can converge on a root rotation. + downLeader, err := cl.Leader() + require.NoError(t, err) + leaderID := downLeader.node.NodeID() + + require.NoError(t, cl.StartNode(downManagerID)) + require.NoError(t, cl.StartNode(downWorkerIDs[0])) + require.NoError(t, cl.RemoveNode(downWorkerIDs[1], false)) + require.NoError(t, cl.AddAgent()) + require.NoError(t, downLeader.Pause(false)) + + // we can finish root rotation even though the previous leader was down because it had + // already rotated its cert + pollRootRotationDone(t, cl) + + // bring the previous leader back up so it can get the new root + require.NoError(t, cl.StartNode(leaderID)) + + // wait until all the nodes have gotten their new certs and trust roots + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + resp, err = cl.api.ListNodes(context.Background(), &api.ListNodesRequest{}) + if err != nil { + return err + } + var newTLSInfo *api.NodeTLSInfo + for _, n := range resp.Nodes { + if newTLSInfo == nil { + newTLSInfo = n.Description.TLSInfo + if bytes.Equal(newTLSInfo.CertIssuerPublicKey, oldTLSInfo.CertIssuerPublicKey) || + bytes.Equal(newTLSInfo.CertIssuerSubject, oldTLSInfo.CertIssuerSubject) { + return errors.New("expecting the issuer to have changed") + } + if !bytes.Equal(newTLSInfo.TrustRoot, newRootCert) { + return errors.New("expecting the the root certificate to have changed") + } + } else if !reflect.DeepEqual(newTLSInfo, n.Description.TLSInfo) { + return fmt.Errorf("the nodes have not converged yet, particularly %s", n.ID) + } + + if n.Certificate.Status.State != api.IssuanceStateIssued { + return errors.New("nodes have yet to finish renewing their TLS certificates") + } + } + return nil + }, opsTimeout)) +} + +func TestRepeatedRootRotation(t *testing.T) { + t.Parallel() + numWorker, numManager := 3, 1 + cl := newCluster(t, numWorker, numManager) + defer func() { + require.NoError(t, cl.Stop()) + }() + pollClusterReady(t, cl, numWorker, numManager) + + resp, err := cl.api.ListNodes(context.Background(), &api.ListNodesRequest{}) + require.NoError(t, err) + var oldTLSInfo *api.NodeTLSInfo + for _, n := range resp.Nodes { + if oldTLSInfo != nil { + require.Equal(t, oldTLSInfo, n.Description.TLSInfo) + } else { + oldTLSInfo = n.Description.TLSInfo + } + } + + // perform multiple root rotations, wait a second between each + var newRootCert, newRootKey []byte + for i := 0; i < 3; i++ { + newRootCert, newRootKey, err = cautils.CreateRootCertAndKey("newRootCN") + require.NoError(t, err) + require.NoError(t, cl.RotateRootCA(newRootCert, newRootKey)) + time.Sleep(time.Second) + } + + pollRootRotationDone(t, cl) + + // wait until all the nodes are stabilized back to the latest issuer + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + resp, err = cl.api.ListNodes(context.Background(), &api.ListNodesRequest{}) + if err != nil { + return nil + } + for _, n := range resp.Nodes { + if reflect.DeepEqual(n.Description.TLSInfo, oldTLSInfo) { + return errors.New("nodes have not changed TLS info") + } + if n.Certificate.Status.State != api.IssuanceStateIssued { + return errors.New("nodes have yet to finish renewing their TLS certificates") + } + if !bytes.Equal(n.Description.TLSInfo.TrustRoot, newRootCert) { + return errors.New("nodes do not all trust the new root yet") + } + } + return nil + }, opsTimeout)) +} diff --git a/integration/node.go b/integration/node.go index 17868ecc95..f4f27a2d2a 100644 --- a/integration/node.go +++ b/integration/node.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" "runtime" + "strings" "google.golang.org/grpc" @@ -115,6 +116,10 @@ func (n *testNode) stop() error { defer cancel() isManager := n.IsManager() if err := n.node.Stop(ctx); err != nil { + // if the error is from trying to stop an already stopped stopped node, ignore the error + if strings.Contains(err.Error(), "node: not started") { + return nil + } // TODO(aaronl): This stack dumping may be removed in the // future once context deadline issues while shutting down // nodes are resolved. From 2a91941cda647a1d662709e937bcf1277d4499a9 Mon Sep 17 00:00:00 2001 From: cyli Date: Thu, 6 Apr 2017 17:45:59 -0700 Subject: [PATCH 3/5] Abstract the reconciliation loop to its own object. Rather than polling the store for all nodes at intervals, rely on the cluster and node watches to update an in-memory mapping of the current nodes. At regular intervals, update the store to tell a throttled number of the unconverged nodes to rotate their certificates. Also, remove the leader rotation part of the root rotation integration test, as that takes a very long time. There are server tests to ensure that multiple CA servers running reconciliation loops, and starting a CA server from a stopped state, does not break root reconciliation. Signed-off-by: cyli --- ca/reconciler.go | 271 +++++++++++++++++++++++++ ca/server.go | 202 +++--------------- ca/server_test.go | 349 ++++++++++++++++++++------------ integration/integration_test.go | 14 +- 4 files changed, 521 insertions(+), 315 deletions(-) create mode 100644 ca/reconciler.go diff --git a/ca/reconciler.go b/ca/reconciler.go new file mode 100644 index 0000000000..04b2aef211 --- /dev/null +++ b/ca/reconciler.go @@ -0,0 +1,271 @@ +package ca + +import ( + "context" + "fmt" + "sync" + "time" + + "bytes" + + "reflect" + + "github.com/cloudflare/cfssl/helpers" + "github.com/docker/swarmkit/api" + "github.com/docker/swarmkit/api/equality" + "github.com/docker/swarmkit/log" + "github.com/docker/swarmkit/manager/state/store" + "github.com/pkg/errors" +) + +// IssuanceStateRotateBatchSize is the maximum number of nodes we'll tell to rotate their certificates in any given update +const IssuanceStateRotateBatchSize = 30 + +func hasIssuer(n *api.Node, info *IssuerInfo) bool { + if n.Description == nil || n.Description.TLSInfo == nil { + return false + } + return bytes.Equal(info.Subject, n.Description.TLSInfo.CertIssuerSubject) && bytes.Equal(info.PublicKey, n.Description.TLSInfo.CertIssuerPublicKey) +} + +// IssuerFromAPIRootCA returns the desired issuer given an API root CA object +func IssuerFromAPIRootCA(rootCA *api.RootCA) (*IssuerInfo, error) { + wantedIssuer := rootCA.CACert + if rootCA.RootRotation != nil { + wantedIssuer = rootCA.RootRotation.CACert + } + issuerCerts, err := helpers.ParseCertificatesPEM(wantedIssuer) + if err != nil { + return nil, errors.Wrap(err, "invalid certificate in cluster root CA object") + } + if len(issuerCerts) == 0 { + return nil, errors.Wrap(err, "invalid certificate in cluster root CA object") + } + return &IssuerInfo{ + Subject: issuerCerts[0].RawSubject, + PublicKey: issuerCerts[0].RawSubjectPublicKeyInfo, + }, nil +} + +var errRootRotationChanged = errors.New("target root rotation has changed") + +// rootRotationReconciler keeps track of all the nodes in the store so that we can determine which ones need reconciliation when nodes are updated +// or the root CA is updated. This is meant to be used with watches on nodes and the cluster, and provides functions to be called when the +// cluster's RootCA has changed and when a node is added, updated, or removed. +type rootRotationReconciler struct { + mu sync.Mutex + clusterID string + batchUpdateInterval time.Duration + ctx context.Context + store *store.MemoryStore + + currentRootCA *api.RootCA + currentIssuer IssuerInfo + allNodes map[string]*api.Node + unconvergedNodes map[string]struct{} + + wg sync.WaitGroup + cancel func() +} + +func newReconciler(ctx context.Context, clusterID string, interval time.Duration, s *store.MemoryStore, rootCA *api.RootCA, nodes []*api.Node) *rootRotationReconciler { + r := &rootRotationReconciler{ + ctx: ctx, + clusterID: clusterID, + store: s, + batchUpdateInterval: interval, + allNodes: make(map[string]*api.Node), + unconvergedNodes: make(map[string]struct{}), + } + r.UpdateRootCA(rootCA) + r.UpdateNodes(nodes...) + return r +} + +func (r *rootRotationReconciler) UpdateRootCA(newRootCA *api.RootCA) { + if newRootCA == nil { + return + } + issuerInfo, err := IssuerFromAPIRootCA(newRootCA) + if err != nil { + log.G(r.ctx).WithError(err).Error("unable to update process the current root CA") + } + r.mu.Lock() + r.currentRootCA = newRootCA + // check if the issuer has changed, first + if reflect.DeepEqual(&r.currentIssuer, issuerInfo) { + r.mu.Unlock() + return + } + // If the issuer has changed, iterate through all the nodes to figure out which ones need rotation + r.currentIssuer = *issuerInfo + r.unconvergedNodes = make(map[string]struct{}) + var ( + hasRootRotation bool + ctx context.Context + wg *sync.WaitGroup + ) + if r.currentRootCA.RootRotation != nil { + hasRootRotation = true + for _, n := range r.allNodes { + if hasIssuer(n, &r.currentIssuer) { + continue + } + r.unconvergedNodes[n.ID] = struct{}{} + } + if r.cancel != nil { // there's already a loop going, so cancel it + r.cancel() + wg = &r.wg + } + ctx, r.cancel = context.WithCancel(r.ctx) + } + r.mu.Unlock() + + if hasRootRotation { + if wg != nil { + wg.Wait() + } + go r.runReconcilerLoop(ctx, newRootCA) + } +} + +func (r *rootRotationReconciler) UpdateNodes(nodes ...*api.Node) { + r.mu.Lock() + for _, n := range nodes { + if n == nil || n.Spec.Membership != api.NodeMembershipAccepted { + continue + } + r.allNodes[n.ID] = n + if r.currentRootCA == nil || r.currentRootCA.RootRotation == nil { + continue + } + if hasIssuer(n, &r.currentIssuer) { + delete(r.unconvergedNodes, n.ID) + } else { + r.unconvergedNodes[n.ID] = struct{}{} + } + } + r.mu.Unlock() +} + +func (r *rootRotationReconciler) DeleteNode(node *api.Node) { + if node == nil { + return + } + r.mu.Lock() + delete(r.allNodes, node.ID) + delete(r.unconvergedNodes, node.ID) + r.mu.Unlock() +} + +func (r *rootRotationReconciler) runReconcilerLoop(ctx context.Context, loopRootCA *api.RootCA) { + r.wg.Add(1) + defer r.wg.Done() + for { + r.mu.Lock() + if len(r.unconvergedNodes) == 0 { + r.mu.Unlock() + + err := r.store.Update(func(tx store.Tx) error { + return r.finishRootRotation(tx, loopRootCA) + }) + if err == nil { + log.G(r.ctx).Infof("completed root rotation on cluster %s", r.clusterID) + return + } + log.G(r.ctx).WithError(err).Errorf("could not complete root rotation") + if err == errRootRotationChanged { + // if the root rotation has changed, this loop will be cancelled anyway, so may as well abort early + return + } + } else { + var toUpdate []*api.Node + for nodeID := range r.unconvergedNodes { + n, ok := r.allNodes[nodeID] + if !ok { // should never happen + continue + } + iState := n.Certificate.Status.State + if iState != api.IssuanceStateRenew&iState && iState != api.IssuanceStatePending && iState != api.IssuanceStateRotate { + n = n.Copy() + n.Certificate.Status.State = api.IssuanceStateRotate + toUpdate = append(toUpdate, n) + if len(toUpdate) >= IssuanceStateRotateBatchSize { + break + } + } + } + r.mu.Unlock() + + if err := r.batchUpdateNodes(toUpdate); err != nil { + log.G(r.ctx).WithError(err).Errorf("store error when trying to batch update %d nodes to request certificate rotation", len(toUpdate)) + } + } + + select { + case <-r.ctx.Done(): + return + case <-time.After(r.batchUpdateInterval): + } + } +} + +// This function assumes that the expected root CA has root rotation. This is intended to be used by +// `reconcileNodeRootsAndCerts`, which uses the root CA from the `lastSeenClusterRootCA`, and checks +// that it has a root rotation before calling this function. +func (r *rootRotationReconciler) finishRootRotation(tx store.Tx, expectedRootCA *api.RootCA) error { + cluster := store.GetCluster(tx, r.clusterID) + if cluster == nil { + return fmt.Errorf("unable to get cluster %s", r.clusterID) + } + + // If the RootCA object has changed (because another root rotation was started or because some other node + // had finished the root rotation), we cannot finish the root rotation that we were working on. + if !equality.RootCAEqualStable(expectedRootCA, &cluster.RootCA) { + return errRootRotationChanged + } + + var signerCert []byte + if len(cluster.RootCA.RootRotation.CAKey) > 0 { + signerCert = cluster.RootCA.RootRotation.CACert + } + // we don't actually have to parse out the default node expiration from the cluster - we are just using + // the ca.RootCA object to generate new tokens and the digest + updatedRootCA, err := NewRootCA(cluster.RootCA.RootRotation.CACert, signerCert, cluster.RootCA.RootRotation.CAKey, + DefaultNodeCertExpiration, nil) + if err != nil { + return errors.Wrap(err, "invalid cluster root rotation object") + } + cluster.RootCA = api.RootCA{ + CACert: cluster.RootCA.RootRotation.CACert, + CAKey: cluster.RootCA.RootRotation.CAKey, + CACertHash: updatedRootCA.Digest.String(), + JoinTokens: api.JoinTokens{ + Worker: GenerateJoinToken(&updatedRootCA), + Manager: GenerateJoinToken(&updatedRootCA), + }, + LastForcedRotation: cluster.RootCA.LastForcedRotation, + } + return store.UpdateCluster(tx, cluster) +} + +func (r *rootRotationReconciler) batchUpdateNodes(toUpdate []*api.Node) error { + if len(toUpdate) == 0 { + return nil + } + _, err := r.store.Batch(func(batch *store.Batch) error { + // Directly update the nodes rather than get + update, and ignore version errors. Since + // `rootRotationReconciler` should be hooked up to all node update/delete/create vents, we should have + // close to the latest versions of all the nodes. If not, the node will updated later and the + // next batch of updates should catch it. + for _, n := range toUpdate { + if err := batch.Update(func(tx store.Tx) error { + return store.UpdateNode(tx, n) + }); err != nil { + log.G(r.ctx).WithError(err).Debugf("unable to update node %s to request a certificate rotation") + } + } + return nil + }) + return err +} diff --git a/ca/server.go b/ca/server.go index 85c18eb75d..4e7aef1df3 100644 --- a/ca/server.go +++ b/ca/server.go @@ -4,12 +4,10 @@ import ( "bytes" "crypto/subtle" "crypto/x509" - "fmt" "sync" "time" "github.com/Sirupsen/logrus" - "github.com/cloudflare/cfssl/helpers" "github.com/docker/swarmkit/api" "github.com/docker/swarmkit/api/equality" "github.com/docker/swarmkit/identity" @@ -67,9 +65,8 @@ type Server struct { // before we update the security config with the new root CA, we need to be able to save the root certs rootPaths CertPaths - // lets us know if we're already doing a reconcile loop to tell nodes to rotate their - // TLS certificates and watching for root rotation completion - rootCAReconciliationInProgress bool + // lets us monitor and finish root rotations + rootReconciler *rootRotationReconciler rootReconciliationRetryInterval time.Duration } @@ -434,6 +431,7 @@ func (s *Server) Run(ctx context.Context) error { }, api.EventCreateNode{}, api.EventUpdateNode{}, + api.EventDeleteNode{}, ) // Do this after updateCluster has been called, so isRunning never @@ -442,7 +440,19 @@ func (s *Server) Run(ctx context.Context) error { s.ctx, s.cancel = context.WithCancel(ctx) ctx = s.ctx close(s.started) + // we need to set it on the server, because `Server.UpdateRootCA` can be called from outside the Run function + s.rootReconciler = newReconciler( + log.WithField(ctx, "method", "(*Server).rootRotationReconciler"), + s.securityConfig.ClientTLSCreds.Organization(), + s.rootReconciliationRetryInterval, + s.store, s.lastSeenClusterRootCA, nodes) + rootReconciler := s.rootReconciler s.mu.Unlock() + defer func() { + s.mu.Lock() + s.rootReconciler = nil + s.mu.Unlock() + }() if err != nil { log.G(ctx).WithFields(logrus.Fields{ @@ -461,7 +471,6 @@ func (s *Server) Run(ctx context.Context) error { "method": "(*Server).Run", }).WithError(err).Errorf("error attempting to reconcile certificates") } - s.reconcileNodeRootsAndCerts() ticker := time.NewTicker(s.reconciliationRetryInterval) defer ticker.Stop() @@ -480,13 +489,18 @@ func (s *Server) Run(ctx context.Context) error { switch v := event.(type) { case api.EventCreateNode: s.evaluateAndSignNodeCert(ctx, v.Node) + rootReconciler.UpdateNodes(v.Node) case api.EventUpdateNode: // If this certificate is already at a final state // no need to evaluate and sign it. if !isFinalState(v.Node.Certificate.Status) { s.evaluateAndSignNodeCert(ctx, v.Node) } + rootReconciler.UpdateNodes(v.Node) + case api.EventDeleteNode: + rootReconciler.DeleteNode(v.Node) } + case <-ticker.C: for _, node := range s.pending { if err := s.evaluateAndSignNodeCert(ctx, node); err != nil { @@ -557,12 +571,16 @@ func (s *Server) isRunning() bool { func (s *Server) UpdateRootCA(ctx context.Context, cluster *api.Cluster) error { s.mu.Lock() s.joinTokens = cluster.RootCA.JoinTokens.Copy() + reconciler := s.rootReconciler s.mu.Unlock() + rCA := cluster.RootCA.Copy() + if reconciler != nil { + reconciler.UpdateRootCA(rCA) + } s.secConfigMu.Lock() defer s.secConfigMu.Unlock() - rCA := cluster.RootCA - rootCAChanged := len(rCA.CACert) != 0 && !equality.RootCAEqualStable(s.lastSeenClusterRootCA, &cluster.RootCA) + rootCAChanged := len(rCA.CACert) != 0 && !equality.RootCAEqualStable(s.lastSeenClusterRootCA, rCA) externalCAChanged := !equality.ExternalCAsEqualStable(s.lastSeenExternalCAs, cluster.Spec.CAConfig.ExternalCAs) logger := log.G(ctx).WithFields(logrus.Fields{ "cluster.id": cluster.ID, @@ -619,7 +637,7 @@ func (s *Server) UpdateRootCA(ctx context.Context, cluster *api.Cluster) error { } // only update the server cache if we've successfully updated the root CA logger.Debug("Root CA updated successfully") - s.lastSeenClusterRootCA = cluster.RootCA.Copy() + s.lastSeenClusterRootCA = rCA } // we want to update if the external CA changed, or if the root CA changed because the root CA could affect what @@ -661,9 +679,6 @@ func (s *Server) UpdateRootCA(ctx context.Context, cluster *api.Cluster) error { s.securityConfig.externalCA.UpdateURLs(cfsslURLs...) s.lastSeenExternalCAs = cluster.Spec.CAConfig.Copy().ExternalCAs } - if rootCAChanged && s.lastSeenClusterRootCA.RootRotation != nil { - s.reconcileNodeRootsAndCerts() - } return nil } @@ -826,166 +841,3 @@ func isFinalState(status api.IssuanceStatus) bool { return false } - -// This function assumes that the expected root CA has root rotation. This is intended to be used by -// `reconcileNodeRootsAndCerts`, which uses the root CA from the `lastSeenClusterRootCA`, and checks -// that it has a root rotation before calling this function. -func (s *Server) finishRootRotation(tx store.Tx, expectedRootCA *api.RootCA) error { - clusterID := s.securityConfig.ClientTLSCreds.Organization() - cluster := store.GetCluster(tx, clusterID) - if cluster == nil { - return fmt.Errorf("unable to get cluster %s", clusterID) - } - - // If the RootCA object has changed (because another root rotation was started or because some other node - // had finished the root rotation), we cannot finish the root rotation that we were working on. - if !equality.RootCAEqualStable(expectedRootCA, &cluster.RootCA) { - return errors.New("target root rotation has changed") - } - - var signerCert []byte - if len(cluster.RootCA.RootRotation.CAKey) > 0 { - signerCert = cluster.RootCA.RootRotation.CACert - } - // we don't actually have to parse out the default node expiration from the cluster - we are just using - // the ca.RootCA object to generate new tokens and the digest - updatedRootCA, err := NewRootCA(cluster.RootCA.RootRotation.CACert, signerCert, cluster.RootCA.RootRotation.CAKey, - DefaultNodeCertExpiration, nil) - if err != nil { - return errors.Wrap(err, "invalid cluster root rotation object") - } - cluster.RootCA = api.RootCA{ - CACert: cluster.RootCA.RootRotation.CACert, - CAKey: cluster.RootCA.RootRotation.CAKey, - CACertHash: updatedRootCA.Digest.String(), - JoinTokens: api.JoinTokens{ - Worker: GenerateJoinToken(&updatedRootCA), - Manager: GenerateJoinToken(&updatedRootCA), - }, - LastForcedRotation: cluster.RootCA.LastForcedRotation, - } - return store.UpdateCluster(tx, cluster) -} - -func (s *Server) reconcileNodeRootsAndCerts() { - s.mu.Lock() - if !s.isRunning() || s.rootCAReconciliationInProgress { - s.mu.Unlock() - return - } - ctx := s.ctx - s.rootCAReconciliationInProgress = true - s.mu.Unlock() - - s.wg.Add(1) - logger := log.G(ctx).WithField("method", "(*Server).reconcileNodeRootsAndCerts") - - go func(retryInterval time.Duration) { - defer func() { - s.wg.Done() - s.mu.Lock() - s.rootCAReconciliationInProgress = false - s.mu.Unlock() - }() - - for { - s.secConfigMu.Lock() - rootCA := s.lastSeenClusterRootCA - s.secConfigMu.Unlock() - if rootCA == nil { - return - } - - wantedIssuer := rootCA.CACert - if rootCA.RootRotation != nil { - wantedIssuer = rootCA.RootRotation.CACert - } - issuerCert, err := helpers.ParseCertificatePEM(wantedIssuer) - if err != nil { - logger.WithError(err).Error("invalid certificate in cluster root CA object") - } - - // If the last seen root rotation is nil, then possibly during a leadership transfer a different manager's - // CA node started root rotation reconciliation and finished the root rotation, so we should also end the - // root reconciliation - if rootCA.RootRotation == nil { - return - } - - var allNodes []*api.Node - - s.store.View(func(tx store.ReadTx) { - allNodes, err = store.FindNodes(tx, store.ByMembership(api.NodeMembershipAccepted)) - }) - if err != nil { - logger.WithError(err).Error("could not find all nodes in the store") - } - - var rootRotationDone bool - _, err = s.store.Batch(func(batch *store.Batch) error { - converged := true // this value is true if all the node certs are issued by the correct issuer - for _, node := range allNodes { - var needUpdate bool - err := batch.Update(func(tx store.Tx) error { - n := store.GetNode(tx, node.ID) - if n == nil || (n.Description != nil && n.Description.TLSInfo != nil && - bytes.Equal(n.Description.TLSInfo.CertIssuerPublicKey, issuerCert.RawSubjectPublicKeyInfo) && - bytes.Equal(n.Description.TLSInfo.CertIssuerSubject, issuerCert.RawSubject)) { - return nil - } - converged = false - - // If we are already waiting for the node to rotate (or the node's about to get issued a new cert by us), - // so don't bother telling the node to rotate again. If a cert is in renew or pending, the cert that - // the node will get should be issued by the correct issuer. However, in all likelihood the node will - // not get a chance to report back its TLS before the reconciliation loop will mark it as rotate, and - // it will get a new certificate, which is fine, but results in an extra certificate renewal round. - issuanceState := n.Certificate.Status.State - if issuanceState == api.IssuanceStateRotate || issuanceState == api.IssuanceStateRenew || issuanceState == api.IssuanceStatePending { - return nil - } - n.Certificate.Status.State = api.IssuanceStateRotate - needUpdate = true - return store.UpdateNode(tx, n) - }) - if err != nil { - logger.WithError(err).Debugf("could not update node %s to IssuanceStateRotate", node.ID) - } else if needUpdate { - logger.Debugf("updated node %s to IssuanceStateRotate", node.ID) - } - } - // It's possible that between getting all nodes and finishing root rotation, new nodes have joined. by the time we - // fetch all nodes, the CA signer should have already been changed to be the desired issuer by the time we list all - // nodes, so any new nodes that appear after that will only get certificates signed by the correct issuer. - if converged { - // when the nodes are all converged on having TLS certificates issued by the correct entity, we can - // try to finish the root rotation - err := batch.Update(func(tx store.Tx) error { - return s.finishRootRotation(tx, rootCA) - }) - if err != nil { - logger.WithError(err).Errorf("could not complete root rotation") - } else { - rootRotationDone = true - } - } - - return nil - }) - if err != nil { - logger.WithError(err).Errorf("failed to check nodes for desired certificate issuer") - } - if rootRotationDone { - logger.Infof("completed root rotation on cluster %s", s.securityConfig.ServerTLSCreds.Organization()) - return - } - - select { - case <-ctx.Done(): - logger.Info("stopping incomplete root rotation reconciliation due to context being stopped") - return - case <-time.After(retryInterval): - } - } - }(s.rootReconciliationRetryInterval) -} diff --git a/ca/server_test.go b/ca/server_test.go index 9bc1cdf1b0..9f41545653 100644 --- a/ca/server_test.go +++ b/ca/server_test.go @@ -563,60 +563,50 @@ type rootRotationTester struct { // go through all the nodes and update/create the ones we want, and delete the ones // we don't func (r *rootRotationTester) convergeWantedNodes(wantNodes map[string]*api.Node, descr string) { - updatedOldNodes := make(map[string]struct{}) - createdNewNodes := make(map[string]struct{}) - require.NoError(r.t, testutils.PollFuncWithTimeout(nil, func() error { - return r.tc.MemoryStore.Update(func(tx store.Tx) error { - nodes, err := store.FindNodes(tx, store.All) - if err != nil { - return err - } - for _, node := range nodes { - wanted, inWanted := wantNodes[node.ID] - _, done := updatedOldNodes[node.ID] - - if inWanted && !done { - node.Description = wanted.Description - node.Certificate = wanted.Certificate - if err := store.UpdateNode(tx, node); err != nil { - return err - } - updatedOldNodes[node.ID] = struct{}{} - } else if !inWanted { - if err := store.DeleteNode(tx, node.ID); err != nil { - return err - } - updatedOldNodes[node.ID] = struct{}{} + // update existing and create new nodes first before deleting nodes, else a root rotation + // may finish early if all the nodes get deleted when the root rotation happens + require.NoError(r.t, r.tc.MemoryStore.Update(func(tx store.Tx) error { + for nodeID, wanted := range wantNodes { + node := store.GetNode(tx, nodeID) + if node == nil { + if err := store.CreateNode(tx, wanted); err != nil { + return err } + continue + } + node.Description = wanted.Description + node.Certificate = wanted.Certificate + if err := store.UpdateNode(tx, node); err != nil { + return err } - for nodeID, wanted := range wantNodes { - _, createdAlready := createdNewNodes[nodeID] - if _, ok := updatedOldNodes[nodeID]; !ok && !createdAlready { - if err := store.CreateNode(tx, wanted); err != nil { - return err - } - createdNewNodes[nodeID] = struct{}{} + } + nodes, err := store.FindNodes(tx, store.All) + if err != nil { + return err + } + for _, node := range nodes { + if _, inWanted := wantNodes[node.ID]; !inWanted { + if err := store.DeleteNode(tx, node.ID); err != nil { + return err } } - return nil - }) - }, 2*time.Second), descr) + } + return nil + }), descr) } func (r *rootRotationTester) convergeRootCA(wantRootCA *api.RootCA, descr string) { - require.NoError(r.t, testutils.PollFuncWithTimeout(nil, func() error { - return r.tc.MemoryStore.Update(func(tx store.Tx) error { - clusters, err := store.FindClusters(tx, store.All) - if err != nil || len(clusters) != 1 { - return errors.Wrap(err, "unable to find cluster") - } - clusters[0].RootCA = *wantRootCA - return store.UpdateCluster(tx, clusters[0]) - }) - }, time.Second), descr) + require.NoError(r.t, r.tc.MemoryStore.Update(func(tx store.Tx) error { + clusters, err := store.FindClusters(tx, store.All) + if err != nil || len(clusters) != 1 { + return errors.Wrap(err, "unable to find cluster") + } + clusters[0].RootCA = *wantRootCA + return store.UpdateCluster(tx, clusters[0]) + }), descr) } -func (r *rootRotationTester) getFakeAPINode(id string, state api.IssuanceStatus_State, tlsInfo *api.NodeTLSInfo, member bool) *api.Node { +func getFakeAPINode(t *testing.T, id string, state api.IssuanceStatus_State, tlsInfo *api.NodeTLSInfo, member bool) *api.Node { node := &api.Node{ ID: id, Certificate: api.Certificate{ @@ -634,7 +624,7 @@ func (r *rootRotationTester) getFakeAPINode(id string, state api.IssuanceStatus_ // the CA server will immediately pick these up, so generate CSRs for the CA server to sign if state == api.IssuanceStateRenew || state == api.IssuanceStatePending { csr, _, err := ca.GenerateNewCSR() - require.NoError(r.t, err) + require.NoError(t, err) node.Certificate.CSR = csr } if tlsInfo != nil { @@ -721,12 +711,12 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { "a new cert. Any renew/pending nodes will have certs issued, but because the TLS info is nil, they will " + `go "rotate" state`), nodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, nil, true), - "2": rt.getFakeAPINode("2", api.IssuanceStateRenew, nil, true), - "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, nil, true), - "4": rt.getFakeAPINode("4", api.IssuanceStatePending, nil, true), - "5": rt.getFakeAPINode("5", api.IssuanceStateFailed, nil, true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, nil, true), + "2": getFakeAPINode(t, "2", api.IssuanceStateRenew, nil, true), + "3": getFakeAPINode(t, "3", api.IssuanceStateRotate, nil, true), + "4": getFakeAPINode(t, "4", api.IssuanceStatePending, nil, true), + "5": getFakeAPINode(t, "5", api.IssuanceStateFailed, nil, true), }, rootCA: &api.RootCA{ CACert: startCluster.RootCA.CACert, @@ -739,12 +729,12 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { }, }, expectedNodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateRotate, nil, true), - "2": rt.getFakeAPINode("2", api.IssuanceStateRotate, nil, true), - "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, nil, true), - "4": rt.getFakeAPINode("4", api.IssuanceStateRotate, nil, true), - "5": rt.getFakeAPINode("5", api.IssuanceStateRotate, nil, true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateRotate, nil, true), + "2": getFakeAPINode(t, "2", api.IssuanceStateRotate, nil, true), + "3": getFakeAPINode(t, "3", api.IssuanceStateRotate, nil, true), + "4": getFakeAPINode(t, "4", api.IssuanceStateRotate, nil, true), + "5": getFakeAPINode(t, "5", api.IssuanceStateRotate, nil, true), }, }, { @@ -752,12 +742,12 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { "(going by the TLS info), which shouldn't really happen. the rotation reconciliation " + "will tell the wrong ones to rotate a second time"), nodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "2": rt.getFakeAPINode("2", api.IssuanceStateIssued, oldNodeTLSInfo, true), - "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "5": rt.getFakeAPINode("5", api.IssuanceStateIssued, oldNodeTLSInfo, true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "2": getFakeAPINode(t, "2", api.IssuanceStateIssued, oldNodeTLSInfo, true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": getFakeAPINode(t, "5", api.IssuanceStateIssued, oldNodeTLSInfo, true), }, rootCA: &api.RootCA{ // no change in root CA from previous CACert: startCluster.RootCA.CACert, @@ -770,23 +760,23 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { }, }, expectedNodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "2": rt.getFakeAPINode("2", api.IssuanceStateRotate, oldNodeTLSInfo, true), - "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "5": rt.getFakeAPINode("5", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "2": getFakeAPINode(t, "2", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": getFakeAPINode(t, "5", api.IssuanceStateRotate, oldNodeTLSInfo, true), }, }, { descr: ("New nodes that are added will also be picked up and told to rotate"), nodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "5": rt.getFakeAPINode("5", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "6": rt.getFakeAPINode("6", api.IssuanceStateRenew, nil, true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": getFakeAPINode(t, "5", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "6": getFakeAPINode(t, "6", api.IssuanceStateRenew, nil, true), }, rootCA: &api.RootCA{ // no change in root CA from previous CACert: startCluster.RootCA.CACert, @@ -799,12 +789,12 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { }, }, expectedNodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "5": rt.getFakeAPINode("5", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "6": rt.getFakeAPINode("6", api.IssuanceStateRotate, nil, true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": getFakeAPINode(t, "5", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "6": getFakeAPINode(t, "6", api.IssuanceStateRotate, nil, true), }, }, { @@ -812,12 +802,12 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { "different cert, all the nodes with the old root rotation cert will be told " + "to rotate again."), nodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[1], true), - "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "5": rt.getFakeAPINode("5", api.IssuanceStateIssued, oldNodeTLSInfo, true), - "6": rt.getFakeAPINode("6", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": getFakeAPINode(t, "5", api.IssuanceStateIssued, oldNodeTLSInfo, true), + "6": getFakeAPINode(t, "6", api.IssuanceStateIssued, rotationTLSInfo[0], true), }, rootCA: &api.RootCA{ // new root rotation CACert: startCluster.RootCA.CACert, @@ -830,23 +820,23 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { }, }, expectedNodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateRotate, rotationTLSInfo[0], true), - "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[1], true), - "4": rt.getFakeAPINode("4", api.IssuanceStateRotate, rotationTLSInfo[0], true), - "5": rt.getFakeAPINode("5", api.IssuanceStateRotate, oldNodeTLSInfo, true), - "6": rt.getFakeAPINode("6", api.IssuanceStateRotate, rotationTLSInfo[0], true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateRotate, rotationTLSInfo[0], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateRotate, rotationTLSInfo[0], true), + "5": getFakeAPINode(t, "5", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "6": getFakeAPINode(t, "6", api.IssuanceStateRotate, rotationTLSInfo[0], true), }, }, { descr: ("Once all nodes have rotated to their desired TLS info (even if it's because " + "a node with the wrong TLS info has been removed, the root rotation is completed."), nodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[1], true), - "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[1], true), - "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[1], true), - "6": rt.getFakeAPINode("6", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "6": getFakeAPINode(t, "6", api.IssuanceStateIssued, rotationTLSInfo[1], true), }, rootCA: &api.RootCA{ // no change in root CA from previous - even if root rotation gets completed after @@ -874,11 +864,11 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { "it will start reconciling the nodes as soon as it's started up again"), caServerRestart: true, nodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[1], true), - "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[1], true), - "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[1], true), - "6": rt.getFakeAPINode("6", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "6": getFakeAPINode(t, "6", api.IssuanceStateIssued, rotationTLSInfo[1], true), }, rootCA: &api.RootCA{ CACert: startCluster.RootCA.CACert, @@ -891,11 +881,11 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { }, }, expectedNodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateRotate, rotationTLSInfo[1], true), - "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, rotationTLSInfo[1], true), - "4": rt.getFakeAPINode("4", api.IssuanceStateRotate, rotationTLSInfo[1], true), - "6": rt.getFakeAPINode("6", api.IssuanceStateRotate, rotationTLSInfo[1], true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateRotate, rotationTLSInfo[1], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateRotate, rotationTLSInfo[1], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateRotate, rotationTLSInfo[1], true), + "6": getFakeAPINode(t, "6", api.IssuanceStateRotate, rotationTLSInfo[1], true), }, }, } @@ -1005,12 +995,12 @@ func TestRootRotationReconciliationNoChanges(t *testing.T) { "is in progress"), caServerStopped: true, nodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, oldNodeTLSInfo, true), - "2": rt.getFakeAPINode("2", api.IssuanceStateRenew, nil, true), - "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, nil, true), - "4": rt.getFakeAPINode("4", api.IssuanceStatePending, nil, true), - "5": rt.getFakeAPINode("5", api.IssuanceStateFailed, nil, true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, oldNodeTLSInfo, true), + "2": getFakeAPINode(t, "2", api.IssuanceStateRenew, nil, true), + "3": getFakeAPINode(t, "3", api.IssuanceStateRotate, nil, true), + "4": getFakeAPINode(t, "4", api.IssuanceStatePending, nil, true), + "5": getFakeAPINode(t, "5", api.IssuanceStateFailed, nil, true), }, rootCA: &api.RootCA{ CACert: startCluster.RootCA.CACert, @@ -1027,10 +1017,10 @@ func TestRootRotationReconciliationNoChanges(t *testing.T) { descr: ("If all nodes have the right TLS info or are already rotated (or are not members), the " + "there will be no changes needed"), nodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo, true), - "2": rt.getFakeAPINode("2", api.IssuanceStateRotate, oldNodeTLSInfo, true), - "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, rotationTLSInfo, true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo, true), + "2": getFakeAPINode(t, "2", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "3": getFakeAPINode(t, "3", api.IssuanceStateRotate, rotationTLSInfo, true), }, rootCA: &api.RootCA{ // no change in root CA from previous CACert: startCluster.RootCA.CACert, @@ -1049,9 +1039,9 @@ func TestRootRotationReconciliationNoChanges(t *testing.T) { "in the process of getting a new cert. Even if they're issued by a different issuer, they will be " + "left alone because they'll have an interemdiate that chains up to the old issuer."), nodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo, true), - "2": rt.getFakeAPINode("2", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo, true), + "2": getFakeAPINode(t, "2", api.IssuanceStateRotate, oldNodeTLSInfo, true), }, rootCA: &api.RootCA{ // no change in root CA from previous CACert: startCluster.RootCA.CACert, @@ -1131,10 +1121,7 @@ func TestRootRotationReconciliationRace(t *testing.T) { } clusterWatch, clusterWatchCancel, err := store.ViewAndWatch( tc.MemoryStore, func(tx store.ReadTx) error { - cluster := store.GetCluster(tx, tc.Organization) - for _, s := range otherServers { - s.UpdateRootCA(context.Background(), cluster) - } + // don't bother getting the cluster - the CA serverß have already done that when first running return nil }, api.EventUpdateCluster{ @@ -1170,7 +1157,7 @@ func TestRootRotationReconciliationRace(t *testing.T) { nodes := make(map[string]*api.Node) for i := 0; i < 5; i++ { nodeID := fmt.Sprintf("%d", i) - nodes[nodeID] = rt.getFakeAPINode(nodeID, api.IssuanceStateIssued, oldNodeTLSInfo, true) + nodes[nodeID] = getFakeAPINode(t, nodeID, api.IssuanceStateIssued, oldNodeTLSInfo, true) } rt.convergeWantedNodes(nodes, "setting up nodes for root rotation race condition test") @@ -1218,5 +1205,111 @@ func TestRootRotationReconciliationRace(t *testing.T) { return fmt.Errorf("root rotation is still present") } return nil - }, time.Second)) + }, 5*time.Second)) +} + +// If there are a lot of nodes, we only update a small number of them at once. +func TestRootRotationReconciliationThrottled(t *testing.T) { + t.Parallel() + if cautils.External { + // the external CA functionality is unrelated to testing the reconciliation loop + return + } + + tc := cautils.NewTestCA(t) + defer tc.Stop() + // immediately stop the CA server - we want to run our down + tc.CAServer.Stop() + + caServer := ca.NewServer(tc.MemoryStore, tc.ServingSecurityConfig, tc.Paths.RootCA) + // set the reconciliation interval to something ridiculous, so we can make sure the first + // batch does update all of them + caServer.SetRootReconciliationInterval(time.Hour) + startCAServer(caServer) + defer caServer.Stop() + + var nodes []*api.Node + clusterWatch, clusterWatchCancel, err := store.ViewAndWatch( + tc.MemoryStore, func(tx store.ReadTx) error { + // don't bother getting the cluster - the CA server has already done that when first running + var err error + nodes, err = store.FindNodes(tx, store.ByMembership(api.NodeMembershipAccepted)) + return err + }, + api.EventUpdateCluster{ + Cluster: &api.Cluster{ID: tc.Organization}, + Checks: []api.ClusterCheckFunc{api.ClusterCheckID}, + }, + ) + require.NoError(t, err) + defer clusterWatchCancel() + + done := make(chan struct{}) + defer close(done) + go func() { + for { + select { + case event := <-clusterWatch: + clusterEvent := event.(api.EventUpdateCluster) + caServer.UpdateRootCA(context.Background(), clusterEvent.Cluster) + case <-done: + return + } + } + }() + + // create twice the batch size of nodes + _, err = tc.MemoryStore.Batch(func(batch *store.Batch) error { + for i := len(nodes); i < ca.IssuanceStateRotateBatchSize*2; i++ { + nodeID := fmt.Sprintf("%d", i) + err := batch.Update(func(tx store.Tx) error { + return store.CreateNode(tx, getFakeAPINode(t, nodeID, api.IssuanceStateIssued, nil, true)) + }) + if err != nil { + return err + } + } + return nil + }) + require.NoError(t, err) + + rotationCert := cautils.ECDSA256SHA256Cert + rotationKey := cautils.ECDSA256Key + rotationCrossSigned, _ := getRotationInfo(t, rotationCert, &tc.RootCA) + + require.NoError(t, tc.MemoryStore.Update(func(tx store.Tx) error { + cluster := store.GetCluster(tx, tc.Organization) + if cluster == nil { + return errors.New("cluster has disappeared") + } + rootCA := cluster.RootCA.Copy() + rootCA.RootRotation = &api.RootRotation{ + CACert: rotationCert, + CAKey: rotationKey, + CrossSignedCACert: rotationCrossSigned, + } + cluster.RootCA = *rootCA + return store.UpdateCluster(tx, cluster) + })) + + checkRotationNumber := func() error { + tc.MemoryStore.View(func(tx store.ReadTx) { + nodes, err = store.FindNodes(tx, store.All) + }) + var issuanceRotate int + for _, n := range nodes { + if n.Certificate.Status.State == api.IssuanceStateRotate { + issuanceRotate += 1 + } + } + if issuanceRotate != ca.IssuanceStateRotateBatchSize { + return fmt.Errorf("expected %d, got %d", ca.IssuanceStateRotateBatchSize, issuanceRotate) + } + return nil + } + + require.NoError(t, testutils.PollFuncWithTimeout(nil, checkRotationNumber, 5*time.Second)) + // prove that it's not just because the updates haven't finished + time.Sleep(time.Second) + require.NoError(t, checkRotationNumber()) } diff --git a/integration/integration_test.go b/integration/integration_test.go index d676987afe..d521222ec8 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -627,27 +627,17 @@ func TestSuccessfulRootRotation(t *testing.T) { return nil }, opsTimeout)) - // Kill the leader, and bring the other manager back to show that a new - // leader can pick up the reconciliation loop and complete the root rotation. - // Bring one worker back, kill the other worker, and add a new worker - show - // that we can converge on a root rotation. - downLeader, err := cl.Leader() - require.NoError(t, err) - leaderID := downLeader.node.NodeID() - + // Bring the other manager back. Also bring one worker back, kill the other worker, + // and add a new worker - show that we can converge on a root rotation. require.NoError(t, cl.StartNode(downManagerID)) require.NoError(t, cl.StartNode(downWorkerIDs[0])) require.NoError(t, cl.RemoveNode(downWorkerIDs[1], false)) require.NoError(t, cl.AddAgent()) - require.NoError(t, downLeader.Pause(false)) // we can finish root rotation even though the previous leader was down because it had // already rotated its cert pollRootRotationDone(t, cl) - // bring the previous leader back up so it can get the new root - require.NoError(t, cl.StartNode(leaderID)) - // wait until all the nodes have gotten their new certs and trust roots require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { resp, err = cl.api.ListNodes(context.Background(), &api.ListNodesRequest{}) From 61e3041d168a5d1265153bda44cd4d746d9eac93 Mon Sep 17 00:00:00 2001 From: cyli Date: Fri, 7 Apr 2017 13:28:54 -0700 Subject: [PATCH 4/5] Add support for rotating root certificates in the swarmctl CLI. Signed-off-by: cyli --- cmd/swarmctl/cluster/inspect.go | 6 + cmd/swarmctl/cluster/root_rotation.go | 201 ++++++++++++++++++++++++++ cmd/swarmctl/cluster/update.go | 9 ++ 3 files changed, 216 insertions(+) create mode 100644 cmd/swarmctl/cluster/root_rotation.go diff --git a/cmd/swarmctl/cluster/inspect.go b/cmd/swarmctl/cluster/inspect.go index 409b3d0dcf..4e0f848850 100644 --- a/cmd/swarmctl/cluster/inspect.go +++ b/cmd/swarmctl/cluster/inspect.go @@ -10,6 +10,7 @@ import ( "github.com/docker/swarmkit/api" "github.com/docker/swarmkit/cmd/swarmctl/common" gogotypes "github.com/gogo/protobuf/types" + "github.com/opencontainers/go-digest" "github.com/spf13/cobra" ) @@ -37,6 +38,11 @@ func printClusterSummary(cluster *api.Cluster) { fmt.Fprintf(w, " Certificate Validity Duration: %s\n", clusterDuration.String()) } } + + if len(cluster.Spec.CAConfig.SigningCACert) > 0 { + fmt.Fprintf(w, " Desired CA Cert Digest: %s\n", digest.FromBytes(cluster.Spec.CAConfig.SigningCACert).String()) + } + fmt.Fprintf(w, " ForceRotate number: %d\n", cluster.Spec.CAConfig.ForceRotate) if len(cluster.Spec.CAConfig.ExternalCAs) > 0 { fmt.Fprintln(w, " External CAs:") for _, ca := range cluster.Spec.CAConfig.ExternalCAs { diff --git a/cmd/swarmctl/cluster/root_rotation.go b/cmd/swarmctl/cluster/root_rotation.go new file mode 100644 index 0000000000..1108ee682a --- /dev/null +++ b/cmd/swarmctl/cluster/root_rotation.go @@ -0,0 +1,201 @@ +package cluster + +import ( + "bytes" + "fmt" + "io" + "os" + "os/signal" + "time" + + "github.com/docker/docker/cli/command" + "github.com/docker/docker/pkg/jsonmessage" + "github.com/docker/docker/pkg/progress" + "github.com/docker/docker/pkg/streamformatter" + "github.com/docker/swarmkit/api" + "github.com/docker/swarmkit/ca" + digest "github.com/opencontainers/go-digest" + "golang.org/x/net/context" +) + +const ( + certsRotatedStr = " nodes rotated TLS certificates" + rootsRotatedStr = " nodes rotated CA certificates" +) + +// rootRotationProgress outputs progress information for convergence of a root rotation. +func rootRotationProgress(ctx context.Context, client api.ControlClient, clusterID string, progressWriter io.WriteCloser) error { + defer progressWriter.Close() + + progressOut := streamformatter.NewJSONStreamFormatter().NewProgressOutput(progressWriter, false) + + sigint := make(chan os.Signal, 1) + signal.Notify(sigint, os.Interrupt) + defer signal.Stop(sigint) + + var ( + updater *rootRotationProgressUpdater + converged bool + convergedAt time.Time + monitor = 3 * time.Second + ) + + for { + clusterResp, err := client.GetCluster(ctx, &api.GetClusterRequest{ClusterID: clusterID}) + if err != nil { + return err + } + + issuerInfo, err := ca.IssuerFromAPIRootCA(&clusterResp.Cluster.RootCA) + if err != nil { + return err + } + desiredTLSInfo := api.NodeTLSInfo{ + TrustRoot: clusterResp.Cluster.RootCA.CACert, + CertIssuerPublicKey: issuerInfo.PublicKey, + CertIssuerSubject: issuerInfo.Subject, + } + + if updater == nil { + updater = &rootRotationProgressUpdater{ + progressOut: progressOut, + } + } + + if converged && time.Since(convergedAt) >= monitor { + return nil + } + + nodesListResp, err := client.ListNodes(ctx, &api.ListNodesRequest{}) + if err != nil { + return err + } + + updater.update(&desiredTLSInfo, nodesListResp.Nodes, clusterResp.Cluster.RootCA.RootRotation == nil) + converged = updater.done + if converged { + if convergedAt.IsZero() { + convergedAt = time.Now() + } + wait := monitor - time.Since(convergedAt) + if wait >= 0 { + progressOut.WriteProgress(progress.Progress{ + // Ideally this would have no ID, but + // the progress rendering code behaves + // poorly on an "action" with no ID. It + // returns the cursor to the beginning + // of the line, so the first character + // may be difficult to read. Then the + // output is overwritten by the shell + // prompt when the command finishes. + ID: "verify", + Action: fmt.Sprintf("Waiting %d seconds to verify that the roots are stable...", wait/time.Second+1), + }) + } + } else { + if !convergedAt.IsZero() { + progressOut.WriteProgress(progress.Progress{ + ID: "verify", + Action: "detected another root rotation change", + }) + } + convergedAt = time.Time{} + } + + select { + case <-time.After(200 * time.Millisecond): + case <-sigint: + if !converged { + progress.Message(progressOut, "", "Operation continuing in background.") + progress.Message(progressOut, "", "Use `swarmctl cluster inspect default` to check progress.") + } + return nil + } + } +} + +type rootRotationProgressUpdater struct { + progressOut progress.Output + initialized bool + done bool +} + +func (r *rootRotationProgressUpdater) update(desiredTLSInfo *api.NodeTLSInfo, nodes []*api.Node, rootRotationDone bool) { + // write the current desired root cert + r.progressOut.WriteProgress(progress.Progress{ + ID: "desired root digest", + Action: digest.FromBytes(desiredTLSInfo.TrustRoot).String(), + }) + + if !r.initialized { + // draw 2 progress bars, 1 for nodes with the correct cert, 1 for nodes with the correct trust root + progress.Update(r.progressOut, certsRotatedStr, " ") + progress.Update(r.progressOut, rootsRotatedStr, " ") + r.initialized = true + } + + // If we had reached a converged state, check if we are still converged. + var certsRight, trustRootsRight int64 + for _, n := range nodes { + if n.Description == nil || n.Description.TLSInfo == nil { + continue + } + + if bytes.Equal(n.Description.TLSInfo.CertIssuerPublicKey, desiredTLSInfo.CertIssuerPublicKey) && + bytes.Equal(n.Description.TLSInfo.CertIssuerSubject, desiredTLSInfo.CertIssuerSubject) { + certsRight++ + } + + if bytes.Equal(n.Description.TLSInfo.TrustRoot, desiredTLSInfo.TrustRoot) { + trustRootsRight++ + } + } + + total := int64(len(nodes)) + certsAction := fmt.Sprintf("%d/%d done", certsRight, total) + r.progressOut.WriteProgress(progress.Progress{ + ID: certsRotatedStr, + Action: certsAction, + Current: certsRight, + Total: total, + HideCounts: true, + }) + + if certsRight == total && rootRotationDone { + rootsAction := fmt.Sprintf("%d/%d done", trustRootsRight, total) + r.progressOut.WriteProgress(progress.Progress{ + ID: rootsRotatedStr, + Action: rootsAction, + Current: trustRootsRight, + Total: total, + HideCounts: true, + }) + r.done = certsRight == total && trustRootsRight == total + } else { + rootsAction := fmt.Sprintf("%d/%d done", 0, total) + r.progressOut.WriteProgress(progress.Progress{ + ID: rootsRotatedStr, + Action: rootsAction, + Current: 0, + Total: total, + HideCounts: true, + }) + r.done = false + } +} + +// WaitOnRootRotation waits for the root rotation to converge. It outputs a progress bar. +func WaitOnRootRotation(ctx context.Context, client api.ControlClient, clusterID string) error { + errChan := make(chan error, 1) + pipeReader, pipeWriter := io.Pipe() + + go func() { + errChan <- rootRotationProgress(ctx, client, clusterID, pipeWriter) + }() + + err := jsonmessage.DisplayJSONMessagesToStream(pipeReader, command.NewOutStream(os.Stdout), nil) + if err == nil { + err = <-errChan + } + return err +} diff --git a/cmd/swarmctl/cluster/update.go b/cmd/swarmctl/cluster/update.go index defaf86c78..959ec2c128 100644 --- a/cmd/swarmctl/cluster/update.go +++ b/cmd/swarmctl/cluster/update.go @@ -101,6 +101,10 @@ var ( } spec.TaskDefaults.LogDriver = driver + if flags.Changed("rotate-ca") { + spec.CAConfig.ForceRotate++ + } + r, err := c.UpdateCluster(common.Context(cmd), &api.UpdateClusterRequest{ ClusterID: cluster.ID, ClusterVersion: &cluster.Meta.Version, @@ -115,6 +119,10 @@ var ( if rotation.ManagerUnlockKey { return displayUnlockKey(cmd) } + + if flags.Changed("rotate-ca") { + WaitOnRootRotation(common.Context(cmd), c, cluster.ID) + } return nil }, } @@ -131,4 +139,5 @@ func init() { updateCmd.Flags().String("rotate-join-token", "", "Rotate join token for worker or manager") updateCmd.Flags().Bool("rotate-unlock-key", false, "Rotate manager unlock key") updateCmd.Flags().Bool("autolock", false, "Enable or disable manager autolocking (requiring an unlock key to start a stopped manager)") + updateCmd.Flags().Bool("rotate-ca", false, "Rotate the root CA certificate and key for the cluster") } From 9745dafd4c506e4ebf8bff643d6d79d6aceae282 Mon Sep 17 00:00:00 2001 From: cyli Date: Fri, 7 Apr 2017 14:00:31 -0700 Subject: [PATCH 5/5] Include TLS information when listing and inspecting nodes Signed-off-by: cyli --- cmd/swarmctl/node/common.go | 13 +++++++++++++ cmd/swarmctl/node/inspect.go | 29 +++++++++++++++++++++++++++-- cmd/swarmctl/node/list.go | 28 ++++++++++++++++++++++++++-- 3 files changed, 66 insertions(+), 4 deletions(-) diff --git a/cmd/swarmctl/node/common.go b/cmd/swarmctl/node/common.go index 2d0beb9ad3..9da4b53043 100644 --- a/cmd/swarmctl/node/common.go +++ b/cmd/swarmctl/node/common.go @@ -214,3 +214,16 @@ func updateNode(cmd *cobra.Command, args []string) error { return nil } + +func getCluster(ctx context.Context, c api.ControlClient) (*api.Cluster, error) { + rl, err := c.ListClusters(ctx, &api.ListClustersRequest{}) + if err != nil { + return nil, err + } + + if len(rl.Clusters) == 0 { + return nil, fmt.Errorf("no clusters found") + } + + return rl.Clusters[0], nil +} diff --git a/cmd/swarmctl/node/inspect.go b/cmd/swarmctl/node/inspect.go index eb4f8acebb..62309e3bfe 100644 --- a/cmd/swarmctl/node/inspect.go +++ b/cmd/swarmctl/node/inspect.go @@ -1,6 +1,7 @@ package node import ( + "bytes" "errors" "fmt" "os" @@ -8,13 +9,14 @@ import ( "text/tabwriter" "github.com/docker/swarmkit/api" + "github.com/docker/swarmkit/ca" "github.com/docker/swarmkit/cmd/swarmctl/common" "github.com/docker/swarmkit/cmd/swarmctl/task" "github.com/dustin/go-humanize" "github.com/spf13/cobra" ) -func printNodeSummary(node *api.Node) { +func printNodeSummary(node *api.Node, clusterIssuer *ca.IssuerInfo, clusterCACert []byte) { w := tabwriter.NewWriter(os.Stdout, 8, 8, 8, ' ', 0) defer func() { // Ignore flushing errors - there's nothing we can do. @@ -105,6 +107,20 @@ func printNodeSummary(node *api.Node) { fmt.Fprintln(w) } } + + if desc.TLSInfo != nil { + fmt.Fprintln(w, "TLS status\t:") + if bytes.Equal(clusterCACert, desc.TLSInfo.TrustRoot) { + fmt.Fprintln(w, " Trusts current cluster root CA") + } else { + fmt.Fprintln(w, " Does not trust current cluster root CA") + } + if bytes.Equal(clusterIssuer.Subject, desc.TLSInfo.CertIssuerSubject) && bytes.Equal(clusterIssuer.PublicKey, desc.TLSInfo.CertIssuerPublicKey) { + fmt.Fprintln(w, " Certificate issued by desired root CA") + } else { + fmt.Fprintln(w, " Certificate not issued by desired root CA") + } + } } var ( @@ -137,6 +153,15 @@ var ( return err } + cluster, err := getCluster(common.Context(cmd), c) + if err != nil { + return err + } + clusterIssuer, err := ca.IssuerFromAPIRootCA(&cluster.RootCA) + if err != nil { + return err + } + r, err := c.ListTasks(common.Context(cmd), &api.ListTasksRequest{ Filters: &api.ListTasksRequest_Filters{ @@ -147,7 +172,7 @@ var ( return err } - printNodeSummary(node) + printNodeSummary(node, clusterIssuer, cluster.RootCA.CACert) if len(r.Tasks) > 0 { fmt.Println() task.Print(r.Tasks, all, common.NewResolver(cmd, c)) diff --git a/cmd/swarmctl/node/list.go b/cmd/swarmctl/node/list.go index 3e577a8bab..d5a5f67743 100644 --- a/cmd/swarmctl/node/list.go +++ b/cmd/swarmctl/node/list.go @@ -6,7 +6,10 @@ import ( "os" "text/tabwriter" + "reflect" + "github.com/docker/swarmkit/api" + "github.com/docker/swarmkit/ca" "github.com/docker/swarmkit/cmd/swarmctl/common" "github.com/spf13/cobra" ) @@ -38,13 +41,27 @@ var ( var output func(n *api.Node) + cluster, err := getCluster(common.Context(cmd), c) + if err != nil { + return err + } + clusterIssuer, err := ca.IssuerFromAPIRootCA(&cluster.RootCA) + if err != nil { + return err + } + desiredTLSInfo := &api.NodeTLSInfo{ + CertIssuerPublicKey: clusterIssuer.PublicKey, + CertIssuerSubject: clusterIssuer.Subject, + TrustRoot: cluster.RootCA.CACert, + } + if !quiet { w := tabwriter.NewWriter(os.Stdout, 0, 4, 2, ' ', 0) defer func() { // Ignore flushing errors - there's nothing we can do. _ = w.Flush() }() - common.PrintHeader(w, "ID", "Name", "Membership", "Status", "Availability", "Manager Status") + common.PrintHeader(w, "ID", "Name", "Membership", "Status", "Availability", "Manager Status", "TLS Status") output = func(n *api.Node) { spec := &n.Spec name := spec.Annotations.Name @@ -64,14 +81,21 @@ var ( if reachability == "" && spec.DesiredRole == api.NodeRoleManager { reachability = "UNKNOWN" } + tlsStatus := "OUTDATED" + if n.Description == nil || n.Description.TLSInfo == nil { + tlsStatus = "UNKNOWN" + } else if reflect.DeepEqual(n.Description.TLSInfo, desiredTLSInfo) { + tlsStatus = "CURRENT" + } - fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n", + fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\t%s\n", n.ID, name, membership, n.Status.State.String(), availability, reachability, + tlsStatus, ) } } else {