Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion cmd/authtoken/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,19 @@ func parseArgs() (interfaces.AuthTokenProvider, error) {
_ = secretCmd.MarkFlagRequired("namespace")

var clientID string
var scope string
azureCmd := &cobra.Command{
Use: "azure",
Args: cobra.NoArgs,
Run: func(_ *cobra.Command, args []string) {
tokenProvider = azure.New(clientID)
tokenProvider = azure.New(clientID, scope)
},
}

azureCmd.Flags().StringVar(&clientID, "clientid", "", "Azure AAD client ID (required)")
// TODO: this scope argument is specific for Azure provider. We should allow registering and parsing provider specific argument
// in provider level, instead of global level.
azureCmd.Flags().StringVar(&scope, "scope", "", "Azure AAD token scope (optional)")
_ = azureCmd.MarkFlagRequired("clientid")

rootCmd.AddCommand(secretCmd, azureCmd)
Expand Down
39 changes: 39 additions & 0 deletions cmd/authtoken/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package main

import (
"os"
"testing"

"github.com/stretchr/testify/assert"

"go.goms.io/fleet/pkg/authtoken/providers/azure"
)

func TestParseArgs(t *testing.T) {
t.Run("all arguments", func(t *testing.T) {
os.Args = []string{"refreshtoken", "azure", "--clientid=test-client-id", "--scope=test-scope"}
t.Cleanup(func() {
os.Args = nil
})
tokenProvider, err := parseArgs()
assert.NotNil(t, tokenProvider)
assert.Nil(t, err)

azTokenProvider, ok := tokenProvider.(*azure.AuthTokenProvider)
assert.Equal(t, true, ok)
assert.Equal(t, "test-scope", azTokenProvider.Scope)
})
t.Run("no optional arguments", func(t *testing.T) {
os.Args = []string{"refreshtoken", "azure", "--clientid=test-client-id"}
t.Cleanup(func() {
os.Args = nil
})
tokenProvider, err := parseArgs()
assert.NotNil(t, tokenProvider)
assert.Nil(t, err)

azTokenProvider, ok := tokenProvider.(*azure.AuthTokenProvider)
assert.Equal(t, true, ok)
assert.Equal(t, "6dae42f8-4368-4678-94ff-3960e28e3630", azTokenProvider.Scope)
})
}
23 changes: 14 additions & 9 deletions pkg/authtoken/providers/azure/azure_msi.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,27 @@ const (
aksScope = "6dae42f8-4368-4678-94ff-3960e28e3630"
)

type azureAuthTokenProvider struct {
clientID string
type AuthTokenProvider struct {
ClientID string
Scope string
}

func New(clientID string) interfaces.AuthTokenProvider {
return &azureAuthTokenProvider{
clientID: clientID,
func New(clientID, scope string) interfaces.AuthTokenProvider {
if scope == "" {
scope = aksScope
}
return &AuthTokenProvider{
ClientID: clientID,
Scope: scope,
}
}

// FetchToken gets a new token to make request to the associated fleet' hub cluster.
func (a *azureAuthTokenProvider) FetchToken(ctx context.Context) (interfaces.AuthToken, error) {
func (a *AuthTokenProvider) FetchToken(ctx context.Context) (interfaces.AuthToken, error) {
token := interfaces.AuthToken{}
opts := &azidentity.ManagedIdentityCredentialOptions{ID: azidentity.ClientID(a.clientID)}
opts := &azidentity.ManagedIdentityCredentialOptions{ID: azidentity.ClientID(a.ClientID)}

klog.V(2).InfoS("FetchToken", "client ID", a.clientID)
klog.V(2).InfoS("FetchToken", "client ID", a.ClientID)
credential, err := azidentity.NewManagedIdentityCredential(opts)
if err != nil {
return token, fmt.Errorf("failed to create managed identity cred: %w", err)
Expand All @@ -48,7 +53,7 @@ func (a *azureAuthTokenProvider) FetchToken(ctx context.Context) (interfaces.Aut
}, func() error {
klog.V(2).InfoS("GetToken start", "credential", credential)
azToken, err = credential.GetToken(ctx, policy.TokenRequestOptions{
Scopes: []string{aksScope},
Scopes: []string{a.Scope},
})
if err != nil {
klog.ErrorS(err, "Failed to GetToken")
Expand Down