From feaaced63c4a7cbfb71dd17e0f90f10ab1136e1c Mon Sep 17 00:00:00 2001 From: Ryan VanGundy <85766511+rmvangun@users.noreply.github.com> Date: Wed, 10 Dec 2025 00:38:23 -0500 Subject: [PATCH] feat(terraform): Generate kubelogin spn auth override for Azure AKS For Azure AKS, service principals are expected to use kubelogin to authenticate with an AKS cluster when accessing the cluster using Entra AD. Now, a `provider_override.tf` file gets generated when using AKS that includes appropriate values for properly using non-interactive kubelogin to authenticate with the cluster. Signed-off-by: Ryan VanGundy <85766511+rmvangun@users.noreply.github.com> --- pkg/composer/composer.go | 1 + pkg/composer/composer_test.go | 2 +- pkg/constants/constants.go | 10 + pkg/provisioner/terraform/stack.go | 4 + pkg/provisioner/terraform/stack_test.go | 26 ++ pkg/runtime/env/terraform_env.go | 91 +++++- pkg/runtime/env/terraform_env_test.go | 382 ++++++++++++++++++++++++ pkg/runtime/tools/tools_manager.go | 47 +++ pkg/runtime/tools/tools_manager_test.go | 215 +++++++++++++ 9 files changed, 776 insertions(+), 2 deletions(-) diff --git a/pkg/composer/composer.go b/pkg/composer/composer.go index dd16b51af..ec06f8339 100644 --- a/pkg/composer/composer.go +++ b/pkg/composer/composer.go @@ -176,6 +176,7 @@ func (r *Composer) generateGitignore() error { ".windsor/", ".volumes/", "terraform/**/backend_override.tf", + "terraform/**/providers_override.tf", "contexts/**/.kube/", "contexts/**/.talos/", "contexts/**/.omni/", diff --git a/pkg/composer/composer_test.go b/pkg/composer/composer_test.go index 75a4794bd..12339eccc 100644 --- a/pkg/composer/composer_test.go +++ b/pkg/composer/composer_test.go @@ -1114,7 +1114,7 @@ func TestComposer_generateGitignore(t *testing.T) { } contentStr := string(content) - requiredEntries := []string{".windsor/", ".volumes/", "terraform/**/backend_override.tf"} + requiredEntries := []string{".windsor/", ".volumes/", "terraform/**/backend_override.tf", "terraform/**/providers_override.tf"} for _, entry := range requiredEntries { if !strings.Contains(contentStr, entry) { t.Errorf("Expected .gitignore to contain %s", entry) diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index 136f9a301..b9c0d762c 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -112,6 +112,16 @@ const MinimumVersion1Password = "2.15.0" const MinimumVersionAWSCLI = "2.15.0" +const MinimumVersionKubelogin = "0.1.7" + +// DefaultAKSOIDCServerID is the standard Azure AKS OIDC server ID (application ID of the +// Microsoft-managed enterprise application "Azure Kubernetes Service AAD Server"). +// This is the same for all AKS clusters with AKS-managed Azure AD enabled. +const DefaultAKSOIDCServerID = "6dae42f8-4368-4678-94ff-3960e28e3630" + +// DefaultAKSOIDCClientID is the standard Azure AKS OIDC client ID used for all AKS clusters. +const DefaultAKSOIDCClientID = "80faf920-1908-4b52-b5ef-a8e7bedfc67a" + const DefaultNodeHealthCheckTimeout = 5 * time.Minute const DefaultNodeHealthCheckPollInterval = 10 * time.Second diff --git a/pkg/provisioner/terraform/stack.go b/pkg/provisioner/terraform/stack.go index 3d8517cb2..25aef2fd5 100644 --- a/pkg/provisioner/terraform/stack.go +++ b/pkg/provisioner/terraform/stack.go @@ -247,6 +247,10 @@ func (s *TerraformStack) Down(blueprint *blueprintv1alpha1.Blueprint) error { if err := s.shims.Remove(filepath.Join(component.FullPath, "backend_override.tf")); err != nil && !os.IsNotExist(err) { return fmt.Errorf("error removing backend_override.tf from %s: %w", component.Path, err) } + + if err := s.shims.Remove(filepath.Join(component.FullPath, "providers_override.tf")); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("error removing providers_override.tf from %s: %w", component.Path, err) + } } return nil diff --git a/pkg/provisioner/terraform/stack_test.go b/pkg/provisioner/terraform/stack_test.go index 0a328604b..94d4116ea 100644 --- a/pkg/provisioner/terraform/stack_test.go +++ b/pkg/provisioner/terraform/stack_test.go @@ -697,6 +697,32 @@ func TestStack_Down(t *testing.T) { t.Errorf("Expected remove error, got: %v", err) } }) + + t.Run("ErrorRemovingProvidersOverride", func(t *testing.T) { + stack, mocks := setup(t) + projectRoot := os.Getenv("WINDSOR_PROJECT_ROOT") + providersOverridePath := filepath.Join(projectRoot, ".windsor", "contexts", "local", "remote", "path", "providers_override.tf") + if err := os.MkdirAll(filepath.Dir(providersOverridePath), 0755); err != nil { + t.Fatalf("Failed to create directory: %v", err) + } + if err := os.WriteFile(providersOverridePath, []byte("test"), 0644); err != nil { + t.Fatalf("Failed to create providers override file: %v", err) + } + mocks.Shims.Remove = func(path string) error { + if strings.Contains(path, "providers_override.tf") { + return fmt.Errorf("remove error") + } + return nil + } + blueprint := createTestBlueprint() + err := stack.Down(blueprint) + if err == nil { + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "error removing providers_override.tf") { + t.Errorf("Expected remove error, got: %v", err) + } + }) } func TestNewShims(t *testing.T) { diff --git a/pkg/runtime/env/terraform_env.go b/pkg/runtime/env/terraform_env.go index 457e16273..356e5ff21 100644 --- a/pkg/runtime/env/terraform_env.go +++ b/pkg/runtime/env/terraform_env.go @@ -15,6 +15,7 @@ import ( "github.com/goccy/go-yaml" blueprintv1alpha1 "github.com/windsorcli/cli/api/v1alpha1" + "github.com/windsorcli/cli/pkg/constants" "github.com/windsorcli/cli/pkg/runtime/config" "github.com/windsorcli/cli/pkg/runtime/shell" ) @@ -106,7 +107,10 @@ func (e *TerraformEnvPrinter) GetEnvVars() (map[string]string, error) { // PostEnvHook executes operations after setting the environment variables. func (e *TerraformEnvPrinter) PostEnvHook(directory ...string) error { - return e.generateBackendOverrideTf(directory...) + if err := e.generateBackendOverrideTf(directory...); err != nil { + return err + } + return e.generateProvidersOverrideTf(directory...) } // GenerateTerraformArgs constructs Terraform CLI arguments and environment variables for given project and module paths. @@ -492,6 +496,91 @@ func (e *TerraformEnvPrinter) generateBackendOverrideTf(directory ...string) err return nil } +// generateProvidersOverrideTf creates a providers_override.tf file when using Azure + AKS +// to configure Kubernetes provider authentication via Entra ID using kubelogin with SPN authentication. +// Detects SPN mode when AZURE_CLIENT_SECRET is set and validates required environment variables. +// This enables generic Kubernetes modules to work with AKS clusters using Entra AD authentication +// without requiring provider blocks in the modules themselves. +func (e *TerraformEnvPrinter) generateProvidersOverrideTf(directory ...string) error { + var currentPath string + if len(directory) > 0 { + currentPath = filepath.Clean(directory[0]) + } else { + var err error + currentPath, err = e.shims.Getwd() + if err != nil { + return fmt.Errorf("error getting current directory: %w", err) + } + } + + projectPath, err := e.findRelativeTerraformProjectPath(directory...) + if err != nil { + return fmt.Errorf("error finding project path: %w", err) + } + + if projectPath == "" { + return nil + } + + azureEnabled := e.configHandler.GetBool("azure.enabled", false) + clusterDriver := e.configHandler.GetString("cluster.driver", "") + + if !azureEnabled || clusterDriver != "aks" { + providersOverridePath := filepath.Join(currentPath, "providers_override.tf") + if _, err := e.shims.Stat(providersOverridePath); err == nil { + if err := e.shims.Remove(providersOverridePath); err != nil { + return fmt.Errorf("error removing providers_override.tf: %w", err) + } + } + return nil + } + + config := e.configHandler.GetConfig() + if config == nil || config.Azure == nil { + return nil + } + + azureEnv := "AzurePublicCloud" + if config.Azure.Environment != nil { + azureEnv = *config.Azure.Environment + } + + azureClientSecret := e.shims.Getenv("AZURE_CLIENT_SECRET") + + if azureClientSecret == "" { + providersOverridePath := filepath.Join(currentPath, "providers_override.tf") + if _, err := e.shims.Stat(providersOverridePath); err == nil { + if err := e.shims.Remove(providersOverridePath); err != nil { + return fmt.Errorf("error removing providers_override.tf: %w", err) + } + } + return nil + } + + providerConfig := fmt.Sprintf(`provider "kubernetes" { + exec { + api_version = "client.authentication.k8s.io/v1beta1" + command = "kubelogin" + args = [ + "get-token", + "--login", "spn", + "--environment", "%s", + "--server-id", "%s", + ] + } +} +`, azureEnv, constants.DefaultAKSOIDCServerID) + + providersOverridePath := filepath.Join(currentPath, "providers_override.tf") + + err = e.shims.WriteFile(providersOverridePath, []byte(providerConfig), os.ModePerm) + if err != nil { + return fmt.Errorf("error writing providers_override.tf: %w", err) + } + + return nil +} + // generateBackendConfigArgs constructs backend config args for terraform commands. // It reads the backend type from the config and adds relevant key-value pairs. // The function supports local, s3, kubernetes, and azurerm backends. diff --git a/pkg/runtime/env/terraform_env_test.go b/pkg/runtime/env/terraform_env_test.go index 5bb6d471d..6f768a7e0 100644 --- a/pkg/runtime/env/terraform_env_test.go +++ b/pkg/runtime/env/terraform_env_test.go @@ -10,6 +10,9 @@ import ( "testing" blueprintv1alpha1 "github.com/windsorcli/cli/api/v1alpha1" + v1alpha1 "github.com/windsorcli/cli/api/v1alpha1" + "github.com/windsorcli/cli/api/v1alpha1/azure" + "github.com/windsorcli/cli/pkg/constants" "github.com/windsorcli/cli/pkg/runtime/config" ) @@ -915,6 +918,385 @@ func TestTerraformEnv_generateBackendOverrideTf(t *testing.T) { }) } +func TestTerraformEnv_generateProvidersOverrideTf(t *testing.T) { + setup := func(t *testing.T) (*TerraformEnvPrinter, *EnvTestMocks) { + t.Helper() + mocks := setupTerraformEnvMocks(t) + printer := NewTerraformEnvPrinter(mocks.Shell, mocks.ConfigHandler) + printer.shims = mocks.Shims + return printer, mocks + } + + t.Run("Success", func(t *testing.T) { + // Given a TerraformEnvPrinter with Azure + AKS enabled and AZURE_CLIENT_SECRET set + printer, mocks := setup(t) + mocks.ConfigHandler.Set("azure.enabled", true) + mocks.ConfigHandler.Set("cluster.driver", "aks") + mocks.Shims.Getenv = func(key string) string { + if key == "AZURE_CLIENT_SECRET" { + return "test-secret" + } + return "" + } + + // Mock WriteFile to capture the output + var writtenData []byte + mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error { + writtenData = data + return nil + } + + // When generateProvidersOverrideTf is called + err := printer.generateProvidersOverrideTf() + + // Then no error should occur and the expected provider config should be written + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + expectedContent := fmt.Sprintf(`provider "kubernetes" { + exec { + api_version = "client.authentication.k8s.io/v1beta1" + command = "kubelogin" + args = [ + "get-token", + "--login", "spn", + "--environment", "AzurePublicCloud", + "--server-id", "%s", + ] + } +} +`, constants.DefaultAKSOIDCServerID) + if string(writtenData) != expectedContent { + t.Errorf("Expected provider config %q, got %q", expectedContent, string(writtenData)) + } + }) + + t.Run("AzureNotEnabled", func(t *testing.T) { + // Given a TerraformEnvPrinter with Azure disabled + printer, mocks := setup(t) + mocks.ConfigHandler.Set("azure.enabled", false) + mocks.ConfigHandler.Set("cluster.driver", "aks") + + // Mock Stat and Remove to verify file deletion + fileExists := true + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if strings.Contains(name, "providers_override.tf") { + if fileExists { + return nil, nil + } + return nil, os.ErrNotExist + } + return nil, os.ErrNotExist + } + + var fileRemoved bool + mocks.Shims.Remove = func(name string) error { + if strings.Contains(name, "providers_override.tf") { + fileRemoved = true + fileExists = false + return nil + } + return nil + } + + // When generateProvidersOverrideTf is called + err := printer.generateProvidersOverrideTf() + + // Then no error should occur and the file should be removed + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if !fileRemoved { + t.Error("Expected providers_override.tf to be removed") + } + }) + + t.Run("NotAKS", func(t *testing.T) { + // Given a TerraformEnvPrinter with Azure enabled but not AKS + printer, mocks := setup(t) + mocks.ConfigHandler.Set("azure.enabled", true) + mocks.ConfigHandler.Set("cluster.driver", "eks") + + // Mock Stat and Remove to verify file deletion + fileExists := true + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if strings.Contains(name, "providers_override.tf") { + if fileExists { + return nil, nil + } + return nil, os.ErrNotExist + } + return nil, os.ErrNotExist + } + + var fileRemoved bool + mocks.Shims.Remove = func(name string) error { + if strings.Contains(name, "providers_override.tf") { + fileRemoved = true + fileExists = false + return nil + } + return nil + } + + // When generateProvidersOverrideTf is called + err := printer.generateProvidersOverrideTf() + + // Then no error should occur and the file should be removed + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if !fileRemoved { + t.Error("Expected providers_override.tf to be removed") + } + }) + + t.Run("NoAzureConfig", func(t *testing.T) { + // Given a TerraformEnvPrinter with Azure enabled but no Azure config + printer, mocks := setup(t) + mocks.ConfigHandler.Set("azure.enabled", true) + mocks.ConfigHandler.Set("cluster.driver", "aks") + + // Mock GetConfig to return nil + mockConfigHandler := config.NewMockConfigHandler() + mockConfigHandler.GetConfigFunc = func() *v1alpha1.Context { + return nil + } + mockConfigHandler.GetBoolFunc = func(key string, defaultValue ...bool) bool { + if key == "azure.enabled" { + return true + } + return false + } + mockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "cluster.driver" { + return "aks" + } + return "" + } + printer.configHandler = mockConfigHandler + + // When generateProvidersOverrideTf is called + err := printer.generateProvidersOverrideTf() + + // Then no error should occur + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + t.Run("AZURE_CLIENT_SECRETNotSet", func(t *testing.T) { + // Given a TerraformEnvPrinter with Azure + AKS enabled but no AZURE_CLIENT_SECRET + printer, mocks := setup(t) + mocks.ConfigHandler.Set("azure.enabled", true) + mocks.ConfigHandler.Set("cluster.driver", "aks") + mocks.Shims.Getenv = func(key string) string { + return "" + } + + // Mock Stat and Remove to verify file deletion + fileExists := true + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if strings.Contains(name, "providers_override.tf") { + if fileExists { + return nil, nil + } + return nil, os.ErrNotExist + } + return nil, os.ErrNotExist + } + + var fileRemoved bool + mocks.Shims.Remove = func(name string) error { + if strings.Contains(name, "providers_override.tf") { + fileRemoved = true + fileExists = false + return nil + } + return nil + } + + // When generateProvidersOverrideTf is called + err := printer.generateProvidersOverrideTf() + + // Then no error should occur and the file should be removed + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if !fileRemoved { + t.Error("Expected providers_override.tf to be removed") + } + }) + + t.Run("CustomAzureEnvironment", func(t *testing.T) { + // Given a TerraformEnvPrinter with custom Azure environment + printer, mocks := setup(t) + mocks.ConfigHandler.Set("azure.enabled", true) + mocks.ConfigHandler.Set("cluster.driver", "aks") + mocks.Shims.Getenv = func(key string) string { + if key == "AZURE_CLIENT_SECRET" { + return "test-secret" + } + return "" + } + + // Mock config with custom environment + mockConfigHandler := config.NewMockConfigHandler() + mockConfigHandler.GetBoolFunc = func(key string, defaultValue ...bool) bool { + if key == "azure.enabled" { + return true + } + return false + } + mockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "cluster.driver" { + return "aks" + } + return "" + } + azureEnv := "AzureUSGovernment" + mockConfigHandler.GetConfigFunc = func() *v1alpha1.Context { + return &v1alpha1.Context{ + Azure: &azure.AzureConfig{ + Environment: &azureEnv, + }, + } + } + printer.configHandler = mockConfigHandler + + // Mock WriteFile to capture the output + var writtenData []byte + mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error { + writtenData = data + return nil + } + + // When generateProvidersOverrideTf is called + err := printer.generateProvidersOverrideTf() + + // Then no error should occur and the expected provider config should be written with custom environment + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + expectedContent := fmt.Sprintf(`provider "kubernetes" { + exec { + api_version = "client.authentication.k8s.io/v1beta1" + command = "kubelogin" + args = [ + "get-token", + "--login", "spn", + "--environment", "AzureUSGovernment", + "--server-id", "%s", + ] + } +} +`, constants.DefaultAKSOIDCServerID) + if string(writtenData) != expectedContent { + t.Errorf("Expected provider config %q, got %q", expectedContent, string(writtenData)) + } + }) + + t.Run("NoProjectPath", func(t *testing.T) { + // Given a TerraformEnvPrinter with no Terraform project path + printer, mocks := setup(t) + mocks.Shims.Glob = func(pattern string) ([]string, error) { + return nil, nil + } + + // When generateProvidersOverrideTf is called + err := printer.generateProvidersOverrideTf() + + // Then no error should occur + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + t.Run("ErrorGettingCurrentDirectory", func(t *testing.T) { + // Given a TerraformEnvPrinter with failing Getwd + printer, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { + return "", fmt.Errorf("mock error getting current directory") + } + + // When generateProvidersOverrideTf is called + err := printer.generateProvidersOverrideTf() + + // Then an error should be returned + if err == nil { + t.Errorf("Expected error, got nil") + } + if !strings.Contains(err.Error(), "error getting current directory") { + t.Errorf("Expected error message to contain 'error getting current directory', got %v", err) + } + }) + + t.Run("ErrorWritingFile", func(t *testing.T) { + // Given a TerraformEnvPrinter with Azure + AKS enabled and AZURE_CLIENT_SECRET set + printer, mocks := setup(t) + mocks.ConfigHandler.Set("azure.enabled", true) + mocks.ConfigHandler.Set("cluster.driver", "aks") + mocks.Shims.Getenv = func(key string) string { + if key == "AZURE_CLIENT_SECRET" { + return "test-secret" + } + return "" + } + + // Mock WriteFile to return error + mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error { + return fmt.Errorf("mock error writing file") + } + + // When generateProvidersOverrideTf is called + err := printer.generateProvidersOverrideTf() + + // Then an error should be returned + if err == nil { + t.Errorf("Expected error, got nil") + } + if !strings.Contains(err.Error(), "error writing providers_override.tf") { + t.Errorf("Expected error message to contain 'error writing providers_override.tf', got %v", err) + } + }) + + t.Run("ErrorRemovingFile", func(t *testing.T) { + // Given a TerraformEnvPrinter with Azure not enabled and existing file + printer, mocks := setup(t) + mocks.ConfigHandler.Set("azure.enabled", false) + mocks.ConfigHandler.Set("cluster.driver", "aks") + + // Mock Stat to return file exists + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if strings.Contains(name, "providers_override.tf") { + return nil, nil + } + return nil, os.ErrNotExist + } + + // Mock Remove to return error + mocks.Shims.Remove = func(name string) error { + if strings.Contains(name, "providers_override.tf") { + return fmt.Errorf("mock error removing file") + } + return nil + } + + // When generateProvidersOverrideTf is called + err := printer.generateProvidersOverrideTf() + + // Then an error should be returned + if err == nil { + t.Errorf("Expected error, got nil") + } + if !strings.Contains(err.Error(), "error removing providers_override.tf") { + t.Errorf("Expected error message to contain 'error removing providers_override.tf', got %v", err) + } + }) +} + func TestTerraformEnv_generateBackendConfigArgs(t *testing.T) { setup := func(t *testing.T) (*TerraformEnvPrinter, *EnvTestMocks) { t.Helper() diff --git a/pkg/runtime/tools/tools_manager.go b/pkg/runtime/tools/tools_manager.go index a531f8ac8..e8af66e43 100644 --- a/pkg/runtime/tools/tools_manager.go +++ b/pkg/runtime/tools/tools_manager.go @@ -101,6 +101,14 @@ func (t *BaseToolsManager) Check() error { } } + if t.configHandler.GetBool("azure.enabled") { + if err := t.checkKubelogin(); err != nil { + spin.Stop() + fmt.Fprintf(os.Stderr, "\033[31m✗ %s - Failed\033[0m\n", message) + return fmt.Errorf("kubelogin check failed: %v", err) + } + } + spin.Stop() fmt.Fprintf(os.Stderr, "\033[32m✔\033[0m %s - \033[32mDone\033[0m\n", message) return nil @@ -227,6 +235,45 @@ func (t *BaseToolsManager) checkOnePassword() error { return nil } +// checkKubelogin ensures kubelogin is available in the system's PATH using execLookPath. +// It checks for 'kubelogin' in the system's PATH, verifies its version, and validates +// required environment variables for SPN authentication if AZURE_CLIENT_SECRET is set. +// Returns nil if found and meets the minimum version requirement, else an error indicating it is not available or outdated. +func (t *BaseToolsManager) checkKubelogin() error { + if _, err := execLookPath("kubelogin"); err != nil { + return fmt.Errorf("kubelogin is not available in the PATH") + } + + out, err := t.shell.ExecSilent("kubelogin", "--version") + if err != nil { + return fmt.Errorf("kubelogin is not available in the PATH") + } + + version := extractVersion(out) + if version == "" { + return fmt.Errorf("failed to extract kubelogin version") + } + + if compareVersion(version, constants.MinimumVersionKubelogin) < 0 { + return fmt.Errorf("kubelogin version %s is below the minimum required version %s", version, constants.MinimumVersionKubelogin) + } + + azureClientSecret := os.Getenv("AZURE_CLIENT_SECRET") + if azureClientSecret != "" { + azureClientID := os.Getenv("AZURE_CLIENT_ID") + azureTenantID := os.Getenv("AZURE_TENANT_ID") + + if azureClientID == "" { + return fmt.Errorf("AZURE_CLIENT_SECRET is set but AZURE_CLIENT_ID is missing - both are required for SPN authentication") + } + if azureTenantID == "" { + return fmt.Errorf("AZURE_CLIENT_SECRET is set but AZURE_TENANT_ID is missing - both are required for SPN authentication") + } + } + + return nil +} + // compareVersion is a helper function to compare two version strings. // It returns -1 if version1 < version2, 0 if version1 == version2, and 1 if version1 > version2. func compareVersion(version1, version2 string) int { diff --git a/pkg/runtime/tools/tools_manager_test.go b/pkg/runtime/tools/tools_manager_test.go index 36adad75c..4e46ebdf7 100644 --- a/pkg/runtime/tools/tools_manager_test.go +++ b/pkg/runtime/tools/tools_manager_test.go @@ -860,6 +860,221 @@ func TestToolsManager_checkOnePassword(t *testing.T) { }) } +// Tests for kubelogin version validation +func TestToolsManager_checkKubelogin(t *testing.T) { + setup := func(t *testing.T) (*Mocks, *BaseToolsManager) { + t.Helper() + mocks := setupMocks(t) + toolsManager := NewToolsManager(mocks.ConfigHandler, mocks.Shell) + return mocks, toolsManager + } + + t.Run("Success", func(t *testing.T) { + // Given kubelogin is available with correct version + mocks, toolsManager := setup(t) + originalExecLookPath := execLookPath + execLookPath = func(name string) (string, error) { + if name == "kubelogin" { + return "/usr/bin/kubelogin", nil + } + return originalExecLookPath(name) + } + mocks.Shell.ExecSilentFunc = func(name string, args ...string) (string, error) { + if name == "kubelogin" && args[0] == "--version" { + return fmt.Sprintf("kubelogin version %s", constants.MinimumVersionKubelogin), nil + } + return "", fmt.Errorf("command not found") + } + // When checking kubelogin version + err := toolsManager.checkKubelogin() + // Then no error should be returned + if err != nil { + t.Errorf("Expected checkKubelogin to succeed, but got error: %v", err) + } + }) + + t.Run("KubeloginNotAvailable", func(t *testing.T) { + // Given kubelogin is not found in PATH + _, toolsManager := setup(t) + originalExecLookPath := execLookPath + execLookPath = func(name string) (string, error) { + if name == "kubelogin" { + return "", exec.ErrNotFound + } + return originalExecLookPath(name) + } + // When checking kubelogin version + err := toolsManager.checkKubelogin() + // Then an error indicating kubelogin is not available should be returned + if err == nil || !strings.Contains(err.Error(), "kubelogin is not available in the PATH") { + t.Errorf("Expected kubelogin not available error, got %v", err) + } + }) + + t.Run("KubeloginVersionInvalidResponse", func(t *testing.T) { + // Given kubelogin version response is invalid + mocks, toolsManager := setup(t) + originalExecLookPath := execLookPath + execLookPath = func(name string) (string, error) { + if name == "kubelogin" { + return "/usr/bin/kubelogin", nil + } + return originalExecLookPath(name) + } + mocks.Shell.ExecSilentFunc = func(name string, args ...string) (string, error) { + if name == "kubelogin" && args[0] == "--version" { + return "Invalid version response", nil + } + return "", fmt.Errorf("command not found") + } + // When checking kubelogin version + err := toolsManager.checkKubelogin() + // Then an error indicating version extraction failed should be returned + if err == nil || !strings.Contains(err.Error(), "failed to extract kubelogin version") { + t.Errorf("Expected failed to extract kubelogin version error, got %v", err) + } + }) + + t.Run("KubeloginVersionTooLow", func(t *testing.T) { + // Given kubelogin version is below minimum required version + mocks, toolsManager := setup(t) + originalExecLookPath := execLookPath + execLookPath = func(name string) (string, error) { + if name == "kubelogin" { + return "/usr/bin/kubelogin", nil + } + return originalExecLookPath(name) + } + mocks.Shell.ExecSilentFunc = func(name string, args ...string) (string, error) { + if name == "kubelogin" && args[0] == "--version" { + return "kubelogin version 0.1.0", nil + } + return "", fmt.Errorf("command not found") + } + // When checking kubelogin version + err := toolsManager.checkKubelogin() + // Then an error indicating version is too low should be returned + if err == nil || !strings.Contains(err.Error(), "kubelogin version 0.1.0 is below the minimum required version") { + t.Errorf("Expected kubelogin version too low error, got %v", err) + } + }) + + t.Run("KubeloginCommandError", func(t *testing.T) { + // Given kubelogin command execution fails + mocks, toolsManager := setup(t) + originalExecLookPath := execLookPath + execLookPath = func(name string) (string, error) { + if name == "kubelogin" { + return "/usr/bin/kubelogin", nil + } + return originalExecLookPath(name) + } + mocks.Shell.ExecSilentFunc = func(name string, args ...string) (string, error) { + if name == "kubelogin" && args[0] == "--version" { + return "", fmt.Errorf("kubelogin is not available in the PATH") + } + return "", fmt.Errorf("command not found") + } + // When checking kubelogin version + err := toolsManager.checkKubelogin() + // Then an error indicating kubelogin is not available should be returned + if err == nil || !strings.Contains(err.Error(), "kubelogin is not available in the PATH") { + t.Errorf("Expected kubelogin is not available in the PATH error, got %v", err) + } + }) + + t.Run("AZURE_CLIENT_SECRETSetButAZURE_CLIENT_IDMissing", func(t *testing.T) { + // Given AZURE_CLIENT_SECRET is set but AZURE_CLIENT_ID is missing + mocks, toolsManager := setup(t) + originalExecLookPath := execLookPath + execLookPath = func(name string) (string, error) { + if name == "kubelogin" { + return "/usr/bin/kubelogin", nil + } + return originalExecLookPath(name) + } + mocks.Shell.ExecSilentFunc = func(name string, args ...string) (string, error) { + if name == "kubelogin" && args[0] == "--version" { + return fmt.Sprintf("kubelogin version %s", constants.MinimumVersionKubelogin), nil + } + return "", fmt.Errorf("command not found") + } + os.Setenv("AZURE_CLIENT_SECRET", "test-secret") + defer os.Unsetenv("AZURE_CLIENT_SECRET") + os.Unsetenv("AZURE_CLIENT_ID") + os.Unsetenv("AZURE_TENANT_ID") + // When checking kubelogin + err := toolsManager.checkKubelogin() + // Then an error indicating AZURE_CLIENT_ID is missing should be returned + if err == nil || !strings.Contains(err.Error(), "AZURE_CLIENT_SECRET is set but AZURE_CLIENT_ID is missing") { + t.Errorf("Expected AZURE_CLIENT_ID missing error, got %v", err) + } + }) + + t.Run("AZURE_CLIENT_SECRETSetButAZURE_TENANT_IDMissing", func(t *testing.T) { + // Given AZURE_CLIENT_SECRET is set but AZURE_TENANT_ID is missing + mocks, toolsManager := setup(t) + originalExecLookPath := execLookPath + execLookPath = func(name string) (string, error) { + if name == "kubelogin" { + return "/usr/bin/kubelogin", nil + } + return originalExecLookPath(name) + } + mocks.Shell.ExecSilentFunc = func(name string, args ...string) (string, error) { + if name == "kubelogin" && args[0] == "--version" { + return fmt.Sprintf("kubelogin version %s", constants.MinimumVersionKubelogin), nil + } + return "", fmt.Errorf("command not found") + } + os.Setenv("AZURE_CLIENT_SECRET", "test-secret") + os.Setenv("AZURE_CLIENT_ID", "test-client-id") + defer func() { + os.Unsetenv("AZURE_CLIENT_SECRET") + os.Unsetenv("AZURE_CLIENT_ID") + }() + os.Unsetenv("AZURE_TENANT_ID") + // When checking kubelogin + err := toolsManager.checkKubelogin() + // Then an error indicating AZURE_TENANT_ID is missing should be returned + if err == nil || !strings.Contains(err.Error(), "AZURE_CLIENT_SECRET is set but AZURE_TENANT_ID is missing") { + t.Errorf("Expected AZURE_TENANT_ID missing error, got %v", err) + } + }) + + t.Run("AZURE_CLIENT_SECRETSetWithAllRequiredVars", func(t *testing.T) { + // Given AZURE_CLIENT_SECRET is set with all required environment variables + mocks, toolsManager := setup(t) + originalExecLookPath := execLookPath + execLookPath = func(name string) (string, error) { + if name == "kubelogin" { + return "/usr/bin/kubelogin", nil + } + return originalExecLookPath(name) + } + mocks.Shell.ExecSilentFunc = func(name string, args ...string) (string, error) { + if name == "kubelogin" && args[0] == "--version" { + return fmt.Sprintf("kubelogin version %s", constants.MinimumVersionKubelogin), nil + } + return "", fmt.Errorf("command not found") + } + os.Setenv("AZURE_CLIENT_SECRET", "test-secret") + os.Setenv("AZURE_CLIENT_ID", "test-client-id") + os.Setenv("AZURE_TENANT_ID", "test-tenant-id") + defer func() { + os.Unsetenv("AZURE_CLIENT_SECRET") + os.Unsetenv("AZURE_CLIENT_ID") + os.Unsetenv("AZURE_TENANT_ID") + }() + // When checking kubelogin + err := toolsManager.checkKubelogin() + // Then no error should be returned + if err != nil { + t.Errorf("Expected checkKubelogin to succeed with all required env vars, but got error: %v", err) + } + }) +} + // ============================================================================= // Test Helpers // =============================================================================