diff --git a/cmd/authtoken/main.go b/cmd/authtoken/main.go index fe743e485..c3f94162c 100644 --- a/cmd/authtoken/main.go +++ b/cmd/authtoken/main.go @@ -50,15 +50,19 @@ func parseArgs() (interfaces.AuthTokenProvider, error) { _ = secretCmd.MarkFlagRequired("namespace") var clientID string + var scope string azureCmd := &cobra.Command{ Use: "azure", Args: cobra.NoArgs, Run: func(_ *cobra.Command, args []string) { - tokenProvider = azure.New(clientID) + tokenProvider = azure.New(clientID, scope) }, } azureCmd.Flags().StringVar(&clientID, "clientid", "", "Azure AAD client ID (required)") + // TODO: this scope argument is specific for Azure provider. We should allow registering and parsing provider specific argument + // in provider level, instead of global level. + azureCmd.Flags().StringVar(&scope, "scope", "", "Azure AAD token scope (optional)") _ = azureCmd.MarkFlagRequired("clientid") rootCmd.AddCommand(secretCmd, azureCmd) diff --git a/cmd/authtoken/main_test.go b/cmd/authtoken/main_test.go new file mode 100644 index 000000000..ad00526d0 --- /dev/null +++ b/cmd/authtoken/main_test.go @@ -0,0 +1,39 @@ +package main + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" + + "go.goms.io/fleet/pkg/authtoken/providers/azure" +) + +func TestParseArgs(t *testing.T) { + t.Run("all arguments", func(t *testing.T) { + os.Args = []string{"refreshtoken", "azure", "--clientid=test-client-id", "--scope=test-scope"} + t.Cleanup(func() { + os.Args = nil + }) + tokenProvider, err := parseArgs() + assert.NotNil(t, tokenProvider) + assert.Nil(t, err) + + azTokenProvider, ok := tokenProvider.(*azure.AuthTokenProvider) + assert.Equal(t, true, ok) + assert.Equal(t, "test-scope", azTokenProvider.Scope) + }) + t.Run("no optional arguments", func(t *testing.T) { + os.Args = []string{"refreshtoken", "azure", "--clientid=test-client-id"} + t.Cleanup(func() { + os.Args = nil + }) + tokenProvider, err := parseArgs() + assert.NotNil(t, tokenProvider) + assert.Nil(t, err) + + azTokenProvider, ok := tokenProvider.(*azure.AuthTokenProvider) + assert.Equal(t, true, ok) + assert.Equal(t, "6dae42f8-4368-4678-94ff-3960e28e3630", azTokenProvider.Scope) + }) +} diff --git a/pkg/authtoken/providers/azure/azure_msi.go b/pkg/authtoken/providers/azure/azure_msi.go index c01d7f7a1..c481ec008 100644 --- a/pkg/authtoken/providers/azure/azure_msi.go +++ b/pkg/authtoken/providers/azure/azure_msi.go @@ -21,22 +21,27 @@ const ( aksScope = "6dae42f8-4368-4678-94ff-3960e28e3630" ) -type azureAuthTokenProvider struct { - clientID string +type AuthTokenProvider struct { + ClientID string + Scope string } -func New(clientID string) interfaces.AuthTokenProvider { - return &azureAuthTokenProvider{ - clientID: clientID, +func New(clientID, scope string) interfaces.AuthTokenProvider { + if scope == "" { + scope = aksScope + } + return &AuthTokenProvider{ + ClientID: clientID, + Scope: scope, } } // FetchToken gets a new token to make request to the associated fleet' hub cluster. -func (a *azureAuthTokenProvider) FetchToken(ctx context.Context) (interfaces.AuthToken, error) { +func (a *AuthTokenProvider) FetchToken(ctx context.Context) (interfaces.AuthToken, error) { token := interfaces.AuthToken{} - opts := &azidentity.ManagedIdentityCredentialOptions{ID: azidentity.ClientID(a.clientID)} + opts := &azidentity.ManagedIdentityCredentialOptions{ID: azidentity.ClientID(a.ClientID)} - klog.V(2).InfoS("FetchToken", "client ID", a.clientID) + klog.V(2).InfoS("FetchToken", "client ID", a.ClientID) credential, err := azidentity.NewManagedIdentityCredential(opts) if err != nil { return token, fmt.Errorf("failed to create managed identity cred: %w", err) @@ -48,7 +53,7 @@ func (a *azureAuthTokenProvider) FetchToken(ctx context.Context) (interfaces.Aut }, func() error { klog.V(2).InfoS("GetToken start", "credential", credential) azToken, err = credential.GetToken(ctx, policy.TokenRequestOptions{ - Scopes: []string{aksScope}, + Scopes: []string{a.Scope}, }) if err != nil { klog.ErrorS(err, "Failed to GetToken")