Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 139 additions & 9 deletions x509util/certificate_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
"crypto/x509/pkix"
"encoding/asn1"
"encoding/json"
"errors"
"fmt"

"github.com/pkg/errors"
"go.step.sm/crypto/internal/utils"
"golang.org/x/crypto/cryptobyte"
cryptobyte_asn1 "golang.org/x/crypto/cryptobyte/asn1"
Expand Down Expand Up @@ -53,6 +54,10 @@
URIs MultiURL `json:"uris"`
SANs []SubjectAlternativeName `json:"sans"`
Extensions []Extension `json:"extensions"`
KeyUsage KeyUsage `json:"keyUsage"`
ExtKeyUsage ExtKeyUsage `json:"extKeyUsage"`
UnknownExtKeyUsage UnknownExtKeyUsage `json:"unknownExtKeyUsage"`
BasicConstraints *BasicConstraints `json:"basicConstraints"`
SignatureAlgorithm SignatureAlgorithm `json:"signatureAlgorithm"`
ChallengePassword string `json:"-"`
PublicKey interface{} `json:"-"`
Expand Down Expand Up @@ -83,7 +88,7 @@
// With templates
var cr CertificateRequest
if err := json.NewDecoder(o.CertBuffer).Decode(&cr); err != nil {
return nil, errors.Wrap(err, "error unmarshaling certificate")
return nil, fmt.Errorf("error unmarshaling certificate: %w", err)
}
cr.PublicKey = pub
cr.Signer = signer
Expand All @@ -100,6 +105,33 @@
cr.Extensions = append([]Extension{ext}, cr.Extensions...)
}

// Add KeyUsage extension if necessary.
if cr.KeyUsage != 0 && !cr.hasExtension(oidExtensionKeyUsage) {
ext, err := cr.KeyUsage.Extension()
if err != nil {
return nil, err
}

Check warning on line 113 in x509util/certificate_request.go

View check run for this annotation

Codecov / codecov/patch

x509util/certificate_request.go#L112-L113

Added lines #L112 - L113 were not covered by tests
cr.Extensions = append([]Extension{ext}, cr.Extensions...)
}

// Add ExtKeyUsage extension if necessary.
if len(cr.ExtKeyUsage) > 0 || len(cr.UnknownExtKeyUsage) > 0 {
ext, err := cr.ExtKeyUsage.Extension(cr.UnknownExtKeyUsage)
if err != nil {
return nil, err
}

Check warning on line 122 in x509util/certificate_request.go

View check run for this annotation

Codecov / codecov/patch

x509util/certificate_request.go#L121-L122

Added lines #L121 - L122 were not covered by tests
cr.Extensions = append([]Extension{ext}, cr.Extensions...)
}

// Add BasicConstraints extension if necessary.
if cr.BasicConstraints != nil {
ext, err := cr.BasicConstraints.Extension()
if err != nil {
return nil, err
}

Check warning on line 131 in x509util/certificate_request.go

View check run for this annotation

Codecov / codecov/patch

x509util/certificate_request.go#L130-L131

Added lines #L130 - L131 were not covered by tests
cr.Extensions = append([]Extension{ext}, cr.Extensions...)
}

return &cr, nil
}

Expand All @@ -114,6 +146,12 @@
func NewCertificateRequestFromX509(cr *x509.CertificateRequest) *CertificateRequest {
// Set SubjectAltName extension as critical if Subject is empty.
fixSubjectAltName(cr)
// Extracts key usage, extended key usage, and basic constraints from the
// certificate extensions. For backward compatibility, this method does not
// return an error if an extension is improperly encoded or cannot be
// decoded. In such cases, the extension is simply ignored.
parsed, _ := parseCertificateRequestExtensions(cr.Extensions)

return &CertificateRequest{
Version: cr.Version,
Subject: newSubject(cr.Subject),
Expand All @@ -123,6 +161,10 @@
IPAddresses: cr.IPAddresses,
URIs: cr.URIs,
Extensions: newExtensions(cr.Extensions),
KeyUsage: parsed.KeyUsage,
ExtKeyUsage: parsed.ExtKeyUsage,
UnknownExtKeyUsage: parsed.UnknownExtKeyUsage,
BasicConstraints: parsed.BasicConstraints,
PublicKey: cr.PublicKey,
PublicKeyAlgorithm: cr.PublicKeyAlgorithm,
Signature: cr.Signature,
Expand All @@ -146,7 +188,7 @@
SignatureAlgorithm: x509.SignatureAlgorithm(c.SignatureAlgorithm),
}, c.Signer)
if err != nil {
return nil, errors.Wrap(err, "error creating certificate request")
return nil, fmt.Errorf("error creating certificate request: %w", err)
}

// If a challenge password is provided, encode and prepend it as a challenge
Expand Down Expand Up @@ -193,7 +235,7 @@

