From f1d0d24df390eb578907b8a8a673b0777038e1fb Mon Sep 17 00:00:00 2001 From: cyli Date: Mon, 3 Apr 2017 18:52:21 -0700 Subject: [PATCH 1/3] Add a reconciliation loop to ca.Server which, on root rotation, updates all the nodes to have an IssuanceStateRotate to trigger all the nodes to get new certificates. When all the nodes have rotated their certificates to be signed by the desired issuer, complete root rotation. Signed-off-by: cyli --- ca/server.go | 195 +++++++++++- ca/server_test.go | 673 +++++++++++++++++++++++++++++++++++++++- ca/testutils/cautils.go | 6 +- 3 files changed, 864 insertions(+), 10 deletions(-) diff --git a/ca/server.go b/ca/server.go index f3fa9cb6c8..85c18eb75d 100644 --- a/ca/server.go +++ b/ca/server.go @@ -4,10 +4,12 @@ import ( "bytes" "crypto/subtle" "crypto/x509" + "fmt" "sync" "time" "github.com/Sirupsen/logrus" + "github.com/cloudflare/cfssl/helpers" "github.com/docker/swarmkit/api" "github.com/docker/swarmkit/api/equality" "github.com/docker/swarmkit/identity" @@ -22,6 +24,7 @@ import ( const ( defaultReconciliationRetryInterval = 10 * time.Second + defaultRootReconciliationInterval = 3 * time.Second ) // APISecurityConfigUpdater knows how to update a SecurityConfig from an api.Cluster object @@ -63,6 +66,11 @@ type Server struct { // before we update the security config with the new root CA, we need to be able to save the root certs rootPaths CertPaths + + // lets us know if we're already doing a reconcile loop to tell nodes to rotate their + // TLS certificates and watching for root rotation completion + rootCAReconciliationInProgress bool + rootReconciliationRetryInterval time.Duration } // DefaultCAConfig returns the default CA Config, with a default expiration. @@ -75,12 +83,13 @@ func DefaultCAConfig() api.CAConfig { // NewServer creates a CA API server. func NewServer(store *store.MemoryStore, securityConfig *SecurityConfig, rootCAPaths CertPaths) *Server { return &Server{ - store: store, - securityConfig: securityConfig, - pending: make(map[string]*api.Node), - started: make(chan struct{}), - reconciliationRetryInterval: defaultReconciliationRetryInterval, - rootPaths: rootCAPaths, + store: store, + securityConfig: securityConfig, + pending: make(map[string]*api.Node), + started: make(chan struct{}), + reconciliationRetryInterval: defaultReconciliationRetryInterval, + rootReconciliationRetryInterval: defaultRootReconciliationInterval, + rootPaths: rootCAPaths, } } @@ -90,6 +99,12 @@ func (s *Server) SetReconciliationRetryInterval(reconciliationRetryInterval time s.reconciliationRetryInterval = reconciliationRetryInterval } +// SetRootReconciliationInterval changes the time interval between root rotation +// reconciliation attempts. This function must be called before Run. +func (s *Server) SetRootReconciliationInterval(interval time.Duration) { + s.rootReconciliationRetryInterval = interval +} + // GetUnlockKey is responsible for returning the current unlock key used for encrypting TLS private keys and // other at rest data. Access to this RPC call should only be allowed via mutual TLS from managers. func (s *Server) GetUnlockKey(ctx context.Context, request *api.GetUnlockKeyRequest) (*api.GetUnlockKeyResponse, error) { @@ -446,6 +461,7 @@ func (s *Server) Run(ctx context.Context) error { "method": "(*Server).Run", }).WithError(err).Errorf("error attempting to reconcile certificates") } + s.reconcileNodeRootsAndCerts() ticker := time.NewTicker(s.reconciliationRetryInterval) defer ticker.Stop() @@ -581,7 +597,6 @@ func (s *Server) UpdateRootCA(ctx context.Context, cluster *api.Cluster) error { if signingKey == nil { signingCert = nil } - updatedRootCA, err := NewRootCA(rCA.CACert, signingCert, signingKey, expiry, intermediates) if err != nil { return errors.Wrap(err, "invalid Root CA object in cluster") @@ -646,6 +661,9 @@ func (s *Server) UpdateRootCA(ctx context.Context, cluster *api.Cluster) error { s.securityConfig.externalCA.UpdateURLs(cfsslURLs...) s.lastSeenExternalCAs = cluster.Spec.CAConfig.Copy().ExternalCAs } + if rootCAChanged && s.lastSeenClusterRootCA.RootRotation != nil { + s.reconcileNodeRootsAndCerts() + } return nil } @@ -808,3 +826,166 @@ func isFinalState(status api.IssuanceStatus) bool { return false } + +// This function assumes that the expected root CA has root rotation. This is intended to be used by +// `reconcileNodeRootsAndCerts`, which uses the root CA from the `lastSeenClusterRootCA`, and checks +// that it has a root rotation before calling this function. +func (s *Server) finishRootRotation(tx store.Tx, expectedRootCA *api.RootCA) error { + clusterID := s.securityConfig.ClientTLSCreds.Organization() + cluster := store.GetCluster(tx, clusterID) + if cluster == nil { + return fmt.Errorf("unable to get cluster %s", clusterID) + } + + // If the RootCA object has changed (because another root rotation was started or because some other node + // had finished the root rotation), we cannot finish the root rotation that we were working on. + if !equality.RootCAEqualStable(expectedRootCA, &cluster.RootCA) { + return errors.New("target root rotation has changed") + } + + var signerCert []byte + if len(cluster.RootCA.RootRotation.CAKey) > 0 { + signerCert = cluster.RootCA.RootRotation.CACert + } + // we don't actually have to parse out the default node expiration from the cluster - we are just using + // the ca.RootCA object to generate new tokens and the digest + updatedRootCA, err := NewRootCA(cluster.RootCA.RootRotation.CACert, signerCert, cluster.RootCA.RootRotation.CAKey, + DefaultNodeCertExpiration, nil) + if err != nil { + return errors.Wrap(err, "invalid cluster root rotation object") + } + cluster.RootCA = api.RootCA{ + CACert: cluster.RootCA.RootRotation.CACert, + CAKey: cluster.RootCA.RootRotation.CAKey, + CACertHash: updatedRootCA.Digest.String(), + JoinTokens: api.JoinTokens{ + Worker: GenerateJoinToken(&updatedRootCA), + Manager: GenerateJoinToken(&updatedRootCA), + }, + LastForcedRotation: cluster.RootCA.LastForcedRotation, + } + return store.UpdateCluster(tx, cluster) +} + +func (s *Server) reconcileNodeRootsAndCerts() { + s.mu.Lock() + if !s.isRunning() || s.rootCAReconciliationInProgress { + s.mu.Unlock() + return + } + ctx := s.ctx + s.rootCAReconciliationInProgress = true + s.mu.Unlock() + + s.wg.Add(1) + logger := log.G(ctx).WithField("method", "(*Server).reconcileNodeRootsAndCerts") + + go func(retryInterval time.Duration) { + defer func() { + s.wg.Done() + s.mu.Lock() + s.rootCAReconciliationInProgress = false + s.mu.Unlock() + }() + + for { + s.secConfigMu.Lock() + rootCA := s.lastSeenClusterRootCA + s.secConfigMu.Unlock() + if rootCA == nil { + return + } + + wantedIssuer := rootCA.CACert + if rootCA.RootRotation != nil { + wantedIssuer = rootCA.RootRotation.CACert + } + issuerCert, err := helpers.ParseCertificatePEM(wantedIssuer) + if err != nil { + logger.WithError(err).Error("invalid certificate in cluster root CA object") + } + + // If the last seen root rotation is nil, then possibly during a leadership transfer a different manager's + // CA node started root rotation reconciliation and finished the root rotation, so we should also end the + // root reconciliation + if rootCA.RootRotation == nil { + return + } + + var allNodes []*api.Node + + s.store.View(func(tx store.ReadTx) { + allNodes, err = store.FindNodes(tx, store.ByMembership(api.NodeMembershipAccepted)) + }) + if err != nil { + logger.WithError(err).Error("could not find all nodes in the store") + } + + var rootRotationDone bool + _, err = s.store.Batch(func(batch *store.Batch) error { + converged := true // this value is true if all the node certs are issued by the correct issuer + for _, node := range allNodes { + var needUpdate bool + err := batch.Update(func(tx store.Tx) error { + n := store.GetNode(tx, node.ID) + if n == nil || (n.Description != nil && n.Description.TLSInfo != nil && + bytes.Equal(n.Description.TLSInfo.CertIssuerPublicKey, issuerCert.RawSubjectPublicKeyInfo) && + bytes.Equal(n.Description.TLSInfo.CertIssuerSubject, issuerCert.RawSubject)) { + return nil + } + converged = false + + // If we are already waiting for the node to rotate (or the node's about to get issued a new cert by us), + // so don't bother telling the node to rotate again. If a cert is in renew or pending, the cert that + // the node will get should be issued by the correct issuer. However, in all likelihood the node will + // not get a chance to report back its TLS before the reconciliation loop will mark it as rotate, and + // it will get a new certificate, which is fine, but results in an extra certificate renewal round. + issuanceState := n.Certificate.Status.State + if issuanceState == api.IssuanceStateRotate || issuanceState == api.IssuanceStateRenew || issuanceState == api.IssuanceStatePending { + return nil + } + n.Certificate.Status.State = api.IssuanceStateRotate + needUpdate = true + return store.UpdateNode(tx, n) + }) + if err != nil { + logger.WithError(err).Debugf("could not update node %s to IssuanceStateRotate", node.ID) + } else if needUpdate { + logger.Debugf("updated node %s to IssuanceStateRotate", node.ID) + } + } + // It's possible that between getting all nodes and finishing root rotation, new nodes have joined. by the time we + // fetch all nodes, the CA signer should have already been changed to be the desired issuer by the time we list all + // nodes, so any new nodes that appear after that will only get certificates signed by the correct issuer. + if converged { + // when the nodes are all converged on having TLS certificates issued by the correct entity, we can + // try to finish the root rotation + err := batch.Update(func(tx store.Tx) error { + return s.finishRootRotation(tx, rootCA) + }) + if err != nil { + logger.WithError(err).Errorf("could not complete root rotation") + } else { + rootRotationDone = true + } + } + + return nil + }) + if err != nil { + logger.WithError(err).Errorf("failed to check nodes for desired certificate issuer") + } + if rootRotationDone { + logger.Infof("completed root rotation on cluster %s", s.securityConfig.ServerTLSCreds.Organization()) + return + } + + select { + case <-ctx.Done(): + logger.Info("stopping incomplete root rotation reconciliation due to context being stopped") + return + case <-time.After(retryInterval): + } + } + }(s.rootReconciliationRetryInterval) +} diff --git a/ca/server_test.go b/ca/server_test.go index c80279665b..9bc1cdf1b0 100644 --- a/ca/server_test.go +++ b/ca/server_test.go @@ -6,19 +6,22 @@ import ( "fmt" "io/ioutil" "os" + "reflect" "testing" "time" - "golang.org/x/net/context" - "github.com/cloudflare/cfssl/helpers" "github.com/docker/swarmkit/api" + "github.com/docker/swarmkit/api/equality" "github.com/docker/swarmkit/ca" cautils "github.com/docker/swarmkit/ca/testutils" "github.com/docker/swarmkit/manager/state/store" "github.com/docker/swarmkit/testutils" + "github.com/opencontainers/go-digest" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" ) @@ -551,3 +554,669 @@ func TestCAServerUpdateRootCA(t *testing.T) { tc.CAServer.UpdateRootCA(context.Background(), fakeClusterSpec(cautils.ECDSA256SHA256Cert, cautils.ECDSA256Key, nil, nil)) require.Equal(t, tc.RootCA.Certs, tc.ServingSecurityConfig.RootCA().Certs) } + +type rootRotationTester struct { + tc *cautils.TestCA + t *testing.T +} + +// go through all the nodes and update/create the ones we want, and delete the ones +// we don't +func (r *rootRotationTester) convergeWantedNodes(wantNodes map[string]*api.Node, descr string) { + updatedOldNodes := make(map[string]struct{}) + createdNewNodes := make(map[string]struct{}) + require.NoError(r.t, testutils.PollFuncWithTimeout(nil, func() error { + return r.tc.MemoryStore.Update(func(tx store.Tx) error { + nodes, err := store.FindNodes(tx, store.All) + if err != nil { + return err + } + for _, node := range nodes { + wanted, inWanted := wantNodes[node.ID] + _, done := updatedOldNodes[node.ID] + + if inWanted && !done { + node.Description = wanted.Description + node.Certificate = wanted.Certificate + if err := store.UpdateNode(tx, node); err != nil { + return err + } + updatedOldNodes[node.ID] = struct{}{} + } else if !inWanted { + if err := store.DeleteNode(tx, node.ID); err != nil { + return err + } + updatedOldNodes[node.ID] = struct{}{} + } + } + for nodeID, wanted := range wantNodes { + _, createdAlready := createdNewNodes[nodeID] + if _, ok := updatedOldNodes[nodeID]; !ok && !createdAlready { + if err := store.CreateNode(tx, wanted); err != nil { + return err + } + createdNewNodes[nodeID] = struct{}{} + } + } + return nil + }) + }, 2*time.Second), descr) +} + +func (r *rootRotationTester) convergeRootCA(wantRootCA *api.RootCA, descr string) { + require.NoError(r.t, testutils.PollFuncWithTimeout(nil, func() error { + return r.tc.MemoryStore.Update(func(tx store.Tx) error { + clusters, err := store.FindClusters(tx, store.All) + if err != nil || len(clusters) != 1 { + return errors.Wrap(err, "unable to find cluster") + } + clusters[0].RootCA = *wantRootCA + return store.UpdateCluster(tx, clusters[0]) + }) + }, time.Second), descr) +} + +func (r *rootRotationTester) getFakeAPINode(id string, state api.IssuanceStatus_State, tlsInfo *api.NodeTLSInfo, member bool) *api.Node { + node := &api.Node{ + ID: id, + Certificate: api.Certificate{ + Status: api.IssuanceStatus{ + State: state, + }, + }, + Spec: api.NodeSpec{ + Membership: api.NodeMembershipAccepted, + }, + } + if !member { + node.Spec.Membership = api.NodeMembershipPending + } + // the CA server will immediately pick these up, so generate CSRs for the CA server to sign + if state == api.IssuanceStateRenew || state == api.IssuanceStatePending { + csr, _, err := ca.GenerateNewCSR() + require.NoError(r.t, err) + node.Certificate.CSR = csr + } + if tlsInfo != nil { + node.Description = &api.NodeDescription{TLSInfo: tlsInfo} + } + return node +} + +func startCAServer(caServer *ca.Server) { + alreadyRunning := make(chan struct{}) + go func() { + if err := caServer.Run(context.Background()); err != nil { + close(alreadyRunning) + } + }() + select { + case <-caServer.Ready(): + case <-alreadyRunning: + } +} + +func getRotationInfo(t *testing.T, rotationCert []byte, rootCA *ca.RootCA) ([]byte, *api.NodeTLSInfo) { + parsedNewRoot, err := helpers.ParseCertificatePEM(rotationCert) + require.NoError(t, err) + crossSigned, err := rootCA.CrossSignCACertificate(rotationCert) + require.NoError(t, err) + return crossSigned, &api.NodeTLSInfo{ + TrustRoot: rootCA.Certs, + CertIssuerPublicKey: parsedNewRoot.RawSubjectPublicKeyInfo, + CertIssuerSubject: parsedNewRoot.RawSubject, + } +} + +// These are the root rotation test cases where we expect there to be a change in the FindNodes +// or root CA values after converging. +func TestRootRotationReconciliationWithChanges(t *testing.T) { + t.Parallel() + if cautils.External { + // the external CA functionality is unrelated to testing the reconciliation loop + return + } + + tc := cautils.NewTestCA(t) + defer tc.Stop() + rt := rootRotationTester{ + tc: tc, + t: t, + } + + rotationCerts := [][]byte{cautils.ECDSA256SHA256Cert, cautils.ECDSACertChain[2]} + rotationKeys := [][]byte{cautils.ECDSA256Key, cautils.ECDSACertChainKeys[2]} + var ( + rotationCrossSigned [][]byte + rotationTLSInfo []*api.NodeTLSInfo + ) + for _, cert := range rotationCerts { + cross, info := getRotationInfo(t, cert, &tc.RootCA) + rotationCrossSigned = append(rotationCrossSigned, cross) + rotationTLSInfo = append(rotationTLSInfo, info) + } + + oldNodeTLSInfo := &api.NodeTLSInfo{ + TrustRoot: tc.RootCA.Certs, + CertIssuerPublicKey: tc.ServingSecurityConfig.IssuerInfo().PublicKey, + CertIssuerSubject: tc.ServingSecurityConfig.IssuerInfo().Subject, + } + + var startCluster *api.Cluster + tc.MemoryStore.View(func(tx store.ReadTx) { + startCluster = store.GetCluster(tx, tc.Organization) + }) + require.NotNil(t, startCluster) + + testcases := []struct { + nodes map[string]*api.Node // what nodes we should start with + rootCA *api.RootCA // what root CA we should start with + expectedNodes map[string]*api.Node // what nodes we expect in the end, if nil, then unchanged from the start + expectedRootCA *api.RootCA // what root CA we expect in the end, if nil, then unchanged from the start + caServerRestart bool // whether to stop the CA server before making the node and root changes and restart after + descr string + }{ + { + descr: ("If there is no TLS info, the reconciliation cycle tells the nodes to rotate if they're not already getting " + + "a new cert. Any renew/pending nodes will have certs issued, but because the TLS info is nil, they will " + + `go "rotate" state`), + nodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, nil, true), + "2": rt.getFakeAPINode("2", api.IssuanceStateRenew, nil, true), + "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, nil, true), + "4": rt.getFakeAPINode("4", api.IssuanceStatePending, nil, true), + "5": rt.getFakeAPINode("5", api.IssuanceStateFailed, nil, true), + }, + rootCA: &api.RootCA{ + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCerts[0], + CAKey: rotationKeys[0], + CrossSignedCACert: rotationCrossSigned[0], + }, + }, + expectedNodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateRotate, nil, true), + "2": rt.getFakeAPINode("2", api.IssuanceStateRotate, nil, true), + "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, nil, true), + "4": rt.getFakeAPINode("4", api.IssuanceStateRotate, nil, true), + "5": rt.getFakeAPINode("5", api.IssuanceStateRotate, nil, true), + }, + }, + { + descr: ("Assume all of the nodes have gotten certs, but some of them are the wrong cert " + + "(going by the TLS info), which shouldn't really happen. the rotation reconciliation " + + "will tell the wrong ones to rotate a second time"), + nodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "2": rt.getFakeAPINode("2", api.IssuanceStateIssued, oldNodeTLSInfo, true), + "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": rt.getFakeAPINode("5", api.IssuanceStateIssued, oldNodeTLSInfo, true), + }, + rootCA: &api.RootCA{ // no change in root CA from previous + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCerts[0], + CAKey: rotationKeys[0], + CrossSignedCACert: rotationCrossSigned[0], + }, + }, + expectedNodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "2": rt.getFakeAPINode("2", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": rt.getFakeAPINode("5", api.IssuanceStateRotate, oldNodeTLSInfo, true), + }, + }, + { + descr: ("New nodes that are added will also be picked up and told to rotate"), + nodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": rt.getFakeAPINode("5", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "6": rt.getFakeAPINode("6", api.IssuanceStateRenew, nil, true), + }, + rootCA: &api.RootCA{ // no change in root CA from previous + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCerts[0], + CAKey: rotationKeys[0], + CrossSignedCACert: rotationCrossSigned[0], + }, + }, + expectedNodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": rt.getFakeAPINode("5", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "6": rt.getFakeAPINode("6", api.IssuanceStateRotate, nil, true), + }, + }, + { + descr: ("Even if root rotation isn't finished, if the root changes again to a " + + "different cert, all the nodes with the old root rotation cert will be told " + + "to rotate again."), + nodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": rt.getFakeAPINode("5", api.IssuanceStateIssued, oldNodeTLSInfo, true), + "6": rt.getFakeAPINode("6", api.IssuanceStateIssued, rotationTLSInfo[0], true), + }, + rootCA: &api.RootCA{ // new root rotation + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCerts[1], + CAKey: rotationKeys[1], + CrossSignedCACert: rotationCrossSigned[1], + }, + }, + expectedNodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateRotate, rotationTLSInfo[0], true), + "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": rt.getFakeAPINode("4", api.IssuanceStateRotate, rotationTLSInfo[0], true), + "5": rt.getFakeAPINode("5", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "6": rt.getFakeAPINode("6", api.IssuanceStateRotate, rotationTLSInfo[0], true), + }, + }, + { + descr: ("Once all nodes have rotated to their desired TLS info (even if it's because " + + "a node with the wrong TLS info has been removed, the root rotation is completed."), + nodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "6": rt.getFakeAPINode("6", api.IssuanceStateIssued, rotationTLSInfo[1], true), + }, + rootCA: &api.RootCA{ + // no change in root CA from previous - even if root rotation gets completed after + // the nodes are first set, and we just add the root rotation again because of this + // test order, because the TLS info is correct for all nodes it will be completed again + // anyway) + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCerts[1], + CAKey: rotationKeys[1], + CrossSignedCACert: rotationCrossSigned[1], + }, + }, + expectedRootCA: &api.RootCA{ + CACert: rotationCerts[1], + CAKey: rotationKeys[1], + CACertHash: digest.FromBytes(rotationCerts[1]).String(), + // ignore the join tokens - we aren't comparing them + }, + }, + { + descr: ("If a root rotation happens when the CA server is down, so long as it saw the change " + + "it will start reconciling the nodes as soon as it's started up again"), + caServerRestart: true, + nodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "6": rt.getFakeAPINode("6", api.IssuanceStateIssued, rotationTLSInfo[1], true), + }, + rootCA: &api.RootCA{ + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCerts[0], + CAKey: rotationKeys[0], + CrossSignedCACert: rotationCrossSigned[0], + }, + }, + expectedNodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateRotate, rotationTLSInfo[1], true), + "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, rotationTLSInfo[1], true), + "4": rt.getFakeAPINode("4", api.IssuanceStateRotate, rotationTLSInfo[1], true), + "6": rt.getFakeAPINode("6", api.IssuanceStateRotate, rotationTLSInfo[1], true), + }, + }, + } + + for _, testcase := range testcases { + if testcase.caServerRestart { + rt.tc.CAServer.Stop() + } + + rt.convergeRootCA(testcase.rootCA, testcase.descr) + rt.convergeWantedNodes(testcase.nodes, testcase.descr) + + if testcase.expectedNodes == nil { + testcase.expectedNodes = testcase.nodes + } + if testcase.expectedRootCA == nil { + testcase.expectedRootCA = testcase.rootCA + } + + if testcase.caServerRestart { + startCAServer(rt.tc.CAServer) + } + + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + var ( + nodes []*api.Node + cluster *api.Cluster + err error + ) + tc.MemoryStore.View(func(tx store.ReadTx) { + nodes, err = store.FindNodes(tx, store.All) + cluster = store.GetCluster(tx, tc.Organization) + }) + if err != nil { + return err + } + if cluster == nil { + return errors.New("no cluster found") + } + + if !equality.RootCAEqualStable(&cluster.RootCA, testcase.expectedRootCA) { + return fmt.Errorf("root CAs not equal:\n\texpected: %v\n\tactual: %v", *testcase.expectedRootCA, cluster.RootCA) + } + if len(nodes) != len(testcase.expectedNodes) { + return fmt.Errorf("number of expected nodes (%d) does not equal number of actual nodes (%d)", + len(testcase.expectedNodes), len(nodes)) + } + for _, node := range nodes { + expected, ok := testcase.expectedNodes[node.ID] + if !ok { + return fmt.Errorf("node %s is present and was unexpected", node.ID) + } + if !reflect.DeepEqual(expected.Description, node.Description) { + return fmt.Errorf("the node description of node %s is not expected:\n\texpected: %v\n\tactual: %v", node.ID, + expected.Description, node.Description) + } + if !reflect.DeepEqual(expected.Certificate.Status, node.Certificate.Status) { + return fmt.Errorf("the certificate status of node %s is not expected:\n\texpected: %v\n\tactual: %v", node.ID, + expected.Certificate, node.Certificate) + } + } + return nil + }, 5*time.Second), testcase.descr) + } +} + +// These are the root rotation test cases where we expect there to be no changes made to either +// the nodes or the root CA object +func TestRootRotationReconciliationNoChanges(t *testing.T) { + t.Parallel() + if cautils.External { + // the external CA functionality is unrelated to testing the reconciliation loop + return + } + + tc := cautils.NewTestCA(t) + defer tc.Stop() + rt := rootRotationTester{ + tc: tc, + t: t, + } + + rotationCert := cautils.ECDSA256SHA256Cert + rotationKey := cautils.ECDSA256Key + rotationCrossSigned, rotationTLSInfo := getRotationInfo(t, rotationCert, &tc.RootCA) + + oldNodeTLSInfo := &api.NodeTLSInfo{ + TrustRoot: tc.RootCA.Certs, + CertIssuerPublicKey: tc.ServingSecurityConfig.IssuerInfo().PublicKey, + CertIssuerSubject: tc.ServingSecurityConfig.IssuerInfo().Subject, + } + + var startCluster *api.Cluster + tc.MemoryStore.View(func(tx store.ReadTx) { + startCluster = store.GetCluster(tx, tc.Organization) + }) + require.NotNil(t, startCluster) + + testcases := []struct { + nodes map[string]*api.Node // what nodes we should start with + rootCA *api.RootCA // what root CA we should start with + descr string + caServerStopped bool // if the server is running, only then will a reconciliation loop happen + }{ + { + descr: ("If the CA server is not running no reconciliation happens even if a root rotation " + + "is in progress"), + caServerStopped: true, + nodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, oldNodeTLSInfo, true), + "2": rt.getFakeAPINode("2", api.IssuanceStateRenew, nil, true), + "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, nil, true), + "4": rt.getFakeAPINode("4", api.IssuanceStatePending, nil, true), + "5": rt.getFakeAPINode("5", api.IssuanceStateFailed, nil, true), + }, + rootCA: &api.RootCA{ + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCert, + CAKey: rotationKey, + CrossSignedCACert: rotationCrossSigned, + }, + }, + }, + { + descr: ("If all nodes have the right TLS info or are already rotated (or are not members), the " + + "there will be no changes needed"), + nodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo, true), + "2": rt.getFakeAPINode("2", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, rotationTLSInfo, true), + }, + rootCA: &api.RootCA{ // no change in root CA from previous + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCert, + CAKey: rotationKey, + CrossSignedCACert: rotationCrossSigned, + }, + }, + }, + { + descr: ("Nodes already in rotate state, even if they currently have the correct TLS issuer, will be " + + "left in the rotate state even if root rotation is aborted because we don't know if they're already " + + "in the process of getting a new cert. Even if they're issued by a different issuer, they will be " + + "left alone because they'll have an interemdiate that chains up to the old issuer."), + nodes: map[string]*api.Node{ + "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), + "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo, true), + "2": rt.getFakeAPINode("2", api.IssuanceStateRotate, oldNodeTLSInfo, true), + }, + rootCA: &api.RootCA{ // no change in root CA from previous + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + }, + }, + } + + for _, testcase := range testcases { + if testcase.caServerStopped { + rt.tc.CAServer.Stop() + } else { + startCAServer(rt.tc.CAServer) + } + + rt.convergeRootCA(testcase.rootCA, testcase.descr) + rt.convergeWantedNodes(testcase.nodes, testcase.descr) + + time.Sleep(500 * time.Millisecond) + + var ( + nodes []*api.Node + cluster *api.Cluster + err error + ) + + tc.MemoryStore.View(func(tx store.ReadTx) { + nodes, err = store.FindNodes(tx, store.All) + cluster = store.GetCluster(tx, tc.Organization) + }) + require.NoError(t, err) + require.NotNil(t, cluster) + require.Equal(t, cluster.RootCA, *testcase.rootCA, testcase.descr) + require.Len(t, nodes, len(testcase.nodes), testcase.descr) + for _, node := range nodes { + expected, ok := testcase.nodes[node.ID] + require.True(t, ok, "node %s: %s", node.ID, testcase.descr) + require.Equal(t, expected.Description, node.Description, "node %s: %s", node.ID, testcase.descr) + require.Equal(t, expected.Certificate.Status, node.Certificate.Status, "node %s: %s", node.ID, testcase.descr) + } + } +} + +// Tests if the root rotation changes while the reconciliation loop is going, eventually the root rotation will finish +// successfully (even if there's a competing reconciliation loop, for instance if there's a bug during leadership handoff). +func TestRootRotationReconciliationRace(t *testing.T) { + t.Parallel() + if cautils.External { + // the external CA functionality is unrelated to testing the reconciliation loop + return + } + + tc := cautils.NewTestCA(t) + defer tc.Stop() + rt := rootRotationTester{ + tc: tc, + t: t, + } + // start a competing CA server + tempDir, err := ioutil.TempDir("", "competing-ca-server") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + paths := ca.NewConfigPaths(tempDir) + competingSecConfig, err := tc.NewNodeConfig(ca.ManagerRole) + require.NoError(t, err) + + var otherServers []*ca.Server + for i := 0; i < 3; i++ { // to make sure we get some collision + otherServer := ca.NewServer(tc.MemoryStore, competingSecConfig, paths.RootCA) + // offset each server's reconciliation interval somewhat so that some will + // pre-empt others + otherServer.SetRootReconciliationInterval(time.Millisecond * time.Duration((i+1)*10)) + startCAServer(otherServer) + defer otherServer.Stop() + otherServers = append(otherServers, otherServer) + } + clusterWatch, clusterWatchCancel, err := store.ViewAndWatch( + tc.MemoryStore, func(tx store.ReadTx) error { + cluster := store.GetCluster(tx, tc.Organization) + for _, s := range otherServers { + s.UpdateRootCA(context.Background(), cluster) + } + return nil + }, + api.EventUpdateCluster{ + Cluster: &api.Cluster{ID: tc.Organization}, + Checks: []api.ClusterCheckFunc{api.ClusterCheckID}, + }, + ) + require.NoError(t, err) + defer clusterWatchCancel() + + done := make(chan struct{}) + defer close(done) + go func() { + for { + select { + case event := <-clusterWatch: + clusterEvent := event.(api.EventUpdateCluster) + for _, s := range otherServers { + s.UpdateRootCA(context.Background(), clusterEvent.Cluster) + } + case <-done: + return + } + } + }() + + oldNodeTLSInfo := &api.NodeTLSInfo{ + TrustRoot: tc.RootCA.Certs, + CertIssuerPublicKey: tc.ServingSecurityConfig.IssuerInfo().PublicKey, + CertIssuerSubject: tc.ServingSecurityConfig.IssuerInfo().Subject, + } + + nodes := make(map[string]*api.Node) + for i := 0; i < 5; i++ { + nodeID := fmt.Sprintf("%d", i) + nodes[nodeID] = rt.getFakeAPINode(nodeID, api.IssuanceStateIssued, oldNodeTLSInfo, true) + } + rt.convergeWantedNodes(nodes, "setting up nodes for root rotation race condition test") + + for i := 0; i < 10; i++ { + var ( + rotationCrossSigned []byte + rotationTLSInfo *api.NodeTLSInfo + ) + rotationCert, rotationKey, err := cautils.CreateRootCertAndKey(fmt.Sprintf("root cn %d", i)) + require.NoError(t, err) + require.NoError(t, tc.MemoryStore.Update(func(tx store.Tx) error { + cluster := store.GetCluster(tx, tc.Organization) + if cluster == nil { + return errors.New("cluster has disappeared") + } + rootCA := cluster.RootCA.Copy() + caRootCA, err := ca.NewRootCA(rootCA.CACert, rootCA.CACert, rootCA.CAKey, ca.DefaultNodeCertExpiration, nil) + if err != nil { + return err + } + rotationCrossSigned, rotationTLSInfo = getRotationInfo(t, rotationCert, &caRootCA) + rootCA.RootRotation = &api.RootRotation{ + CACert: rotationCert, + CAKey: rotationKey, + CrossSignedCACert: rotationCrossSigned, + } + cluster.RootCA = *rootCA + return store.UpdateCluster(tx, cluster) + })) + for _, node := range nodes { + node.Description.TLSInfo = rotationTLSInfo + } + rt.convergeWantedNodes(nodes, fmt.Sprintf("iteration %d", i)) + } + + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + var cluster *api.Cluster + tc.MemoryStore.View(func(tx store.ReadTx) { + cluster = store.GetCluster(tx, tc.Organization) + }) + if cluster == nil { + return errors.New("cluster has disappeared") + } + if cluster.RootCA.RootRotation == nil { + return fmt.Errorf("root rotation is still present") + } + return nil + }, time.Second)) +} diff --git a/ca/testutils/cautils.go b/ca/testutils/cautils.go index 61024320c2..532943181c 100644 --- a/ca/testutils/cautils.go +++ b/ca/testutils/cautils.go @@ -21,6 +21,7 @@ import ( "github.com/docker/swarmkit/connectionbroker" "github.com/docker/swarmkit/identity" "github.com/docker/swarmkit/ioutils" + "github.com/docker/swarmkit/log" "github.com/docker/swarmkit/manager/state/store" stateutils "github.com/docker/swarmkit/manager/state/testutils" "github.com/docker/swarmkit/remotes" @@ -178,6 +179,7 @@ func NewTestCAFromRootCA(t *testing.T, tempBaseDir string, rootCA ca.RootCA, krw caServer := ca.NewServer(s, managerConfig, paths.RootCA) caServer.SetReconciliationRetryInterval(50 * time.Millisecond) + caServer.SetRootReconciliationInterval(50 * time.Millisecond) api.RegisterCAServer(grpcServer, caServer) api.RegisterNodeCAServer(grpcServer, caServer) @@ -200,7 +202,9 @@ func NewTestCAFromRootCA(t *testing.T, tempBaseDir string, rootCA ca.RootCA, krw select { case event := <-clusterWatch: clusterEvent := event.(api.EventUpdateCluster) - caServer.UpdateRootCA(ctx, clusterEvent.Cluster) + if err := caServer.UpdateRootCA(ctx, clusterEvent.Cluster); err != nil { + log.G(ctx).WithError(err).Error("ca utils CA server could not update root CA") + } case <-ctx.Done(): clusterWatchCancel() return From 225bea52378ff0890c6dca06385ae2deffc962d0 Mon Sep 17 00:00:00 2001 From: cyli Date: Wed, 5 Apr 2017 15:53:01 -0700 Subject: [PATCH 2/3] Add root rotation integration tests Signed-off-by: cyli --- integration/api.go | 10 +- integration/cluster.go | 51 ++++++--- integration/integration_test.go | 181 ++++++++++++++++++++++++++++++++ integration/node.go | 5 + 4 files changed, 229 insertions(+), 18 deletions(-) diff --git a/integration/api.go b/integration/api.go index b7a838eeb9..0f44037887 100644 --- a/integration/api.go +++ b/integration/api.go @@ -131,6 +131,12 @@ func (a *dummyAPI) ListClusters(ctx context.Context, r *api.ListClustersRequest) return cli.ListClusters(ctx, r) } -func (a *dummyAPI) UpdateCluster(context.Context, *api.UpdateClusterRequest) (*api.UpdateClusterResponse, error) { - panic("not implemented") +func (a *dummyAPI) UpdateCluster(ctx context.Context, r *api.UpdateClusterRequest) (*api.UpdateClusterResponse, error) { + ctx, cancel := context.WithTimeout(ctx, opsTimeout) + defer cancel() + cli, err := a.c.RandomManager().ControlClient(ctx) + if err != nil { + return nil, err + } + return cli.UpdateCluster(ctx, r) } diff --git a/integration/cluster.go b/integration/cluster.go index f56351d7e5..6bc00925d6 100644 --- a/integration/cluster.go +++ b/integration/cluster.go @@ -98,14 +98,11 @@ func (c *testCluster) AddManager(lateBind bool, rootCA *ca.RootCA) error { if err != nil { return err } - clusterInfo, err := c.api.ListClusters(context.Background(), &api.ListClustersRequest{}) + clusterInfo, err := c.GetClusterInfo() if err != nil { return err } - if len(clusterInfo.Clusters) == 0 { - return fmt.Errorf("joining manager: there is no cluster created in storage") - } - node, err := newTestNode(joinAddr, clusterInfo.Clusters[0].RootCA.JoinTokens.Manager, false, nil) + node, err := newTestNode(joinAddr, clusterInfo.RootCA.JoinTokens.Manager, false, nil) if err != nil { return err } @@ -137,14 +134,9 @@ func (c *testCluster) AddManager(lateBind bool, rootCA *ca.RootCA) error { if lateBind { // Verify that the control API works - clusterInfo, err := c.api.ListClusters(context.Background(), &api.ListClustersRequest{}) - if err != nil { + if _, err := c.GetClusterInfo(); err != nil { return err } - if len(clusterInfo.Clusters) == 0 { - return fmt.Errorf("joining manager: there is no cluster created in storage") - } - return n.node.BindRemote(context.Background(), "127.0.0.1:0", "") } @@ -162,14 +154,11 @@ func (c *testCluster) AddAgent() error { if err != nil { return err } - clusterInfo, err := c.api.ListClusters(context.Background(), &api.ListClustersRequest{}) + clusterInfo, err := c.GetClusterInfo() if err != nil { return err } - if len(clusterInfo.Clusters) == 0 { - return fmt.Errorf("joining agent: there is no cluster created in storage") - } - node, err := newTestNode(joinAddr, clusterInfo.Clusters[0].RootCA.JoinTokens.Worker, false, nil) + node, err := newTestNode(joinAddr, clusterInfo.RootCA.JoinTokens.Worker, false, nil) if err != nil { return err } @@ -383,3 +372,33 @@ func (c *testCluster) StartNode(id string) error { } return nil } + +func (c *testCluster) GetClusterInfo() (*api.Cluster, error) { + clusterInfo, err := c.api.ListClusters(context.Background(), &api.ListClustersRequest{}) + if err != nil { + return nil, err + } + if len(clusterInfo.Clusters) != 1 { + return nil, fmt.Errorf("number of clusters in storage: %d; expected 1", len(clusterInfo.Clusters)) + } + return clusterInfo.Clusters[0], nil +} + +func (c *testCluster) RotateRootCA(cert, key []byte) error { + // poll in case something else changes the cluster before we can update it + return testutils.PollFuncWithTimeout(nil, func() error { + clusterInfo, err := c.GetClusterInfo() + if err != nil { + return err + } + newSpec := clusterInfo.Spec.Copy() + newSpec.CAConfig.SigningCACert = cert + newSpec.CAConfig.SigningCAKey = key + _, err = c.api.UpdateCluster(context.Background(), &api.UpdateClusterRequest{ + ClusterID: clusterInfo.ID, + Spec: newSpec, + ClusterVersion: &clusterInfo.Meta.Version, + }) + return err + }, opsTimeout) +} diff --git a/integration/integration_test.go b/integration/integration_test.go index dd13f5667c..d676987afe 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -1,6 +1,7 @@ package integration import ( + "bytes" "flag" "fmt" "io/ioutil" @@ -13,6 +14,8 @@ import ( "golang.org/x/net/context" + "reflect" + "github.com/Sirupsen/logrus" "github.com/cloudflare/cfssl/helpers" events "github.com/docker/go-events" @@ -99,6 +102,9 @@ func pollClusterReady(t *testing.T, c *testCluster, numWorker, numManager int) { return fmt.Errorf("worker node %s should not have manager status, returned %s", n.ID, n.ManagerStatus) } } + if n.Description.TLSInfo == nil { + return fmt.Errorf("node %s has not reported its TLS info yet", n.ID) + } } if !leaderFound { return fmt.Errorf("leader of cluster is not found") @@ -547,3 +553,178 @@ func TestForceNewCluster(t *testing.T) { require.NoError(t, ioutil.WriteFile(managerCertFile, expiredCertPEM, 0644)) require.Error(t, cl.StartNode(nodeID)) } + +func pollRootRotationDone(t *testing.T, cl *testCluster) { + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + clusterInfo, err := cl.GetClusterInfo() + if err != nil { + return err + } + if clusterInfo.RootCA.RootRotation != nil { + return errors.New("root rotation not done") + } + return nil + }, opsTimeout)) +} + +func TestSuccessfulRootRotation(t *testing.T) { + t.Parallel() + numWorker, numManager := 2, 3 + cl := newCluster(t, numWorker, numManager) + defer func() { + require.NoError(t, cl.Stop()) + }() + pollClusterReady(t, cl, numWorker, numManager) + + // Take down one of managers and both workers, so we can't actually ever finish root rotation. + resp, err := cl.api.ListNodes(context.Background(), &api.ListNodesRequest{}) + require.NoError(t, err) + var ( + downManagerID string + downWorkerIDs []string + oldTLSInfo *api.NodeTLSInfo + ) + for _, n := range resp.Nodes { + if oldTLSInfo != nil { + require.Equal(t, oldTLSInfo, n.Description.TLSInfo) + } else { + oldTLSInfo = n.Description.TLSInfo + } + if n.Role == api.NodeRoleManager { + if !n.ManagerStatus.Leader && downManagerID == "" { + downManagerID = n.ID + require.NoError(t, cl.nodes[n.ID].Pause(false)) + } + continue + } + downWorkerIDs = append(downWorkerIDs, n.ID) + require.NoError(t, cl.nodes[n.ID].Pause(false)) + } + + // perform a root rotation, and wait until all the nodes that are up have newly issued certs + newRootCert, newRootKey, err := cautils.CreateRootCertAndKey("newRootCN") + require.NoError(t, err) + require.NoError(t, cl.RotateRootCA(newRootCert, newRootKey)) + + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + resp, err := cl.api.ListNodes(context.Background(), &api.ListNodesRequest{}) + if err != nil { + return err + } + for _, n := range resp.Nodes { + isDown := n.ID == downManagerID || n.ID == downWorkerIDs[0] || n.ID == downWorkerIDs[1] + if reflect.DeepEqual(n.Description.TLSInfo, oldTLSInfo) != isDown { + return fmt.Errorf("expected TLS info to have changed: %v", !isDown) + } + } + + // root rotation isn't done + clusterInfo, err := cl.GetClusterInfo() + if err != nil { + return err + } + require.NotNil(t, clusterInfo.RootCA.RootRotation) // if root rotation is already done, fail and finish the test here + return nil + }, opsTimeout)) + + // Kill the leader, and bring the other manager back to show that a new + // leader can pick up the reconciliation loop and complete the root rotation. + // Bring one worker back, kill the other worker, and add a new worker - show + // that we can converge on a root rotation. + downLeader, err := cl.Leader() + require.NoError(t, err) + leaderID := downLeader.node.NodeID() + + require.NoError(t, cl.StartNode(downManagerID)) + require.NoError(t, cl.StartNode(downWorkerIDs[0])) + require.NoError(t, cl.RemoveNode(downWorkerIDs[1], false)) + require.NoError(t, cl.AddAgent()) + require.NoError(t, downLeader.Pause(false)) + + // we can finish root rotation even though the previous leader was down because it had + // already rotated its cert + pollRootRotationDone(t, cl) + + // bring the previous leader back up so it can get the new root + require.NoError(t, cl.StartNode(leaderID)) + + // wait until all the nodes have gotten their new certs and trust roots + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + resp, err = cl.api.ListNodes(context.Background(), &api.ListNodesRequest{}) + if err != nil { + return err + } + var newTLSInfo *api.NodeTLSInfo + for _, n := range resp.Nodes { + if newTLSInfo == nil { + newTLSInfo = n.Description.TLSInfo + if bytes.Equal(newTLSInfo.CertIssuerPublicKey, oldTLSInfo.CertIssuerPublicKey) || + bytes.Equal(newTLSInfo.CertIssuerSubject, oldTLSInfo.CertIssuerSubject) { + return errors.New("expecting the issuer to have changed") + } + if !bytes.Equal(newTLSInfo.TrustRoot, newRootCert) { + return errors.New("expecting the the root certificate to have changed") + } + } else if !reflect.DeepEqual(newTLSInfo, n.Description.TLSInfo) { + return fmt.Errorf("the nodes have not converged yet, particularly %s", n.ID) + } + + if n.Certificate.Status.State != api.IssuanceStateIssued { + return errors.New("nodes have yet to finish renewing their TLS certificates") + } + } + return nil + }, opsTimeout)) +} + +func TestRepeatedRootRotation(t *testing.T) { + t.Parallel() + numWorker, numManager := 3, 1 + cl := newCluster(t, numWorker, numManager) + defer func() { + require.NoError(t, cl.Stop()) + }() + pollClusterReady(t, cl, numWorker, numManager) + + resp, err := cl.api.ListNodes(context.Background(), &api.ListNodesRequest{}) + require.NoError(t, err) + var oldTLSInfo *api.NodeTLSInfo + for _, n := range resp.Nodes { + if oldTLSInfo != nil { + require.Equal(t, oldTLSInfo, n.Description.TLSInfo) + } else { + oldTLSInfo = n.Description.TLSInfo + } + } + + // perform multiple root rotations, wait a second between each + var newRootCert, newRootKey []byte + for i := 0; i < 3; i++ { + newRootCert, newRootKey, err = cautils.CreateRootCertAndKey("newRootCN") + require.NoError(t, err) + require.NoError(t, cl.RotateRootCA(newRootCert, newRootKey)) + time.Sleep(time.Second) + } + + pollRootRotationDone(t, cl) + + // wait until all the nodes are stabilized back to the latest issuer + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + resp, err = cl.api.ListNodes(context.Background(), &api.ListNodesRequest{}) + if err != nil { + return nil + } + for _, n := range resp.Nodes { + if reflect.DeepEqual(n.Description.TLSInfo, oldTLSInfo) { + return errors.New("nodes have not changed TLS info") + } + if n.Certificate.Status.State != api.IssuanceStateIssued { + return errors.New("nodes have yet to finish renewing their TLS certificates") + } + if !bytes.Equal(n.Description.TLSInfo.TrustRoot, newRootCert) { + return errors.New("nodes do not all trust the new root yet") + } + } + return nil + }, opsTimeout)) +} diff --git a/integration/node.go b/integration/node.go index 17868ecc95..f4f27a2d2a 100644 --- a/integration/node.go +++ b/integration/node.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" "runtime" + "strings" "google.golang.org/grpc" @@ -115,6 +116,10 @@ func (n *testNode) stop() error { defer cancel() isManager := n.IsManager() if err := n.node.Stop(ctx); err != nil { + // if the error is from trying to stop an already stopped stopped node, ignore the error + if strings.Contains(err.Error(), "node: not started") { + return nil + } // TODO(aaronl): This stack dumping may be removed in the // future once context deadline issues while shutting down // nodes are resolved. From 0ba0da4c9cac7e0bf0d1066906858acbed62fbe5 Mon Sep 17 00:00:00 2001 From: cyli Date: Thu, 6 Apr 2017 17:45:59 -0700 Subject: [PATCH 3/3] Abstract the reconciliation loop to its own object. Rather than polling the store for all nodes at intervals, rely on the cluster and node watches to update an in-memory mapping of the current nodes. At regular intervals, update the store to tell a throttled number of the unconverged nodes to rotate their certificates. Also, remove the leader rotation part of the root rotation integration test, as that takes a very long time. There are server tests to ensure that multiple CA servers running reconciliation loops, and starting a CA server from a stopped state, does not break root reconciliation. Signed-off-by: cyli --- ca/reconciler.go | 259 ++++++++++++++++++++ ca/server.go | 210 +++------------- ca/server_test.go | 411 +++++++++++++++++++++----------- integration/integration_test.go | 14 +- 4 files changed, 567 insertions(+), 327 deletions(-) create mode 100644 ca/reconciler.go diff --git a/ca/reconciler.go b/ca/reconciler.go new file mode 100644 index 0000000000..a35ae7cc41 --- /dev/null +++ b/ca/reconciler.go @@ -0,0 +1,259 @@ +package ca + +import ( + "bytes" + "context" + "fmt" + "reflect" + "sync" + "time" + + "github.com/cloudflare/cfssl/helpers" + "github.com/docker/swarmkit/api" + "github.com/docker/swarmkit/api/equality" + "github.com/docker/swarmkit/log" + "github.com/docker/swarmkit/manager/state/store" + "github.com/pkg/errors" +) + +// IssuanceStateRotateMaxBatchSize is the maximum number of nodes we'll tell to rotate their certificates in any given update +const IssuanceStateRotateMaxBatchSize = 30 + +func hasIssuer(n *api.Node, info *IssuerInfo) bool { + if n.Description == nil || n.Description.TLSInfo == nil { + return false + } + return bytes.Equal(info.Subject, n.Description.TLSInfo.CertIssuerSubject) && bytes.Equal(info.PublicKey, n.Description.TLSInfo.CertIssuerPublicKey) +} + +var errRootRotationChanged = errors.New("target root rotation has changed") + +// rootRotationReconciler keeps track of all the nodes in the store so that we can determine which ones need reconciliation when nodes are updated +// or the root CA is updated. This is meant to be used with watches on nodes and the cluster, and provides functions to be called when the +// cluster's RootCA has changed and when a node is added, updated, or removed. +type rootRotationReconciler struct { + mu sync.Mutex + clusterID string + batchUpdateInterval time.Duration + ctx context.Context + store *store.MemoryStore + + currentRootCA *api.RootCA + currentIssuer IssuerInfo + unconvergedNodes map[string]*api.Node + + wg sync.WaitGroup + cancel func() +} + +// IssuerFromAPIRootCA returns the desired issuer given an API root CA object +func IssuerFromAPIRootCA(rootCA *api.RootCA) (*IssuerInfo, error) { + wantedIssuer := rootCA.CACert + if rootCA.RootRotation != nil { + wantedIssuer = rootCA.RootRotation.CACert + } + issuerCerts, err := helpers.ParseCertificatesPEM(wantedIssuer) + if err != nil { + return nil, errors.Wrap(err, "invalid certificate in cluster root CA object") + } + if len(issuerCerts) == 0 { + return nil, errors.New("invalid certificate in cluster root CA object") + } + return &IssuerInfo{ + Subject: issuerCerts[0].RawSubject, + PublicKey: issuerCerts[0].RawSubjectPublicKeyInfo, + }, nil +} + +// assumption: UpdateRootCA will never be called with a `nil` root CA because the caller will be acting in response to +// a store update event +func (r *rootRotationReconciler) UpdateRootCA(newRootCA *api.RootCA) { + issuerInfo, err := IssuerFromAPIRootCA(newRootCA) + if err != nil { + log.G(r.ctx).WithError(err).Error("unable to update process the current root CA") + return + } + + var ( + shouldStartNewLoop, waitForPrevLoop bool + loopCtx context.Context + ) + r.mu.Lock() + defer func() { + r.mu.Unlock() + if shouldStartNewLoop { + if waitForPrevLoop { + r.wg.Wait() + } + go r.runReconcilerLoop(loopCtx, newRootCA) + } + }() + + // check if the issuer has changed, first + if reflect.DeepEqual(&r.currentIssuer, issuerInfo) { + r.currentRootCA = newRootCA + return + } + // If the issuer has changed, iterate through all the nodes to figure out which ones need rotation + if newRootCA.RootRotation != nil { + var nodes []*api.Node + r.store.View(func(tx store.ReadTx) { + nodes, err = store.FindNodes(tx, store.ByMembership(api.NodeMembershipAccepted)) + }) + if err != nil { + log.G(r.ctx).WithError(err).Error("unable to list nodes, so unable to process the current root CA") + return + } + + // from here on out, there will be no more errors that cause us to have to abandon updating the Root CA, + // so we can start making changes to r's fields + r.unconvergedNodes = make(map[string]*api.Node) + for _, n := range nodes { + if !hasIssuer(n, issuerInfo) { + r.unconvergedNodes[n.ID] = n + } + } + shouldStartNewLoop = true + if r.cancel != nil { // there's already a loop going, so cancel it + r.cancel() + waitForPrevLoop = true + } + loopCtx, r.cancel = context.WithCancel(r.ctx) + } else { + r.unconvergedNodes = nil + } + r.currentRootCA = newRootCA + r.currentIssuer = *issuerInfo +} + +// assumption: UpdateNode will never be called with a `nil` node because the caller will be acting in response to +// a store update event +func (r *rootRotationReconciler) UpdateNode(node *api.Node) { + r.mu.Lock() + defer r.mu.Unlock() + // if we're not in the middle of a root rotation, or if this node does not have membership, ignore it + if r.currentRootCA == nil || r.currentRootCA.RootRotation == nil || node.Spec.Membership != api.NodeMembershipAccepted { + return + } + if hasIssuer(node, &r.currentIssuer) { + delete(r.unconvergedNodes, node.ID) + } else { + r.unconvergedNodes[node.ID] = node + } +} + +// assumption: DeleteNode will never be called with a `nil` node because the caller will be acting in response to +// a store update event +func (r *rootRotationReconciler) DeleteNode(node *api.Node) { + r.mu.Lock() + delete(r.unconvergedNodes, node.ID) + r.mu.Unlock() +} + +func (r *rootRotationReconciler) runReconcilerLoop(ctx context.Context, loopRootCA *api.RootCA) { + r.wg.Add(1) + defer r.wg.Done() + for { + r.mu.Lock() + if len(r.unconvergedNodes) == 0 { + r.mu.Unlock() + + err := r.store.Update(func(tx store.Tx) error { + return r.finishRootRotation(tx, loopRootCA) + }) + if err == nil { + log.G(r.ctx).Info("completed root rotation") + return + } + log.G(r.ctx).WithError(err).Error("could not complete root rotation") + if err == errRootRotationChanged { + // if the root rotation has changed, this loop will be cancelled anyway, so may as well abort early + return + } + } else { + var toUpdate []*api.Node + for _, n := range r.unconvergedNodes { + iState := n.Certificate.Status.State + if iState != api.IssuanceStateRenew && iState != api.IssuanceStatePending && iState != api.IssuanceStateRotate { + n = n.Copy() + n.Certificate.Status.State = api.IssuanceStateRotate + toUpdate = append(toUpdate, n) + if len(toUpdate) >= IssuanceStateRotateMaxBatchSize { + break + } + } + } + r.mu.Unlock() + + if err := r.batchUpdateNodes(toUpdate); err != nil { + log.G(r.ctx).WithError(err).Errorf("store error when trying to batch update %d nodes to request certificate rotation", len(toUpdate)) + } + } + + select { + case <-ctx.Done(): + return + case <-time.After(r.batchUpdateInterval): + } + } +} + +// This function assumes that the expected root CA has root rotation. This is intended to be used by +// `reconcileNodeRootsAndCerts`, which uses the root CA from the `lastSeenClusterRootCA`, and checks +// that it has a root rotation before calling this function. +func (r *rootRotationReconciler) finishRootRotation(tx store.Tx, expectedRootCA *api.RootCA) error { + cluster := store.GetCluster(tx, r.clusterID) + if cluster == nil { + return fmt.Errorf("unable to get cluster %s", r.clusterID) + } + + // If the RootCA object has changed (because another root rotation was started or because some other node + // had finished the root rotation), we cannot finish the root rotation that we were working on. + if !equality.RootCAEqualStable(expectedRootCA, &cluster.RootCA) { + return errRootRotationChanged + } + + var signerCert []byte + if len(cluster.RootCA.RootRotation.CAKey) > 0 { + signerCert = cluster.RootCA.RootRotation.CACert + } + // we don't actually have to parse out the default node expiration from the cluster - we are just using + // the ca.RootCA object to generate new tokens and the digest + updatedRootCA, err := NewRootCA(cluster.RootCA.RootRotation.CACert, signerCert, cluster.RootCA.RootRotation.CAKey, + DefaultNodeCertExpiration, nil) + if err != nil { + return errors.Wrap(err, "invalid cluster root rotation object") + } + cluster.RootCA = api.RootCA{ + CACert: cluster.RootCA.RootRotation.CACert, + CAKey: cluster.RootCA.RootRotation.CAKey, + CACertHash: updatedRootCA.Digest.String(), + JoinTokens: api.JoinTokens{ + Worker: GenerateJoinToken(&updatedRootCA), + Manager: GenerateJoinToken(&updatedRootCA), + }, + LastForcedRotation: cluster.RootCA.LastForcedRotation, + } + return store.UpdateCluster(tx, cluster) +} + +func (r *rootRotationReconciler) batchUpdateNodes(toUpdate []*api.Node) error { + if len(toUpdate) == 0 { + return nil + } + _, err := r.store.Batch(func(batch *store.Batch) error { + // Directly update the nodes rather than get + update, and ignore version errors. Since + // `rootRotationReconciler` should be hooked up to all node update/delete/create events, we should have + // close to the latest versions of all the nodes. If not, the node will updated later and the + // next batch of updates should catch it. + for _, n := range toUpdate { + if err := batch.Update(func(tx store.Tx) error { + return store.UpdateNode(tx, n) + }); err != nil && err != store.ErrSequenceConflict { + log.G(r.ctx).WithError(err).Errorf("unable to update node %s to request a certificate rotation", n.ID) + } + } + return nil + }) + return err +} diff --git a/ca/server.go b/ca/server.go index 85c18eb75d..982b0e2e83 100644 --- a/ca/server.go +++ b/ca/server.go @@ -4,12 +4,10 @@ import ( "bytes" "crypto/subtle" "crypto/x509" - "fmt" "sync" "time" "github.com/Sirupsen/logrus" - "github.com/cloudflare/cfssl/helpers" "github.com/docker/swarmkit/api" "github.com/docker/swarmkit/api/equality" "github.com/docker/swarmkit/identity" @@ -67,9 +65,8 @@ type Server struct { // before we update the security config with the new root CA, we need to be able to save the root certs rootPaths CertPaths - // lets us know if we're already doing a reconcile loop to tell nodes to rotate their - // TLS certificates and watching for root rotation completion - rootCAReconciliationInProgress bool + // lets us monitor and finish root rotations + rootReconciler *rootRotationReconciler rootReconciliationRetryInterval time.Duration } @@ -410,14 +407,28 @@ func (s *Server) Run(ctx context.Context) error { return errors.New("CA signer is already running") } s.wg.Add(1) + s.ctx, s.cancel = context.WithCancel(log.WithModule(ctx, "ca")) + ctx = s.ctx + // we need to set it on the server, because `Server.UpdateRootCA` can be called from outside the Run function + s.rootReconciler = &rootRotationReconciler{ + ctx: log.WithField(ctx, "method", "(*Server).rootRotationReconciler"), + clusterID: s.securityConfig.ClientTLSCreds.Organization(), + store: s.store, + batchUpdateInterval: s.rootReconciliationRetryInterval, + } + rootReconciler := s.rootReconciler s.mu.Unlock() - defer s.wg.Done() - ctx = log.WithModule(ctx, "ca") + defer func() { + s.mu.Lock() + s.rootReconciler = nil + s.mu.Unlock() + }() // Retrieve the channels to keep track of changes in the cluster // Retrieve all the currently registered nodes var nodes []*api.Node + updates, cancel, err := store.ViewAndWatch( s.store, func(readTx store.ReadTx) error { @@ -434,13 +445,12 @@ func (s *Server) Run(ctx context.Context) error { }, api.EventCreateNode{}, api.EventUpdateNode{}, + api.EventDeleteNode{}, ) // Do this after updateCluster has been called, so isRunning never // returns true without joinTokens being set correctly. s.mu.Lock() - s.ctx, s.cancel = context.WithCancel(ctx) - ctx = s.ctx close(s.started) s.mu.Unlock() @@ -461,7 +471,6 @@ func (s *Server) Run(ctx context.Context) error { "method": "(*Server).Run", }).WithError(err).Errorf("error attempting to reconcile certificates") } - s.reconcileNodeRootsAndCerts() ticker := time.NewTicker(s.reconciliationRetryInterval) defer ticker.Stop() @@ -480,13 +489,18 @@ func (s *Server) Run(ctx context.Context) error { switch v := event.(type) { case api.EventCreateNode: s.evaluateAndSignNodeCert(ctx, v.Node) + rootReconciler.UpdateNode(v.Node) case api.EventUpdateNode: // If this certificate is already at a final state // no need to evaluate and sign it. if !isFinalState(v.Node.Certificate.Status) { s.evaluateAndSignNodeCert(ctx, v.Node) } + rootReconciler.UpdateNode(v.Node) + case api.EventDeleteNode: + rootReconciler.DeleteNode(v.Node) } + case <-ticker.C: for _, node := range s.pending { if err := s.evaluateAndSignNodeCert(ctx, node); err != nil { @@ -557,12 +571,16 @@ func (s *Server) isRunning() bool { func (s *Server) UpdateRootCA(ctx context.Context, cluster *api.Cluster) error { s.mu.Lock() s.joinTokens = cluster.RootCA.JoinTokens.Copy() + reconciler := s.rootReconciler s.mu.Unlock() + rCA := cluster.RootCA.Copy() + if reconciler != nil { + reconciler.UpdateRootCA(rCA) + } s.secConfigMu.Lock() defer s.secConfigMu.Unlock() - rCA := cluster.RootCA - rootCAChanged := len(rCA.CACert) != 0 && !equality.RootCAEqualStable(s.lastSeenClusterRootCA, &cluster.RootCA) + rootCAChanged := len(rCA.CACert) != 0 && !equality.RootCAEqualStable(s.lastSeenClusterRootCA, rCA) externalCAChanged := !equality.ExternalCAsEqualStable(s.lastSeenExternalCAs, cluster.Spec.CAConfig.ExternalCAs) logger := log.G(ctx).WithFields(logrus.Fields{ "cluster.id": cluster.ID, @@ -619,7 +637,7 @@ func (s *Server) UpdateRootCA(ctx context.Context, cluster *api.Cluster) error { } // only update the server cache if we've successfully updated the root CA logger.Debug("Root CA updated successfully") - s.lastSeenClusterRootCA = cluster.RootCA.Copy() + s.lastSeenClusterRootCA = rCA } // we want to update if the external CA changed, or if the root CA changed because the root CA could affect what @@ -661,9 +679,6 @@ func (s *Server) UpdateRootCA(ctx context.Context, cluster *api.Cluster) error { s.securityConfig.externalCA.UpdateURLs(cfsslURLs...) s.lastSeenExternalCAs = cluster.Spec.CAConfig.Copy().ExternalCAs } - if rootCAChanged && s.lastSeenClusterRootCA.RootRotation != nil { - s.reconcileNodeRootsAndCerts() - } return nil } @@ -826,166 +841,3 @@ func isFinalState(status api.IssuanceStatus) bool { return false } - -// This function assumes that the expected root CA has root rotation. This is intended to be used by -// `reconcileNodeRootsAndCerts`, which uses the root CA from the `lastSeenClusterRootCA`, and checks -// that it has a root rotation before calling this function. -func (s *Server) finishRootRotation(tx store.Tx, expectedRootCA *api.RootCA) error { - clusterID := s.securityConfig.ClientTLSCreds.Organization() - cluster := store.GetCluster(tx, clusterID) - if cluster == nil { - return fmt.Errorf("unable to get cluster %s", clusterID) - } - - // If the RootCA object has changed (because another root rotation was started or because some other node - // had finished the root rotation), we cannot finish the root rotation that we were working on. - if !equality.RootCAEqualStable(expectedRootCA, &cluster.RootCA) { - return errors.New("target root rotation has changed") - } - - var signerCert []byte - if len(cluster.RootCA.RootRotation.CAKey) > 0 { - signerCert = cluster.RootCA.RootRotation.CACert - } - // we don't actually have to parse out the default node expiration from the cluster - we are just using - // the ca.RootCA object to generate new tokens and the digest - updatedRootCA, err := NewRootCA(cluster.RootCA.RootRotation.CACert, signerCert, cluster.RootCA.RootRotation.CAKey, - DefaultNodeCertExpiration, nil) - if err != nil { - return errors.Wrap(err, "invalid cluster root rotation object") - } - cluster.RootCA = api.RootCA{ - CACert: cluster.RootCA.RootRotation.CACert, - CAKey: cluster.RootCA.RootRotation.CAKey, - CACertHash: updatedRootCA.Digest.String(), - JoinTokens: api.JoinTokens{ - Worker: GenerateJoinToken(&updatedRootCA), - Manager: GenerateJoinToken(&updatedRootCA), - }, - LastForcedRotation: cluster.RootCA.LastForcedRotation, - } - return store.UpdateCluster(tx, cluster) -} - -func (s *Server) reconcileNodeRootsAndCerts() { - s.mu.Lock() - if !s.isRunning() || s.rootCAReconciliationInProgress { - s.mu.Unlock() - return - } - ctx := s.ctx - s.rootCAReconciliationInProgress = true - s.mu.Unlock() - - s.wg.Add(1) - logger := log.G(ctx).WithField("method", "(*Server).reconcileNodeRootsAndCerts") - - go func(retryInterval time.Duration) { - defer func() { - s.wg.Done() - s.mu.Lock() - s.rootCAReconciliationInProgress = false - s.mu.Unlock() - }() - - for { - s.secConfigMu.Lock() - rootCA := s.lastSeenClusterRootCA - s.secConfigMu.Unlock() - if rootCA == nil { - return - } - - wantedIssuer := rootCA.CACert - if rootCA.RootRotation != nil { - wantedIssuer = rootCA.RootRotation.CACert - } - issuerCert, err := helpers.ParseCertificatePEM(wantedIssuer) - if err != nil { - logger.WithError(err).Error("invalid certificate in cluster root CA object") - } - - // If the last seen root rotation is nil, then possibly during a leadership transfer a different manager's - // CA node started root rotation reconciliation and finished the root rotation, so we should also end the - // root reconciliation - if rootCA.RootRotation == nil { - return - } - - var allNodes []*api.Node - - s.store.View(func(tx store.ReadTx) { - allNodes, err = store.FindNodes(tx, store.ByMembership(api.NodeMembershipAccepted)) - }) - if err != nil { - logger.WithError(err).Error("could not find all nodes in the store") - } - - var rootRotationDone bool - _, err = s.store.Batch(func(batch *store.Batch) error { - converged := true // this value is true if all the node certs are issued by the correct issuer - for _, node := range allNodes { - var needUpdate bool - err := batch.Update(func(tx store.Tx) error { - n := store.GetNode(tx, node.ID) - if n == nil || (n.Description != nil && n.Description.TLSInfo != nil && - bytes.Equal(n.Description.TLSInfo.CertIssuerPublicKey, issuerCert.RawSubjectPublicKeyInfo) && - bytes.Equal(n.Description.TLSInfo.CertIssuerSubject, issuerCert.RawSubject)) { - return nil - } - converged = false - - // If we are already waiting for the node to rotate (or the node's about to get issued a new cert by us), - // so don't bother telling the node to rotate again. If a cert is in renew or pending, the cert that - // the node will get should be issued by the correct issuer. However, in all likelihood the node will - // not get a chance to report back its TLS before the reconciliation loop will mark it as rotate, and - // it will get a new certificate, which is fine, but results in an extra certificate renewal round. - issuanceState := n.Certificate.Status.State - if issuanceState == api.IssuanceStateRotate || issuanceState == api.IssuanceStateRenew || issuanceState == api.IssuanceStatePending { - return nil - } - n.Certificate.Status.State = api.IssuanceStateRotate - needUpdate = true - return store.UpdateNode(tx, n) - }) - if err != nil { - logger.WithError(err).Debugf("could not update node %s to IssuanceStateRotate", node.ID) - } else if needUpdate { - logger.Debugf("updated node %s to IssuanceStateRotate", node.ID) - } - } - // It's possible that between getting all nodes and finishing root rotation, new nodes have joined. by the time we - // fetch all nodes, the CA signer should have already been changed to be the desired issuer by the time we list all - // nodes, so any new nodes that appear after that will only get certificates signed by the correct issuer. - if converged { - // when the nodes are all converged on having TLS certificates issued by the correct entity, we can - // try to finish the root rotation - err := batch.Update(func(tx store.Tx) error { - return s.finishRootRotation(tx, rootCA) - }) - if err != nil { - logger.WithError(err).Errorf("could not complete root rotation") - } else { - rootRotationDone = true - } - } - - return nil - }) - if err != nil { - logger.WithError(err).Errorf("failed to check nodes for desired certificate issuer") - } - if rootRotationDone { - logger.Infof("completed root rotation on cluster %s", s.securityConfig.ServerTLSCreds.Organization()) - return - } - - select { - case <-ctx.Done(): - logger.Info("stopping incomplete root rotation reconciliation due to context being stopped") - return - case <-time.After(retryInterval): - } - } - }(s.rootReconciliationRetryInterval) -} diff --git a/ca/server_test.go b/ca/server_test.go index 9bc1cdf1b0..ce53f16170 100644 --- a/ca/server_test.go +++ b/ca/server_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io/ioutil" "os" + "path/filepath" "reflect" "testing" "time" @@ -563,60 +564,50 @@ type rootRotationTester struct { // go through all the nodes and update/create the ones we want, and delete the ones // we don't func (r *rootRotationTester) convergeWantedNodes(wantNodes map[string]*api.Node, descr string) { - updatedOldNodes := make(map[string]struct{}) - createdNewNodes := make(map[string]struct{}) - require.NoError(r.t, testutils.PollFuncWithTimeout(nil, func() error { - return r.tc.MemoryStore.Update(func(tx store.Tx) error { - nodes, err := store.FindNodes(tx, store.All) - if err != nil { - return err - } - for _, node := range nodes { - wanted, inWanted := wantNodes[node.ID] - _, done := updatedOldNodes[node.ID] - - if inWanted && !done { - node.Description = wanted.Description - node.Certificate = wanted.Certificate - if err := store.UpdateNode(tx, node); err != nil { - return err - } - updatedOldNodes[node.ID] = struct{}{} - } else if !inWanted { - if err := store.DeleteNode(tx, node.ID); err != nil { - return err - } - updatedOldNodes[node.ID] = struct{}{} + // update existing and create new nodes first before deleting nodes, else a root rotation + // may finish early if all the nodes get deleted when the root rotation happens + require.NoError(r.t, r.tc.MemoryStore.Update(func(tx store.Tx) error { + for nodeID, wanted := range wantNodes { + node := store.GetNode(tx, nodeID) + if node == nil { + if err := store.CreateNode(tx, wanted); err != nil { + return err } + continue } - for nodeID, wanted := range wantNodes { - _, createdAlready := createdNewNodes[nodeID] - if _, ok := updatedOldNodes[nodeID]; !ok && !createdAlready { - if err := store.CreateNode(tx, wanted); err != nil { - return err - } - createdNewNodes[nodeID] = struct{}{} + node.Description = wanted.Description + node.Certificate = wanted.Certificate + if err := store.UpdateNode(tx, node); err != nil { + return err + } + } + nodes, err := store.FindNodes(tx, store.All) + if err != nil { + return err + } + for _, node := range nodes { + if _, inWanted := wantNodes[node.ID]; !inWanted { + if err := store.DeleteNode(tx, node.ID); err != nil { + return err } } - return nil - }) - }, 2*time.Second), descr) + } + return nil + }), descr) } func (r *rootRotationTester) convergeRootCA(wantRootCA *api.RootCA, descr string) { - require.NoError(r.t, testutils.PollFuncWithTimeout(nil, func() error { - return r.tc.MemoryStore.Update(func(tx store.Tx) error { - clusters, err := store.FindClusters(tx, store.All) - if err != nil || len(clusters) != 1 { - return errors.Wrap(err, "unable to find cluster") - } - clusters[0].RootCA = *wantRootCA - return store.UpdateCluster(tx, clusters[0]) - }) - }, time.Second), descr) + require.NoError(r.t, r.tc.MemoryStore.Update(func(tx store.Tx) error { + clusters, err := store.FindClusters(tx, store.All) + if err != nil || len(clusters) != 1 { + return errors.Wrap(err, "unable to find cluster") + } + clusters[0].RootCA = *wantRootCA + return store.UpdateCluster(tx, clusters[0]) + }), descr) } -func (r *rootRotationTester) getFakeAPINode(id string, state api.IssuanceStatus_State, tlsInfo *api.NodeTLSInfo, member bool) *api.Node { +func getFakeAPINode(t *testing.T, id string, state api.IssuanceStatus_State, tlsInfo *api.NodeTLSInfo, member bool) *api.Node { node := &api.Node{ ID: id, Certificate: api.Certificate{ @@ -634,7 +625,7 @@ func (r *rootRotationTester) getFakeAPINode(id string, state api.IssuanceStatus_ // the CA server will immediately pick these up, so generate CSRs for the CA server to sign if state == api.IssuanceStateRenew || state == api.IssuanceStatePending { csr, _, err := ca.GenerateNewCSR() - require.NoError(r.t, err) + require.NoError(t, err) node.Certificate.CSR = csr } if tlsInfo != nil { @@ -721,12 +712,12 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { "a new cert. Any renew/pending nodes will have certs issued, but because the TLS info is nil, they will " + `go "rotate" state`), nodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, nil, true), - "2": rt.getFakeAPINode("2", api.IssuanceStateRenew, nil, true), - "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, nil, true), - "4": rt.getFakeAPINode("4", api.IssuanceStatePending, nil, true), - "5": rt.getFakeAPINode("5", api.IssuanceStateFailed, nil, true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, nil, true), + "2": getFakeAPINode(t, "2", api.IssuanceStateRenew, nil, true), + "3": getFakeAPINode(t, "3", api.IssuanceStateRotate, nil, true), + "4": getFakeAPINode(t, "4", api.IssuanceStatePending, nil, true), + "5": getFakeAPINode(t, "5", api.IssuanceStateFailed, nil, true), }, rootCA: &api.RootCA{ CACert: startCluster.RootCA.CACert, @@ -739,12 +730,12 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { }, }, expectedNodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateRotate, nil, true), - "2": rt.getFakeAPINode("2", api.IssuanceStateRotate, nil, true), - "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, nil, true), - "4": rt.getFakeAPINode("4", api.IssuanceStateRotate, nil, true), - "5": rt.getFakeAPINode("5", api.IssuanceStateRotate, nil, true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateRotate, nil, true), + "2": getFakeAPINode(t, "2", api.IssuanceStateRotate, nil, true), + "3": getFakeAPINode(t, "3", api.IssuanceStateRotate, nil, true), + "4": getFakeAPINode(t, "4", api.IssuanceStateRotate, nil, true), + "5": getFakeAPINode(t, "5", api.IssuanceStateRotate, nil, true), }, }, { @@ -752,12 +743,12 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { "(going by the TLS info), which shouldn't really happen. the rotation reconciliation " + "will tell the wrong ones to rotate a second time"), nodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "2": rt.getFakeAPINode("2", api.IssuanceStateIssued, oldNodeTLSInfo, true), - "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "5": rt.getFakeAPINode("5", api.IssuanceStateIssued, oldNodeTLSInfo, true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "2": getFakeAPINode(t, "2", api.IssuanceStateIssued, oldNodeTLSInfo, true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": getFakeAPINode(t, "5", api.IssuanceStateIssued, oldNodeTLSInfo, true), }, rootCA: &api.RootCA{ // no change in root CA from previous CACert: startCluster.RootCA.CACert, @@ -770,23 +761,23 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { }, }, expectedNodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "2": rt.getFakeAPINode("2", api.IssuanceStateRotate, oldNodeTLSInfo, true), - "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "5": rt.getFakeAPINode("5", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "2": getFakeAPINode(t, "2", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": getFakeAPINode(t, "5", api.IssuanceStateRotate, oldNodeTLSInfo, true), }, }, { descr: ("New nodes that are added will also be picked up and told to rotate"), nodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "5": rt.getFakeAPINode("5", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "6": rt.getFakeAPINode("6", api.IssuanceStateRenew, nil, true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": getFakeAPINode(t, "5", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "6": getFakeAPINode(t, "6", api.IssuanceStateRenew, nil, true), }, rootCA: &api.RootCA{ // no change in root CA from previous CACert: startCluster.RootCA.CACert, @@ -799,12 +790,12 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { }, }, expectedNodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "5": rt.getFakeAPINode("5", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "6": rt.getFakeAPINode("6", api.IssuanceStateRotate, nil, true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": getFakeAPINode(t, "5", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "6": getFakeAPINode(t, "6", api.IssuanceStateRotate, nil, true), }, }, { @@ -812,12 +803,12 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { "different cert, all the nodes with the old root rotation cert will be told " + "to rotate again."), nodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[1], true), - "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[0], true), - "5": rt.getFakeAPINode("5", api.IssuanceStateIssued, oldNodeTLSInfo, true), - "6": rt.getFakeAPINode("6", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": getFakeAPINode(t, "5", api.IssuanceStateIssued, oldNodeTLSInfo, true), + "6": getFakeAPINode(t, "6", api.IssuanceStateIssued, rotationTLSInfo[0], true), }, rootCA: &api.RootCA{ // new root rotation CACert: startCluster.RootCA.CACert, @@ -830,23 +821,23 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { }, }, expectedNodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateRotate, rotationTLSInfo[0], true), - "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[1], true), - "4": rt.getFakeAPINode("4", api.IssuanceStateRotate, rotationTLSInfo[0], true), - "5": rt.getFakeAPINode("5", api.IssuanceStateRotate, oldNodeTLSInfo, true), - "6": rt.getFakeAPINode("6", api.IssuanceStateRotate, rotationTLSInfo[0], true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateRotate, rotationTLSInfo[0], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateRotate, rotationTLSInfo[0], true), + "5": getFakeAPINode(t, "5", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "6": getFakeAPINode(t, "6", api.IssuanceStateRotate, rotationTLSInfo[0], true), }, }, { descr: ("Once all nodes have rotated to their desired TLS info (even if it's because " + "a node with the wrong TLS info has been removed, the root rotation is completed."), nodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[1], true), - "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[1], true), - "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[1], true), - "6": rt.getFakeAPINode("6", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "6": getFakeAPINode(t, "6", api.IssuanceStateIssued, rotationTLSInfo[1], true), }, rootCA: &api.RootCA{ // no change in root CA from previous - even if root rotation gets completed after @@ -874,11 +865,11 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { "it will start reconciling the nodes as soon as it's started up again"), caServerRestart: true, nodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo[1], true), - "3": rt.getFakeAPINode("3", api.IssuanceStateIssued, rotationTLSInfo[1], true), - "4": rt.getFakeAPINode("4", api.IssuanceStateIssued, rotationTLSInfo[1], true), - "6": rt.getFakeAPINode("6", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "6": getFakeAPINode(t, "6", api.IssuanceStateIssued, rotationTLSInfo[1], true), }, rootCA: &api.RootCA{ CACert: startCluster.RootCA.CACert, @@ -891,11 +882,11 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { }, }, expectedNodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateRotate, rotationTLSInfo[1], true), - "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, rotationTLSInfo[1], true), - "4": rt.getFakeAPINode("4", api.IssuanceStateRotate, rotationTLSInfo[1], true), - "6": rt.getFakeAPINode("6", api.IssuanceStateRotate, rotationTLSInfo[1], true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateRotate, rotationTLSInfo[1], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateRotate, rotationTLSInfo[1], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateRotate, rotationTLSInfo[1], true), + "6": getFakeAPINode(t, "6", api.IssuanceStateRotate, rotationTLSInfo[1], true), }, }, } @@ -956,6 +947,19 @@ func TestRootRotationReconciliationWithChanges(t *testing.T) { return fmt.Errorf("the certificate status of node %s is not expected:\n\texpected: %v\n\tactual: %v", node.ID, expected.Certificate, node.Certificate) } + + // ensure that the security config's root CA object has the same expected key + expectedKey := testcase.expectedRootCA.CAKey + if testcase.expectedRootCA.RootRotation != nil { + expectedKey = testcase.expectedRootCA.RootRotation.CAKey + } + s, err := rt.tc.ServingSecurityConfig.RootCA().Signer() + if err != nil { + return err + } + if !bytes.Equal(s.Key, expectedKey) { + return fmt.Errorf("the security config has not been updated correctly") + } } return nil }, 5*time.Second), testcase.descr) @@ -1005,12 +1009,12 @@ func TestRootRotationReconciliationNoChanges(t *testing.T) { "is in progress"), caServerStopped: true, nodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, oldNodeTLSInfo, true), - "2": rt.getFakeAPINode("2", api.IssuanceStateRenew, nil, true), - "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, nil, true), - "4": rt.getFakeAPINode("4", api.IssuanceStatePending, nil, true), - "5": rt.getFakeAPINode("5", api.IssuanceStateFailed, nil, true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, oldNodeTLSInfo, true), + "2": getFakeAPINode(t, "2", api.IssuanceStateRenew, nil, true), + "3": getFakeAPINode(t, "3", api.IssuanceStateRotate, nil, true), + "4": getFakeAPINode(t, "4", api.IssuanceStatePending, nil, true), + "5": getFakeAPINode(t, "5", api.IssuanceStateFailed, nil, true), }, rootCA: &api.RootCA{ CACert: startCluster.RootCA.CACert, @@ -1024,13 +1028,13 @@ func TestRootRotationReconciliationNoChanges(t *testing.T) { }, }, { - descr: ("If all nodes have the right TLS info or are already rotated (or are not members), the " + + descr: ("If all nodes have the right TLS info or are already rotated (or are not members), " + "there will be no changes needed"), nodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo, true), - "2": rt.getFakeAPINode("2", api.IssuanceStateRotate, oldNodeTLSInfo, true), - "3": rt.getFakeAPINode("3", api.IssuanceStateRotate, rotationTLSInfo, true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo, true), + "2": getFakeAPINode(t, "2", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "3": getFakeAPINode(t, "3", api.IssuanceStateRotate, rotationTLSInfo, true), }, rootCA: &api.RootCA{ // no change in root CA from previous CACert: startCluster.RootCA.CACert, @@ -1049,9 +1053,9 @@ func TestRootRotationReconciliationNoChanges(t *testing.T) { "in the process of getting a new cert. Even if they're issued by a different issuer, they will be " + "left alone because they'll have an interemdiate that chains up to the old issuer."), nodes: map[string]*api.Node{ - "0": rt.getFakeAPINode("0", api.IssuanceStatePending, nil, false), - "1": rt.getFakeAPINode("1", api.IssuanceStateIssued, rotationTLSInfo, true), - "2": rt.getFakeAPINode("2", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo, true), + "2": getFakeAPINode(t, "2", api.IssuanceStateRotate, oldNodeTLSInfo, true), }, rootCA: &api.RootCA{ // no change in root CA from previous CACert: startCluster.RootCA.CACert, @@ -1093,6 +1097,15 @@ func TestRootRotationReconciliationNoChanges(t *testing.T) { require.Equal(t, expected.Description, node.Description, "node %s: %s", node.ID, testcase.descr) require.Equal(t, expected.Certificate.Status, node.Certificate.Status, "node %s: %s", node.ID, testcase.descr) } + + // ensure that the security config's root CA object has the same expected key + expectedKey := testcase.rootCA.CAKey + if testcase.rootCA.RootRotation != nil { + expectedKey = testcase.rootCA.RootRotation.CAKey + } + s, err := rt.tc.ServingSecurityConfig.RootCA().Signer() + require.NoError(t, err, testcase.descr) + require.Equal(t, s.Key, expectedKey, testcase.descr) } } @@ -1111,16 +1124,21 @@ func TestRootRotationReconciliationRace(t *testing.T) { tc: tc, t: t, } - // start a competing CA server + tempDir, err := ioutil.TempDir("", "competing-ca-server") require.NoError(t, err) defer os.RemoveAll(tempDir) - paths := ca.NewConfigPaths(tempDir) - competingSecConfig, err := tc.NewNodeConfig(ca.ManagerRole) - require.NoError(t, err) var otherServers []*ca.Server + var secConfigs []*ca.SecurityConfig for i := 0; i < 3; i++ { // to make sure we get some collision + // start a competing CA server + competingSecConfig, err := tc.NewNodeConfig(ca.ManagerRole) + require.NoError(t, err) + secConfigs = append(secConfigs, competingSecConfig) + + paths := ca.NewConfigPaths(filepath.Join(tempDir, fmt.Sprintf("%d", i))) + otherServer := ca.NewServer(tc.MemoryStore, competingSecConfig, paths.RootCA) // offset each server's reconciliation interval somewhat so that some will // pre-empt others @@ -1131,10 +1149,7 @@ func TestRootRotationReconciliationRace(t *testing.T) { } clusterWatch, clusterWatchCancel, err := store.ViewAndWatch( tc.MemoryStore, func(tx store.ReadTx) error { - cluster := store.GetCluster(tx, tc.Organization) - for _, s := range otherServers { - s.UpdateRootCA(context.Background(), cluster) - } + // don't bother getting the cluster - the CA serverß have already done that when first running return nil }, api.EventUpdateCluster{ @@ -1170,16 +1185,17 @@ func TestRootRotationReconciliationRace(t *testing.T) { nodes := make(map[string]*api.Node) for i := 0; i < 5; i++ { nodeID := fmt.Sprintf("%d", i) - nodes[nodeID] = rt.getFakeAPINode(nodeID, api.IssuanceStateIssued, oldNodeTLSInfo, true) + nodes[nodeID] = getFakeAPINode(t, nodeID, api.IssuanceStateIssued, oldNodeTLSInfo, true) } rt.convergeWantedNodes(nodes, "setting up nodes for root rotation race condition test") + var rotationCert, rotationKey []byte for i := 0; i < 10; i++ { var ( rotationCrossSigned []byte rotationTLSInfo *api.NodeTLSInfo ) - rotationCert, rotationKey, err := cautils.CreateRootCertAndKey(fmt.Sprintf("root cn %d", i)) + rotationCert, rotationKey, err = cautils.CreateRootCertAndKey(fmt.Sprintf("root cn %d", i)) require.NoError(t, err) require.NoError(t, tc.MemoryStore.Update(func(tx store.Tx) error { cluster := store.GetCluster(tx, tc.Organization) @@ -1214,9 +1230,132 @@ func TestRootRotationReconciliationRace(t *testing.T) { if cluster == nil { return errors.New("cluster has disappeared") } - if cluster.RootCA.RootRotation == nil { - return fmt.Errorf("root rotation is still present") + if cluster.RootCA.RootRotation != nil { + return errors.New("root rotation is still present") + } + if !bytes.Equal(cluster.RootCA.CACert, rotationCert) { + return errors.New("expected root cert is wrong") + } + if !bytes.Equal(cluster.RootCA.CAKey, rotationKey) { + return errors.New("expected root key is wrong") + } + for _, secConfig := range secConfigs { + s, err := secConfig.RootCA().Signer() + if err != nil { + return err + } + if !bytes.Equal(s.Key, rotationKey) { + return errors.New("all the sec configs haven't been updated yet") + } + } + return nil + }, 5*time.Second)) + + // all of the ca servers have the appropriate cert and key +} + +// If there are a lot of nodes, we only update a small number of them at once. +func TestRootRotationReconciliationThrottled(t *testing.T) { + t.Parallel() + if cautils.External { + // the external CA functionality is unrelated to testing the reconciliation loop + return + } + + tc := cautils.NewTestCA(t) + defer tc.Stop() + // immediately stop the CA server - we want to run our down + tc.CAServer.Stop() + + caServer := ca.NewServer(tc.MemoryStore, tc.ServingSecurityConfig, tc.Paths.RootCA) + // set the reconciliation interval to something ridiculous, so we can make sure the first + // batch does update all of them + caServer.SetRootReconciliationInterval(time.Hour) + startCAServer(caServer) + defer caServer.Stop() + + var nodes []*api.Node + clusterWatch, clusterWatchCancel, err := store.ViewAndWatch( + tc.MemoryStore, func(tx store.ReadTx) error { + // don't bother getting the cluster - the CA server has already done that when first running + var err error + nodes, err = store.FindNodes(tx, store.ByMembership(api.NodeMembershipAccepted)) + return err + }, + api.EventUpdateCluster{ + Cluster: &api.Cluster{ID: tc.Organization}, + Checks: []api.ClusterCheckFunc{api.ClusterCheckID}, + }, + ) + require.NoError(t, err) + defer clusterWatchCancel() + + done := make(chan struct{}) + defer close(done) + go func() { + for { + select { + case event := <-clusterWatch: + clusterEvent := event.(api.EventUpdateCluster) + caServer.UpdateRootCA(context.Background(), clusterEvent.Cluster) + case <-done: + return + } + } + }() + + // create twice the batch size of nodes + _, err = tc.MemoryStore.Batch(func(batch *store.Batch) error { + for i := len(nodes); i < ca.IssuanceStateRotateMaxBatchSize*2; i++ { + nodeID := fmt.Sprintf("%d", i) + err := batch.Update(func(tx store.Tx) error { + return store.CreateNode(tx, getFakeAPINode(t, nodeID, api.IssuanceStateIssued, nil, true)) + }) + if err != nil { + return err + } + } + return nil + }) + require.NoError(t, err) + + rotationCert := cautils.ECDSA256SHA256Cert + rotationKey := cautils.ECDSA256Key + rotationCrossSigned, _ := getRotationInfo(t, rotationCert, &tc.RootCA) + + require.NoError(t, tc.MemoryStore.Update(func(tx store.Tx) error { + cluster := store.GetCluster(tx, tc.Organization) + if cluster == nil { + return errors.New("cluster has disappeared") + } + rootCA := cluster.RootCA.Copy() + rootCA.RootRotation = &api.RootRotation{ + CACert: rotationCert, + CAKey: rotationKey, + CrossSignedCACert: rotationCrossSigned, + } + cluster.RootCA = *rootCA + return store.UpdateCluster(tx, cluster) + })) + + checkRotationNumber := func() error { + tc.MemoryStore.View(func(tx store.ReadTx) { + nodes, err = store.FindNodes(tx, store.All) + }) + var issuanceRotate int + for _, n := range nodes { + if n.Certificate.Status.State == api.IssuanceStateRotate { + issuanceRotate += 1 + } + } + if issuanceRotate != ca.IssuanceStateRotateMaxBatchSize { + return fmt.Errorf("expected %d, got %d", ca.IssuanceStateRotateMaxBatchSize, issuanceRotate) } return nil - }, time.Second)) + } + + require.NoError(t, testutils.PollFuncWithTimeout(nil, checkRotationNumber, 5*time.Second)) + // prove that it's not just because the updates haven't finished + time.Sleep(time.Second) + require.NoError(t, checkRotationNumber()) } diff --git a/integration/integration_test.go b/integration/integration_test.go index d676987afe..d521222ec8 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -627,27 +627,17 @@ func TestSuccessfulRootRotation(t *testing.T) { return nil }, opsTimeout)) - // Kill the leader, and bring the other manager back to show that a new - // leader can pick up the reconciliation loop and complete the root rotation. - // Bring one worker back, kill the other worker, and add a new worker - show - // that we can converge on a root rotation. - downLeader, err := cl.Leader() - require.NoError(t, err) - leaderID := downLeader.node.NodeID() - + // Bring the other manager back. Also bring one worker back, kill the other worker, + // and add a new worker - show that we can converge on a root rotation. require.NoError(t, cl.StartNode(downManagerID)) require.NoError(t, cl.StartNode(downWorkerIDs[0])) require.NoError(t, cl.RemoveNode(downWorkerIDs[1], false)) require.NoError(t, cl.AddAgent()) - require.NoError(t, downLeader.Pause(false)) // we can finish root rotation even though the previous leader was down because it had // already rotated its cert pollRootRotationDone(t, cl) - // bring the previous leader back up so it can get the new root - require.NoError(t, cl.StartNode(leaderID)) - // wait until all the nodes have gotten their new certs and trust roots require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { resp, err = cl.api.ListNodes(context.Background(), &api.ListNodesRequest{})