Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ windsor.yaml
.windsor/
.volumes/
terraform/**/backend_override.tf
terraform/**/provider_override.tf
contexts/**/.terraform/
contexts/**/.tfstate/
contexts/**/.kube/
Expand Down
36 changes: 36 additions & 0 deletions pkg/di/mock_injector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,39 @@ func TestMockContainer_ResolveAll(t *testing.T) {
}
})
}

func TestMockInjector_Resolve(t *testing.T) {
t.Run("Success", func(t *testing.T) {
// Given a new mock injector
injector := NewMockInjector()

// And a mock service registered
mockService := &MockItemImpl{}
injector.Register("mockService", mockService)

// When resolving the service by name
resolvedInstance := injector.Resolve("mockService")

// Then the resolved instance should match the registered service
if resolvedInstance != mockService {
t.Fatalf("expected %v, got %v", mockService, resolvedInstance)
}
})

t.Run("ResolveError", func(t *testing.T) {
// Given a new mock injector
injector := NewMockInjector()

// And a resolve error set for a specific service name
expectedError := errors.New("resolve error")
injector.SetResolveError("mockService", expectedError)

// When resolving the service by name
resolvedInstance := injector.Resolve("mockService")

// Then the resolved instance should be the expected error
if resolvedInstance != expectedError {
t.Fatalf("expected error %v, got %v", expectedError, resolvedInstance)
}
})
}
3 changes: 3 additions & 0 deletions pkg/env/shims.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,6 @@ var execLookPath = exec.LookPath

// Define a variable for os.LookupEnv for easier testing
var osLookupEnv = os.LookupEnv

// Define a variable for os.Remove for easier testing
var osRemove = os.Remove
170 changes: 122 additions & 48 deletions pkg/env/terraform_env.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ import (
"sort"
"strings"

"github.com/hashicorp/hcl/v2/hclwrite"
"github.com/windsorcli/cli/pkg/constants"
"github.com/windsorcli/cli/pkg/di"
svc "github.com/windsorcli/cli/pkg/services"
"github.com/zclconf/go-cty/cty"
)

// TerraformEnvPrinter simulates a Terraform environment for testing purposes.
Expand Down Expand Up @@ -93,9 +97,33 @@ func (e *TerraformEnvPrinter) GetEnvVars() (map[string]string, error) {
return envVars, nil
}

// PostEnvHook executes operations after setting the environment variables.
// PostEnvHook finalizes the environment setup by generating necessary override configurations
// if the current directory is within a Terraform project and Localstack is enabled.
func (e *TerraformEnvPrinter) PostEnvHook() error {
return e.generateBackendOverrideTf()
currentPath, err := getwd()
if err != nil {
return fmt.Errorf("error getting current directory: %w", err)
}

projectPath, err := findRelativeTerraformProjectPath()
if err != nil {
return fmt.Errorf("error finding Terraform project path: %w", err)
}
if projectPath == "" {
return nil
}

if err := e.generateBackendOverrideTf(currentPath); err != nil {
return err
}

if e.configHandler.GetBool("aws.localstack.enabled", false) {
if err := e.generateProviderOverrideTf(currentPath); err != nil {
return err
}
}

return nil
}