b, err := builder.Bytes()
if err != nil {
return nil, errors.Wrap(err, "error marshaling challenge password")
return nil, fmt.Errorf("error marshaling challenge password: %w", err)
}
challengePasswordAttr := asn1.RawValue{
FullBytes: b,
Expand Down Expand Up @@ -223,7 +265,7 @@
// Marshal tbsCertificateRequest
tbsCSRContents, err := asn1.Marshal(tbsCSR)
if err != nil {
return nil, errors.Wrap(err, "error creating certificate request")
return nil, fmt.Errorf("error creating certificate request: %w", err)

Check warning on line 268 in x509util/certificate_request.go

View check run for this annotation

Codecov / codecov/patch

x509util/certificate_request.go#L268

Added line #L268 was not covered by tests
}
tbsCSR.Raw = tbsCSRContents

Expand All @@ -239,7 +281,7 @@
}
}
if !found {
return nil, errors.Errorf("error creating certificate request: unsupported signature algorithm %s", sigAlgoOID)
return nil, fmt.Errorf("error creating certificate request: unsupported signature algorithm %q", sigAlgoOID)
}

// Sign tbsCertificateRequest
Expand All @@ -253,7 +295,7 @@
var signature []byte
signature, err = c.Signer.Sign(rand.Reader, signed, hashFunc)
if err != nil {
return nil, errors.Wrap(err, "error creating certificate request")
return nil, fmt.Errorf("error creating certificate request: %w", err)
}

// Build new certificate request and marshal
Expand All @@ -266,7 +308,7 @@
},
})
if err != nil {
return nil, errors.Wrap(err, "error creating certificate request")
return nil, fmt.Errorf("error creating certificate request: %w", err)

Check warning on line 311 in x509util/certificate_request.go

View check run for this annotation

Codecov / codecov/patch

x509util/certificate_request.go#L311

Added line #L311 was not covered by tests
}
return asn1Data, nil
}
Expand Down Expand Up @@ -351,7 +393,7 @@
URIs: uris,
}, signer)
if err != nil {
return nil, errors.Wrap(err, "error creating certificate request")
return nil, fmt.Errorf("error creating certificate request: %w", err)
}
// This should not fail
return x509.ParseCertificateRequest(asn1Data)
Expand All @@ -368,3 +410,91 @@
}
}
}

type certificateRequestParsedExtensions struct {
KeyUsage KeyUsage
ExtKeyUsage ExtKeyUsage
UnknownExtKeyUsage UnknownExtKeyUsage
BasicConstraints *BasicConstraints
}

func parseCertificateRequestExtensions(exts []pkix.Extension) (cr certificateRequestParsedExtensions, errs error) {
var err error
for _, ext := range exts {
switch {
case ext.Id.Equal(oidExtensionKeyUsage):
if cr.KeyUsage, err = parseKeyUsageExtension(ext.Value); err != nil {
errs = errors.Join(errs, err)
}
case ext.Id.Equal(oidExtensionExtendedKeyUsage):
if cr.ExtKeyUsage, cr.UnknownExtKeyUsage, err = parseExtKeyUsageExtension(ext.Value); err != nil {
errs = errors.Join(errs, err)
}
case ext.Id.Equal(oidExtensionBasicConstraints):
if cr.BasicConstraints, err = parseBasicConstraintsExtension(ext.Value); err != nil {
errs = errors.Join(errs, err)
}
}
}

return
}

func parseKeyUsageExtension(der cryptobyte.String) (KeyUsage, error) {
var usageBits asn1.BitString
if !der.ReadASN1BitString(&usageBits) {
return 0, errors.New("invalid key usage")
}

var usage int
for i := 0; i < 9; i++ {
if usageBits.At(i) != 0 {
usage |= 1 << uint(i)
}
}

return KeyUsage(usage), nil
}

func parseExtKeyUsageExtension(der cryptobyte.String) (ExtKeyUsage, UnknownExtKeyUsage, error) {
var extKeyUsages ExtKeyUsage
var unknownUsages UnknownExtKeyUsage
if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) {
return nil, nil, errors.New("invalid extended key usages")
}
for !der.Empty() {
var eku asn1.ObjectIdentifier
if !der.ReadASN1ObjectIdentifier(&eku) {
return nil, nil, errors.New("invalid extended key usages")
}
if extKeyUsage, ok := extKeyUsageFromOID(eku); ok {
extKeyUsages = append(extKeyUsages, extKeyUsage)
} else {
unknownUsages = append(unknownUsages, eku)
}
}

return extKeyUsages, unknownUsages, nil
}

func parseBasicConstraintsExtension(der cryptobyte.String) (*BasicConstraints, error) {
var isCA bool
if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("invalid basic constraints")
}
if der.PeekASN1Tag(cryptobyte_asn1.BOOLEAN) {
if !der.ReadASN1Boolean(&isCA) {
return nil, errors.New("invalid basic constraints")
}
}
maxPathLen := -1
if der.PeekASN1Tag(cryptobyte_asn1.INTEGER) {
if !der.ReadASN1Integer(&maxPathLen) {
return nil, errors.New("invalid basic constraints")
}
}

return &BasicConstraints{
IsCA: isCA, MaxPathLen: maxPathLen,
}, nil
}
Loading
Loading