Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 35 additions & 51 deletions ca/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,28 +137,40 @@ func (s *SecurityConfig) UpdateRootCA(cert, key []byte, certExpiry time.Duration
return err
}

// the RootCA pool should validate against the TLS certificate in the credentials
if s.ClientTLSCreds != nil {
s.ClientTLSCreds.UpdateCAs(rootCA.Pool, nil)
s.rootCA = &rootCA
clientTLSConfig := s.ClientTLSCreds.Config()
return s.updateTLSCredentials(clientTLSConfig.Certificates)
}

// updateTLSCredentials updates the client, server, and TLS credentials on a security config. This function expects
// something else to have taken out a lock on the SecurityConfig.
func (s *SecurityConfig) updateTLSCredentials(certificates []tls.Certificate) error {
clientConfig, err := NewClientTLSConfig(certificates, s.rootCA.Pool, ManagerRole)
if err != nil {
return errors.Wrap(err, "failed to create a new client config using the new root CA")
}

if s.ServerTLSCreds != nil {
s.ServerTLSCreds.UpdateCAs(rootCA.Pool, rootCA.Pool)
serverConfig, err := NewServerTLSConfig(certificates, s.rootCA.Pool)
if err != nil {
return errors.Wrap(err, "failed to create a new server config using the new root CA")
}

if s.externalCA != nil {
clientTLSConfig := s.ClientTLSCreds.Config()
if err := s.ClientTLSCreds.loadNewTLSConfig(clientConfig); err != nil {
return errors.Wrap(err, "failed to update the client credentials")
}

externalCATLSConfig := &tls.Config{
Certificates: clientTLSConfig.Certificates,
RootCAs: rootCA.Pool,
MinVersion: tls.VersionTLS12,
}
// Update the external CA to use the new client TLS
// config using a copy without a serverName specified.
s.externalCA.UpdateTLSConfig(&tls.Config{
Certificates: certificates,
RootCAs: s.rootCA.Pool,
MinVersion: tls.VersionTLS12,
})

s.externalCA.UpdateTLSConfig(externalCATLSConfig)
if err := s.ServerTLSCreds.loadNewTLSConfig(serverConfig); err != nil {
return errors.Wrap(err, "failed to update the server TLS credentials")
}

s.rootCA = &rootCA
return nil
}

Expand Down Expand Up @@ -424,37 +436,9 @@ func RenewTLSConfigNow(ctx context.Context, s *SecurityConfig, connBroker *conne
log.WithError(err).Errorf("failed to renew the certificate")
return err
}

clientTLSConfig, err := NewClientTLSConfig(tlsKeyPair, rootCA.Pool, CARole)
if err != nil {
log.WithError(err).Errorf("failed to create a new client config")
return err
}
serverTLSConfig, err := NewServerTLSConfig(tlsKeyPair, rootCA.Pool)
if err != nil {
log.WithError(err).Errorf("failed to create a new server config")
return err
}

if err = s.ClientTLSCreds.LoadNewTLSConfig(clientTLSConfig); err != nil {
log.WithError(err).Errorf("failed to update the client credentials")
return err
}

// Update the external CA to use the new client TLS
// config using a copy without a serverName specified.
s.externalCA.UpdateTLSConfig(&tls.Config{
Certificates: clientTLSConfig.Certificates,
RootCAs: clientTLSConfig.RootCAs,
MinVersion: tls.VersionTLS12,
})

if err = s.ServerTLSCreds.LoadNewTLSConfig(serverTLSConfig); err != nil {
log.WithError(err).Errorf("failed to update the server TLS credentials")
return err
}

return nil
s.mu.Lock()
defer s.mu.Unlock()
return s.updateTLSCredentials([]tls.Certificate{*tlsKeyPair})
}

// RenewTLSConfig will continuously monitor for the necessity of renewing the local certificates, either by
Expand Down Expand Up @@ -574,13 +558,13 @@ func calculateRandomExpiry(validFrom, validUntil time.Time) time.Duration {

// NewServerTLSConfig returns a tls.Config configured for a TLS Server, given a tls.Certificate
// and the PEM-encoded root CA Certificate
func NewServerTLSConfig(cert *tls.Certificate, rootCAPool *x509.CertPool) (*tls.Config, error) {
func NewServerTLSConfig(certs []tls.Certificate, rootCAPool *x509.CertPool) (*tls.Config, error) {
if rootCAPool == nil {
return nil, errors.New("valid root CA pool required")
}

return &tls.Config{
Certificates: []tls.Certificate{*cert},
Certificates: certs,
// Since we're using the same CA server to issue Certificates to new nodes, we can't
// use tls.RequireAndVerifyClientCert
ClientAuth: tls.VerifyClientCertIfGiven,
Expand All @@ -593,14 +577,14 @@ func NewServerTLSConfig(cert *tls.Certificate, rootCAPool *x509.CertPool) (*tls.

// NewClientTLSConfig returns a tls.Config configured for a TLS Client, given a tls.Certificate
// the PEM-encoded root CA Certificate, and the name of the remote server the client wants to connect to.
func NewClientTLSConfig(cert *tls.Certificate, rootCAPool *x509.CertPool, serverName string) (*tls.Config, error) {
func NewClientTLSConfig(certs []tls.Certificate, rootCAPool *x509.CertPool, serverName string) (*tls.Config, error) {
if rootCAPool == nil {
return nil, errors.New("valid root CA pool required")
}

return &tls.Config{
ServerName: serverName,
Certificates: []tls.Certificate{*cert},
Certificates: certs,
RootCAs: rootCAPool,
MinVersion: tls.VersionTLS12,
}, nil
Expand All @@ -609,7 +593,7 @@ func NewClientTLSConfig(cert *tls.Certificate, rootCAPool *x509.CertPool, server
// NewClientTLSCredentials returns GRPC credentials for a TLS GRPC client, given a tls.Certificate
// a PEM-Encoded root CA Certificate, and the name of the remote server the client wants to connect to.
func (rootCA *RootCA) NewClientTLSCredentials(cert *tls.Certificate, serverName string) (*MutableTLSCreds, error) {
tlsConfig, err := NewClientTLSConfig(cert, rootCA.Pool, serverName)
tlsConfig, err := NewClientTLSConfig([]tls.Certificate{*cert}, rootCA.Pool, serverName)
if err != nil {
return nil, err
}
Expand All @@ -622,7 +606,7 @@ func (rootCA *RootCA) NewClientTLSCredentials(cert *tls.Certificate, serverName
// NewServerTLSCredentials returns GRPC credentials for a TLS GRPC client, given a tls.Certificate
// a PEM-Encoded root CA Certificate, and the name of the remote server the client wants to connect to.
func (rootCA *RootCA) NewServerTLSCredentials(cert *tls.Certificate) (*MutableTLSCreds, error) {
tlsConfig, err := NewServerTLSConfig(cert, rootCA.Pool)
tlsConfig, err := NewServerTLSConfig([]tls.Certificate{*cert}, rootCA.Pool)
if err != nil {
return nil, err
}
Expand Down
52 changes: 52 additions & 0 deletions ca/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ca_test
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"os"
Expand Down Expand Up @@ -371,6 +372,57 @@ func TestSecurityConfigUpdateRootCA(t *testing.T) {
}
}

// enforce that no matter what order updating the root CA and updating TLS credential happens, we
// end up with a security config that has updated certs, and an updated root pool
func TestRenewTLSConfigUpdateRootCARace(t *testing.T) {
tc := testutils.NewTestCA(t)
defer tc.Stop()
paths := ca.NewConfigPaths(tc.TempDir)

secConfig, err := tc.WriteNewNodeConfig(ca.WorkerRole)
require.NoError(t, err)

leafCert, err := ioutil.ReadFile(paths.Node.Cert)
require.NoError(t, err)

for i := 0; i < 5; i++ {
cert, _, err := testutils.CreateRootCertAndKey(fmt.Sprintf("root %d", i+2))
require.NoError(t, err)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

done1, done2 := make(chan struct{}), make(chan struct{})
rootCA := secConfig.RootCA()
go func() {
defer close(done1)
var key []byte
if rootCA.Signer != nil {
key = rootCA.Signer.Key
}
require.NoError(t, secConfig.UpdateRootCA(append(rootCA.Cert, cert...), key, ca.DefaultNodeCertExpiration))
}()

go func() {
defer close(done2)
require.NoError(t, ca.RenewTLSConfigNow(ctx, secConfig, tc.ConnBroker))
}()

<-done1
<-done2

newCert, err := ioutil.ReadFile(paths.Node.Cert)
require.NoError(t, err)

require.NotEqual(t, newCert, leafCert)
leafCert = newCert

// at the start of this loop had i+1 certs, afterward should have added one more
require.Len(t, secConfig.ClientTLSCreds.Config().RootCAs.Subjects(), i+2)
require.Len(t, secConfig.ServerTLSCreds.Config().RootCAs.Subjects(), i+2)
}
}

func TestRenewTLSConfigWorker(t *testing.T) {
t.Parallel()

Expand Down
15 changes: 2 additions & 13 deletions ca/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"strings"
"sync"

"github.com/docker/docker/pkg/tlsconfig"
"github.com/pkg/errors"
"golang.org/x/net/context"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -121,8 +120,8 @@ func (c *MutableTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentia
return conn, credentials.TLSInfo{State: conn.ConnectionState()}, nil
}

// LoadNewTLSConfig replaces the currently loaded TLS config with a new one
func (c *MutableTLSCreds) LoadNewTLSConfig(newConfig *tls.Config) error {
// loadNewTLSConfig replaces the currently loaded TLS config with a new one
func (c *MutableTLSCreds) loadNewTLSConfig(newConfig *tls.Config) error {
newSubject, err := GetAndValidateCertificateSubject(newConfig.Certificates)
if err != nil {
return err
Expand All @@ -136,16 +135,6 @@ func (c *MutableTLSCreds) LoadNewTLSConfig(newConfig *tls.Config) error {
return nil
}

// UpdateCAs updates the root CAs and client CAs of the existing TLS config in place
func (c *MutableTLSCreds) UpdateCAs(rootCAs, clientCAs *x509.CertPool) {
c.Lock()
defer c.Unlock()
config := tlsconfig.Clone(c.config)
config.RootCAs = rootCAs
config.ClientCAs = clientCAs
c.config = config
}

// Config returns the current underlying TLS config.
func (c *MutableTLSCreds) Config() *tls.Config {
c.Lock()
Expand Down
67 changes: 43 additions & 24 deletions ca/transport_test.go
Original file line number Diff line number Diff line change
@@ -1,66 +1,85 @@
package ca_test
package ca

import (
"crypto/tls"
"io/ioutil"
"os"
"testing"

"github.com/docker/swarmkit/ca"
"github.com/docker/swarmkit/ca/testutils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewMutableTLS(t *testing.T) {
tc := testutils.NewTestCA(t)
defer tc.Stop()
tempdir, err := ioutil.TempDir("", "test-transport")
require.NoError(t, err)
defer os.RemoveAll(tempdir)
paths := NewConfigPaths(tempdir)
krw := NewKeyReadWriter(paths.Node, nil, nil)

cert, err := tc.RootCA.IssueAndSaveNewCertificates(tc.KeyReadWriter, "CN", ca.ManagerRole, tc.Organization)
rootCA, err := CreateRootCA("rootCN", paths.RootCA)
require.NoError(t, err)

cert, err := rootCA.IssueAndSaveNewCertificates(krw, "CN", ManagerRole, "org")
assert.NoError(t, err)

tlsConfig, err := ca.NewServerTLSConfig(cert, tc.RootCA.Pool)
tlsConfig, err := NewServerTLSConfig([]tls.Certificate{*cert}, rootCA.Pool)
assert.NoError(t, err)
creds, err := ca.NewMutableTLS(tlsConfig)
creds, err := NewMutableTLS(tlsConfig)
assert.NoError(t, err)
assert.Equal(t, ca.ManagerRole, creds.Role())
assert.Equal(t, ManagerRole, creds.Role())
assert.Equal(t, "CN", creds.NodeID())
}

func TestGetAndValidateCertificateSubject(t *testing.T) {
tc := testutils.NewTestCA(t)
defer tc.Stop()
tempdir, err := ioutil.TempDir("", "test-transport")
require.NoError(t, err)
defer os.RemoveAll(tempdir)
paths := NewConfigPaths(tempdir)
krw := NewKeyReadWriter(paths.Node, nil, nil)

rootCA, err := CreateRootCA("rootCN", paths.RootCA)
require.NoError(t, err)

cert, err := tc.RootCA.IssueAndSaveNewCertificates(tc.KeyReadWriter, "CN", ca.ManagerRole, tc.Organization)
cert, err := rootCA.IssueAndSaveNewCertificates(krw, "CN", ManagerRole, "org")
assert.NoError(t, err)

name, err := ca.GetAndValidateCertificateSubject([]tls.Certificate{*cert})
name, err := GetAndValidateCertificateSubject([]tls.Certificate{*cert})
assert.NoError(t, err)
assert.Equal(t, "CN", name.CommonName)
assert.Len(t, name.OrganizationalUnit, 1)
assert.Equal(t, ca.ManagerRole, name.OrganizationalUnit[0])
assert.Equal(t, ManagerRole, name.OrganizationalUnit[0])
}

func TestLoadNewTLSConfig(t *testing.T) {
tc := testutils.NewTestCA(t)
defer tc.Stop()
tempdir, err := ioutil.TempDir("", "test-transport")
require.NoError(t, err)
defer os.RemoveAll(tempdir)
paths := NewConfigPaths(tempdir)
krw := NewKeyReadWriter(paths.Node, nil, nil)

rootCA, err := CreateRootCA("rootCN", paths.RootCA)
require.NoError(t, err)

// Create two different certs and two different TLS configs
cert1, err := tc.RootCA.IssueAndSaveNewCertificates(tc.KeyReadWriter, "CN1", ca.ManagerRole, tc.Organization)
cert1, err := rootCA.IssueAndSaveNewCertificates(krw, "CN1", ManagerRole, "org")
assert.NoError(t, err)
cert2, err := tc.RootCA.IssueAndSaveNewCertificates(tc.KeyReadWriter, "CN2", ca.WorkerRole, tc.Organization)
cert2, err := rootCA.IssueAndSaveNewCertificates(krw, "CN2", WorkerRole, "org")
assert.NoError(t, err)
tlsConfig1, err := ca.NewServerTLSConfig(cert1, tc.RootCA.Pool)
tlsConfig1, err := NewServerTLSConfig([]tls.Certificate{*cert1}, rootCA.Pool)
assert.NoError(t, err)
tlsConfig2, err := ca.NewServerTLSConfig(cert2, tc.RootCA.Pool)
tlsConfig2, err := NewServerTLSConfig([]tls.Certificate{*cert2}, rootCA.Pool)
assert.NoError(t, err)

// Load the first TLS config into a MutableTLS
creds, err := ca.NewMutableTLS(tlsConfig1)
creds, err := NewMutableTLS(tlsConfig1)
assert.NoError(t, err)
assert.Equal(t, ca.ManagerRole, creds.Role())
assert.Equal(t, ManagerRole, creds.Role())
assert.Equal(t, "CN1", creds.NodeID())

// Load the new Config and assert it changed
err = creds.LoadNewTLSConfig(tlsConfig2)
err = creds.loadNewTLSConfig(tlsConfig2)
assert.NoError(t, err)
assert.Equal(t, ca.WorkerRole, creds.Role())
assert.Equal(t, WorkerRole, creds.Role())
assert.Equal(t, "CN2", creds.NodeID())
}