// Print outputs the environment variables for the Terraform environment.
Expand All @@ -120,47 +148,106 @@ func (e *TerraformEnvPrinter) getAlias() (map[string]string, error) {

// generateBackendOverrideTf creates the backend_override.tf file for the project by determining
// the backend type and writing the appropriate configuration to the file.
func (e *TerraformEnvPrinter) generateBackendOverrideTf() error {
currentPath, err := getwd()
if err != nil {
return fmt.Errorf("error getting current directory: %w", err)
func (e *TerraformEnvPrinter) generateBackendOverrideTf(projectPath string) error {
if projectPath == "" {
return nil
}

projectPath, err := findRelativeTerraformProjectPath()
backendType := e.configHandler.GetString("terraform.backend.type", "local")

backendOverridePath := filepath.Join(projectPath, "backend_override.tf")
backendConfig := fmt.Sprintf(`terraform {
backend "%s" {}
}`, backendType)

err := writeFile(backendOverridePath, []byte(backendConfig), os.ModePerm)
if err != nil {
return fmt.Errorf("error finding project path: %w", err)
return fmt.Errorf("error writing backend_override.tf: %w", err)
}

return nil
}

// generateProviderOverrideTf creates the provider_override.tf file for the project by determining
// the provider configuration and writing the appropriate configuration to the file.
func (e *TerraformEnvPrinter) generateProviderOverrideTf(projectPath string) error {
if projectPath == "" {
return nil
}

contextConfig := e.configHandler.GetConfig()
backend := contextConfig.Terraform.Backend
overridePath := filepath.Join(projectPath, "provider_override.tf")

backendOverridePath := filepath.Join(currentPath, "backend_override.tf")
var backendConfig string
// Check if localstack is enabled
if !e.configHandler.GetBool("aws.localstack.enabled", false) {
// If localstack isn't enabled, delete provider_override.tf if it exists
if _, err := stat(overridePath); err == nil {
if err := osRemove(overridePath); err != nil {
return fmt.Errorf("error deleting provider_override.tf: %w", err)
}
}
return nil
}

switch backend.Type {
case "local":
backendConfig = fmt.Sprintf(`terraform {
backend "local" {}
}`)
case "s3":
backendConfig = fmt.Sprintf(`terraform {
backend "s3" {}
}`)
case "kubernetes":
backendConfig = fmt.Sprintf(`terraform {
backend "kubernetes" {}
}`)
default:
return fmt.Errorf("unsupported backend: %s", backend.Type)
region := e.configHandler.GetString("aws.region", "us-east-1")

// Derive the AWS endpoint URL as done in AWSGenerator
service, ok := e.injector.Resolve("localstackService").(svc.Service)
if !ok {
return fmt.Errorf("localstackService not found")
}
tld := e.configHandler.GetString("dns.domain", "test")
fullName := service.GetName() + "." + tld
localstackPort := constants.DEFAULT_AWS_LOCALSTACK_PORT
localstackEndpoint := "http://" + fullName + ":" + localstackPort

// Determine the list of AWS services to use
var awsServices []string
configuredAwsServices := e.configHandler.GetStringSlice("aws.localstack.services", nil)
if len(configuredAwsServices) > 0 {
awsServices = configuredAwsServices
} else {
awsServices = svc.ValidLocalstackServiceNames
}

// Filter out invalid Terraform AWS service names
validAwsServices := make([]string, 0, len(awsServices))
invalidServiceSet := make(map[string]struct{}, len(svc.InvalidTerraformAwsServiceNames))
for _, invalidService := range svc.InvalidTerraformAwsServiceNames {
invalidServiceSet[invalidService] = struct{}{}
}
for _, awsService := range awsServices {
if _, isInvalid := invalidServiceSet[awsService]; !isInvalid {
validAwsServices = append(validAwsServices, awsService)
}
}

// Create a new HCL file for the provider configuration
providerContent := hclwrite.NewEmptyFile()
body := providerContent.Body()

// Append a new block for the provider "aws"
providerBlock := body.AppendNewBlock("provider", []string{"aws"})
providerBody := providerBlock.Body()

// Set provider attributes
providerBody.SetAttributeValue("access_key", cty.StringVal("test"))
providerBody.SetAttributeValue("secret_key", cty.StringVal("test"))
providerBody.SetAttributeValue("skip_credentials_validation", cty.BoolVal(true))
providerBody.SetAttributeValue("skip_metadata_api_check", cty.BoolVal(true))
providerBody.SetAttributeValue("skip_requesting_account_id", cty.BoolVal(true))
providerBody.SetAttributeValue("region", cty.StringVal(region))

// Create a block for endpoints
endpointsBlock := providerBody.AppendNewBlock("endpoints", nil)
endpointsBody := endpointsBlock.Body()
for _, awsService := range validAwsServices {
endpointsBody.SetAttributeValue(awsService, cty.StringVal(localstackEndpoint))
}

err = writeFile(backendOverridePath, []byte(backendConfig), os.ModePerm)
// Write the provider configuration to the file
err := writeFile(overridePath, providerContent.Bytes(), os.ModePerm)
if err != nil {
return fmt.Errorf("error writing backend_override.tf: %w", err)
return fmt.Errorf("error writing provider_override.tf: %w", err)
}

return nil
Expand All @@ -171,20 +258,7 @@ func (e *TerraformEnvPrinter) generateBackendOverrideTf() error {
// The function supports local, s3, and kubernetes backends.
// It also includes backend.tfvars if present in the context directory.
func (e *TerraformEnvPrinter) generateBackendConfigArgs(projectPath, configRoot string) ([]string, error) {
backend := e.configHandler.GetConfig().Terraform.Backend
backendType := e.configHandler.GetString("terraform.backend.type", "")
if backendType == "" {
switch {
case backend.S3 != nil:
backendType = "s3"
case backend.Kubernetes != nil:
backendType = "kubernetes"
case backend.Local != nil:
backendType = "local"
default:
backendType = "local"
}
}
backendType := e.configHandler.GetString("terraform.backend.type", "local")

var backendConfigArgs []string

Expand All @@ -206,20 +280,20 @@ func (e *TerraformEnvPrinter) generateBackendConfigArgs(projectPath, configRoot
addBackendConfigArg("path", filepath.ToSlash(filepath.Join(configRoot, ".tfstate", projectPath, "terraform.tfstate")))
case "s3":
addBackendConfigArg("key", filepath.ToSlash(filepath.Join(projectPath, "terraform.tfstate")))
if backend.S3 != nil {
if err := processBackendConfig(backend.S3, addBackendConfigArg); err != nil {
if backend := e.configHandler.GetConfig().Terraform.Backend.S3; backend != nil {
if err := processBackendConfig(backend, addBackendConfigArg); err != nil {
return nil, fmt.Errorf("error processing S3 backend config: %w", err)
}
}
case "kubernetes":
addBackendConfigArg("secret_suffix", sanitizeForK8s(projectPath))
if backend.Kubernetes != nil {
if err := processBackendConfig(backend.Kubernetes, addBackendConfigArg); err != nil {
if backend := e.configHandler.GetConfig().Terraform.Backend.Kubernetes; backend != nil {
if err := processBackendConfig(backend, addBackendConfigArg); err != nil {
return nil, fmt.Errorf("error processing Kubernetes backend config: %w", err)
}
}
default:
return nil, fmt.Errorf("unsupported backend: %s", backend.Type)
return nil, fmt.Errorf("unsupported backend: %s", backendType)
}

return backendConfigArgs, nil
Expand Down
Loading
Loading