diff --git a/pkg/command/commandbuilder.go b/pkg/command/commandbuilder.go index 7b1452b0..ea02c12e 100644 --- a/pkg/command/commandbuilder.go +++ b/pkg/command/commandbuilder.go @@ -111,6 +111,10 @@ func appendCloudProviderOptions( "--cloud-provider", "azure-blob-storage") + if useDefaultAzureCredentials(ctx) { + break + } + if !credentials.Azure.InheritFromAzureAD { break } @@ -143,3 +147,26 @@ func appendCloudProviderOptions( return options, nil } + +type contextKey string + +// contextKeyUseDefaultAzureCredentials contains a bool indicating if the default azure credentials should be used +const contextKeyUseDefaultAzureCredentials contextKey = "useDefaultAzureCredentials" + +func useDefaultAzureCredentials(ctx context.Context) bool { + v := ctx.Value(contextKeyUseDefaultAzureCredentials) + if v == nil { + return false + } + result, ok := v.(bool) + if !ok { + return false + } + return result +} + +// ContextWithDefaultAzureCredentials create a context that contains the contextKeyUseDefaultAzureCredentials flag. +// When set to true barman-cloud will use the default Azure credentials. +func ContextWithDefaultAzureCredentials(ctx context.Context, enabled bool) context.Context { + return context.WithValue(ctx, contextKeyUseDefaultAzureCredentials, enabled) +} diff --git a/pkg/command/commandbuilder_test.go b/pkg/command/commandbuilder_test.go index 00d661c0..9f7e16e4 100644 --- a/pkg/command/commandbuilder_test.go +++ b/pkg/command/commandbuilder_test.go @@ -17,6 +17,7 @@ limitations under the License. package command import ( + "context" "strings" barmanApi "github.com/cloudnative-pg/barman-cloud/pkg/api" @@ -57,3 +58,24 @@ var _ = Describe("barmanCloudWalRestoreOptions", func() { )) }) }) + +var _ = Describe("useDefaultAzureCredentials", func() { + It("should be false by default", func(ctx SpecContext) { + Expect(useDefaultAzureCredentials(ctx)).To(BeFalse()) + }) + + It("should be false if ctx contains an invalid value", func(ctx SpecContext) { + newCtx := context.WithValue(ctx, contextKeyUseDefaultAzureCredentials, "invalidValue") + Expect(useDefaultAzureCredentials(newCtx)).To(BeFalse()) + }) + + It("should be false if ctx contains false value", func(ctx SpecContext) { + newCtx := context.WithValue(ctx, contextKeyUseDefaultAzureCredentials, false) + Expect(useDefaultAzureCredentials(newCtx)).To(BeFalse()) + }) + + It("should be true only if ctx contains true value", func(ctx SpecContext) { + newCtx := context.WithValue(ctx, contextKeyUseDefaultAzureCredentials, true) + Expect(useDefaultAzureCredentials(newCtx)).To(BeTrue()) + }) +})