diff --git a/crypto/spiffe/spiffe.go b/crypto/spiffe/spiffe.go index 966a73d..5be423d 100644 --- a/crypto/spiffe/spiffe.go +++ b/crypto/spiffe/spiffe.go @@ -46,12 +46,14 @@ const ( type SVIDResponse struct { X509Certificates []*x509.Certificate JWT *string + PerAudienceJWT map[string]string } // Identity contains both X.509 and JWT SVIDs for a workload. type Identity struct { - X509SVID *x509svid.SVID - JWTSVID *jwtsvid.SVID + X509SVID *x509svid.SVID + JWTSVID *jwtsvid.SVID + PerAudienceJWTSVID map[string]*jwtsvid.SVID } type ( @@ -77,8 +79,11 @@ type Options struct { // Used to manage workload SVIDs, and share read-only interfaces to consumers. type SPIFFE struct { currentX509SVID *x509svid.SVID - currentJWTSVID *jwtsvid.SVID - requestSVIDFn RequestSVIDFn + + currentBaseJWTSVID *jwtsvid.SVID + currentPerAudJWTSVID map[string]*jwtsvid.SVID + + requestSVIDFn RequestSVIDFn dir *dir.Dir trustAnchors trustanchors.Interface @@ -124,7 +129,8 @@ func (s *SPIFFE) Run(ctx context.Context) error { } s.currentX509SVID = initialIdentity.X509SVID - s.currentJWTSVID = initialIdentity.JWTSVID + s.currentBaseJWTSVID = initialIdentity.JWTSVID + s.currentPerAudJWTSVID = initialIdentity.PerAudienceJWTSVID close(s.readyCh) s.lock.Unlock() @@ -169,7 +175,7 @@ func (s *SPIFFE) runRotation(ctx context.Context) { s.lock.RLock() cert := s.currentX509SVID.Certificates[0] - jwtSVID := s.currentJWTSVID + jwtSVID := s.currentBaseJWTSVID s.lock.RUnlock() renewTime := calculateRenewalTime(time.Now(), cert, jwtSVID) @@ -197,7 +203,8 @@ func (s *SPIFFE) runRotation(ctx context.Context) { s.lock.Lock() s.currentX509SVID = identity.X509SVID - s.currentJWTSVID = identity.JWTSVID + s.currentBaseJWTSVID = identity.JWTSVID + s.currentPerAudJWTSVID = identity.PerAudienceJWTSVID cert = identity.X509SVID.Certificates[0] jwtSVID = identity.JWTSVID s.lock.Unlock() @@ -265,6 +272,19 @@ func (s *SPIFFE) fetchIdentity(ctx context.Context) (*Identity, error) { s.log.Infof("Successfully received JWT SVID with expiry: %s", jwtSvid.Expiry.String()) } + for aud, token := range svidResponse.PerAudienceJWT { + jwtSvid, err := jwtsvid.ParseInsecure(token, []string{aud}) + if err != nil { + return nil, fmt.Errorf("failed to parse JWT SVID: %w", err) + } + + if identity.PerAudienceJWTSVID == nil { + identity.PerAudienceJWTSVID = make(map[string]*jwtsvid.SVID) + } + identity.PerAudienceJWTSVID[aud] = jwtSvid + s.log.Infof("Successfully received per-audience JWT SVID for audience %s with expiry: %s", aud, jwtSvid.Expiry.String()) + } + if s.dir != nil { pkPEM, err := pem.EncodePrivateKey(key) if err != nil { diff --git a/crypto/spiffe/svidsource.go b/crypto/spiffe/svidsource.go index 568638a..6409ce4 100644 --- a/crypto/spiffe/svidsource.go +++ b/crypto/spiffe/svidsource.go @@ -77,7 +77,10 @@ func (s *svidSource) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (* case <-s.spiffe.readyCh: } - svid := s.spiffe.currentJWTSVID + svid, ok := s.spiffe.currentPerAudJWTSVID[params.Audience] + if !ok { + svid = s.spiffe.currentBaseJWTSVID + } if svid == nil { return nil, errNoJWTSVIDAvailable } diff --git a/crypto/spiffe/svidsource_test.go b/crypto/spiffe/svidsource_test.go index fb26ebe..a04126d 100644 --- a/crypto/spiffe/svidsource_test.go +++ b/crypto/spiffe/svidsource_test.go @@ -72,9 +72,9 @@ func TestFetchJWTSVID(t *testing.T) { t.Run("should return error when no JWT SVID available", func(t *testing.T) { s := &svidSource{ spiffe: &SPIFFE{ - readyCh: make(chan struct{}), - lock: sync.RWMutex{}, - currentJWTSVID: nil, + readyCh: make(chan struct{}), + lock: sync.RWMutex{}, + currentBaseJWTSVID: nil, }, } close(s.spiffe.readyCh) // Mark as ready @@ -94,9 +94,39 @@ func TestFetchJWTSVID(t *testing.T) { s := &svidSource{ spiffe: &SPIFFE{ - readyCh: make(chan struct{}), - lock: sync.RWMutex{}, - currentJWTSVID: mockJWTSVID, + readyCh: make(chan struct{}), + lock: sync.RWMutex{}, + currentBaseJWTSVID: mockJWTSVID, + }, + } + close(s.spiffe.readyCh) // Mark as ready + + svid, err := s.FetchJWTSVID(t.Context(), jwtsvid.Params{ + Audience: "requested-audience", + }) + + require.Nil(t, svid) + require.Error(t, err) + + // Verify the specific error type and contents + audienceErr, ok := err.(*audienceMismatchError) + require.True(t, ok, "Expected audienceMismatchError") + require.Equal(t, "JWT SVID has different audiences than requested: expected requested-audience, got actual-audience", audienceErr.Error()) + }) + + t.Run("PER: should return error when audience doesn't match", func(t *testing.T) { + // Create a mock SVID with a specific audience + mockJWTSVID, err := createMockJWTSVID([]string{"actual-audience"}) + require.NoError(t, err) + + s := &svidSource{ + spiffe: &SPIFFE{ + readyCh: make(chan struct{}), + lock: sync.RWMutex{}, + currentBaseJWTSVID: mockJWTSVID, + currentPerAudJWTSVID: map[string]*jwtsvid.SVID{ + "actual-audience": mockJWTSVID, + }, }, } close(s.spiffe.readyCh) // Mark as ready @@ -120,9 +150,9 @@ func TestFetchJWTSVID(t *testing.T) { s := &svidSource{ spiffe: &SPIFFE{ - readyCh: make(chan struct{}), - lock: sync.RWMutex{}, - currentJWTSVID: mockJWTSVID, + readyCh: make(chan struct{}), + lock: sync.RWMutex{}, + currentBaseJWTSVID: mockJWTSVID, }, } close(s.spiffe.readyCh) // Mark as ready @@ -135,6 +165,35 @@ func TestFetchJWTSVID(t *testing.T) { require.Equal(t, mockJWTSVID, svid) }) + t.Run("PER: should return JWT SVID when audience matches", func(t *testing.T) { + mockJWTSVID1, err := createMockJWTSVID([]string{"test-audience", "extra-audience"}) + require.NoError(t, err) + mockJWTSVID2, err := createMockJWTSVID([]string{"test-audience"}) + require.NoError(t, err) + mockJWTSVID3, err := createMockJWTSVID([]string{"extra-audience"}) + require.NoError(t, err) + + s := &svidSource{ + spiffe: &SPIFFE{ + readyCh: make(chan struct{}), + lock: sync.RWMutex{}, + currentBaseJWTSVID: mockJWTSVID1, + currentPerAudJWTSVID: map[string]*jwtsvid.SVID{ + "test-audience": mockJWTSVID2, + "extra-audience": mockJWTSVID3, + }, + }, + } + close(s.spiffe.readyCh) // Mark as ready + + svid, err := s.FetchJWTSVID(t.Context(), jwtsvid.Params{ + Audience: "test-audience", + }) + + require.NoError(t, err) + require.Equal(t, mockJWTSVID2, svid) + }) + t.Run("should wait for readyCh before checking SVID", func(t *testing.T) { mockJWTSVID, err := createMockJWTSVID([]string{"test-audience"}) require.NoError(t, err) @@ -142,9 +201,9 @@ func TestFetchJWTSVID(t *testing.T) { readyCh := make(chan struct{}) s := &svidSource{ spiffe: &SPIFFE{ - readyCh: readyCh, - lock: sync.RWMutex{}, - currentJWTSVID: mockJWTSVID, + readyCh: readyCh, + lock: sync.RWMutex{}, + currentBaseJWTSVID: mockJWTSVID, }, }