diff --git a/.gitignore b/.gitignore index 391d230bd..49f5a8ed9 100644 --- a/.gitignore +++ b/.gitignore @@ -41,6 +41,7 @@ windsor.yaml .windsor/ .volumes/ terraform/**/backend_override.tf +terraform/**/provider_override.tf contexts/**/.terraform/ contexts/**/.tfstate/ contexts/**/.kube/ diff --git a/pkg/di/mock_injector_test.go b/pkg/di/mock_injector_test.go index 2df77b893..76697abfe 100644 --- a/pkg/di/mock_injector_test.go +++ b/pkg/di/mock_injector_test.go @@ -92,3 +92,39 @@ func TestMockContainer_ResolveAll(t *testing.T) { } }) } + +func TestMockInjector_Resolve(t *testing.T) { + t.Run("Success", func(t *testing.T) { + // Given a new mock injector + injector := NewMockInjector() + + // And a mock service registered + mockService := &MockItemImpl{} + injector.Register("mockService", mockService) + + // When resolving the service by name + resolvedInstance := injector.Resolve("mockService") + + // Then the resolved instance should match the registered service + if resolvedInstance != mockService { + t.Fatalf("expected %v, got %v", mockService, resolvedInstance) + } + }) + + t.Run("ResolveError", func(t *testing.T) { + // Given a new mock injector + injector := NewMockInjector() + + // And a resolve error set for a specific service name + expectedError := errors.New("resolve error") + injector.SetResolveError("mockService", expectedError) + + // When resolving the service by name + resolvedInstance := injector.Resolve("mockService") + + // Then the resolved instance should be the expected error + if resolvedInstance != expectedError { + t.Fatalf("expected error %v, got %v", expectedError, resolvedInstance) + } + }) +} diff --git a/pkg/env/shims.go b/pkg/env/shims.go index 231d94560..f4019e902 100644 --- a/pkg/env/shims.go +++ b/pkg/env/shims.go @@ -59,3 +59,6 @@ var execLookPath = exec.LookPath // Define a variable for os.LookupEnv for easier testing var osLookupEnv = os.LookupEnv + +// Define a variable for os.Remove for easier testing +var osRemove = os.Remove diff --git a/pkg/env/terraform_env.go b/pkg/env/terraform_env.go index 0364f6968..75321824d 100644 --- a/pkg/env/terraform_env.go +++ b/pkg/env/terraform_env.go @@ -8,7 +8,11 @@ import ( "sort" "strings" + "github.com/hashicorp/hcl/v2/hclwrite" + "github.com/windsorcli/cli/pkg/constants" "github.com/windsorcli/cli/pkg/di" + svc "github.com/windsorcli/cli/pkg/services" + "github.com/zclconf/go-cty/cty" ) // TerraformEnvPrinter simulates a Terraform environment for testing purposes. @@ -93,9 +97,33 @@ func (e *TerraformEnvPrinter) GetEnvVars() (map[string]string, error) { return envVars, nil } -// PostEnvHook executes operations after setting the environment variables. +// PostEnvHook finalizes the environment setup by generating necessary override configurations +// if the current directory is within a Terraform project and Localstack is enabled. func (e *TerraformEnvPrinter) PostEnvHook() error { - return e.generateBackendOverrideTf() + currentPath, err := getwd() + if err != nil { + return fmt.Errorf("error getting current directory: %w", err) + } + + projectPath, err := findRelativeTerraformProjectPath() + if err != nil { + return fmt.Errorf("error finding Terraform project path: %w", err) + } + if projectPath == "" { + return nil + } + + if err := e.generateBackendOverrideTf(currentPath); err != nil { + return err + } + + if e.configHandler.GetBool("aws.localstack.enabled", false) { + if err := e.generateProviderOverrideTf(currentPath); err != nil { + return err + } + } + + return nil } // Print outputs the environment variables for the Terraform environment. @@ -120,47 +148,106 @@ func (e *TerraformEnvPrinter) getAlias() (map[string]string, error) { // generateBackendOverrideTf creates the backend_override.tf file for the project by determining // the backend type and writing the appropriate configuration to the file. -func (e *TerraformEnvPrinter) generateBackendOverrideTf() error { - currentPath, err := getwd() - if err != nil { - return fmt.Errorf("error getting current directory: %w", err) +func (e *TerraformEnvPrinter) generateBackendOverrideTf(projectPath string) error { + if projectPath == "" { + return nil } - projectPath, err := findRelativeTerraformProjectPath() + backendType := e.configHandler.GetString("terraform.backend.type", "local") + + backendOverridePath := filepath.Join(projectPath, "backend_override.tf") + backendConfig := fmt.Sprintf(`terraform { + backend "%s" {} +}`, backendType) + + err := writeFile(backendOverridePath, []byte(backendConfig), os.ModePerm) if err != nil { - return fmt.Errorf("error finding project path: %w", err) + return fmt.Errorf("error writing backend_override.tf: %w", err) } + return nil +} + +// generateProviderOverrideTf creates the provider_override.tf file for the project by determining +// the provider configuration and writing the appropriate configuration to the file. +func (e *TerraformEnvPrinter) generateProviderOverrideTf(projectPath string) error { if projectPath == "" { return nil } - contextConfig := e.configHandler.GetConfig() - backend := contextConfig.Terraform.Backend + overridePath := filepath.Join(projectPath, "provider_override.tf") - backendOverridePath := filepath.Join(currentPath, "backend_override.tf") - var backendConfig string + // Check if localstack is enabled + if !e.configHandler.GetBool("aws.localstack.enabled", false) { + // If localstack isn't enabled, delete provider_override.tf if it exists + if _, err := stat(overridePath); err == nil { + if err := osRemove(overridePath); err != nil { + return fmt.Errorf("error deleting provider_override.tf: %w", err) + } + } + return nil + } - switch backend.Type { - case "local": - backendConfig = fmt.Sprintf(`terraform { - backend "local" {} -}`) - case "s3": - backendConfig = fmt.Sprintf(`terraform { - backend "s3" {} -}`) - case "kubernetes": - backendConfig = fmt.Sprintf(`terraform { - backend "kubernetes" {} -}`) - default: - return fmt.Errorf("unsupported backend: %s", backend.Type) + region := e.configHandler.GetString("aws.region", "us-east-1") + + // Derive the AWS endpoint URL as done in AWSGenerator + service, ok := e.injector.Resolve("localstackService").(svc.Service) + if !ok { + return fmt.Errorf("localstackService not found") + } + tld := e.configHandler.GetString("dns.domain", "test") + fullName := service.GetName() + "." + tld + localstackPort := constants.DEFAULT_AWS_LOCALSTACK_PORT + localstackEndpoint := "http://" + fullName + ":" + localstackPort + + // Determine the list of AWS services to use + var awsServices []string + configuredAwsServices := e.configHandler.GetStringSlice("aws.localstack.services", nil) + if len(configuredAwsServices) > 0 { + awsServices = configuredAwsServices + } else { + awsServices = svc.ValidLocalstackServiceNames + } + + // Filter out invalid Terraform AWS service names + validAwsServices := make([]string, 0, len(awsServices)) + invalidServiceSet := make(map[string]struct{}, len(svc.InvalidTerraformAwsServiceNames)) + for _, invalidService := range svc.InvalidTerraformAwsServiceNames { + invalidServiceSet[invalidService] = struct{}{} + } + for _, awsService := range awsServices { + if _, isInvalid := invalidServiceSet[awsService]; !isInvalid { + validAwsServices = append(validAwsServices, awsService) + } + } + + // Create a new HCL file for the provider configuration + providerContent := hclwrite.NewEmptyFile() + body := providerContent.Body() + + // Append a new block for the provider "aws" + providerBlock := body.AppendNewBlock("provider", []string{"aws"}) + providerBody := providerBlock.Body() + + // Set provider attributes + providerBody.SetAttributeValue("access_key", cty.StringVal("test")) + providerBody.SetAttributeValue("secret_key", cty.StringVal("test")) + providerBody.SetAttributeValue("skip_credentials_validation", cty.BoolVal(true)) + providerBody.SetAttributeValue("skip_metadata_api_check", cty.BoolVal(true)) + providerBody.SetAttributeValue("skip_requesting_account_id", cty.BoolVal(true)) + providerBody.SetAttributeValue("region", cty.StringVal(region)) + + // Create a block for endpoints + endpointsBlock := providerBody.AppendNewBlock("endpoints", nil) + endpointsBody := endpointsBlock.Body() + for _, awsService := range validAwsServices { + endpointsBody.SetAttributeValue(awsService, cty.StringVal(localstackEndpoint)) } - err = writeFile(backendOverridePath, []byte(backendConfig), os.ModePerm) + // Write the provider configuration to the file + err := writeFile(overridePath, providerContent.Bytes(), os.ModePerm) if err != nil { - return fmt.Errorf("error writing backend_override.tf: %w", err) + return fmt.Errorf("error writing provider_override.tf: %w", err) } return nil @@ -171,20 +258,7 @@ func (e *TerraformEnvPrinter) generateBackendOverrideTf() error { // The function supports local, s3, and kubernetes backends. // It also includes backend.tfvars if present in the context directory. func (e *TerraformEnvPrinter) generateBackendConfigArgs(projectPath, configRoot string) ([]string, error) { - backend := e.configHandler.GetConfig().Terraform.Backend - backendType := e.configHandler.GetString("terraform.backend.type", "") - if backendType == "" { - switch { - case backend.S3 != nil: - backendType = "s3" - case backend.Kubernetes != nil: - backendType = "kubernetes" - case backend.Local != nil: - backendType = "local" - default: - backendType = "local" - } - } + backendType := e.configHandler.GetString("terraform.backend.type", "local") var backendConfigArgs []string @@ -206,20 +280,20 @@ func (e *TerraformEnvPrinter) generateBackendConfigArgs(projectPath, configRoot addBackendConfigArg("path", filepath.ToSlash(filepath.Join(configRoot, ".tfstate", projectPath, "terraform.tfstate"))) case "s3": addBackendConfigArg("key", filepath.ToSlash(filepath.Join(projectPath, "terraform.tfstate"))) - if backend.S3 != nil { - if err := processBackendConfig(backend.S3, addBackendConfigArg); err != nil { + if backend := e.configHandler.GetConfig().Terraform.Backend.S3; backend != nil { + if err := processBackendConfig(backend, addBackendConfigArg); err != nil { return nil, fmt.Errorf("error processing S3 backend config: %w", err) } } case "kubernetes": addBackendConfigArg("secret_suffix", sanitizeForK8s(projectPath)) - if backend.Kubernetes != nil { - if err := processBackendConfig(backend.Kubernetes, addBackendConfigArg); err != nil { + if backend := e.configHandler.GetConfig().Terraform.Backend.Kubernetes; backend != nil { + if err := processBackendConfig(backend, addBackendConfigArg); err != nil { return nil, fmt.Errorf("error processing Kubernetes backend config: %w", err) } } default: - return nil, fmt.Errorf("unsupported backend: %s", backend.Type) + return nil, fmt.Errorf("unsupported backend: %s", backendType) } return backendConfigArgs, nil diff --git a/pkg/env/terraform_env_test.go b/pkg/env/terraform_env_test.go index ed9984ed0..72af91ba8 100644 --- a/pkg/env/terraform_env_test.go +++ b/pkg/env/terraform_env_test.go @@ -13,7 +13,9 @@ import ( "github.com/windsorcli/cli/api/v1alpha1/aws" "github.com/windsorcli/cli/api/v1alpha1/terraform" "github.com/windsorcli/cli/pkg/config" + "github.com/windsorcli/cli/pkg/di" + "github.com/windsorcli/cli/pkg/services" "github.com/windsorcli/cli/pkg/shell" ) @@ -41,22 +43,82 @@ func setupSafeTerraformEnvMocks(injector ...di.Injector) *TerraformEnvMocks { return &v1alpha1.Context{ Terraform: &terraform.TerraformConfig{ Backend: &terraform.BackendConfig{ - Type: "local", + Type: "local", + Local: &terraform.LocalBackend{}, + S3: &terraform.S3Backend{}, + Kubernetes: &terraform.KubernetesBackend{}, + }, + }, + AWS: &aws.AWSConfig{ + Localstack: &aws.LocalstackConfig{ + Enabled: boolPtr(true), + Services: []string{"s3", "sns"}, }, + Region: stringPtr("us-east-1"), }, } } mockConfigHandler.GetContextFunc = func() string { return "mock-context" } + mockConfigHandler.GetBoolFunc = func(key string, defaultValue ...bool) bool { + switch key { + case "aws.localstack.enabled": + return true + case "aws.localstack.create": + return true + default: + if len(defaultValue) > 0 { + return defaultValue[0] + } + return false + } + } + mockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + switch key { + case "aws.region": + return "us-east-1" + case "dns.domain": + return "test" + default: + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } + } + mockConfigHandler.GetStringSliceFunc = func(key string, defaultValue ...[]string) []string { + if key == "aws.localstack.services" { + return []string{"s3", "sns"} + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return nil + } mockInjector.Register("shell", mockShell) mockInjector.Register("configHandler", mockConfigHandler) + mockLocalstackService := services.NewMockService() + mockLocalstackService.GetNameFunc = func() string { + return "localstack" + } + mockInjector.Register("localstackService", mockLocalstackService) + stat = func(name string) (os.FileInfo, error) { return nil, nil } + // Mock os.Remove to simulate successful file removal + osRemove = func(name string) error { + // Simulate successful removal of provider_override.tf + if strings.Contains(name, "provider_override.tf") { + return nil + } + return fmt.Errorf("mock error removing file: %s", name) + } + return &TerraformEnvMocks{ Injector: mockInjector, Shell: mockShell, @@ -162,13 +224,14 @@ func TestTerraformEnv_GetEnvVars(t *testing.T) { }) t.Run("NoProjectPathFound", func(t *testing.T) { + mocks := setupSafeTerraformEnvMocks() + // Given a mocked getwd function returning a specific path originalGetwd := getwd defer func() { getwd = originalGetwd }() getwd = func() (string, error) { return filepath.FromSlash("/mock/project/root"), nil } - mocks := setupSafeTerraformEnvMocks() // When the GetEnvVars function is called terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) @@ -384,12 +447,12 @@ func TestTerraformEnv_PostEnvHook(t *testing.T) { } }) - t.Run("ErrorFindingProjectPath", func(t *testing.T) { - // Given a mocked glob function returning an error - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { - return nil, fmt.Errorf("mock error finding project path") + t.Run("ErrorFindingRelativeTerraformProjectPath", func(t *testing.T) { + // Given a mocked findRelativeTerraformProjectPath function returning an error + originalFindRelativeTerraformProjectPath := findRelativeTerraformProjectPath + defer func() { findRelativeTerraformProjectPath = originalFindRelativeTerraformProjectPath }() + findRelativeTerraformProjectPath = func() (string, error) { + return "", fmt.Errorf("mock error finding Terraform project path") } // When the PostEnvHook function is called @@ -402,46 +465,29 @@ func TestTerraformEnv_PostEnvHook(t *testing.T) { if err == nil { t.Errorf("Expected error, got nil") } - if !strings.Contains(err.Error(), "error finding project path") { - t.Errorf("Expected error message to contain 'error finding project path', got %v", err) + expectedError := "error finding Terraform project path: mock error finding Terraform project path" + if err.Error() != expectedError { + t.Errorf("Expected error message to be '%s', got '%v'", expectedError, err.Error()) } }) - t.Run("UnsupportedBackend", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - Terraform: &terraform.TerraformConfig{ - Backend: &terraform.BackendConfig{ - Type: "unsupported", - }, - }, - } - } - - // Given a mocked getwd function simulating being in a terraform project root - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return filepath.FromSlash("mock/project/root/terraform/project/path"), nil - } - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { - return []string{filepath.FromSlash("mock/project/root/terraform/project/path/main.tf")}, nil + t.Run("NotInATerraformProject", func(t *testing.T) { + // Given a mocked findRelativeTerraformProjectPath function returning an empty string + originalFindRelativeTerraformProjectPath := findRelativeTerraformProjectPath + defer func() { findRelativeTerraformProjectPath = originalFindRelativeTerraformProjectPath }() + findRelativeTerraformProjectPath = func() (string, error) { + return "", nil } // When the PostEnvHook function is called + mocks := setupSafeTerraformEnvMocks() terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) terraformEnvPrinter.Initialize() err := terraformEnvPrinter.PostEnvHook() - // Then the error should contain the expected message - if err == nil { - t.Errorf("Expected error, got nil") - } - if !strings.Contains(err.Error(), "unsupported backend") { - t.Errorf("Expected error message to contain 'unsupported backend', got %v", err) + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) } }) @@ -608,14 +654,11 @@ func TestTerraformEnv_getAlias(t *testing.T) { mocks.ConfigHandler.GetContextFunc = func() string { return "local" } - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - AWS: &aws.AWSConfig{ - Localstack: &aws.LocalstackConfig{ - Enabled: boolPtr(false), - }, - }, + mocks.ConfigHandler.GetBoolFunc = func(key string, defaultValue ...bool) bool { + if key == "aws.localstack.enabled" { + return false } + return false } // When getAlias is called @@ -807,38 +850,17 @@ func TestTerraformEnv_sanitizeForK8s(t *testing.T) { func TestTerraformEnv_generateBackendOverrideTf(t *testing.T) { t.Run("Success", func(t *testing.T) { + // Use setupSafeTerraformEnvMocks to create mocks mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetConfigRootFunc = func() (string, error) { - return filepath.FromSlash("/mock/config/root"), nil - } - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - Terraform: &terraform.TerraformConfig{ - Backend: &terraform.BackendConfig{ - Type: "local", - }, - }, - } - } - // Given a mocked getwd function simulating being in a terraform project root + // Mocked getwd function originalGetwd := getwd defer func() { getwd = originalGetwd }() getwd = func() (string, error) { return filepath.FromSlash("/mock/project/root/terraform/project/path"), nil } - // And a mocked glob function simulating finding Terraform files - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { - expectedPattern := filepath.FromSlash("/mock/project/root/terraform/project/path/*.tf") - if pattern == expectedPattern { - return []string{filepath.FromSlash("/mock/project/root/terraform/project/path/main.tf")}, nil - } - return nil, nil - } - // And a mocked writeFile function to capture the output + // Mocked writeFile function to capture the output var writtenData []byte originalWriteFile := writeFile defer func() { writeFile = originalWriteFile }() @@ -850,206 +872,55 @@ func TestTerraformEnv_generateBackendOverrideTf(t *testing.T) { // When generateBackendOverrideTf is called terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) terraformEnvPrinter.Initialize() - err := terraformEnvPrinter.generateBackendOverrideTf() + err := terraformEnvPrinter.generateBackendOverrideTf("project/path") // Then no error should occur and the expected backend config should be written if err != nil { t.Errorf("Expected no error, got %v", err) } - expectedContent := `terraform { - backend "local" {} -}` + expectedContent := "terraform {\n backend \"local\" {}\n}" if string(writtenData) != expectedContent { t.Errorf("Expected backend config %q, got %q", expectedContent, string(writtenData)) } }) - t.Run("S3Backend", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - Terraform: &terraform.TerraformConfig{ - Backend: &terraform.BackendConfig{ - Type: "s3", - }, - }, - } - } - - // Given a mocked getwd function simulating being in a terraform project root - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return filepath.FromSlash("/mock/project/root/terraform/project/path"), nil - } - // And a mocked glob function simulating finding Terraform files - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { - if pattern == filepath.FromSlash("/mock/project/root/terraform/project/path/*.tf") { - return []string{filepath.FromSlash("/mock/project/root/terraform/project/path/main.tf")}, nil - } - return nil, nil - } - - // And a mocked writeFile function to capture the output - var writtenData []byte + t.Run("NoProjectPath", func(t *testing.T) { + // Mock writeFile to ensure it never gets called originalWriteFile := writeFile defer func() { writeFile = originalWriteFile }() writeFile = func(filename string, data []byte, perm os.FileMode) error { - writtenData = data + t.Errorf("writeFile should not be called") return nil } - // When generateBackendOverrideTf is called - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - err := terraformEnvPrinter.generateBackendOverrideTf() - - // Then no error should occur and the expected backend config should be written - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - - expectedContent := `terraform { - backend "s3" {} -}` - if string(writtenData) != expectedContent { - t.Errorf("Expected backend config %q, got %q", expectedContent, string(writtenData)) + if err := NewTerraformEnvPrinter(setupSafeTerraformEnvMocks().Injector).generateBackendOverrideTf(""); err != nil { + t.Errorf("Expected nil, got %v", err) } }) - t.Run("KubernetesBackend", func(t *testing.T) { + t.Run("ErrorHandling", func(t *testing.T) { + // Use setupSafeTerraformEnvMocks to create mocks mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - Terraform: &terraform.TerraformConfig{ - Backend: &terraform.BackendConfig{ - Type: "kubernetes", - }, - }, - } - } - // Given a mocked getwd function simulating being in a terraform project root - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return filepath.FromSlash("/mock/project/root/terraform/project/path"), nil - } - // And a mocked glob function simulating finding Terraform files - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { - if pattern == filepath.FromSlash("/mock/project/root/terraform/project/path/*.tf") { - return []string{filepath.FromSlash("/mock/project/root/terraform/project/path/main.tf")}, nil - } - return nil, nil - } - - // And a mocked writeFile function to capture the output - var writtenData []byte + // Mocked writeFile function to simulate an error originalWriteFile := writeFile defer func() { writeFile = originalWriteFile }() writeFile = func(filename string, data []byte, perm os.FileMode) error { - writtenData = data - return nil - } - - // When generateBackendOverrideTf is called - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - err := terraformEnvPrinter.generateBackendOverrideTf() - - // Then no error should occur and the expected backend config should be written - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - - expectedContent := `terraform { - backend "kubernetes" {} -}` - if string(writtenData) != expectedContent { - t.Errorf("Expected backend config %q, got %q", expectedContent, string(writtenData)) - } - }) - - t.Run("UnsupportedBackend", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - Terraform: &terraform.TerraformConfig{ - Backend: &terraform.BackendConfig{ - Type: "unsupported", - }, - }, - } - } - - // Given a mocked getwd function simulating being in a terraform project root - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return filepath.FromSlash("/mock/project/root/terraform/project/path"), nil - } - // And a mocked glob function simulating finding Terraform files - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { - if pattern == filepath.FromSlash("/mock/project/root/terraform/project/path/*.tf") { - return []string{filepath.FromSlash("/mock/project/root/terraform/project/path/main.tf")}, nil - } - return nil, nil + return fmt.Errorf("mock error writing backend_override.tf file") } // When generateBackendOverrideTf is called terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) terraformEnvPrinter.Initialize() - err := terraformEnvPrinter.generateBackendOverrideTf() + err := terraformEnvPrinter.generateBackendOverrideTf("project/path") - // Then the error should contain the expected message + // Then an error should occur if err == nil { t.Errorf("Expected error, got nil") } - if !strings.Contains(err.Error(), "unsupported backend: unsupported") { - t.Errorf("Expected error message to contain 'unsupported backend: unsupported', got %v", err) - } - }) - - t.Run("NoTerraformFiles", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - Terraform: &terraform.TerraformConfig{ - Backend: &terraform.BackendConfig{ - Type: "local", - }, - }, - } - } - - // Given a mocked getwd function simulating being in a terraform project root - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return filepath.FromSlash("/mock/project/root/terraform/project/path"), nil - } - // And a mocked glob function simulating no Terraform files found - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { - return nil, nil - } - - // When generateBackendOverrideTf is called - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - err := terraformEnvPrinter.generateBackendOverrideTf() - - // Then no error should occur - if err != nil { - t.Errorf("Expected no error, got %v", err) + if !strings.Contains(err.Error(), "mock error writing backend_override.tf file") { + t.Errorf("Expected error message to contain 'mock error writing backend_override.tf file', got %v", err) } }) } @@ -1121,8 +992,6 @@ func TestTerraformEnv_generateBackendConfigArgs(t *testing.T) { S3: &terraform.S3Backend{ Bucket: stringPtr("mock-bucket"), Region: stringPtr("mock-region"), - AccessKey: stringPtr("mock-access-key"), - SecretKey: stringPtr("mock-secret-key"), MaxRetries: intPtr(5), SkipCredentialsValidation: boolPtr(true), }, @@ -1130,6 +999,15 @@ func TestTerraformEnv_generateBackendConfigArgs(t *testing.T) { }, } } + mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "terraform.backend.type" { + return "s3" + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) terraformEnvPrinter.Initialize() @@ -1144,11 +1022,9 @@ func TestTerraformEnv_generateBackendConfigArgs(t *testing.T) { expectedArgs := []string{ fmt.Sprintf(`-backend-config="%s"`, filepath.ToSlash(filepath.Join(configRoot, "terraform", "backend.tfvars"))), `-backend-config="key=project/path/terraform.tfstate"`, - `-backend-config="access_key=mock-access-key"`, `-backend-config="bucket=mock-bucket"`, `-backend-config="max_retries=5"`, `-backend-config="region=mock-region"`, - `-backend-config="secret_key=mock-secret-key"`, `-backend-config="skip_credentials_validation=true"`, } @@ -1171,6 +1047,15 @@ func TestTerraformEnv_generateBackendConfigArgs(t *testing.T) { }, } } + mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "terraform.backend.type" { + return "kubernetes" + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) terraformEnvPrinter.Initialize() @@ -1222,15 +1107,11 @@ func TestTerraformEnv_generateBackendConfigArgs(t *testing.T) { t.Run("ErrorMarshallingBackendConfig", func(t *testing.T) { mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - Terraform: &terraform.TerraformConfig{ - Backend: &terraform.BackendConfig{ - Type: "s3", - S3: &terraform.S3Backend{}, - }, - }, + mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "terraform.backend.type" { + return "s3" } + return "" } // Mock yamlMarshal to return an error @@ -1269,6 +1150,16 @@ func TestTerraformEnv_generateBackendConfigArgs(t *testing.T) { } } + mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "terraform.backend.type" { + return "kubernetes" + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } + // Mock processBackendConfig to return an error originalProcessBackendConfig := processBackendConfig defer func() { processBackendConfig = originalProcessBackendConfig }() @@ -1401,3 +1292,207 @@ func TestTerraformEnv_processBackendConfig(t *testing.T) { } }) } + +func TestTerraformEnv_generateProviderOverrideTf(t *testing.T) { + t.Run("NoProjectPath", func(t *testing.T) { + // Mock writeFile to ensure it never gets called + originalWriteFile := writeFile + defer func() { writeFile = originalWriteFile }() + writeFile = func(filename string, data []byte, perm os.FileMode) error { + t.Errorf("writeFile should not be called") + return nil + } + + // Given a TerraformEnvPrinter with no project path + terraformEnvPrinter := NewTerraformEnvPrinter(setupSafeTerraformEnvMocks().Injector) + + // When generateProviderOverrideTf is called with an empty project path + err := terraformEnvPrinter.generateProviderOverrideTf("") + + // Then no error should occur + if err != nil { + t.Errorf("Expected nil, got %v", err) + } + }) + + t.Run("LocalstackEnabled", func(t *testing.T) { + mocks := setupSafeTerraformEnvMocks() + + // Given a mocked writeFile function to capture the output + var writtenData []byte + originalWriteFile := writeFile + defer func() { writeFile = originalWriteFile }() + writeFile = func(filename string, data []byte, perm os.FileMode) error { + writtenData = data + return nil + } + + // When generateProviderOverrideTf is called + terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) + terraformEnvPrinter.Initialize() + err := terraformEnvPrinter.generateProviderOverrideTf("project/path") + + // Then no error should occur and the provider config should be validated + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + // Validate the returned provider config structure + providerConfig := string(writtenData) + if !strings.Contains(providerConfig, `provider "aws"`) { + t.Errorf("Expected provider config to contain 'provider \"aws\"', got %q", providerConfig) + } + if !strings.Contains(providerConfig, `endpoints {`) { + t.Errorf("Expected provider config to contain 'endpoints {', got %q", providerConfig) + } + if !strings.Contains(providerConfig, `s3 = "http://localstack.test:4566"`) { + t.Errorf("Expected provider config to contain 's3 = \"http://localstack.test:4566\"', got %q", providerConfig) + } + if !strings.Contains(providerConfig, `sns = "http://localstack.test:4566"`) { + t.Errorf("Expected provider config to contain 'sns = \"http://localstack.test:4566\"', got %q", providerConfig) + } + }) + + t.Run("LocalstackDisabled", func(t *testing.T) { + mocks := setupSafeTerraformEnvMocks() + mocks.ConfigHandler.GetBoolFunc = func(key string, defaultValue ...bool) bool { + if key == "aws.localstack.enabled" { + return false + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return false + } + + // Given a mocked writeFile function to capture the output + var writtenData []byte + originalWriteFile := writeFile + defer func() { writeFile = originalWriteFile }() + writeFile = func(filename string, data []byte, perm os.FileMode) error { + writtenData = data + return nil + } + + // When generateProviderOverrideTf is called + terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) + terraformEnvPrinter.Initialize() + err := terraformEnvPrinter.generateProviderOverrideTf("project/path") + + // Then no error should occur and no provider config should be written + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if len(writtenData) != 0 { + t.Errorf("Expected no provider config to be written, got %q", string(writtenData)) + } + }) + + t.Run("ErrorRemovingProviderOverrideTf", func(t *testing.T) { + mocks := setupSafeTerraformEnvMocks() + mocks.ConfigHandler.GetBoolFunc = func(key string, defaultValue ...bool) bool { + if key == "aws.localstack.enabled" { + return false + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return false + } + + // Mock osRemove to simulate an error + originalOsRemove := osRemove + defer func() { osRemove = originalOsRemove }() + osRemove = func(name string) error { + return fmt.Errorf("mock error removing provider_override.tf") + } + + // When generateProviderOverrideTf is called + terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) + terraformEnvPrinter.Initialize() + err := terraformEnvPrinter.generateProviderOverrideTf("project/path") + + // Then an error should occur + if err == nil { + t.Errorf("Expected error, got nil") + } + if !strings.Contains(err.Error(), "mock error removing provider_override.tf") { + t.Errorf("Expected error message to contain 'mock error removing provider_override.tf', got %v", err) + } + }) + + t.Run("ErrorResolvingLocalstackService", func(t *testing.T) { + mocks := setupSafeTerraformEnvMocks() + mocks.Injector.Register("localstackService", nil) + mocks.ConfigHandler.GetBoolFunc = func(key string, defaultValue ...bool) bool { + return true + } + + terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) + terraformEnvPrinter.Initialize() + err := terraformEnvPrinter.generateProviderOverrideTf("project/path") + + if err == nil { + t.Errorf("Expected error, got nil") + } + if !strings.Contains(err.Error(), "localstackService not found") { + t.Errorf("Expected error message to contain 'localstackService not found', got %v", err) + } + }) + + t.Run("UsesAllLocalstackServicesByDefault", func(t *testing.T) { + mocks := setupSafeTerraformEnvMocks() + mocks.ConfigHandler.GetBoolFunc = func(key string, defaultValue ...bool) bool { + if key == "aws.localstack.enabled" { + return true + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return false + } + mocks.ConfigHandler.GetStringSliceFunc = func(key string, defaultValue ...[]string) []string { + if key == "aws.localstack.services" { + return nil + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return nil + } + + terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) + terraformEnvPrinter.Initialize() + err := terraformEnvPrinter.generateProviderOverrideTf("project/path") + + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + t.Run("ErrorWritingProviderOverrideTf", func(t *testing.T) { + mocks := setupSafeTerraformEnvMocks() + mocks.ConfigHandler.GetBoolFunc = func(key string, defaultValue ...bool) bool { + return true + } + + // Mocked writeFile function to simulate an error + originalWriteFile := writeFile + defer func() { writeFile = originalWriteFile }() + writeFile = func(filename string, data []byte, perm os.FileMode) error { + return fmt.Errorf("mock error writing provider_override.tf file") + } + + terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) + terraformEnvPrinter.Initialize() + err := terraformEnvPrinter.generateProviderOverrideTf("project/path") + + if err == nil { + t.Errorf("Expected error, got nil") + } + if !strings.Contains(err.Error(), "mock error writing provider_override.tf file") { + t.Errorf("Expected error message to contain 'mock error writing provider_override.tf file', got %v", err) + } + }) +} diff --git a/pkg/generators/git_generator.go b/pkg/generators/git_generator.go index 5261e46dd..4b1f2d905 100644 --- a/pkg/generators/git_generator.go +++ b/pkg/generators/git_generator.go @@ -15,6 +15,7 @@ var gitIgnoreLines = []string{ ".windsor/", ".volumes/", "terraform/**/backend_override.tf", + "terraform/**/provider_override.tf", "contexts/**/.terraform/", "contexts/**/.tfstate/", "contexts/**/.kube/", diff --git a/pkg/generators/git_generator_test.go b/pkg/generators/git_generator_test.go index 2bb353202..b6f6fb0f1 100644 --- a/pkg/generators/git_generator_test.go +++ b/pkg/generators/git_generator_test.go @@ -18,6 +18,7 @@ const ( .windsor/ .volumes/ terraform/**/backend_override.tf +terraform/**/provider_override.tf contexts/**/.terraform/ contexts/**/.tfstate/ contexts/**/.kube/ diff --git a/pkg/services/localstack_service.go b/pkg/services/localstack_service.go index b031c978d..f53716896 100644 --- a/pkg/services/localstack_service.go +++ b/pkg/services/localstack_service.go @@ -11,6 +11,20 @@ import ( "github.com/windsorcli/cli/pkg/di" ) +// Valid AWS service names that use the same endpoint +var ValidLocalstackServiceNames = []string{ + "acm", "apigateway", "cloudformation", "cloudwatch", "config", "dynamodb", "dynamodbstreams", + "ec2", "es", "events", "firehose", "iam", "kinesis", "kms", "lambda", "logs", "opensearch", + "redshift", "resource-groups", "resourcegroupstaggingapi", "route53", "route53resolver", "s3", + "s3control", "scheduler", "secretsmanager", "ses", "sns", "sqs", "ssm", "stepfunctions", "sts", + "support", "swf", "transcribe", +} + +// Invalid Terraform AWS service names that do not get an endpoint configuration +var InvalidTerraformAwsServiceNames = []string{ + "dynamodbstreams", "resource-groups", "support", "logs", "opensearch", "scheduler", +} + // LocalstackService is a service struct that provides Localstack-specific utility functions type LocalstackService struct { BaseService @@ -42,7 +56,12 @@ func (s *LocalstackService) GetComposeConfig() (*types.Config, error) { servicesList := "" if contextConfig.AWS.Localstack.Services != nil { - servicesList = strings.Join(contextConfig.AWS.Localstack.Services, ",") + services := s.configHandler.GetStringSlice("aws.localstack.services", []string{}) + validServices, invalidServices := validateServices(services) + if len(invalidServices) > 0 { + return nil, fmt.Errorf("invalid services found: %s", strings.Join(invalidServices, ", ")) + } + servicesList = strings.Join(validServices, ",") } tld := s.configHandler.GetString("dns.domain", "test") @@ -50,6 +69,7 @@ func (s *LocalstackService) GetComposeConfig() (*types.Config, error) { port, err := strconv.ParseUint(constants.DEFAULT_AWS_LOCALSTACK_PORT, 10, 32) if err != nil { + // Can't test this error until the port is configurable return nil, fmt.Errorf("invalid port format: %w", err) } port32 := uint32(port) @@ -88,5 +108,24 @@ func (s *LocalstackService) GetComposeConfig() (*types.Config, error) { return &types.Config{Services: services}, nil } +// validateServices checks the input services and returns valid and invalid services. +func validateServices(services []string) ([]string, []string) { + validServicesMap := make(map[string]struct{}, len(ValidLocalstackServiceNames)) + for _, serviceName := range ValidLocalstackServiceNames { + validServicesMap[serviceName] = struct{}{} + } + + var validServices []string + var invalidServices []string + for _, service := range services { + if _, exists := validServicesMap[service]; exists { + validServices = append(validServices, service) + } else { + invalidServices = append(invalidServices, service) + } + } + return validServices, invalidServices +} + // Ensure LocalstackService implements Service interface var _ Service = (*LocalstackService)(nil) diff --git a/pkg/services/localstack_service_test.go b/pkg/services/localstack_service_test.go index cdb38e967..e31939a92 100644 --- a/pkg/services/localstack_service_test.go +++ b/pkg/services/localstack_service_test.go @@ -3,6 +3,7 @@ package services import ( "os" "path/filepath" + "strings" "testing" "github.com/windsorcli/cli/api/v1alpha1" @@ -50,6 +51,29 @@ func createLocalstackServiceMocks(mockInjector ...di.Injector) *LocalstackServic mockConfigHandler.SetContextFunc = func(context string) error { return nil } mockConfigHandler.GetConfigRootFunc = func() (string, error) { return filepath.FromSlash("/mock/config/root"), nil } + // Mock GetConfig to return a valid Localstack configuration with SERVICES set + mockConfigHandler.GetConfigFunc = func() *v1alpha1.Context { + return &v1alpha1.Context{ + AWS: &aws.AWSConfig{ + Localstack: &aws.LocalstackConfig{ + Enabled: ptrBool(true), + Services: []string{"s3", "dynamodb"}, + }, + }, + } + } + + // Mock GetStringSlice to return a list of services for Localstack + mockConfigHandler.GetStringSliceFunc = func(key string, defaultValue ...[]string) []string { + if key == "aws.localstack.services" { + return []string{"s3", "dynamodb"} + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return nil + } + // Register mocks in the injector injector.Register("configHandler", mockConfigHandler) injector.Register("shell", mockShell) @@ -66,18 +90,6 @@ func TestLocalstackService_GetComposeConfig(t *testing.T) { // Create mock injector with necessary mocks mocks := createLocalstackServiceMocks() - // Mock GetConfig to return a valid Localstack configuration - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - AWS: &aws.AWSConfig{ - Localstack: &aws.LocalstackConfig{ - Enabled: ptrBool(true), - Services: []string{"s3", "dynamodb"}, - }, - }, - } - } - // Create an instance of LocalstackService localstackService := NewLocalstackService(mocks.Injector) @@ -150,4 +162,41 @@ func TestLocalstackService_GetComposeConfig(t *testing.T) { t.Errorf("expected service to have LOCALSTACK_AUTH_TOKEN environment variable, got %v", service.Environment["LOCALSTACK_AUTH_TOKEN"]) } }) + + t.Run("InvalidServicesDetected", func(t *testing.T) { + // Create mock injector with necessary mocks + mocks := createLocalstackServiceMocks() + + // Mock GetStringSlice to return an invalid Localstack configuration + mocks.ConfigHandler.GetStringSliceFunc = func(key string, defaultValue ...[]string) []string { + if key == "aws.localstack.services" { + return []string{"invalidService"} + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return nil + } + + // Create an instance of LocalstackService + localstackService := NewLocalstackService(mocks.Injector) + + // Initialize the service + if err := localstackService.Initialize(); err != nil { + t.Fatalf("Initialize() error = %v", err) + } + + // When: GetComposeConfig is called + _, err := localstackService.GetComposeConfig() + + // Then: an error should be returned indicating invalid services + if err == nil { + t.Fatalf("expected error due to invalid services, got nil") + } + + expectedError := "invalid services found: invalidService" + if !strings.Contains(err.Error(), expectedError) { + t.Errorf("expected error to contain %q, got %v", expectedError, err) + } + }) }