diff --git a/pkg/provisioner/terraform/stack.go b/pkg/provisioner/terraform/stack.go index 25aef2fd5..41804b49d 100644 --- a/pkg/provisioner/terraform/stack.go +++ b/pkg/provisioner/terraform/stack.go @@ -159,6 +159,13 @@ func (s *TerraformStack) Up(blueprint *blueprintv1alpha1.Blueprint) error { return fmt.Errorf("error removing backend override file for %s: %w", component.Path, err) } } + + providersOverridePath := filepath.Join(component.FullPath, "providers_override.tf") + if _, err := s.shims.Stat(providersOverridePath); err == nil { + if err := s.shims.Remove(providersOverridePath); err != nil { + return fmt.Errorf("error removing providers override file for %s: %w", component.Path, err) + } + } } return nil diff --git a/pkg/provisioner/terraform/stack_test.go b/pkg/provisioner/terraform/stack_test.go index 94d4116ea..42510f0a8 100644 --- a/pkg/provisioner/terraform/stack_test.go +++ b/pkg/provisioner/terraform/stack_test.go @@ -444,6 +444,32 @@ func TestStack_Up(t *testing.T) { t.Errorf("Expected remove error, got: %v", err) } }) + + t.Run("ErrorRemovingProvidersOverride", func(t *testing.T) { + stack, mocks := setup(t) + projectRoot := os.Getenv("WINDSOR_PROJECT_ROOT") + providersOverridePath := filepath.Join(projectRoot, ".windsor", "contexts", "local", "remote", "path", "providers_override.tf") + if err := os.MkdirAll(filepath.Dir(providersOverridePath), 0755); err != nil { + t.Fatalf("Failed to create directory: %v", err) + } + if err := os.WriteFile(providersOverridePath, []byte("test"), 0644); err != nil { + t.Fatalf("Failed to create providers override file: %v", err) + } + mocks.Shims.Remove = func(path string) error { + if strings.Contains(path, "providers_override.tf") { + return fmt.Errorf("remove error") + } + return nil + } + blueprint := createTestBlueprint() + err := stack.Up(blueprint) + if err == nil { + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "error removing providers override file") { + t.Errorf("Expected remove error, got: %v", err) + } + }) } func TestStack_Down(t *testing.T) { diff --git a/pkg/runtime/env/terraform_env.go b/pkg/runtime/env/terraform_env.go index 356e5ff21..125f87be8 100644 --- a/pkg/runtime/env/terraform_env.go +++ b/pkg/runtime/env/terraform_env.go @@ -546,15 +546,15 @@ func (e *TerraformEnvPrinter) generateProvidersOverrideTf(directory ...string) e } azureClientSecret := e.shims.Getenv("AZURE_CLIENT_SECRET") + azureFederatedTokenFile := e.shims.Getenv("AZURE_FEDERATED_TOKEN_FILE") - if azureClientSecret == "" { - providersOverridePath := filepath.Join(currentPath, "providers_override.tf") - if _, err := e.shims.Stat(providersOverridePath); err == nil { - if err := e.shims.Remove(providersOverridePath); err != nil { - return fmt.Errorf("error removing providers_override.tf: %w", err) - } - } - return nil + var loginMode string + if azureFederatedTokenFile != "" { + loginMode = "workloadidentity" + } else if azureClientSecret != "" { + loginMode = "spn" + } else { + loginMode = "azurecli" } providerConfig := fmt.Sprintf(`provider "kubernetes" { @@ -563,13 +563,13 @@ func (e *TerraformEnvPrinter) generateProvidersOverrideTf(directory ...string) e command = "kubelogin" args = [ "get-token", - "--login", "spn", + "--login", "%s", "--environment", "%s", "--server-id", "%s", ] } } -`, azureEnv, constants.DefaultAKSOIDCServerID) +`, loginMode, azureEnv, constants.DefaultAKSOIDCServerID) providersOverridePath := filepath.Join(currentPath, "providers_override.tf") diff --git a/pkg/runtime/env/terraform_env_test.go b/pkg/runtime/env/terraform_env_test.go index 6f768a7e0..d528c3edd 100644 --- a/pkg/runtime/env/terraform_env_test.go +++ b/pkg/runtime/env/terraform_env_test.go @@ -927,7 +927,100 @@ func TestTerraformEnv_generateProvidersOverrideTf(t *testing.T) { return printer, mocks } - t.Run("Success", func(t *testing.T) { + t.Run("SuccessWithWorkloadIdentity", func(t *testing.T) { + // Given a TerraformEnvPrinter with Azure + AKS enabled and AZURE_FEDERATED_TOKEN_FILE set + printer, mocks := setup(t) + mocks.ConfigHandler.Set("azure.enabled", true) + mocks.ConfigHandler.Set("cluster.driver", "aks") + mocks.Shims.Getenv = func(key string) string { + if key == "AZURE_FEDERATED_TOKEN_FILE" { + return "/path/to/token/file" + } + return "" + } + + // Mock WriteFile to capture the output + var writtenData []byte + mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error { + writtenData = data + return nil + } + + // When generateProvidersOverrideTf is called + err := printer.generateProvidersOverrideTf() + + // Then no error should occur and the expected provider config with Workload Identity should be written + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + expectedContent := fmt.Sprintf(`provider "kubernetes" { + exec { + api_version = "client.authentication.k8s.io/v1beta1" + command = "kubelogin" + args = [ + "get-token", + "--login", "workloadidentity", + "--environment", "AzurePublicCloud", + "--server-id", "%s", + ] + } +} +`, constants.DefaultAKSOIDCServerID) + if string(writtenData) != expectedContent { + t.Errorf("Expected provider config %q, got %q", expectedContent, string(writtenData)) + } + }) + + t.Run("WorkloadIdentityPriorityOverSPN", func(t *testing.T) { + // Given a TerraformEnvPrinter with both AZURE_FEDERATED_TOKEN_FILE and AZURE_CLIENT_SECRET set + printer, mocks := setup(t) + mocks.ConfigHandler.Set("azure.enabled", true) + mocks.ConfigHandler.Set("cluster.driver", "aks") + mocks.Shims.Getenv = func(key string) string { + if key == "AZURE_FEDERATED_TOKEN_FILE" { + return "/path/to/token/file" + } + if key == "AZURE_CLIENT_SECRET" { + return "test-secret" + } + return "" + } + + // Mock WriteFile to capture the output + var writtenData []byte + mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error { + writtenData = data + return nil + } + + // When generateProvidersOverrideTf is called + err := printer.generateProvidersOverrideTf() + + // Then no error should occur and Workload Identity should be used (higher priority) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + expectedContent := fmt.Sprintf(`provider "kubernetes" { + exec { + api_version = "client.authentication.k8s.io/v1beta1" + command = "kubelogin" + args = [ + "get-token", + "--login", "workloadidentity", + "--environment", "AzurePublicCloud", + "--server-id", "%s", + ] + } +} +`, constants.DefaultAKSOIDCServerID) + if string(writtenData) != expectedContent { + t.Errorf("Expected provider config %q, got %q", expectedContent, string(writtenData)) + } + }) + + t.Run("SuccessWithSPN", func(t *testing.T) { // Given a TerraformEnvPrinter with Azure + AKS enabled and AZURE_CLIENT_SECRET set printer, mocks := setup(t) mocks.ConfigHandler.Set("azure.enabled", true) @@ -949,7 +1042,7 @@ func TestTerraformEnv_generateProvidersOverrideTf(t *testing.T) { // When generateProvidersOverrideTf is called err := printer.generateProvidersOverrideTf() - // Then no error should occur and the expected provider config should be written + // Then no error should occur and the expected provider config with SPN should be written if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -972,6 +1065,48 @@ func TestTerraformEnv_generateProvidersOverrideTf(t *testing.T) { } }) + t.Run("SuccessWithAzureCLI", func(t *testing.T) { + // Given a TerraformEnvPrinter with Azure + AKS enabled but no AZURE_CLIENT_SECRET (fallback to Azure CLI) + printer, mocks := setup(t) + mocks.ConfigHandler.Set("azure.enabled", true) + mocks.ConfigHandler.Set("cluster.driver", "aks") + mocks.Shims.Getenv = func(key string) string { + return "" + } + + // Mock WriteFile to capture the output + var writtenData []byte + mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error { + writtenData = data + return nil + } + + // When generateProvidersOverrideTf is called + err := printer.generateProvidersOverrideTf() + + // Then no error should occur and the expected provider config with azurecli should be written + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + expectedContent := fmt.Sprintf(`provider "kubernetes" { + exec { + api_version = "client.authentication.k8s.io/v1beta1" + command = "kubelogin" + args = [ + "get-token", + "--login", "azurecli", + "--environment", "AzurePublicCloud", + "--server-id", "%s", + ] + } +} +`, constants.DefaultAKSOIDCServerID) + if string(writtenData) != expectedContent { + t.Errorf("Expected provider config %q, got %q", expectedContent, string(writtenData)) + } + }) + t.Run("AzureNotEnabled", func(t *testing.T) { // Given a TerraformEnvPrinter with Azure disabled printer, mocks := setup(t) @@ -1095,37 +1230,36 @@ func TestTerraformEnv_generateProvidersOverrideTf(t *testing.T) { return "" } - // Mock Stat and Remove to verify file deletion - fileExists := true - mocks.Shims.Stat = func(name string) (os.FileInfo, error) { - if strings.Contains(name, "providers_override.tf") { - if fileExists { - return nil, nil - } - return nil, os.ErrNotExist - } - return nil, os.ErrNotExist - } - - var fileRemoved bool - mocks.Shims.Remove = func(name string) error { - if strings.Contains(name, "providers_override.tf") { - fileRemoved = true - fileExists = false - return nil - } + // Mock WriteFile to capture the output + var writtenData []byte + mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error { + writtenData = data return nil } // When generateProvidersOverrideTf is called err := printer.generateProvidersOverrideTf() - // Then no error should occur and the file should be removed + // Then no error should occur and the expected provider config with azurecli should be written if err != nil { t.Errorf("Expected no error, got %v", err) } - if !fileRemoved { - t.Error("Expected providers_override.tf to be removed") + + expectedContent := fmt.Sprintf(`provider "kubernetes" { + exec { + api_version = "client.authentication.k8s.io/v1beta1" + command = "kubelogin" + args = [ + "get-token", + "--login", "azurecli", + "--environment", "AzurePublicCloud", + "--server-id", "%s", + ] + } +} +`, constants.DefaultAKSOIDCServerID) + if string(writtenData) != expectedContent { + t.Errorf("Expected provider config %q, got %q", expectedContent, string(writtenData)) } }) diff --git a/pkg/runtime/tools/tools_manager.go b/pkg/runtime/tools/tools_manager.go index e8af66e43..9c550d342 100644 --- a/pkg/runtime/tools/tools_manager.go +++ b/pkg/runtime/tools/tools_manager.go @@ -258,16 +258,25 @@ func (t *BaseToolsManager) checkKubelogin() error { return fmt.Errorf("kubelogin version %s is below the minimum required version %s", version, constants.MinimumVersionKubelogin) } - azureClientSecret := os.Getenv("AZURE_CLIENT_SECRET") - if azureClientSecret != "" { - azureClientID := os.Getenv("AZURE_CLIENT_ID") - azureTenantID := os.Getenv("AZURE_TENANT_ID") - - if azureClientID == "" { - return fmt.Errorf("AZURE_CLIENT_SECRET is set but AZURE_CLIENT_ID is missing - both are required for SPN authentication") - } - if azureTenantID == "" { - return fmt.Errorf("AZURE_CLIENT_SECRET is set but AZURE_TENANT_ID is missing - both are required for SPN authentication") + validationRules := []struct { + triggerVar string + authMethod string + }{ + {"AZURE_FEDERATED_TOKEN_FILE", "Workload Identity"}, + {"AZURE_CLIENT_SECRET", "SPN"}, + } + + for _, rule := range validationRules { + if os.Getenv(rule.triggerVar) != "" { + azureClientID := os.Getenv("AZURE_CLIENT_ID") + azureTenantID := os.Getenv("AZURE_TENANT_ID") + + if azureClientID == "" { + return fmt.Errorf("%s is set but AZURE_CLIENT_ID is missing - both are required for %s authentication", rule.triggerVar, rule.authMethod) + } + if azureTenantID == "" { + return fmt.Errorf("%s is set but AZURE_TENANT_ID is missing - both are required for %s authentication", rule.triggerVar, rule.authMethod) + } } } diff --git a/pkg/runtime/tools/tools_manager_test.go b/pkg/runtime/tools/tools_manager_test.go index 4e46ebdf7..0c7ce9b91 100644 --- a/pkg/runtime/tools/tools_manager_test.go +++ b/pkg/runtime/tools/tools_manager_test.go @@ -983,6 +983,65 @@ func TestToolsManager_checkKubelogin(t *testing.T) { } }) + t.Run("AZURE_FEDERATED_TOKEN_FILESetButAZURE_CLIENT_IDMissing", func(t *testing.T) { + // Given AZURE_FEDERATED_TOKEN_FILE is set but AZURE_CLIENT_ID is missing + mocks, toolsManager := setup(t) + originalExecLookPath := execLookPath + execLookPath = func(name string) (string, error) { + if name == "kubelogin" { + return "/usr/bin/kubelogin", nil + } + return originalExecLookPath(name) + } + mocks.Shell.ExecSilentFunc = func(name string, args ...string) (string, error) { + if name == "kubelogin" && args[0] == "--version" { + return fmt.Sprintf("kubelogin version %s", constants.MinimumVersionKubelogin), nil + } + return "", fmt.Errorf("command not found") + } + os.Setenv("AZURE_FEDERATED_TOKEN_FILE", "/path/to/token") + defer os.Unsetenv("AZURE_FEDERATED_TOKEN_FILE") + os.Unsetenv("AZURE_CLIENT_ID") + os.Unsetenv("AZURE_TENANT_ID") + // When checking kubelogin + err := toolsManager.checkKubelogin() + // Then an error indicating AZURE_CLIENT_ID is missing should be returned + if err == nil || !strings.Contains(err.Error(), "AZURE_FEDERATED_TOKEN_FILE is set but AZURE_CLIENT_ID is missing") { + t.Errorf("Expected AZURE_CLIENT_ID missing error, got %v", err) + } + }) + + t.Run("AZURE_FEDERATED_TOKEN_FILESetButAZURE_TENANT_IDMissing", func(t *testing.T) { + // Given AZURE_FEDERATED_TOKEN_FILE is set but AZURE_TENANT_ID is missing + mocks, toolsManager := setup(t) + originalExecLookPath := execLookPath + execLookPath = func(name string) (string, error) { + if name == "kubelogin" { + return "/usr/bin/kubelogin", nil + } + return originalExecLookPath(name) + } + mocks.Shell.ExecSilentFunc = func(name string, args ...string) (string, error) { + if name == "kubelogin" && args[0] == "--version" { + return fmt.Sprintf("kubelogin version %s", constants.MinimumVersionKubelogin), nil + } + return "", fmt.Errorf("command not found") + } + os.Setenv("AZURE_FEDERATED_TOKEN_FILE", "/path/to/token") + os.Setenv("AZURE_CLIENT_ID", "test-client-id") + defer func() { + os.Unsetenv("AZURE_FEDERATED_TOKEN_FILE") + os.Unsetenv("AZURE_CLIENT_ID") + }() + os.Unsetenv("AZURE_TENANT_ID") + // When checking kubelogin + err := toolsManager.checkKubelogin() + // Then an error indicating AZURE_TENANT_ID is missing should be returned + if err == nil || !strings.Contains(err.Error(), "AZURE_FEDERATED_TOKEN_FILE is set but AZURE_TENANT_ID is missing") { + t.Errorf("Expected AZURE_TENANT_ID missing error, got %v", err) + } + }) + t.Run("AZURE_CLIENT_SECRETSetButAZURE_CLIENT_IDMissing", func(t *testing.T) { // Given AZURE_CLIENT_SECRET is set but AZURE_CLIENT_ID is missing mocks, toolsManager := setup(t)