Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pkg/provisioner/terraform/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@ func (s *TerraformStack) Up(blueprint *blueprintv1alpha1.Blueprint) error {
return fmt.Errorf("error removing backend override file for %s: %w", component.Path, err)
}
}

providersOverridePath := filepath.Join(component.FullPath, "providers_override.tf")
if _, err := s.shims.Stat(providersOverridePath); err == nil {
if err := s.shims.Remove(providersOverridePath); err != nil {
return fmt.Errorf("error removing providers override file for %s: %w", component.Path, err)
}
}
}

return nil
Expand Down
26 changes: 26 additions & 0 deletions pkg/provisioner/terraform/stack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,32 @@ func TestStack_Up(t *testing.T) {
t.Errorf("Expected remove error, got: %v", err)
}
})

t.Run("ErrorRemovingProvidersOverride", func(t *testing.T) {
stack, mocks := setup(t)
projectRoot := os.Getenv("WINDSOR_PROJECT_ROOT")
providersOverridePath := filepath.Join(projectRoot, ".windsor", "contexts", "local", "remote", "path", "providers_override.tf")
if err := os.MkdirAll(filepath.Dir(providersOverridePath), 0755); err != nil {
t.Fatalf("Failed to create directory: %v", err)
}
if err := os.WriteFile(providersOverridePath, []byte("test"), 0644); err != nil {
t.Fatalf("Failed to create providers override file: %v", err)
}
mocks.Shims.Remove = func(path string) error {
if strings.Contains(path, "providers_override.tf") {
return fmt.Errorf("remove error")
}
return nil
}
blueprint := createTestBlueprint()
err := stack.Up(blueprint)
if err == nil {
t.Error("Expected error, got nil")
}
if !strings.Contains(err.Error(), "error removing providers override file") {
t.Errorf("Expected remove error, got: %v", err)
}
})
}

func TestStack_Down(t *testing.T) {
Expand Down
20 changes: 10 additions & 10 deletions pkg/runtime/env/terraform_env.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,15 +546,15 @@ func (e *TerraformEnvPrinter) generateProvidersOverrideTf(directory ...string) e
}

azureClientSecret := e.shims.Getenv("AZURE_CLIENT_SECRET")
azureFederatedTokenFile := e.shims.Getenv("AZURE_FEDERATED_TOKEN_FILE")

