diff --git a/ca/reconciler.go b/ca/reconciler.go new file mode 100644 index 0000000000..a35ae7cc41 --- /dev/null +++ b/ca/reconciler.go @@ -0,0 +1,259 @@ +package ca + +import ( + "bytes" + "context" + "fmt" + "reflect" + "sync" + "time" + + "github.com/cloudflare/cfssl/helpers" + "github.com/docker/swarmkit/api" + "github.com/docker/swarmkit/api/equality" + "github.com/docker/swarmkit/log" + "github.com/docker/swarmkit/manager/state/store" + "github.com/pkg/errors" +) + +// IssuanceStateRotateMaxBatchSize is the maximum number of nodes we'll tell to rotate their certificates in any given update +const IssuanceStateRotateMaxBatchSize = 30 + +func hasIssuer(n *api.Node, info *IssuerInfo) bool { + if n.Description == nil || n.Description.TLSInfo == nil { + return false + } + return bytes.Equal(info.Subject, n.Description.TLSInfo.CertIssuerSubject) && bytes.Equal(info.PublicKey, n.Description.TLSInfo.CertIssuerPublicKey) +} + +var errRootRotationChanged = errors.New("target root rotation has changed") + +// rootRotationReconciler keeps track of all the nodes in the store so that we can determine which ones need reconciliation when nodes are updated +// or the root CA is updated. This is meant to be used with watches on nodes and the cluster, and provides functions to be called when the +// cluster's RootCA has changed and when a node is added, updated, or removed. +type rootRotationReconciler struct { + mu sync.Mutex + clusterID string + batchUpdateInterval time.Duration + ctx context.Context + store *store.MemoryStore + + currentRootCA *api.RootCA + currentIssuer IssuerInfo + unconvergedNodes map[string]*api.Node + + wg sync.WaitGroup + cancel func() +} + +// IssuerFromAPIRootCA returns the desired issuer given an API root CA object +func IssuerFromAPIRootCA(rootCA *api.RootCA) (*IssuerInfo, error) { + wantedIssuer := rootCA.CACert + if rootCA.RootRotation != nil { + wantedIssuer = rootCA.RootRotation.CACert + } + issuerCerts, err := helpers.ParseCertificatesPEM(wantedIssuer) + if err != nil { + return nil, errors.Wrap(err, "invalid certificate in cluster root CA object") + } + if len(issuerCerts) == 0 { + return nil, errors.New("invalid certificate in cluster root CA object") + } + return &IssuerInfo{ + Subject: issuerCerts[0].RawSubject, + PublicKey: issuerCerts[0].RawSubjectPublicKeyInfo, + }, nil +} + +// assumption: UpdateRootCA will never be called with a `nil` root CA because the caller will be acting in response to +// a store update event +func (r *rootRotationReconciler) UpdateRootCA(newRootCA *api.RootCA) { + issuerInfo, err := IssuerFromAPIRootCA(newRootCA) + if err != nil { + log.G(r.ctx).WithError(err).Error("unable to update process the current root CA") + return + } + + var ( + shouldStartNewLoop, waitForPrevLoop bool + loopCtx context.Context + ) + r.mu.Lock() + defer func() { + r.mu.Unlock() + if shouldStartNewLoop { + if waitForPrevLoop { + r.wg.Wait() + } + go r.runReconcilerLoop(loopCtx, newRootCA) + } + }() + + // check if the issuer has changed, first + if reflect.DeepEqual(&r.currentIssuer, issuerInfo) { + r.currentRootCA = newRootCA + return + } + // If the issuer has changed, iterate through all the nodes to figure out which ones need rotation + if newRootCA.RootRotation != nil { + var nodes []*api.Node + r.store.View(func(tx store.ReadTx) { + nodes, err = store.FindNodes(tx, store.ByMembership(api.NodeMembershipAccepted)) + }) + if err != nil { + log.G(r.ctx).WithError(err).Error("unable to list nodes, so unable to process the current root CA") + return + } + + // from here on out, there will be no more errors that cause us to have to abandon updating the Root CA, + // so we can start making changes to r's fields + r.unconvergedNodes = make(map[string]*api.Node) + for _, n := range nodes { + if !hasIssuer(n, issuerInfo) { + r.unconvergedNodes[n.ID] = n + } + } + shouldStartNewLoop = true + if r.cancel != nil { // there's already a loop going, so cancel it + r.cancel() + waitForPrevLoop = true + } + loopCtx, r.cancel = context.WithCancel(r.ctx) + } else { + r.unconvergedNodes = nil + } + r.currentRootCA = newRootCA + r.currentIssuer = *issuerInfo +} + +// assumption: UpdateNode will never be called with a `nil` node because the caller will be acting in response to +// a store update event +func (r *rootRotationReconciler) UpdateNode(node *api.Node) { + r.mu.Lock() + defer r.mu.Unlock() + // if we're not in the middle of a root rotation, or if this node does not have membership, ignore it + if r.currentRootCA == nil || r.currentRootCA.RootRotation == nil || node.Spec.Membership != api.NodeMembershipAccepted { + return + } + if hasIssuer(node, &r.currentIssuer) { + delete(r.unconvergedNodes, node.ID) + } else { + r.unconvergedNodes[node.ID] = node + } +} + +// assumption: DeleteNode will never be called with a `nil` node because the caller will be acting in response to +// a store update event +func (r *rootRotationReconciler) DeleteNode(node *api.Node) { + r.mu.Lock() + delete(r.unconvergedNodes, node.ID) + r.mu.Unlock() +} + +func (r *rootRotationReconciler) runReconcilerLoop(ctx context.Context, loopRootCA *api.RootCA) { + r.wg.Add(1) + defer r.wg.Done() + for { + r.mu.Lock() + if len(r.unconvergedNodes) == 0 { + r.mu.Unlock() + + err := r.store.Update(func(tx store.Tx) error { + return r.finishRootRotation(tx, loopRootCA) + }) + if err == nil { + log.G(r.ctx).Info("completed root rotation") + return + } + log.G(r.ctx).WithError(err).Error("could not complete root rotation") + if err == errRootRotationChanged { + // if the root rotation has changed, this loop will be cancelled anyway, so may as well abort early + return + } + } else { + var toUpdate []*api.Node + for _, n := range r.unconvergedNodes { + iState := n.Certificate.Status.State + if iState != api.IssuanceStateRenew && iState != api.IssuanceStatePending && iState != api.IssuanceStateRotate { + n = n.Copy() + n.Certificate.Status.State = api.IssuanceStateRotate + toUpdate = append(toUpdate, n) + if len(toUpdate) >= IssuanceStateRotateMaxBatchSize { + break + } + } + } + r.mu.Unlock() + + if err := r.batchUpdateNodes(toUpdate); err != nil { + log.G(r.ctx).WithError(err).Errorf("store error when trying to batch update %d nodes to request certificate rotation", len(toUpdate)) + } + } + + select { + case <-ctx.Done(): + return + case <-time.After(r.batchUpdateInterval): + } + } +} + +// This function assumes that the expected root CA has root rotation. This is intended to be used by +// `reconcileNodeRootsAndCerts`, which uses the root CA from the `lastSeenClusterRootCA`, and checks +// that it has a root rotation before calling this function. +func (r *rootRotationReconciler) finishRootRotation(tx store.Tx, expectedRootCA *api.RootCA) error { + cluster := store.GetCluster(tx, r.clusterID) + if cluster == nil { + return fmt.Errorf("unable to get cluster %s", r.clusterID) + } + + // If the RootCA object has changed (because another root rotation was started or because some other node + // had finished the root rotation), we cannot finish the root rotation that we were working on. + if !equality.RootCAEqualStable(expectedRootCA, &cluster.RootCA) { + return errRootRotationChanged + } + + var signerCert []byte + if len(cluster.RootCA.RootRotation.CAKey) > 0 { + signerCert = cluster.RootCA.RootRotation.CACert + } + // we don't actually have to parse out the default node expiration from the cluster - we are just using + // the ca.RootCA object to generate new tokens and the digest + updatedRootCA, err := NewRootCA(cluster.RootCA.RootRotation.CACert, signerCert, cluster.RootCA.RootRotation.CAKey, + DefaultNodeCertExpiration, nil) + if err != nil { + return errors.Wrap(err, "invalid cluster root rotation object") + } + cluster.RootCA = api.RootCA{ + CACert: cluster.RootCA.RootRotation.CACert, + CAKey: cluster.RootCA.RootRotation.CAKey, + CACertHash: updatedRootCA.Digest.String(), + JoinTokens: api.JoinTokens{ + Worker: GenerateJoinToken(&updatedRootCA), + Manager: GenerateJoinToken(&updatedRootCA), + }, + LastForcedRotation: cluster.RootCA.LastForcedRotation, + } + return store.UpdateCluster(tx, cluster) +} + +func (r *rootRotationReconciler) batchUpdateNodes(toUpdate []*api.Node) error { + if len(toUpdate) == 0 { + return nil + } + _, err := r.store.Batch(func(batch *store.Batch) error { + // Directly update the nodes rather than get + update, and ignore version errors. Since + // `rootRotationReconciler` should be hooked up to all node update/delete/create events, we should have + // close to the latest versions of all the nodes. If not, the node will updated later and the + // next batch of updates should catch it. + for _, n := range toUpdate { + if err := batch.Update(func(tx store.Tx) error { + return store.UpdateNode(tx, n) + }); err != nil && err != store.ErrSequenceConflict { + log.G(r.ctx).WithError(err).Errorf("unable to update node %s to request a certificate rotation", n.ID) + } + } + return nil + }) + return err +} diff --git a/ca/server.go b/ca/server.go index f3fa9cb6c8..982b0e2e83 100644 --- a/ca/server.go +++ b/ca/server.go @@ -22,6 +22,7 @@ import ( const ( defaultReconciliationRetryInterval = 10 * time.Second + defaultRootReconciliationInterval = 3 * time.Second ) // APISecurityConfigUpdater knows how to update a SecurityConfig from an api.Cluster object @@ -63,6 +64,10 @@ type Server struct { // before we update the security config with the new root CA, we need to be able to save the root certs rootPaths CertPaths + + // lets us monitor and finish root rotations + rootReconciler *rootRotationReconciler + rootReconciliationRetryInterval time.Duration } // DefaultCAConfig returns the default CA Config, with a default expiration. @@ -75,12 +80,13 @@ func DefaultCAConfig() api.CAConfig { // NewServer creates a CA API server. func NewServer(store *store.MemoryStore, securityConfig *SecurityConfig, rootCAPaths CertPaths) *Server { return &Server{ - store: store, - securityConfig: securityConfig, - pending: make(map[string]*api.Node), - started: make(chan struct{}), - reconciliationRetryInterval: defaultReconciliationRetryInterval, - rootPaths: rootCAPaths, + store: store, + securityConfig: securityConfig, + pending: make(map[string]*api.Node), + started: make(chan struct{}), + reconciliationRetryInterval: defaultReconciliationRetryInterval, + rootReconciliationRetryInterval: defaultRootReconciliationInterval, + rootPaths: rootCAPaths, } } @@ -90,6 +96,12 @@ func (s *Server) SetReconciliationRetryInterval(reconciliationRetryInterval time s.reconciliationRetryInterval = reconciliationRetryInterval } +// SetRootReconciliationInterval changes the time interval between root rotation +// reconciliation attempts. This function must be called before Run. +func (s *Server) SetRootReconciliationInterval(interval time.Duration) { + s.rootReconciliationRetryInterval = interval +} + // GetUnlockKey is responsible for returning the current unlock key used for encrypting TLS private keys and // other at rest data. Access to this RPC call should only be allowed via mutual TLS from managers. func (s *Server) GetUnlockKey(ctx context.Context, request *api.GetUnlockKeyRequest) (*api.GetUnlockKeyResponse, error) { @@ -395,14 +407,28 @@ func (s *Server) Run(ctx context.Context) error { return errors.New("CA signer is already running") } s.wg.Add(1) + s.ctx, s.cancel = context.WithCancel(log.WithModule(ctx, "ca")) + ctx = s.ctx + // we need to set it on the server, because `Server.UpdateRootCA` can be called from outside the Run function + s.rootReconciler = &rootRotationReconciler{ + ctx: log.WithField(ctx, "method", "(*Server).rootRotationReconciler"), + clusterID: s.securityConfig.ClientTLSCreds.Organization(), + store: s.store, + batchUpdateInterval: s.rootReconciliationRetryInterval, + } + rootReconciler := s.rootReconciler s.mu.Unlock() - defer s.wg.Done() - ctx = log.WithModule(ctx, "ca") + defer func() { + s.mu.Lock() + s.rootReconciler = nil + s.mu.Unlock() + }() // Retrieve the channels to keep track of changes in the cluster // Retrieve all the currently registered nodes var nodes []*api.Node + updates, cancel, err := store.ViewAndWatch( s.store, func(readTx store.ReadTx) error { @@ -419,13 +445,12 @@ func (s *Server) Run(ctx context.Context) error { }, api.EventCreateNode{}, api.EventUpdateNode{}, + api.EventDeleteNode{}, ) // Do this after updateCluster has been called, so isRunning never // returns true without joinTokens being set correctly. s.mu.Lock() - s.ctx, s.cancel = context.WithCancel(ctx) - ctx = s.ctx close(s.started) s.mu.Unlock() @@ -464,13 +489,18 @@ func (s *Server) Run(ctx context.Context) error { switch v := event.(type) { case api.EventCreateNode: s.evaluateAndSignNodeCert(ctx, v.Node) + rootReconciler.UpdateNode(v.Node) case api.EventUpdateNode: // If this certificate is already at a final state // no need to evaluate and sign it. if !isFinalState(v.Node.Certificate.Status) { s.evaluateAndSignNodeCert(ctx, v.Node) } + rootReconciler.UpdateNode(v.Node) + case api.EventDeleteNode: + rootReconciler.DeleteNode(v.Node) } + case <-ticker.C: for _, node := range s.pending { if err := s.evaluateAndSignNodeCert(ctx, node); err != nil { @@ -541,12 +571,16 @@ func (s *Server) isRunning() bool { func (s *Server) UpdateRootCA(ctx context.Context, cluster *api.Cluster) error { s.mu.Lock() s.joinTokens = cluster.RootCA.JoinTokens.Copy() + reconciler := s.rootReconciler s.mu.Unlock() + rCA := cluster.RootCA.Copy() + if reconciler != nil { + reconciler.UpdateRootCA(rCA) + } s.secConfigMu.Lock() defer s.secConfigMu.Unlock() - rCA := cluster.RootCA - rootCAChanged := len(rCA.CACert) != 0 && !equality.RootCAEqualStable(s.lastSeenClusterRootCA, &cluster.RootCA) + rootCAChanged := len(rCA.CACert) != 0 && !equality.RootCAEqualStable(s.lastSeenClusterRootCA, rCA) externalCAChanged := !equality.ExternalCAsEqualStable(s.lastSeenExternalCAs, cluster.Spec.CAConfig.ExternalCAs) logger := log.G(ctx).WithFields(logrus.Fields{ "cluster.id": cluster.ID, @@ -581,7 +615,6 @@ func (s *Server) UpdateRootCA(ctx context.Context, cluster *api.Cluster) error { if signingKey == nil { signingCert = nil } - updatedRootCA, err := NewRootCA(rCA.CACert, signingCert, signingKey, expiry, intermediates) if err != nil { return errors.Wrap(err, "invalid Root CA object in cluster") @@ -604,7 +637,7 @@ func (s *Server) UpdateRootCA(ctx context.Context, cluster *api.Cluster) error { } // only update the server cache if we've successfully updated the root CA logger.Debug("Root CA updated successfully") - s.lastSeenClusterRootCA = cluster.RootCA.Copy() + s.lastSeenClusterRootCA = rCA } // we want to update if the external CA changed, or if the root CA changed because the root CA could affect what diff --git a/ca/server_test.go b/ca/server_test.go index c80279665b..ce53f16170 100644 --- a/ca/server_test.go +++ b/ca/server_test.go @@ -6,19 +6,23 @@ import ( "fmt" "io/ioutil" "os" + "path/filepath" + "reflect" "testing" "time" - "golang.org/x/net/context" - "github.com/cloudflare/cfssl/helpers" "github.com/docker/swarmkit/api" + "github.com/docker/swarmkit/api/equality" "github.com/docker/swarmkit/ca" cautils "github.com/docker/swarmkit/ca/testutils" "github.com/docker/swarmkit/manager/state/store" "github.com/docker/swarmkit/testutils" + "github.com/opencontainers/go-digest" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" ) @@ -551,3 +555,807 @@ func TestCAServerUpdateRootCA(t *testing.T) { tc.CAServer.UpdateRootCA(context.Background(), fakeClusterSpec(cautils.ECDSA256SHA256Cert, cautils.ECDSA256Key, nil, nil)) require.Equal(t, tc.RootCA.Certs, tc.ServingSecurityConfig.RootCA().Certs) } + +type rootRotationTester struct { + tc *cautils.TestCA + t *testing.T +} + +// go through all the nodes and update/create the ones we want, and delete the ones +// we don't +func (r *rootRotationTester) convergeWantedNodes(wantNodes map[string]*api.Node, descr string) { + // update existing and create new nodes first before deleting nodes, else a root rotation + // may finish early if all the nodes get deleted when the root rotation happens + require.NoError(r.t, r.tc.MemoryStore.Update(func(tx store.Tx) error { + for nodeID, wanted := range wantNodes { + node := store.GetNode(tx, nodeID) + if node == nil { + if err := store.CreateNode(tx, wanted); err != nil { + return err + } + continue + } + node.Description = wanted.Description + node.Certificate = wanted.Certificate + if err := store.UpdateNode(tx, node); err != nil { + return err + } + } + nodes, err := store.FindNodes(tx, store.All) + if err != nil { + return err + } + for _, node := range nodes { + if _, inWanted := wantNodes[node.ID]; !inWanted { + if err := store.DeleteNode(tx, node.ID); err != nil { + return err + } + } + } + return nil + }), descr) +} + +func (r *rootRotationTester) convergeRootCA(wantRootCA *api.RootCA, descr string) { + require.NoError(r.t, r.tc.MemoryStore.Update(func(tx store.Tx) error { + clusters, err := store.FindClusters(tx, store.All) + if err != nil || len(clusters) != 1 { + return errors.Wrap(err, "unable to find cluster") + } + clusters[0].RootCA = *wantRootCA + return store.UpdateCluster(tx, clusters[0]) + }), descr) +} + +func getFakeAPINode(t *testing.T, id string, state api.IssuanceStatus_State, tlsInfo *api.NodeTLSInfo, member bool) *api.Node { + node := &api.Node{ + ID: id, + Certificate: api.Certificate{ + Status: api.IssuanceStatus{ + State: state, + }, + }, + Spec: api.NodeSpec{ + Membership: api.NodeMembershipAccepted, + }, + } + if !member { + node.Spec.Membership = api.NodeMembershipPending + } + // the CA server will immediately pick these up, so generate CSRs for the CA server to sign + if state == api.IssuanceStateRenew || state == api.IssuanceStatePending { + csr, _, err := ca.GenerateNewCSR() + require.NoError(t, err) + node.Certificate.CSR = csr + } + if tlsInfo != nil { + node.Description = &api.NodeDescription{TLSInfo: tlsInfo} + } + return node +} + +func startCAServer(caServer *ca.Server) { + alreadyRunning := make(chan struct{}) + go func() { + if err := caServer.Run(context.Background()); err != nil { + close(alreadyRunning) + } + }() + select { + case <-caServer.Ready(): + case <-alreadyRunning: + } +} + +func getRotationInfo(t *testing.T, rotationCert []byte, rootCA *ca.RootCA) ([]byte, *api.NodeTLSInfo) { + parsedNewRoot, err := helpers.ParseCertificatePEM(rotationCert) + require.NoError(t, err) + crossSigned, err := rootCA.CrossSignCACertificate(rotationCert) + require.NoError(t, err) + return crossSigned, &api.NodeTLSInfo{ + TrustRoot: rootCA.Certs, + CertIssuerPublicKey: parsedNewRoot.RawSubjectPublicKeyInfo, + CertIssuerSubject: parsedNewRoot.RawSubject, + } +} + +// These are the root rotation test cases where we expect there to be a change in the FindNodes +// or root CA values after converging. +func TestRootRotationReconciliationWithChanges(t *testing.T) { + t.Parallel() + if cautils.External { + // the external CA functionality is unrelated to testing the reconciliation loop + return + } + + tc := cautils.NewTestCA(t) + defer tc.Stop() + rt := rootRotationTester{ + tc: tc, + t: t, + } + + rotationCerts := [][]byte{cautils.ECDSA256SHA256Cert, cautils.ECDSACertChain[2]} + rotationKeys := [][]byte{cautils.ECDSA256Key, cautils.ECDSACertChainKeys[2]} + var ( + rotationCrossSigned [][]byte + rotationTLSInfo []*api.NodeTLSInfo + ) + for _, cert := range rotationCerts { + cross, info := getRotationInfo(t, cert, &tc.RootCA) + rotationCrossSigned = append(rotationCrossSigned, cross) + rotationTLSInfo = append(rotationTLSInfo, info) + } + + oldNodeTLSInfo := &api.NodeTLSInfo{ + TrustRoot: tc.RootCA.Certs, + CertIssuerPublicKey: tc.ServingSecurityConfig.IssuerInfo().PublicKey, + CertIssuerSubject: tc.ServingSecurityConfig.IssuerInfo().Subject, + } + + var startCluster *api.Cluster + tc.MemoryStore.View(func(tx store.ReadTx) { + startCluster = store.GetCluster(tx, tc.Organization) + }) + require.NotNil(t, startCluster) + + testcases := []struct { + nodes map[string]*api.Node // what nodes we should start with + rootCA *api.RootCA // what root CA we should start with + expectedNodes map[string]*api.Node // what nodes we expect in the end, if nil, then unchanged from the start + expectedRootCA *api.RootCA // what root CA we expect in the end, if nil, then unchanged from the start + caServerRestart bool // whether to stop the CA server before making the node and root changes and restart after + descr string + }{ + { + descr: ("If there is no TLS info, the reconciliation cycle tells the nodes to rotate if they're not already getting " + + "a new cert. Any renew/pending nodes will have certs issued, but because the TLS info is nil, they will " + + `go "rotate" state`), + nodes: map[string]*api.Node{ + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, nil, true), + "2": getFakeAPINode(t, "2", api.IssuanceStateRenew, nil, true), + "3": getFakeAPINode(t, "3", api.IssuanceStateRotate, nil, true), + "4": getFakeAPINode(t, "4", api.IssuanceStatePending, nil, true), + "5": getFakeAPINode(t, "5", api.IssuanceStateFailed, nil, true), + }, + rootCA: &api.RootCA{ + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCerts[0], + CAKey: rotationKeys[0], + CrossSignedCACert: rotationCrossSigned[0], + }, + }, + expectedNodes: map[string]*api.Node{ + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateRotate, nil, true), + "2": getFakeAPINode(t, "2", api.IssuanceStateRotate, nil, true), + "3": getFakeAPINode(t, "3", api.IssuanceStateRotate, nil, true), + "4": getFakeAPINode(t, "4", api.IssuanceStateRotate, nil, true), + "5": getFakeAPINode(t, "5", api.IssuanceStateRotate, nil, true), + }, + }, + { + descr: ("Assume all of the nodes have gotten certs, but some of them are the wrong cert " + + "(going by the TLS info), which shouldn't really happen. the rotation reconciliation " + + "will tell the wrong ones to rotate a second time"), + nodes: map[string]*api.Node{ + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "2": getFakeAPINode(t, "2", api.IssuanceStateIssued, oldNodeTLSInfo, true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": getFakeAPINode(t, "5", api.IssuanceStateIssued, oldNodeTLSInfo, true), + }, + rootCA: &api.RootCA{ // no change in root CA from previous + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCerts[0], + CAKey: rotationKeys[0], + CrossSignedCACert: rotationCrossSigned[0], + }, + }, + expectedNodes: map[string]*api.Node{ + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "2": getFakeAPINode(t, "2", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": getFakeAPINode(t, "5", api.IssuanceStateRotate, oldNodeTLSInfo, true), + }, + }, + { + descr: ("New nodes that are added will also be picked up and told to rotate"), + nodes: map[string]*api.Node{ + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": getFakeAPINode(t, "5", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "6": getFakeAPINode(t, "6", api.IssuanceStateRenew, nil, true), + }, + rootCA: &api.RootCA{ // no change in root CA from previous + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCerts[0], + CAKey: rotationKeys[0], + CrossSignedCACert: rotationCrossSigned[0], + }, + }, + expectedNodes: map[string]*api.Node{ + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": getFakeAPINode(t, "5", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "6": getFakeAPINode(t, "6", api.IssuanceStateRotate, nil, true), + }, + }, + { + descr: ("Even if root rotation isn't finished, if the root changes again to a " + + "different cert, all the nodes with the old root rotation cert will be told " + + "to rotate again."), + nodes: map[string]*api.Node{ + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[0], true), + "5": getFakeAPINode(t, "5", api.IssuanceStateIssued, oldNodeTLSInfo, true), + "6": getFakeAPINode(t, "6", api.IssuanceStateIssued, rotationTLSInfo[0], true), + }, + rootCA: &api.RootCA{ // new root rotation + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCerts[1], + CAKey: rotationKeys[1], + CrossSignedCACert: rotationCrossSigned[1], + }, + }, + expectedNodes: map[string]*api.Node{ + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateRotate, rotationTLSInfo[0], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateRotate, rotationTLSInfo[0], true), + "5": getFakeAPINode(t, "5", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "6": getFakeAPINode(t, "6", api.IssuanceStateRotate, rotationTLSInfo[0], true), + }, + }, + { + descr: ("Once all nodes have rotated to their desired TLS info (even if it's because " + + "a node with the wrong TLS info has been removed, the root rotation is completed."), + nodes: map[string]*api.Node{ + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "6": getFakeAPINode(t, "6", api.IssuanceStateIssued, rotationTLSInfo[1], true), + }, + rootCA: &api.RootCA{ + // no change in root CA from previous - even if root rotation gets completed after + // the nodes are first set, and we just add the root rotation again because of this + // test order, because the TLS info is correct for all nodes it will be completed again + // anyway) + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCerts[1], + CAKey: rotationKeys[1], + CrossSignedCACert: rotationCrossSigned[1], + }, + }, + expectedRootCA: &api.RootCA{ + CACert: rotationCerts[1], + CAKey: rotationKeys[1], + CACertHash: digest.FromBytes(rotationCerts[1]).String(), + // ignore the join tokens - we aren't comparing them + }, + }, + { + descr: ("If a root rotation happens when the CA server is down, so long as it saw the change " + + "it will start reconciling the nodes as soon as it's started up again"), + caServerRestart: true, + nodes: map[string]*api.Node{ + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateIssued, rotationTLSInfo[1], true), + "6": getFakeAPINode(t, "6", api.IssuanceStateIssued, rotationTLSInfo[1], true), + }, + rootCA: &api.RootCA{ + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCerts[0], + CAKey: rotationKeys[0], + CrossSignedCACert: rotationCrossSigned[0], + }, + }, + expectedNodes: map[string]*api.Node{ + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateRotate, rotationTLSInfo[1], true), + "3": getFakeAPINode(t, "3", api.IssuanceStateRotate, rotationTLSInfo[1], true), + "4": getFakeAPINode(t, "4", api.IssuanceStateRotate, rotationTLSInfo[1], true), + "6": getFakeAPINode(t, "6", api.IssuanceStateRotate, rotationTLSInfo[1], true), + }, + }, + } + + for _, testcase := range testcases { + if testcase.caServerRestart { + rt.tc.CAServer.Stop() + } + + rt.convergeRootCA(testcase.rootCA, testcase.descr) + rt.convergeWantedNodes(testcase.nodes, testcase.descr) + + if testcase.expectedNodes == nil { + testcase.expectedNodes = testcase.nodes + } + if testcase.expectedRootCA == nil { + testcase.expectedRootCA = testcase.rootCA + } + + if testcase.caServerRestart { + startCAServer(rt.tc.CAServer) + } + + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + var ( + nodes []*api.Node + cluster *api.Cluster + err error + ) + tc.MemoryStore.View(func(tx store.ReadTx) { + nodes, err = store.FindNodes(tx, store.All) + cluster = store.GetCluster(tx, tc.Organization) + }) + if err != nil { + return err + } + if cluster == nil { + return errors.New("no cluster found") + } + + if !equality.RootCAEqualStable(&cluster.RootCA, testcase.expectedRootCA) { + return fmt.Errorf("root CAs not equal:\n\texpected: %v\n\tactual: %v", *testcase.expectedRootCA, cluster.RootCA) + } + if len(nodes) != len(testcase.expectedNodes) { + return fmt.Errorf("number of expected nodes (%d) does not equal number of actual nodes (%d)", + len(testcase.expectedNodes), len(nodes)) + } + for _, node := range nodes { + expected, ok := testcase.expectedNodes[node.ID] + if !ok { + return fmt.Errorf("node %s is present and was unexpected", node.ID) + } + if !reflect.DeepEqual(expected.Description, node.Description) { + return fmt.Errorf("the node description of node %s is not expected:\n\texpected: %v\n\tactual: %v", node.ID, + expected.Description, node.Description) + } + if !reflect.DeepEqual(expected.Certificate.Status, node.Certificate.Status) { + return fmt.Errorf("the certificate status of node %s is not expected:\n\texpected: %v\n\tactual: %v", node.ID, + expected.Certificate, node.Certificate) + } + + // ensure that the security config's root CA object has the same expected key + expectedKey := testcase.expectedRootCA.CAKey + if testcase.expectedRootCA.RootRotation != nil { + expectedKey = testcase.expectedRootCA.RootRotation.CAKey + } + s, err := rt.tc.ServingSecurityConfig.RootCA().Signer() + if err != nil { + return err + } + if !bytes.Equal(s.Key, expectedKey) { + return fmt.Errorf("the security config has not been updated correctly") + } + } + return nil + }, 5*time.Second), testcase.descr) + } +} + +// These are the root rotation test cases where we expect there to be no changes made to either +// the nodes or the root CA object +func TestRootRotationReconciliationNoChanges(t *testing.T) { + t.Parallel() + if cautils.External { + // the external CA functionality is unrelated to testing the reconciliation loop + return + } + + tc := cautils.NewTestCA(t) + defer tc.Stop() + rt := rootRotationTester{ + tc: tc, + t: t, + } + + rotationCert := cautils.ECDSA256SHA256Cert + rotationKey := cautils.ECDSA256Key + rotationCrossSigned, rotationTLSInfo := getRotationInfo(t, rotationCert, &tc.RootCA) + + oldNodeTLSInfo := &api.NodeTLSInfo{ + TrustRoot: tc.RootCA.Certs, + CertIssuerPublicKey: tc.ServingSecurityConfig.IssuerInfo().PublicKey, + CertIssuerSubject: tc.ServingSecurityConfig.IssuerInfo().Subject, + } + + var startCluster *api.Cluster + tc.MemoryStore.View(func(tx store.ReadTx) { + startCluster = store.GetCluster(tx, tc.Organization) + }) + require.NotNil(t, startCluster) + + testcases := []struct { + nodes map[string]*api.Node // what nodes we should start with + rootCA *api.RootCA // what root CA we should start with + descr string + caServerStopped bool // if the server is running, only then will a reconciliation loop happen + }{ + { + descr: ("If the CA server is not running no reconciliation happens even if a root rotation " + + "is in progress"), + caServerStopped: true, + nodes: map[string]*api.Node{ + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, oldNodeTLSInfo, true), + "2": getFakeAPINode(t, "2", api.IssuanceStateRenew, nil, true), + "3": getFakeAPINode(t, "3", api.IssuanceStateRotate, nil, true), + "4": getFakeAPINode(t, "4", api.IssuanceStatePending, nil, true), + "5": getFakeAPINode(t, "5", api.IssuanceStateFailed, nil, true), + }, + rootCA: &api.RootCA{ + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCert, + CAKey: rotationKey, + CrossSignedCACert: rotationCrossSigned, + }, + }, + }, + { + descr: ("If all nodes have the right TLS info or are already rotated (or are not members), " + + "there will be no changes needed"), + nodes: map[string]*api.Node{ + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo, true), + "2": getFakeAPINode(t, "2", api.IssuanceStateRotate, oldNodeTLSInfo, true), + "3": getFakeAPINode(t, "3", api.IssuanceStateRotate, rotationTLSInfo, true), + }, + rootCA: &api.RootCA{ // no change in root CA from previous + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + RootRotation: &api.RootRotation{ + CACert: rotationCert, + CAKey: rotationKey, + CrossSignedCACert: rotationCrossSigned, + }, + }, + }, + { + descr: ("Nodes already in rotate state, even if they currently have the correct TLS issuer, will be " + + "left in the rotate state even if root rotation is aborted because we don't know if they're already " + + "in the process of getting a new cert. Even if they're issued by a different issuer, they will be " + + "left alone because they'll have an interemdiate that chains up to the old issuer."), + nodes: map[string]*api.Node{ + "0": getFakeAPINode(t, "0", api.IssuanceStatePending, nil, false), + "1": getFakeAPINode(t, "1", api.IssuanceStateIssued, rotationTLSInfo, true), + "2": getFakeAPINode(t, "2", api.IssuanceStateRotate, oldNodeTLSInfo, true), + }, + rootCA: &api.RootCA{ // no change in root CA from previous + CACert: startCluster.RootCA.CACert, + CAKey: startCluster.RootCA.CAKey, + CACertHash: startCluster.RootCA.CACertHash, + }, + }, + } + + for _, testcase := range testcases { + if testcase.caServerStopped { + rt.tc.CAServer.Stop() + } else { + startCAServer(rt.tc.CAServer) + } + + rt.convergeRootCA(testcase.rootCA, testcase.descr) + rt.convergeWantedNodes(testcase.nodes, testcase.descr) + + time.Sleep(500 * time.Millisecond) + + var ( + nodes []*api.Node + cluster *api.Cluster + err error + ) + + tc.MemoryStore.View(func(tx store.ReadTx) { + nodes, err = store.FindNodes(tx, store.All) + cluster = store.GetCluster(tx, tc.Organization) + }) + require.NoError(t, err) + require.NotNil(t, cluster) + require.Equal(t, cluster.RootCA, *testcase.rootCA, testcase.descr) + require.Len(t, nodes, len(testcase.nodes), testcase.descr) + for _, node := range nodes { + expected, ok := testcase.nodes[node.ID] + require.True(t, ok, "node %s: %s", node.ID, testcase.descr) + require.Equal(t, expected.Description, node.Description, "node %s: %s", node.ID, testcase.descr) + require.Equal(t, expected.Certificate.Status, node.Certificate.Status, "node %s: %s", node.ID, testcase.descr) + } + + // ensure that the security config's root CA object has the same expected key + expectedKey := testcase.rootCA.CAKey + if testcase.rootCA.RootRotation != nil { + expectedKey = testcase.rootCA.RootRotation.CAKey + } + s, err := rt.tc.ServingSecurityConfig.RootCA().Signer() + require.NoError(t, err, testcase.descr) + require.Equal(t, s.Key, expectedKey, testcase.descr) + } +} + +// Tests if the root rotation changes while the reconciliation loop is going, eventually the root rotation will finish +// successfully (even if there's a competing reconciliation loop, for instance if there's a bug during leadership handoff). +func TestRootRotationReconciliationRace(t *testing.T) { + t.Parallel() + if cautils.External { + // the external CA functionality is unrelated to testing the reconciliation loop + return + } + + tc := cautils.NewTestCA(t) + defer tc.Stop() + rt := rootRotationTester{ + tc: tc, + t: t, + } + + tempDir, err := ioutil.TempDir("", "competing-ca-server") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + var otherServers []*ca.Server + var secConfigs []*ca.SecurityConfig + for i := 0; i < 3; i++ { // to make sure we get some collision + // start a competing CA server + competingSecConfig, err := tc.NewNodeConfig(ca.ManagerRole) + require.NoError(t, err) + secConfigs = append(secConfigs, competingSecConfig) + + paths := ca.NewConfigPaths(filepath.Join(tempDir, fmt.Sprintf("%d", i))) + + otherServer := ca.NewServer(tc.MemoryStore, competingSecConfig, paths.RootCA) + // offset each server's reconciliation interval somewhat so that some will + // pre-empt others + otherServer.SetRootReconciliationInterval(time.Millisecond * time.Duration((i+1)*10)) + startCAServer(otherServer) + defer otherServer.Stop() + otherServers = append(otherServers, otherServer) + } + clusterWatch, clusterWatchCancel, err := store.ViewAndWatch( + tc.MemoryStore, func(tx store.ReadTx) error { + // don't bother getting the cluster - the CA serverß have already done that when first running + return nil + }, + api.EventUpdateCluster{ + Cluster: &api.Cluster{ID: tc.Organization}, + Checks: []api.ClusterCheckFunc{api.ClusterCheckID}, + }, + ) + require.NoError(t, err) + defer clusterWatchCancel() + + done := make(chan struct{}) + defer close(done) + go func() { + for { + select { + case event := <-clusterWatch: + clusterEvent := event.(api.EventUpdateCluster) + for _, s := range otherServers { + s.UpdateRootCA(context.Background(), clusterEvent.Cluster) + } + case <-done: + return + } + } + }() + + oldNodeTLSInfo := &api.NodeTLSInfo{ + TrustRoot: tc.RootCA.Certs, + CertIssuerPublicKey: tc.ServingSecurityConfig.IssuerInfo().PublicKey, + CertIssuerSubject: tc.ServingSecurityConfig.IssuerInfo().Subject, + } + + nodes := make(map[string]*api.Node) + for i := 0; i < 5; i++ { + nodeID := fmt.Sprintf("%d", i) + nodes[nodeID] = getFakeAPINode(t, nodeID, api.IssuanceStateIssued, oldNodeTLSInfo, true) + } + rt.convergeWantedNodes(nodes, "setting up nodes for root rotation race condition test") + + var rotationCert, rotationKey []byte + for i := 0; i < 10; i++ { + var ( + rotationCrossSigned []byte + rotationTLSInfo *api.NodeTLSInfo + ) + rotationCert, rotationKey, err = cautils.CreateRootCertAndKey(fmt.Sprintf("root cn %d", i)) + require.NoError(t, err) + require.NoError(t, tc.MemoryStore.Update(func(tx store.Tx) error { + cluster := store.GetCluster(tx, tc.Organization) + if cluster == nil { + return errors.New("cluster has disappeared") + } + rootCA := cluster.RootCA.Copy() + caRootCA, err := ca.NewRootCA(rootCA.CACert, rootCA.CACert, rootCA.CAKey, ca.DefaultNodeCertExpiration, nil) + if err != nil { + return err + } + rotationCrossSigned, rotationTLSInfo = getRotationInfo(t, rotationCert, &caRootCA) + rootCA.RootRotation = &api.RootRotation{ + CACert: rotationCert, + CAKey: rotationKey, + CrossSignedCACert: rotationCrossSigned, + } + cluster.RootCA = *rootCA + return store.UpdateCluster(tx, cluster) + })) + for _, node := range nodes { + node.Description.TLSInfo = rotationTLSInfo + } + rt.convergeWantedNodes(nodes, fmt.Sprintf("iteration %d", i)) + } + + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + var cluster *api.Cluster + tc.MemoryStore.View(func(tx store.ReadTx) { + cluster = store.GetCluster(tx, tc.Organization) + }) + if cluster == nil { + return errors.New("cluster has disappeared") + } + if cluster.RootCA.RootRotation != nil { + return errors.New("root rotation is still present") + } + if !bytes.Equal(cluster.RootCA.CACert, rotationCert) { + return errors.New("expected root cert is wrong") + } + if !bytes.Equal(cluster.RootCA.CAKey, rotationKey) { + return errors.New("expected root key is wrong") + } + for _, secConfig := range secConfigs { + s, err := secConfig.RootCA().Signer() + if err != nil { + return err + } + if !bytes.Equal(s.Key, rotationKey) { + return errors.New("all the sec configs haven't been updated yet") + } + } + return nil + }, 5*time.Second)) + + // all of the ca servers have the appropriate cert and key +} + +// If there are a lot of nodes, we only update a small number of them at once. +func TestRootRotationReconciliationThrottled(t *testing.T) { + t.Parallel() + if cautils.External { + // the external CA functionality is unrelated to testing the reconciliation loop + return + } + + tc := cautils.NewTestCA(t) + defer tc.Stop() + // immediately stop the CA server - we want to run our down + tc.CAServer.Stop() + + caServer := ca.NewServer(tc.MemoryStore, tc.ServingSecurityConfig, tc.Paths.RootCA) + // set the reconciliation interval to something ridiculous, so we can make sure the first + // batch does update all of them + caServer.SetRootReconciliationInterval(time.Hour) + startCAServer(caServer) + defer caServer.Stop() + + var nodes []*api.Node + clusterWatch, clusterWatchCancel, err := store.ViewAndWatch( + tc.MemoryStore, func(tx store.ReadTx) error { + // don't bother getting the cluster - the CA server has already done that when first running + var err error + nodes, err = store.FindNodes(tx, store.ByMembership(api.NodeMembershipAccepted)) + return err + }, + api.EventUpdateCluster{ + Cluster: &api.Cluster{ID: tc.Organization}, + Checks: []api.ClusterCheckFunc{api.ClusterCheckID}, + }, + ) + require.NoError(t, err) + defer clusterWatchCancel() + + done := make(chan struct{}) + defer close(done) + go func() { + for { + select { + case event := <-clusterWatch: + clusterEvent := event.(api.EventUpdateCluster) + caServer.UpdateRootCA(context.Background(), clusterEvent.Cluster) + case <-done: + return + } + } + }() + + // create twice the batch size of nodes + _, err = tc.MemoryStore.Batch(func(batch *store.Batch) error { + for i := len(nodes); i < ca.IssuanceStateRotateMaxBatchSize*2; i++ { + nodeID := fmt.Sprintf("%d", i) + err := batch.Update(func(tx store.Tx) error { + return store.CreateNode(tx, getFakeAPINode(t, nodeID, api.IssuanceStateIssued, nil, true)) + }) + if err != nil { + return err + } + } + return nil + }) + require.NoError(t, err) + + rotationCert := cautils.ECDSA256SHA256Cert + rotationKey := cautils.ECDSA256Key + rotationCrossSigned, _ := getRotationInfo(t, rotationCert, &tc.RootCA) + + require.NoError(t, tc.MemoryStore.Update(func(tx store.Tx) error { + cluster := store.GetCluster(tx, tc.Organization) + if cluster == nil { + return errors.New("cluster has disappeared") + } + rootCA := cluster.RootCA.Copy() + rootCA.RootRotation = &api.RootRotation{ + CACert: rotationCert, + CAKey: rotationKey, + CrossSignedCACert: rotationCrossSigned, + } + cluster.RootCA = *rootCA + return store.UpdateCluster(tx, cluster) + })) + + checkRotationNumber := func() error { + tc.MemoryStore.View(func(tx store.ReadTx) { + nodes, err = store.FindNodes(tx, store.All) + }) + var issuanceRotate int + for _, n := range nodes { + if n.Certificate.Status.State == api.IssuanceStateRotate { + issuanceRotate += 1 + } + } + if issuanceRotate != ca.IssuanceStateRotateMaxBatchSize { + return fmt.Errorf("expected %d, got %d", ca.IssuanceStateRotateMaxBatchSize, issuanceRotate) + } + return nil + } + + require.NoError(t, testutils.PollFuncWithTimeout(nil, checkRotationNumber, 5*time.Second)) + // prove that it's not just because the updates haven't finished + time.Sleep(time.Second) + require.NoError(t, checkRotationNumber()) +} diff --git a/ca/testutils/cautils.go b/ca/testutils/cautils.go index 61024320c2..532943181c 100644 --- a/ca/testutils/cautils.go +++ b/ca/testutils/cautils.go @@ -21,6 +21,7 @@ import ( "github.com/docker/swarmkit/connectionbroker" "github.com/docker/swarmkit/identity" "github.com/docker/swarmkit/ioutils" + "github.com/docker/swarmkit/log" "github.com/docker/swarmkit/manager/state/store" stateutils "github.com/docker/swarmkit/manager/state/testutils" "github.com/docker/swarmkit/remotes" @@ -178,6 +179,7 @@ func NewTestCAFromRootCA(t *testing.T, tempBaseDir string, rootCA ca.RootCA, krw caServer := ca.NewServer(s, managerConfig, paths.RootCA) caServer.SetReconciliationRetryInterval(50 * time.Millisecond) + caServer.SetRootReconciliationInterval(50 * time.Millisecond) api.RegisterCAServer(grpcServer, caServer) api.RegisterNodeCAServer(grpcServer, caServer) @@ -200,7 +202,9 @@ func NewTestCAFromRootCA(t *testing.T, tempBaseDir string, rootCA ca.RootCA, krw select { case event := <-clusterWatch: clusterEvent := event.(api.EventUpdateCluster) - caServer.UpdateRootCA(ctx, clusterEvent.Cluster) + if err := caServer.UpdateRootCA(ctx, clusterEvent.Cluster); err != nil { + log.G(ctx).WithError(err).Error("ca utils CA server could not update root CA") + } case <-ctx.Done(): clusterWatchCancel() return diff --git a/integration/api.go b/integration/api.go index b7a838eeb9..0f44037887 100644 --- a/integration/api.go +++ b/integration/api.go @@ -131,6 +131,12 @@ func (a *dummyAPI) ListClusters(ctx context.Context, r *api.ListClustersRequest) return cli.ListClusters(ctx, r) } -func (a *dummyAPI) UpdateCluster(context.Context, *api.UpdateClusterRequest) (*api.UpdateClusterResponse, error) { - panic("not implemented") +func (a *dummyAPI) UpdateCluster(ctx context.Context, r *api.UpdateClusterRequest) (*api.UpdateClusterResponse, error) { + ctx, cancel := context.WithTimeout(ctx, opsTimeout) + defer cancel() + cli, err := a.c.RandomManager().ControlClient(ctx) + if err != nil { + return nil, err + } + return cli.UpdateCluster(ctx, r) } diff --git a/integration/cluster.go b/integration/cluster.go index f56351d7e5..6bc00925d6 100644 --- a/integration/cluster.go +++ b/integration/cluster.go @@ -98,14 +98,11 @@ func (c *testCluster) AddManager(lateBind bool, rootCA *ca.RootCA) error { if err != nil { return err } - clusterInfo, err := c.api.ListClusters(context.Background(), &api.ListClustersRequest{}) + clusterInfo, err := c.GetClusterInfo() if err != nil { return err } - if len(clusterInfo.Clusters) == 0 { - return fmt.Errorf("joining manager: there is no cluster created in storage") - } - node, err := newTestNode(joinAddr, clusterInfo.Clusters[0].RootCA.JoinTokens.Manager, false, nil) + node, err := newTestNode(joinAddr, clusterInfo.RootCA.JoinTokens.Manager, false, nil) if err != nil { return err } @@ -137,14 +134,9 @@ func (c *testCluster) AddManager(lateBind bool, rootCA *ca.RootCA) error { if lateBind { // Verify that the control API works - clusterInfo, err := c.api.ListClusters(context.Background(), &api.ListClustersRequest{}) - if err != nil { + if _, err := c.GetClusterInfo(); err != nil { return err } - if len(clusterInfo.Clusters) == 0 { - return fmt.Errorf("joining manager: there is no cluster created in storage") - } - return n.node.BindRemote(context.Background(), "127.0.0.1:0", "") } @@ -162,14 +154,11 @@ func (c *testCluster) AddAgent() error { if err != nil { return err } - clusterInfo, err := c.api.ListClusters(context.Background(), &api.ListClustersRequest{}) + clusterInfo, err := c.GetClusterInfo() if err != nil { return err } - if len(clusterInfo.Clusters) == 0 { - return fmt.Errorf("joining agent: there is no cluster created in storage") - } - node, err := newTestNode(joinAddr, clusterInfo.Clusters[0].RootCA.JoinTokens.Worker, false, nil) + node, err := newTestNode(joinAddr, clusterInfo.RootCA.JoinTokens.Worker, false, nil) if err != nil { return err } @@ -383,3 +372,33 @@ func (c *testCluster) StartNode(id string) error { } return nil } + +func (c *testCluster) GetClusterInfo() (*api.Cluster, error) { + clusterInfo, err := c.api.ListClusters(context.Background(), &api.ListClustersRequest{}) + if err != nil { + return nil, err + } + if len(clusterInfo.Clusters) != 1 { + return nil, fmt.Errorf("number of clusters in storage: %d; expected 1", len(clusterInfo.Clusters)) + } + return clusterInfo.Clusters[0], nil +} + +func (c *testCluster) RotateRootCA(cert, key []byte) error { + // poll in case something else changes the cluster before we can update it + return testutils.PollFuncWithTimeout(nil, func() error { + clusterInfo, err := c.GetClusterInfo() + if err != nil { + return err + } + newSpec := clusterInfo.Spec.Copy() + newSpec.CAConfig.SigningCACert = cert + newSpec.CAConfig.SigningCAKey = key + _, err = c.api.UpdateCluster(context.Background(), &api.UpdateClusterRequest{ + ClusterID: clusterInfo.ID, + Spec: newSpec, + ClusterVersion: &clusterInfo.Meta.Version, + }) + return err + }, opsTimeout) +} diff --git a/integration/integration_test.go b/integration/integration_test.go index dd13f5667c..d521222ec8 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -1,6 +1,7 @@ package integration import ( + "bytes" "flag" "fmt" "io/ioutil" @@ -13,6 +14,8 @@ import ( "golang.org/x/net/context" + "reflect" + "github.com/Sirupsen/logrus" "github.com/cloudflare/cfssl/helpers" events "github.com/docker/go-events" @@ -99,6 +102,9 @@ func pollClusterReady(t *testing.T, c *testCluster, numWorker, numManager int) { return fmt.Errorf("worker node %s should not have manager status, returned %s", n.ID, n.ManagerStatus) } } + if n.Description.TLSInfo == nil { + return fmt.Errorf("node %s has not reported its TLS info yet", n.ID) + } } if !leaderFound { return fmt.Errorf("leader of cluster is not found") @@ -547,3 +553,168 @@ func TestForceNewCluster(t *testing.T) { require.NoError(t, ioutil.WriteFile(managerCertFile, expiredCertPEM, 0644)) require.Error(t, cl.StartNode(nodeID)) } + +func pollRootRotationDone(t *testing.T, cl *testCluster) { + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + clusterInfo, err := cl.GetClusterInfo() + if err != nil { + return err + } + if clusterInfo.RootCA.RootRotation != nil { + return errors.New("root rotation not done") + } + return nil + }, opsTimeout)) +} + +func TestSuccessfulRootRotation(t *testing.T) { + t.Parallel() + numWorker, numManager := 2, 3 + cl := newCluster(t, numWorker, numManager) + defer func() { + require.NoError(t, cl.Stop()) + }() + pollClusterReady(t, cl, numWorker, numManager) + + // Take down one of managers and both workers, so we can't actually ever finish root rotation. + resp, err := cl.api.ListNodes(context.Background(), &api.ListNodesRequest{}) + require.NoError(t, err) + var ( + downManagerID string + downWorkerIDs []string + oldTLSInfo *api.NodeTLSInfo + ) + for _, n := range resp.Nodes { + if oldTLSInfo != nil { + require.Equal(t, oldTLSInfo, n.Description.TLSInfo) + } else { + oldTLSInfo = n.Description.TLSInfo + } + if n.Role == api.NodeRoleManager { + if !n.ManagerStatus.Leader && downManagerID == "" { + downManagerID = n.ID + require.NoError(t, cl.nodes[n.ID].Pause(false)) + } + continue + } + downWorkerIDs = append(downWorkerIDs, n.ID) + require.NoError(t, cl.nodes[n.ID].Pause(false)) + } + + // perform a root rotation, and wait until all the nodes that are up have newly issued certs + newRootCert, newRootKey, err := cautils.CreateRootCertAndKey("newRootCN") + require.NoError(t, err) + require.NoError(t, cl.RotateRootCA(newRootCert, newRootKey)) + + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + resp, err := cl.api.ListNodes(context.Background(), &api.ListNodesRequest{}) + if err != nil { + return err + } + for _, n := range resp.Nodes { + isDown := n.ID == downManagerID || n.ID == downWorkerIDs[0] || n.ID == downWorkerIDs[1] + if reflect.DeepEqual(n.Description.TLSInfo, oldTLSInfo) != isDown { + return fmt.Errorf("expected TLS info to have changed: %v", !isDown) + } + } + + // root rotation isn't done + clusterInfo, err := cl.GetClusterInfo() + if err != nil { + return err + } + require.NotNil(t, clusterInfo.RootCA.RootRotation) // if root rotation is already done, fail and finish the test here + return nil + }, opsTimeout)) + + // Bring the other manager back. Also bring one worker back, kill the other worker, + // and add a new worker - show that we can converge on a root rotation. + require.NoError(t, cl.StartNode(downManagerID)) + require.NoError(t, cl.StartNode(downWorkerIDs[0])) + require.NoError(t, cl.RemoveNode(downWorkerIDs[1], false)) + require.NoError(t, cl.AddAgent()) + + // we can finish root rotation even though the previous leader was down because it had + // already rotated its cert + pollRootRotationDone(t, cl) + + // wait until all the nodes have gotten their new certs and trust roots + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + resp, err = cl.api.ListNodes(context.Background(), &api.ListNodesRequest{}) + if err != nil { + return err + } + var newTLSInfo *api.NodeTLSInfo + for _, n := range resp.Nodes { + if newTLSInfo == nil { + newTLSInfo = n.Description.TLSInfo + if bytes.Equal(newTLSInfo.CertIssuerPublicKey, oldTLSInfo.CertIssuerPublicKey) || + bytes.Equal(newTLSInfo.CertIssuerSubject, oldTLSInfo.CertIssuerSubject) { + return errors.New("expecting the issuer to have changed") + } + if !bytes.Equal(newTLSInfo.TrustRoot, newRootCert) { + return errors.New("expecting the the root certificate to have changed") + } + } else if !reflect.DeepEqual(newTLSInfo, n.Description.TLSInfo) { + return fmt.Errorf("the nodes have not converged yet, particularly %s", n.ID) + } + + if n.Certificate.Status.State != api.IssuanceStateIssued { + return errors.New("nodes have yet to finish renewing their TLS certificates") + } + } + return nil + }, opsTimeout)) +} + +func TestRepeatedRootRotation(t *testing.T) { + t.Parallel() + numWorker, numManager := 3, 1 + cl := newCluster(t, numWorker, numManager) + defer func() { + require.NoError(t, cl.Stop()) + }() + pollClusterReady(t, cl, numWorker, numManager) + + resp, err := cl.api.ListNodes(context.Background(), &api.ListNodesRequest{}) + require.NoError(t, err) + var oldTLSInfo *api.NodeTLSInfo + for _, n := range resp.Nodes { + if oldTLSInfo != nil { + require.Equal(t, oldTLSInfo, n.Description.TLSInfo) + } else { + oldTLSInfo = n.Description.TLSInfo + } + } + + // perform multiple root rotations, wait a second between each + var newRootCert, newRootKey []byte + for i := 0; i < 3; i++ { + newRootCert, newRootKey, err = cautils.CreateRootCertAndKey("newRootCN") + require.NoError(t, err) + require.NoError(t, cl.RotateRootCA(newRootCert, newRootKey)) + time.Sleep(time.Second) + } + + pollRootRotationDone(t, cl) + + // wait until all the nodes are stabilized back to the latest issuer + require.NoError(t, testutils.PollFuncWithTimeout(nil, func() error { + resp, err = cl.api.ListNodes(context.Background(), &api.ListNodesRequest{}) + if err != nil { + return nil + } + for _, n := range resp.Nodes { + if reflect.DeepEqual(n.Description.TLSInfo, oldTLSInfo) { + return errors.New("nodes have not changed TLS info") + } + if n.Certificate.Status.State != api.IssuanceStateIssued { + return errors.New("nodes have yet to finish renewing their TLS certificates") + } + if !bytes.Equal(n.Description.TLSInfo.TrustRoot, newRootCert) { + return errors.New("nodes do not all trust the new root yet") + } + } + return nil + }, opsTimeout)) +} diff --git a/integration/node.go b/integration/node.go index 17868ecc95..f4f27a2d2a 100644 --- a/integration/node.go +++ b/integration/node.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" "runtime" + "strings" "google.golang.org/grpc" @@ -115,6 +116,10 @@ func (n *testNode) stop() error { defer cancel() isManager := n.IsManager() if err := n.node.Stop(ctx); err != nil { + // if the error is from trying to stop an already stopped stopped node, ignore the error + if strings.Contains(err.Error(), "node: not started") { + return nil + } // TODO(aaronl): This stack dumping may be removed in the // future once context deadline issues while shutting down // nodes are resolved.