Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions crypto/spiffe/spiffe.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ const (
type SVIDResponse struct {
X509Certificates []*x509.Certificate
JWT *string
PerAudienceJWT map[string]string
}

// Identity contains both X.509 and JWT SVIDs for a workload.
type Identity struct {
X509SVID *x509svid.SVID
JWTSVID *jwtsvid.SVID
X509SVID *x509svid.SVID
JWTSVID *jwtsvid.SVID
PerAudienceJWTSVID map[string]*jwtsvid.SVID
}

type (
Expand All @@ -77,8 +79,11 @@ type Options struct {
// Used to manage workload SVIDs, and share read-only interfaces to consumers.
type SPIFFE struct {
currentX509SVID *x509svid.SVID
currentJWTSVID *jwtsvid.SVID
requestSVIDFn RequestSVIDFn

currentBaseJWTSVID *jwtsvid.SVID
currentPerAudJWTSVID map[string]*jwtsvid.SVID

requestSVIDFn RequestSVIDFn

dir *dir.Dir
trustAnchors trustanchors.Interface
Expand Down Expand Up @@ -124,7 +129,8 @@ func (s *SPIFFE) Run(ctx context.Context) error {
}

s.currentX509SVID = initialIdentity.X509SVID
s.currentJWTSVID = initialIdentity.JWTSVID
s.currentBaseJWTSVID = initialIdentity.JWTSVID
s.currentPerAudJWTSVID = initialIdentity.PerAudienceJWTSVID
close(s.readyCh)
s.lock.Unlock()

Expand Down Expand Up @@ -169,7 +175,7 @@ func (s *SPIFFE) runRotation(ctx context.Context) {

s.lock.RLock()
cert := s.currentX509SVID.Certificates[0]
jwtSVID := s.currentJWTSVID
jwtSVID := s.currentBaseJWTSVID
s.lock.RUnlock()
Comment on lines 176 to 179
Copy link

Copilot AI Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rotation scheduler still computes renewal time using only the base JWT SVID. Per-audience JWT SVIDs may have earlier expiries, which could leave expired per-audience tokens being served until the next cert/base-JWT renewal. Consider incorporating the earliest expiry among all JWT SVIDs (base + per-audience) when calculating renewTime.

Copilot uses AI. Check for mistakes.

renewTime := calculateRenewalTime(time.Now(), cert, jwtSVID)
Expand Down Expand Up @@ -197,7 +203,8 @@ func (s *SPIFFE) runRotation(ctx context.Context) {

s.lock.Lock()
s.currentX509SVID = identity.X509SVID
s.currentJWTSVID = identity.JWTSVID
s.currentBaseJWTSVID = identity.JWTSVID
s.currentPerAudJWTSVID = identity.PerAudienceJWTSVID
cert = identity.X509SVID.Certificates[0]
jwtSVID = identity.JWTSVID
s.lock.Unlock()
Expand Down Expand Up @@ -265,6 +272,19 @@ func (s *SPIFFE) fetchIdentity(ctx context.Context) (*Identity, error) {
s.log.Infof("Successfully received JWT SVID with expiry: %s", jwtSvid.Expiry.String())
}

for aud, token := range svidResponse.PerAudienceJWT {
jwtSvid, err := jwtsvid.ParseInsecure(token, []string{aud})
if err != nil {
return nil, fmt.Errorf("failed to parse JWT SVID: %w", err)
}
Comment on lines +275 to +279
Copy link

Copilot AI Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When parsing per-audience JWTs, the returned error doesn't include which audience key failed. Including the audience (and possibly whether it's from PerAudienceJWT vs the base JWT) would make troubleshooting malformed tokens significantly easier.

Copilot uses AI. Check for mistakes.

if identity.PerAudienceJWTSVID == nil {
identity.PerAudienceJWTSVID = make(map[string]*jwtsvid.SVID)
}
identity.PerAudienceJWTSVID[aud] = jwtSvid
s.log.Infof("Successfully received per-audience JWT SVID for audience %s with expiry: %s", aud, jwtSvid.Expiry.String())
}
Comment on lines +275 to +286
Copy link

Copilot AI Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new loop adds behavior to parse and store per-audience JWT SVIDs, but there doesn't appear to be test coverage validating that (1) PerAudienceJWT values from RequestSVIDFn are parsed into Identity.PerAudienceJWTSVID and (2) those entries are then served by FetchJWTSVID. Adding a focused unit test around fetchIdentity/Run would help prevent regressions.

Copilot uses AI. Check for mistakes.

if s.dir != nil {
pkPEM, err := pem.EncodePrivateKey(key)
if err != nil {
Expand Down
5 changes: 4 additions & 1 deletion crypto/spiffe/svidsource.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ func (s *svidSource) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (*
case <-s.spiffe.readyCh:
}

svid := s.spiffe.currentJWTSVID
svid, ok := s.spiffe.currentPerAudJWTSVID[params.Audience]
if !ok {
svid = s.spiffe.currentBaseJWTSVID
}
if svid == nil {
return nil, errNoJWTSVIDAvailable
}
Expand Down
83 changes: 71 additions & 12 deletions crypto/spiffe/svidsource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ func TestFetchJWTSVID(t *testing.T) {
t.Run("should return error when no JWT SVID available", func(t *testing.T) {
s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentJWTSVID: nil,
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentBaseJWTSVID: nil,
},
}
close(s.spiffe.readyCh) // Mark as ready
Expand All @@ -94,9 +94,39 @@ func TestFetchJWTSVID(t *testing.T) {

s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentJWTSVID: mockJWTSVID,
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentBaseJWTSVID: mockJWTSVID,
},
}
close(s.spiffe.readyCh) // Mark as ready

svid, err := s.FetchJWTSVID(t.Context(), jwtsvid.Params{
Audience: "requested-audience",
})

require.Nil(t, svid)
require.Error(t, err)

// Verify the specific error type and contents
audienceErr, ok := err.(*audienceMismatchError)
require.True(t, ok, "Expected audienceMismatchError")
require.Equal(t, "JWT SVID has different audiences than requested: expected requested-audience, got actual-audience", audienceErr.Error())
})

t.Run("PER: should return error when audience doesn't match", func(t *testing.T) {
// Create a mock SVID with a specific audience
mockJWTSVID, err := createMockJWTSVID([]string{"actual-audience"})
require.NoError(t, err)

s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentBaseJWTSVID: mockJWTSVID,
currentPerAudJWTSVID: map[string]*jwtsvid.SVID{
"actual-audience": mockJWTSVID,
},
},
}
close(s.spiffe.readyCh) // Mark as ready
Expand All @@ -120,9 +150,9 @@ func TestFetchJWTSVID(t *testing.T) {

s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentJWTSVID: mockJWTSVID,
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentBaseJWTSVID: mockJWTSVID,
},
}
close(s.spiffe.readyCh) // Mark as ready
Expand All @@ -135,16 +165,45 @@ func TestFetchJWTSVID(t *testing.T) {
require.Equal(t, mockJWTSVID, svid)
})

t.Run("PER: should return JWT SVID when audience matches", func(t *testing.T) {
mockJWTSVID1, err := createMockJWTSVID([]string{"test-audience", "extra-audience"})
require.NoError(t, err)
mockJWTSVID2, err := createMockJWTSVID([]string{"test-audience"})
require.NoError(t, err)
mockJWTSVID3, err := createMockJWTSVID([]string{"extra-audience"})
require.NoError(t, err)

s := &svidSource{
spiffe: &SPIFFE{
readyCh: make(chan struct{}),
lock: sync.RWMutex{},
currentBaseJWTSVID: mockJWTSVID1,
currentPerAudJWTSVID: map[string]*jwtsvid.SVID{
"test-audience": mockJWTSVID2,
"extra-audience": mockJWTSVID3,
},
},
}
close(s.spiffe.readyCh) // Mark as ready

svid, err := s.FetchJWTSVID(t.Context(), jwtsvid.Params{
Audience: "test-audience",
})

require.NoError(t, err)
require.Equal(t, mockJWTSVID2, svid)
})

t.Run("should wait for readyCh before checking SVID", func(t *testing.T) {
mockJWTSVID, err := createMockJWTSVID([]string{"test-audience"})
require.NoError(t, err)

readyCh := make(chan struct{})
s := &svidSource{
spiffe: &SPIFFE{
readyCh: readyCh,
lock: sync.RWMutex{},
currentJWTSVID: mockJWTSVID,
readyCh: readyCh,
lock: sync.RWMutex{},
currentBaseJWTSVID: mockJWTSVID,
},
}

Expand Down