if azureClientSecret == "" {
providersOverridePath := filepath.Join(currentPath, "providers_override.tf")
if _, err := e.shims.Stat(providersOverridePath); err == nil {
if err := e.shims.Remove(providersOverridePath); err != nil {
return fmt.Errorf("error removing providers_override.tf: %w", err)
}
}
return nil
var loginMode string
if azureFederatedTokenFile != "" {
loginMode = "workloadidentity"
} else if azureClientSecret != "" {
loginMode = "spn"
} else {
loginMode = "azurecli"
}

providerConfig := fmt.Sprintf(`provider "kubernetes" {
Expand All @@ -563,13 +563,13 @@ func (e *TerraformEnvPrinter) generateProvidersOverrideTf(directory ...string) e
command = "kubelogin"
args = [
"get-token",
"--login", "spn",
"--login", "%s",
"--environment", "%s",
"--server-id", "%s",
]
}
}
`, azureEnv, constants.DefaultAKSOIDCServerID)
`, loginMode, azureEnv, constants.DefaultAKSOIDCServerID)

providersOverridePath := filepath.Join(currentPath, "providers_override.tf")

Expand Down
182 changes: 158 additions & 24 deletions pkg/runtime/env/terraform_env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,100 @@ func TestTerraformEnv_generateProvidersOverrideTf(t *testing.T) {
return printer, mocks
}

t.Run("Success", func(t *testing.T) {
t.Run("SuccessWithWorkloadIdentity", func(t *testing.T) {
// Given a TerraformEnvPrinter with Azure + AKS enabled and AZURE_FEDERATED_TOKEN_FILE set
printer, mocks := setup(t)
mocks.ConfigHandler.Set("azure.enabled", true)
mocks.ConfigHandler.Set("cluster.driver", "aks")
mocks.Shims.Getenv = func(key string) string {
if key == "AZURE_FEDERATED_TOKEN_FILE" {
return "/path/to/token/file"
}
return ""
}

// Mock WriteFile to capture the output
var writtenData []byte
mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error {
writtenData = data
return nil
}

// When generateProvidersOverrideTf is called
err := printer.generateProvidersOverrideTf()

// Then no error should occur and the expected provider config with Workload Identity should be written
if err != nil {
t.Errorf("Expected no error, got %v", err)
}

expectedContent := fmt.Sprintf(`provider "kubernetes" {
exec {
api_version = "client.authentication.k8s.io/v1beta1"
command = "kubelogin"
args = [
"get-token",
"--login", "workloadidentity",
"--environment", "AzurePublicCloud",
"--server-id", "%s",
]
}
}
`, constants.DefaultAKSOIDCServerID)
if string(writtenData) != expectedContent {
t.Errorf("Expected provider config %q, got %q", expectedContent, string(writtenData))
}
})

t.Run("WorkloadIdentityPriorityOverSPN", func(t *testing.T) {
// Given a TerraformEnvPrinter with both AZURE_FEDERATED_TOKEN_FILE and AZURE_CLIENT_SECRET set
printer, mocks := setup(t)
mocks.ConfigHandler.Set("azure.enabled", true)
mocks.ConfigHandler.Set("cluster.driver", "aks")
mocks.Shims.Getenv = func(key string) string {
if key == "AZURE_FEDERATED_TOKEN_FILE" {
return "/path/to/token/file"
}
if key == "AZURE_CLIENT_SECRET" {
return "test-secret"
}
return ""
}

// Mock WriteFile to capture the output
var writtenData []byte
mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error {
writtenData = data
return nil
}

// When generateProvidersOverrideTf is called
err := printer.generateProvidersOverrideTf()

// Then no error should occur and Workload Identity should be used (higher priority)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}

expectedContent := fmt.Sprintf(`provider "kubernetes" {
exec {
api_version = "client.authentication.k8s.io/v1beta1"
command = "kubelogin"
args = [
"get-token",
"--login", "workloadidentity",
"--environment", "AzurePublicCloud",
"--server-id", "%s",
]
}
}
`, constants.DefaultAKSOIDCServerID)
if string(writtenData) != expectedContent {
t.Errorf("Expected provider config %q, got %q", expectedContent, string(writtenData))
}
})

t.Run("SuccessWithSPN", func(t *testing.T) {
// Given a TerraformEnvPrinter with Azure + AKS enabled and AZURE_CLIENT_SECRET set
printer, mocks := setup(t)
mocks.ConfigHandler.Set("azure.enabled", true)
Expand All @@ -949,7 +1042,7 @@ func TestTerraformEnv_generateProvidersOverrideTf(t *testing.T) {
// When generateProvidersOverrideTf is called
err := printer.generateProvidersOverrideTf()

// Then no error should occur and the expected provider config should be written
// Then no error should occur and the expected provider config with SPN should be written
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
Expand All @@ -972,6 +1065,48 @@ func TestTerraformEnv_generateProvidersOverrideTf(t *testing.T) {
}
})

t.Run("SuccessWithAzureCLI", func(t *testing.T) {
// Given a TerraformEnvPrinter with Azure + AKS enabled but no AZURE_CLIENT_SECRET (fallback to Azure CLI)
printer, mocks := setup(t)
mocks.ConfigHandler.Set("azure.enabled", true)
mocks.ConfigHandler.Set("cluster.driver", "aks")
mocks.Shims.Getenv = func(key string) string {
return ""
}

// Mock WriteFile to capture the output
var writtenData []byte
mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error {
writtenData = data
return nil
}

// When generateProvidersOverrideTf is called
err := printer.generateProvidersOverrideTf()

// Then no error should occur and the expected provider config with azurecli should be written
if err != nil {
t.Errorf("Expected no error, got %v", err)
}

expectedContent := fmt.Sprintf(`provider "kubernetes" {
exec {
api_version = "client.authentication.k8s.io/v1beta1"
command = "kubelogin"
args = [
"get-token",
"--login", "azurecli",
"--environment", "AzurePublicCloud",
"--server-id", "%s",
]
}
}
`, constants.DefaultAKSOIDCServerID)
if string(writtenData) != expectedContent {
t.Errorf("Expected provider config %q, got %q", expectedContent, string(writtenData))
}
})

t.Run("AzureNotEnabled", func(t *testing.T) {
// Given a TerraformEnvPrinter with Azure disabled
printer, mocks := setup(t)
Expand Down Expand Up @@ -1095,37 +1230,36 @@ func TestTerraformEnv_generateProvidersOverrideTf(t *testing.T) {
return ""
}

// Mock Stat and Remove to verify file deletion
fileExists := true
mocks.Shims.Stat = func(name string) (os.FileInfo, error) {
if strings.Contains(name, "providers_override.tf") {
if fileExists {
return nil, nil
}
return nil, os.ErrNotExist
}
return nil, os.ErrNotExist
}

var fileRemoved bool
mocks.Shims.Remove = func(name string) error {
if strings.Contains(name, "providers_override.tf") {
fileRemoved = true
fileExists = false
return nil
}
// Mock WriteFile to capture the output
var writtenData []byte
mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error {
writtenData = data
return nil
}

// When generateProvidersOverrideTf is called
err := printer.generateProvidersOverrideTf()

// Then no error should occur and the file should be removed
// Then no error should occur and the expected provider config with azurecli should be written
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if !fileRemoved {
t.Error("Expected providers_override.tf to be removed")

expectedContent := fmt.Sprintf(`provider "kubernetes" {
exec {
api_version = "client.authentication.k8s.io/v1beta1"
command = "kubelogin"
args = [
"get-token",
"--login", "azurecli",
"--environment", "AzurePublicCloud",
"--server-id", "%s",
]
}
}
`, constants.DefaultAKSOIDCServerID)
if string(writtenData) != expectedContent {
t.Errorf("Expected provider config %q, got %q", expectedContent, string(writtenData))
}
})

Expand Down
29 changes: 19 additions & 10 deletions pkg/runtime/tools/tools_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,16 +258,25 @@ func (t *BaseToolsManager) checkKubelogin() error {
return fmt.Errorf("kubelogin version %s is below the minimum required version %s", version, constants.MinimumVersionKubelogin)
}

azureClientSecret := os.Getenv("AZURE_CLIENT_SECRET")
if azureClientSecret != "" {
azureClientID := os.Getenv("AZURE_CLIENT_ID")
azureTenantID := os.Getenv("AZURE_TENANT_ID")

if azureClientID == "" {
return fmt.Errorf("AZURE_CLIENT_SECRET is set but AZURE_CLIENT_ID is missing - both are required for SPN authentication")
}
if azureTenantID == "" {
return fmt.Errorf("AZURE_CLIENT_SECRET is set but AZURE_TENANT_ID is missing - both are required for SPN authentication")
validationRules := []struct {
triggerVar string
authMethod string
}{
{"AZURE_FEDERATED_TOKEN_FILE", "Workload Identity"},
{"AZURE_CLIENT_SECRET", "SPN"},
}

for _, rule := range validationRules {
if os.Getenv(rule.triggerVar) != "" {
azureClientID := os.Getenv("AZURE_CLIENT_ID")
azureTenantID := os.Getenv("AZURE_TENANT_ID")

if azureClientID == "" {
return fmt.Errorf("%s is set but AZURE_CLIENT_ID is missing - both are required for %s authentication", rule.triggerVar, rule.authMethod)
}
if azureTenantID == "" {
return fmt.Errorf("%s is set but AZURE_TENANT_ID is missing - both are required for %s authentication", rule.triggerVar, rule.authMethod)
}
}
}

Expand Down
Loading
Loading