diff --git a/ca/certificates.go b/ca/certificates.go index 7258e30593..acc6dc6a27 100644 --- a/ca/certificates.go +++ b/ca/certificates.go @@ -151,6 +151,22 @@ func (rca *RootCA) IssueAndSaveNewCertificates(kw KeyWriter, cn, ou, org string) return &tlsKeyPair, nil } +// Normally we can just call cert.Verify(opts), but since we actually want more information about +// whether a certificate is not yet valid or expired, we also need to perform the expiry checks ourselves. +func verifyCertificate(cert *x509.Certificate, opts x509.VerifyOptions) error { + _, err := cert.Verify(opts) + if invalidErr, ok := err.(x509.CertificateInvalidError); ok && invalidErr.Reason == x509.Expired { + now := time.Now().UTC() + if now.Before(cert.NotBefore) { + return errors.Wrapf(err, "certificate not valid before %s, and it is currently %s", + cert.NotBefore.UTC().Format(time.RFC1123), now.Format(time.RFC1123)) + } + return errors.Wrapf(err, "certificate expires at %s, and it is currently %s", + cert.NotAfter.UTC().Format(time.RFC1123), now.Format(time.RFC1123)) + } + return err +} + // RequestAndSaveNewCertificates gets new certificates issued, either by signing them locally if a signer is // available, or by requesting them from the remote server at remoteAddr. func (rca *RootCA) RequestAndSaveNewCertificates(ctx context.Context, kw KeyWriter, config CertificateRequestConfig) (*tls.Certificate, error) { @@ -199,7 +215,7 @@ func (rca *RootCA) RequestAndSaveNewCertificates(ctx context.Context, kw KeyWrit Roots: rca.Pool, } // Check to see if this certificate was signed by our CA, and isn't expired - if _, err := X509Cert.Verify(opts); err != nil { + if err := verifyCertificate(X509Cert, opts); err != nil { return nil, err } diff --git a/ca/config.go b/ca/config.go index d2664bd635..48d4431e5f 100644 --- a/ca/config.go +++ b/ca/config.go @@ -189,7 +189,7 @@ func GenerateJoinToken(rootCA *RootCA) string { func getCAHashFromToken(token string) (digest.Digest, error) { split := strings.Split(token, "-") - if len(split) != 4 || split[0] != "SWMTKN" || split[1] != "1" { + if len(split) != 4 || split[0] != "SWMTKN" || split[1] != "1" || len(split[2]) != base36DigestLen || len(split[3]) != maxGeneratedSecretLength { return "", errors.New("invalid join token") } @@ -273,7 +273,7 @@ func LoadSecurityConfig(ctx context.Context, rootCA RootCA, krw *KeyReadWriter) } // Check to see if this certificate was signed by our CA, and isn't expired - if _, err := X509Cert.Verify(opts); err != nil { + if err := verifyCertificate(X509Cert, opts); err != nil { return nil, err } diff --git a/ca/config_test.go b/ca/config_test.go index 2e5bea9f46..cb85d0fb46 100644 --- a/ca/config_test.go +++ b/ca/config_test.go @@ -54,9 +54,14 @@ func TestDownloadRootCAWrongCAHash(t *testing.T) { os.RemoveAll(tc.Paths.RootCA.Cert) // invalid token - _, err := ca.DownloadRootCA(tc.Context, tc.Paths.RootCA, "invalidtoken", tc.ConnBroker) - require.Error(t, err) - require.Contains(t, err.Error(), "invalid join token") + for _, invalid := range []string{ + "invalidtoken", // completely invalid + "SWMTKN-1-3wkodtpeoipd1u1hi0ykdcdwhw16dk73ulqqtn14b3indz68rf-4myj5xihyto11dg1cn55w8p6", // mistyped + } { + _, err := ca.DownloadRootCA(tc.Context, tc.Paths.RootCA, invalid, tc.ConnBroker) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid join token") + } // invalid hash token splitToken := strings.Split(tc.ManagerToken, "-") @@ -65,7 +70,7 @@ func TestDownloadRootCAWrongCAHash(t *testing.T) { os.RemoveAll(tc.Paths.RootCA.Cert) - _, err = ca.DownloadRootCA(tc.Context, tc.Paths.RootCA, replacementToken, tc.ConnBroker) + _, err := ca.DownloadRootCA(tc.Context, tc.Paths.RootCA, replacementToken, tc.ConnBroker) require.Error(t, err) require.Contains(t, err.Error(), "remote CA does not match fingerprint.") }