diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index 682a8b3b1..e4ac98718 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -64,6 +64,11 @@ const ( DEFAULT_NETWORK_CIDR = "10.5.0.0/16" ) +// Kubernetes settings +const ( + KUBERNETES_SHORT_TIMEOUT = 200 * time.Millisecond +) + // Minimum versions for tools const ( MINIMUM_VERSION_COLIMA = "0.7.0" diff --git a/pkg/env/aws_env.go b/pkg/env/aws_env.go index 5255da1b1..e2b5a66a8 100644 --- a/pkg/env/aws_env.go +++ b/pkg/env/aws_env.go @@ -1,3 +1,8 @@ +// The AwsEnvPrinter is a specialized component that manages AWS environment configuration. +// It provides AWS-specific environment variable management and configuration, +// The AwsEnvPrinter handles AWS profile, endpoint, and S3 configuration settings, +// ensuring proper AWS CLI integration and environment setup for AWS operations. + package env import ( @@ -8,20 +13,30 @@ import ( "github.com/windsorcli/cli/pkg/di" ) -// AwsEnvPrinter is a struct that simulates an AWS environment for testing purposes. +// ============================================================================= +// Types +// ============================================================================= + +// AwsEnvPrinter is a struct that implements AWS environment configuration type AwsEnvPrinter struct { BaseEnvPrinter } -// NewAwsEnvPrinter initializes a new awsEnv instance using the provided dependency injector. +// ============================================================================= +// Constructor +// ============================================================================= + +// NewAwsEnvPrinter creates a new AwsEnvPrinter instance func NewAwsEnvPrinter(injector di.Injector) *AwsEnvPrinter { return &AwsEnvPrinter{ - BaseEnvPrinter: BaseEnvPrinter{ - injector: injector, - }, + BaseEnvPrinter: *NewBaseEnvPrinter(injector), } } +// ============================================================================= +// Public Methods +// ============================================================================= + // GetEnvVars retrieves the environment variables for the AWS environment. func (e *AwsEnvPrinter) GetEnvVars() (map[string]string, error) { envVars := make(map[string]string) @@ -42,13 +57,13 @@ func (e *AwsEnvPrinter) GetEnvVars() (map[string]string, error) { // Construct the path to the AWS configuration file and verify its existence. awsConfigPath := filepath.Join(configRoot, ".aws", "config") - if _, err := stat(awsConfigPath); os.IsNotExist(err) { + if _, err := e.shims.Stat(awsConfigPath); os.IsNotExist(err) { awsConfigPath = "" } // Populate environment variables with AWS configuration data. if awsConfigPath != "" { - envVars["AWS_CONFIG_FILE"] = awsConfigPath + envVars["AWS_CONFIG_FILE"] = filepath.ToSlash(awsConfigPath) } if contextConfigData.AWS.AWSProfile != nil { envVars["AWS_PROFILE"] = *contextConfigData.AWS.AWSProfile @@ -77,5 +92,5 @@ func (e *AwsEnvPrinter) Print() error { return e.BaseEnvPrinter.Print(envVars) } -// Ensure awsEnv implements the EnvPrinter interface +// Ensure AwsEnvPrinter implements the EnvPrinter interface var _ EnvPrinter = (*AwsEnvPrinter)(nil) diff --git a/pkg/env/aws_env_test.go b/pkg/env/aws_env_test.go index c40a5ad59..74d27c8de 100644 --- a/pkg/env/aws_env_test.go +++ b/pkg/env/aws_env_test.go @@ -1,265 +1,274 @@ package env import ( - "errors" "fmt" "os" "path/filepath" "reflect" + "strings" "testing" + "github.com/goccy/go-yaml" "github.com/windsorcli/cli/api/v1alpha1" - "github.com/windsorcli/cli/api/v1alpha1/aws" "github.com/windsorcli/cli/pkg/config" - "github.com/windsorcli/cli/pkg/di" - "github.com/windsorcli/cli/pkg/shell" ) -type AwsEnvMocks struct { - Injector di.Injector - ConfigHandler *config.MockConfigHandler - Shell *shell.MockShell -} - -func setupSafeAwsEnvMocks(injector ...di.Injector) *AwsEnvMocks { - // Use the provided injector or create a new one if not provided - var mockInjector di.Injector - if len(injector) > 0 { - mockInjector = injector[0] - } else { - mockInjector = di.NewMockInjector() +// ============================================================================= +// Test Setup +// ============================================================================= + +func setupAwsEnvMocks(t *testing.T, opts ...*SetupOptions) *Mocks { + t.Helper() + if len(opts) == 0 || opts[0].ConfigStr == "" { + opts = []*SetupOptions{{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + aws: + aws_profile: default + aws_endpoint_url: https://aws.endpoint + s3_hostname: s3.amazonaws.com + mwaa_endpoint: https://mwaa.endpoint +`, + }} } - // Create a mock ConfigHandler using its constructor + // Create a mock config handler mockConfigHandler := config.NewMockConfigHandler() - mockConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - AWS: &aws.AWSConfig{ - AWSProfile: stringPtr("default"), - AWSEndpointURL: stringPtr("https://aws.endpoint"), - S3Hostname: stringPtr("s3.amazonaws.com"), - MWAAEndpoint: stringPtr("https://mwaa.endpoint"), - }, - } - } + + // Set up the GetConfigRoot function to return a mock path mockConfigHandler.GetConfigRootFunc = func() (string, error) { - return filepath.FromSlash("/mock/config/root"), nil - } - mockConfigHandler.GetContextFunc = func() string { - return "test-context" + return "/mock/config/root", nil } - // Create a mock Shell using its constructor - mockShell := shell.NewMockShell() + // Set up the GetConfig function to return a mock config + mockConfigHandler.GetConfigFunc = func() *v1alpha1.Context { + // Parse the config string + var config v1alpha1.Config + if err := yaml.Unmarshal([]byte(opts[0].ConfigStr), &config); err != nil { + t.Fatalf("Failed to unmarshal config: %v", err) + } - // Register the mocks in the DI injector - mockInjector.Register("configHandler", mockConfigHandler) - mockInjector.Register("shell", mockShell) + // Return the context for the test-context + if ctx, ok := config.Contexts["test-context"]; ok { + return ctx + } + return &v1alpha1.Context{} + } - return &AwsEnvMocks{ - Injector: mockInjector, + // Create mocks with the mock config handler + mocks := setupMocks(t, &SetupOptions{ ConfigHandler: mockConfigHandler, - Shell: mockShell, + }) + + if err := mocks.ConfigHandler.Initialize(); err != nil { + t.Fatalf("Failed to initialize config handler: %v", err) + } + if err := mocks.ConfigHandler.SetContext("test-context"); err != nil { + t.Fatalf("Failed to set context: %v", err) } + + // Set up shims for AWS config file check + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == filepath.FromSlash("/mock/config/root/.aws/config") { + return nil, nil + } + return nil, os.ErrNotExist + } + + return mocks } -func TestAwsEnv_GetEnvVars(t *testing.T) { - t.Run("Success", func(t *testing.T) { - // Use setupSafeAwsEnvMocks to create mocks - mocks := setupSafeAwsEnvMocks() +// ============================================================================= +// Test Public Methods +// ============================================================================= - // Mock the stat function to simulate the existence of the AWS config file - stat = func(name string) (os.FileInfo, error) { - if name == filepath.FromSlash("/mock/config/root/.aws/config") { - return nil, nil // Simulate that the file exists - } - return nil, os.ErrNotExist +// TestAwsEnv_GetEnvVars tests the GetEnvVars method of the AwsEnvPrinter +func TestAwsEnv_GetEnvVars(t *testing.T) { + setup := func() (*AwsEnvPrinter, *Mocks) { + mocks := setupAwsEnvMocks(t) + env := NewAwsEnvPrinter(mocks.Injector) + if err := env.Initialize(); err != nil { + t.Fatalf("Failed to initialize env: %v", err) } + env.shims = mocks.Shims + return env, mocks + } - awsEnvPrinter := NewAwsEnvPrinter(mocks.Injector) - awsEnvPrinter.Initialize() + t.Run("Success", func(t *testing.T) { + env, _ := setup() - // When calling GetEnvVars - envVars, err := awsEnvPrinter.GetEnvVars() + envVars, err := env.GetEnvVars() if err != nil { - t.Fatalf("GetEnvVars returned an error: %v", err) + t.Errorf("GetEnvVars returned an error: %v", err) } - // Then the environment variables should be set correctly - expectedConfigFile := filepath.FromSlash("/mock/config/root/.aws/config") - if envVars["AWS_CONFIG_FILE"] != expectedConfigFile { - t.Errorf("AWS_CONFIG_FILE = %v, want %v", envVars["AWS_CONFIG_FILE"], expectedConfigFile) + expected := map[string]string{ + "AWS_PROFILE": "default", + "AWS_ENDPOINT_URL": "https://aws.endpoint", + "S3_HOSTNAME": "s3.amazonaws.com", + "MWAA_ENDPOINT": "https://mwaa.endpoint", + "AWS_CONFIG_FILE": "/mock/config/root/.aws/config", } - }) - t.Run("MissingConfiguration", func(t *testing.T) { - // Use setupSafeAwsEnvMocks to create mocks - mocks := setupSafeAwsEnvMocks() - - // Override the GetConfigFunc to return nil for AWS configuration - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{AWS: nil} + if !reflect.DeepEqual(envVars, expected) { + t.Errorf("GetEnvVars returned %v, want %v", envVars, expected) } + }) - mockInjector := mocks.Injector - - awsEnvPrinter := NewAwsEnvPrinter(mockInjector) - awsEnvPrinter.Initialize() + t.Run("NonExistentConfigFile", func(t *testing.T) { + env, _ := setup() - // Capture stdout - output := captureStdout(t, func() { - // When calling GetEnvVars - _, err := awsEnvPrinter.GetEnvVars() - if err != nil { - fmt.Println(err) - } - }) - - // Then the output should indicate the missing configuration - expectedOutput := "context configuration or AWS configuration is missing\n" - if output != expectedOutput { - t.Errorf("output = %v, want %v", output, expectedOutput) + // Override shims to make AWS config file not exist + env.shims.Stat = func(name string) (os.FileInfo, error) { + return nil, os.ErrNotExist } - }) - t.Run("NoAwsConfigFile", func(t *testing.T) { - // Use setupSafeAwsEnvMocks to create mocks - mocks := setupSafeAwsEnvMocks() - - // Override the GetConfigFunc to return a valid AWS configuration - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - AWS: &aws.AWSConfig{ - AWSProfile: stringPtr("default"), - AWSEndpointURL: stringPtr("https://example.com"), - S3Hostname: stringPtr("s3.example.com"), - MWAAEndpoint: stringPtr("mwaa.example.com"), - }, - } + envVars, err := env.GetEnvVars() + if err != nil { + t.Errorf("GetEnvVars returned an error: %v", err) } - // Override the GetConfigRootFunc to return a valid path - mocks.ConfigHandler.GetConfigRootFunc = func() (string, error) { - return "/non/existent/path", nil + expected := map[string]string{ + "AWS_PROFILE": "default", + "AWS_ENDPOINT_URL": "https://aws.endpoint", + "S3_HOSTNAME": "s3.amazonaws.com", + "MWAA_ENDPOINT": "https://mwaa.endpoint", } - mockInjector := mocks.Injector - - awsEnvPrinter := NewAwsEnvPrinter(mockInjector) - awsEnvPrinter.Initialize() + if !reflect.DeepEqual(envVars, expected) { + t.Errorf("GetEnvVars returned %v, want %v", envVars, expected) + } + }) - // Capture stdout - output := captureStdout(t, func() { - // When calling GetEnvVars - _, err := awsEnvPrinter.GetEnvVars() - if err != nil { - fmt.Println(err) - } + t.Run("MissingConfiguration", func(t *testing.T) { + mocks := setupAwsEnvMocks(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: {} +`, }) + env := NewAwsEnvPrinter(mocks.Injector) + if err := env.Initialize(); err != nil { + t.Fatalf("Failed to initialize env: %v", err) + } - // Then the output should not include AWS_CONFIG_FILE and should not indicate an error - if output != "" { - t.Errorf("output = %v, want empty output", output) + _, err := env.GetEnvVars() + if err == nil { + t.Error("GetEnvVars did not return an error") + } + if !strings.Contains(err.Error(), "context configuration or AWS configuration is missing") { + t.Errorf("GetEnvVars returned error %v, want error containing 'context configuration or AWS configuration is missing'", err) } }) t.Run("GetConfigRootError", func(t *testing.T) { - // Use setupSafeAwsEnvMocks to create mocks - mocks := setupSafeAwsEnvMocks() + mocks := setupAwsEnvMocks(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + aws: + aws_profile: default +`, + }) - // Override the GetConfigRootFunc to simulate an error - mocks.ConfigHandler.GetConfigRootFunc = func() (string, error) { - return "", errors.New("mock context error") + // Mock the GetConfigRoot function to return an error + mockConfigHandler := mocks.ConfigHandler.(*config.MockConfigHandler) + mockConfigHandler.GetConfigRootFunc = func() (string, error) { + return "", fmt.Errorf("error retrieving configuration root directory") } - mockInjector := mocks.Injector - - awsEnvPrinter := NewAwsEnvPrinter(mockInjector) - awsEnvPrinter.Initialize() - - // Capture stdout - output := captureStdout(t, func() { - // When calling GetEnvVars - _, err := awsEnvPrinter.GetEnvVars() - if err != nil { - fmt.Println(err) - } - }) + env := NewAwsEnvPrinter(mocks.Injector) + if err := env.Initialize(); err != nil { + t.Fatalf("Failed to initialize env: %v", err) + } - // Then the output should indicate the error - expectedOutput := "error retrieving configuration root directory: mock context error\n" - if output != expectedOutput { - t.Errorf("output = %v, want %v", output, expectedOutput) + _, err := env.GetEnvVars() + if err == nil { + t.Error("GetEnvVars did not return an error") + } + if !strings.Contains(err.Error(), "error retrieving configuration root directory") { + t.Errorf("GetEnvVars returned error %v, want error containing 'error retrieving configuration root directory'", err) } }) } +// TestAwsEnv_Print tests the Print method of the AwsEnvPrinter func TestAwsEnv_Print(t *testing.T) { + setup := func() (*AwsEnvPrinter, *Mocks) { + mocks := setupAwsEnvMocks(t) + env := NewAwsEnvPrinter(mocks.Injector) + if err := env.Initialize(); err != nil { + t.Fatalf("Failed to initialize env: %v", err) + } + env.shims = mocks.Shims + return env, mocks + } + t.Run("Success", func(t *testing.T) { - // Use setupSafeAwsEnvMocks to create mocks - mocks := setupSafeAwsEnvMocks() - mockInjector := mocks.Injector - awsEnvPrinter := NewAwsEnvPrinter(mockInjector) - awsEnvPrinter.Initialize() - - // Mock the stat function to simulate the existence of the AWS config file - stat = func(name string) (os.FileInfo, error) { + env, mocks := setup() + + // Mock stat function to make AWS config file exist + env.shims.Stat = func(name string) (os.FileInfo, error) { if name == filepath.FromSlash("/mock/config/root/.aws/config") { - return nil, nil // Simulate that the file exists + return nil, nil } return nil, os.ErrNotExist } - // Mock the PrintEnvVarsFunc to verify it is called with the correct envVars + // Mock PrintEnvVarsFunc to capture printed vars var capturedEnvVars map[string]string mocks.Shell.PrintEnvVarsFunc = func(envVars map[string]string) { capturedEnvVars = envVars } - // Call Print and check for errors - err := awsEnvPrinter.Print() + // When calling Print + err := env.Print() + + // Then no error should be returned if err != nil { - t.Errorf("unexpected error: %v", err) + t.Errorf("Print returned an error: %v", err) } - // Verify that PrintEnvVarsFunc was called with the correct envVars + // And environment variables should be set correctly expectedEnvVars := map[string]string{ - "AWS_CONFIG_FILE": filepath.FromSlash("/mock/config/root/.aws/config"), "AWS_PROFILE": "default", "AWS_ENDPOINT_URL": "https://aws.endpoint", "S3_HOSTNAME": "s3.amazonaws.com", "MWAA_ENDPOINT": "https://mwaa.endpoint", + "AWS_CONFIG_FILE": "/mock/config/root/.aws/config", } if !reflect.DeepEqual(capturedEnvVars, expectedEnvVars) { - t.Errorf("capturedEnvVars = %v, want %v", capturedEnvVars, expectedEnvVars) + t.Errorf("Print set environment variables to %v, want %v", capturedEnvVars, expectedEnvVars) } }) - t.Run("Error", func(t *testing.T) { - // Use setupSafeAwsEnvMocks to create mocks - mocks := setupSafeAwsEnvMocks() - - // Set AWS configuration to nil to simulate the error condition - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - AWS: nil, - } + t.Run("GetConfigError", func(t *testing.T) { + // Given a new AwsEnvPrinter with failing config lookup + mocks := setupAwsEnvMocks(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: {} +`, + }) + env := NewAwsEnvPrinter(mocks.Injector) + if err := env.Initialize(); err != nil { + t.Fatalf("Failed to initialize env: %v", err) } - mockInjector := mocks.Injector - awsEnvPrinter := NewAwsEnvPrinter(mockInjector) - awsEnvPrinter.Initialize() + // When calling Print + err := env.Print() - // Call Print and expect an error - err := awsEnvPrinter.Print() + // Then appropriate error should be returned if err == nil { - t.Error("expected error, got nil") + t.Error("Print did not return an error") } - - // Verify the error message - expectedError := "error getting environment variables: context configuration or AWS configuration is missing" - if err.Error() != expectedError { - t.Errorf("error = %v, want %v", err.Error(), expectedError) + if !strings.Contains(err.Error(), "context configuration or AWS configuration is missing") { + t.Errorf("Print returned error %v, want error containing 'context configuration or AWS configuration is missing'", err) } }) } diff --git a/pkg/env/docker_env.go b/pkg/env/docker_env.go index 668002099..5fbc0844f 100644 --- a/pkg/env/docker_env.go +++ b/pkg/env/docker_env.go @@ -1,3 +1,8 @@ +// The DockerEnvPrinter is a specialized component that manages Docker environment configuration. +// It provides Docker-specific environment variable management and configuration, +// The DockerEnvPrinter handles Docker host, context, and registry configuration settings, +// ensuring proper Docker CLI integration and environment setup for container operations. + package env import ( @@ -9,20 +14,30 @@ import ( "github.com/windsorcli/cli/pkg/di" ) -// DockerEnvPrinter is a struct that simulates a Docker environment for testing purposes. +// ============================================================================= +// Types +// ============================================================================= + +// DockerEnvPrinter is a struct that implements Docker environment configuration type DockerEnvPrinter struct { BaseEnvPrinter } -// NewDockerEnvPrinter initializes a new dockerEnv instance using the provided dependency injector. +// ============================================================================= +// Constructor +// ============================================================================= + +// NewDockerEnvPrinter creates a new DockerEnvPrinter instance func NewDockerEnvPrinter(injector di.Injector) *DockerEnvPrinter { return &DockerEnvPrinter{ - BaseEnvPrinter: BaseEnvPrinter{ - injector: injector, - }, + BaseEnvPrinter: *NewBaseEnvPrinter(injector), } } +// ============================================================================= +// Public Methods +// ============================================================================= + // GetEnvVars sets Docker-specific env vars, using DOCKER_HOST from vm.driver config or existing env. // Defaults to WINDSORCONFIG or home dir for Docker paths, ensuring config directory exists. // Writes config if content changes, adds DOCKER_CONFIG and REGISTRY_URL, and returns the map. @@ -30,11 +45,11 @@ func NewDockerEnvPrinter(injector di.Injector) *DockerEnvPrinter { func (e *DockerEnvPrinter) GetEnvVars() (map[string]string, error) { envVars := make(map[string]string) - if dockerHostValue, dockerHostExists := osLookupEnv("DOCKER_HOST"); dockerHostExists { + if dockerHostValue, dockerHostExists := e.shims.LookupEnv("DOCKER_HOST"); dockerHostExists { envVars["DOCKER_HOST"] = dockerHostValue } else { vmDriver := e.configHandler.GetString("vm.driver") - homeDir, err := osUserHomeDir() + homeDir, err := e.shims.UserHomeDir() if err != nil { return nil, fmt.Errorf("error retrieving user home directory: %w", err) } @@ -48,29 +63,30 @@ func (e *DockerEnvPrinter) GetEnvVars() (map[string]string, error) { // Determine the Docker context name based on the VM driver var contextName string + configContext := e.configHandler.GetContext() - switch vmDriver { - case "colima": - configContext := e.configHandler.GetContext() - contextName = fmt.Sprintf("colima-windsor-%s", configContext) - dockerHostPath := fmt.Sprintf("unix://%s/.colima/windsor-%s/docker.sock", homeDir, configContext) - envVars["DOCKER_HOST"] = dockerHostPath - - case "docker-desktop": + if e.shims.Goos() == "windows" { contextName = "desktop-linux" - if goos() == "windows" { - envVars["DOCKER_HOST"] = "npipe:////./pipe/docker_engine" - } else { + envVars["DOCKER_HOST"] = "npipe:////./pipe/docker_engine" + } else { + switch vmDriver { + case "colima": + contextName = fmt.Sprintf("colima-windsor-%s", configContext) + dockerHostPath := fmt.Sprintf("unix://%s/.colima/windsor-%s/docker.sock", homeDir, configContext) + envVars["DOCKER_HOST"] = dockerHostPath + + case "docker-desktop": + contextName = "desktop-linux" dockerHostPath := fmt.Sprintf("unix://%s/.docker/run/docker.sock", homeDir) envVars["DOCKER_HOST"] = dockerHostPath - } - case "docker": - contextName = "default" - envVars["DOCKER_HOST"] = "unix:///var/run/docker.sock" + case "docker": + contextName = "default" + envVars["DOCKER_HOST"] = "unix:///var/run/docker.sock" - default: - contextName = "default" + default: + contextName = "default" + } } // Create Docker config content with the determined context name @@ -81,17 +97,17 @@ func (e *DockerEnvPrinter) GetEnvVars() (map[string]string, error) { "features": {} }`, contextName) - if err := mkdirAll(dockerConfigDir, 0755); err != nil { + if err := e.shims.MkdirAll(dockerConfigDir, 0755); err != nil { return nil, fmt.Errorf("error creating docker config directory: %w", err) } - existingContent, err := readFile(dockerConfigPath) + existingContent, err := e.shims.ReadFile(dockerConfigPath) if err != nil || string(existingContent) != dockerConfigContent { - if err := writeFile(dockerConfigPath, []byte(dockerConfigContent), 0644); err != nil { + if err := e.shims.WriteFile(dockerConfigPath, []byte(dockerConfigContent), 0644); err != nil { return nil, fmt.Errorf("error writing docker config file: %w", err) } } - envVars["DOCKER_CONFIG"] = dockerConfigDir + envVars["DOCKER_CONFIG"] = filepath.ToSlash(dockerConfigDir) } registryURL, _ := e.getRegistryURL() @@ -107,7 +123,7 @@ func (e *DockerEnvPrinter) GetEnvVars() (map[string]string, error) { // alias for docker-compose. func (e *DockerEnvPrinter) GetAlias() (map[string]string, error) { aliasMap := make(map[string]string) - if _, err := execLookPath("docker-cli-plugin-docker-compose"); err == nil { + if _, err := e.shims.LookPath("docker-cli-plugin-docker-compose"); err == nil { aliasMap["docker-compose"] = "docker-cli-plugin-docker-compose" } return aliasMap, nil @@ -122,28 +138,47 @@ func (e *DockerEnvPrinter) Print() error { return e.BaseEnvPrinter.Print(envVars) } -// getRegistryURL retrieves a registry URL, appending a port if not present. -// It retrieves the URL from the configuration and checks if it already includes a port. -// If not, it looks for a matching registry configuration to append the host port. -// Returns the constructed URL or an empty string if no URL is configured. +// ============================================================================= +// Private Methods +// ============================================================================= + +// getRegistryURL returns the configured Docker registry URL with port. +// Priority: +// 1. docker.registry_url setting (with port from registry config if needed) +// 2. First non-mirror registry from docker.registries +// +// Returns empty string if no registry is configured. func (e *DockerEnvPrinter) getRegistryURL() (string, error) { - config := e.configHandler.GetConfig() registryURL := e.configHandler.GetString("docker.registry_url") - if registryURL == "" { - return "", nil - } - if _, _, err := net.SplitHostPort(registryURL); err == nil { + if registryURL != "" { + if _, _, err := net.SplitHostPort(registryURL); err == nil { + return registryURL, nil + } + config := e.configHandler.GetConfig() + if config.Docker != nil && config.Docker.Registries != nil { + if registryConfig, exists := config.Docker.Registries[registryURL]; exists { + if registryConfig.HostPort != 0 { + return fmt.Sprintf("%s:%d", registryURL, registryConfig.HostPort), nil + } + } + } return registryURL, nil } + + config := e.configHandler.GetConfig() if config.Docker != nil && config.Docker.Registries != nil { - if registryConfig, exists := config.Docker.Registries[registryURL]; exists { - if registryConfig.HostPort != 0 { - registryURL = fmt.Sprintf("%s:%d", registryURL, registryConfig.HostPort) + for url, registryConfig := range config.Docker.Registries { + if registryConfig.Remote == "" { + if registryConfig.HostPort != 0 { + return fmt.Sprintf("%s:%d", url, registryConfig.HostPort), nil + } + return fmt.Sprintf("%s:5000", url), nil } } } - return registryURL, nil + + return "", nil } -// Ensure dockerEnv implements the EnvPrinter interface +// Ensure DockerEnvPrinter implements the EnvPrinter interface var _ EnvPrinter = (*DockerEnvPrinter)(nil) diff --git a/pkg/env/docker_env_test.go b/pkg/env/docker_env_test.go index 14bafa8f6..dbaa4a6da 100644 --- a/pkg/env/docker_env_test.go +++ b/pkg/env/docker_env_test.go @@ -9,217 +9,240 @@ import ( "strings" "testing" - "github.com/windsorcli/cli/api/v1alpha1" - "github.com/windsorcli/cli/api/v1alpha1/docker" "github.com/windsorcli/cli/pkg/config" "github.com/windsorcli/cli/pkg/di" - "github.com/windsorcli/cli/pkg/services" "github.com/windsorcli/cli/pkg/shell" ) +// ============================================================================= +// Test Setup +// ============================================================================= + +// DockerEnvPrinterMocks holds all mock objects used in Docker environment tests type DockerEnvPrinterMocks struct { Injector di.Injector Shell *shell.MockShell ConfigHandler *config.MockConfigHandler } -func setupSafeDockerEnvPrinterMocks(injector ...di.Injector) *DockerEnvPrinterMocks { - var mockInjector di.Injector - if len(injector) > 0 { - mockInjector = injector[0] - } else { - mockInjector = di.NewInjector() - } - - mockShell := shell.NewMockShell() - - mockConfigHandler := config.NewMockConfigHandler() - mockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - switch key { - case "vm.driver": - return "colima" - case "dns.domain": - return "mock-domain" - case "docker.registry_url": - return "mock-registry-url" - default: - if len(defaultValue) > 0 { - return defaultValue[0] - } - return "" - } - } - mockConfigHandler.GetConfigRootFunc = func() (string, error) { - return filepath.Join("mock", "config", "root"), nil - } - mockConfigHandler.GetContextFunc = func() string { - return "mock-context" - } - mockConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - Docker: &docker.DockerConfig{ - Registries: map[string]docker.RegistryConfig{ - "mock-registry-url": { - HostPort: 5000, - }, - }, - }, - } - } - - mkdirAll = func(path string, perm os.FileMode) error { - return nil +// setupDockerEnvMocks creates a new set of mocks for Docker environment tests +func setupDockerEnvMocks(t *testing.T, opts ...*SetupOptions) *Mocks { + t.Helper() + if len(opts) == 0 || opts[0].ConfigStr == "" { + opts = []*SetupOptions{{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: colima + docker: + registries: + mock-registry-url: + hostport: 5000 +`, + }} } - writeFile = func(filename string, data []byte, perm os.FileMode) error { - return nil - } + // Create mocks with the config + mocks := setupMocks(t, opts...) - readFile = func(_ string) ([]byte, error) { - return nil, nil - } + // Set the context + mocks.ConfigHandler.SetContext("test-context") - osUserHomeDir = func() (string, error) { - return filepath.ToSlash("/mock/home"), nil + // Set up shims for Docker operations + mocks.Shims.UserHomeDir = func() (string, error) { + return "/mock/home", nil } - // Use the real RegistryService - registryService := services.NewRegistryService(mockInjector) - registryService.SetName("mock-registry") - registryService.SetAddress("mock-registry-url") - - mockInjector.Register("shell", mockShell) - mockInjector.Register("configHandler", mockConfigHandler) - mockInjector.Register("registryService", registryService) - - // Initialize the RegistryService - registryService.Initialize() - - return &DockerEnvPrinterMocks{ - Injector: mockInjector, - Shell: mockShell, - ConfigHandler: mockConfigHandler, - } + return mocks } +// ============================================================================= +// Test Public Methods +// ============================================================================= + +// TestDockerEnvPrinter_GetEnvVars tests the GetEnvVars method of the DockerEnvPrinter func TestDockerEnvPrinter_GetEnvVars(t *testing.T) { // Save original env var and restore after all tests originalDockerHost := os.Getenv("DOCKER_HOST") defer os.Setenv("DOCKER_HOST", originalDockerHost) t.Run("Success", func(t *testing.T) { - // Clear any existing DOCKER_HOST for this test - os.Unsetenv("DOCKER_HOST") + // Given a new DockerEnvPrinter with default configuration + mocks := setupDockerEnvMocks(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: colima + docker: + registries: + mock-registry-url: + hostport: 5000 +`, + }) + + printer := NewDockerEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + printer.Initialize() - mocks := setupSafeDockerEnvPrinterMocks() - - dockerEnvPrinter := NewDockerEnvPrinter(mocks.Injector) - dockerEnvPrinter.Initialize() + // When getting environment variables + envVars, err := printer.GetEnvVars() - envVars, err := dockerEnvPrinter.GetEnvVars() + // Then no error should be returned if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } - expectedDockerHost := fmt.Sprintf("unix://%s/.colima/windsor-mock-context/docker.sock", filepath.ToSlash("/mock/home")) + // And DOCKER_HOST should be set based on vm driver and OS + var expectedDockerHost string + if mocks.Shims.Goos() == "windows" { + expectedDockerHost = "npipe:////./pipe/docker_engine" + } else { + expectedDockerHost = fmt.Sprintf("unix://%s/.colima/windsor-test-context/docker.sock", filepath.ToSlash("/mock/home")) + } if envVars["DOCKER_HOST"] != expectedDockerHost { t.Errorf("DOCKER_HOST = %v, want %v", envVars["DOCKER_HOST"], expectedDockerHost) } - if envVars["REGISTRY_URL"] != "mock-registry-url:5000" { - t.Errorf("REGISTRY_URL = %v, want mock-registry-url:5000", envVars["REGISTRY_URL"]) + // And REGISTRY_URL should be set based on registry configuration + expectedRegistryURL := "mock-registry-url:5000" + if envVars["REGISTRY_URL"] != expectedRegistryURL { + t.Errorf("REGISTRY_URL = %v, want %v", envVars["REGISTRY_URL"], expectedRegistryURL) + } + + // And DOCKER_CONFIG should be set + expectedDockerConfig := filepath.ToSlash("/mock/home/.config/windsor/docker") + if envVars["DOCKER_CONFIG"] != expectedDockerConfig { + t.Errorf("DOCKER_CONFIG = %v, want %v", envVars["DOCKER_CONFIG"], expectedDockerConfig) } }) t.Run("ColimaDriver", func(t *testing.T) { - // Clear any existing DOCKER_HOST for this test + // Given a new DockerEnvPrinter with Colima driver os.Unsetenv("DOCKER_HOST") + mocks := setupDockerEnvMocks(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: colima + docker: + registries: + mock-registry-url: + hostport: 5000 +`, + }) + + printer := NewDockerEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + printer.Initialize() - mocks := setupSafeDockerEnvPrinterMocks() - mocks.ConfigHandler.GetContextFunc = func() string { - return "test-context" - } - - dockerEnvPrinter := NewDockerEnvPrinter(mocks.Injector) - dockerEnvPrinter.Initialize() + // When getting environment variables + envVars, err := printer.GetEnvVars() - envVars, err := dockerEnvPrinter.GetEnvVars() + // Then no error should be returned if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } - expectedDockerHost := fmt.Sprintf("unix://%s/.colima/windsor-%s/docker.sock", filepath.ToSlash("/mock/home"), "test-context") + // And DOCKER_HOST should be set correctly for Colima and OS + var expectedDockerHost string + if mocks.Shims.Goos() == "windows" { + expectedDockerHost = "npipe:////./pipe/docker_engine" + } else { + expectedDockerHost = fmt.Sprintf("unix://%s/.colima/windsor-test-context/docker.sock", filepath.ToSlash("/mock/home")) + } if envVars["DOCKER_HOST"] != expectedDockerHost { t.Errorf("DOCKER_HOST = %v, want %v", envVars["DOCKER_HOST"], expectedDockerHost) } - if envVars["REGISTRY_URL"] != "mock-registry-url:5000" { - t.Errorf("REGISTRY_URL = %v, want mock-registry-url:5000", envVars["REGISTRY_URL"]) + // And REGISTRY_URL should be set correctly + expectedRegistryURL := "mock-registry-url:5000" + if envVars["REGISTRY_URL"] != expectedRegistryURL { + t.Errorf("REGISTRY_URL = %v, want %v", envVars["REGISTRY_URL"], expectedRegistryURL) + } + + // And DOCKER_CONFIG should be set correctly + expectedDockerConfig := filepath.ToSlash("/mock/home/.config/windsor/docker") + if envVars["DOCKER_CONFIG"] != expectedDockerConfig { + t.Errorf("DOCKER_CONFIG = %v, want %v", envVars["DOCKER_CONFIG"], expectedDockerConfig) } }) t.Run("DockerDesktopDriver", func(t *testing.T) { - // Clear any existing DOCKER_HOST for this test + // Given a new DockerEnvPrinter with Docker Desktop driver on Linux os.Unsetenv("DOCKER_HOST") - - mocks := setupSafeDockerEnvPrinterMocks() - - // Mock goos function to simulate different OS environments - originalGoos := goos - defer func() { goos = originalGoos }() - goos = func() string { + mocks := setupDockerEnvMocks(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: docker-desktop + docker: + registries: + mock-registry-url: + hostport: 5000 +`, + }) + + // And Linux OS environment + mocks.Shims.Goos = func() string { return "linux" } - // Mock mkdirAll function - originalMkdirAll := mkdirAll - defer func() { mkdirAll = originalMkdirAll }() + // And mock filesystem operations mkdirAllCalled := false mkdirAllPath := "" - mkdirAll = func(path string, perm os.FileMode) error { + mocks.Shims.MkdirAll = func(path string, perm os.FileMode) error { mkdirAllCalled = true mkdirAllPath = filepath.ToSlash(path) return nil } - // Mock writeFile function - originalWriteFile := writeFile - defer func() { writeFile = originalWriteFile }() writeFileCalled := false writeFilePath := "" - writeFile = func(filename string, data []byte, perm os.FileMode) error { + mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error { writeFileCalled = true writeFilePath = filepath.ToSlash(filename) return nil } - // Use the existing mockConfigHandler from mocks - originalGetStringFunc := mocks.ConfigHandler.GetStringFunc - mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "vm.driver" { - return "docker-desktop" - } - return originalGetStringFunc(key, defaultValue...) - } + printer := NewDockerEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + printer.Initialize() - dockerEnvPrinter := NewDockerEnvPrinter(mocks.Injector) - dockerEnvPrinter.Initialize() + // When getting environment variables + envVars, err := printer.GetEnvVars() - envVars, err := dockerEnvPrinter.GetEnvVars() + // Then no error should be returned if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } + // And DOCKER_HOST should be set correctly for Docker Desktop expectedDockerHost := fmt.Sprintf("unix://%s/.docker/run/docker.sock", filepath.ToSlash("/mock/home")) if envVars["DOCKER_HOST"] != expectedDockerHost { t.Errorf("DOCKER_HOST = %v, want %v", envVars["DOCKER_HOST"], expectedDockerHost) } + // And REGISTRY_URL should be set correctly expectedRegistryURL := "mock-registry-url:5000" if envVars["REGISTRY_URL"] != expectedRegistryURL { t.Errorf("REGISTRY_URL = %v, want %v", envVars["REGISTRY_URL"], expectedRegistryURL) } + // And DOCKER_CONFIG should be set correctly + expectedDockerConfig := filepath.ToSlash("/mock/home/.config/windsor/docker") + if envVars["DOCKER_CONFIG"] != expectedDockerConfig { + t.Errorf("DOCKER_CONFIG = %v, want %v", envVars["DOCKER_CONFIG"], expectedDockerConfig) + } + + // And directory should be created if !mkdirAllCalled { t.Error("mkdirAll was not called") } else { @@ -229,6 +252,7 @@ func TestDockerEnvPrinter_GetEnvVars(t *testing.T) { } } + // And config file should be written if !writeFileCalled { t.Error("writeFile was not called") } else { @@ -240,57 +264,70 @@ func TestDockerEnvPrinter_GetEnvVars(t *testing.T) { }) t.Run("DockerDriver", func(t *testing.T) { - // Clear any existing DOCKER_HOST for this test + // Given a new DockerEnvPrinter with Docker driver os.Unsetenv("DOCKER_HOST") + mocks := setupDockerEnvMocks(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: docker + docker: + registries: + mock-registry-url: + hostport: 5000 +`, + }) + + printer := NewDockerEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + printer.Initialize() - mocks := setupSafeDockerEnvPrinterMocks() - mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "vm.driver" { - return "docker" - } - return "" - } - - dockerEnvPrinter := NewDockerEnvPrinter(mocks.Injector) - dockerEnvPrinter.Initialize() + // When getting environment variables + envVars, err := printer.GetEnvVars() - envVars, err := dockerEnvPrinter.GetEnvVars() + // Then no error should be returned if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } - expectedDockerHost := "unix:///var/run/docker.sock" + // And DOCKER_HOST should be set correctly for Docker driver and OS + var expectedDockerHost string + if mocks.Shims.Goos() == "windows" { + expectedDockerHost = "npipe:////./pipe/docker_engine" + } else { + expectedDockerHost = "unix:///var/run/docker.sock" + } if envVars["DOCKER_HOST"] != expectedDockerHost { t.Errorf("DOCKER_HOST = %v, want %v", envVars["DOCKER_HOST"], expectedDockerHost) } - if envVars["REGISTRY_URL"] != "" { - t.Errorf("REGISTRY_URL = %v, want empty", envVars["REGISTRY_URL"]) + // And REGISTRY_URL should be set correctly + expectedRegistryURL := "mock-registry-url:5000" + if envVars["REGISTRY_URL"] != expectedRegistryURL { + t.Errorf("REGISTRY_URL = %v, want %v", envVars["REGISTRY_URL"], expectedRegistryURL) } }) t.Run("GetUserHomeDirError", func(t *testing.T) { - // Clear any existing DOCKER_HOST for this test + // Given a new DockerEnvPrinter with failing user home directory lookup os.Unsetenv("DOCKER_HOST") + mocks := setupDockerEnvMocks(t) - mocks := setupSafeDockerEnvPrinterMocks() - mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "vm.driver" { - return "colima" - } - return "" - } - - originalUserHomeDir := osUserHomeDir - defer func() { osUserHomeDir = originalUserHomeDir }() - osUserHomeDir = func() (string, error) { + // Override the UserHomeDir shim + mocks.Shims.UserHomeDir = func() (string, error) { return "", errors.New("mock user home dir error") } - dockerEnvPrinter := NewDockerEnvPrinter(mocks.Injector) - dockerEnvPrinter.Initialize() + printer := NewDockerEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + printer.Initialize() - _, err := dockerEnvPrinter.GetEnvVars() + // When getting environment variables + _, err := printer.GetEnvVars() + + // Then appropriate error should be returned if err == nil { t.Error("expected an error, got nil") } else if !strings.Contains(err.Error(), "mock user home dir error") { @@ -299,27 +336,23 @@ func TestDockerEnvPrinter_GetEnvVars(t *testing.T) { }) t.Run("MkdirAllError", func(t *testing.T) { - // Clear any existing DOCKER_HOST for this test + // Given a new DockerEnvPrinter with failing directory creation os.Unsetenv("DOCKER_HOST") + mocks := setupDockerEnvMocks(t) - mocks := setupSafeDockerEnvPrinterMocks() - mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "vm.driver" { - return "colima" - } - return "" - } - - originalMkdirAll := mkdirAll - defer func() { mkdirAll = originalMkdirAll }() - mkdirAll = func(path string, perm os.FileMode) error { + // Override the MkdirAll shim + mocks.Shims.MkdirAll = func(path string, perm os.FileMode) error { return errors.New("mock mkdirAll error") } - dockerEnvPrinter := NewDockerEnvPrinter(mocks.Injector) - dockerEnvPrinter.Initialize() + printer := NewDockerEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + printer.Initialize() - _, err := dockerEnvPrinter.GetEnvVars() + // When getting environment variables + _, err := printer.GetEnvVars() + + // Then appropriate error should be returned if err == nil { t.Error("expected an error, got nil") } else if !strings.Contains(err.Error(), "mock mkdirAll error") { @@ -328,27 +361,23 @@ func TestDockerEnvPrinter_GetEnvVars(t *testing.T) { }) t.Run("WriteFileError", func(t *testing.T) { - // Clear any existing DOCKER_HOST for this test + // Given a new DockerEnvPrinter with failing file write os.Unsetenv("DOCKER_HOST") + mocks := setupDockerEnvMocks(t) - mocks := setupSafeDockerEnvPrinterMocks() - mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "vm.driver" { - return "colima" - } - return "" - } - - originalWriteFile := writeFile - defer func() { writeFile = originalWriteFile }() - writeFile = func(filename string, data []byte, perm os.FileMode) error { + // Override the WriteFile shim + mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error { return errors.New("mock writeFile error") } - dockerEnvPrinter := NewDockerEnvPrinter(mocks.Injector) - dockerEnvPrinter.Initialize() + printer := NewDockerEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + printer.Initialize() - _, err := dockerEnvPrinter.GetEnvVars() + // When getting environment variables + _, err := printer.GetEnvVars() + + // Then appropriate error should be returned if err == nil { t.Error("expected an error, got nil") } else if !strings.Contains(err.Error(), "mock writeFile error") { @@ -381,32 +410,39 @@ func TestDockerEnvPrinter_GetEnvVars(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - // Clear any existing DOCKER_HOST for this test + // Given a new DockerEnvPrinter with specific OS os.Unsetenv("DOCKER_HOST") - - mocks := setupSafeDockerEnvPrinterMocks() - mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "vm.driver" { - return "docker-desktop" - } - return "" - } - - // Save original goos function and restore after test - originalGoos := goos - defer func() { goos = originalGoos }() - goos = func() string { + mocks := setupDockerEnvMocks(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: docker-desktop + docker: + registries: + mock-registry-url: + hostport: 5000 +`, + }) + + mocks.Shims.Goos = func() string { return tc.os } - dockerEnvPrinter := NewDockerEnvPrinter(mocks.Injector) - dockerEnvPrinter.Initialize() + printer := NewDockerEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + printer.Initialize() - envVars, err := dockerEnvPrinter.GetEnvVars() + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then no error should be returned if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } + // And DOCKER_HOST should be set correctly for the OS if envVars["DOCKER_HOST"] != tc.expected { t.Errorf("DOCKER_HOST = %v, want %v", envVars["DOCKER_HOST"], tc.expected) } @@ -415,34 +451,44 @@ func TestDockerEnvPrinter_GetEnvVars(t *testing.T) { }) t.Run("DockerHostFromEnvironment", func(t *testing.T) { - // Set a specific DOCKER_HOST for this test + // Given a new DockerEnvPrinter with DOCKER_HOST environment variable os.Setenv("DOCKER_HOST", "tcp://custom-docker-host:2375") defer os.Unsetenv("DOCKER_HOST") - mocks := setupSafeDockerEnvPrinterMocks() - - // Save original lookup function and restore after test - originalLookupEnv := osLookupEnv - defer func() { osLookupEnv = originalLookupEnv }() - - // Mock environment lookup to return a specific DOCKER_HOST value - osLookupEnv = func(key string) (string, bool) { + mocks := setupDockerEnvMocks(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: docker-desktop + docker: + registries: + mock-registry-url: + hostport: 5000 +`, + }) + + mocks.Shims.LookupEnv = func(key string) (string, bool) { if key == "DOCKER_HOST" { return "tcp://custom-docker-host:2375", true } return "", false } - dockerEnvPrinter := NewDockerEnvPrinter(mocks.Injector) - dockerEnvPrinter.Initialize() + printer := NewDockerEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + printer.Initialize() // When getting environment variables - envVars, err := dockerEnvPrinter.GetEnvVars() + envVars, err := printer.GetEnvVars() + + // Then no error should be returned if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } - // Then the DOCKER_HOST should match the environment value + // And DOCKER_HOST should match environment value expectedDockerHost := "tcp://custom-docker-host:2375" if envVars["DOCKER_HOST"] != expectedDockerHost { t.Errorf("DOCKER_HOST = %v, want %v", envVars["DOCKER_HOST"], expectedDockerHost) @@ -450,73 +496,85 @@ func TestDockerEnvPrinter_GetEnvVars(t *testing.T) { }) t.Run("DockerHostNotSet", func(t *testing.T) { - // Given a mock setup without DOCKER_HOST environment variable - mocks := setupSafeDockerEnvPrinterMocks() - - // Save original lookup function and restore after test - originalLookupEnv := osLookupEnv - defer func() { osLookupEnv = originalLookupEnv }() - - // Mock environment lookup to return no DOCKER_HOST - osLookupEnv = func(key string) (string, bool) { + // Given a new DockerEnvPrinter without DOCKER_HOST environment variable + mocks := setupDockerEnvMocks(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: colima + docker: + registries: + mock-registry-url: + hostport: 5000 +`, + }) + + mocks.Shims.LookupEnv = func(key string) (string, bool) { return "", false } - dockerEnvPrinter := NewDockerEnvPrinter(mocks.Injector) - dockerEnvPrinter.Initialize() + printer := NewDockerEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + printer.Initialize() // When getting environment variables - envVars, err := dockerEnvPrinter.GetEnvVars() + envVars, err := printer.GetEnvVars() + + // Then no error should be returned if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } - // Then the DOCKER_HOST should be set based on the vm driver configuration - expectedDockerHost := fmt.Sprintf("unix://%s/.colima/windsor-mock-context/docker.sock", filepath.ToSlash("/mock/home")) + // And DOCKER_HOST should be set based on vm driver and OS + var expectedDockerHost string + if mocks.Shims.Goos() == "windows" { + expectedDockerHost = "npipe:////./pipe/docker_engine" + } else { + expectedDockerHost = fmt.Sprintf("unix://%s/.colima/windsor-test-context/docker.sock", filepath.ToSlash("/mock/home")) + } if envVars["DOCKER_HOST"] != expectedDockerHost { t.Errorf("DOCKER_HOST = %v, want %v", envVars["DOCKER_HOST"], expectedDockerHost) } }) t.Run("DockerHostFromEnvironmentOverridesDriver", func(t *testing.T) { - // Given a mock setup with both DOCKER_HOST env var and vm driver configured - mocks := setupSafeDockerEnvPrinterMocks() - - // Save original lookup function and restore after test - originalLookupEnv := osLookupEnv - defer func() { osLookupEnv = originalLookupEnv }() - - // Mock environment lookup to return a specific DOCKER_HOST value - osLookupEnv = func(key string) (string, bool) { + // Given a new DockerEnvPrinter with both DOCKER_HOST and vm driver + mocks := setupDockerEnvMocks(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: docker-desktop + docker: + registries: + mock-registry-url: + hostport: 5000 +`, + }) + + mocks.Shims.LookupEnv = func(key string) (string, bool) { if key == "DOCKER_HOST" { return "tcp://override-host:2375", true } return "", false } - // Configure vm.driver to ensure it would set a different value - mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - switch key { - case "vm.driver": - return "docker-desktop" - default: - if len(defaultValue) > 0 { - return defaultValue[0] - } - return "" - } - } - - dockerEnvPrinter := NewDockerEnvPrinter(mocks.Injector) - dockerEnvPrinter.Initialize() + printer := NewDockerEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + printer.Initialize() // When getting environment variables - envVars, err := dockerEnvPrinter.GetEnvVars() + envVars, err := printer.GetEnvVars() + + // Then no error should be returned if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } - // Then the DOCKER_HOST should match the environment value, not the driver-based value + // And DOCKER_HOST should match environment value, not driver value expectedDockerHost := "tcp://override-host:2375" if envVars["DOCKER_HOST"] != expectedDockerHost { t.Errorf("DOCKER_HOST = %v, want %v", envVars["DOCKER_HOST"], expectedDockerHost) @@ -524,26 +582,32 @@ func TestDockerEnvPrinter_GetEnvVars(t *testing.T) { }) } +// TestDockerEnvPrinter_GetAlias tests the GetAlias method of the DockerEnvPrinter func TestDockerEnvPrinter_GetAlias(t *testing.T) { t.Run("Success", func(t *testing.T) { - mocks := setupSafeDockerEnvPrinterMocks() - originalExecLookPath := execLookPath - defer func() { execLookPath = originalExecLookPath }() - execLookPath = func(file string) (string, error) { + // Given a new DockerEnvPrinter with docker-compose plugin available + mocks := setupDockerEnvMocks(t) + + mocks.Shims.LookPath = func(file string) (string, error) { if file == "docker-cli-plugin-docker-compose" { return "/usr/local/bin/docker-cli-plugin-docker-compose", nil } return "", fmt.Errorf("not found") } - dockerEnvPrinter := NewDockerEnvPrinter(mocks.Injector) - dockerEnvPrinter.Initialize() + printer := NewDockerEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + printer.Initialize() - aliasMap, err := dockerEnvPrinter.GetAlias() + // When getting aliases + aliasMap, err := printer.GetAlias() + + // Then no error should be returned if err != nil { t.Fatalf("unexpected error: %v", err) } + // And docker-compose alias should be set correctly expectedAlias := "docker-cli-plugin-docker-compose" if aliasMap["docker-compose"] != expectedAlias { t.Errorf("aliasMap[docker-compose] = %v, want %v", aliasMap["docker-compose"], expectedAlias) @@ -551,75 +615,417 @@ func TestDockerEnvPrinter_GetAlias(t *testing.T) { }) t.Run("Failure", func(t *testing.T) { - mocks := setupSafeDockerEnvPrinterMocks() - originalExecLookPath := execLookPath - defer func() { execLookPath = originalExecLookPath }() - execLookPath = func(file string) (string, error) { + // Given a new DockerEnvPrinter without docker-compose plugin + mocks := setupDockerEnvMocks(t) + mocks.Shims.LookPath = func(file string) (string, error) { return "", fmt.Errorf("not found") } - dockerEnvPrinter := NewDockerEnvPrinter(mocks.Injector) - dockerEnvPrinter.Initialize() + printer := NewDockerEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + printer.Initialize() + + // When getting aliases + aliasMap, err := printer.GetAlias() - aliasMap, err := dockerEnvPrinter.GetAlias() + // Then no error should be returned if err != nil { t.Fatalf("unexpected error: %v", err) } + // And alias map should be empty if len(aliasMap) != 0 { t.Errorf("aliasMap = %v, want empty map", aliasMap) } }) } +// TestDockerEnvPrinter_Print tests the Print method of the DockerEnvPrinter func TestDockerEnvPrinter_Print(t *testing.T) { // Save original env var and restore after all tests originalDockerHost := os.Getenv("DOCKER_HOST") defer os.Setenv("DOCKER_HOST", originalDockerHost) t.Run("Success", func(t *testing.T) { - mocks := setupSafeDockerEnvPrinterMocks() - dockerEnvPrinter := NewDockerEnvPrinter(mocks.Injector) - dockerEnvPrinter.Initialize() + // Given a new DockerEnvPrinter + mocks := setupDockerEnvMocks(t) + printer := NewDockerEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + printer.Initialize() - // Mock the Print method of BaseEnvPrinter to capture the envVars + // And PrintEnvVarsFunc is mocked var capturedEnvVars map[string]string mocks.Shell.PrintEnvVarsFunc = func(envVars map[string]string) { capturedEnvVars = envVars } - // Call Print and check for errors - err := dockerEnvPrinter.Print() + // When calling Print + err := printer.Print() + + // Then no error should be returned if err != nil { t.Errorf("unexpected error: %v", err) } - // Verify that PrintEnvVarsFunc was called with the correct envVars - expectedEnvVars, _ := dockerEnvPrinter.GetEnvVars() + // And environment variables should be set correctly + expectedEnvVars, _ := printer.GetEnvVars() if !reflect.DeepEqual(capturedEnvVars, expectedEnvVars) { t.Errorf("capturedEnvVars = %v, want %v", capturedEnvVars, expectedEnvVars) } }) t.Run("GetEnvVarsError", func(t *testing.T) { - // Clear any existing DOCKER_HOST for this test + // Given a new DockerEnvPrinter with failing user home directory lookup os.Unsetenv("DOCKER_HOST") + mocks := setupDockerEnvMocks(t) - mocks := setupSafeDockerEnvPrinterMocks() - - // Mock osUserHomeDir to return an error - originalOsUserHomeDir := osUserHomeDir - defer func() { osUserHomeDir = originalOsUserHomeDir }() - osUserHomeDir = func() (string, error) { + // Override the UserHomeDir shim + mocks.Shims.UserHomeDir = func() (string, error) { return "", errors.New("mock error") } - dockerEnvPrinter := NewDockerEnvPrinter(mocks.Injector) - dockerEnvPrinter.Initialize() + printer := NewDockerEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + printer.Initialize() + + // When calling Print + err := printer.Print() - err := dockerEnvPrinter.Print() + // Then appropriate error should be returned if err == nil { t.Error("expected an error, got nil") } }) } + +// TestDockerEnvPrinter_getRegistryURL tests the getRegistryURL method of the DockerEnvPrinter +func TestDockerEnvPrinter_getRegistryURL(t *testing.T) { + // setup creates a new DockerEnvPrinter with the given configuration + setup := func(t *testing.T, opts ...*SetupOptions) (*DockerEnvPrinter, *Mocks) { + t.Helper() + mocks := setupDockerEnvMocks(t, opts...) + printer := NewDockerEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize printer: %v", err) + } + return printer, mocks + } + + t.Run("ValidRegistryURL", func(t *testing.T) { + // Given a DockerEnvPrinter with a valid registry URL in config + printer, _ := setup(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: colima + docker: + registry_url: registry.example.com:5000 +`, + }) + + // And the registry URL is set in the context + printer.configHandler.SetContextValue("docker.registry_url", "registry.example.com:5000") + + // When getting the registry URL + url, err := printer.getRegistryURL() + + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + // And the URL should match the config value + if url != "registry.example.com:5000" { + t.Errorf("Expected URL 'registry.example.com:5000', got %q", url) + } + }) + + t.Run("RegistryURLWithConfig", func(t *testing.T) { + // Given a DockerEnvPrinter with a registry URL and matching config + printer, _ := setup(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: colima + docker: + registry_url: registry.example.com + registries: + registry.example.com: + hostport: 5000 +`, + }) + + // And the registry URL is set in the context + printer.configHandler.SetContextValue("docker.registry_url", "registry.example.com") + + // When getting the registry URL + url, err := printer.getRegistryURL() + + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + // And the URL should include the hostport from config + if url != "registry.example.com:5000" { + t.Errorf("Expected URL 'registry.example.com:5000', got %q", url) + } + }) + + t.Run("EmptyRegistryURL", func(t *testing.T) { + // Given a DockerEnvPrinter with no registry URL but with registries config + printer, _ := setup(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: colima + docker: + registries: + mock-registry-url: + hostport: 5000 +`, + }) + + // When getting the registry URL + url, err := printer.getRegistryURL() + + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + // And the URL should be taken from the first registry in config + if url != "mock-registry-url:5000" { + t.Errorf("Expected URL 'mock-registry-url:5000', got %q", url) + } + }) + + t.Run("EmptyConfig", func(t *testing.T) { + // Given a DockerEnvPrinter with empty registries config + printer, _ := setup(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: colima + docker: + registries: {} +`, + }) + + // When getting the registry URL + url, err := printer.getRegistryURL() + + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + // And the URL should be empty + if url != "" { + t.Errorf("Expected empty URL, got %q", url) + } + }) + + t.Run("RegistryURLWithoutPortNoConfig", func(t *testing.T) { + // Given a DockerEnvPrinter with a registry URL without port and no matching config + printer, _ := setup(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: colima + docker: + registry_url: registry.example.com + registries: + other-registry: + hostport: 5000 +`, + }) + + // And the registry URL is set in the context + printer.configHandler.SetContextValue("docker.registry_url", "registry.example.com") + + // When getting the registry URL + url, err := printer.getRegistryURL() + + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + // And the URL should be returned as-is without a port + if url != "registry.example.com" { + t.Errorf("Expected URL 'registry.example.com', got %q", url) + } + }) + + t.Run("RegistryURLInvalidPort", func(t *testing.T) { + // Given a DockerEnvPrinter with a registry URL with invalid port + printer, _ := setup(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: colima + docker: + registry_url: registry.example.com:invalid +`, + }) + + // When getting the registry URL + url, err := printer.getRegistryURL() + + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + // And the URL should be the same as the config value + if url != "registry.example.com:invalid" { + t.Errorf("Expected URL 'registry.example.com:invalid', got %q", url) + } + }) + + t.Run("RegistryURLNoPortNoHostPort", func(t *testing.T) { + // Given a DockerEnvPrinter with a registry URL without port and no hostport in config + printer, _ := setup(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: colima + docker: + registry_url: registry.example.com + registries: + registry.example.com: {} +`, + }) + + // When getting the registry URL + url, err := printer.getRegistryURL() + + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + // And the URL should be the same as the config value + if url != "registry.example.com" { + t.Errorf("Expected URL 'registry.example.com', got %q", url) + } + }) + + t.Run("RegistryURLEmptyRegistries", func(t *testing.T) { + // Given a DockerEnvPrinter with a registry URL and empty registries config + printer, _ := setup(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: colima + docker: + registry_url: registry.example.com + registries: {} +`, + }) + + // When getting the registry URL + url, err := printer.getRegistryURL() + + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + // And the URL should be the same as the config value + if url != "registry.example.com" { + t.Errorf("Expected URL 'registry.example.com', got %q", url) + } + }) + + t.Run("NilDockerConfig", func(t *testing.T) { + // Given a DockerEnvPrinter with no Docker config + printer, _ := setup(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: colima +`, + }) + + // When getting the registry URL + url, err := printer.getRegistryURL() + + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + // And the URL should be empty + if url != "" { + t.Errorf("Expected empty URL, got %q", url) + } + }) + + t.Run("NilRegistriesWithURL", func(t *testing.T) { + // Given a DockerEnvPrinter with a registry URL but no registries config + printer, _ := setup(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: colima + docker: + registry_url: registry.example.com +`, + }) + + // When getting the registry URL + url, err := printer.getRegistryURL() + + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + // And the URL should be the same as the config value + if url != "registry.example.com" { + t.Errorf("Expected URL 'registry.example.com', got %q", url) + } + }) + + t.Run("RegistryWithoutHostPort", func(t *testing.T) { + // Given a DockerEnvPrinter with a registry without hostport in config + printer, _ := setup(t, &SetupOptions{ + ConfigStr: ` +version: v1alpha1 +contexts: + test-context: + vm: + driver: colima + docker: + registries: + registry.example.com: + remote: "" +`, + }) + + // When getting the registry URL + url, err := printer.getRegistryURL() + + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + // And the URL should be the registry name with default port 5000 + if url != "registry.example.com:5000" { + t.Errorf("Expected URL 'registry.example.com:5000', got %q", url) + } + }) +} diff --git a/pkg/env/env.go b/pkg/env/env.go index 7ace5e86a..36e640de2 100644 --- a/pkg/env/env.go +++ b/pkg/env/env.go @@ -1,23 +1,22 @@ +// The EnvPrinter is a core component that manages environment variable state and context. +// It provides a unified interface for loading, printing, and managing environment variables, +// The EnvPrinter acts as the central environment orchestrator for the application, +// coordinating environment variable management, shell integration, and configuration persistence. + package env import ( "fmt" "slices" - "sync" - "github.com/windsorcli/cli/pkg/config" "github.com/windsorcli/cli/pkg/di" "github.com/windsorcli/cli/pkg/shell" ) -// These are the environment variables that are managed by Windsor. -// They are scoped to the current shell session. -var ( - windsorManagedEnv = []string{} - windsorManagedAlias = []string{} - windsorManagedMu sync.Mutex -) +// ============================================================================= +// Types +// ============================================================================= // EnvPrinter defines the method for printing environment variables. type EnvPrinter interface { @@ -33,19 +32,33 @@ type EnvPrinter interface { Reset() } -// Env is a struct that implements the EnvPrinter interface. +// BaseEnvPrinter is a base implementation of the EnvPrinter interface type BaseEnvPrinter struct { EnvPrinter injector di.Injector shell shell.Shell configHandler config.ConfigHandler + shims *Shims + managedEnv []string + managedAlias []string } -// NewBaseEnvPrinter creates a new BaseEnvPrinter instance. +// ============================================================================= +// Constructor +// ============================================================================= + +// NewBaseEnvPrinter creates a new BaseEnvPrinter instance func NewBaseEnvPrinter(injector di.Injector) *BaseEnvPrinter { - return &BaseEnvPrinter{injector: injector} + return &BaseEnvPrinter{ + injector: injector, + shims: NewShims(), + } } +// ============================================================================= +// Public Methods +// ============================================================================= + // Initialize resolves and assigns the shell and configHandler from the injector. func (e *BaseEnvPrinter) Initialize() error { shell, ok := e.injector.Resolve("shell").(shell.Shell) @@ -126,40 +139,34 @@ func (e *BaseEnvPrinter) PostEnvHook() error { // GetManagedEnv returns the environment variables that are managed by Windsor. func (e *BaseEnvPrinter) GetManagedEnv() []string { - windsorManagedMu.Lock() - defer windsorManagedMu.Unlock() - return windsorManagedEnv + return e.managedEnv } // GetManagedAlias returns the shell aliases that are managed by Windsor. func (e *BaseEnvPrinter) GetManagedAlias() []string { - windsorManagedMu.Lock() - defer windsorManagedMu.Unlock() - return windsorManagedAlias + return e.managedAlias } // SetManagedEnv sets the environment variables that are managed by Windsor. func (e *BaseEnvPrinter) SetManagedEnv(env string) { - windsorManagedMu.Lock() - defer windsorManagedMu.Unlock() - if slices.Contains(windsorManagedEnv, env) { + if slices.Contains(e.managedEnv, env) { return } - windsorManagedEnv = append(windsorManagedEnv, env) + e.managedEnv = append(e.managedEnv, env) } // SetManagedAlias sets the shell aliases that are managed by Windsor. func (e *BaseEnvPrinter) SetManagedAlias(alias string) { - windsorManagedMu.Lock() - defer windsorManagedMu.Unlock() - if slices.Contains(windsorManagedAlias, alias) { + if slices.Contains(e.managedAlias, alias) { return } - windsorManagedAlias = append(windsorManagedAlias, alias) + e.managedAlias = append(e.managedAlias, alias) } // Reset removes all managed environment variables and aliases. // It delegates to the shell's Reset method to handle the reset logic. func (e *BaseEnvPrinter) Reset() { + e.managedEnv = make([]string, 0) + e.managedAlias = make([]string, 0) e.shell.Reset() } diff --git a/pkg/env/env_test.go b/pkg/env/env_test.go index 63f407eb5..cbe0358de 100644 --- a/pkg/env/env_test.go +++ b/pkg/env/env_test.go @@ -3,7 +3,6 @@ package env import ( "os" "reflect" - "slices" "testing" "github.com/windsorcli/cli/pkg/config" @@ -11,60 +10,150 @@ import ( "github.com/windsorcli/cli/pkg/shell" ) +// ============================================================================= +// Test Setup +// ============================================================================= + // Mocks holds all the mock objects used in the tests. type Mocks struct { - Injector *di.MockInjector + Injector di.Injector + ConfigHandler config.ConfigHandler Shell *shell.MockShell - ConfigHandler *config.MockConfigHandler - Env *MockEnvPrinter + Shims *Shims +} + +type SetupOptions struct { + Injector di.Injector + ConfigHandler config.ConfigHandler + ConfigStr string +} + +// setupShims creates a new Shims instance with default implementations +func setupShims(t *testing.T) *Shims { + t.Helper() + shims := NewShims() + + shims.LookupEnv = func(key string) (string, bool) { return "", false } + shims.WriteFile = func(name string, data []byte, perm os.FileMode) error { return nil } + shims.ReadFile = func(name string) ([]byte, error) { return []byte{}, nil } + shims.MkdirAll = func(path string, perm os.FileMode) error { return nil } + shims.UserHomeDir = func() (string, error) { return t.TempDir(), nil } + shims.Stat = func(name string) (os.FileInfo, error) { return nil, nil } + shims.Getwd = func() (string, error) { return t.TempDir(), nil } + + return shims } -// setupEnvMockTests sets up the mock injector and returns the Mocks object. -// It takes an optional injector and only creates one if it's not provided. -func setupEnvMockTests(injector *di.MockInjector) *Mocks { - if injector == nil { - injector = di.NewMockInjector() +func setupMocks(t *testing.T, opts ...*SetupOptions) *Mocks { + t.Helper() + + // Store original directory and create temp dir + origDir, err := os.Getwd() + if err != nil { + t.Fatalf("Failed to get working directory: %v", err) } + + tmpDir := t.TempDir() + if err := os.Chdir(tmpDir); err != nil { + t.Fatalf("Failed to change to temp directory: %v", err) + } + + // Set project root environment variable + os.Setenv("WINDSOR_PROJECT_ROOT", tmpDir) + os.Setenv("WINDSOR_CONTEXT", "mock-context") + + // Process options with defaults + options := &SetupOptions{} + if len(opts) > 0 && opts[0] != nil { + options = opts[0] + } + + // Create injector + var injector di.Injector + if options.Injector == nil { + injector = di.NewInjector() + } else { + injector = options.Injector + } + + // Create shell with project root matching temp dir mockShell := shell.NewMockShell() - mockConfigHandler := config.NewMockConfigHandler() + mockShell.GetProjectRootFunc = func() (string, error) { + return tmpDir, nil + } injector.Register("shell", mockShell) - injector.Register("configHandler", mockConfigHandler) - mockEnv := NewMockEnvPrinter() - injector.Register("env", mockEnv) + + // Create config handler + var configHandler config.ConfigHandler + if options.ConfigHandler == nil { + configHandler = config.NewYamlConfigHandler(injector) + } else { + configHandler = options.ConfigHandler + } + if options.ConfigStr != "" { + configHandler.LoadConfigString(options.ConfigStr) + } + injector.Register("configHandler", configHandler) + + // Setup shims + shims := setupShims(t) + + configHandler.Initialize() + + // Register cleanup to restore original state + t.Cleanup(func() { + os.Unsetenv("WINDSOR_PROJECT_ROOT") + os.Unsetenv("WINDSOR_CONTEXT") + if err := os.Chdir(origDir); err != nil { + t.Logf("Warning: Failed to change back to original directory: %v", err) + } + }) + + // Return mocks return &Mocks{ Injector: injector, Shell: mockShell, - ConfigHandler: mockConfigHandler, - Env: mockEnv, + ConfigHandler: configHandler, + Shims: shims, } } +// ============================================================================= +// Test Public Methods +// ============================================================================= + // TestEnv_Initialize tests the Initialize method of the Env struct func TestEnv_Initialize(t *testing.T) { + setup := func(t *testing.T) (*BaseEnvPrinter, *Mocks) { + t.Helper() + mocks := setupMocks(t) + printer := NewBaseEnvPrinter(mocks.Injector) + return printer, mocks + } + t.Run("Success", func(t *testing.T) { - mocks := setupEnvMockTests(nil) + // Given a new BaseEnvPrinter + printer, _ := setup(t) - // Use a BaseEnvPrinter for real initialization - envPrinter := NewBaseEnvPrinter(mocks.Injector) + // When calling Initialize + err := printer.Initialize() - // Call Initialize and check for errors - err := envPrinter.Initialize() + // Then no error should be returned if err != nil { t.Errorf("unexpected error: %v", err) } }) t.Run("ErrorResolvingShell", func(t *testing.T) { - mocks := setupEnvMockTests(nil) - - // Register an invalid shell that cannot be cast to shell.Shell - mocks.Injector.Register("shell", "invalid") + // Given a new BaseEnvPrinter with an invalid shell + injector := di.NewMockInjector() + injector.Register("shell", "invalid") + printer := NewBaseEnvPrinter(injector) - // Use a BaseEnvPrinter for real initialization - envPrinter := NewBaseEnvPrinter(mocks.Injector) + // When calling Initialize + err := printer.Initialize() - // Call Initialize and expect an error - err := envPrinter.Initialize() + // Then an error should be returned if err == nil { t.Error("expected error, got nil") } else if err.Error() != "error resolving or casting shell to shell.Shell" { @@ -73,16 +162,16 @@ func TestEnv_Initialize(t *testing.T) { }) t.Run("ErrorCastingCliConfigHandler", func(t *testing.T) { - mocks := setupEnvMockTests(nil) + // Given a new BaseEnvPrinter with an invalid configHandler + injector := di.NewMockInjector() + injector.Register("shell", shell.NewMockShell()) + injector.Register("configHandler", struct{}{}) + printer := NewBaseEnvPrinter(injector) - // Register an invalid configHandler that cannot be cast to config.ConfigHandler - mocks.Injector.Register("configHandler", "invalid") + // When calling Initialize + err := printer.Initialize() - // Use a BaseEnvPrinter for real initialization - envPrinter := NewBaseEnvPrinter(mocks.Injector) - - // Call Initialize and expect an error - err := envPrinter.Initialize() + // Then an error should be returned if err == nil { t.Error("expected error, got nil") } else if err.Error() != "error resolving or casting configHandler to config.ConfigHandler" { @@ -93,23 +182,28 @@ func TestEnv_Initialize(t *testing.T) { // TestBaseEnvPrinter_GetEnvVars tests the GetEnvVars method of the BaseEnvPrinter struct func TestBaseEnvPrinter_GetEnvVars(t *testing.T) { - t.Run("Success", func(t *testing.T) { - mocks := setupEnvMockTests(nil) - - // Create a new BaseEnvPrinter and initialize it - envPrinter := NewBaseEnvPrinter(mocks.Injector) - err := envPrinter.Initialize() + setup := func(t *testing.T) (*BaseEnvPrinter, *Mocks) { + t.Helper() + mocks := setupMocks(t) + printer := NewBaseEnvPrinter(mocks.Injector) + err := printer.Initialize() if err != nil { t.Errorf("unexpected error during initialization: %v", err) } + return printer, mocks + } + + t.Run("Success", func(t *testing.T) { + // Given a new BaseEnvPrinter + printer, _ := setup(t) + + // When calling GetEnvVars + envVars, err := printer.GetEnvVars() - // Call GetEnvVars and check for errors - envVars, err := envPrinter.GetEnvVars() + // Then no error should be returned and envVars should be empty if err != nil { t.Errorf("unexpected error: %v", err) } - - // Verify that the returned envVars is an empty map expectedEnvVars := map[string]string{} if !reflect.DeepEqual(envVars, expectedEnvVars) { t.Errorf("envVars = %v, want %v", envVars, expectedEnvVars) @@ -119,32 +213,37 @@ func TestBaseEnvPrinter_GetEnvVars(t *testing.T) { // TestEnv_Print tests the Print method of the Env struct func TestEnv_Print(t *testing.T) { - t.Run("Success", func(t *testing.T) { - mocks := setupEnvMockTests(nil) - - // Create a new BaseEnvPrinter and initialize it - envPrinter := NewBaseEnvPrinter(mocks.Injector) - err := envPrinter.Initialize() + setup := func(t *testing.T) (*BaseEnvPrinter, *Mocks) { + t.Helper() + mocks := setupMocks(t) + printer := NewBaseEnvPrinter(mocks.Injector) + err := printer.Initialize() if err != nil { t.Errorf("unexpected error during initialization: %v", err) } + return printer, mocks + } - // Mock the PrintEnvVarsFunc to verify it is called + t.Run("Success", func(t *testing.T) { + // Given a new BaseEnvPrinter with test environment variables + printer, mocks := setup(t) + + // And a mock PrintEnvVarsFunc var capturedEnvVars map[string]string mocks.Shell.PrintEnvVarsFunc = func(envVars map[string]string) { capturedEnvVars = envVars } - // Set up test environment variables + // And test environment variables testEnvVars := map[string]string{"TEST_VAR": "test_value"} - // Call Print with test environment variables - err = envPrinter.Print(testEnvVars) + // When calling Print with test environment variables + err := printer.Print(testEnvVars) + + // Then no error should be returned and PrintEnvVarsFunc should be called with correct envVars if err != nil { t.Errorf("unexpected error: %v", err) } - - // Verify that PrintEnvVarsFunc was called with the correct envVars expectedEnvVars := map[string]string{"TEST_VAR": "test_value"} if !reflect.DeepEqual(capturedEnvVars, expectedEnvVars) { t.Errorf("capturedEnvVars = %v, want %v", capturedEnvVars, expectedEnvVars) @@ -152,28 +251,22 @@ func TestEnv_Print(t *testing.T) { }) t.Run("NoCustomVars", func(t *testing.T) { - mocks := setupEnvMockTests(nil) - - // Create a new BaseEnvPrinter and initialize it - envPrinter := NewBaseEnvPrinter(mocks.Injector) - err := envPrinter.Initialize() - if err != nil { - t.Errorf("unexpected error during initialization: %v", err) - } + // Given a new BaseEnvPrinter + printer, mocks := setup(t) - // Mock the PrintEnvVarsFunc to verify it is called + // And a mock PrintEnvVarsFunc var capturedEnvVars map[string]string mocks.Shell.PrintEnvVarsFunc = func(envVars map[string]string) { capturedEnvVars = envVars } - // Call Print without custom vars and check for errors - err = envPrinter.Print() + // When calling Print without custom vars + err := printer.Print() + + // Then no error should be returned and PrintEnvVarsFunc should be called with empty map if err != nil { t.Errorf("unexpected error: %v", err) } - - // Verify that PrintEnvVarsFunc was called with an empty map expectedEnvVars := map[string]string{} if !reflect.DeepEqual(capturedEnvVars, expectedEnvVars) { t.Errorf("capturedEnvVars = %v, want %v", capturedEnvVars, expectedEnvVars) @@ -183,32 +276,37 @@ func TestEnv_Print(t *testing.T) { // TestEnv_PrintAlias tests the PrintAlias method of the Env struct func TestEnv_PrintAlias(t *testing.T) { - t.Run("SuccessWithCustomAlias", func(t *testing.T) { - mocks := setupEnvMockTests(nil) - - // Create a new BaseEnvPrinter and initialize it - envPrinter := NewBaseEnvPrinter(mocks.Injector) - err := envPrinter.Initialize() + setup := func(t *testing.T) (*BaseEnvPrinter, *Mocks) { + t.Helper() + mocks := setupMocks(t) + printer := NewBaseEnvPrinter(mocks.Injector) + err := printer.Initialize() if err != nil { t.Errorf("unexpected error during initialization: %v", err) } + return printer, mocks + } + + t.Run("SuccessWithCustomAlias", func(t *testing.T) { + // Given a new BaseEnvPrinter with test alias + printer, mocks := setup(t) - // Mock the PrintAliasFunc to verify it is called + // And a mock PrintAliasFunc var capturedAlias map[string]string mocks.Shell.PrintAliasFunc = func(alias map[string]string) { capturedAlias = alias } - // Set up test alias + // And test alias testAlias := map[string]string{"alias1": "command1"} - // Call PrintAlias with test alias - err = envPrinter.PrintAlias(testAlias) + // When calling PrintAlias with test alias + err := printer.PrintAlias(testAlias) + + // Then no error should be returned and PrintAliasFunc should be called with correct alias if err != nil { t.Errorf("unexpected error: %v", err) } - - // Verify that PrintAliasFunc was called with the correct alias expectedAlias := map[string]string{"alias1": "command1"} if !reflect.DeepEqual(capturedAlias, expectedAlias) { t.Errorf("capturedAlias = %v, want %v", capturedAlias, expectedAlias) @@ -216,28 +314,22 @@ func TestEnv_PrintAlias(t *testing.T) { }) t.Run("SuccessWithoutCustomAlias", func(t *testing.T) { - mocks := setupEnvMockTests(nil) - - // Create a new BaseEnvPrinter and initialize it - envPrinter := NewBaseEnvPrinter(mocks.Injector) - err := envPrinter.Initialize() - if err != nil { - t.Errorf("unexpected error during initialization: %v", err) - } + // Given a new BaseEnvPrinter + printer, mocks := setup(t) - // Mock the PrintAliasFunc to verify it is called + // And a mock PrintAliasFunc var capturedAlias map[string]string mocks.Shell.PrintAliasFunc = func(alias map[string]string) { capturedAlias = alias } - // Call PrintAlias without custom alias - err = envPrinter.PrintAlias() + // When calling PrintAlias without custom alias + err := printer.PrintAlias() + + // Then no error should be returned and PrintAliasFunc should be called with empty map if err != nil { t.Errorf("unexpected error: %v", err) } - - // Verify that PrintAliasFunc was called with an empty map expectedAlias := map[string]string{} if !reflect.DeepEqual(capturedAlias, expectedAlias) { t.Errorf("capturedAlias = %v, want %v", capturedAlias, expectedAlias) @@ -247,38 +339,39 @@ func TestEnv_PrintAlias(t *testing.T) { // TestBaseEnvPrinter_GetManagedEnv tests the GetManagedEnv method of the BaseEnvPrinter struct func TestBaseEnvPrinter_GetManagedEnv(t *testing.T) { - t.Run("Success", func(t *testing.T) { - // Given a new BaseEnvPrinter - mocks := setupEnvMockTests(nil) - envPrinter := NewBaseEnvPrinter(mocks.Injector) - err := envPrinter.Initialize() + setup := func(t *testing.T) (*BaseEnvPrinter, *Mocks) { + t.Helper() + mocks := setupMocks(t) + printer := NewBaseEnvPrinter(mocks.Injector) + err := printer.Initialize() if err != nil { t.Errorf("unexpected error during initialization: %v", err) } + return printer, mocks + } + + t.Run("Success", func(t *testing.T) { + // Given a new BaseEnvPrinter + printer, _ := setup(t) - // Save original value to restore it after the test - originalManagedEnv := make([]string, len(windsorManagedEnv)) - copy(originalManagedEnv, windsorManagedEnv) + // And test environment variables + // Store original managed environment variables + originalManagedEnv := make([]string, len(printer.managedEnv)) + copy(originalManagedEnv, printer.managedEnv) defer func() { - windsorManagedMu.Lock() - windsorManagedEnv = originalManagedEnv - windsorManagedMu.Unlock() + printer.managedEnv = originalManagedEnv }() - // Set test variables - windsorManagedMu.Lock() - windsorManagedEnv = []string{"TEST_VAR1", "TEST_VAR2"} - windsorManagedMu.Unlock() + // Set test environment variables + printer.managedEnv = []string{"TEST_VAR1", "TEST_VAR2"} // When calling GetManagedEnv - managedEnv := envPrinter.GetManagedEnv() + managedEnv := printer.GetManagedEnv() // Then the returned list should contain our tracked variables if len(managedEnv) != 2 { t.Errorf("expected 2 variables, got %d", len(managedEnv)) } - - // Verify expected variables are present if managedEnv[0] != "TEST_VAR1" || managedEnv[1] != "TEST_VAR2" { t.Errorf("expected [TEST_VAR1, TEST_VAR2], got %v", managedEnv) } @@ -287,38 +380,39 @@ func TestBaseEnvPrinter_GetManagedEnv(t *testing.T) { // TestBaseEnvPrinter_GetManagedAlias tests the GetManagedAlias method of the BaseEnvPrinter struct func TestBaseEnvPrinter_GetManagedAlias(t *testing.T) { - t.Run("Success", func(t *testing.T) { - // Given a new BaseEnvPrinter - mocks := setupEnvMockTests(nil) - envPrinter := NewBaseEnvPrinter(mocks.Injector) - err := envPrinter.Initialize() + setup := func(t *testing.T) (*BaseEnvPrinter, *Mocks) { + t.Helper() + mocks := setupMocks(t) + printer := NewBaseEnvPrinter(mocks.Injector) + err := printer.Initialize() if err != nil { t.Errorf("unexpected error during initialization: %v", err) } + return printer, mocks + } - // Save original value to restore it after the test - originalManagedAlias := make([]string, len(windsorManagedAlias)) - copy(originalManagedAlias, windsorManagedAlias) + t.Run("Success", func(t *testing.T) { + // Given a new BaseEnvPrinter + printer, _ := setup(t) + + // And test aliases + // Store original managed aliases + originalManagedAlias := make([]string, len(printer.managedAlias)) + copy(originalManagedAlias, printer.managedAlias) defer func() { - windsorManagedMu.Lock() - windsorManagedAlias = originalManagedAlias - windsorManagedMu.Unlock() + printer.managedAlias = originalManagedAlias }() // Set test aliases - windsorManagedMu.Lock() - windsorManagedAlias = []string{"alias1", "alias2"} - windsorManagedMu.Unlock() + printer.managedAlias = []string{"alias1", "alias2"} // When calling GetManagedAlias - managedAlias := envPrinter.GetManagedAlias() + managedAlias := printer.GetManagedAlias() // Then the returned list should contain our tracked aliases if len(managedAlias) != 2 { t.Errorf("expected 2 aliases, got %d", len(managedAlias)) } - - // Verify expected aliases are present if managedAlias[0] != "alias1" || managedAlias[1] != "alias2" { t.Errorf("expected [alias1, alias2], got %v", managedAlias) } @@ -327,41 +421,40 @@ func TestBaseEnvPrinter_GetManagedAlias(t *testing.T) { // TestBaseEnvPrinter_SetManagedEnv tests the SetManagedEnv method of the BaseEnvPrinter struct func TestBaseEnvPrinter_SetManagedEnv(t *testing.T) { - t.Run("Success", func(t *testing.T) { - // Given a new BaseEnvPrinter - mocks := setupEnvMockTests(nil) - envPrinter := NewBaseEnvPrinter(mocks.Injector) - err := envPrinter.Initialize() + setup := func(t *testing.T) (*BaseEnvPrinter, *Mocks) { + t.Helper() + mocks := setupMocks(t) + printer := NewBaseEnvPrinter(mocks.Injector) + err := printer.Initialize() if err != nil { t.Errorf("unexpected error during initialization: %v", err) } + return printer, mocks + } - // Save original value to restore it after the test - originalManagedEnv := make([]string, len(windsorManagedEnv)) - copy(originalManagedEnv, windsorManagedEnv) + t.Run("Success", func(t *testing.T) { + // Given a new BaseEnvPrinter + printer, _ := setup(t) + + // And empty managed environment variables + // Store original managed environment variables + originalManagedEnv := make([]string, len(printer.managedEnv)) + copy(originalManagedEnv, printer.managedEnv) defer func() { - windsorManagedMu.Lock() - windsorManagedEnv = originalManagedEnv - windsorManagedMu.Unlock() + printer.managedEnv = originalManagedEnv }() - // Reset managed environment variables for this test - windsorManagedMu.Lock() - windsorManagedEnv = []string{} - windsorManagedMu.Unlock() + // Set empty managed environment variables + printer.managedEnv = []string{} - // Set test variables (one string at a time) - envPrinter.SetManagedEnv("SET_TEST_VAR1") + // When setting a managed environment variable + printer.SetManagedEnv("SET_TEST_VAR1") - // When calling GetManagedEnv to verify - managedEnv := envPrinter.GetManagedEnv() - - // Then the returned list should contain our variables + // Then GetManagedEnv should return the variable + managedEnv := printer.GetManagedEnv() if len(managedEnv) != 1 { t.Errorf("expected 1 variable, got %d", len(managedEnv)) } - - // Verify expected variables are present if managedEnv[0] != "SET_TEST_VAR1" { t.Errorf("expected [SET_TEST_VAR1], got %v", managedEnv) } @@ -369,40 +462,28 @@ func TestBaseEnvPrinter_SetManagedEnv(t *testing.T) { t.Run("Dedupe", func(t *testing.T) { // Given a new BaseEnvPrinter - mocks := setupEnvMockTests(nil) - envPrinter := NewBaseEnvPrinter(mocks.Injector) - err := envPrinter.Initialize() - if err != nil { - t.Errorf("unexpected error during initialization: %v", err) - } + printer, _ := setup(t) - // Save original value to restore it after the test - originalManagedEnv := make([]string, len(windsorManagedEnv)) - copy(originalManagedEnv, windsorManagedEnv) + // And empty managed environment variables + // Store original managed environment variables + originalManagedEnv := make([]string, len(printer.managedEnv)) + copy(originalManagedEnv, printer.managedEnv) defer func() { - windsorManagedMu.Lock() - windsorManagedEnv = originalManagedEnv - windsorManagedMu.Unlock() + printer.managedEnv = originalManagedEnv }() - // Reset managed environment variables for this test - windsorManagedMu.Lock() - windsorManagedEnv = []string{} - windsorManagedMu.Unlock() + // Set empty managed environment variables + printer.managedEnv = []string{} - // Set duplicate test variables - envPrinter.SetManagedEnv("SET_TEST_VAR1") - envPrinter.SetManagedEnv("SET_TEST_VAR1") // Attempt to add duplicate + // When setting duplicate managed environment variables + printer.SetManagedEnv("SET_TEST_VAR1") + printer.SetManagedEnv("SET_TEST_VAR1") - // When calling GetManagedEnv to verify - managedEnv := envPrinter.GetManagedEnv() - - // Then the returned list should contain only one instance of the variable + // Then GetManagedEnv should return only one instance + managedEnv := printer.GetManagedEnv() if len(managedEnv) != 1 { t.Errorf("expected 1 variable, got %d", len(managedEnv)) } - - // Verify expected variables are present if managedEnv[0] != "SET_TEST_VAR1" { t.Errorf("expected [SET_TEST_VAR1], got %v", managedEnv) } @@ -411,41 +492,40 @@ func TestBaseEnvPrinter_SetManagedEnv(t *testing.T) { // TestBaseEnvPrinter_SetManagedAlias tests the SetManagedAlias method of the BaseEnvPrinter struct func TestBaseEnvPrinter_SetManagedAlias(t *testing.T) { - t.Run("Success", func(t *testing.T) { - // Given a new BaseEnvPrinter - mocks := setupEnvMockTests(nil) - envPrinter := NewBaseEnvPrinter(mocks.Injector) - err := envPrinter.Initialize() + setup := func(t *testing.T) (*BaseEnvPrinter, *Mocks) { + t.Helper() + mocks := setupMocks(t) + printer := NewBaseEnvPrinter(mocks.Injector) + err := printer.Initialize() if err != nil { t.Errorf("unexpected error during initialization: %v", err) } + return printer, mocks + } + + t.Run("Success", func(t *testing.T) { + // Given a new BaseEnvPrinter + printer, _ := setup(t) - // Save original value to restore it after the test - originalManagedAlias := make([]string, len(windsorManagedAlias)) - copy(originalManagedAlias, windsorManagedAlias) + // And empty managed aliases + // Store original managed aliases + originalManagedAlias := make([]string, len(printer.managedAlias)) + copy(originalManagedAlias, printer.managedAlias) defer func() { - windsorManagedMu.Lock() - windsorManagedAlias = originalManagedAlias - windsorManagedMu.Unlock() + printer.managedAlias = originalManagedAlias }() - // Reset managed aliases for this test - windsorManagedMu.Lock() - windsorManagedAlias = []string{} - windsorManagedMu.Unlock() + // Set empty managed aliases + printer.managedAlias = []string{} - // Set test aliases (one string at a time) - envPrinter.SetManagedAlias("set_alias1") + // When setting a managed alias + printer.SetManagedAlias("set_alias1") - // When calling GetManagedAlias to verify - managedAlias := envPrinter.GetManagedAlias() - - // Then the returned list should contain our aliases + // Then GetManagedAlias should return the alias + managedAlias := printer.GetManagedAlias() if len(managedAlias) != 1 { t.Errorf("expected 1 alias, got %d", len(managedAlias)) } - - // Verify expected aliases are present if managedAlias[0] != "set_alias1" { t.Errorf("expected [set_alias1], got %v", managedAlias) } @@ -453,40 +533,28 @@ func TestBaseEnvPrinter_SetManagedAlias(t *testing.T) { t.Run("Dedupe", func(t *testing.T) { // Given a new BaseEnvPrinter - mocks := setupEnvMockTests(nil) - envPrinter := NewBaseEnvPrinter(mocks.Injector) - err := envPrinter.Initialize() - if err != nil { - t.Errorf("unexpected error during initialization: %v", err) - } + printer, _ := setup(t) - // Save original value to restore it after the test - originalManagedAlias := make([]string, len(windsorManagedAlias)) - copy(originalManagedAlias, windsorManagedAlias) + // And empty managed aliases + // Store original managed aliases + originalManagedAlias := make([]string, len(printer.managedAlias)) + copy(originalManagedAlias, printer.managedAlias) defer func() { - windsorManagedMu.Lock() - windsorManagedAlias = originalManagedAlias - windsorManagedMu.Unlock() + printer.managedAlias = originalManagedAlias }() - // Reset managed aliases for this test - windsorManagedMu.Lock() - windsorManagedAlias = []string{} - windsorManagedMu.Unlock() - - // Set duplicate test aliases - envPrinter.SetManagedAlias("set_alias1") - envPrinter.SetManagedAlias("set_alias1") // Attempt to add duplicate + // Set empty managed aliases + printer.managedAlias = []string{} - // When calling GetManagedAlias to verify - managedAlias := envPrinter.GetManagedAlias() + // When setting duplicate managed aliases + printer.SetManagedAlias("set_alias1") + printer.SetManagedAlias("set_alias1") - // Then the returned list should contain only one instance of the alias + // Then GetManagedAlias should return only one instance + managedAlias := printer.GetManagedAlias() if len(managedAlias) != 1 { t.Errorf("expected 1 alias, got %d", len(managedAlias)) } - - // Verify expected aliases are present if managedAlias[0] != "set_alias1" { t.Errorf("expected [set_alias1], got %v", managedAlias) } @@ -495,23 +563,29 @@ func TestBaseEnvPrinter_SetManagedAlias(t *testing.T) { // TestBaseEnvPrinter_Reset tests the Reset method of the BaseEnvPrinter struct func TestBaseEnvPrinter_Reset(t *testing.T) { - t.Run("ResetWithNoEnvVars", func(t *testing.T) { - // Given a new BaseEnvPrinter - mocks := setupEnvMockTests(nil) - envPrinter := NewBaseEnvPrinter(mocks.Injector) - err := envPrinter.Initialize() + setup := func(t *testing.T) (*BaseEnvPrinter, *Mocks) { + t.Helper() + mocks := setupMocks(t) + printer := NewBaseEnvPrinter(mocks.Injector) + err := printer.Initialize() if err != nil { t.Errorf("unexpected error during initialization: %v", err) } + return printer, mocks + } + + t.Run("ResetWithNoEnvVars", func(t *testing.T) { + // Given a new BaseEnvPrinter + printer, mocks := setup(t) - // Track calls to Reset + // And a mock Reset function resetCalled := false mocks.Shell.ResetFunc = func() { resetCalled = true } // When calling Reset - envPrinter.Reset() + printer.Reset() // Then shell.Reset should be called if !resetCalled { @@ -520,15 +594,10 @@ func TestBaseEnvPrinter_Reset(t *testing.T) { }) t.Run("ResetWithEnvironmentVariables", func(t *testing.T) { - // Given a new BaseEnvPrinter - mocks := setupEnvMockTests(nil) - envPrinter := NewBaseEnvPrinter(mocks.Injector) - err := envPrinter.Initialize() - if err != nil { - t.Errorf("unexpected error during initialization: %v", err) - } + // Given a new BaseEnvPrinter with environment variables + printer, mocks := setup(t) - // Set environment variables + // And environment variables set os.Setenv("WINDSOR_MANAGED_ENV", "ENV1,ENV2, ENV3") os.Setenv("WINDSOR_MANAGED_ALIAS", "alias1,alias2, alias3") defer func() { @@ -536,14 +605,14 @@ func TestBaseEnvPrinter_Reset(t *testing.T) { os.Unsetenv("WINDSOR_MANAGED_ALIAS") }() - // Track calls to Reset + // And a mock Reset function resetCalled := false mocks.Shell.ResetFunc = func() { resetCalled = true } // When calling Reset - envPrinter.Reset() + printer.Reset() // Then shell.Reset should be called if !resetCalled { @@ -551,53 +620,40 @@ func TestBaseEnvPrinter_Reset(t *testing.T) { } }) - t.Run("InternalStatePersistsWithReset", func(t *testing.T) { + t.Run("InternalStateResetsWithReset", func(t *testing.T) { // Given a new BaseEnvPrinter with managed environment variables and aliases - mocks := setupEnvMockTests(nil) - envPrinter := NewBaseEnvPrinter(mocks.Injector) - err := envPrinter.Initialize() - if err != nil { - t.Errorf("unexpected error during initialization: %v", err) - } + printer, mocks := setup(t) - // Set up some managed environment variables and aliases - envPrinter.SetManagedEnv("TEST_ENV1") - envPrinter.SetManagedEnv("TEST_ENV2") - envPrinter.SetManagedAlias("test_alias1") - envPrinter.SetManagedAlias("test_alias2") + // And managed environment variables and aliases set + printer.SetManagedEnv("TEST_ENV1") + printer.SetManagedEnv("TEST_ENV2") + printer.SetManagedAlias("test_alias1") + printer.SetManagedAlias("test_alias2") - // Track calls to Reset + // And a mock Reset function resetCalled := false mocks.Shell.ResetFunc = func() { resetCalled = true } // When calling Reset - envPrinter.Reset() + printer.Reset() // Then shell.Reset should be called if !resetCalled { t.Errorf("expected Shell.Reset to be called, but it wasn't") } - // And the managed environment variables should still be available - managedEnv := envPrinter.GetManagedEnv() - - // Verify that our test variables are in the managed env (without requiring exact count) - for _, env := range []string{"TEST_ENV1", "TEST_ENV2"} { - if !slices.Contains(managedEnv, env) { - t.Errorf("expected GetManagedEnv to contain %s", env) - } + // And the managed environment variables should be empty + managedEnv := printer.GetManagedEnv() + if len(managedEnv) > 0 { + t.Errorf("expected GetManagedEnv to be empty, got %v", managedEnv) } - // And the managed aliases should still be available - managedAlias := envPrinter.GetManagedAlias() - - // Verify that our test aliases are in the managed aliases (without requiring exact count) - for _, alias := range []string{"test_alias1", "test_alias2"} { - if !slices.Contains(managedAlias, alias) { - t.Errorf("expected GetManagedAlias to contain %s", alias) - } + // And the managed aliases should be empty + managedAlias := printer.GetManagedAlias() + if len(managedAlias) > 0 { + t.Errorf("expected GetManagedAlias to be empty, got %v", managedAlias) } }) } diff --git a/pkg/env/kube_env.go b/pkg/env/kube_env.go index 504249d3a..1dd8ee7eb 100644 --- a/pkg/env/kube_env.go +++ b/pkg/env/kube_env.go @@ -1,3 +1,8 @@ +// The KubeEnvPrinter is a specialized component that manages Kubernetes environment configuration. +// It provides Kubernetes-specific environment variable management and configuration, +// The KubeEnvPrinter handles kubeconfig, context, and persistent volume configuration settings, +// ensuring proper kubectl integration and environment setup for Kubernetes operations. + package env import ( @@ -8,6 +13,7 @@ import ( "regexp" "strings" + "github.com/windsorcli/cli/pkg/constants" "github.com/windsorcli/cli/pkg/di" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -15,20 +21,30 @@ import ( "k8s.io/client-go/tools/clientcmd" ) -// KubeEnvPrinter is a struct that simulates a Kubernetes environment for testing purposes. +// ============================================================================= +// Types +// ============================================================================= + +// KubeEnvPrinter is a struct that implements Kubernetes environment configuration type KubeEnvPrinter struct { BaseEnvPrinter } -// NewKubeEnv initializes a new kubeEnv instance using the provided dependency injector. +// ============================================================================= +// Constructor +// ============================================================================= + +// NewKubeEnvPrinter creates a new KubeEnvPrinter instance func NewKubeEnvPrinter(injector di.Injector) *KubeEnvPrinter { return &KubeEnvPrinter{ - BaseEnvPrinter: BaseEnvPrinter{ - injector: injector, - }, + BaseEnvPrinter: *NewBaseEnvPrinter(injector), } } +// ============================================================================= +// Public Methods +// ============================================================================= + // GetEnvVars constructs a map of Kubernetes environment variables by setting // KUBECONFIG and KUBE_CONFIG_PATH based on the configuration root directory. // It checks for a project-specific volume directory and returns current variables @@ -48,22 +64,26 @@ func (e *KubeEnvPrinter) GetEnvVars() (map[string]string, error) { projectRoot := os.Getenv("WINDSOR_PROJECT_ROOT") volumeDir := filepath.Join(projectRoot, ".volumes") - if _, err := stat(volumeDir); os.IsNotExist(err) { - return envVars, nil + _, err = e.shims.Stat(volumeDir) + if err != nil { + if os.IsNotExist(err) { + return envVars, nil + } + return nil, fmt.Errorf("error checking volume directory: %w", err) } - volumeDirs, err := readDir(volumeDir) + volumeDirs, err := e.shims.ReadDir(volumeDir) if err != nil { return nil, fmt.Errorf("error reading volume directories: %w", err) } existingEnvVars := make(map[string]string) - for _, env := range os.Environ() { + for _, env := range e.shims.Environ() { if strings.HasPrefix(env, "PV_") { parts := strings.SplitN(env, "=", 2) if len(parts) == 2 { existingEnvVars[parts[0]] = parts[1] - envVars[parts[0]] = parts[1] // Include existing PV environment variables + envVars[parts[0]] = parts[1] } } } @@ -89,7 +109,10 @@ func (e *KubeEnvPrinter) GetEnvVars() (map[string]string, error) { return envVars, nil } - pvcs, _ := queryPersistentVolumeClaims(kubeConfigPath) // ignores error + pvcs, err := queryPersistentVolumeClaims(kubeConfigPath) + if err != nil { + return nil, fmt.Errorf("error querying persistent volume claims: %w", err) + } if pvcs != nil && pvcs.Items != nil { for _, dir := range volumeDirs { @@ -122,8 +145,9 @@ func (e *KubeEnvPrinter) Print() error { return e.BaseEnvPrinter.Print(envVars) } -// Ensure kubeEnv implements the EnvPrinter interface -var _ EnvPrinter = (*KubeEnvPrinter)(nil) +// ============================================================================= +// Private Methods +// ============================================================================= // sanitizeEnvVar converts a string to uppercase, trims whitespace, and replaces invalid characters with underscores. func sanitizeEnvVar(input string) string { @@ -146,10 +170,19 @@ var queryPersistentVolumeClaims = func(kubeConfigPath string) (*corev1.Persisten return nil, err } - pvcs, err := clientset.CoreV1().PersistentVolumeClaims("").List(context.TODO(), metav1.ListOptions{}) + ctx, cancel := context.WithTimeout(context.Background(), constants.KUBERNETES_SHORT_TIMEOUT) + defer cancel() + + pvcs, err := clientset.CoreV1().PersistentVolumeClaims("").List(ctx, metav1.ListOptions{}) if err != nil { + if ctx.Err() == context.DeadlineExceeded { + return nil, fmt.Errorf("timeout querying PVCs: %w", err) + } return nil, err } return pvcs, nil } + +// Ensure KubeEnvPrinter implements the EnvPrinter interface +var _ EnvPrinter = (*KubeEnvPrinter)(nil) diff --git a/pkg/env/kube_env_test.go b/pkg/env/kube_env_test.go index 221158056..a8a57f6a3 100644 --- a/pkg/env/kube_env_test.go +++ b/pkg/env/kube_env_test.go @@ -4,325 +4,584 @@ import ( "errors" "os" "path/filepath" - "reflect" "strings" "testing" "time" "github.com/windsorcli/cli/pkg/config" - "github.com/windsorcli/cli/pkg/di" - "github.com/windsorcli/cli/pkg/shell" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) -type KubeEnvPrinterMocks struct { - Injector di.Injector - ConfigHandler *config.MockConfigHandler - Shell *shell.MockShell +// ============================================================================= +// Test Setup +// ============================================================================= + +// mockDirEntry is a simple mock implementation of os.DirEntry +type mockDirEntry struct { + name string + isDir bool } -func setupSafeKubeEnvPrinterMocks(injector ...di.Injector) *KubeEnvPrinterMocks { - var mockInjector di.Injector - if len(injector) > 0 { - mockInjector = injector[0] - } else { - mockInjector = di.NewMockInjector() - } +func (m mockDirEntry) Name() string { return m.name } +func (m mockDirEntry) IsDir() bool { return m.isDir } +func (m mockDirEntry) Type() os.FileMode { return os.ModeDir } +func (m mockDirEntry) Info() (os.FileInfo, error) { return mockFileInfo{name: m.name}, nil } + +// mockFileInfo is a simple mock implementation of os.FileInfo +type mockFileInfo struct { + name string +} - mockConfigHandler := config.NewMockConfigHandler() - mockConfigHandler.GetConfigRootFunc = func() (string, error) { - return filepath.FromSlash("/mock/config/root"), nil +func (m mockFileInfo) Name() string { return m.name } +func (m mockFileInfo) Size() int64 { return 0 } +func (m mockFileInfo) Mode() os.FileMode { return os.ModeDir } +func (m mockFileInfo) ModTime() time.Time { return time.Time{} } +func (m mockFileInfo) IsDir() bool { return true } +func (m mockFileInfo) Sys() any { return nil } + +// setupKubeEnvMocks creates a base mock setup for Kubernetes environment tests +func setupKubeEnvMocks(t *testing.T, opts ...*SetupOptions) *Mocks { + t.Helper() + if len(opts) == 0 { + opts = []*SetupOptions{{}} } - mockShell := shell.NewMockShell() + mocks := setupMocks(t, opts[0]) + projectRoot, err := mocks.Shell.GetProjectRoot() + if err != nil { + t.Fatalf("Failed to get project root: %v", err) + } - mockInjector.Register("configHandler", mockConfigHandler) - mockInjector.Register("shell", mockShell) + t.Setenv("WINDSOR_PROJECT_ROOT", projectRoot) // Mock readDir to return some valid persistent volume folders - readDir = func(dirname string) ([]os.DirEntry, error) { - if strings.HasSuffix(dirname, ".volumes") { + mocks.Shims.ReadDir = func(dirname string) ([]os.DirEntry, error) { + if strings.HasSuffix(dirname, "volumes") { return []os.DirEntry{ - mockDirEntry{name: "pvc-1234"}, - mockDirEntry{name: "pvc-5678"}, + mockDirEntry{name: "pvc-1234", isDir: true}, + mockDirEntry{name: "pvc-5678", isDir: true}, }, nil } return nil, errors.New("mock readDir error") } - // Mock stat to return nil - stat = func(name string) (os.FileInfo, error) { - if strings.HasSuffix(name, ".kube/config") || strings.HasSuffix(name, ".volumes") { - return nil, nil + // Mock stat to return nil for both kubeconfig and volumes + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if strings.HasSuffix(name, ".kube/config") || strings.HasSuffix(name, "volumes") { + return mockFileInfo{name: filepath.Base(name)}, nil } return nil, os.ErrNotExist } - // Mock queryPersistentVolumeClaims to return appropriate PVC claims - queryPersistentVolumeClaims = func(kubeConfigPath string) (*corev1.PersistentVolumeClaimList, error) { + // Mock queryPersistentVolumeClaims to return some PVCs + queryPersistentVolumeClaims = func(_ string) (*corev1.PersistentVolumeClaimList, error) { return &corev1.PersistentVolumeClaimList{ Items: []corev1.PersistentVolumeClaim{ - {ObjectMeta: metav1.ObjectMeta{UID: "1234", Namespace: "default", Name: "claim1"}}, - {ObjectMeta: metav1.ObjectMeta{UID: "5678", Namespace: "default", Name: "claim2"}}, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "pvc-1", + Namespace: "test-ns", + UID: "1234", + }, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "pvc-2", + Namespace: "test-ns", + UID: "5678", + }, + }, }, }, nil } - return &KubeEnvPrinterMocks{ - Injector: mockInjector, - ConfigHandler: mockConfigHandler, - Shell: mockShell, - } + return mocks } -// mockDirEntry is a simple mock implementation of os.DirEntry -type mockDirEntry struct { - name string -} +// ============================================================================= +// Test Public Methods +// ============================================================================= -func (m mockDirEntry) Name() string { return m.name } -func (m mockDirEntry) IsDir() bool { return true } -func (m mockDirEntry) Type() os.FileMode { return os.ModeDir } -func (m mockDirEntry) Info() (os.FileInfo, error) { return mockFileInfo{name: m.name}, nil } +// TestKubeEnvPrinter_GetEnvVars tests the GetEnvVars method of the KubeEnvPrinter +func TestKubeEnvPrinter_GetEnvVars(t *testing.T) { + setup := func(t *testing.T) (*KubeEnvPrinter, *Mocks) { + t.Helper() + mocks := setupKubeEnvMocks(t) + printer := NewKubeEnvPrinter(mocks.Injector) + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize printer: %v", err) + } + printer.shims = mocks.Shims + return printer, mocks + } -// mockFileInfo is a simple mock implementation of os.FileInfo -type mockFileInfo struct { - name string -} + t.Run("SuccessWithKubeConfig", func(t *testing.T) { + // Given a KubeEnvPrinter with valid configuration + printer, mocks := setup(t) -func (m mockFileInfo) Name() string { return m.name } -func (m mockFileInfo) Size() int64 { return 0 } -func (m mockFileInfo) Mode() os.FileMode { return os.ModeDir } -func (m mockFileInfo) ModTime() time.Time { return time.Time{} } -func (m mockFileInfo) IsDir() bool { return true } -func (m mockFileInfo) Sys() interface{} { return nil } + // And a valid config root + configRoot, err := mocks.ConfigHandler.GetConfigRoot() + if err != nil { + t.Fatalf("Failed to get config root: %v", err) + } -func TestKubeEnvPrinter_GetEnvVars(t *testing.T) { - t.Run("Success", func(t *testing.T) { - mocks := setupSafeKubeEnvPrinterMocks() + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } - kubeEnvPrinter := NewKubeEnvPrinter(mocks.Injector) - kubeEnvPrinter.Initialize() + // And KUBECONFIG should be set correctly + expectedKubeConfig := filepath.Join(configRoot, ".kube/config") + if envVars["KUBECONFIG"] != expectedKubeConfig { + t.Errorf("Expected KUBECONFIG=%s, got %s", expectedKubeConfig, envVars["KUBECONFIG"]) + } + }) - envVars, err := kubeEnvPrinter.GetEnvVars() + t.Run("SuccessWithVolumes", func(t *testing.T) { + // Given a KubeEnvPrinter with valid configuration + printer, _ := setup(t) + + // And a valid project root + projectRoot := os.Getenv("WINDSOR_PROJECT_ROOT") + volumeDir := filepath.Join(projectRoot, ".volumes") + + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then no error should be returned if err != nil { - t.Fatalf("GetEnvVars returned an error: %v", err) + t.Errorf("Expected no error, got %v", err) } - expectedPath := filepath.FromSlash("/mock/config/root/.kube/config") - if envVars["KUBECONFIG"] != expectedPath || envVars["KUBE_CONFIG_PATH"] != expectedPath { - t.Errorf("KUBECONFIG = %v, KUBE_CONFIG_PATH = %v, want both to be %v", envVars["KUBECONFIG"], envVars["KUBE_CONFIG_PATH"], expectedPath) + // And volume paths should be set correctly + expectedPaths := map[string]string{ + "PV_TEST_NS_PVC_1": filepath.Join(volumeDir, "pvc-1234"), + "PV_TEST_NS_PVC_2": filepath.Join(volumeDir, "pvc-5678"), + } + + for k, v := range expectedPaths { + if envVars[k] != v { + t.Errorf("Expected %s=%s, got %s", k, v, envVars[k]) + } } }) t.Run("NoKubeConfig", func(t *testing.T) { - mocks := setupSafeKubeEnvPrinterMocks() + // Given a KubeEnvPrinter with valid configuration + printer, mocks := setup(t) - originalStat := stat - defer func() { stat = originalStat }() - stat = func(name string) (os.FileInfo, error) { + // And a valid config root + configRoot, err := mocks.ConfigHandler.GetConfigRoot() + if err != nil { + t.Fatalf("Failed to get config root: %v", err) + } + + // And a mock Stat function that returns ErrNotExist + mocks.Shims.Stat = func(path string) (os.FileInfo, error) { return nil, os.ErrNotExist } - kubeEnvPrinter := NewKubeEnvPrinter(mocks.Injector) - kubeEnvPrinter.Initialize() + // When getting environment variables + envVars, err := printer.GetEnvVars() - envVars, err := kubeEnvPrinter.GetEnvVars() + // Then no error should be returned if err != nil { - t.Fatalf("GetEnvVars returned an error: %v", err) + t.Errorf("Expected no error, got %v", err) } - expectedPath := filepath.FromSlash("/mock/config/root/.kube/config") - if envVars["KUBECONFIG"] != expectedPath || envVars["KUBE_CONFIG_PATH"] != expectedPath { - t.Errorf("KUBECONFIG = %v, KUBE_CONFIG_PATH = %v, want both to be %v", envVars["KUBECONFIG"], envVars["KUBE_CONFIG_PATH"], expectedPath) + // And KUBECONFIG should still be set correctly + expectedKubeConfig := filepath.Join(configRoot, ".kube/config") + if envVars["KUBECONFIG"] != expectedKubeConfig { + t.Errorf("Expected KUBECONFIG=%s, got %s", expectedKubeConfig, envVars["KUBECONFIG"]) } }) t.Run("GetConfigRootError", func(t *testing.T) { - mocks := setupSafeKubeEnvPrinterMocks() - mocks.ConfigHandler.GetConfigRootFunc = func() (string, error) { - return "", errors.New("mock context error") + // Given a mock ConfigHandler that returns an error + mockConfigHandler := config.NewMockConfigHandler() + mockConfigHandler.GetConfigRootFunc = func() (string, error) { + return "", errors.New("mock config error") } - kubeEnvPrinter := NewKubeEnvPrinter(mocks.Injector) - kubeEnvPrinter.Initialize() + // And a KubeEnvPrinter with the mock ConfigHandler + mocks := setupKubeEnvMocks(t, &SetupOptions{ConfigHandler: mockConfigHandler}) + printer := NewKubeEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize printer: %v", err) + } + + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then an error should be returned + if err == nil { + t.Error("Expected error, got nil") + } - _, err := kubeEnvPrinter.GetEnvVars() - expectedError := "error retrieving configuration root directory: mock context error" - if err == nil || err.Error() != expectedError { - t.Errorf("error = %v, want %v", err, expectedError) + // And envVars should be nil + if envVars != nil { + t.Errorf("Expected nil envVars, got %v", envVars) } }) t.Run("ErrorReadingVolumes", func(t *testing.T) { - mocks := setupSafeKubeEnvPrinterMocks() - mocks.ConfigHandler.GetConfigRootFunc = func() (string, error) { - return "/mock/config/root", nil + // Given a KubeEnvPrinter with valid configuration + printer, mocks := setup(t) + + // And a mock ReadDir function that returns an error + mocks.Shims.ReadDir = func(dirname string) ([]os.DirEntry, error) { + if strings.HasSuffix(dirname, "volumes") { + return nil, errors.New("mock readDir error") + } + return nil, nil } - originalReadDir := readDir - defer func() { readDir = originalReadDir }() - readDir = func(dirname string) ([]os.DirEntry, error) { - return nil, errors.New("mock readDir error") + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then an error should be returned + if err == nil { + t.Error("Expected error, got nil") } - kubeEnvPrinter := NewKubeEnvPrinter(mocks.Injector) - kubeEnvPrinter.Initialize() + // And the error should mention reading volume directories + if !strings.Contains(err.Error(), "error reading volume directories") { + t.Errorf("Expected error about reading volume directories, got %v", err) + } - _, err := kubeEnvPrinter.GetEnvVars() - expectedError := "error reading volume directories: mock readDir error" - if err == nil || err.Error() != expectedError { - t.Errorf("error = %v, want %v", err, expectedError) + // And envVars should be nil + if envVars != nil { + t.Errorf("Expected nil envVars, got %v", envVars) } }) - t.Run("SuccessWithExistingPVCEnvVars", func(t *testing.T) { - // Use setupSafeKubeEnvPrinterMocks to create mocks - mocks := setupSafeKubeEnvPrinterMocks() - kubeEnvPrinter := NewKubeEnvPrinter(mocks.Injector) - kubeEnvPrinter.Initialize() + t.Run("ErrorQueryingPVCs", func(t *testing.T) { + // Given a KubeEnvPrinter with valid configuration + printer, _ := setup(t) - // Set up environment variables to simulate existing PVC environment variables - os.Setenv("PV_NAMESPACE_PVCNAME", "/mock/volume/dir/pvc-12345") - defer os.Unsetenv("PV_NAMESPACE_PVCNAME") + // And a mock queryPersistentVolumeClaims function that returns an error + queryPersistentVolumeClaims = func(_ string) (*corev1.PersistentVolumeClaimList, error) { + return nil, errors.New("mock PVC query error") + } - // Mock the readDir function to simulate reading the volume directory - originalReadDir := readDir - defer func() { readDir = originalReadDir }() - readDir = func(dirname string) ([]os.DirEntry, error) { - return []os.DirEntry{ - mockDirEntry{name: "pvc-12345"}, - }, nil + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then an error should be returned + if err == nil { + t.Error("Expected error, got nil") } - // Call GetEnvVars and check for errors - envVars, err := kubeEnvPrinter.GetEnvVars() + // And the error should mention querying PVCs + if !strings.Contains(err.Error(), "error querying persistent volume claims") { + t.Errorf("Expected error about querying PVCs, got %v", err) + } + + // And envVars should be nil + if envVars != nil { + t.Errorf("Expected nil envVars, got %v", envVars) + } + }) + + t.Run("NilPVCList", func(t *testing.T) { + // Given a KubeEnvPrinter with valid configuration + printer, _ := setup(t) + + // And a mock queryPersistentVolumeClaims function that returns nil list + queryPersistentVolumeClaims = func(_ string) (*corev1.PersistentVolumeClaimList, error) { + return nil, nil + } + + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then no error should be returned if err != nil { - t.Errorf("unexpected error: %v", err) + t.Errorf("Expected no error, got %v", err) } - // Verify that GetEnvVars returns the correct envVars - expectedEnvVars := map[string]string{ - "KUBECONFIG": filepath.FromSlash("/mock/config/root/.kube/config"), - "KUBE_CONFIG_PATH": filepath.FromSlash("/mock/config/root/.kube/config"), - "K8S_AUTH_KUBECONFIG": filepath.FromSlash("/mock/config/root/.kube/config"), - "PV_NAMESPACE_PVCNAME": "/mock/volume/dir/pvc-12345", + // And envVars should not be nil + if envVars == nil { + t.Error("Expected non-nil envVars") } - if !reflect.DeepEqual(envVars, expectedEnvVars) { - t.Errorf("envVars = %v, want %v", envVars, expectedEnvVars) + }) + + t.Run("EmptyPVCList", func(t *testing.T) { + // Given a KubeEnvPrinter with valid configuration + printer, _ := setup(t) + + // And a mock queryPersistentVolumeClaims function that returns empty list + queryPersistentVolumeClaims = func(_ string) (*corev1.PersistentVolumeClaimList, error) { + return &corev1.PersistentVolumeClaimList{Items: []corev1.PersistentVolumeClaim{}}, nil + } + + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + // And envVars should not be nil + if envVars == nil { + t.Error("Expected non-nil envVars") } }) - t.Run("AllVolumesAccountedFor", func(t *testing.T) { - mocks := setupSafeKubeEnvPrinterMocks() - kubeEnvPrinter := NewKubeEnvPrinter(mocks.Injector) - kubeEnvPrinter.Initialize() - - // Set up environment variables to simulate all PVCs being accounted for - os.Setenv("PV_DEFAULT_CLAIM1", "/mock/volume/dir/pvc-1234") - os.Setenv("PV_DEFAULT_CLAIM2", "/mock/volume/dir/pvc-5678") - defer os.Unsetenv("PV_DEFAULT_CLAIM1") - defer os.Unsetenv("PV_DEFAULT_CLAIM2") - - // Mock the readDir function to simulate reading the volume directory - originalReadDir := readDir - defer func() { readDir = originalReadDir }() - readDir = func(dirname string) ([]os.DirEntry, error) { - return []os.DirEntry{ - mockDirEntry{name: "pvc-1234"}, - mockDirEntry{name: "pvc-5678"}, - }, nil + t.Run("VolumeDirStatError", func(t *testing.T) { + // Given a KubeEnvPrinter with valid configuration + printer, mocks := setup(t) + + // And a mock Stat function that returns an error for volume directory + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if strings.HasSuffix(name, "volumes") { + return nil, errors.New("mock stat error") + } + return mockFileInfo{name: filepath.Base(name)}, nil + } + + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then an error should be returned + if err == nil { + t.Error("Expected error, got nil") + } + + // And the error should mention checking volume directory + if !strings.Contains(err.Error(), "error checking volume directory") { + t.Errorf("Expected error about checking volume directory, got %v", err) + } + + // And envVars should be nil + if envVars != nil { + t.Errorf("Expected nil envVars, got %v", envVars) } + }) + + t.Run("NoPVCDirectories", func(t *testing.T) { + // Given a KubeEnvPrinter with valid configuration + printer, mocks := setup(t) - // Mock queryPersistentVolumeClaims to verify it is not called - originalQueryPVCs := queryPersistentVolumeClaims - defer func() { queryPersistentVolumeClaims = originalQueryPVCs }() - queryPersistentVolumeClaims = func(kubeConfigPath string) (*corev1.PersistentVolumeClaimList, error) { - t.Error("queryPersistentVolumeClaims should not be called") + // And a mock ReadDir function that returns no PVC directories + mocks.Shims.ReadDir = func(dirname string) ([]os.DirEntry, error) { + if strings.HasSuffix(dirname, "volumes") { + return []os.DirEntry{ + mockDirEntry{name: "other-dir", isDir: true}, + }, nil + } return nil, nil } - // Call GetEnvVars and check for errors - envVars, err := kubeEnvPrinter.GetEnvVars() + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then no error should be returned if err != nil { - t.Errorf("unexpected error: %v", err) + t.Errorf("Expected no error, got %v", err) } - // Verify that GetEnvVars returns the correct envVars without calling queryPersistentVolumeClaims - expectedEnvVars := map[string]string{ - "KUBECONFIG": filepath.FromSlash("/mock/config/root/.kube/config"), - "KUBE_CONFIG_PATH": filepath.FromSlash("/mock/config/root/.kube/config"), - "K8S_AUTH_KUBECONFIG": filepath.FromSlash("/mock/config/root/.kube/config"), - "PV_DEFAULT_CLAIM1": "/mock/volume/dir/pvc-1234", - "PV_DEFAULT_CLAIM2": "/mock/volume/dir/pvc-5678", + // And envVars should not be nil + if envVars == nil { + t.Error("Expected non-nil envVars") } - if !reflect.DeepEqual(envVars, expectedEnvVars) { - t.Errorf("envVars = %v, want %v", envVars, expectedEnvVars) + }) + + t.Run("UnmatchedPVCDirectories", func(t *testing.T) { + // Given a KubeEnvPrinter with valid configuration + printer, mocks := setup(t) + + // And a mock ReadDir function that returns PVC directories that don't match any PVCs + mocks.Shims.ReadDir = func(dirname string) ([]os.DirEntry, error) { + if strings.HasSuffix(dirname, "volumes") { + return []os.DirEntry{ + mockDirEntry{name: "pvc-9999", isDir: true}, + mockDirEntry{name: "pvc-8888", isDir: true}, + }, nil + } + return nil, nil + } + + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + // And envVars should not be nil + if envVars == nil { + t.Error("Expected non-nil envVars") } }) -} -func TestKubeEnvPrinter_Print(t *testing.T) { - t.Run("Success", func(t *testing.T) { - // Use setupSafeKubeEnvPrinterMocks to create mocks - mocks := setupSafeKubeEnvPrinterMocks() - mockInjector := mocks.Injector - kubeEnvPrinter := NewKubeEnvPrinter(mockInjector) - kubeEnvPrinter.Initialize() - - // Mock the stat function to simulate the existence of the kubeconfig file - originalStat := stat - defer func() { stat = originalStat }() - stat = func(name string) (os.FileInfo, error) { - if name == filepath.FromSlash("/mock/config/root/.kube/config") { - return nil, nil // Simulate that the file exists + t.Run("ExistingPVEnvVars", func(t *testing.T) { + // Given a KubeEnvPrinter with valid configuration + printer, mocks := setup(t) + + // And a mock Environ function that returns PV_ prefixed variables + mocks.Shims.Environ = func() []string { + return []string{ + "PV_TEST_NS_PVC_1=/path/to/pvc-1234", + "OTHER_VAR=value", + "PV_TEST_NS_PVC_2=/path/to/pvc-5678", } - return nil, os.ErrNotExist } - // Mock the PrintEnvVarsFunc to verify it is called with the correct envVars - var capturedEnvVars map[string]string - mocks.Shell.PrintEnvVarsFunc = func(envVars map[string]string) { - capturedEnvVars = envVars + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + // And envVars should not be nil + if envVars == nil { + t.Error("Expected non-nil envVars") + } + + // And PV environment variables should be set correctly + if envVars["PV_TEST_NS_PVC_1"] != "/path/to/pvc-1234" { + t.Errorf("Expected PV_TEST_NS_PVC_1=/path/to/pvc-1234, got %s", envVars["PV_TEST_NS_PVC_1"]) + } + if envVars["PV_TEST_NS_PVC_2"] != "/path/to/pvc-5678" { + t.Errorf("Expected PV_TEST_NS_PVC_2=/path/to/pvc-5678, got %s", envVars["PV_TEST_NS_PVC_2"]) + } + }) + + t.Run("EmptyVolumeDir", func(t *testing.T) { + // Given a KubeEnvPrinter with valid configuration + printer, mocks := setup(t) + + // And a mock ReadDir function that returns empty directory + mocks.Shims.ReadDir = func(dirname string) ([]os.DirEntry, error) { + if strings.HasSuffix(dirname, "volumes") { + return []os.DirEntry{}, nil + } + return nil, nil + } + + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + // And envVars should not be nil + if envVars == nil { + t.Error("Expected non-nil envVars") + } + }) + + t.Run("PartiallyMatchedPVCDirectories", func(t *testing.T) { + // Given a KubeEnvPrinter with valid configuration + printer, mocks := setup(t) + + // And a mock ReadDir function that returns mix of matching and non-matching PVC directories + mocks.Shims.ReadDir = func(dirname string) ([]os.DirEntry, error) { + if strings.HasSuffix(dirname, "volumes") { + return []os.DirEntry{ + mockDirEntry{name: "pvc-1234", isDir: true}, // matches + mockDirEntry{name: "pvc-9999", isDir: true}, // doesn't match + }, nil + } + return nil, nil + } + + // And a mock queryPersistentVolumeClaims function that returns specific PVCs + queryPersistentVolumeClaims = func(_ string) (*corev1.PersistentVolumeClaimList, error) { + return &corev1.PersistentVolumeClaimList{ + Items: []corev1.PersistentVolumeClaim{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "pvc-1", + Namespace: "test-ns", + UID: "1234", + }, + }, + }, + }, nil } - // Call Print and check for errors - err := kubeEnvPrinter.Print() + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then no error should be returned if err != nil { - t.Errorf("unexpected error: %v", err) + t.Errorf("Expected no error, got %v", err) } - // Verify that PrintEnvVarsFunc was called with the correct envVars - expectedEnvVars := map[string]string{ - "KUBECONFIG": filepath.FromSlash("/mock/config/root/.kube/config"), - "KUBE_CONFIG_PATH": filepath.FromSlash("/mock/config/root/.kube/config"), - "K8S_AUTH_KUBECONFIG": filepath.FromSlash("/mock/config/root/.kube/config"), + // And envVars should not be nil + if envVars == nil { + t.Error("Expected non-nil envVars") } - if !reflect.DeepEqual(capturedEnvVars, expectedEnvVars) { - t.Errorf("capturedEnvVars = %v, want %v", capturedEnvVars, expectedEnvVars) + + // And only the matching PVC should be in envVars + projectRoot := os.Getenv("WINDSOR_PROJECT_ROOT") + expectedPath := filepath.Join(projectRoot, ".volumes", "pvc-1234") + if envVars["PV_TEST_NS_PVC_1"] != expectedPath { + t.Errorf("Expected PV_TEST_NS_PVC_1=%s, got %s", expectedPath, envVars["PV_TEST_NS_PVC_1"]) } }) +} - t.Run("GetConfigError", func(t *testing.T) { - // Use setupSafeKubeEnvPrinterMocks to create mocks - mocks := setupSafeKubeEnvPrinterMocks() +// TestKubeEnvPrinter_Print tests the Print method of the KubeEnvPrinter +func TestKubeEnvPrinter_Print(t *testing.T) { + setup := func(t *testing.T) (*KubeEnvPrinter, *Mocks) { + t.Helper() + mocks := setupKubeEnvMocks(t) + printer := NewKubeEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize printer: %v", err) + } + return printer, mocks + } - // Override the GetConfigFunc to simulate an error - mocks.ConfigHandler.GetConfigRootFunc = func() (string, error) { + t.Run("Success", func(t *testing.T) { + // Given a KubeEnvPrinter with valid configuration + printer, _ := setup(t) + + // When printing environment variables + err := printer.Print() + + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + t.Run("GetConfigError", func(t *testing.T) { + // Given a mock ConfigHandler that returns an error + mockConfigHandler := config.NewMockConfigHandler() + mockConfigHandler.GetConfigRootFunc = func() (string, error) { return "", errors.New("mock config error") } - mockInjector := mocks.Injector + // And a KubeEnvPrinter with the mock ConfigHandler + mocks := setupKubeEnvMocks(t, &SetupOptions{ConfigHandler: mockConfigHandler}) + printer := NewKubeEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize printer: %v", err) + } + + // When printing environment variables + err := printer.Print() - kubeEnvPrinter := NewKubeEnvPrinter(mockInjector) - kubeEnvPrinter.Initialize() - // Call Print and check for errors - err := kubeEnvPrinter.Print() + // Then an error should be returned if err == nil { - t.Error("expected error, got nil") - } else if !strings.Contains(err.Error(), "mock config error") { - t.Errorf("unexpected error message: %v", err) + t.Error("Expected error, got nil") } }) } diff --git a/pkg/env/mock_env.go b/pkg/env/mock_env.go index c27e57c1e..0f5297ba0 100644 --- a/pkg/env/mock_env.go +++ b/pkg/env/mock_env.go @@ -1,22 +1,41 @@ +// The MockEnvPrinter is a mock implementation of the EnvPrinter interface. +// It provides a testable implementation of environment variable management, +// The MockEnvPrinter enables testing of environment-dependent functionality, +// allowing for controlled simulation of environment operations in tests. + package env -// MockEnvPrinter is a struct that simulates an environment for testing purposes. +// ============================================================================= +// Types +// ============================================================================= + +// MockEnvPrinter is a struct that implements mock environment configuration type MockEnvPrinter struct { BaseEnvPrinter - InitializeFunc func() error - PrintFunc func() error - PrintAliasFunc func() error - PostEnvHookFunc func() error - GetEnvVarsFunc func() (map[string]string, error) - GetAliasFunc func() (map[string]string, error) - ResetFunc func() + InitializeFunc func() error + PrintFunc func() error + PrintAliasFunc func() error + PostEnvHookFunc func() error + GetEnvVarsFunc func() (map[string]string, error) + GetAliasFunc func() (map[string]string, error) + GetManagedEnvFunc func() []string + GetManagedAliasFunc func() []string + ResetFunc func() } -// NewMockEnvPrinter creates a new instance of MockEnvPrinter. +// ============================================================================= +// Constructor +// ============================================================================= + +// NewMockEnvPrinter creates a new MockEnvPrinter instance func NewMockEnvPrinter() *MockEnvPrinter { return &MockEnvPrinter{} } +// ============================================================================= +// Public Methods +// ============================================================================= + // Initialize calls the custom InitializeFunc if provided. func (m *MockEnvPrinter) Initialize() error { if m.InitializeFunc != nil { @@ -82,5 +101,23 @@ func (m *MockEnvPrinter) Reset() { } } +// GetManagedEnv returns the managed environment variables. +// If a custom GetManagedEnvFunc is provided, it will use that function instead. +func (m *MockEnvPrinter) GetManagedEnv() []string { + if m.GetManagedEnvFunc != nil { + return m.GetManagedEnvFunc() + } + return m.BaseEnvPrinter.GetManagedEnv() +} + +// GetManagedAlias returns the managed aliases. +// If a custom GetManagedAliasFunc is provided, it will use that function instead. +func (m *MockEnvPrinter) GetManagedAlias() []string { + if m.GetManagedAliasFunc != nil { + return m.GetManagedAliasFunc() + } + return m.BaseEnvPrinter.GetManagedAlias() +} + // Ensure MockEnvPrinter implements the EnvPrinter interface var _ EnvPrinter = (*MockEnvPrinter)(nil) diff --git a/pkg/env/mock_env_test.go b/pkg/env/mock_env_test.go index a5df1186b..80c8bf9a9 100644 --- a/pkg/env/mock_env_test.go +++ b/pkg/env/mock_env_test.go @@ -1,131 +1,118 @@ package env import ( - "bytes" "fmt" - "os" "reflect" "testing" ) -func captureStdout(t *testing.T, f func()) string { - // Save the current stdout - old := os.Stdout - // Create a pipe to capture stdout - r, w, err := os.Pipe() - if err != nil { - t.Fatalf("Failed to create pipe: %v", err) - } - // Set stdout to the write end of the pipe - os.Stdout = w - - // Run the function - f() - - // Close the write end of the pipe and restore stdout - w.Close() - os.Stdout = old - - // Read the captured output - var buf bytes.Buffer - if _, err := buf.ReadFrom(r); err != nil { - t.Fatalf("Failed to read from pipe: %v", err) - } - - return buf.String() -} +// ============================================================================= +// Test Public Methods +// ============================================================================= +// TestMockEnvPrinter_Initialize tests the Initialize method of the MockEnvPrinter func TestMockEnvPrinter_Initialize(t *testing.T) { t.Run("Success", func(t *testing.T) { - // Given a mock environment with a custom InitializeFunc - mockEnv := NewMockEnvPrinter() + // Given a mock environment printer + printer := NewMockEnvPrinter() var initialized bool - mockEnv.InitializeFunc = func() error { + printer.InitializeFunc = func() error { initialized = true return nil } - // When calling Initialize - err := mockEnv.Initialize() + // When initializing + err := printer.Initialize() - // Then no error should be returned and initialized should be true + // Then no error should be returned if err != nil { t.Errorf("Initialize() error = %v, want nil", err) } + // And initialized should be true if !initialized { t.Errorf("Initialize() did not set initialized to true") } }) t.Run("DefaultInitialize", func(t *testing.T) { - // Given a mock environment with default Initialize implementation - mockEnv := NewMockEnvPrinter() - // When calling Initialize - if err := mockEnv.Initialize(); err != nil { + // Given a mock environment printer with default implementation + printer := NewMockEnvPrinter() + + // When initializing + err := printer.Initialize() + + // Then no error should be returned + if err != nil { t.Errorf("Initialize() error = %v, want nil", err) } }) } +// TestMockEnvPrinter_NewMockEnvPrinter tests the NewMockEnvPrinter constructor func TestMockEnvPrinter_NewMockEnvPrinter(t *testing.T) { t.Run("CreateMockEnvPrinterWithoutContainer", func(t *testing.T) { // When creating a new mock environment without an injector - mockEnv := NewMockEnvPrinter() + printer := NewMockEnvPrinter() // Then no error should be returned - if mockEnv == nil { - t.Errorf("Expected mockEnv, got nil") + if printer == nil { + t.Errorf("Expected printer, got nil") } }) } +// TestMockEnvPrinter_GetEnvVars tests the GetEnvVars method of the MockEnvPrinter func TestMockEnvPrinter_GetEnvVars(t *testing.T) { t.Run("DefaultGetEnvVars", func(t *testing.T) { - // Given a mock environment with default GetEnvVars implementation - mockEnv := NewMockEnvPrinter() - // When calling GetEnvVars - envVars, err := mockEnv.GetEnvVars() - // Then no error should be returned and envVars should be an empty map + // Given a mock environment printer with default implementation + printer := NewMockEnvPrinter() + + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then no error should be returned if err != nil { t.Errorf("GetEnvVars() error = %v, want nil", err) } + // And envVars should be empty if len(envVars) != 0 { t.Errorf("GetEnvVars() = %v, want empty map", envVars) } }) t.Run("CustomGetEnvVars", func(t *testing.T) { - // Given a mock environment with custom GetEnvVars implementation - mockEnv := NewMockEnvPrinter() + // Given a mock environment printer with custom implementation + printer := NewMockEnvPrinter() expectedEnvVars := map[string]string{ "VAR1": "value1", "VAR2": "value2", } - mockEnv.GetEnvVarsFunc = func() (map[string]string, error) { + printer.GetEnvVarsFunc = func() (map[string]string, error) { return expectedEnvVars, nil } - // When calling GetEnvVars - envVars, err := mockEnv.GetEnvVars() - // Then no error should be returned and envVars should match expectedEnvVars + + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then no error should be returned if err != nil { t.Errorf("GetEnvVars() error = %v, want nil", err) } - if len(envVars) != len(expectedEnvVars) { + // And envVars should match expected values + if !reflect.DeepEqual(envVars, expectedEnvVars) { t.Errorf("GetEnvVars() = %v, want %v", envVars, expectedEnvVars) } - for key, value := range expectedEnvVars { - if envVars[key] != value { - t.Errorf("GetEnvVars()[%v] = %v, want %v", key, envVars[key], value) - } - } }) } +// TestMockEnvPrinter_Print tests the Print method of the MockEnvPrinter func TestMockEnvPrinter_Print(t *testing.T) { t.Run("DefaultPrint", func(t *testing.T) { - // Given a mock environment with default Print implementation - mockEnv := NewMockEnvPrinter() - // When calling Print - err := mockEnv.Print() + // Given a mock environment printer with default implementation + printer := NewMockEnvPrinter() + + // When printing + err := printer.Print() + // Then no error should be returned if err != nil { t.Errorf("Print() error = %v, want nil", err) @@ -133,49 +120,58 @@ func TestMockEnvPrinter_Print(t *testing.T) { }) t.Run("CustomPrint", func(t *testing.T) { - // Given a mock environment with custom Print implementation - mockEnv := NewMockEnvPrinter() + // Given a mock environment printer with custom implementation + printer := NewMockEnvPrinter() expectedError := fmt.Errorf("custom print error") - mockEnv.PrintFunc = func() error { + printer.PrintFunc = func() error { return expectedError } - // When calling Print - err := mockEnv.Print() - // Then the custom error should be returned + + // When printing + err := printer.Print() + + // Then the expected error should be returned if err != expectedError { t.Errorf("Print() error = %v, want %v", err, expectedError) } }) } +// TestMockPrinter_GetAlias tests the GetAlias method of the MockEnvPrinter func TestMockPrinter_GetAlias(t *testing.T) { t.Run("Success", func(t *testing.T) { - // Given a mock environment with custom GetAlias implementation - mockEnv := NewMockEnvPrinter() + // Given a mock environment printer with custom implementation + printer := NewMockEnvPrinter() expectedAlias := map[string]string{"test": "echo test"} - mockEnv.GetAliasFunc = func() (map[string]string, error) { + printer.GetAliasFunc = func() (map[string]string, error) { return expectedAlias, nil } - // When calling GetAlias - alias, err := mockEnv.GetAlias() - // Then no error should be returned and alias should match expectedAlias + + // When getting aliases + alias, err := printer.GetAlias() + + // Then no error should be returned if err != nil { t.Errorf("GetAlias() error = %v, want nil", err) } + // And aliases should match expected values if !reflect.DeepEqual(alias, expectedAlias) { t.Errorf("GetAlias() = %v, want %v", alias, expectedAlias) } }) t.Run("NotImplemented", func(t *testing.T) { - // Given a mock environment with default GetAlias implementation - mockEnv := NewMockEnvPrinter() - // When calling GetAlias - alias, err := mockEnv.GetAlias() - // Then no error should be returned and alias should be an empty map + // Given a mock environment printer with default implementation + printer := NewMockEnvPrinter() + + // When getting aliases + alias, err := printer.GetAlias() + + // Then no error should be returned if err != nil { t.Errorf("GetAlias() error = %v, want nil", err) } + // And an empty map should be returned expectedAlias := map[string]string{} if !reflect.DeepEqual(alias, expectedAlias) { t.Errorf("GetAlias() = %v, want %v", alias, expectedAlias) @@ -183,15 +179,18 @@ func TestMockPrinter_GetAlias(t *testing.T) { }) } +// TestMockEnvPrinter_PrintAlias tests the PrintAlias method of the MockEnvPrinter func TestMockEnvPrinter_PrintAlias(t *testing.T) { t.Run("Success", func(t *testing.T) { - // Given a mock environment with custom PrintAlias implementation - mockEnv := NewMockEnvPrinter() - mockEnv.PrintAliasFunc = func() error { + // Given a mock environment printer with custom implementation + printer := NewMockEnvPrinter() + printer.PrintAliasFunc = func() error { return nil } - // When calling PrintAlias - err := mockEnv.PrintAlias() + + // When printing aliases + err := printer.PrintAlias() + // Then no error should be returned if err != nil { t.Errorf("PrintAlias() error = %v, want nil", err) @@ -199,23 +198,28 @@ func TestMockEnvPrinter_PrintAlias(t *testing.T) { }) t.Run("NotImplemented", func(t *testing.T) { - // Given a mock environment with default GetAlias implementation - mockEnv := NewMockEnvPrinter() - // When calling GetAlias - err := mockEnv.PrintAlias() + // Given a mock environment printer with default implementation + printer := NewMockEnvPrinter() + + // When printing aliases + err := printer.PrintAlias() + // Then no error should be returned if err != nil { - t.Errorf("GetAlias() error = %v, want nil", err) + t.Errorf("PrintAlias() error = %v, want nil", err) } }) } +// TestMockEnvPrinter_PostEnvHook tests the PostEnvHook method of the MockEnvPrinter func TestMockEnvPrinter_PostEnvHook(t *testing.T) { t.Run("DefaultPostEnvHook", func(t *testing.T) { - // Given a mock environment with default PostEnvHook implementation - mockEnv := NewMockEnvPrinter() - // When calling PostEnvHook - err := mockEnv.PostEnvHook() + // Given a mock environment printer with default implementation + printer := NewMockEnvPrinter() + + // When running post-env hook + err := printer.PostEnvHook() + // Then no error should be returned if err != nil { t.Errorf("PostEnvHook() error = %v, want nil", err) @@ -223,45 +227,117 @@ func TestMockEnvPrinter_PostEnvHook(t *testing.T) { }) t.Run("CustomPostEnvHook", func(t *testing.T) { - // Given a mock environment with custom PostEnvHook implementation - mockEnv := NewMockEnvPrinter() + // Given a mock environment printer with custom implementation + printer := NewMockEnvPrinter() expectedError := fmt.Errorf("custom post env hook error") - mockEnv.PostEnvHookFunc = func() error { + printer.PostEnvHookFunc = func() error { return expectedError } - // When calling PostEnvHook - err := mockEnv.PostEnvHook() - // Then the custom error should be returned + + // When running post-env hook + err := printer.PostEnvHook() + + // Then the expected error should be returned if err != expectedError { t.Errorf("PostEnvHook() error = %v, want %v", err, expectedError) } }) } +// TestMockEnvPrinter_Reset tests the Reset method of the MockEnvPrinter func TestMockEnvPrinter_Reset(t *testing.T) { t.Run("DefaultReset", func(t *testing.T) { - // Given a mock environment with default Reset implementation - mockEnv := NewMockEnvPrinter() + // Given a mock environment printer with default implementation + printer := NewMockEnvPrinter() - // When calling Reset (without a custom implementation) - // This is a no-op, so we just call it to ensure it doesn't panic - mockEnv.Reset() + // When resetting + printer.Reset() + + // Then no panic should occur }) t.Run("CustomReset", func(t *testing.T) { - // Given a mock environment with custom Reset implementation - mockEnv := NewMockEnvPrinter() + // Given a mock environment printer with custom implementation + printer := NewMockEnvPrinter() resetCalled := false - mockEnv.ResetFunc = func() { + printer.ResetFunc = func() { resetCalled = true } - // When calling Reset - mockEnv.Reset() + // When resetting + printer.Reset() - // Then it should call the custom reset function + // Then the custom reset function should be called if !resetCalled { t.Error("Reset() did not call ResetFunc") } }) } + +// TestMockEnvPrinter_GetManagedEnv tests the GetManagedEnv method of the MockEnvPrinter +func TestMockEnvPrinter_GetManagedEnv(t *testing.T) { + t.Run("CustomGetManagedEnv", func(t *testing.T) { + // Given a mock environment printer with custom implementation + printer := NewMockEnvPrinter() + expectedEnv := []string{"VAR1", "VAR2"} + printer.GetManagedEnvFunc = func() []string { + return expectedEnv + } + + // When getting managed environment variables + managedEnv := printer.GetManagedEnv() + + // Then the expected environment variables should be returned + if !reflect.DeepEqual(managedEnv, expectedEnv) { + t.Errorf("GetManagedEnv() = %v, want %v", managedEnv, expectedEnv) + } + }) + + t.Run("DefaultGetManagedEnv", func(t *testing.T) { + // Given a mock environment printer with default implementation + printer := NewMockEnvPrinter() + + // When getting managed environment variables + managedEnv := printer.GetManagedEnv() + + // Then the base implementation should be used + baseEnv := printer.BaseEnvPrinter.GetManagedEnv() + if !reflect.DeepEqual(managedEnv, baseEnv) { + t.Errorf("GetManagedEnv() = %v, want %v", managedEnv, baseEnv) + } + }) +} + +// TestMockEnvPrinter_GetManagedAlias tests the GetManagedAlias method of the MockEnvPrinter +func TestMockEnvPrinter_GetManagedAlias(t *testing.T) { + t.Run("CustomGetManagedAlias", func(t *testing.T) { + // Given a mock environment printer with custom implementation + printer := NewMockEnvPrinter() + expectedAlias := []string{"alias1", "alias2"} + printer.GetManagedAliasFunc = func() []string { + return expectedAlias + } + + // When getting managed aliases + managedAlias := printer.GetManagedAlias() + + // Then the expected aliases should be returned + if !reflect.DeepEqual(managedAlias, expectedAlias) { + t.Errorf("GetManagedAlias() = %v, want %v", managedAlias, expectedAlias) + } + }) + + t.Run("DefaultGetManagedAlias", func(t *testing.T) { + // Given a mock environment printer with default implementation + printer := NewMockEnvPrinter() + + // When getting managed aliases + managedAlias := printer.GetManagedAlias() + + // Then the base implementation should be used + baseAlias := printer.BaseEnvPrinter.GetManagedAlias() + if !reflect.DeepEqual(managedAlias, baseAlias) { + t.Errorf("GetManagedAlias() = %v, want %v", managedAlias, baseAlias) + } + }) +} diff --git a/pkg/env/omni_env.go b/pkg/env/omni_env.go index b061a54c4..c864aaa54 100644 --- a/pkg/env/omni_env.go +++ b/pkg/env/omni_env.go @@ -1,3 +1,8 @@ +// The OmniEnvPrinter is a specialized component that manages Omni environment configuration. +// It provides Omni-specific environment variable management and configuration, +// The OmniEnvPrinter handles Omni configuration settings and environment setup, +// ensuring proper Omni CLI integration and environment setup for operations. + package env import ( @@ -7,20 +12,30 @@ import ( "github.com/windsorcli/cli/pkg/di" ) -// OmniEnvPrinter is a struct that simulates a Kubernetes environment for testing purposes. +// ============================================================================= +// Types +// ============================================================================= + +// OmniEnvPrinter is a struct that implements Omni environment configuration type OmniEnvPrinter struct { BaseEnvPrinter } -// NewOmniEnv initializes a new omniEnv instance using the provided dependency injector. +// ============================================================================= +// Constructor +// ============================================================================= + +// NewOmniEnvPrinter creates a new OmniEnvPrinter instance func NewOmniEnvPrinter(injector di.Injector) *OmniEnvPrinter { return &OmniEnvPrinter{ - BaseEnvPrinter: BaseEnvPrinter{ - injector: injector, - }, + BaseEnvPrinter: *NewBaseEnvPrinter(injector), } } +// ============================================================================= +// Public Methods +// ============================================================================= + // GetEnvVars retrieves the environment variables for the Omni environment. func (e *OmniEnvPrinter) GetEnvVars() (map[string]string, error) { envVars := make(map[string]string) diff --git a/pkg/env/omni_env_test.go b/pkg/env/omni_env_test.go index c99c7813f..6abf27e25 100644 --- a/pkg/env/omni_env_test.go +++ b/pkg/env/omni_env_test.go @@ -7,102 +7,96 @@ import ( "reflect" "strings" "testing" - - "github.com/windsorcli/cli/pkg/config" - "github.com/windsorcli/cli/pkg/di" - "github.com/windsorcli/cli/pkg/shell" ) -type OmniEnvPrinterMocks struct { - Injector di.Injector - Shell *shell.MockShell - ConfigHandler *config.MockConfigHandler -} - -func setupSafeOmniEnvPrinterMocks(injector ...di.Injector) *OmniEnvPrinterMocks { - var mockInjector di.Injector - if len(injector) > 0 { - mockInjector = injector[0] - } else { - mockInjector = di.NewMockInjector() - } - - mockConfigHandler := config.NewMockConfigHandler() - mockConfigHandler.GetConfigRootFunc = func() (string, error) { - return filepath.FromSlash("/mock/config/root"), nil - } - - mockShell := shell.NewMockShell() - - mockInjector.Register("configHandler", mockConfigHandler) - mockInjector.Register("shell", mockShell) +// ============================================================================= +// Test Public Methods +// ============================================================================= - return &OmniEnvPrinterMocks{ - Injector: mockInjector, - ConfigHandler: mockConfigHandler, - Shell: mockShell, +// TestOmniEnvPrinter_GetEnvVars tests the GetEnvVars method of the OmniEnvPrinter +func TestOmniEnvPrinter_GetEnvVars(t *testing.T) { + setup := func(t *testing.T) (*OmniEnvPrinter, *Mocks) { + t.Helper() + mocks := setupMocks(t) + printer := NewOmniEnvPrinter(mocks.Injector) + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize env: %v", err) + } + printer.shims = mocks.Shims + return printer, mocks } -} -func TestOmniEnvPrinter_GetEnvVars(t *testing.T) { t.Run("Success", func(t *testing.T) { - mocks := setupSafeOmniEnvPrinterMocks() + // Given a new OmniEnvPrinter with existing Omni config + printer, mocks := setup(t) + + // Get the actual project root path + projectRoot, err := mocks.Shell.GetProjectRoot() + if err != nil { + t.Fatalf("Failed to get project root: %v", err) + } + expectedPath := filepath.Join(projectRoot, "contexts", "mock-context", ".omni", "config") - originalStat := stat - defer func() { stat = originalStat }() - stat = func(name string) (os.FileInfo, error) { - if name == filepath.FromSlash("/mock/config/root/.omni/config") { + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == expectedPath { return nil, nil } return nil, os.ErrNotExist } - omniEnvPrinter := NewOmniEnvPrinter(mocks.Injector) - omniEnvPrinter.Initialize() + // When getting environment variables + envVars, err := printer.GetEnvVars() - envVars, err := omniEnvPrinter.GetEnvVars() + // Then no error should be returned if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } - if envVars["OMNICONFIG"] != filepath.FromSlash("/mock/config/root/.omni/config") { - t.Errorf("OMNICONFIG = %v, want %v", envVars["OMNICONFIG"], filepath.FromSlash("/mock/config/root/.omni/config")) + // And OMNICONFIG should be set correctly + if envVars["OMNICONFIG"] != expectedPath { + t.Errorf("OMNICONFIG = %v, want %v", envVars["OMNICONFIG"], expectedPath) } }) t.Run("NoOmniConfig", func(t *testing.T) { - mocks := setupSafeOmniEnvPrinterMocks() + // Given a new OmniEnvPrinter without existing Omni config + printer, mocks := setup(t) - originalStat := stat - defer func() { stat = originalStat }() - stat = func(name string) (os.FileInfo, error) { + // Get the actual project root path + projectRoot, err := mocks.Shell.GetProjectRoot() + if err != nil { + t.Fatalf("Failed to get project root: %v", err) + } + expectedPath := filepath.Join(projectRoot, "contexts", "mock-context", ".omni", "config") + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { return nil, os.ErrNotExist } - omniEnvPrinter := NewOmniEnvPrinter(mocks.Injector) - omniEnvPrinter.Initialize() + // When getting environment variables + envVars, err := printer.GetEnvVars() - envVars, err := omniEnvPrinter.GetEnvVars() + // Then no error should be returned if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } - expectedPath := filepath.FromSlash("/mock/config/root/.omni/config") + // And OMNICONFIG should still be set to default path if envVars["OMNICONFIG"] != expectedPath { t.Errorf("OMNICONFIG = %v, want %v", envVars["OMNICONFIG"], expectedPath) } }) - t.Run("GetConfigRootError", func(t *testing.T) { - mocks := setupSafeOmniEnvPrinterMocks() - mocks.ConfigHandler.GetConfigRootFunc = func() (string, error) { + t.Run("GetProjectRootError", func(t *testing.T) { + // Given a new OmniEnvPrinter with failing project root lookup + printer, mocks := setup(t) + mocks.Shell.GetProjectRootFunc = func() (string, error) { return "", errors.New("mock context error") } - omniEnvPrinter := NewOmniEnvPrinter(mocks.Injector) - omniEnvPrinter.Initialize() + // When getting environment variables + _, err := printer.GetEnvVars() - _, err := omniEnvPrinter.GetEnvVars() + // Then appropriate error should be returned expectedError := "error retrieving configuration root directory: mock context error" if err == nil || err.Error() != expectedError { t.Errorf("error = %v, want %v", err, expectedError) @@ -110,59 +104,72 @@ func TestOmniEnvPrinter_GetEnvVars(t *testing.T) { }) } +// TestOmniEnvPrinter_Print tests the Print method of the OmniEnvPrinter func TestOmniEnvPrinter_Print(t *testing.T) { + setup := func(t *testing.T) (*OmniEnvPrinter, *Mocks) { + t.Helper() + mocks := setupMocks(t) + printer := NewOmniEnvPrinter(mocks.Injector) + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize env: %v", err) + } + printer.shims = mocks.Shims + return printer, mocks + } + t.Run("Success", func(t *testing.T) { - // Use setupSafeOmniEnvPrinterMocks to create mocks - mocks := setupSafeOmniEnvPrinterMocks() - mockInjector := mocks.Injector - omniEnvPrinter := NewOmniEnvPrinter(mockInjector) - omniEnvPrinter.Initialize() - - // Mock the stat function to simulate the existence of the omniconfig file - stat = func(name string) (os.FileInfo, error) { - if name == filepath.FromSlash("/mock/config/root/.omni/config") { - return nil, nil // Simulate that the file exists + // Given a new OmniEnvPrinter with existing Omni config + printer, mocks := setup(t) + + // Get the actual project root path + projectRoot, err := mocks.Shell.GetProjectRoot() + if err != nil { + t.Fatalf("Failed to get project root: %v", err) + } + expectedPath := filepath.Join(projectRoot, "contexts", "mock-context", ".omni", "config") + + // And Omni config file exists + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == expectedPath { + return nil, nil } return nil, os.ErrNotExist } - // Mock the PrintEnvVarsFunc to verify it is called with the correct envVars + // And PrintEnvVarsFunc is mocked var capturedEnvVars map[string]string mocks.Shell.PrintEnvVarsFunc = func(envVars map[string]string) { capturedEnvVars = envVars } - // Call Print and check for errors - err := omniEnvPrinter.Print() + // When calling Print + err = printer.Print() + + // Then no error should be returned if err != nil { t.Errorf("unexpected error: %v", err) } - // Verify that PrintEnvVarsFunc was called with the correct envVars + // And environment variables should be set correctly expectedEnvVars := map[string]string{ - "OMNICONFIG": filepath.FromSlash("/mock/config/root/.omni/config"), + "OMNICONFIG": expectedPath, } if !reflect.DeepEqual(capturedEnvVars, expectedEnvVars) { t.Errorf("capturedEnvVars = %v, want %v", capturedEnvVars, expectedEnvVars) } }) - t.Run("GetConfigError", func(t *testing.T) { - // Use setupSafeOmniEnvPrinterMocks to create mocks - mocks := setupSafeOmniEnvPrinterMocks() - - // Override the GetConfigFunc to simulate an error - mocks.ConfigHandler.GetConfigRootFunc = func() (string, error) { + t.Run("GetProjectRootError", func(t *testing.T) { + // Given a new OmniEnvPrinter with failing project root lookup + printer, mocks := setup(t) + mocks.Shell.GetProjectRootFunc = func() (string, error) { return "", errors.New("mock config error") } - mockInjector := mocks.Injector - - omniEnvPrinter := NewOmniEnvPrinter(mockInjector) - omniEnvPrinter.Initialize() + // When calling Print + err := printer.Print() - // Call Print and check for errors - err := omniEnvPrinter.Print() + // Then appropriate error should be returned if err == nil { t.Error("expected error, got nil") } else if !strings.Contains(err.Error(), "mock config error") { diff --git a/pkg/env/shims.go b/pkg/env/shims.go index 0e76d1a99..62c2e303a 100644 --- a/pkg/env/shims.go +++ b/pkg/env/shims.go @@ -1,3 +1,8 @@ +// The shims package is a system call abstraction layer +// It provides mockable wrappers around system and runtime functions +// It serves as a testing aid by allowing system calls to be intercepted +// It enables dependency injection and test isolation for system-level operations + package env import ( @@ -10,64 +15,56 @@ import ( "github.com/goccy/go-yaml" ) -// stat is a variable that holds the os.Stat function for mocking -var stat = os.Stat - -// Define a variable for os.Getwd() for easier testing -var getwd = os.Getwd - -// Define a variable for filepath.Glob for easier testing -var glob = filepath.Glob - -// Wrapper function for os.WriteFile -var writeFile = os.WriteFile - -// Wrapper function for os.ReadDir -var readDir = os.ReadDir - -// Wrapper function for yaml.Unmarshal -var yamlUnmarshal = yaml.Unmarshal - -// Wrapper function for yaml.Marshal -var yamlMarshal = yaml.Marshal - -// Wrapper for os.Remove for mocking in tests -var osRemove = os.Remove - -// Wrapper for os.RemoveAll for mocking in tests -var osRemoveAll = os.RemoveAll - -// Wrapper for crypto/rand.Read for mocking in tests -var cryptoRandRead = func(b []byte) (int, error) { - return rand.Read(b) +// ============================================================================= +// Types +// ============================================================================= + +// Shims provides mockable wrappers around system and runtime functions +type Shims struct { + Stat func(string) (os.FileInfo, error) + Getwd func() (string, error) + Glob func(string) ([]string, error) + WriteFile func(string, []byte, os.FileMode) error + ReadDir func(string) ([]os.DirEntry, error) + YamlUnmarshal func([]byte, interface{}) error + YamlMarshal func(interface{}) ([]byte, error) + Remove func(string) error + RemoveAll func(string) error + CryptoRandRead func([]byte) (int, error) + Goos func() string + UserHomeDir func() (string, error) + MkdirAll func(string, os.FileMode) error + ReadFile func(string) ([]byte, error) + LookPath func(string) (string, error) + LookupEnv func(string) (string, bool) + Environ func() []string + Getenv func(string) string } -// intPtr returns a pointer to an int value -func intPtr(i int) *int { - return &i +// ============================================================================= +// Helpers +// ============================================================================= + +// NewShims creates a new Shims instance with default implementations +func NewShims() *Shims { + return &Shims{ + Stat: os.Stat, + Getwd: os.Getwd, + Glob: filepath.Glob, + WriteFile: os.WriteFile, + ReadDir: os.ReadDir, + YamlUnmarshal: yaml.Unmarshal, + YamlMarshal: yaml.Marshal, + Remove: os.Remove, + RemoveAll: os.RemoveAll, + CryptoRandRead: func(b []byte) (int, error) { return rand.Read(b) }, + Goos: func() string { return runtime.GOOS }, + UserHomeDir: os.UserHomeDir, + MkdirAll: os.MkdirAll, + ReadFile: os.ReadFile, + LookPath: exec.LookPath, + LookupEnv: os.LookupEnv, + Environ: os.Environ, + Getenv: os.Getenv, + } } - -// stringPtr returns a pointer to a string value -func stringPtr(s string) *string { - return &s -} - -// Define a variable for runtime.GOOS for easier testing -var goos = func() string { - return runtime.GOOS -} - -// Define a variable for os.UserHomeDir for easier testing -var osUserHomeDir = os.UserHomeDir - -// Define a variable for os.MkdirAll for easier testing -var mkdirAll = os.MkdirAll - -// Define a variable for os.ReadFile for easier testing -var readFile = os.ReadFile - -// Define a variable for exec.LookPath for easier testing -var execLookPath = exec.LookPath - -// Define a variable for os.LookupEnv for easier testing -var osLookupEnv = os.LookupEnv diff --git a/pkg/env/shims_test.go b/pkg/env/shims_test.go index de42abe24..76768cf73 100644 --- a/pkg/env/shims_test.go +++ b/pkg/env/shims_test.go @@ -1,6 +1,67 @@ package env -// boolPtr returns a pointer to a boolean value -func boolPtr(b bool) *bool { - return &b +import ( + "testing" +) + +func TestNewShims(t *testing.T) { + t.Run("CreatesShimsWithAllFunctions", func(t *testing.T) { + shims := NewShims() + + // Verify all shim functions are initialized + if shims.Stat == nil { + t.Error("Stat shim not initialized") + } + if shims.Getwd == nil { + t.Error("Getwd shim not initialized") + } + if shims.Glob == nil { + t.Error("Glob shim not initialized") + } + if shims.WriteFile == nil { + t.Error("WriteFile shim not initialized") + } + if shims.ReadDir == nil { + t.Error("ReadDir shim not initialized") + } + if shims.YamlUnmarshal == nil { + t.Error("YamlUnmarshal shim not initialized") + } + if shims.YamlMarshal == nil { + t.Error("YamlMarshal shim not initialized") + } + if shims.Remove == nil { + t.Error("Remove shim not initialized") + } + if shims.RemoveAll == nil { + t.Error("RemoveAll shim not initialized") + } + if shims.CryptoRandRead == nil { + t.Error("CryptoRandRead shim not initialized") + } + if shims.Goos == nil { + t.Error("Goos shim not initialized") + } + if shims.UserHomeDir == nil { + t.Error("UserHomeDir shim not initialized") + } + if shims.MkdirAll == nil { + t.Error("MkdirAll shim not initialized") + } + if shims.ReadFile == nil { + t.Error("ReadFile shim not initialized") + } + if shims.LookPath == nil { + t.Error("LookPath shim not initialized") + } + if shims.LookupEnv == nil { + t.Error("LookupEnv shim not initialized") + } + if shims.Environ == nil { + t.Error("Environ shim not initialized") + } + if shims.Getenv == nil { + t.Error("Getenv shim not initialized") + } + }) } diff --git a/pkg/env/talos_env.go b/pkg/env/talos_env.go index 771e6e8be..2a7177453 100644 --- a/pkg/env/talos_env.go +++ b/pkg/env/talos_env.go @@ -1,3 +1,8 @@ +// The TalosEnvPrinter is a specialized component that manages Talos environment configuration. +// It provides Talos-specific environment variable management and configuration, +// The TalosEnvPrinter handles Talos configuration settings and environment setup, +// ensuring proper Talos CLI integration and environment setup for operations. + package env import ( @@ -7,20 +12,30 @@ import ( "github.com/windsorcli/cli/pkg/di" ) -// TalosEnvPrinter is a struct that simulates a Kubernetes environment for testing purposes. +// ============================================================================= +// Types +// ============================================================================= + +// TalosEnvPrinter is a struct that implements Talos environment configuration type TalosEnvPrinter struct { BaseEnvPrinter } -// NewTalosEnvPrinter initializes a new talosEnvPrinter instance using the provided dependency injector. +// ============================================================================= +// Constructor +// ============================================================================= + +// NewTalosEnvPrinter creates a new TalosEnvPrinter instance func NewTalosEnvPrinter(injector di.Injector) *TalosEnvPrinter { return &TalosEnvPrinter{ - BaseEnvPrinter: BaseEnvPrinter{ - injector: injector, - }, + BaseEnvPrinter: *NewBaseEnvPrinter(injector), } } +// ============================================================================= +// Public Methods +// ============================================================================= + // GetEnvVars retrieves the environment variables for the Talos environment. func (e *TalosEnvPrinter) GetEnvVars() (map[string]string, error) { envVars := make(map[string]string) diff --git a/pkg/env/talos_env_test.go b/pkg/env/talos_env_test.go index 1bb8e2eb1..e16e77e45 100644 --- a/pkg/env/talos_env_test.go +++ b/pkg/env/talos_env_test.go @@ -7,165 +7,175 @@ import ( "reflect" "strings" "testing" - - "github.com/windsorcli/cli/pkg/config" - "github.com/windsorcli/cli/pkg/di" - "github.com/windsorcli/cli/pkg/shell" ) -type TalosEnvMocks struct { - Injector di.Injector - ConfigHandler *config.MockConfigHandler - Shell *shell.MockShell -} - -func setupSafeTalosEnvMocks(injector ...di.Injector) *TalosEnvMocks { - var mockInjector di.Injector - if len(injector) > 0 { - mockInjector = injector[0] - } else { - mockInjector = di.NewMockInjector() - } - - mockConfigHandler := config.NewMockConfigHandler() - mockConfigHandler.GetConfigRootFunc = func() (string, error) { - return filepath.FromSlash("/mock/config/root"), nil - } - - mockShell := shell.NewMockShell() +// ============================================================================= +// Test Public Methods +// ============================================================================= - mockInjector.Register("configHandler", mockConfigHandler) - mockInjector.Register("shell", mockShell) +// TestTalosEnv_GetEnvVars tests the GetEnvVars method of the TalosEnvPrinter +func TestTalosEnv_GetEnvVars(t *testing.T) { + setup := func(t *testing.T) (*TalosEnvPrinter, *Mocks) { + t.Helper() + mocks := setupMocks(t) + printer := NewTalosEnvPrinter(mocks.Injector) + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize env: %v", err) + } + printer.shims = mocks.Shims - return &TalosEnvMocks{ - Injector: mockInjector, - ConfigHandler: mockConfigHandler, - Shell: mockShell, + return printer, mocks } -} -func TestTalosEnv_GetEnvVars(t *testing.T) { t.Run("Success", func(t *testing.T) { - mocks := setupSafeTalosEnvMocks() + // Given a new TalosEnvPrinter with existing Talos config + printer, mocks := setup(t) - originalStat := stat - defer func() { stat = originalStat }() - stat = func(name string) (os.FileInfo, error) { - if name == filepath.FromSlash("/mock/config/root/.talos/config") { + // Get the project root path + projectRoot, err := mocks.Shell.GetProjectRoot() + if err != nil { + t.Fatalf("Failed to get project root: %v", err) + } + expectedPath := filepath.Join(projectRoot, "contexts", "mock-context", ".talos", "config") + + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == expectedPath { return nil, nil } return nil, os.ErrNotExist } - talosEnvPrinter := NewTalosEnvPrinter(mocks.Injector) - talosEnvPrinter.Initialize() + // When getting environment variables + envVars, err := printer.GetEnvVars() - envVars, err := talosEnvPrinter.GetEnvVars() + // Then no error should be returned if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } - expectedPath := filepath.FromSlash("/mock/config/root/.talos/config") + // And TALOSCONFIG should be set correctly if envVars["TALOSCONFIG"] != expectedPath { t.Errorf("TALOSCONFIG = %v, want %v", envVars["TALOSCONFIG"], expectedPath) } }) t.Run("NoTalosConfig", func(t *testing.T) { - mocks := setupSafeTalosEnvMocks() + // Given a new TalosEnvPrinter without existing Talos config + printer, mocks := setup(t) - originalStat := stat - defer func() { stat = originalStat }() - stat = func(name string) (os.FileInfo, error) { + // Get the project root path + projectRoot, err := mocks.Shell.GetProjectRoot() + if err != nil { + t.Fatalf("Failed to get project root: %v", err) + } + expectedPath := filepath.Join(projectRoot, "contexts", "mock-context", ".talos", "config") + + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { return nil, os.ErrNotExist } - talosEnvPrinter := NewTalosEnvPrinter(mocks.Injector) - talosEnvPrinter.Initialize() + // When getting environment variables + envVars, err := printer.GetEnvVars() - envVars, err := talosEnvPrinter.GetEnvVars() + // Then no error should be returned if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } - expectedPath := filepath.FromSlash("/mock/config/root/.talos/config") + // And TALOSCONFIG should still be set to default path if envVars["TALOSCONFIG"] != expectedPath { t.Errorf("TALOSCONFIG = %v, want %v", envVars["TALOSCONFIG"], expectedPath) } }) - t.Run("GetConfigRootError", func(t *testing.T) { - mocks := setupSafeTalosEnvMocks() - mocks.ConfigHandler.GetConfigRootFunc = func() (string, error) { - return "", errors.New("mock config error") + t.Run("GetProjectRootError", func(t *testing.T) { + // Given a new TalosEnvPrinter with failing project root lookup + printer, mocks := setup(t) + mocks.Shell.GetProjectRootFunc = func() (string, error) { + return "", errors.New("mock project root error") } - talosEnvPrinter := NewTalosEnvPrinter(mocks.Injector) - talosEnvPrinter.Initialize() + // When getting environment variables + _, err := printer.GetEnvVars() - _, err := talosEnvPrinter.GetEnvVars() - if err == nil || err.Error() != "error retrieving configuration root directory: mock config error" { - t.Errorf("expected error retrieving configuration root directory, got %v", err) + // Then appropriate error should be returned + expectedError := "error retrieving configuration root directory: mock project root error" + if err == nil || err.Error() != expectedError { + t.Errorf("error = %v, want %v", err, expectedError) } }) } +// TestTalosEnv_Print tests the Print method of the TalosEnvPrinter func TestTalosEnv_Print(t *testing.T) { + setup := func(t *testing.T) (*TalosEnvPrinter, *Mocks) { + t.Helper() + mocks := setupMocks(t) + printer := NewTalosEnvPrinter(mocks.Injector) + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize env: %v", err) + } + printer.shims = mocks.Shims + + return printer, mocks + } + t.Run("Success", func(t *testing.T) { - // Use setupSafeTalosEnvMocks to create mocks - mocks := setupSafeTalosEnvMocks() - mockInjector := mocks.Injector - talosEnvPrinter := NewTalosEnvPrinter(mockInjector) - talosEnvPrinter.Initialize() - - // Mock the stat function to simulate the existence of the talos config file - stat = func(name string) (os.FileInfo, error) { - if name == filepath.FromSlash("/mock/config/root/.talos/config") { - return nil, nil // Simulate that the file exists + // Given a new TalosEnvPrinter with existing Talos config + printer, mocks := setup(t) + + // Get the project root path + projectRoot, err := mocks.Shell.GetProjectRoot() + if err != nil { + t.Fatalf("Failed to get project root: %v", err) + } + expectedPath := filepath.Join(projectRoot, "contexts", "mock-context", ".talos", "config") + + // And Talos config file exists + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == expectedPath { + return nil, nil } return nil, os.ErrNotExist } - // Mock the PrintEnvVarsFunc to verify it is called with the correct envVars + // And PrintEnvVarsFunc is mocked var capturedEnvVars map[string]string mocks.Shell.PrintEnvVarsFunc = func(envVars map[string]string) { capturedEnvVars = envVars } - // Call Print and check for errors - err := talosEnvPrinter.Print() + // When calling Print + err = printer.Print() + + // Then no error should be returned if err != nil { t.Errorf("unexpected error: %v", err) } - // Verify that PrintEnvVarsFunc was called with the correct envVars + // And environment variables should be set correctly expectedEnvVars := map[string]string{ - "TALOSCONFIG": filepath.FromSlash("/mock/config/root/.talos/config"), + "TALOSCONFIG": expectedPath, } if !reflect.DeepEqual(capturedEnvVars, expectedEnvVars) { t.Errorf("capturedEnvVars = %v, want %v", capturedEnvVars, expectedEnvVars) } }) - t.Run("GetConfigError", func(t *testing.T) { - // Use setupSafeTalosEnvMocks to create mocks - mocks := setupSafeTalosEnvMocks() - - // Override the GetConfigFunc to simulate an error - mocks.ConfigHandler.GetConfigRootFunc = func() (string, error) { - return "", errors.New("mock config error") + t.Run("GetProjectRootError", func(t *testing.T) { + // Given a new TalosEnvPrinter with failing project root lookup + printer, mocks := setup(t) + mocks.Shell.GetProjectRootFunc = func() (string, error) { + return "", errors.New("mock project root error") } - mockInjector := mocks.Injector - - talosEnvPrinter := NewTalosEnvPrinter(mockInjector) - talosEnvPrinter.Initialize() + // When calling Print + err := printer.Print() - // Call Print and check for errors - err := talosEnvPrinter.Print() + // Then appropriate error should be returned if err == nil { t.Error("expected error, got nil") - } else if !strings.Contains(err.Error(), "mock config error") { + } else if !strings.Contains(err.Error(), "mock project root error") { t.Errorf("unexpected error message: %v", err) } }) diff --git a/pkg/env/terraform_env.go b/pkg/env/terraform_env.go index dac6443f3..bd6b77094 100644 --- a/pkg/env/terraform_env.go +++ b/pkg/env/terraform_env.go @@ -1,3 +1,8 @@ +// The TerraformEnvPrinter is a specialized component that manages Terraform environment configuration. +// It provides Terraform-specific environment variable management and configuration, +// The TerraformEnvPrinter handles backend configuration, variable files, and state management, +// ensuring proper Terraform CLI integration and environment setup for infrastructure operations. + package env import ( @@ -11,20 +16,30 @@ import ( "github.com/windsorcli/cli/pkg/di" ) -// TerraformEnvPrinter simulates a Terraform environment for testing purposes. +// ============================================================================= +// Types +// ============================================================================= + +// TerraformEnvPrinter is a struct that implements Terraform environment configuration type TerraformEnvPrinter struct { BaseEnvPrinter } -// NewTerraformEnvPrinter initializes a new TerraformEnvPrinter instance. +// ============================================================================= +// Constructor +// ============================================================================= + +// NewTerraformEnvPrinter creates a new TerraformEnvPrinter instance func NewTerraformEnvPrinter(injector di.Injector) *TerraformEnvPrinter { return &TerraformEnvPrinter{ - BaseEnvPrinter: BaseEnvPrinter{ - injector: injector, - }, + BaseEnvPrinter: *NewBaseEnvPrinter(injector), } } +// ============================================================================= +// Public Methods +// ============================================================================= + // GetEnvVars retrieves environment variables for Terraform. // It determines the config root and project path, checks for tfvars files, // and sets variables based on the OS. If not in a terraform project folder, @@ -38,7 +53,7 @@ func (e *TerraformEnvPrinter) GetEnvVars() (map[string]string, error) { return nil, fmt.Errorf("error getting config root: %w", err) } - projectPath, err := findRelativeTerraformProjectPath() + projectPath, err := e.findRelativeTerraformProjectPath() if err != nil { return nil, fmt.Errorf("error finding project path: %w", err) } @@ -56,7 +71,7 @@ func (e *TerraformEnvPrinter) GetEnvVars() (map[string]string, error) { } for _, varName := range managedVars { - if _, exists := osLookupEnv(varName); exists { + if _, exists := e.shims.LookupEnv(varName); exists { envVars[varName] = "" } } @@ -73,12 +88,14 @@ func (e *TerraformEnvPrinter) GetEnvVars() (map[string]string, error) { var varFileArgs []string for _, pattern := range patterns { - if _, err := stat(pattern); err != nil { + if _, err := e.shims.Stat(filepath.FromSlash(pattern)); err != nil { if !os.IsNotExist(err) { return nil, fmt.Errorf("error checking file: %w", err) } } else { - varFileArgs = append(varFileArgs, fmt.Sprintf("-var-file=\"%s\"", filepath.ToSlash(pattern))) + // Convert back to slash format for environment variable + slashPath := filepath.ToSlash(pattern) + varFileArgs = append(varFileArgs, fmt.Sprintf("-var-file=\"%s\"", slashPath)) } } @@ -103,7 +120,8 @@ func (e *TerraformEnvPrinter) GetEnvVars() (map[string]string, error) { envVars["TF_CLI_ARGS_destroy"] = strings.TrimSpace(strings.Join(varFileArgs, " ")) envVars["TF_VAR_context_path"] = strings.TrimSpace(filepath.ToSlash(configRoot)) - if goos() == "windows" { + // Set os_type based on the OS + if e.shims.Goos() == "windows" { envVars["TF_VAR_os_type"] = "windows" } else { envVars["TF_VAR_os_type"] = "unix" @@ -126,15 +144,19 @@ func (e *TerraformEnvPrinter) Print() error { return e.BaseEnvPrinter.Print(envVars) } +// ============================================================================= +// Private Methods +// ============================================================================= + // generateBackendOverrideTf creates the backend_override.tf file for the project by determining // the backend type and writing the appropriate configuration to the file. func (e *TerraformEnvPrinter) generateBackendOverrideTf() error { - currentPath, err := getwd() + currentPath, err := e.shims.Getwd() if err != nil { return fmt.Errorf("error getting current directory: %w", err) } - projectPath, err := findRelativeTerraformProjectPath() + projectPath, err := e.findRelativeTerraformProjectPath() if err != nil { return fmt.Errorf("error finding project path: %w", err) } @@ -165,7 +187,7 @@ func (e *TerraformEnvPrinter) generateBackendOverrideTf() error { return fmt.Errorf("unsupported backend: %s", backend) } - err = writeFile(backendOverridePath, []byte(backendConfig), os.ModePerm) + err = e.shims.WriteFile(backendOverridePath, []byte(backendConfig), os.ModePerm) if err != nil { return fmt.Errorf("error writing backend_override.tf: %w", err) } @@ -189,7 +211,7 @@ func (e *TerraformEnvPrinter) generateBackendConfigArgs(projectPath, configRoot if context := e.configHandler.GetContext(); context != "" { backendTfvarsPath := filepath.Join(configRoot, "terraform", "backend.tfvars") - if _, err := stat(backendTfvarsPath); err == nil { + if _, err := e.shims.Stat(backendTfvarsPath); err == nil { backendConfigArgs = append(backendConfigArgs, fmt.Sprintf("-backend-config=\"%s\"", filepath.ToSlash(backendTfvarsPath))) } } @@ -208,7 +230,7 @@ func (e *TerraformEnvPrinter) generateBackendConfigArgs(projectPath, configRoot keyPath := fmt.Sprintf("%s%s", prefix, filepath.ToSlash(filepath.Join(projectPath, "terraform.tfstate"))) addBackendConfigArg("key", keyPath) if backend := e.configHandler.GetConfig().Terraform.Backend.S3; backend != nil { - if err := processBackendConfig(backend, addBackendConfigArg); err != nil { + if err := e.processBackendConfig(backend, addBackendConfigArg); err != nil { return nil, fmt.Errorf("error processing S3 backend config: %w", err) } } @@ -220,7 +242,7 @@ func (e *TerraformEnvPrinter) generateBackendConfigArgs(projectPath, configRoot secretSuffix = sanitizeForK8s(secretSuffix) addBackendConfigArg("secret_suffix", secretSuffix) if backend := e.configHandler.GetConfig().Terraform.Backend.Kubernetes; backend != nil { - if err := processBackendConfig(backend, addBackendConfigArg); err != nil { + if err := e.processBackendConfig(backend, addBackendConfigArg); err != nil { return nil, fmt.Errorf("error processing Kubernetes backend config: %w", err) } } @@ -231,18 +253,15 @@ func (e *TerraformEnvPrinter) generateBackendConfigArgs(projectPath, configRoot return backendConfigArgs, nil } -// Ensure TerraformEnvPrinter implements the EnvPrinter interface. -var _ EnvPrinter = (*TerraformEnvPrinter)(nil) - // processBackendConfig processes the backend config and adds the key-value pairs to the backend config args. -var processBackendConfig = func(backendConfig any, addArg func(key, value string)) error { - yamlData, err := yamlMarshal(backendConfig) +func (e *TerraformEnvPrinter) processBackendConfig(backendConfig any, addArg func(key, value string)) error { + yamlData, err := e.shims.YamlMarshal(backendConfig) if err != nil { return fmt.Errorf("error marshalling backend to YAML: %w", err) } var configMap map[string]any - if err := yamlUnmarshal(yamlData, &configMap); err != nil { + if err := e.shims.YamlUnmarshal(yamlData, &configMap); err != nil { return fmt.Errorf("error unmarshalling backend YAML: %w", err) } @@ -263,7 +282,7 @@ var processBackendConfig = func(backendConfig any, addArg func(key, value string } // processMap processes a map and adds the key-value pairs to the backend config args. -func processMap(prefix string, configMap map[string]interface{}, addArg func(key, value string)) { +func processMap(prefix string, configMap map[string]any, addArg func(key, value string)) { keys := make([]string, 0, len(configMap)) for key := range configMap { keys = append(keys, key) @@ -283,13 +302,13 @@ func processMap(prefix string, configMap map[string]interface{}, addArg func(key addArg(fullKey, fmt.Sprintf("%t", v)) case int, uint64: addArg(fullKey, fmt.Sprintf("%d", v)) - case []interface{}: + case []any: for _, item := range v { if strItem, ok := item.(string); ok { addArg(fullKey, strItem) } } - case map[string]interface{}: + case map[string]any: processMap(fullKey, v, addArg) } } @@ -311,8 +330,8 @@ var sanitizeForK8s = func(input string) string { // findRelativeTerraformProjectPath locates the Terraform project path by checking the current // directory and its ancestors for Terraform files, returning the relative path if found. -var findRelativeTerraformProjectPath = func() (string, error) { - currentPath, err := getwd() +func (e *TerraformEnvPrinter) findRelativeTerraformProjectPath() (string, error) { + currentPath, err := e.shims.Getwd() if err != nil { return "", fmt.Errorf("error getting current directory: %w", err) } @@ -320,7 +339,7 @@ var findRelativeTerraformProjectPath = func() (string, error) { currentPath = filepath.Clean(currentPath) globPattern := filepath.Join(currentPath, "*.tf") - matches, err := glob(globPattern) + matches, err := e.shims.Glob(globPattern) if err != nil { return "", fmt.Errorf("error finding project path: %w", err) } @@ -332,9 +351,12 @@ var findRelativeTerraformProjectPath = func() (string, error) { for i := len(pathParts) - 1; i >= 0; i-- { if strings.EqualFold(pathParts[i], "terraform") || strings.EqualFold(pathParts[i], ".tf_modules") { relativePath := filepath.Join(pathParts[i+1:]...) - return relativePath, nil + return filepath.ToSlash(relativePath), nil } } return "", nil } + +// Ensure TerraformEnvPrinter implements the EnvPrinter interface +var _ EnvPrinter = (*TerraformEnvPrinter)(nil) diff --git a/pkg/env/terraform_env_test.go b/pkg/env/terraform_env_test.go index 727662c45..58bdd8f2d 100644 --- a/pkg/env/terraform_env_test.go +++ b/pkg/env/terraform_env_test.go @@ -9,126 +9,115 @@ import ( "strings" "testing" - "github.com/windsorcli/cli/api/v1alpha1" - "github.com/windsorcli/cli/api/v1alpha1/terraform" "github.com/windsorcli/cli/pkg/config" - "github.com/windsorcli/cli/pkg/di" - "github.com/windsorcli/cli/pkg/shell" ) -type TerraformEnvMocks struct { - Injector di.Injector - Shell *shell.MockShell - ConfigHandler *config.MockConfigHandler -} - -func setupSafeTerraformEnvMocks(injector ...di.Injector) *TerraformEnvMocks { - var mockInjector di.Injector - if len(injector) > 0 { - mockInjector = injector[0] - } else { - mockInjector = di.NewMockInjector() - } +// ============================================================================= +// Test Setup +// ============================================================================= - mockShell := shell.NewMockShell() +// setupTerraformEnvMocks creates and configures mock objects for Terraform environment tests. +func setupTerraformEnvMocks(t *testing.T, opts ...*SetupOptions) *Mocks { + // Pass the mock config handler to setupMocks + mocks := setupMocks(t, opts...) - mockConfigHandler := config.NewMockConfigHandler() - mockConfigHandler.GetConfigRootFunc = func() (string, error) { - return "/mock/config/root", nil + mocks.Shims.Getwd = func() (string, error) { + // Use platform-agnostic path + return filepath.Join("mock", "project", "root", "terraform", "project", "path"), nil } - mockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "terraform.backend.type" { - return "local" - } - if len(defaultValue) > 0 { - return defaultValue[0] + + mocks.Shims.Glob = func(pattern string) ([]string, error) { + if strings.Contains(pattern, "*.tf") { + return []string{ + filepath.Join("real", "terraform", "project", "path", "file1.tf"), + filepath.Join("real", "terraform", "project", "path", "file2.tf"), + }, nil } - return "" - } - mockConfigHandler.GetContextFunc = func() string { - return "mock-context" + return nil, nil } - mockInjector.Register("shell", mockShell) - mockInjector.Register("configHandler", mockConfigHandler) + mocks.ConfigHandler.SetContextValue("terraform.backend.type", "local") - stat = func(name string) (os.FileInfo, error) { - return nil, nil - } + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + // Convert paths to slash format for consistent comparison + nameSlash := filepath.ToSlash(name) - return &TerraformEnvMocks{ - Injector: mockInjector, - Shell: mockShell, - ConfigHandler: mockConfigHandler, + // Check for tfvars files in the expected paths + if strings.Contains(nameSlash, "project/path.tfvars") || + strings.Contains(nameSlash, "project/path.tfvars.json") || + strings.Contains(nameSlash, "project\\path.tfvars") || + strings.Contains(nameSlash, "project\\path.tfvars.json") { + return nil, nil + } + if strings.Contains(nameSlash, "project/path_generated.tfvars") { + return nil, os.ErrNotExist + } + return nil, os.ErrNotExist } -} -func TestTerraformEnv_GetEnvVars(t *testing.T) { - t.Run("Success", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() + return mocks +} - expectedEnvVars := map[string]string{ - "TF_DATA_DIR": `/mock/config/root/.terraform/project/path`, - "TF_CLI_ARGS_init": `-backend=true -backend-config="path=/mock/config/root/.tfstate/project/path/terraform.tfstate"`, - "TF_CLI_ARGS_plan": `-out="/mock/config/root/.terraform/project/path/terraform.tfplan" -var-file="/mock/config/root/terraform/project/path.tfvars" -var-file="/mock/config/root/terraform/project/path.tfvars.json"`, - "TF_CLI_ARGS_apply": `"/mock/config/root/.terraform/project/path/terraform.tfplan"`, - "TF_CLI_ARGS_import": `-var-file="/mock/config/root/terraform/project/path.tfvars" -var-file="/mock/config/root/terraform/project/path.tfvars.json"`, - "TF_CLI_ARGS_destroy": `-var-file="/mock/config/root/terraform/project/path.tfvars" -var-file="/mock/config/root/terraform/project/path.tfvars.json"`, - "TF_VAR_context_path": `/mock/config/root`, - } +// ============================================================================= +// Test Public Methods +// ============================================================================= - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() +// TestTerraformEnv_GetEnvVars tests the GetEnvVars method of the TerraformEnvPrinter +func TestTerraformEnv_GetEnvVars(t *testing.T) { + setup := func(t *testing.T) (*TerraformEnvPrinter, *Mocks) { + t.Helper() + mocks := setupTerraformEnvMocks(t) + printer := NewTerraformEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize printer: %v", err) + } + return printer, mocks + } - // Given a mocked glob function simulating the presence of tf files - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { - if strings.Contains(pattern, "*.tf") { - return []string{"real/terraform/project/path/file1.tf", "real/terraform/project/path/file2.tf"}, nil - } - return nil, nil - } + t.Run("Success", func(t *testing.T) { + // Given a new TerraformEnvPrinter with mock configuration + printer, mocks := setup(t) - // And a mocked getwd function returning a specific path - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return filepath.FromSlash("/mock/project/root/terraform/project/path"), nil + // Mock the OS type + osType := "unix" + if mocks.Shims.Goos() == "windows" { + osType = "windows" } - // And a mocked stat function simulating file existence with varied tfvars files - originalStat := stat - defer func() { stat = originalStat }() - stat = func(name string) (os.FileInfo, error) { - // Debugging: Print the path being checked - t.Logf("Checking file: %s", name) - switch name { - case filepath.FromSlash("/mock/config/root/terraform/project/path.tfvars"): - return nil, nil // Simulate file exists - case filepath.FromSlash("/mock/config/root/terraform/project/path.tfvars.json"): - return nil, nil // Simulate file exists - case filepath.FromSlash("/mock/config/root/terraform/project/path_generated.tfvars"): - return nil, os.ErrNotExist // Simulate file does not exist - case filepath.FromSlash("/mock/config/root/terraform/project/path_generated.tfvars.json"): - return nil, os.ErrNotExist // Simulate file does not exist - default: - return nil, os.ErrNotExist // Simulate file does not exist - } + // Get the actual config root + configRoot, err := mocks.ConfigHandler.GetConfigRoot() + if err != nil { + t.Fatalf("Failed to get config root: %v", err) } - // When the GetEnvVars function is called - envVars, err := terraformEnvPrinter.GetEnvVars() + expectedEnvVars := map[string]string{ + "TF_DATA_DIR": filepath.ToSlash(filepath.Join(configRoot, ".terraform/project/path")), + "TF_CLI_ARGS_init": fmt.Sprintf(`-backend=true -backend-config="path=%s"`, filepath.ToSlash(filepath.Join(configRoot, ".tfstate/project/path/terraform.tfstate"))), + "TF_CLI_ARGS_plan": fmt.Sprintf(`-out="%s" -var-file="%s" -var-file="%s"`, + filepath.ToSlash(filepath.Join(configRoot, ".terraform/project/path/terraform.tfplan")), + filepath.ToSlash(filepath.Join(configRoot, "terraform/project/path.tfvars")), + filepath.ToSlash(filepath.Join(configRoot, "terraform/project/path.tfvars.json"))), + "TF_CLI_ARGS_apply": fmt.Sprintf(`"%s"`, filepath.ToSlash(filepath.Join(configRoot, ".terraform/project/path/terraform.tfplan"))), + "TF_CLI_ARGS_import": fmt.Sprintf(`-var-file="%s" -var-file="%s"`, + filepath.ToSlash(filepath.Join(configRoot, "terraform/project/path.tfvars")), + filepath.ToSlash(filepath.Join(configRoot, "terraform/project/path.tfvars.json"))), + "TF_CLI_ARGS_destroy": fmt.Sprintf(`-var-file="%s" -var-file="%s"`, + filepath.ToSlash(filepath.Join(configRoot, "terraform/project/path.tfvars")), + filepath.ToSlash(filepath.Join(configRoot, "terraform/project/path.tfvars.json"))), + "TF_VAR_context_path": filepath.ToSlash(configRoot), + "TF_VAR_os_type": osType, + } + + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then no error should be returned if err != nil { t.Errorf("Expected no error, got %v", err) } - // Debugging: Print the actual envVars on Windows - for key, value := range envVars { - t.Logf("envVar[%s] = %s", key, value) - } - - // Then the expected environment variables should be set + // And environment variables should be set correctly for key, expectedValue := range expectedEnvVars { if value, exists := envVars[key]; !exists || value != expectedValue { t.Errorf("Expected %s to be %s, got %s", key, expectedValue, value) @@ -137,19 +126,15 @@ func TestTerraformEnv_GetEnvVars(t *testing.T) { }) t.Run("ErrorGettingProjectPath", func(t *testing.T) { - // Mock the getwd function to simulate an error - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { + printer, mocks := setup(t) + + // Mock Getwd to return an error + mocks.Shims.Getwd = func() (string, error) { return "", fmt.Errorf("mock error getting current directory") } - mocks := setupSafeTerraformEnvMocks() - - // When the GetEnvVars function is called - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - _, err := terraformEnvPrinter.GetEnvVars() + // When GetEnvVars is called + _, err := printer.GetEnvVars() // Then the error should contain the expected message if err == nil { @@ -161,58 +146,52 @@ func TestTerraformEnv_GetEnvVars(t *testing.T) { }) t.Run("NoProjectPathFound", func(t *testing.T) { - // Given a mocked getwd function returning a specific path - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { + // Given a new TerraformEnvPrinter with no Terraform project path + printer, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { return filepath.FromSlash("/mock/project/root"), nil } - mocks := setupSafeTerraformEnvMocks() - // When the GetEnvVars function is called - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - envVars, err := terraformEnvPrinter.GetEnvVars() + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } - // Then it should return an empty map without an error + // And an empty map should be returned if envVars == nil { t.Errorf("Expected an empty map, got nil") } if len(envVars) != 0 { t.Errorf("Expected empty map, got %v", envVars) } - if err != nil { - t.Errorf("Expected no error, got %v", err) - } }) t.Run("ResetEnvVarsWhenNoProjectPathFound", func(t *testing.T) { - // Given a mocked getwd function returning a specific path - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { + // Given a new TerraformEnvPrinter with existing environment variables + printer, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { return filepath.FromSlash("/mock/project/root"), nil } - // And some environment variables that should be reset - originalLookupEnv := osLookupEnv - defer func() { osLookupEnv = originalLookupEnv }() - osLookupEnv = func(key string) (string, bool) { - // Simulate that TF_DATA_DIR and TF_CLI_ARGS_init exist in environment + mocks.Shims.LookupEnv = func(key string) (string, bool) { if key == "TF_DATA_DIR" || key == "TF_CLI_ARGS_init" { return "some-value", true } return "", false } - mocks := setupSafeTerraformEnvMocks() + // When getting environment variables + envVars, err := printer.GetEnvVars() - // When the GetEnvVars function is called - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - envVars, err := terraformEnvPrinter.GetEnvVars() + // Then no error should be returned + if err != nil { + t.Errorf("Expected no error, got %v", err) + } - // Then it should return a map with only the variables that need to be reset + // And environment variables should be reset if envVars == nil { t.Errorf("Expected a map with reset variables, got nil") } @@ -225,75 +204,53 @@ func TestTerraformEnv_GetEnvVars(t *testing.T) { if val, exists := envVars["TF_CLI_ARGS_init"]; !exists || val != "" { t.Errorf("Expected TF_CLI_ARGS_init to be empty string, got %v", val) } - if err != nil { - t.Errorf("Expected no error, got %v", err) - } }) t.Run("ErrorGettingConfigRoot", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetConfigRootFunc = func() (string, error) { + configHandler := config.NewMockConfigHandler() + configHandler.GetConfigRootFunc = func() (string, error) { return "", fmt.Errorf("mock error getting config root") } - - // Given a mocked getwd function simulating being in a terraform project root - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return filepath.FromSlash("/mock/project/root/terraform/project/path"), nil + mocks := setupTerraformEnvMocks(t, &SetupOptions{ + ConfigHandler: configHandler, + }) + printer := NewTerraformEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize printer: %v", err) } - // When the GetEnvVars function is called - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - _, err := terraformEnvPrinter.GetEnvVars() + // When GetEnvVars is called + _, err := printer.GetEnvVars() - // Then the error should be as expected - expectedErrorMessage := "error getting config root: mock error getting config root" - if err == nil || err.Error() != expectedErrorMessage { - t.Errorf("Expected error %q, got %v", expectedErrorMessage, err) + // Then the error should contain the expected message + if err == nil { + t.Errorf("Expected error, got nil") + } + if !strings.Contains(err.Error(), "error getting config root") { + t.Errorf("Expected error message to contain 'error getting config root', got %v", err) } }) t.Run("ErrorListingTfvarsFiles", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetContextFunc = func() string { - return "mockContext" - } - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{} - } - - // Given a mocked getwd function simulating being in a terraform project root - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return filepath.FromSlash("/mock/project/root/terraform/project/path"), nil - } + // Given a new TerraformEnvPrinter with failing file stat + printer, mocks := setup(t) - // And a mocked glob function succeeding for *.tf files - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { + mocks.Shims.Glob = func(pattern string) ([]string, error) { if strings.Contains(pattern, "*.tf") { return []string{"file1.tf", "file2.tf"}, nil } return nil, nil } - // And a mocked stat function returning an error other than os.IsNotExist - originalStat := stat - defer func() { stat = originalStat }() - stat = func(name string) (os.FileInfo, error) { + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { return nil, fmt.Errorf("mock error checking file") } - // When the GetEnvVars function is called - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - _, err := terraformEnvPrinter.GetEnvVars() + // When getting environment variables + _, err := printer.GetEnvVars() - // Then the error should be as expected + // Then appropriate error should be returned expectedErrorMessage := "error checking file: mock error checking file" if err == nil || err.Error() != expectedErrorMessage { t.Errorf("Expected error %q, got %v", expectedErrorMessage, err) @@ -301,105 +258,103 @@ func TestTerraformEnv_GetEnvVars(t *testing.T) { }) t.Run("TestWindows", func(t *testing.T) { - originalGoos := goos - defer func() { goos = originalGoos }() - goos = func() string { + // Given a new TerraformEnvPrinter on Windows + printer, mocks := setup(t) + + // Mock Windows OS + mocks.Shims.Goos = func() string { return "windows" } - mocks := setupSafeTerraformEnvMocks() - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - - // Mock the getwd function to simulate being in a terraform project path - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { + // Mock filesystem operations + mocks.Shims.Getwd = func() (string, error) { return filepath.FromSlash("/mock/project/root/terraform/project/path"), nil } - // Mock the glob function to simulate the presence of *.tf files - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { + mocks.Shims.Glob = func(pattern string) ([]string, error) { if strings.Contains(pattern, "*.tf") { return []string{"main.tf"}, nil } return nil, nil } - // Mock the stat function to simulate the existence of tfvars files - originalStat := stat - defer func() { stat = originalStat }() - stat = func(name string) (os.FileInfo, error) { - if name == filepath.FromSlash("/mock/config/root/terraform/project/path.tfvars") { - return nil, nil // Simulate file exists + // Get the actual config root + configRoot, err := mocks.ConfigHandler.GetConfigRoot() + if err != nil { + t.Fatalf("Failed to get config root: %v", err) + } + + // Mock Stat to handle both tfvars files + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + // Convert paths to slash format for consistent comparison + nameSlash := filepath.ToSlash(name) + + // Check for tfvars files in the expected paths + if strings.Contains(nameSlash, "project/path.tfvars") || + strings.Contains(nameSlash, "project/path.tfvars.json") || + strings.Contains(nameSlash, "project\\path.tfvars") || + strings.Contains(nameSlash, "project\\path.tfvars.json") { + return nil, nil } return nil, os.ErrNotExist } - // Mock the GetEnvVars function to verify it returns the correct envVars - envVars, err := terraformEnvPrinter.GetEnvVars() + // When getting environment variables + envVars, err := printer.GetEnvVars() + + // Then no error should be returned if err != nil { t.Errorf("unexpected error: %v", err) } - // Verify that GetEnvVars returns the correct envVars + // And environment variables should be set correctly expectedEnvVars := map[string]string{ - "TF_VAR_os_type": "windows", + "TF_DATA_DIR": filepath.ToSlash(filepath.Join(configRoot, ".terraform/project/path")), + "TF_CLI_ARGS_init": fmt.Sprintf(`-backend=true -backend-config="path=%s"`, filepath.ToSlash(filepath.Join(configRoot, ".tfstate/project/path/terraform.tfstate"))), + "TF_CLI_ARGS_plan": fmt.Sprintf(`-out="%s" -var-file="%s" -var-file="%s"`, + filepath.ToSlash(filepath.Join(configRoot, ".terraform/project/path/terraform.tfplan")), + filepath.ToSlash(filepath.Join(configRoot, "terraform/project/path.tfvars")), + filepath.ToSlash(filepath.Join(configRoot, "terraform/project/path.tfvars.json"))), + "TF_CLI_ARGS_apply": fmt.Sprintf(`"%s"`, filepath.ToSlash(filepath.Join(configRoot, ".terraform/project/path/terraform.tfplan"))), + "TF_CLI_ARGS_import": fmt.Sprintf(`-var-file="%s" -var-file="%s"`, + filepath.ToSlash(filepath.Join(configRoot, "terraform/project/path.tfvars")), + filepath.ToSlash(filepath.Join(configRoot, "terraform/project/path.tfvars.json"))), + "TF_CLI_ARGS_destroy": fmt.Sprintf(`-var-file="%s" -var-file="%s"`, + filepath.ToSlash(filepath.Join(configRoot, "terraform/project/path.tfvars")), + filepath.ToSlash(filepath.Join(configRoot, "terraform/project/path.tfvars.json"))), + "TF_VAR_context_path": filepath.ToSlash(configRoot), + "TF_VAR_os_type": "windows", } + if envVars == nil { t.Errorf("envVars is nil, expected %v", expectedEnvVars) - } else if value, exists := envVars["TF_VAR_os_type"]; !exists || value != expectedEnvVars["TF_VAR_os_type"] { - t.Errorf("envVars[TF_VAR_os_type] = %v, want %v", value, expectedEnvVars["TF_VAR_os_type"]) + } else { + for key, expectedValue := range expectedEnvVars { + if value, exists := envVars[key]; !exists || value != expectedValue { + t.Errorf("Expected %s to be %s, got %s", key, expectedValue, value) + } + } } }) } func TestTerraformEnv_PostEnvHook(t *testing.T) { - t.Run("Success", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetContextFunc = func() string { - return "mockContext" - } - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - Terraform: &terraform.TerraformConfig{ - Backend: &terraform.BackendConfig{ - Type: "local", - }, - }, - } - } - - // Given a mocked getwd function simulating being in a terraform project root - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return filepath.FromSlash("mock/project/root/terraform/project/path"), nil - } - - // And a mocked glob function succeeding for *.tf files - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { - if strings.Contains(pattern, "*.tf") { - return []string{"file1.tf", "file2.tf"}, nil - } - return nil, nil - } + setup := func(t *testing.T) (*TerraformEnvPrinter, *Mocks) { + t.Helper() + mocks := setupTerraformEnvMocks(t) + printer := NewTerraformEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize printer: %v", err) + } + return printer, mocks + } - // And a mocked writeFile function simulating successful file writing - originalWriteFile := writeFile - defer func() { writeFile = originalWriteFile }() - writeFile = func(filename string, data []byte, perm os.FileMode) error { - return nil - } + t.Run("Success", func(t *testing.T) { + printer, _ := setup(t) // When the PostEnvHook function is called - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - err := terraformEnvPrinter.PostEnvHook() + err := printer.PostEnvHook() // Then no error should occur if err != nil { @@ -408,18 +363,13 @@ func TestTerraformEnv_PostEnvHook(t *testing.T) { }) t.Run("ErrorGettingCurrentDirectory", func(t *testing.T) { - // Given a mocked getwd function returning an error - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { + printer, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { return "", fmt.Errorf("mock error getting current directory") } // When the PostEnvHook function is called - mocks := setupSafeTerraformEnvMocks() - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - err := terraformEnvPrinter.PostEnvHook() + err := printer.PostEnvHook() // Then the error should contain the expected message if err == nil { @@ -431,18 +381,13 @@ func TestTerraformEnv_PostEnvHook(t *testing.T) { }) t.Run("ErrorFindingProjectPath", func(t *testing.T) { - // Given a mocked glob function returning an error - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { + printer, mocks := setup(t) + mocks.Shims.Glob = func(pattern string) ([]string, error) { return nil, fmt.Errorf("mock error finding project path") } // When the PostEnvHook function is called - mocks := setupSafeTerraformEnvMocks() - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - err := terraformEnvPrinter.PostEnvHook() + err := printer.PostEnvHook() // Then the error should contain the expected message if err == nil { @@ -454,33 +399,11 @@ func TestTerraformEnv_PostEnvHook(t *testing.T) { }) t.Run("UnsupportedBackend", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "terraform.backend.type" { - return "unsupported" - } - if len(defaultValue) > 0 { - return defaultValue[0] - } - return "" - } - - // Given a mocked getwd function simulating being in a terraform project root - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return filepath.FromSlash("mock/project/root/terraform/project/path"), nil - } - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { - return []string{filepath.FromSlash("mock/project/root/terraform/project/path/main.tf")}, nil - } + printer, mocks := setup(t) + mocks.ConfigHandler.SetContextValue("terraform.backend.type", "unsupported") // When the PostEnvHook function is called - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - err := terraformEnvPrinter.PostEnvHook() + err := printer.PostEnvHook() // Then the error should contain the expected message if err == nil { @@ -492,30 +415,13 @@ func TestTerraformEnv_PostEnvHook(t *testing.T) { }) t.Run("ErrorWritingBackendOverrideFile", func(t *testing.T) { - // Given a mocked writeFile function returning an error - originalWriteFile := writeFile - defer func() { writeFile = originalWriteFile }() - writeFile = func(filename string, data []byte, perm os.FileMode) error { + printer, mocks := setup(t) + mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error { return fmt.Errorf("mock error writing backend_override.tf file") } - // And a mocked getwd function simulating being in a terraform project root - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return filepath.FromSlash("mock/project/root/terraform/project/path"), nil - } - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { - return []string{filepath.FromSlash("mock/project/root/terraform/project/path/main.tf")}, nil - } - // When the PostEnvHook function is called - mocks := setupSafeTerraformEnvMocks() - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - err := terraformEnvPrinter.PostEnvHook() + err := printer.PostEnvHook() // Then the error should contain the expected message if err == nil { @@ -528,90 +434,89 @@ func TestTerraformEnv_PostEnvHook(t *testing.T) { } func TestTerraformEnv_Print(t *testing.T) { - t.Run("Success", func(t *testing.T) { - // Use setupSafeTerraformEnvMocks to create mocks - mocks := setupSafeTerraformEnvMocks() - mockInjector := mocks.Injector - terraformEnvPrinter := NewTerraformEnvPrinter(mockInjector) - terraformEnvPrinter.Initialize() - - // Mock the stat function to simulate the existence of the terraform config file - originalStat := stat - defer func() { stat = originalStat }() - stat = func(name string) (os.FileInfo, error) { - if name == filepath.FromSlash("/mock/config/root/.terraform/config") { - return nil, nil // Simulate that the file exists - } - return nil, os.ErrNotExist - } - - // Mock the glob function to simulate the presence of *.tf files - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { - if strings.Contains(pattern, "*.tf") { - return []string{"main.tf"}, nil // Simulate that tf files exist - } - return nil, nil - } + setup := func(t *testing.T) (*TerraformEnvPrinter, *Mocks) { + t.Helper() + mocks := setupTerraformEnvMocks(t) + printer := NewTerraformEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize printer: %v", err) + } + return printer, mocks + } - // Mock the getwd function to return a path that includes "terraform" - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return filepath.FromSlash("/mock/project/root/terraform/project/path"), nil - } + t.Run("Success", func(t *testing.T) { + // Given a TerraformEnvPrinter with mock configuration + printer, mocks := setup(t) - // Mock the PrintEnvVarsFunc to verify it is called with the correct envVars var capturedEnvVars map[string]string mocks.Shell.PrintEnvVarsFunc = func(envVars map[string]string) { capturedEnvVars = envVars } - // Call Print and check for errors - err := terraformEnvPrinter.Print() + // When Print is called + err := printer.Print() if err != nil { t.Errorf("unexpected error: %v", err) } - // Determine the expected OS type + // Then the expected environment variables should be set expectedOSType := "unix" - if goos() == "windows" { + if mocks.Shims.Goos() == "windows" { expectedOSType = "windows" } - // Verify that PrintEnvVarsFunc was called with the correct envVars + configRoot, err := mocks.ConfigHandler.GetConfigRoot() + if err != nil { + t.Fatalf("Failed to get config root: %v", err) + } + expectedEnvVars := map[string]string{ - "TF_DATA_DIR": "/mock/config/root/.terraform/project/path", - "TF_CLI_ARGS_init": "-backend=true -backend-config=\"path=/mock/config/root/.tfstate/project/path/terraform.tfstate\"", - "TF_CLI_ARGS_plan": `-out="/mock/config/root/.terraform/project/path/terraform.tfplan"`, - "TF_CLI_ARGS_apply": `"/mock/config/root/.terraform/project/path/terraform.tfplan"`, - "TF_CLI_ARGS_import": "", - "TF_CLI_ARGS_destroy": "", - "TF_VAR_context_path": "/mock/config/root", + "TF_DATA_DIR": filepath.Join(configRoot, ".terraform/project/path"), + "TF_CLI_ARGS_init": fmt.Sprintf(`-backend=true -backend-config="path=%s"`, filepath.Join(configRoot, ".tfstate/project/path/terraform.tfstate")), + "TF_CLI_ARGS_plan": fmt.Sprintf(`-out="%s" -var-file="%s" -var-file="%s"`, + filepath.Join(configRoot, ".terraform/project/path/terraform.tfplan"), + filepath.Join(configRoot, "terraform/project/path.tfvars"), + filepath.Join(configRoot, "terraform/project/path.tfvars.json")), + "TF_CLI_ARGS_apply": fmt.Sprintf(`"%s"`, filepath.Join(configRoot, ".terraform/project/path/terraform.tfplan")), + "TF_CLI_ARGS_import": fmt.Sprintf(`-var-file="%s" -var-file="%s"`, + filepath.Join(configRoot, "terraform/project/path.tfvars"), + filepath.Join(configRoot, "terraform/project/path.tfvars.json")), + "TF_CLI_ARGS_destroy": fmt.Sprintf(`-var-file="%s" -var-file="%s"`, + filepath.Join(configRoot, "terraform/project/path.tfvars"), + filepath.Join(configRoot, "terraform/project/path.tfvars.json")), + "TF_VAR_context_path": configRoot, "TF_VAR_os_type": expectedOSType, } + + for k, v := range expectedEnvVars { + expectedEnvVars[k] = filepath.ToSlash(v) + } + for k, v := range capturedEnvVars { + capturedEnvVars[k] = filepath.ToSlash(v) + } + if !reflect.DeepEqual(capturedEnvVars, expectedEnvVars) { t.Errorf("capturedEnvVars = %v, want %v", capturedEnvVars, expectedEnvVars) } }) t.Run("GetConfigError", func(t *testing.T) { - // Use setupSafeTerraformEnvMocks to create mocks - mocks := setupSafeTerraformEnvMocks() - - // Override the GetConfigFunc to simulate an error - mocks.ConfigHandler.GetConfigRootFunc = func() (string, error) { + // Given a TerraformEnvPrinter with a failing config handler + configHandler := config.NewMockConfigHandler() + configHandler.GetConfigRootFunc = func() (string, error) { return "", fmt.Errorf("mock config error") } - - mockInjector := mocks.Injector - - terraformEnvPrinter := NewTerraformEnvPrinter(mockInjector) + mocks := setupTerraformEnvMocks(t, &SetupOptions{ + ConfigHandler: configHandler, + }) + terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) terraformEnvPrinter.Initialize() - // Call Print and check for errors + // When Print is called err := terraformEnvPrinter.Print() + + // Then an error should be returned if err == nil { t.Error("expected error, got nil") } else if !strings.Contains(err.Error(), "mock config error") { @@ -621,54 +526,43 @@ func TestTerraformEnv_Print(t *testing.T) { } func TestTerraformEnv_findRelativeTerraformProjectPath(t *testing.T) { + setup := func(t *testing.T) (*TerraformEnvPrinter, *Mocks) { + t.Helper() + mocks := setupTerraformEnvMocks(t) + printer := NewTerraformEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize printer: %v", err) + } + return printer, mocks + } + t.Run("Success", func(t *testing.T) { - // Given a mocked getwd function returning a specific directory path - originalGetwd := getwd - getwd = func() (string, error) { - return filepath.FromSlash("/mock/path/to/terraform/project"), nil - } - defer func() { getwd = originalGetwd }() - - // And a mocked glob function simulating finding Terraform files - originalGlob := glob - glob = func(pattern string) ([]string, error) { - if pattern == filepath.FromSlash("/mock/path/to/terraform/project/*.tf") { - return []string{filepath.FromSlash("/mock/path/to/terraform/project/main.tf")}, nil - } - return nil, nil - } - defer func() { glob = originalGlob }() + // Given a TerraformEnvPrinter with mock configuration + printer, _ := setup(t) // When findRelativeTerraformProjectPath is called - projectPath, err := findRelativeTerraformProjectPath() + projectPath, err := printer.findRelativeTerraformProjectPath() // Then no error should occur and the expected project path should be returned if err != nil { t.Errorf("Expected no error, got %v", err) } - expectedPath := "project" + expectedPath := "project/path" if projectPath != expectedPath { t.Errorf("Expected project path %v, got %v", expectedPath, projectPath) } }) t.Run("NoTerraformFiles", func(t *testing.T) { - // Given a mocked getwd function returning a specific directory path - originalGetwd := getwd - getwd = func() (string, error) { - return filepath.FromSlash("/mock/path/to/terraform/project"), nil - } - defer func() { getwd = originalGetwd }() - - // And a mocked glob function simulating no Terraform files found - originalGlob := glob - glob = func(pattern string) ([]string, error) { + // Given a TerraformEnvPrinter with no Terraform files + printer, mocks := setup(t) + mocks.Shims.Glob = func(pattern string) ([]string, error) { return nil, nil } - defer func() { glob = originalGlob }() // When findRelativeTerraformProjectPath is called - projectPath, err := findRelativeTerraformProjectPath() + projectPath, err := printer.findRelativeTerraformProjectPath() // Then no error should occur and the project path should be empty if err != nil { @@ -680,17 +574,16 @@ func TestTerraformEnv_findRelativeTerraformProjectPath(t *testing.T) { }) t.Run("ErrorGettingCurrentDirectory", func(t *testing.T) { - // Given a mocked getwd function returning an error - originalGetwd := getwd - getwd = func() (string, error) { + // Given a TerraformEnvPrinter with a failing Getwd function + printer, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { return "", fmt.Errorf("mock error getting current directory") } - defer func() { getwd = originalGetwd }() // When findRelativeTerraformProjectPath is called - _, err := findRelativeTerraformProjectPath() + _, err := printer.findRelativeTerraformProjectPath() - // Then the error should contain the expected message + // Then an error should be returned if err == nil { t.Errorf("Expected error, got nil") } @@ -700,25 +593,20 @@ func TestTerraformEnv_findRelativeTerraformProjectPath(t *testing.T) { }) t.Run("NoTerraformDirectoryFound", func(t *testing.T) { - // Given a mocked getwd function returning a specific directory path - originalGetwd := getwd - getwd = func() (string, error) { + // Given a TerraformEnvPrinter with no Terraform directory + printer, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { return filepath.FromSlash("/mock/path/to/project"), nil } - defer func() { getwd = originalGetwd }() - - // And a mocked glob function simulating finding Terraform files - originalGlob := glob - glob = func(pattern string) ([]string, error) { + mocks.Shims.Glob = func(pattern string) ([]string, error) { if pattern == filepath.FromSlash("/mock/path/to/project/*.tf") { return []string{filepath.FromSlash("/mock/path/to/project/main.tf")}, nil } return nil, nil } - defer func() { glob = originalGlob }() // When findRelativeTerraformProjectPath is called - projectPath, err := findRelativeTerraformProjectPath() + projectPath, err := printer.findRelativeTerraformProjectPath() // Then no error should occur and the project path should be empty if err != nil { @@ -792,51 +680,30 @@ func TestTerraformEnv_sanitizeForK8s(t *testing.T) { } func TestTerraformEnv_generateBackendOverrideTf(t *testing.T) { - t.Run("Success", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetConfigRootFunc = func() (string, error) { - return filepath.FromSlash("/mock/config/root"), nil - } - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - Terraform: &terraform.TerraformConfig{ - Backend: &terraform.BackendConfig{ - Type: "local", - }, - }, - } - } + setup := func(t *testing.T) (*TerraformEnvPrinter, *Mocks) { + t.Helper() + mocks := setupTerraformEnvMocks(t) + printer := NewTerraformEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize printer: %v", err) + } + return printer, mocks + } - // Given a mocked getwd function simulating being in a terraform project root - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return filepath.FromSlash("/mock/project/root/terraform/project/path"), nil - } - // And a mocked glob function simulating finding Terraform files - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { - expectedPattern := filepath.FromSlash("/mock/project/root/terraform/project/path/*.tf") - if pattern == expectedPattern { - return []string{filepath.FromSlash("/mock/project/root/terraform/project/path/main.tf")}, nil - } - return nil, nil - } + t.Run("Success", func(t *testing.T) { + // Given a TerraformEnvPrinter with mock configuration + printer, mocks := setup(t) - // And a mocked writeFile function to capture the output + // Mock WriteFile to capture the output var writtenData []byte - originalWriteFile := writeFile - defer func() { writeFile = originalWriteFile }() - writeFile = func(filename string, data []byte, perm os.FileMode) error { + mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error { writtenData = data return nil } // When generateBackendOverrideTf is called - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - err := terraformEnvPrinter.generateBackendOverrideTf() + err := printer.generateBackendOverrideTf() // Then no error should occur and the expected backend config should be written if err != nil { @@ -852,46 +719,18 @@ func TestTerraformEnv_generateBackendOverrideTf(t *testing.T) { }) t.Run("S3Backend", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "terraform.backend.type" { - return "s3" - } - if len(defaultValue) > 0 { - return defaultValue[0] - } - return "" - } - - // Given a mocked getwd function simulating being in a terraform project root - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return filepath.FromSlash("/mock/project/root/terraform/project/path"), nil - } - // And a mocked glob function simulating finding Terraform files - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { - if pattern == filepath.FromSlash("/mock/project/root/terraform/project/path/*.tf") { - return []string{filepath.FromSlash("/mock/project/root/terraform/project/path/main.tf")}, nil - } - return nil, nil - } + // Given a TerraformEnvPrinter with S3 backend configuration + printer, mocks := setup(t) + mocks.ConfigHandler.SetContextValue("terraform.backend.type", "s3") - // And a mocked writeFile function to capture the output var writtenData []byte - originalWriteFile := writeFile - defer func() { writeFile = originalWriteFile }() - writeFile = func(filename string, data []byte, perm os.FileMode) error { + mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error { writtenData = data return nil } // When generateBackendOverrideTf is called - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - err := terraformEnvPrinter.generateBackendOverrideTf() + err := printer.generateBackendOverrideTf() // Then no error should occur and the expected backend config should be written if err != nil { @@ -907,46 +746,18 @@ func TestTerraformEnv_generateBackendOverrideTf(t *testing.T) { }) t.Run("KubernetesBackend", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "terraform.backend.type" { - return "kubernetes" - } - if len(defaultValue) > 0 { - return defaultValue[0] - } - return "" - } - - // Given a mocked getwd function simulating being in a terraform project root - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return filepath.FromSlash("/mock/project/root/terraform/project/path"), nil - } - // And a mocked glob function simulating finding Terraform files - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { - if pattern == filepath.FromSlash("/mock/project/root/terraform/project/path/*.tf") { - return []string{filepath.FromSlash("/mock/project/root/terraform/project/path/main.tf")}, nil - } - return nil, nil - } + // Given a TerraformEnvPrinter with Kubernetes backend configuration + printer, mocks := setup(t) + mocks.ConfigHandler.SetContextValue("terraform.backend.type", "kubernetes") - // And a mocked writeFile function to capture the output var writtenData []byte - originalWriteFile := writeFile - defer func() { writeFile = originalWriteFile }() - writeFile = func(filename string, data []byte, perm os.FileMode) error { + mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error { writtenData = data return nil } // When generateBackendOverrideTf is called - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - err := terraformEnvPrinter.generateBackendOverrideTf() + err := printer.generateBackendOverrideTf() // Then no error should occur and the expected backend config should be written if err != nil { @@ -962,39 +773,14 @@ func TestTerraformEnv_generateBackendOverrideTf(t *testing.T) { }) t.Run("UnsupportedBackend", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "terraform.backend.type" { - return "unsupported" - } - if len(defaultValue) > 0 { - return defaultValue[0] - } - return "" - } - - // Given a mocked getwd function simulating being in a terraform project root - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return filepath.FromSlash("/mock/project/root/terraform/project/path"), nil - } - // And a mocked glob function simulating finding Terraform files - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { - if pattern == filepath.FromSlash("/mock/project/root/terraform/project/path/*.tf") { - return []string{filepath.FromSlash("/mock/project/root/terraform/project/path/main.tf")}, nil - } - return nil, nil - } + // Given a TerraformEnvPrinter with unsupported backend configuration + printer, mocks := setup(t) + mocks.ConfigHandler.SetContextValue("terraform.backend.type", "unsupported") // When generateBackendOverrideTf is called - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - err := terraformEnvPrinter.generateBackendOverrideTf() + err := printer.generateBackendOverrideTf() - // Then the error should contain the expected message + // Then an error should be returned if err == nil { t.Errorf("Expected error, got nil") } @@ -1004,34 +790,14 @@ func TestTerraformEnv_generateBackendOverrideTf(t *testing.T) { }) t.Run("NoTerraformFiles", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - Terraform: &terraform.TerraformConfig{ - Backend: &terraform.BackendConfig{ - Type: "local", - }, - }, - } - } - - // Given a mocked getwd function simulating being in a terraform project root - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return filepath.FromSlash("/mock/project/root/terraform/project/path"), nil - } - // And a mocked glob function simulating no Terraform files found - originalGlob := glob - defer func() { glob = originalGlob }() - glob = func(pattern string) ([]string, error) { + // Given a TerraformEnvPrinter with no Terraform files + printer, mocks := setup(t) + mocks.Shims.Glob = func(pattern string) ([]string, error) { return nil, nil } // When generateBackendOverrideTf is called - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - err := terraformEnvPrinter.generateBackendOverrideTf() + err := printer.generateBackendOverrideTf() // Then no error should occur if err != nil { @@ -1041,22 +807,33 @@ func TestTerraformEnv_generateBackendOverrideTf(t *testing.T) { } func TestTerraformEnv_generateBackendConfigArgs(t *testing.T) { - t.Run("Success", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() + setup := func(t *testing.T) (*TerraformEnvPrinter, *Mocks) { + t.Helper() + mocks := setupTerraformEnvMocks(t) + printer := NewTerraformEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize printer: %v", err) + } + return printer, mocks + } + t.Run("Success", func(t *testing.T) { + // Given a TerraformEnvPrinter with mock configuration + printer, _ := setup(t) projectPath := "project/path" - configRoot := filepath.FromSlash("/mock/config/root") + configRoot := "/mock/config/root" - backendConfigArgs, err := terraformEnvPrinter.generateBackendConfigArgs(projectPath, configRoot) + // When generateBackendConfigArgs is called + backendConfigArgs, err := printer.generateBackendConfigArgs(projectPath, configRoot) + + // Then no error should occur and the expected arguments should be returned if err != nil { t.Errorf("unexpected error: %v", err) } expectedArgs := []string{ - fmt.Sprintf(`-backend-config="%s"`, filepath.ToSlash(filepath.Join(configRoot, "terraform", "backend.tfvars"))), - fmt.Sprintf(`-backend-config="path=%s"`, filepath.ToSlash(filepath.Join(configRoot, ".tfstate", projectPath, "terraform.tfstate"))), + `-backend-config="path=/mock/config/root/.tfstate/project/path/terraform.tfstate"`, } if !reflect.DeepEqual(backendConfigArgs, expectedArgs) { @@ -1065,30 +842,22 @@ func TestTerraformEnv_generateBackendConfigArgs(t *testing.T) { }) t.Run("LocalBackendWithPrefix", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "terraform.backend.prefix" { - return "mock-prefix/" - } - if len(defaultValue) > 0 { - return defaultValue[0] - } - return "" - } - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - + // Given a TerraformEnvPrinter with local backend and prefix configuration + printer, mocks := setup(t) + mocks.ConfigHandler.SetContextValue("terraform.backend.prefix", "mock-prefix/") projectPath := "project/path" - configRoot := filepath.FromSlash("/mock/config/root") + configRoot := "/mock/config/root" - backendConfigArgs, err := terraformEnvPrinter.generateBackendConfigArgs(projectPath, configRoot) + // When generateBackendConfigArgs is called + backendConfigArgs, err := printer.generateBackendConfigArgs(projectPath, configRoot) + + // Then no error should occur and the expected arguments should be returned if err != nil { t.Errorf("unexpected error: %v", err) } expectedArgs := []string{ - fmt.Sprintf(`-backend-config="%s"`, filepath.ToSlash(filepath.Join(configRoot, "terraform", "backend.tfvars"))), - fmt.Sprintf(`-backend-config="path=%s"`, filepath.ToSlash(filepath.Join(configRoot, ".tfstate", "mock-prefix", projectPath, "terraform.tfstate"))), + `-backend-config="path=/mock/config/root/.tfstate/mock-prefix/project/path/terraform.tfstate"`, } if !reflect.DeepEqual(backendConfigArgs, expectedArgs) { @@ -1097,53 +866,29 @@ func TestTerraformEnv_generateBackendConfigArgs(t *testing.T) { }) t.Run("S3BackendWithPrefix", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - Terraform: &terraform.TerraformConfig{ - Backend: &terraform.BackendConfig{ - S3: &terraform.S3Backend{ - Bucket: stringPtr("mock-bucket"), - Region: stringPtr("mock-region"), - SecretKey: stringPtr("mock-secret-key"), - MaxRetries: intPtr(5), - SkipCredentialsValidation: boolPtr(true), - }, - }, - }, - } - } - mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "terraform.backend.type" { - return "s3" - } - if key == "terraform.backend.prefix" { - return "mock-prefix/" - } - if len(defaultValue) > 0 { - return defaultValue[0] - } - return "" - } - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() + // Given a TerraformEnvPrinter with S3 backend and prefix configuration + printer, mocks := setup(t) + mocks.ConfigHandler.SetContextValue("terraform.backend.type", "s3") + mocks.ConfigHandler.SetContextValue("terraform.backend.prefix", "mock-prefix/") + mocks.ConfigHandler.SetContextValue("terraform.backend.s3.bucket", "mock-bucket") + mocks.ConfigHandler.SetContextValue("terraform.backend.s3.region", "mock-region") + mocks.ConfigHandler.SetContextValue("terraform.backend.s3.secret_key", "mock-secret-key") + projectPath := "project/path" + configRoot := "/mock/config/root" - projectPath := filepath.FromSlash("project/path") - configRoot := filepath.FromSlash("/mock/config/root") + // When generateBackendConfigArgs is called + backendConfigArgs, err := printer.generateBackendConfigArgs(projectPath, configRoot) - backendConfigArgs, err := terraformEnvPrinter.generateBackendConfigArgs(projectPath, configRoot) + // Then no error should occur and the expected arguments should be returned if err != nil { t.Errorf("unexpected error: %v", err) } expectedArgs := []string{ - fmt.Sprintf(`-backend-config="%s"`, filepath.ToSlash(filepath.Join(configRoot, "terraform", "backend.tfvars"))), `-backend-config="key=mock-prefix/project/path/terraform.tfstate"`, `-backend-config="bucket=mock-bucket"`, - `-backend-config="max_retries=5"`, `-backend-config="region=mock-region"`, `-backend-config="secret_key=mock-secret-key"`, - `-backend-config="skip_credentials_validation=true"`, } if !reflect.DeepEqual(backendConfigArgs, expectedArgs) { @@ -1152,44 +897,23 @@ func TestTerraformEnv_generateBackendConfigArgs(t *testing.T) { }) t.Run("KubernetesBackendWithPrefix", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - Terraform: &terraform.TerraformConfig{ - Backend: &terraform.BackendConfig{ - Kubernetes: &terraform.KubernetesBackend{ - Namespace: stringPtr("mock-namespace"), - }, - }, - }, - } - } - mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - switch key { - case "terraform.backend.type": - return "kubernetes" - case "terraform.backend.prefix": - return "mock-prefix" - default: - if len(defaultValue) > 0 { - return defaultValue[0] - } - return "" - } - } - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() + // Given a TerraformEnvPrinter with Kubernetes backend and prefix configuration + printer, mocks := setup(t) + mocks.ConfigHandler.SetContextValue("terraform.backend.type", "kubernetes") + mocks.ConfigHandler.SetContextValue("terraform.backend.prefix", "mock-prefix") + mocks.ConfigHandler.SetContextValue("terraform.backend.kubernetes.namespace", "mock-namespace") + projectPath := "project/path" + configRoot := "/mock/config/root" - projectPath := filepath.FromSlash("project/path") - configRoot := filepath.FromSlash("/mock/config/root") + // When generateBackendConfigArgs is called + backendConfigArgs, err := printer.generateBackendConfigArgs(projectPath, configRoot) - backendConfigArgs, err := terraformEnvPrinter.generateBackendConfigArgs(projectPath, configRoot) + // Then no error should occur and the expected arguments should be returned if err != nil { t.Errorf("unexpected error: %v", err) } expectedArgs := []string{ - fmt.Sprintf(`-backend-config="%s"`, filepath.ToSlash(filepath.Join(configRoot, "terraform", "backend.tfvars"))), `-backend-config="secret_suffix=mock-prefix-project-path"`, `-backend-config="namespace=mock-namespace"`, } @@ -1200,33 +924,23 @@ func TestTerraformEnv_generateBackendConfigArgs(t *testing.T) { }) t.Run("BackendTfvarsFileExistsWithPrefix", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetContextFunc = func() string { - return "mock-context" - } - mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "terraform.backend.prefix" { - return "mock-prefix/" - } - if len(defaultValue) > 0 { - return defaultValue[0] - } - return "" - } - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() + // Given a TerraformEnvPrinter with backend tfvars file and prefix configuration + printer, mocks := setup(t) + mocks.ConfigHandler.SetContextValue("terraform.backend.prefix", "mock-prefix/") + mocks.ConfigHandler.SetContextValue("context", "mock-context") + projectPath := "project/path" + configRoot := "/mock/config/root" - projectPath := filepath.FromSlash("project/path") - configRoot := filepath.FromSlash("/mock/config/root") + // When generateBackendConfigArgs is called + backendConfigArgs, err := printer.generateBackendConfigArgs(projectPath, configRoot) - backendConfigArgs, err := terraformEnvPrinter.generateBackendConfigArgs(projectPath, configRoot) + // Then no error should occur and the expected arguments should be returned if err != nil { t.Errorf("unexpected error: %v", err) } expectedArgs := []string{ - fmt.Sprintf(`-backend-config="%s"`, filepath.ToSlash(filepath.Join(configRoot, "terraform", "backend.tfvars"))), - fmt.Sprintf(`-backend-config="path=%s"`, filepath.ToSlash(filepath.Join(configRoot, ".tfstate", "mock-prefix/project/path/terraform.tfstate"))), + `-backend-config="path=/mock/config/root/.tfstate/mock-prefix/project/path/terraform.tfstate"`, } if !reflect.DeepEqual(backendConfigArgs, expectedArgs) { @@ -1234,125 +948,84 @@ func TestTerraformEnv_generateBackendConfigArgs(t *testing.T) { } }) - t.Run("ErrorMarshallingBackendConfig", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - Terraform: &terraform.TerraformConfig{ - Backend: &terraform.BackendConfig{ - Type: "s3", - S3: &terraform.S3Backend{}, - }, - }, - } - } - mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "terraform.backend.type" { - return "s3" - } - if len(defaultValue) > 0 { - return defaultValue[0] + t.Run("BackendTfvarsFileExists", func(t *testing.T) { + // Given a TerraformEnvPrinter with a backend.tfvars file + printer, mocks := setup(t) + mocks.ConfigHandler.SetContextValue("context", "mock-context") + projectPath := "project/path" + configRoot := "/mock/config/root" + + // Mock Stat to return nil error for backend.tfvars + backendTfvarsPath := filepath.Join(configRoot, "terraform", "backend.tfvars") + mocks.Shims.Stat = func(path string) (os.FileInfo, error) { + if path == backendTfvarsPath { + return nil, nil } - return "" + return nil, fmt.Errorf("unexpected path: %s", path) } - // Mock yamlMarshal to return an error - originalYamlMarshal := yamlMarshal - defer func() { yamlMarshal = originalYamlMarshal }() - yamlMarshal = func(v any) ([]byte, error) { - return nil, fmt.Errorf("mock marshalling error") - } + // When generateBackendConfigArgs is called + backendConfigArgs, err := printer.generateBackendConfigArgs(projectPath, configRoot) - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - - projectPath := "project/path" - configRoot := filepath.FromSlash("/mock/config/root") + // Then no error should occur and backend.tfvars should be included + if err != nil { + t.Errorf("unexpected error: %v", err) + } - _, err := terraformEnvPrinter.generateBackendConfigArgs(projectPath, configRoot) - if err == nil { - t.Errorf("expected error, got nil") + expectedArgs := []string{ + fmt.Sprintf(`-backend-config="%s"`, filepath.ToSlash(backendTfvarsPath)), + `-backend-config="path=/mock/config/root/.tfstate/project/path/terraform.tfstate"`, } - expectedErrorMsg := "error marshalling backend to YAML: mock marshalling error" - if !strings.Contains(err.Error(), expectedErrorMsg) { - t.Errorf("expected error to contain %v, got %v", expectedErrorMsg, err.Error()) + if !reflect.DeepEqual(backendConfigArgs, expectedArgs) { + t.Errorf("expected %v, got %v", expectedArgs, backendConfigArgs) } }) - t.Run("ErrorProcessingKubernetesBackendConfig", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - Terraform: &terraform.TerraformConfig{ - Backend: &terraform.BackendConfig{ - Type: "kubernetes", - Kubernetes: &terraform.KubernetesBackend{}, - }, - }, - } - } - mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "terraform.backend.type" { - return "kubernetes" - } - if len(defaultValue) > 0 { - return defaultValue[0] - } - return "" - } + t.Run("BackendTfvarsFileDoesNotExist", func(t *testing.T) { + // Given a TerraformEnvPrinter without a backend.tfvars file + printer, mocks := setup(t) + mocks.ConfigHandler.SetContextValue("context", "mock-context") + projectPath := "project/path" + configRoot := "/mock/config/root" - // Mock processBackendConfig to return an error - originalProcessBackendConfig := processBackendConfig - defer func() { processBackendConfig = originalProcessBackendConfig }() - processBackendConfig = func(backendConfig interface{}, addArg func(key, value string)) error { - return fmt.Errorf("mock processing error") + // Mock Stat to return error for backend.tfvars + backendTfvarsPath := filepath.Join(configRoot, "terraform", "backend.tfvars") + mocks.Shims.Stat = func(path string) (os.FileInfo, error) { + if path == backendTfvarsPath { + return nil, fmt.Errorf("file not found") + } + return nil, fmt.Errorf("unexpected path: %s", path) } - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() + // When generateBackendConfigArgs is called + backendConfigArgs, err := printer.generateBackendConfigArgs(projectPath, configRoot) - projectPath := "project/path" - configRoot := filepath.FromSlash("/mock/config/root") + // Then no error should occur and backend.tfvars should not be included + if err != nil { + t.Errorf("unexpected error: %v", err) + } - _, err := terraformEnvPrinter.generateBackendConfigArgs(projectPath, configRoot) - if err == nil { - t.Errorf("expected error, got nil") + expectedArgs := []string{ + `-backend-config="path=/mock/config/root/.tfstate/project/path/terraform.tfstate"`, } - if !strings.Contains(err.Error(), "error processing Kubernetes backend config: mock processing error") { - t.Errorf("expected error to contain %v, got %v", "error processing Kubernetes backend config: mock processing error", err.Error()) + if !reflect.DeepEqual(backendConfigArgs, expectedArgs) { + t.Errorf("expected %v, got %v", expectedArgs, backendConfigArgs) } }) t.Run("UnsupportedBackendType", func(t *testing.T) { - mocks := setupSafeTerraformEnvMocks() - mocks.ConfigHandler.GetConfigFunc = func() *v1alpha1.Context { - return &v1alpha1.Context{ - Terraform: &terraform.TerraformConfig{ - Backend: &terraform.BackendConfig{ - Type: "unsupported", - }, - }, - } - } - mocks.ConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "terraform.backend.type" { - return "unsupported" - } - if len(defaultValue) > 0 { - return defaultValue[0] - } - return "" - } - - terraformEnvPrinter := NewTerraformEnvPrinter(mocks.Injector) - terraformEnvPrinter.Initialize() - + // Given a TerraformEnvPrinter with unsupported backend configuration + printer, mocks := setup(t) + mocks.ConfigHandler.SetContextValue("terraform.backend.type", "unsupported") projectPath := "project/path" - configRoot := filepath.FromSlash("/mock/config/root") + configRoot := "/mock/config/root" - _, err := terraformEnvPrinter.generateBackendConfigArgs(projectPath, configRoot) + // When generateBackendConfigArgs is called + _, err := printer.generateBackendConfigArgs(projectPath, configRoot) + + // Then an error should be returned if err == nil { t.Errorf("expected error, got nil") } @@ -1364,13 +1037,27 @@ func TestTerraformEnv_generateBackendConfigArgs(t *testing.T) { } func TestTerraformEnv_processBackendConfig(t *testing.T) { + setup := func(t *testing.T) (*TerraformEnvPrinter, *Mocks) { + t.Helper() + mocks := setupTerraformEnvMocks(t) + printer := NewTerraformEnvPrinter(mocks.Injector) + printer.shims = mocks.Shims + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize printer: %v", err) + } + return printer, mocks + } + t.Run("Success", func(t *testing.T) { - backendConfig := map[string]interface{}{ + // Given a TerraformEnvPrinter with valid backend configuration + printer, _ := setup(t) + + backendConfig := map[string]any{ "key1": "value1", "key2": true, "key3": 123, - "key4": []interface{}{"item1", "item2"}, - "key5": map[string]interface{}{ + "key4": []any{"item1", "item2"}, + "key5": map[string]any{ "nestedKey1": "nestedValue1", "nestedKey2": "nestedValue2", }, @@ -1381,11 +1068,15 @@ func TestTerraformEnv_processBackendConfig(t *testing.T) { args = append(args, fmt.Sprintf("%s=%s", key, value)) } - err := processBackendConfig(backendConfig, addArg) + // When processing the backend configuration + err := printer.processBackendConfig(backendConfig, addArg) + + // Then no error should occur if err != nil { t.Errorf("unexpected error: %v", err) } + // And all configuration values should be properly formatted expectedArgs := []string{ "key1=value1", "key2=true", @@ -1405,15 +1096,14 @@ func TestTerraformEnv_processBackendConfig(t *testing.T) { }) t.Run("ErrorUnmarshallingBackendConfig", func(t *testing.T) { - originalYamlUnmarshal := yamlUnmarshal - defer func() { yamlUnmarshal = originalYamlUnmarshal }() + // Given a TerraformEnvPrinter with failing YAML unmarshalling + printer, mocks := setup(t) - yamlUnmarshal = func(data []byte, v interface{}) error { - return fmt.Errorf("mocked error") + mocks.Shims.YamlMarshal = func(v any) ([]byte, error) { + return []byte("valid yaml"), nil } - - backendConfig := map[string]interface{}{ - "key1": "value1", + mocks.Shims.YamlUnmarshal = func(data []byte, v any) error { + return fmt.Errorf("mock unmarshal error") } var args []string @@ -1421,14 +1111,18 @@ func TestTerraformEnv_processBackendConfig(t *testing.T) { args = append(args, fmt.Sprintf("%s=%s", key, value)) } - err := processBackendConfig(backendConfig, addArg) + // When processing the backend configuration + err := printer.processBackendConfig(map[string]any{"key1": "value1"}, addArg) + + // Then an error should be returned if err == nil { t.Errorf("expected error, got nil") } - expectedError := "mocked error" - if !strings.Contains(err.Error(), expectedError) { - t.Errorf("expected error to contain %v, got %v", expectedError, err.Error()) + // And the error should contain the expected message + expectedError := "error unmarshalling backend YAML: mock unmarshal error" + if err.Error() != expectedError { + t.Errorf("expected error %q, got %q", expectedError, err.Error()) } }) } diff --git a/pkg/env/windsor_env.go b/pkg/env/windsor_env.go index efd9e64e1..ba93d0624 100644 --- a/pkg/env/windsor_env.go +++ b/pkg/env/windsor_env.go @@ -1,8 +1,12 @@ +// The WindsorEnvPrinter is a specialized component that manages Windsor environment configuration. +// It provides Windsor-specific environment variable management and configuration, +// The WindsorEnvPrinter handles context, project root, and secrets management, +// ensuring proper Windsor CLI integration and environment setup for application operations. + package env import ( "fmt" - "os" "regexp" "strings" @@ -10,6 +14,11 @@ import ( "github.com/windsorcli/cli/pkg/secrets" ) +// ============================================================================= +// Constants +// ============================================================================= + +// WindsorPrefixedVars are the environment variables that are managed by Windsor. var WindsorPrefixedVars = []string{ "WINDSOR_CONTEXT", "WINDSOR_PROJECT_ROOT", @@ -18,21 +27,31 @@ var WindsorPrefixedVars = []string{ "WINDSOR_MANAGED_ALIAS", } -// WindsorEnvPrinter is a struct that simulates a Kubernetes environment for testing purposes. +// ============================================================================= +// Types +// ============================================================================= + +// WindsorEnvPrinter is a struct that implements Windsor environment configuration type WindsorEnvPrinter struct { BaseEnvPrinter secretsProviders []secrets.SecretsProvider } -// NewWindsorEnvPrinter initializes a new WindsorEnvPrinter instance using the provided dependency injector. +// ============================================================================= +// Constructor +// ============================================================================= + +// NewWindsorEnvPrinter creates a new WindsorEnvPrinter instance func NewWindsorEnvPrinter(injector di.Injector) *WindsorEnvPrinter { return &WindsorEnvPrinter{ - BaseEnvPrinter: BaseEnvPrinter{ - injector: injector, - }, + BaseEnvPrinter: *NewBaseEnvPrinter(injector), } } +// ============================================================================= +// Public Methods +// ============================================================================= + // Initialize sets up the WindsorEnvPrinter, including resolving secrets providers. func (e *WindsorEnvPrinter) Initialize() error { if err := e.BaseEnvPrinter.Initialize(); err != nil { @@ -47,11 +66,7 @@ func (e *WindsorEnvPrinter) Initialize() error { secretsProviders := make([]secrets.SecretsProvider, 0, len(instances)) for _, instance := range instances { - secretsProvider, ok := instance.(secrets.SecretsProvider) - if !ok { - return fmt.Errorf("failed to cast instance to SecretsProvider") - } - secretsProviders = append(secretsProviders, secretsProvider) + secretsProviders = append(secretsProviders, instance.(secrets.SecretsProvider)) } e.secretsProviders = secretsProviders @@ -59,7 +74,7 @@ func (e *WindsorEnvPrinter) Initialize() error { return nil } -// GetEnvVars constructs a map of Windsor-specific environment variables including +// GetEnvVars constructs a map of Windsor-specific environment variables by retrieving // the current context, project root, and session token. It resolves secrets in custom // environment variables using configured providers, handles caching of values, and // manages environment variables and aliases. For secrets, it leverages the secrets cache @@ -88,7 +103,7 @@ func (e *WindsorEnvPrinter) GetEnvVars() (map[string]string, error) { re := regexp.MustCompile(`\${{\s*(.*?)\s*}}`) - _, managedEnvExists := osLookupEnv("WINDSOR_MANAGED_ENV") + _, managedEnvExists := e.shims.LookupEnv("WINDSOR_MANAGED_ENV") for k, v := range originalEnvVars { if !managedEnvExists { @@ -96,11 +111,11 @@ func (e *WindsorEnvPrinter) GetEnvVars() (map[string]string, error) { } if re.MatchString(v) { - if existingValue, exists := osLookupEnv(k); exists { + if existingValue, exists := e.shims.LookupEnv(k); exists { if managedEnvExists { e.SetManagedEnv(k) } - if shouldUseCache() && !strings.Contains(existingValue, " 0 { - mockInjector = injector[0] - } else { - mockInjector = di.NewMockInjector() +func setupWindsorEnvMocks(t *testing.T, opts ...*SetupOptions) *Mocks { + t.Helper() + if opts == nil { + opts = []*SetupOptions{{}} + } + if opts[0].ConfigStr == "" { + opts[0].ConfigStr = ` +version: v1alpha1 +contexts: + mock-context: + environment: + TEST_VAR: test_value + SECRET_VAR: "{{secret_name}}" +` } + mocks := setupMocks(t, opts[0]) - mockConfigHandler := config.NewMockConfigHandler() - mockConfigHandler.GetConfigRootFunc = func() (string, error) { - return filepath.FromSlash("/mock/config/root"), nil + // Get the temp dir that was set up in setupMocks + projectRoot, err := mocks.Shell.GetProjectRoot() + if err != nil { + t.Fatalf("Failed to get project root: %v", err) } - mockConfigHandler.GetContextFunc = func() string { - return "mock-context" + + // Set up shims for Windsor operations + mocks.Shims.LookupEnv = func(key string) (string, bool) { + // Use os.LookupEnv to get the real environment variables + val, ok := os.LookupEnv(key) + return val, ok } - mockShell := shell.NewMockShell() - mockShell.GetProjectRootFunc = func() (string, error) { - return filepath.FromSlash("/mock/project/root"), nil + // Mock GetSessionToken + mocks.Shell.GetSessionTokenFunc = func() (string, error) { + return "mock-token", nil } - // Default behavior for GetSessionToken that matches test expectations - mockShell.GetSessionTokenFunc = func() (string, error) { - // If WINDSOR_SESSION_TOKEN is set in the environment, check it - if envToken := os.Getenv("WINDSOR_SESSION_TOKEN"); envToken != "" { - // Check for signal file if env token exists - tokenFilePath := filepath.Join("/mock/project/root", ".windsor", ".session."+envToken) - if _, err := stat(tokenFilePath); err == nil { - // Signal file exists, generate new token - return "abcdefg", nil - } - // Signal file doesn't exist, return environment token - return envToken, nil + // Create and register mock secrets provider + mockSecretsProvider := secrets.NewMockSecretsProvider(mocks.Injector) + mockSecretsProvider.ParseSecretsFunc = func(input string) (string, error) { + if strings.Contains(input, "${{secret_name}}") { + return "parsed_secret_value", nil } - // No env token, return default mock token - return "mock-token", nil + return input, nil } + mocks.Injector.Register("secretsProvider", mockSecretsProvider) - mockInjector.Register("configHandler", mockConfigHandler) - mockInjector.Register("shell", mockShell) + t.Cleanup(func() { + os.Unsetenv("NO_CACHE") + os.Unsetenv("WINDSOR_MANAGED_ENV") + mocks.Shell.ResetSessionToken() + }) - return &WindsorEnvMocks{ - Injector: mockInjector, - ConfigHandler: mockConfigHandler, - Shell: mockShell, - } -} + // Set environment variables using the temp dir + t.Setenv("WINDSOR_CONTEXT", "mock-context") + t.Setenv("WINDSOR_PROJECT_ROOT", projectRoot) -// customMockInjector is a custom injector for testing that returns non-castable objects -type customMockInjector struct { - *di.MockInjector + return mocks } -// ResolveAll overrides the ResolveAll method to return non-castable objects -func (c *customMockInjector) ResolveAll(targetType interface{}) ([]interface{}, error) { - if _, ok := targetType.(*secrets.SecretsProvider); ok { - // Return a non-castable int - return []interface{}{123}, nil - } - return c.MockInjector.ResolveAll(targetType) -} +// ============================================================================= +// Test Public Methods +// ============================================================================= +// TestWindsorEnv_GetEnvVars tests the GetEnvVars method of the WindsorEnvPrinter func TestWindsorEnv_GetEnvVars(t *testing.T) { - originalOsLookupEnv := osLookupEnv - defer func() { - osLookupEnv = originalOsLookupEnv - }() - - // Reset session token before each test - shell.ResetSessionToken() - - // Set up mock environment variables - t.Setenv("NO_CACHE", "") + setup := func(t *testing.T) (*WindsorEnvPrinter, *Mocks) { + t.Helper() + mocks := setupWindsorEnvMocks(t) + printer := NewWindsorEnvPrinter(mocks.Injector) + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize env: %v", err) + } + printer.shims = mocks.Shims + return printer, mocks + } t.Run("Success", func(t *testing.T) { - // Reset session token to ensure consistent behavior - shell.ResetSessionToken() - - // Given a WindsorEnvPrinter with mock dependencies - mocks := setupSafeWindsorEnvMocks() - - // Make the shell return a consistent mock token - mocks.Shell.GetSessionTokenFunc = func() (string, error) { - return "mock-token", nil - } - - windsorEnvPrinter := NewWindsorEnvPrinter(mocks.Injector) - windsorEnvPrinter.Initialize() + printer, _ := setup(t) + // Given a properly initialized WindsorEnvPrinter // When GetEnvVars is called - envVars, err := windsorEnvPrinter.GetEnvVars() + envVars, err := printer.GetEnvVars() - // Then the result should not contain an error + // Then no error should be returned if err != nil { t.Fatalf("Expected no error, got %v", err) } - // And the environment variables should contain the expected values + // And environment variables should contain expected values expectedContext := "mock-context" if envVars["WINDSOR_CONTEXT"] != expectedContext { t.Errorf("Expected WINDSOR_CONTEXT to be %q, got %q", expectedContext, envVars["WINDSOR_CONTEXT"]) } - expectedProjectRoot := "/mock/project/root" - if filepath.ToSlash(envVars["WINDSOR_PROJECT_ROOT"]) != expectedProjectRoot { - t.Errorf("Expected WINDSOR_PROJECT_ROOT to be %q, got %q", expectedProjectRoot, envVars["WINDSOR_PROJECT_ROOT"]) + // And project root should be set + if envVars["WINDSOR_PROJECT_ROOT"] == "" { + t.Error("Expected WINDSOR_PROJECT_ROOT to be set") } + // And session token should be set expectedSessionToken := "mock-token" if envVars["WINDSOR_SESSION_TOKEN"] != expectedSessionToken { t.Errorf("Expected WINDSOR_SESSION_TOKEN to be %q, got %q", expectedSessionToken, envVars["WINDSOR_SESSION_TOKEN"]) } }) - t.Run("ExistingSessionToken", func(t *testing.T) { - // Reset session token to ensure consistent behavior - shell.ResetSessionToken() + t.Run("ProjectRootError", func(t *testing.T) { + printer, mocks := setup(t) + + // Given a WindsorEnvPrinter with failing project root retrieval + mocks.Shell.GetProjectRootFunc = func() (string, error) { + return "", fmt.Errorf("mock project root error") + } + + // When GetEnvVars is called + _, err := printer.GetEnvVars() + + // Then an error should be returned + if err == nil { + t.Fatal("Expected error from project root retrieval, got nil") + } + if !strings.Contains(err.Error(), "error retrieving project root") { + t.Errorf("Unexpected error message: %v", err) + } + }) + + t.Run("SecretVarWithCacheEnabled", func(t *testing.T) { + printer, mocks := setup(t) + + // Given a WindsorEnvPrinter with cache enabled + t.Setenv("NO_CACHE", "0") + t.Setenv("SECRET_VAR", "cached_value") + t.Setenv("WINDSOR_MANAGED_ENV", "") + + // And mock secrets provider + mockSecretsProvider := secrets.NewMockSecretsProvider(mocks.Injector) + mockSecretsProvider.ParseSecretsFunc = func(input string) (string, error) { + if strings.Contains(input, "{{secret_name}}") { + return "parsed_secret_value", nil + } + return input, nil + } + mocks.Injector.Register("secretsProvider", mockSecretsProvider) + + // And re-initialize printer to pick up new mock + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to re-initialize env: %v", err) + } + + // And config with secret variable + if err := mocks.ConfigHandler.LoadConfigString(` +version: v1alpha1 +contexts: + mock-context: + environment: + SECRET_VAR: "${{secret_name}}" +`); err != nil { + t.Fatalf("LoadConfigString returned error: %v", err) + } + + // When GetEnvVars is called + envVars, err := printer.GetEnvVars() + if err != nil { + t.Fatalf("GetEnvVars returned an error: %v", err) + } + + // Then the cached value should be used and the variable should not be tracked + if _, exists := envVars["SECRET_VAR"]; exists { + t.Error("Expected SECRET_VAR to not be in envVars when caching is enabled") + } + + // And it should be tracked in managed env + managedEnv := envVars["WINDSOR_MANAGED_ENV"] + if !strings.Contains(managedEnv, "SECRET_VAR") { + t.Error("Expected SECRET_VAR to be in managed env when caching is enabled") + } + }) + + t.Run("SecretVarWithCacheDisabled", func(t *testing.T) { + printer, mocks := setup(t) + + // Given a WindsorEnvPrinter with cache disabled + t.Setenv("NO_CACHE", "1") + t.Setenv("SECRET_VAR", "cached_value") + + // And config with secret variable + if err := mocks.ConfigHandler.LoadConfigString(` +version: v1alpha1 +contexts: + mock-context: + environment: + SECRET_VAR: "${{secret_name}}" +`); err != nil { + t.Fatalf("LoadConfigString returned error: %v", err) + } + + // When GetEnvVars is called + envVars, err := printer.GetEnvVars() + if err != nil { + t.Fatalf("GetEnvVars returned an error: %v", err) + } + + // Then the cached value should not be used + if envVars["SECRET_VAR"] == "cached_value" { + t.Error("Expected SECRET_VAR to not use cached value") + } + }) + + t.Run("SecretVarWithErrorInExistingValue", func(t *testing.T) { + printer, mocks := setup(t) + + // Given a WindsorEnvPrinter with error in existing value + t.Setenv("NO_CACHE", "0") + t.Setenv("SECRET_VAR", "secret error") + + // And config with secret variable + if err := mocks.ConfigHandler.LoadConfigString(` +version: v1alpha1 +contexts: + mock-context: + environment: + SECRET_VAR: "${{secret_name}}" +`); err != nil { + t.Fatalf("LoadConfigString returned error: %v", err) + } + + // When GetEnvVars is called + envVars, err := printer.GetEnvVars() + if err != nil { + t.Fatalf("GetEnvVars returned an error: %v", err) + } + + // Then the cached value should not be used + if envVars["SECRET_VAR"] == "secret error" { + t.Error("Expected SECRET_VAR to not use cached error value") + } + }) - mocks := setupSafeWindsorEnvMocks() + t.Run("SecretVarWithManagedEnvExists", func(t *testing.T) { + printer, mocks := setup(t) - // Setup mock shell to simulate token regeneration - mockShell := shell.NewMockShell() + // Given a WindsorEnvPrinter with managed env exists + t.Setenv("NO_CACHE", "0") + t.Setenv("SECRET_VAR", "cached_value") + t.Setenv("WINDSOR_MANAGED_ENV", "SECRET_VAR") + + // And config with secret variable + if err := mocks.ConfigHandler.LoadConfigString(` +version: v1alpha1 +contexts: + mock-context: + environment: + SECRET_VAR: "${{secret_name}}" +`); err != nil { + t.Fatalf("LoadConfigString returned error: %v", err) + } + + // When GetEnvVars is called + envVars, err := printer.GetEnvVars() + if err != nil { + t.Fatalf("GetEnvVars returned an error: %v", err) + } + + // Then the variable should be in managed env + managedEnv := envVars["WINDSOR_MANAGED_ENV"] + if !strings.Contains(managedEnv, "SECRET_VAR") { + t.Error("Expected SECRET_VAR to be in managed env") + } + }) + + t.Run("ExistingSessionToken", func(t *testing.T) { + printer, mocks := setup(t) + + // Given a WindsorEnvPrinter with token regeneration var callCount int - mockShell.GetSessionTokenFunc = func() (string, error) { + mocks.Shell.GetSessionTokenFunc = func() (string, error) { callCount++ if callCount == 1 { return "first-token", nil } return "regenerated-token", nil } - mocks.Injector.Register("shell", mockShell) - windsorEnvPrinter := NewWindsorEnvPrinter(mocks.Injector) - windsorEnvPrinter.Initialize() - - // First call - envVars, err := windsorEnvPrinter.GetEnvVars() + // When GetEnvVars is called twice + envVars, err := printer.GetEnvVars() if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } + + // Then the first token should be returned firstToken := envVars["WINDSOR_SESSION_TOKEN"] if firstToken != "first-token" { t.Errorf("Expected first token to be 'first-token', got %s", firstToken) } - // Second call - envVars, err = windsorEnvPrinter.GetEnvVars() + // And when called again + envVars, err = printer.GetEnvVars() if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } + + // Then a new token should be generated secondToken := envVars["WINDSOR_SESSION_TOKEN"] if secondToken != "regenerated-token" { t.Errorf("Expected second token to be 'regenerated-token', got %s", secondToken) @@ -176,20 +323,17 @@ func TestWindsorEnv_GetEnvVars(t *testing.T) { }) t.Run("SessionTokenError", func(t *testing.T) { - mocks := setupSafeWindsorEnvMocks() + printer, mocks := setup(t) - // Setup mock shell to simulate token error - mockShell := shell.NewMockShell() - mockShell.GetSessionTokenFunc = func() (string, error) { + // Given a WindsorEnvPrinter with failing session token generation + mocks.Shell.GetSessionTokenFunc = func() (string, error) { return "", fmt.Errorf("mock session token error") } - mocks.Injector.Register("shell", mockShell) - windsorEnvPrinter := NewWindsorEnvPrinter(mocks.Injector) - windsorEnvPrinter.Initialize() + // When GetEnvVars is called + _, err := printer.GetEnvVars() - // Call should fail with session token error - _, err := windsorEnvPrinter.GetEnvVars() + // Then an error should be returned if err == nil { t.Fatal("Expected error from session token generation, got nil") } @@ -199,42 +343,46 @@ func TestWindsorEnv_GetEnvVars(t *testing.T) { }) t.Run("NoEnvironmentVarsInConfig", func(t *testing.T) { - // Reset managed environment variables and aliases - windsorManagedMu.Lock() - windsorManagedEnv = []string{} - windsorManagedAlias = []string{} - windsorManagedMu.Unlock() - - mocks := setupSafeWindsorEnvMocks() - - // Set GetStringMap to return an empty map to simulate no environment vars in config - mocks.ConfigHandler.GetStringMapFunc = func(key string, defaultValue ...map[string]string) map[string]string { - if key == "environment" { - return map[string]string{} - } - return map[string]string{} + printer, mocks := setup(t) + + // Given a WindsorEnvPrinter with empty environment configuration + projectRoot, err := mocks.Shell.GetProjectRoot() + if err != nil { + t.Fatalf("Failed to get project root: %v", err) } - windsorEnvPrinter := NewWindsorEnvPrinter(mocks.Injector) - windsorEnvPrinter.Initialize() + // And no managed environment variables or aliases + printer.managedEnv = []string{} + printer.managedAlias = []string{} + + // And empty environment map in config + if err := mocks.ConfigHandler.LoadConfigString(` +version: v1alpha1 +contexts: + test-context: + environment: {} +`); err != nil { + t.Fatalf("LoadConfigString returned error: %v", err) + } - envVars, err := windsorEnvPrinter.GetEnvVars() + // When GetEnvVars is called + envVars, err := printer.GetEnvVars() if err != nil { t.Fatalf("GetEnvVars should not return an error: %v", err) } - // Verify we still have the base environment variables + // Then base environment variables should be set if envVars["WINDSOR_CONTEXT"] != "mock-context" { t.Errorf("WINDSOR_CONTEXT should be set even when no environment vars are in config") } - if filepath.ToSlash(envVars["WINDSOR_PROJECT_ROOT"]) != "/mock/project/root" { - t.Errorf("WINDSOR_PROJECT_ROOT should be set") + if envVars["WINDSOR_PROJECT_ROOT"] != projectRoot { + t.Errorf("WINDSOR_PROJECT_ROOT = %q, want %q", envVars["WINDSOR_PROJECT_ROOT"], projectRoot) } if envVars["WINDSOR_SESSION_TOKEN"] == "" { t.Errorf("Session token should be generated") } - // Verify no additional variables were added from config (since there were none) + // And no additional variables should be added t.Logf("Environment variables: %v", envVars) if len(envVars) != 5 { t.Errorf("Should have five base environment variables") @@ -242,105 +390,114 @@ func TestWindsorEnv_GetEnvVars(t *testing.T) { }) t.Run("EnvironmentTokenWithoutSignalFile", func(t *testing.T) { - mocks := setupSafeWindsorEnvMocks() + printer, mocks := setup(t) - // Set up environment variable with a token + // Given a WindsorEnvPrinter with environment token t.Setenv("WINDSOR_SESSION_TOKEN", "envtoken") - // Mock stat to simulate no signal file - stat = func(name string) (os.FileInfo, error) { + // And no signal file exists + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { return nil, os.ErrNotExist } - windsorEnvPrinter := NewWindsorEnvPrinter(mocks.Injector) - windsorEnvPrinter.Initialize() + // And GetSessionToken configured to handle environment token + mocks.Shell.GetSessionTokenFunc = func() (string, error) { + if envToken, exists := mocks.Shims.LookupEnv("WINDSOR_SESSION_TOKEN"); exists { + return envToken, nil + } + return "mock-token", nil + } - envVars, err := windsorEnvPrinter.GetEnvVars() + // When GetEnvVars is called + envVars, err := printer.GetEnvVars() if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } - // Verify the environment token is used + // Then the environment token should be used if envVars["WINDSOR_SESSION_TOKEN"] != "envtoken" { t.Errorf("Expected session token to be 'envtoken', got %s", envVars["WINDSOR_SESSION_TOKEN"]) } }) t.Run("EnvironmentTokenWithStatError", func(t *testing.T) { - mocks := setupSafeWindsorEnvMocks() + printer, mocks := setup(t) - // Set up environment variable with a token + // Given a WindsorEnvPrinter with environment token t.Setenv("WINDSOR_SESSION_TOKEN", "envtoken") - // Mock stat to return an error that is not os.ErrNotExist - stat = func(name string) (os.FileInfo, error) { + // And stat returns a non-ErrNotExist error + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { return nil, fmt.Errorf("mock stat error") } - windsorEnvPrinter := NewWindsorEnvPrinter(mocks.Injector) - windsorEnvPrinter.Initialize() + // And GetSessionToken configured to handle environment token + mocks.Shell.GetSessionTokenFunc = func() (string, error) { + if envToken, exists := mocks.Shims.LookupEnv("WINDSOR_SESSION_TOKEN"); exists { + return envToken, nil + } + return "mock-token", nil + } - envVars, err := windsorEnvPrinter.GetEnvVars() + // When GetEnvVars is called + envVars, err := printer.GetEnvVars() if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } - // Verify the environment token is used, since the error is not specifically ErrNotExist + // Then the environment token should be used if envVars["WINDSOR_SESSION_TOKEN"] != "envtoken" { t.Errorf("Expected session token to be 'envtoken', got %s", envVars["WINDSOR_SESSION_TOKEN"]) } }) t.Run("EnvironmentTokenWithSignalFile", func(t *testing.T) { - // Mock file system functions - originalStat := stat - defer func() { - stat = originalStat - }() + printer, mocks := setup(t) + + // Given a WindsorEnvPrinter with environment token + t.Setenv("WINDSOR_SESSION_TOKEN", "envtoken") - stat = func(name string) (os.FileInfo, error) { + // And signal file exists + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { if strings.Contains(name, ".session.envtoken") { return nil, nil // File exists } return nil, os.ErrNotExist } - originalOsRemoveAll := osRemoveAll - defer func() { - osRemoveAll = originalOsRemoveAll - }() - - osRemoveAll = func(path string) error { + // And RemoveAll succeeds + mocks.Shims.RemoveAll = func(path string) error { return nil } - // Mock crypto functions for predictable output - origCryptoRandRead := cryptoRandRead - defer func() { - cryptoRandRead = origCryptoRandRead - }() - - cryptoRandRead = func(b []byte) (n int, err error) { + // And CryptoRandRead returns predictable output + mocks.Shims.CryptoRandRead = func(b []byte) (n int, err error) { for i := range b { b[i] = byte(i % 62) // Will map to characters in charset } return len(b), nil } - mocks := setupSafeWindsorEnvMocks() - - // Set up environment variable with a token - t.Setenv("WINDSOR_SESSION_TOKEN", "envtoken") - - windsorEnvPrinter := NewWindsorEnvPrinter(mocks.Injector) - windsorEnvPrinter.Initialize() + // And GetSessionToken configured to handle environment token and signal file + mocks.Shell.GetSessionTokenFunc = func() (string, error) { + if envToken, exists := mocks.Shims.LookupEnv("WINDSOR_SESSION_TOKEN"); exists { + tokenFilePath := filepath.Join("/mock/project/root", ".windsor", ".session."+envToken) + if _, err := mocks.Shims.Stat(tokenFilePath); err == nil { + // Signal file exists, generate new token + return "abcdefg", nil + } + return envToken, nil + } + return "mock-token", nil + } - envVars, err := windsorEnvPrinter.GetEnvVars() + // When GetEnvVars is called + envVars, err := printer.GetEnvVars() if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } - // Verify a new token was generated (should be "abcdefg" per our mock) + // Then a new token should be generated if envVars["WINDSOR_SESSION_TOKEN"] == "envtoken" { t.Errorf("Expected a new token to be generated, but got the environment token") } @@ -350,64 +507,52 @@ func TestWindsorEnv_GetEnvVars(t *testing.T) { }) t.Run("SignalFileRemovalError", func(t *testing.T) { - // Mock file system functions - originalStat := stat - defer func() { - stat = originalStat - }() + printer, mocks := setup(t) - stat = func(name string) (os.FileInfo, error) { + // Given a WindsorEnvPrinter with environment token + t.Setenv("WINDSOR_SESSION_TOKEN", "envtoken") + + // And signal file exists + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { if strings.Contains(name, ".session.envtoken") { return nil, nil // File exists } return nil, os.ErrNotExist } - // Mock osRemoveAll to return an error - originalOsRemoveAll := osRemoveAll - defer func() { - osRemoveAll = originalOsRemoveAll - }() - - osRemoveAll = func(path string) error { + // And RemoveAll fails + mocks.Shims.RemoveAll = func(path string) error { return fmt.Errorf("mock error removing signal file") } - // Mock crypto functions for predictable output - origCryptoRandRead := cryptoRandRead - defer func() { - cryptoRandRead = origCryptoRandRead - }() - - cryptoRandRead = func(b []byte) (n int, err error) { + // And CryptoRandRead returns predictable output + mocks.Shims.CryptoRandRead = func(b []byte) (n int, err error) { for i := range b { b[i] = byte(i % 62) // Will map to characters in charset } return len(b), nil } - mocks := setupSafeWindsorEnvMocks() - - // Set up environment variable with a token - t.Setenv("WINDSOR_SESSION_TOKEN", "envtoken") - - // We'll redirect stdout to discard any error output - origStdout := os.Stdout - os.Stdout = os.NewFile(0, os.DevNull) - defer func() { - os.Stdout = origStdout - }() - - windsorEnvPrinter := NewWindsorEnvPrinter(mocks.Injector) - windsorEnvPrinter.Initialize() + // And GetSessionToken configured to handle environment token and signal file + mocks.Shell.GetSessionTokenFunc = func() (string, error) { + if envToken, exists := mocks.Shims.LookupEnv("WINDSOR_SESSION_TOKEN"); exists { + tokenFilePath := filepath.Join("/mock/project/root", ".windsor", ".session."+envToken) + if _, err := mocks.Shims.Stat(tokenFilePath); err == nil { + // Signal file exists, generate new token + return "abcdefg", nil + } + return envToken, nil + } + return "mock-token", nil + } - // Call should not fail (error is deferred and printed to stdout) - envVars, err := windsorEnvPrinter.GetEnvVars() + // When GetEnvVars is called + envVars, err := printer.GetEnvVars() if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } - // Verify a new token was generated (should be "abcdefg" per our mock) + // Then a new token should be generated despite removal error if envVars["WINDSOR_SESSION_TOKEN"] == "envtoken" { t.Errorf("Expected a new token to be generated, but got the environment token") } @@ -417,30 +562,23 @@ func TestWindsorEnv_GetEnvVars(t *testing.T) { }) t.Run("ProjectRootErrorDuringEnvTokenSignalFileCheck", func(t *testing.T) { - // Set up environment variable with a token - t.Setenv("WINDSOR_SESSION_TOKEN", "envtoken") - - mocks := setupSafeWindsorEnvMocks() + printer, mocks := setup(t) - // First call succeeds, second fails (during token check) - var callCount int - mocks.Shell.GetProjectRootFunc = func() (string, error) { - callCount++ - return filepath.FromSlash("/mock/project/root"), nil - } + // Given a WindsorEnvPrinter with environment token + t.Setenv("WINDSOR_SESSION_TOKEN", "envtoken") - // Make GetSessionToken return an error during the signal file check + // And GetSessionToken returns an error during signal file check mocks.Shell.GetSessionTokenFunc = func() (string, error) { - if envToken := os.Getenv("WINDSOR_SESSION_TOKEN"); envToken != "" { + if _, exists := mocks.Shims.LookupEnv("WINDSOR_SESSION_TOKEN"); exists { return "", fmt.Errorf("error getting project root: mock error getting project root during token check") } return "mock-token", nil } - windsorEnvPrinter := NewWindsorEnvPrinter(mocks.Injector) - windsorEnvPrinter.Initialize() + // When GetEnvVars is called + _, err := printer.GetEnvVars() - _, err := windsorEnvPrinter.GetEnvVars() + // Then an error should be returned if err == nil { t.Fatal("Expected error, got nil") } @@ -452,31 +590,24 @@ func TestWindsorEnv_GetEnvVars(t *testing.T) { }) t.Run("RandomGenerationError", func(t *testing.T) { - // Mock file system functions - originalStat := stat - defer func() { - stat = originalStat - }() + printer, mocks := setup(t) - stat = func(name string) (os.FileInfo, error) { + // Given a WindsorEnvPrinter with environment token + t.Setenv("WINDSOR_SESSION_TOKEN", "envtoken") + + // And signal file exists + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { if strings.Contains(name, ".session.envtoken") { return nil, nil // File exists } return nil, os.ErrNotExist } - // Set up environment variable to trigger token regeneration - t.Setenv("WINDSOR_SESSION_TOKEN", "envtoken") - - mocks := setupSafeWindsorEnvMocks() - - // Make the shell's GetSessionToken return an error when regenerating token + // And GetSessionToken returns an error during token regeneration mocks.Shell.GetSessionTokenFunc = func() (string, error) { - // Check if we are being called for the environment token check - if envToken := os.Getenv("WINDSOR_SESSION_TOKEN"); envToken != "" { - // We are doing the check for the environment token + if envToken, exists := mocks.Shims.LookupEnv("WINDSOR_SESSION_TOKEN"); exists { tokenFilePath := filepath.Join("/mock/project/root", ".windsor", ".session."+envToken) - if _, err := stat(tokenFilePath); err == nil { + if _, err := mocks.Shims.Stat(tokenFilePath); err == nil { // Signal file exists, mock error during regeneration return "", fmt.Errorf("mock random generation error during token regeneration") } @@ -485,11 +616,10 @@ func TestWindsorEnv_GetEnvVars(t *testing.T) { return "mock-token", nil } - windsorEnvPrinter := NewWindsorEnvPrinter(mocks.Injector) - windsorEnvPrinter.Initialize() + // When GetEnvVars is called + _, err := printer.GetEnvVars() - // This should trigger an error in token regeneration - _, err := windsorEnvPrinter.GetEnvVars() + // Then an error should be returned if err == nil { t.Fatal("Expected error from random generation during token regeneration, got nil") } @@ -499,15 +629,17 @@ func TestWindsorEnv_GetEnvVars(t *testing.T) { }) t.Run("GetProjectRootError", func(t *testing.T) { - mocks := setupSafeWindsorEnvMocks() + printer, mocks := setup(t) + + // Given a WindsorEnvPrinter with failing project root lookup mocks.Shell.GetProjectRootFunc = func() (string, error) { return "", fmt.Errorf("mock shell error") } - windsorEnvPrinter := NewWindsorEnvPrinter(mocks.Injector) - windsorEnvPrinter.Initialize() + // When GetEnvVars is called + _, err := printer.GetEnvVars() - _, err := windsorEnvPrinter.GetEnvVars() + // Then an error should be returned expectedErrorMessage := "error retrieving project root: mock shell error" if err == nil || err.Error() != expectedErrorMessage { t.Errorf("Expected error %q, got %v", expectedErrorMessage, err) @@ -515,23 +647,23 @@ func TestWindsorEnv_GetEnvVars(t *testing.T) { }) t.Run("ProjectRootErrorDuringTokenCheck", func(t *testing.T) { - mocks := setupSafeWindsorEnvMocks() + printer, mocks := setup(t) - // Set up environment variable with a token to trigger the token check code path + // Given a WindsorEnvPrinter with environment token t.Setenv("WINDSOR_SESSION_TOKEN", "envtoken") - // Make GetSessionToken return an error during the project root check + // And GetSessionToken returns an error during project root check mocks.Shell.GetSessionTokenFunc = func() (string, error) { - if envToken := os.Getenv("WINDSOR_SESSION_TOKEN"); envToken != "" { + if _, exists := mocks.Shims.LookupEnv("WINDSOR_SESSION_TOKEN"); exists { return "", fmt.Errorf("error getting project root: mock shell error during token check") } return "mock-token", nil } - windsorEnvPrinter := NewWindsorEnvPrinter(mocks.Injector) - windsorEnvPrinter.Initialize() + // When GetEnvVars is called + _, err := printer.GetEnvVars() - _, err := windsorEnvPrinter.GetEnvVars() + // Then an error should be returned if err == nil { t.Fatal("Expected error, got nil") } @@ -543,42 +675,44 @@ func TestWindsorEnv_GetEnvVars(t *testing.T) { }) t.Run("ComprehensiveEnvironmentTokenTest", func(t *testing.T) { - // Mock file system functions to handle various cases - originalStat := stat - defer func() { - stat = originalStat - }() + printer, mocks := setup(t) - stat = func(name string) (os.FileInfo, error) { + // Given a WindsorEnvPrinter with mock file system functions + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { if strings.Contains(name, ".session.testtoken") { return nil, nil // Session file exists } return nil, os.ErrNotExist } - mocks := setupSafeWindsorEnvMocks() - // Phase 1: No environment token present - // Should generate a new token - windsorEnvPrinter := NewWindsorEnvPrinter(mocks.Injector) - windsorEnvPrinter.Initialize() + // Given no environment token + mocks.Shims.LookupEnv = func(key string) (string, bool) { + return "", false + } - envVars, err := windsorEnvPrinter.GetEnvVars() + // When GetEnvVars is called + envVars, err := printer.GetEnvVars() if err != nil { t.Fatalf("GetEnvVars returned an error: %v", err) } firstToken := envVars["WINDSOR_SESSION_TOKEN"] // Phase 2: Set environment token - // The mock should return this token since no signal file exists for it - t.Setenv("WINDSOR_SESSION_TOKEN", "testtoken") + // Given environment token is set + mocks.Shims.LookupEnv = func(key string) (string, bool) { + if key == "WINDSOR_SESSION_TOKEN" { + return "testtoken", true + } + return "", false + } - // Update the mock to handle the testtoken case + // And GetSessionToken configured to handle testtoken mocks.Shell.GetSessionTokenFunc = func() (string, error) { - if envToken := os.Getenv("WINDSOR_SESSION_TOKEN"); envToken != "" { + if envToken, exists := mocks.Shims.LookupEnv("WINDSOR_SESSION_TOKEN"); exists { // Our testtoken has a signal file tokenFilePath := filepath.Join("/mock/project/root", ".windsor", ".session."+envToken) - if _, err := stat(tokenFilePath); err == nil { + if _, err := mocks.Shims.Stat(tokenFilePath); err == nil { return "newtoken", nil // Return a different token to show regeneration } return envToken, nil @@ -586,12 +720,13 @@ func TestWindsorEnv_GetEnvVars(t *testing.T) { return "mock-token", nil } - // Now the test token should come back - envVars, err = windsorEnvPrinter.GetEnvVars() + // When GetEnvVars is called again + envVars, err = printer.GetEnvVars() if err != nil { t.Fatalf("GetEnvVars returned an error in phase 2: %v", err) } + // Then a new token should be generated secondToken := envVars["WINDSOR_SESSION_TOKEN"] if secondToken != "newtoken" { t.Errorf("Expected token 'newtoken', got %q", secondToken) @@ -601,1184 +736,356 @@ func TestWindsorEnv_GetEnvVars(t *testing.T) { t.Errorf("Second token %q should be different from the first token %q", secondToken, firstToken) } }) +} - t.Run("RandomErrorDuringSignalFileRegeneration", func(t *testing.T) { - // Mock file system functions - originalStat := stat - defer func() { - stat = originalStat - }() +// TestWindsorEnv_PostEnvHook tests the PostEnvHook method of the WindsorEnvPrinter +func TestWindsorEnv_PostEnvHook(t *testing.T) { + t.Run("Success", func(t *testing.T) { + // Given a WindsorEnvPrinter + injector := di.NewMockInjector() + windsorEnvPrinter := NewWindsorEnvPrinter(injector) - stat = func(name string) (os.FileInfo, error) { - if strings.Contains(name, ".session.testtoken") { - return nil, nil // File exists - } - return nil, os.ErrNotExist + // When PostEnvHook is called + err := windsorEnvPrinter.PostEnvHook() + + // Then no error should be returned + if err != nil { + t.Errorf("PostEnvHook() returned an error: %v", err) } + }) +} - // Set up environment variable to trigger token regeneration - t.Setenv("WINDSOR_SESSION_TOKEN", "testtoken") +// TestWindsorEnv_Print tests the Print method of the WindsorEnvPrinter +func TestWindsorEnv_Print(t *testing.T) { + setup := func(t *testing.T) (*WindsorEnvPrinter, *Mocks) { + t.Helper() + mocks := setupWindsorEnvMocks(t) + printer := NewWindsorEnvPrinter(mocks.Injector) + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize env: %v", err) + } + printer.shims = mocks.Shims + return printer, mocks + } - mocks := setupSafeWindsorEnvMocks() + t.Run("Success", func(t *testing.T) { + printer, mocks := setup(t) - // Make the shell's GetSessionToken return an error when regenerating token - mocks.Shell.GetSessionTokenFunc = func() (string, error) { - // Check if we are being called for the environment token check - if envToken := os.Getenv("WINDSOR_SESSION_TOKEN"); envToken != "" { - // We are doing the check for the environment token - tokenFilePath := filepath.Join("/mock/project/root", ".windsor", ".session."+envToken) - if _, err := stat(tokenFilePath); err == nil { - // Signal file exists, mock error during regeneration - return "", fmt.Errorf("mock random generation error during token regeneration") - } - return envToken, nil - } - return "mock-token", nil + // Given a WindsorEnvPrinter with project root + projectRoot, err := mocks.Shell.GetProjectRoot() + if err != nil { + t.Fatalf("Failed to get project root: %v", err) + } + + // And a mock PrintEnvVars function + var capturedEnvVars map[string]string + mocks.Shell.PrintEnvVarsFunc = func(envVars map[string]string) { + capturedEnvVars = envVars } - windsorEnvPrinter := NewWindsorEnvPrinter(mocks.Injector) - windsorEnvPrinter.Initialize() + // When Print is called + err = printer.Print() - // This should trigger an error in token regeneration - _, err := windsorEnvPrinter.GetEnvVars() - if err == nil { - t.Fatal("Expected error from random generation during token regeneration, got nil") + // Then no error should be returned + if err != nil { + t.Errorf("unexpected error: %v", err) } - if !strings.Contains(err.Error(), "error retrieving session token") { - t.Errorf("Unexpected error message: %v", err) + + // And core Windsor environment variables should be set correctly + if capturedEnvVars["WINDSOR_CONTEXT"] != "mock-context" { + t.Errorf("WINDSOR_CONTEXT = %q, want %q", capturedEnvVars["WINDSOR_CONTEXT"], "mock-context") } - }) - t.Run("DifferentContextDisablesCache", func(t *testing.T) { - mocks := setupSafeWindsorEnvMocks() - - // Set up test environment - envVarKey := "TEST_VAR_WITH_SECRET" - envVarValue := "value with ${{ secrets.mySecret }}" - - // Save original environment values and restore them after test - originalEnvContext := os.Getenv("WINDSOR_CONTEXT") - originalEnvToken := os.Getenv("WINDSOR_SESSION_TOKEN") - originalTestVar := os.Getenv(envVarKey) - originalNoCache := os.Getenv("NO_CACHE") - - // Setting NO_CACHE=true should disable the cache - t.Setenv("NO_CACHE", "true") - t.Setenv("WINDSOR_CONTEXT", "different-context") - t.Setenv("WINDSOR_SESSION_TOKEN", "") - t.Setenv(envVarKey, "existing-value") - - defer func() { - os.Setenv("WINDSOR_CONTEXT", originalEnvContext) - os.Setenv("WINDSOR_SESSION_TOKEN", originalEnvToken) - os.Setenv(envVarKey, originalTestVar) - os.Setenv("NO_CACHE", originalNoCache) - }() - - // Set up mock config handler to return environment variables - mocks.ConfigHandler.GetStringMapFunc = func(key string, defaultValue ...map[string]string) map[string]string { - if key == "environment" { - return map[string]string{ - envVarKey: envVarValue, - } - } - return map[string]string{} + if capturedEnvVars["WINDSOR_PROJECT_ROOT"] != projectRoot { + t.Errorf("WINDSOR_PROJECT_ROOT = %q, want %q", capturedEnvVars["WINDSOR_PROJECT_ROOT"], projectRoot) } - // Mock secrets provider that will be called regardless of cache - mockSecretsProvider := secrets.NewMockSecretsProvider(mocks.Injector) - mockSecretsProvider.ParseSecretsFunc = func(input string) (string, error) { - if input == envVarValue { - return "resolved-value", nil - } - return input, nil + if capturedEnvVars["WINDSOR_SESSION_TOKEN"] == "" { + t.Errorf("WINDSOR_SESSION_TOKEN is empty") } - // Create WindsorEnvPrinter with mock injector - mockInjector := mocks.Injector - mockInjector.Register("secretsProvider", mockSecretsProvider) - windsorEnvPrinter := NewWindsorEnvPrinter(mockInjector) - err := windsorEnvPrinter.Initialize() - if err != nil { - t.Fatalf("Initialize should not return an error: %v", err) + // And WINDSOR_MANAGED_ENV should include core Windsor variables + managedEnv := capturedEnvVars["WINDSOR_MANAGED_ENV"] + coreVars := []string{"WINDSOR_CONTEXT", "WINDSOR_PROJECT_ROOT", "WINDSOR_SESSION_TOKEN", "WINDSOR_MANAGED_ENV", "WINDSOR_MANAGED_ALIAS"} + for _, v := range coreVars { + if !strings.Contains(managedEnv, v) { + t.Errorf("WINDSOR_MANAGED_ENV should contain %q, got %q", v, managedEnv) + } } + }) - // Make secretsProviders accessible to the test - windsorEnvPrinter.secretsProviders = []secrets.SecretsProvider{mockSecretsProvider} + t.Run("GetProjectRootError", func(t *testing.T) { + printer, mocks := setup(t) - // Get environment variables - envVars, err := windsorEnvPrinter.GetEnvVars() - if err != nil { - t.Fatalf("GetEnvVars should not return an error: %v", err) + // Given a WindsorEnvPrinter with failing project root lookup + mocks.Shell.GetProjectRootFunc = func() (string, error) { + return "", fmt.Errorf("mock project root error") } - // Verify the variable was resolved despite having an existing value in the environment - // This confirms that NO_CACHE=true worked as expected - if envVars[envVarKey] != "resolved-value" { - t.Errorf("Environment variable should be resolved even with existing value when NO_CACHE=true") + // When Print is called + err := printer.Print() + + // Then an error should be returned + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), "mock project root error") { + t.Errorf("unexpected error message: %v", err) } }) +} - t.Run("NoCacheEnvVarDisablesCache", func(t *testing.T) { - mocks := setupSafeWindsorEnvMocks() +// TestWindsorEnv_Initialize tests the Initialize method of the WindsorEnvPrinter +func TestWindsorEnv_Initialize(t *testing.T) { + setup := func(t *testing.T) (*WindsorEnvPrinter, *Mocks) { + t.Helper() + mocks := setupWindsorEnvMocks(t) + printer := NewWindsorEnvPrinter(mocks.Injector) + return printer, mocks + } - // Override random generation to avoid token generation errors - origCryptoRandRead := cryptoRandRead - cryptoRandRead = func(b []byte) (n int, err error) { - for i := range b { - b[i] = byte(i%26) + 'a' // Generate predictable letters - } - return len(b), nil - } - defer func() { - cryptoRandRead = origCryptoRandRead - }() - - // Set up test environment variables - envVarKey := "TEST_VAR_WITH_SECRET" - envVarValue := "value with ${{ secrets.mySecret }}" - - // Save original environment values and restore them after test - originalEnvContext := os.Getenv("WINDSOR_CONTEXT") - originalEnvToken := os.Getenv("WINDSOR_SESSION_TOKEN") - originalTestVar := os.Getenv(envVarKey) - originalNoCache := os.Getenv("NO_CACHE") - - // Setting NO_CACHE=true should disable the cache - t.Setenv("NO_CACHE", "true") - t.Setenv("WINDSOR_CONTEXT", "") // Use same context to test NO_CACHE specifically - t.Setenv("WINDSOR_SESSION_TOKEN", "") - t.Setenv(envVarKey, "existing-value-should-be-ignored") - - defer func() { - os.Setenv("WINDSOR_CONTEXT", originalEnvContext) - os.Setenv("WINDSOR_SESSION_TOKEN", originalEnvToken) - os.Setenv(envVarKey, originalTestVar) - os.Setenv("NO_CACHE", originalNoCache) - }() - - // Configure mock config handler - mocks.ConfigHandler.GetStringMapFunc = func(key string, defaultValue ...map[string]string) map[string]string { - if key == "environment" { - return map[string]string{ - envVarKey: envVarValue, - } - } - return map[string]string{} - } + t.Run("Success", func(t *testing.T) { + printer, _ := setup(t) - // Mock secrets provider that will resolve the secret - mockSecretsProvider := secrets.NewMockSecretsProvider(mocks.Injector) - mockSecretsProvider.ParseSecretsFunc = func(input string) (string, error) { - if input == envVarValue { - return "resolved-value", nil - } - return input, nil - } + // When Initialize is called + err := printer.Initialize() - // Create WindsorEnvPrinter with mock injector - mockInjector := mocks.Injector - mockInjector.Register("secretsProvider", mockSecretsProvider) - windsorEnvPrinter := NewWindsorEnvPrinter(mockInjector) - err := windsorEnvPrinter.Initialize() + // Then no error should be returned if err != nil { - t.Fatalf("Initialize should not return an error: %v", err) + t.Fatalf("Initialize returned error: %v", err) } - // Make secretsProviders accessible to the test - windsorEnvPrinter.secretsProviders = []secrets.SecretsProvider{mockSecretsProvider} + // And secretsProviders should be populated + if len(printer.secretsProviders) != 1 { + t.Errorf("Expected 1 secrets provider, got %d", len(printer.secretsProviders)) + } + }) - // Get environment variables - envVars, err := windsorEnvPrinter.GetEnvVars() - if err != nil { - t.Fatalf("GetEnvVars should not return an error: %v", err) - } - - // Verify the variable was resolved despite having an existing value in the environment - // This confirms that NO_CACHE=true worked as expected - if envVars[envVarKey] != "resolved-value" { - t.Errorf("Environment variable should be resolved even with existing value when NO_CACHE=true") - } - }) - - t.Run("RegularEnvironmentVarsWithoutSecrets", func(t *testing.T) { - mocks := setupSafeWindsorEnvMocks() - - // Override random generation to avoid token generation errors - origCryptoRandRead := cryptoRandRead - cryptoRandRead = func(b []byte) (n int, err error) { - for i := range b { - b[i] = byte(i%26) + 'a' // Generate predictable letters - } - return len(b), nil - } - defer func() { - cryptoRandRead = origCryptoRandRead - }() - - // Set up test environment variables with regular values (no secret placeholders) - regularVarKey1 := "REGULAR_ENV_VAR1" - regularVarValue1 := "regular value 1" - regularVarKey2 := "REGULAR_ENV_VAR2" - regularVarValue2 := "regular value 2" - - // Save original environment values - originalEnvContext := os.Getenv("WINDSOR_CONTEXT") - originalEnvToken := os.Getenv("WINDSOR_SESSION_TOKEN") - - // Clean environment for test - t.Setenv("WINDSOR_CONTEXT", "") - t.Setenv("WINDSOR_SESSION_TOKEN", "") - - defer func() { - os.Setenv("WINDSOR_CONTEXT", originalEnvContext) - os.Setenv("WINDSOR_SESSION_TOKEN", originalEnvToken) - }() - - // Configure mock config handler with regular environment variables - mocks.ConfigHandler.GetStringMapFunc = func(key string, defaultValue ...map[string]string) map[string]string { - if key == "environment" { - return map[string]string{ - regularVarKey1: regularVarValue1, - regularVarKey2: regularVarValue2, - "WITH_SECRET": "${{ secrets.mySecret }}", // Include one with secret to test both branches - } - } - return map[string]string{} - } - - // Mock secrets provider - mockSecretsProvider := secrets.NewMockSecretsProvider(mocks.Injector) - mockSecretsProvider.ParseSecretsFunc = func(input string) (string, error) { - if input == "${{ secrets.mySecret }}" { - return "resolved-secret", nil - } - return input, nil - } - - // Create WindsorEnvPrinter with mock injector - mockInjector := mocks.Injector - mockInjector.Register("secretsProvider", mockSecretsProvider) - windsorEnvPrinter := NewWindsorEnvPrinter(mockInjector) - err := windsorEnvPrinter.Initialize() - if err != nil { - t.Fatalf("Initialize should not return an error: %v", err) - } - - // Make secretsProviders accessible to the test - windsorEnvPrinter.secretsProviders = []secrets.SecretsProvider{mockSecretsProvider} - - // Get environment variables - envVars, err := windsorEnvPrinter.GetEnvVars() - if err != nil { - t.Fatalf("GetEnvVars should not return an error: %v", err) - } - - // Verify the regular variables were set directly without parsing - if envVars[regularVarKey1] != regularVarValue1 { - t.Errorf("Regular environment variable should be set directly") - } - if envVars[regularVarKey2] != regularVarValue2 { - t.Errorf("Regular environment variable should be set directly") - } - - // Also verify that the secret was parsed correctly - if envVars["WITH_SECRET"] != "resolved-secret" { - t.Errorf("Environment variable with secret should be resolved") - } - }) - - t.Run("ManagedCustomEnvironmentVars", func(t *testing.T) { - // Save original values - originalManagedEnv := make([]string, len(windsorManagedEnv)) - copy(originalManagedEnv, windsorManagedEnv) - - // Save original environment values - originalEnvContext := os.Getenv("WINDSOR_CONTEXT") - originalEnvToken := os.Getenv("WINDSOR_SESSION_TOKEN") - originalNoCache := os.Getenv("NO_CACHE") - - // Set environment variables for test - ensure NO_CACHE is unset - t.Setenv("WINDSOR_CONTEXT", "") - t.Setenv("WINDSOR_SESSION_TOKEN", "") - t.Setenv("NO_CACHE", "") - - // Restore original environment variables after test - defer func() { - os.Setenv("WINDSOR_CONTEXT", originalEnvContext) - os.Setenv("WINDSOR_SESSION_TOKEN", originalEnvToken) - os.Setenv("NO_CACHE", originalNoCache) - }() - - // Restore original state after test - defer func() { - windsorManagedMu.Lock() - windsorManagedEnv = originalManagedEnv - windsorManagedMu.Unlock() - }() - - // Setup mocks - mocks := setupSafeWindsorEnvMocks() - - // Override random generation to avoid token generation errors - origCryptoRandRead := cryptoRandRead - cryptoRandRead = func(b []byte) (n int, err error) { - for i := range b { - b[i] = byte(i%26) + 'a' // Generate predictable letters - } - return len(b), nil - } - defer func() { - cryptoRandRead = origCryptoRandRead - }() - - // Set up mock config handler to return custom environment variables - mocks.ConfigHandler.GetStringMapFunc = func(key string, defaultValue ...map[string]string) map[string]string { - if key == "environment" { - return map[string]string{ - "CUSTOM_ENV_VAR1": "value1", - "CUSTOM_ENV_VAR2": "value2", - } - } - return map[string]string{} - } + t.Run("ResolveAllError", func(t *testing.T) { + // Given a WindsorEnvPrinter with failing ResolveAll + injector := di.NewMockInjector() + setupWindsorEnvMocks(t, &SetupOptions{ + Injector: injector, + }) - // Track custom variables - windsorManagedMu.Lock() - windsorManagedEnv = []string{"CUSTOM_ENV_VAR1", "CUSTOM_ENV_VAR2"} - windsorManagedMu.Unlock() + // And error set for resolving secrets providers + injector.SetResolveAllError((*secrets.SecretsProvider)(nil), fmt.Errorf("mock error")) - // Create WindsorEnvPrinter and initialize it - windsorEnvPrinter := NewWindsorEnvPrinter(mocks.Injector) - err := windsorEnvPrinter.Initialize() - if err != nil { - t.Fatalf("Failed to initialize WindsorEnvPrinter: %v", err) - } + // When Initialize is called + printer := NewWindsorEnvPrinter(injector) + err := printer.Initialize() - // Get environment variables - envVars, err := windsorEnvPrinter.GetEnvVars() - if err != nil { - t.Fatalf("GetEnvVars should not return an error: %v", err) - } - - // Verify custom variables are in the environment variables map - if envVars["CUSTOM_ENV_VAR1"] != "value1" { - t.Errorf("CUSTOM_ENV_VAR1 should be set to 'value1'") - } - if envVars["CUSTOM_ENV_VAR2"] != "value2" { - t.Errorf("CUSTOM_ENV_VAR2 should be set to 'value2'") - } - - // Verify that WINDSOR_MANAGED_ENV includes our custom variables and Windsor prefixed vars - managedEnvList := envVars["WINDSOR_MANAGED_ENV"] - expectedVars := []string{ - "CUSTOM_ENV_VAR1", - "CUSTOM_ENV_VAR2", - "WINDSOR_CONTEXT", - "WINDSOR_PROJECT_ROOT", - "WINDSOR_SESSION_TOKEN", - "WINDSOR_MANAGED_ENV", - "WINDSOR_MANAGED_ALIAS", + // Then an error should be returned + if err == nil { + t.Fatal("Expected error, got nil") } - for _, v := range expectedVars { - if !strings.Contains(managedEnvList, v) { - t.Errorf("WINDSOR_MANAGED_ENV should contain %s", v) - } + if !strings.Contains(err.Error(), "failed to resolve secrets providers") { + t.Errorf("Unexpected error message: %v", err) } }) +} - t.Run("ErrorValueBypassesCache", func(t *testing.T) { - // Save original functions - originalCryptoRandRead := cryptoRandRead - - // Restore original function after test - defer func() { - cryptoRandRead = originalCryptoRandRead - }() - - // Set up test mocks - mocks := setupSafeWindsorEnvMocks() - - // Override random generation to avoid token generation errors - cryptoRandRead = func(b []byte) (n int, err error) { - for i := range b { - b[i] = byte(i%26) + 'a' // Generate predictable letters - } - return len(b), nil - } +// TestWindsorEnv_ParseAndCheckSecrets tests the parseAndCheckSecrets method +func TestWindsorEnv_ParseAndCheckSecrets(t *testing.T) { + setup := func(t *testing.T) (*WindsorEnvPrinter, *Mocks) { + t.Helper() + mocks := setupWindsorEnvMocks(t) + printer := NewWindsorEnvPrinter(mocks.Injector) + if err := printer.Initialize(); err != nil { + t.Fatalf("Failed to initialize env: %v", err) + } + printer.shims = mocks.Shims + return printer, mocks + } - // Set up test environment variables - errorVarKey := "TEST_VAR_WITH_ERROR" - normalVarKey := "TEST_VAR_NORMAL" - errorVarValue := "value with ${{ secrets.errorSecret }}" - normalVarValue := "value with ${{ secrets.normalSecret }}" - - // Save original environment values - originalEnvContext := os.Getenv("WINDSOR_CONTEXT") - originalEnvToken := os.Getenv("WINDSOR_SESSION_TOKEN") - originalErrorVar := os.Getenv(errorVarKey) - originalNormalVar := os.Getenv(normalVarKey) - originalNoCache := os.Getenv("NO_CACHE") - - // Explicitly set NO_CACHE=false to enable caching - os.Setenv("NO_CACHE", "false") - os.Setenv("WINDSOR_CONTEXT", "") - os.Setenv("WINDSOR_SESSION_TOKEN", "") - - // Set the existing values - one with error and one normal but with a pattern that will be cached - os.Setenv(errorVarKey, "") - os.Setenv(normalVarKey, "cached-normal-value") - - // Restore original values after test - defer func() { - os.Setenv("WINDSOR_CONTEXT", originalEnvContext) - os.Setenv("WINDSOR_SESSION_TOKEN", originalEnvToken) - os.Setenv(errorVarKey, originalErrorVar) - os.Setenv(normalVarKey, originalNormalVar) - os.Setenv("NO_CACHE", originalNoCache) - }() - - // Verify caching is enabled in this test - if !shouldUseCache() { - t.Fatalf("shouldUseCache() returned false, expected true with NO_CACHE=false") - } - - // Configure mock config handler - mocks.ConfigHandler.GetStringMapFunc = func(key string, defaultValue ...map[string]string) map[string]string { - if key == "environment" { - return map[string]string{ - errorVarKey: errorVarValue, - normalVarKey: normalVarValue, - } - } - return map[string]string{} - } + t.Run("Success", func(t *testing.T) { + printer, mocks := setup(t) - // Mock secrets provider that will resolve the secrets + // Given a mock secrets provider that successfully parses secrets mockSecretsProvider := secrets.NewMockSecretsProvider(mocks.Injector) mockSecretsProvider.ParseSecretsFunc = func(input string) (string, error) { - if input == errorVarValue { - return "resolved-error-value", nil - } - if input == normalVarValue { - return "resolved-normal-value", nil + if input == "value with ${{ secrets.mySecret }}" { + return "value with resolved-secret", nil } return input, nil } + printer.secretsProviders = []secrets.SecretsProvider{mockSecretsProvider} - // Create WindsorEnvPrinter with mock injector - mockInjector := mocks.Injector - mockInjector.Register("secretsProvider", mockSecretsProvider) - windsorEnvPrinter := NewWindsorEnvPrinter(mockInjector) - - // Initialize the WindsorEnvPrinter - err := windsorEnvPrinter.Initialize() - if err != nil { - t.Fatalf("Initialize returned error: %v", err) - } - - // Make secretsProviders accessible to the test - windsorEnvPrinter.secretsProviders = []secrets.SecretsProvider{mockSecretsProvider} + // When parseAndCheckSecrets is called + result := printer.parseAndCheckSecrets("value with ${{ secrets.mySecret }}") - // Get environment variables - envVars, err := windsorEnvPrinter.GetEnvVars() - if err != nil { - t.Fatalf("GetEnvVars returned error: %v", err) - } - - // Check that the variable with was re-resolved - // despite caching being enabled - if got, want := envVars[errorVarKey], "resolved-error-value"; got != want { - t.Errorf("Environment variable with was not properly re-resolved: got %q, want %q", got, want) - } - - // Check that we got the expected keys in the results - expectedKeys := []string{ - "WINDSOR_CONTEXT", - "WINDSOR_PROJECT_ROOT", - "WINDSOR_SESSION_TOKEN", - "WINDSOR_MANAGED_ENV", - "WINDSOR_MANAGED_ALIAS", - errorVarKey, - } - - // Verify all expected keys are present - for _, key := range expectedKeys { - if _, exists := envVars[key]; !exists { - t.Errorf("Expected key %q missing from results", key) - } - } - - // Normal variable should not be in results because it's cached - if _, exists := envVars[normalVarKey]; exists { - t.Errorf("Cached normal variable %q should not be in results", normalVarKey) - } - }) - - t.Run("ManagedEnv", func(t *testing.T) { - // Setup mocks - mocks := setupSafeWindsorEnvMocks() - - // Create a test environment - env := NewWindsorEnvPrinter(mocks.Injector) - env.Initialize() - - // Set up managed environment - env.SetManagedEnv("test-env") - - // Get environment variables - vars, err := env.GetEnvVars() - - // Verify the result - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - - // Verify managed environment contains the test-env and Windsor prefixed vars - expectedVars := append([]string{"test-env"}, WindsorPrefixedVars...) - managedEnvVars := strings.Split(vars["WINDSOR_MANAGED_ENV"], ",") - - for _, expected := range expectedVars { - found := false - for _, actual := range managedEnvVars { - if actual == expected { - found = true - break - } - } - if !found { - t.Errorf("Expected WINDSOR_MANAGED_ENV to contain %q", expected) - } + // Then the secret should be resolved + if result != "value with resolved-secret" { + t.Errorf("Expected 'value with resolved-secret', got %q", result) } }) - t.Run("CachedVariableAddedToManagedEnv", func(t *testing.T) { - // Setup mocks - mocks := setupSafeWindsorEnvMocks() - - // Set up test environment variables - cachedVarKey := "CACHED_VAR" - cachedVarValue := "value with ${{ secrets.cachedSecret }}" - secretVarKey := "SECRET_VAR" - secretVarValue := "value with ${{ secrets.mySecret }}" - - // Save original environment values - originalEnvContext := os.Getenv("WINDSOR_CONTEXT") - originalEnvToken := os.Getenv("WINDSOR_SESSION_TOKEN") - originalCachedVar := os.Getenv(cachedVarKey) - originalSecretVar := os.Getenv(secretVarKey) - originalNoCache := os.Getenv("NO_CACHE") - - // Set up environment with cached variable - t.Setenv("NO_CACHE", "false") - t.Setenv("WINDSOR_CONTEXT", "") - t.Setenv("WINDSOR_SESSION_TOKEN", "") - t.Setenv(cachedVarKey, "cached-value") - - defer func() { - os.Setenv("WINDSOR_CONTEXT", originalEnvContext) - os.Setenv("WINDSOR_SESSION_TOKEN", originalEnvToken) - os.Setenv(cachedVarKey, originalCachedVar) - os.Setenv(secretVarKey, originalSecretVar) - os.Setenv("NO_CACHE", originalNoCache) - }() - - // Configure mock config handler - mocks.ConfigHandler.GetStringMapFunc = func(key string, defaultValue ...map[string]string) map[string]string { - if key == "environment" { - return map[string]string{ - cachedVarKey: cachedVarValue, - secretVarKey: secretVarValue, - } - } - return map[string]string{} - } + t.Run("SecretsProviderError", func(t *testing.T) { + printer, mocks := setup(t) - // Mock secrets provider + // Given a mock secrets provider that fails to parse secrets mockSecretsProvider := secrets.NewMockSecretsProvider(mocks.Injector) mockSecretsProvider.ParseSecretsFunc = func(input string) (string, error) { - if input == secretVarValue { - return "resolved-secret", nil - } - if input == cachedVarValue { - return "resolved-cached", nil - } - return input, nil - } - - // Create WindsorEnvPrinter with mock injector - mockInjector := mocks.Injector - mockInjector.Register("secretsProvider", mockSecretsProvider) - windsorEnvPrinter := NewWindsorEnvPrinter(mockInjector) - err := windsorEnvPrinter.Initialize() - if err != nil { - t.Fatalf("Initialize returned error: %v", err) - } - - // Make secretsProviders accessible to the test - windsorEnvPrinter.secretsProviders = []secrets.SecretsProvider{mockSecretsProvider} - - // Get environment variables - envVars, err := windsorEnvPrinter.GetEnvVars() - if err != nil { - t.Fatalf("GetEnvVars returned error: %v", err) + return "", fmt.Errorf("error parsing secrets") } + printer.secretsProviders = []secrets.SecretsProvider{mockSecretsProvider} - // Verify cached variable is not in returned environment variables - if _, exists := envVars[cachedVarKey]; exists { - t.Errorf("Cached variable %q should not be in returned environment variables", cachedVarKey) - } + // When parseAndCheckSecrets is called + result := printer.parseAndCheckSecrets("value with ${{ secrets.mySecret }}") - // Verify secret variable is in returned environment variables - if envVars[secretVarKey] != "resolved-secret" { - t.Errorf("Secret variable should be resolved and returned") + // Then an error message should be returned + if !strings.Contains(result, "', got %q", result) } }) - t.Run("ExistingVariableInManagedEnv", func(t *testing.T) { - // Reset managed environment variables - windsorManagedMu.Lock() - windsorManagedEnv = []string{} - windsorManagedAlias = []string{} - windsorManagedMu.Unlock() - - // Setup mocks - mocks := setupSafeWindsorEnvMocks() - - // Set up test environment variables - envVarKey := "EXISTING_VAR" - envVarValue := "regular value" - - // Save original environment values - originalEnvContext := os.Getenv("WINDSOR_CONTEXT") - originalEnvToken := os.Getenv("WINDSOR_SESSION_TOKEN") - originalEnvVar := os.Getenv(envVarKey) - originalManagedEnv := os.Getenv("WINDSOR_MANAGED_ENV") - originalNoCache := os.Getenv("NO_CACHE") - - // Set up environment with variable already in managed env - t.Setenv("NO_CACHE", "false") - t.Setenv("WINDSOR_CONTEXT", "") - t.Setenv("WINDSOR_SESSION_TOKEN", "") - t.Setenv(envVarKey, "existing-value") - t.Setenv("WINDSOR_MANAGED_ENV", envVarKey) - - defer func() { - os.Setenv("WINDSOR_CONTEXT", originalEnvContext) - os.Setenv("WINDSOR_SESSION_TOKEN", originalEnvToken) - os.Setenv(envVarKey, originalEnvVar) - os.Setenv("WINDSOR_MANAGED_ENV", originalManagedEnv) - os.Setenv("NO_CACHE", originalNoCache) - }() - - // Configure mock config handler - mocks.ConfigHandler.GetStringMapFunc = func(key string, defaultValue ...map[string]string) map[string]string { - if key == "environment" { - return map[string]string{ - envVarKey: envVarValue, - } - } - return map[string]string{} - } + t.Run("UnparsedSecrets", func(t *testing.T) { + printer, mocks := setup(t) - // Create WindsorEnvPrinter with mock injector - windsorEnvPrinter := NewWindsorEnvPrinter(mocks.Injector) - err := windsorEnvPrinter.Initialize() - if err != nil { - t.Fatalf("Initialize returned error: %v", err) + // Given a mock secrets provider that doesn't recognize secrets + mockSecretsProvider := secrets.NewMockSecretsProvider(mocks.Injector) + mockSecretsProvider.ParseSecretsFunc = func(input string) (string, error) { + return input, nil } + printer.secretsProviders = []secrets.SecretsProvider{mockSecretsProvider} - // Set managed environment variable - windsorEnvPrinter.SetManagedEnv(envVarKey) - - // Get environment variables - envVars, err := windsorEnvPrinter.GetEnvVars() - if err != nil { - t.Fatalf("GetEnvVars returned error: %v", err) - } + // When parseAndCheckSecrets is called + result := printer.parseAndCheckSecrets("value with ${{ secrets.mySecret }}") - // Verify variable is in returned environment variables - if envVars[envVarKey] != envVarValue { - t.Errorf("Variable should be in returned environment variables with value %q, got %q", envVarValue, envVars[envVarKey]) + // Then an error message should be returned + if !strings.Contains(result, "', got %q", result) + // Then it should return true + if !shouldCache { + t.Error("Expected shouldUseCache to return true for NO_CACHE=0") } }) - t.Run("UnparsedSecrets", func(t *testing.T) { - // Setup - mockInjector := di.NewMockInjector() - mockSecretsProvider := secrets.NewMockSecretsProvider(mockInjector) - // This provider doesn't recognize the secret pattern - mockSecretsProvider.ParseSecretsFunc = func(input string) (string, error) { - return input, nil + t.Run("NoCacheFalse", func(t *testing.T) { + // Given NO_CACHE environment variable is set to "false" + printer, mocks := setup(t) + mocks.Shims.LookupEnv = func(key string) (string, bool) { + if key == "NO_CACHE" { + return "false", true + } + return "", false } - windsorEnv := NewWindsorEnvPrinter(mockInjector) - windsorEnv.secretsProviders = []secrets.SecretsProvider{mockSecretsProvider} - - // Call the method with a string containing a secret - result := windsorEnv.parseAndCheckSecrets("value with ${{ secrets.mySecret }}") + // When shouldUseCache is called + shouldCache := printer.shouldUseCache() - // Verify result - if !strings.Contains(result, " 0 { s.UnsetEnvs(managedEnvs) } - if len(managedAliases) > 0 { s.UnsetAlias(managedAliases) } } -// generateRandomString creates a secure random string of the given length using a predefined charset. -func (s *DefaultShell) generateRandomString(length int) (string, error) { - const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - randomBytes := make([]byte, length) - - _, err := randRead(randomBytes) - if err != nil { - return "", err - } - - // Map random bytes to charset - for i, b := range randomBytes { - randomBytes[i] = charset[b%byte(len(charset))] - } - - return string(randomBytes), nil -} - // CheckResetFlags checks if a reset signal file exists for the current session token. // It returns true if the specific session token file exists and always removes all .session.* files. func (s *DefaultShell) CheckResetFlags() (bool, error) { // Get current session token from environment - envToken := osGetenv("WINDSOR_SESSION_TOKEN") + envToken := s.shims.Getenv("WINDSOR_SESSION_TOKEN") if envToken == "" { return false, nil } @@ -561,18 +559,17 @@ func (s *DefaultShell) CheckResetFlags() (bool, error) { // Check for the specific session token file tokenFileExists := false - if _, err := osStat(tokenFilePath); err == nil { + if _, err := s.shims.Stat(tokenFilePath); err == nil { tokenFileExists = true } - // Remove all .session.* files - sessionFiles, err := filepathGlob(filepath.Join(windsorDir, SessionTokenPrefix+"*")) + sessionFiles, err := s.shims.Glob(filepath.Join(windsorDir, SessionTokenPrefix+"*")) if err != nil { return false, fmt.Errorf("error finding session files: %w", err) } for _, file := range sessionFiles { - if err := osRemoveAll(file); err != nil { + if err := s.shims.RemoveAll(file); err != nil { return false, fmt.Errorf("error removing session file %s: %w", file, err) } } @@ -580,5 +577,31 @@ func (s *DefaultShell) CheckResetFlags() (bool, error) { return tokenFileExists, nil } +// ============================================================================= +// Private Methods +// ============================================================================= + +// generateRandomString creates a secure random string of the given length using a predefined charset. +func (s *DefaultShell) generateRandomString(length int) (string, error) { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + b := make([]byte, length) + if _, err := s.shims.RandRead(b); err != nil { + return "", err + } + for i := range b { + b[i] = charset[int(b[i])%len(charset)] + } + return string(b), nil +} + +// ============================================================================= +// Public Functions +// ============================================================================= + +// ResetSessionToken resets the session token - used primarily for testing +func (s *DefaultShell) ResetSessionToken() { + s.sessionToken = "" +} + // Ensure DefaultShell implements the Shell interface var _ Shell = (*DefaultShell)(nil) diff --git a/pkg/shell/shell_test.go b/pkg/shell/shell_test.go index 18508c79d..bf7155a74 100644 --- a/pkg/shell/shell_test.go +++ b/pkg/shell/shell_test.go @@ -3,1019 +3,1368 @@ package shell import ( "bufio" "bytes" - "errors" "fmt" "io" "os" "os/exec" "path/filepath" - "runtime" "strings" - "sync" "testing" "text/template" "github.com/windsorcli/cli/pkg/di" ) -// Test utilities for shell tests +// The ShellTest is a test suite for the Shell interface and its implementations. +// It provides comprehensive test coverage for shell operations, command execution, +// project root detection, and environment management. +// The ShellTest acts as a validation framework for shell functionality, +// ensuring reliable command execution, proper error handling, and environment isolation. + +// ============================================================================= +// Test Setup +// ============================================================================= + +// Mock functions for testing +var ( + Command func(name string, args ...string) *exec.Cmd + CmdStart func(cmd *exec.Cmd) error + CmdWait func(cmd *exec.Cmd) error + CmdRun func(cmd *exec.Cmd) error + StdoutPipe func(cmd *exec.Cmd) (io.ReadCloser, error) + StderrPipe func(cmd *exec.Cmd) (io.ReadCloser, error) + Getwd func() (string, error) + Stat func(name string) (os.FileInfo, error) +) -// setupShellTest creates a new DefaultShell for testing -func setupShellTest(t *testing.T) *DefaultShell { - // Create a new injector - injector := di.NewInjector() - // Create a new default shell - shell := NewDefaultShell(injector) - // Initialize the shell - err := shell.Initialize() - if err != nil { - t.Fatalf("Failed to initialize shell: %v", err) - } - return shell +type Mocks struct { + Injector di.Injector + Shims *Shims + TmpDir string +} + +type SetupOptions struct { + Injector di.Injector } -// Helper function to test random string generation -func testRandomStringGeneration(t *testing.T, shell *DefaultShell, length int) { +// setupMocks creates a new set of mocks for testing +func setupMocks(t *testing.T) *Mocks { t.Helper() - // Generate a string of the specified length - token, err := shell.generateRandomString(length) + // Create temp dir + tmpDir := t.TempDir() - // Verify no errors - if err != nil { - t.Fatalf("generateRandomString() error: %v", err) + // Create shims with mock implementations + shims := NewShims() + + // Mock command execution with proper cleanup + shims.Command = func(name string, args ...string) *exec.Cmd { + cmd := exec.Command("echo", "test") + cmd.Stdout = new(bytes.Buffer) + cmd.Stderr = new(bytes.Buffer) + return cmd + } + + // Mock command execution methods with proper cleanup + shims.CmdStart = func(cmd *exec.Cmd) error { + if cmd.Stdout != nil { + if w, ok := cmd.Stdout.(io.Writer); ok { + if _, err := w.Write([]byte("test\n")); err != nil { + return fmt.Errorf("failed to write to stdout: %v", err) + } + } + } + return nil + } + + shims.CmdWait = func(cmd *exec.Cmd) error { + return nil + } + + shims.CmdRun = func(cmd *exec.Cmd) error { + if cmd.Stdout != nil { + if w, ok := cmd.Stdout.(io.Writer); ok { + if _, err := w.Write([]byte("test\n")); err != nil { + return fmt.Errorf("failed to write to stdout: %v", err) + } + } + } + return nil + } + + // Mock pipes with proper cleanup + shims.StdoutPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + go func() { + if _, err := w.Write([]byte("test output\n")); err != nil { + t.Errorf("Failed to write to stdout pipe: %v", err) + } + w.Close() + }() + return r, nil + } + + shims.StderrPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + go func() { + if _, err := w.Write([]byte("error\n")); err != nil { + t.Errorf("Failed to write to stderr pipe: %v", err) + } + w.Close() + }() + return r, nil + } + + // Mock file operations + shims.Getwd = func() (string, error) { + return "/test/dir", nil + } + + shims.Stat = func(name string) (os.FileInfo, error) { + if name == "trusted_dirs" { + return nil, os.ErrNotExist + } + return nil, nil + } + + // Mock file operations with proper cleanup + shims.OpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) { + return os.NewFile(0, "test"), nil + } + + shims.WriteFile = func(name string, data []byte, perm os.FileMode) error { + return nil + } + + shims.ReadFile = func(name string) ([]byte, error) { + return []byte("test\n"), nil + } + + shims.MkdirAll = func(path string, perm os.FileMode) error { + return nil + } + + shims.Remove = func(name string) error { + return nil + } + + shims.RemoveAll = func(path string) error { + return nil + } + + shims.Chdir = func(dir string) error { + return nil + } + + shims.Setenv = func(key, value string) error { + return nil } - // Verify correct length - if len(token) != length { - t.Errorf("Expected token to have length %d, got %d", length, len(token)) + shims.Getenv = func(key string) string { + return "" } - // Check that token only contains expected characters - validChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - for _, char := range token { - if !strings.ContainsRune(validChars, char) { - t.Errorf("Token contains unexpected character: %c", char) + shims.UserHomeDir = func() (string, error) { + return "/home/test", nil + } + + // Mock random operations with proper cleanup + shims.RandRead = func(b []byte) (n int, err error) { + for i := range b { + b[i] = byte(i % 62) } + return len(b), nil + } + + // Mock template operations with proper cleanup + shims.NewTemplate = func(name string) *template.Template { + return template.New(name) + } + + shims.TemplateParse = func(tmpl *template.Template, text string) (*template.Template, error) { + return tmpl.Parse(text) + } + + shims.TemplateExecute = func(tmpl *template.Template, wr io.Writer, data any) error { + return nil + } + + shims.ExecuteTemplate = func(tmpl *template.Template, data interface{}) error { + return nil + } + + // Mock bufio operations with proper cleanup + shims.ScannerErr = func(scanner *bufio.Scanner) error { + return nil + } + + shims.NewWriter = func(w io.Writer) *bufio.Writer { + return bufio.NewWriter(w) + } + + // Mock filepath operations + shims.Glob = func(pattern string) ([]string, error) { + return []string{"/test/dir/test"}, nil + } + + shims.Join = func(elem ...string) string { + return filepath.Join(elem...) + } + + shims.ScannerText = func(scanner *bufio.Scanner) string { + return scanner.Text() + } + + return &Mocks{ + Injector: di.NewMockInjector(), + Shims: shims, + TmpDir: tmpDir, } } +// ============================================================================= +// Test Public Methods +// ============================================================================= + func TestShell_Initialize(t *testing.T) { t.Run("Success", func(t *testing.T) { - injector := di.NewInjector() - - // Given a DefaultShell instance - shell := NewDefaultShell(injector) + // Given a shell + shell := NewDefaultShell(nil) - // When calling Initialize + // When initializing the shell err := shell.Initialize() - // Then no error should be returned + // Then it should succeed if err != nil { - t.Errorf("Initialize() error = %v, wantErr %v", err, false) + t.Errorf("Expected no error, got %v", err) } }) } func TestShell_SetVerbosity(t *testing.T) { - t.Run("Set to True", func(t *testing.T) { + t.Run("Success", func(t *testing.T) { + // Given a shell shell := NewDefaultShell(nil) + + // When setting verbosity to true shell.SetVerbosity(true) + + // Then verbosity should be true if !shell.verbose { t.Fatalf("Expected verbosity to be true, got false") } }) - t.Run("Set to False", func(t *testing.T) { + t.Run("DisableVerbosity", func(t *testing.T) { + // Given a shell shell := NewDefaultShell(nil) + + // When setting verbosity to false shell.SetVerbosity(false) + + // Then verbosity should be false if shell.verbose { t.Fatalf("Expected verbosity to be false, got true") } }) } -func TestShell_GetProjectRoot(t *testing.T) { - t.Run("Cached", func(t *testing.T) { - injector := di.NewInjector() +// ============================================================================= +// Test Private Methods +// ============================================================================= - // Given a temporary directory with a cached project root - rootDir := createTempDir(t, "project-root") - subDir := filepath.Join(rootDir, "subdir") - if err := os.Mkdir(subDir, 0755); err != nil { - t.Fatalf("Failed to create subdir: %v", err) - } +func TestShell_GetProjectRoot(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } - changeDir(t, subDir) + t.Run("Success", func(t *testing.T) { + // Given a shell with project root set + shell, _ := setup(t) + shell.projectRoot = "/test/root" - // When calling GetProjectRoot - shell := NewDefaultShell(injector) - shell.projectRoot = rootDir // Simulate cached project root - cachedProjectRoot, err := shell.GetProjectRoot() + // When getting the project root + root, err := shell.GetProjectRoot() - // Then the cached project root should be returned without error + // Then it should return the expected root if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Errorf("Expected no error, got %v", err) } - - // Normalize paths for Windows compatibility - expectedRootDir := normalizePath(rootDir) - cachedProjectRoot = normalizePath(cachedProjectRoot) - - if expectedRootDir != cachedProjectRoot { - t.Errorf("Expected cached project root %q, got %q", expectedRootDir, cachedProjectRoot) + if root != "/test/root" { + t.Errorf("Expected /test/root, got %s", root) } }) - t.Run("MaxDepthExceeded", func(t *testing.T) { - injector := di.NewInjector() - - // Mock the getwd function to simulate directory structure - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return "/mock/deep/directory/structure/level1/level2/level3/level4/level5/level6/level7/level8/level9/level10/level11", nil + t.Run("FindsProjectRoot", func(t *testing.T) { + // Given a shell in a directory with windsor.yaml + shell, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { + return "/test/current", nil } - - // Mock the osStat function to simulate file existence - originalOsStat := osStat - defer func() { osStat = originalOsStat }() - osStat = func(name string) (os.FileInfo, error) { + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == "/test/current/windsor.yaml" { + return nil, nil + } return nil, os.ErrNotExist } - // When calling GetProjectRoot - shell := NewDefaultShell(injector) - projectRoot, err := shell.GetProjectRoot() + // When getting the project root + root, err := shell.GetProjectRoot() - // Then the project root should be the original directory due to max depth exceeded + // Then it should find the root directory if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Errorf("Expected no error, got %v", err) } - expectedProjectRoot := "/mock/deep/directory/structure/level1/level2/level3/level4/level5/level6/level7/level8/level9/level10/level11" - if projectRoot != expectedProjectRoot { - t.Errorf("Expected project root to be %q, got %q", expectedProjectRoot, projectRoot) + if root != "/test/current" { + t.Errorf("Expected /test/current, got %s", root) } }) - t.Run("NoGitNoYaml", func(t *testing.T) { - injector := di.NewInjector() - - // Mock the getwd function to simulate directory structure - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return "/mock/current/dir/subdir", nil + t.Run("FindsProjectRootWithYml", func(t *testing.T) { + // Given a shell in a directory with windsor.yml + shell, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { + return "/test/current", nil } - - // Mock the osStat function to simulate file existence - originalOsStat := osStat - defer func() { osStat = originalOsStat }() - osStat = func(name string) (os.FileInfo, error) { + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == "/test/current/windsor.yml" { + return nil, nil + } return nil, os.ErrNotExist } - // When calling GetProjectRoot - shell := NewDefaultShell(injector) - projectRoot, err := shell.GetProjectRoot() + // When getting the project root + root, err := shell.GetProjectRoot() - // Then the project root should be the original directory + // Then it should find the root directory if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Errorf("Expected no error, got %v", err) } - if projectRoot != "/mock/current/dir/subdir" { - t.Errorf("Expected project root to be %q, got %q", "/mock/current/dir/subdir", projectRoot) + if root != "/test/current" { + t.Errorf("Expected /test/current, got %s", root) } }) - t.Run("GetwdFails", func(t *testing.T) { - injector := di.NewInjector() - - // Given a simulated error in getwd - originalGetwd := getwd - getwd = func() (string, error) { - return "", errors.New("simulated error") + t.Run("ErrorOnGetwdFailure", func(t *testing.T) { + // Given a shell with failing Getwd + shell, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { + return "", fmt.Errorf("getwd failed") } - defer func() { getwd = originalGetwd }() - // When calling GetProjectRoot - shell := NewDefaultShell(injector) + // When getting project root _, err := shell.GetProjectRoot() - // Then an error should be returned + // Then it should return an error if err == nil { - t.Fatalf("Expected an error, got nil") + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "getwd failed") { + t.Errorf("Expected error to contain 'getwd failed', got %v", err) + } + }) + + t.Run("MaxDepthExceeded", func(t *testing.T) { + // Given a shell in a deep directory structure without windsor.yaml/yml + shell, mocks := setup(t) + originalDir := "/test/very/deep/directory/structure/without/config/file" + mocks.Shims.Getwd = func() (string, error) { + return originalDir, nil + } + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + return nil, os.ErrNotExist + } + + // When getting the project root + root, err := shell.GetProjectRoot() + + // Then it should return the original directory without error + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if root != originalDir { + t.Errorf("Expected %s, got %s", originalDir, root) } }) } func TestShell_Exec(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } + t.Run("Success", func(t *testing.T) { - expectedOutput := "hello\n" - command := "echo" - args := []string{"hello"} - - // Mock execCommand to simulate command execution - originalExecCommand := execCommand - execCommand = func(name string, arg ...string) *exec.Cmd { - cmd := exec.Command("echo", "hello") - cmd.Stdout = &bytes.Buffer{} + // Given a shell with mocked command execution + shell, mocks := setup(t) + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + cmd := exec.Command("echo", "test") return cmd } - defer func() { execCommand = originalExecCommand }() - - // Mock cmdStart to simulate successful command start - originalCmdStart := cmdStart - cmdStart = func(cmd *exec.Cmd) error { + mocks.Shims.CmdStart = func(cmd *exec.Cmd) error { + if w, ok := cmd.Stdout.(io.Writer); ok { + if _, err := w.Write([]byte("test\n")); err != nil { + return fmt.Errorf("failed to write to stdout: %v", err) + } + } return nil } - defer func() { cmdStart = originalCmdStart }() - - // Mock cmdWait to simulate successful command execution - originalCmdWait := cmdWait - cmdWait = func(cmd *exec.Cmd) error { - cmd.Stdout.Write([]byte("hello\n")) + mocks.Shims.CmdWait = func(cmd *exec.Cmd) error { return nil } - defer func() { cmdWait = originalCmdWait }() - injector := di.NewInjector() - shell := NewDefaultShell(injector) + // Capture stdout during test + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + // When executing a command + out, err := shell.Exec("test", "arg") + + // Restore stdout + w.Close() + os.Stdout = oldStdout + io.ReadAll(r) - output, err := shell.Exec(command, args...) + // Then it should succeed and return output if err != nil { - t.Fatalf("Failed to execute command: %v", err) + t.Errorf("Expected no error, got %v", err) } - if output != expectedOutput { - t.Fatalf("Expected output %q, got %q", expectedOutput, output) + if out != "test\n" { + t.Errorf("Expected output 'test\n', got '%s'", out) } }) - t.Run("ErrorRunningCommand", func(t *testing.T) { - command := "nonexistentcommand" - args := []string{} - - // Mock cmdStart to simulate command execution failure - originalCmdStart := cmdStart - cmdStart = func(cmd *exec.Cmd) error { - return fmt.Errorf("command start failed: exec: \"%s\": executable file not found in $PATH", command) + t.Run("Error", func(t *testing.T) { + // Given a shell with failing command + shell, mocks := setup(t) + mocks.Shims.Command = func(name string, arg ...string) *exec.Cmd { + cmd := exec.Command("echo", "test") + return cmd + } + mocks.Shims.CmdStart = func(cmd *exec.Cmd) error { + return nil + } + mocks.Shims.CmdWait = func(cmd *exec.Cmd) error { + return fmt.Errorf("command failed") } - defer func() { cmdStart = originalCmdStart }() - shell := NewDefaultShell(nil) + // When executing a command + out, err := shell.Exec("test", "arg") - _, err := shell.Exec(command, args...) + // Then it should return an error if err == nil { - t.Fatalf("Expected error when executing nonexistent command, got nil") + t.Error("Expected error, got nil") } - expectedError := fmt.Sprintf("command start failed: exec: \"%s\": executable file not found in $PATH", command) - if !strings.Contains(err.Error(), expectedError) { - t.Fatalf("Expected error to contain %q, got %q", expectedError, err.Error()) + if !strings.Contains(err.Error(), "command failed") { + t.Errorf("Expected error to contain 'command failed', got %v", err) + } + if out != "" { + t.Errorf("Expected empty output, got '%s'", out) } }) - t.Run("ErrorWaitingForCommand", func(t *testing.T) { - command := "echo" - args := []string{"hello"} - - // Mock execCommand to simulate command execution - originalExecCommand := execCommand - execCommand = mockExecCommandError - defer func() { execCommand = originalExecCommand }() - - // Mock cmdStart to simulate successful command start - originalCmdStart := cmdStart - cmdStart = func(cmd *exec.Cmd) error { - return nil + t.Run("ErrorOnStart", func(t *testing.T) { + // Given a shell with failing command start + shell, mocks := setup(t) + mocks.Shims.Command = func(name string, arg ...string) *exec.Cmd { + cmd := exec.Command("echo", "test") + return cmd } - defer func() { cmdStart = originalCmdStart }() - - // Mock cmdWait to simulate an error when waiting for the command - originalCmdWait := cmdWait - cmdWait = func(cmd *exec.Cmd) error { - return fmt.Errorf("failed to wait for command") + mocks.Shims.CmdStart = func(cmd *exec.Cmd) error { + return fmt.Errorf("failed to start command") } - defer func() { cmdWait = originalCmdWait }() - shell := NewDefaultShell(nil) - _, err := shell.Exec(command, args...) + // When executing a command + out, err := shell.Exec("test", "arg") + + // Then it should return an error if err == nil { - t.Fatalf("Expected error, got nil") + t.Error("Expected error, got nil") } - expectedError := "failed to wait for command" - if !strings.Contains(err.Error(), expectedError) { - t.Fatalf("Expected error to contain %q, got %q", expectedError, err.Error()) + if !strings.Contains(err.Error(), "failed to start command") { + t.Errorf("Expected error to contain 'failed to start command', got %v", err) + } + if out != "" { + t.Errorf("Expected empty output, got '%s'", out) } }) } func TestShell_ExecSudo(t *testing.T) { - // Mock cmdRun, cmdStart, cmdWait, and osOpenFile to simulate command execution - originalCmdRun := cmdRun - originalCmdStart := cmdStart - originalCmdWait := cmdWait - originalOsOpenFile := osOpenFile - - defer func() { - cmdRun = originalCmdRun - cmdStart = originalCmdStart - cmdWait = originalCmdWait - osOpenFile = originalOsOpenFile - }() - - cmdRun = func(cmd *exec.Cmd) error { - _, _ = cmd.Stdout.Write([]byte("hello\n")) - return nil - } - cmdStart = func(cmd *exec.Cmd) error { - _, _ = cmd.Stdout.Write([]byte("hello\n")) - return nil - } - cmdWait = func(_ *exec.Cmd) error { - return nil - } - osOpenFile = func(_ string, _ int, _ os.FileMode) (*os.File, error) { - return &os.File{}, nil + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks } t.Run("Success", func(t *testing.T) { - command := "echo" - args := []string{"hello"} + shell, mocks := setup(t) - shell := NewDefaultShell(nil) + // Mock command to return test output + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + cmd := exec.Command("echo", "test") + return cmd + } + + // Mock successful command execution + mocks.Shims.CmdStart = func(cmd *exec.Cmd) error { + if cmd.Stdout != nil { + if w, ok := cmd.Stdout.(io.Writer); ok { + if _, err := w.Write([]byte("test output")); err != nil { + return fmt.Errorf("failed to write to stdout: %v", err) + } + } + } + return nil + } + mocks.Shims.CmdWait = func(cmd *exec.Cmd) error { + return nil + } + mocks.Shims.OpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) { + return os.NewFile(0, "test"), nil + } - output, err := shell.ExecSudo("Test Sudo Command", command, args...) + output, err := shell.ExecSudo("Running test", "test", "arg") if err != nil { - t.Fatalf("Expected no error, got %v", err) + t.Errorf("Expected no error, got: %v", err) } - expectedOutput := "hello\n" - if output != expectedOutput { - t.Fatalf("Expected output %q, got %q", expectedOutput, output) + if output != "test output" { + t.Errorf("Expected output 'test output', got: %q", output) } }) - t.Run("ErrorOpeningTTY", func(t *testing.T) { - // Mock osOpenFile to simulate an error when opening /dev/tty - originalOsOpenFile := osOpenFile - osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) { + t.Run("SuccessWithVerbose", func(t *testing.T) { + // Given a shell with verbose mode enabled + shell, mocks := setup(t) + shell.SetVerbosity(true) + + // Mock command execution + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + cmd := &exec.Cmd{ + Path: name, + Args: append([]string{name}, args...), + Stdout: new(bytes.Buffer), + Stderr: new(bytes.Buffer), + } + return cmd + } + + // Mock successful command execution + mocks.Shims.CmdStart = func(cmd *exec.Cmd) error { + if cmd.Stdout != nil { + if w, ok := cmd.Stdout.(io.Writer); ok { + if _, err := w.Write([]byte("test\n")); err != nil { + return fmt.Errorf("failed to write to stdout: %v", err) + } + } + } + return nil + } + mocks.Shims.CmdWait = func(cmd *exec.Cmd) error { + return nil + } + + // Mock TTY handling + mocks.Shims.OpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) { if name == "/dev/tty" { - return nil, fmt.Errorf("failed to open /dev/tty") + return os.NewFile(0, "/dev/tty"), nil } - return originalOsOpenFile(name, flag, perm) + return nil, fmt.Errorf("unexpected file: %s", name) } - defer func() { osOpenFile = originalOsOpenFile }() // Restore original function after test - shell := NewDefaultShell(nil) - _, err := shell.ExecSudo("Test Sudo Command", "echo", "hello") + // When executing a sudo command + output, err := shell.ExecSudo("test", "command") + + // Then it should succeed and return the expected output + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + if output != "test\n" { + t.Errorf("Expected 'test\\n' output, got: %q", output) + } + }) + + t.Run("ErrorOnOpenTTY", func(t *testing.T) { + shell, mocks := setup(t) + mocks.Shims.OpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) { + return nil, fmt.Errorf("failed to open tty") + } + + output, err := shell.ExecSudo("Running test", "test", "arg") if err == nil { - t.Fatalf("Expected error, got nil") + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to open /dev/tty") { + t.Errorf("Expected error about tty, got: %v", err) } - expectedError := "failed to open /dev/tty" - if !strings.Contains(err.Error(), expectedError) { - t.Fatalf("Expected error to contain %q, got %q", expectedError, err.Error()) + if output != "" { + t.Errorf("Expected empty output, got: %s", output) } }) - t.Run("ErrorStartingCommand", func(t *testing.T) { - // Mock cmdStart to simulate an error when starting the command - originalCmdStart := cmdStart - cmdStart = func(cmd *exec.Cmd) error { + t.Run("ErrorOnStart", func(t *testing.T) { + shell, mocks := setup(t) + mocks.Shims.CmdStart = func(cmd *exec.Cmd) error { return fmt.Errorf("failed to start command") } - defer func() { - cmdStart = originalCmdStart - }() - command := "echo" - args := []string{"hello"} - shell := NewDefaultShell(nil) - _, err := shell.ExecSudo("Test Sudo Command", command, args...) + output, err := shell.ExecSudo("Running test", "test", "arg") if err == nil { - t.Fatalf("Expected error, got nil") + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to start command") { + t.Errorf("Expected error about command start, got: %v", err) } - expectedError := "failed to start command" - if !strings.Contains(err.Error(), expectedError) { - t.Fatalf("Expected error to contain %q, got %q", expectedError, err.Error()) + if output != "" { + t.Errorf("Expected empty output, got: %s", output) } }) - t.Run("ErrorWaitingForCommand", func(t *testing.T) { - // Mock cmdWait to simulate an error when waiting for the command - cmdWait = func(cmd *exec.Cmd) error { - return fmt.Errorf("failed to wait for command") + t.Run("ErrorOnWait", func(t *testing.T) { + // Setup + shims := NewShims() + sh := NewDefaultShell(nil) + sh.shims = shims + + expectedOutput := "test output" + expectedErr := fmt.Errorf("wait error") + + // Mock command execution + shims.Command = func(name string, args ...string) *exec.Cmd { + cmd := exec.Command("echo", "test") + return cmd + } + shims.CmdStart = func(cmd *exec.Cmd) error { + if cmd.Stdout != nil { + if w, ok := cmd.Stdout.(*bytes.Buffer); ok { + w.WriteString(expectedOutput) + } + } + return nil + } + shims.CmdWait = func(cmd *exec.Cmd) error { + return expectedErr + } + shims.OpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) { + return os.NewFile(0, "test"), nil } - defer func() { cmdWait = func(cmd *exec.Cmd) error { return cmd.Wait() } }() // Restore original function after test - command := "echo" - args := []string{"hello"} - shell := NewDefaultShell(nil) - _, err := shell.ExecSudo("Test Sudo Command", command, args...) + // Execute + output, err := sh.ExecSudo("test", "test", "arg") + + // Assert if err == nil { - t.Fatalf("Expected error, got nil") + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), expectedErr.Error()) { + t.Errorf("Expected error containing %v, got %v", expectedErr, err) } - expectedError := "failed to wait for command" - if !strings.Contains(err.Error(), expectedError) { - t.Fatalf("Expected error to contain %q, got %q", expectedError, err.Error()) + if output != expectedOutput { + t.Errorf("Expected output %q, got %q", expectedOutput, output) } }) - t.Run("VerboseOutput", func(t *testing.T) { - command := "echo" - args := []string{"hello"} - - shell := NewDefaultShell(nil) - shell.SetVerbosity(true) + t.Run("SudoCommand", func(t *testing.T) { + // Given a shell with verbose mode disabled + shell, mocks := setup(t) - // Mock execCommand to simulate command execution - originalExecCommand := execCommand - execCommand = func(name string, arg ...string) *exec.Cmd { + // Mock command execution + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { cmd := &exec.Cmd{ - Stdout: &bytes.Buffer{}, - Stderr: &bytes.Buffer{}, + Path: name, + Args: append([]string{name}, args...), + Stdout: new(bytes.Buffer), + Stderr: new(bytes.Buffer), } return cmd } - defer func() { execCommand = originalExecCommand }() - - // Mock cmdStart to simulate successful command start - originalCmdStart := cmdStart - cmdStart = func(cmd *exec.Cmd) error { - _, _ = cmd.Stdout.Write([]byte("hello\n")) + mocks.Shims.CmdStart = func(cmd *exec.Cmd) error { return nil } - defer func() { cmdStart = originalCmdStart }() - - // Mock cmdWait to simulate successful command completion - originalCmdWait := cmdWait - cmdWait = func(cmd *exec.Cmd) error { + mocks.Shims.CmdWait = func(cmd *exec.Cmd) error { return nil } - defer func() { cmdWait = originalCmdWait }() + mocks.Shims.OpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) { + return os.NewFile(0, "test"), nil + } - stdout, stderr := captureStdoutAndStderr(t, func() { - output, err := shell.ExecSudo("Test Sudo Command", command, args...) - if err != nil { - t.Fatalf("Expected no error, got %v", err) - } - expectedOutput := "hello\n" - if output != expectedOutput { - t.Fatalf("Expected output %q, got %q", expectedOutput, output) - } - }) + // When executing a sudo command + output, err := shell.ExecSudo("test", "sudo", "command") + + // Then it should succeed and return empty output + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + if output != "" { + t.Errorf("Expected empty output, got: %q", output) + } + }) - // Validate stdout and stderr - expectedStdout := "hello\n" - if stdout != expectedStdout { - t.Fatalf("Expected stdout %q, got %q", expectedStdout, stdout) + t.Run("CommandNil", func(t *testing.T) { + // Given a shell with Command returning nil + shell, mocks := setup(t) + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + return nil } - expectedVerboseOutput := "Test Sudo Command\n" - if !strings.Contains(stderr, expectedVerboseOutput) { - t.Fatalf("Expected verbose output %q, got stderr: %q", expectedVerboseOutput, stderr) + // When executing a sudo command + output, err := shell.ExecSudo("test", "command") + + // Then it should fail with the expected error + if err == nil { + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to create command") { + t.Errorf("Expected error about failed command creation, got: %v", err) + } + if output != "" { + t.Errorf("Expected empty output, got: %q", output) } }) } func TestShell_ExecSilent(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } + t.Run("Success", func(t *testing.T) { - command := "go" - args := []string{"version"} + // Given a shell with mocked command execution + shell, _ := setup(t) - shell := NewDefaultShell(nil) - output, err := shell.ExecSilent(command, args...) + // When executing a command silently + out, err := shell.ExecSilent("test", "arg") + + // Then it should succeed and return output if err != nil { - t.Fatalf("Expected no error, got %v", err) + t.Errorf("Expected no error, got %v", err) } - expectedOutputPrefix := "go version" - if !strings.HasPrefix(output, expectedOutputPrefix) { - t.Fatalf("Expected output to start with %q, got %q", expectedOutputPrefix, output) + if out != "test\n" { + t.Errorf("Expected output 'test\n', got '%s'", out) } }) - t.Run("ErrorRunningCommand", func(t *testing.T) { - // Mock cmdRun to simulate an error when running the command - cmdRun = func(cmd *exec.Cmd) error { - return fmt.Errorf("failed to run command") - } - defer func() { cmdRun = func(cmd *exec.Cmd) error { return cmd.Run() } }() // Restore original function after test - - command := "nonexistentcommand" - args := []string{} - shell := NewDefaultShell(nil) - _, err := shell.ExecSilent(command, args...) - if err == nil { - t.Fatalf("Expected error, got nil") + t.Run("Error", func(t *testing.T) { + // Given a shell with failing command + shell, mocks := setup(t) + mocks.Shims.Command = func(name string, arg ...string) *exec.Cmd { + cmd := exec.Command("echo", "test") + return cmd } - expectedError := "command execution failed" - if !strings.Contains(err.Error(), expectedError) { - t.Fatalf("Expected error to contain %q, got %q", expectedError, err.Error()) + mocks.Shims.CmdRun = func(cmd *exec.Cmd) error { + return fmt.Errorf("command failed") } - }) - t.Run("VerboseOutput", func(t *testing.T) { - command := "go" - args := []string{"version"} + // When executing a command silently + out, err := shell.ExecSilent("test", "arg") - // Mock execCommand to simulate command execution - originalExecCommand := execCommand - execCommand = func(name string, arg ...string) *exec.Cmd { - cmd := &exec.Cmd{ - Stdout: &bytes.Buffer{}, - Stderr: &bytes.Buffer{}, - } - cmd.Stdout.Write([]byte("go version go1.16.3\n")) - return cmd + // Then it should return an error + if err == nil { + t.Error("Expected error, got nil") } - defer func() { execCommand = originalExecCommand }() - - // Mock cmdStart and cmdWait to simulate command execution without hanging - originalCmdStart := cmdStart - cmdStart = func(cmd *exec.Cmd) error { - cmd.Stdout.Write([]byte("go version go1.16.3\n")) - return nil + if !strings.Contains(err.Error(), "command failed") { + t.Errorf("Expected error to contain 'command failed', got %v", err) + } + if out != "" { + t.Errorf("Expected empty output, got '%s'", out) } - defer func() { cmdStart = originalCmdStart }() + }) - originalCmdWait := cmdWait - cmdWait = func(cmd *exec.Cmd) error { + t.Run("CommandNil", func(t *testing.T) { + // Given a shell with Command returning nil + shell, mocks := setup(t) + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { return nil } - defer func() { cmdWait = originalCmdWait }() - shell := NewDefaultShell(nil) - shell.SetVerbosity(true) + // When executing a command silently + output, err := shell.ExecSilent("test", "arg") - stdout, _ := captureStdoutAndStderr(t, func() { - output, err := shell.ExecSilent(command, args...) - if err != nil { - t.Fatalf("Expected no error, got %v", err) - } - expectedOutputPrefix := "go version" - if !strings.HasPrefix(output, expectedOutputPrefix) { - t.Fatalf("Expected output to start with %q, got %q", expectedOutputPrefix, output) - } - }) - - expectedVerboseOutput := "go version" - if !strings.Contains(stdout, expectedVerboseOutput) { - t.Fatalf("Expected verbose output to contain %q, got %q", expectedVerboseOutput, stdout) + // Then it should return an error + if err == nil { + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to create command") { + t.Errorf("Expected error about command creation, got: %v", err) + } + if output != "" { + t.Errorf("Expected empty output, got %q", output) } }) -} -func TestShell_ExecProgress(t *testing.T) { - // Helper function to mock a command execution - mockCommandExecution := func() { - execCommand = func(command string, args ...string) *exec.Cmd { - return exec.Command("go", "version") + t.Run("StdoutPipeError", func(t *testing.T) { + // Given a shell with failing StdoutPipe + shell, mocks := setup(t) + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + return exec.Command("echo", "test") + } + mocks.Shims.StdoutPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + return nil, fmt.Errorf("stdout pipe error") } - } - // Helper function to mock stdout pipe - mockStdoutPipe := func() { - cmdStdoutPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { - r, w := io.Pipe() - go func() { - defer w.Close() - w.Write([]byte("go version go1.16.3\n")) - }() - return r, nil + // When executing a command with progress + output, err := shell.ExecProgress("test message", "test", "arg") + + // Then it should return an error + if err == nil { + t.Error("Expected error, got nil") } - } + if !strings.Contains(err.Error(), "stdout pipe error") { + t.Errorf("Expected error about stdout pipe, got: %v", err) + } + if output != "" { + t.Errorf("Expected empty output, got %q", output) + } + }) - // Helper function to mock stderr pipe - mockStderrPipe := func() { - cmdStderrPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + t.Run("StderrPipeError", func(t *testing.T) { + // Given a shell with failing StderrPipe + shell, mocks := setup(t) + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + return exec.Command("echo", "test") + } + mocks.Shims.StdoutPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { r, w := io.Pipe() go func() { - defer w.Close() + if _, err := w.Write([]byte("test output\n")); err != nil { + t.Errorf("Failed to write to stdout pipe: %v", err) + } + w.Close() }() return r, nil } - } - - // Save original functions - originalExecCommand := execCommand - originalCmdStdoutPipe := cmdStdoutPipe - originalCmdStderrPipe := cmdStderrPipe - - // Mock functions - mockCommandExecution() - mockStdoutPipe() - mockStderrPipe() - - // Restore original functions after test - defer func() { - execCommand = originalExecCommand - cmdStdoutPipe = originalCmdStdoutPipe - cmdStderrPipe = originalCmdStderrPipe - }() + mocks.Shims.StderrPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + return nil, fmt.Errorf("stderr pipe error") + } - t.Run("Success", func(t *testing.T) { - command := "go" - args := []string{"version"} + // When executing a command with progress + output, err := shell.ExecProgress("test message", "test", "arg") - shell := NewDefaultShell(nil) - output, err := shell.ExecProgress("Test Progress Command", command, args...) - if err != nil { - t.Fatalf("Expected no error, got %v", err) + // Then it should return an error + if err == nil { + t.Error("Expected error, got nil") } - expectedOutput := "go version go1.16.3\n" - if output != expectedOutput { - t.Fatalf("Expected output %q, got %q", expectedOutput, output) + if !strings.Contains(err.Error(), "stderr pipe error") { + t.Errorf("Expected error about stderr pipe, got: %v", err) + } + if output != "" { + t.Errorf("Expected empty output, got %q", output) } }) - t.Run("ErrStdoutPipe", func(t *testing.T) { - command := "go" - args := []string{"version"} - - // Mock cmdStdoutPipe to simulate an error - originalCmdStdoutPipe := cmdStdoutPipe - cmdStdoutPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { - return nil, fmt.Errorf("failed to create stdout pipe") + t.Run("CmdStartError", func(t *testing.T) { + // Given a shell with failing CmdStart + shell, mocks := setup(t) + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + return exec.Command("echo", "test") + } + mocks.Shims.StdoutPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + go func() { + if _, err := w.Write([]byte("test output\n")); err != nil { + t.Errorf("Failed to write to stdout pipe: %v", err) + } + w.Close() + }() + return r, nil + } + mocks.Shims.StderrPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + go func() { + if _, err := w.Write([]byte("error\n")); err != nil { + t.Errorf("Failed to write to stderr pipe: %v", err) + } + w.Close() + }() + return r, nil + } + mocks.Shims.CmdStart = func(cmd *exec.Cmd) error { + return fmt.Errorf("command start error") } - defer func() { cmdStdoutPipe = originalCmdStdoutPipe }() // Restore original function after test - shell := NewDefaultShell(nil) - _, err := shell.ExecProgress("Test Progress Command", command, args...) + // When executing a command with progress + output, err := shell.ExecProgress("test message", "test", "arg") + + // Then it should return an error if err == nil { - t.Fatalf("Expected error, got nil") + t.Error("Expected error, got nil") } - expectedError := "failed to create stdout pipe" - if !strings.Contains(err.Error(), expectedError) { - t.Fatalf("Expected error to contain %q, got %q", expectedError, err.Error()) + if !strings.Contains(err.Error(), "command start error") { + t.Errorf("Expected error about command start, got: %v", err) + } + if output != "" { + t.Errorf("Expected empty output, got %q", output) } }) +} - t.Run("ErrStderrPipe", func(t *testing.T) { - command := "go" - args := []string{"version"} +func TestShell_GetSessionToken(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } - // Mock cmdStderrPipe to simulate an error - originalCmdStderrPipe := cmdStderrPipe - cmdStderrPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { - return nil, fmt.Errorf("failed to create stderr pipe") + t.Run("Success", func(t *testing.T) { + // Given a shell with session token + shell, mocks := setup(t) + expectedToken := "test-token" + mocks.Shims.Getenv = func(key string) string { + if key == "WINDSOR_SESSION_TOKEN" { + return expectedToken + } + return "" } - defer func() { cmdStderrPipe = originalCmdStderrPipe }() // Restore original function after test - shell := NewDefaultShell(nil) - _, err := shell.ExecProgress("Test Progress Command", command, args...) - if err == nil { - t.Fatalf("Expected error, got nil") + // When getting session token + token, err := shell.GetSessionToken() + + // Then it should return the expected token + if err != nil { + t.Errorf("Expected no error, got %v", err) } - expectedError := "failed to create stderr pipe" - if !strings.Contains(err.Error(), expectedError) { - t.Fatalf("Expected error to contain %q, got %q", expectedError, err.Error()) + if token != expectedToken { + t.Errorf("Expected token %q, got %q", expectedToken, token) } }) - t.Run("ErrStartCommand", func(t *testing.T) { - command := "go" - args := []string{"version"} + t.Run("NoToken", func(t *testing.T) { + // Given a shell with mocked operations + shell, mocks := setup(t) - // Mock cmdStart to simulate an error - originalCmdStart := cmdStart - cmdStart = func(cmd *exec.Cmd) error { - return fmt.Errorf("failed to start command") + // Mock environment to return no token + mocks.Shims.Getenv = func(key string) string { + return "" } - defer func() { cmdStart = originalCmdStart }() // Restore original function after test - shell := NewDefaultShell(nil) - _, err := shell.ExecProgress("Test Progress Command", command, args...) - if err == nil { - t.Fatalf("Expected error, got nil") + // Mock random generation + mocks.Shims.RandRead = func(b []byte) (n int, err error) { + for i := range b { + b[i] = byte('a') + } + return len(b), nil } - expectedError := "failed to start command" - if !strings.Contains(err.Error(), expectedError) { - t.Fatalf("Expected error to contain %q, got %q", expectedError, err.Error()) + + // When getting session token + token, err := shell.GetSessionToken() + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + // Then it should return a token of length 7 + if len(token) != 7 { + t.Errorf("Expected token length 7, got %d", len(token)) } }) +} - t.Run("ErrBufioScannerScan", func(t *testing.T) { - command := "go" - args := []string{"version"} +func TestShell_CheckResetFlags(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } - // Mock bufioScannerScan to simulate an error - originalBufioScannerScan := bufioScannerScan - bufioScannerScan = func(scanner *bufio.Scanner) bool { - return false + t.Run("Success", func(t *testing.T) { + shell, mocks := setup(t) + expectedToken := "test-token" + mocks.Shims.Getenv = func(key string) string { + if key == "WINDSOR_SESSION_TOKEN" { + return expectedToken + } + return "" } - defer func() { bufioScannerScan = originalBufioScannerScan }() // Restore original function after test - - // Mock bufioScannerErr to return an error - originalBufioScannerErr := bufioScannerErr - bufioScannerErr = func(scanner *bufio.Scanner) error { - return fmt.Errorf("error reading stdout") + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if strings.HasSuffix(name, expectedToken) { + return nil, nil + } + return nil, os.ErrNotExist + } + mocks.Shims.Glob = func(pattern string) ([]string, error) { + return []string{"/test/project/.windsor/session-test-token"}, nil + } + mocks.Shims.RemoveAll = func(path string) error { + return nil } - defer func() { bufioScannerErr = originalBufioScannerErr }() // Restore original function after test - shell := NewDefaultShell(nil) - _, err := shell.ExecProgress("Test Progress Command", command, args...) - if err == nil { - t.Fatalf("Expected error, got nil") + shouldReset, err := shell.CheckResetFlags() + if err != nil { + t.Errorf("Expected no error, got %v", err) } - expectedError := "error reading stdout" - if !strings.Contains(err.Error(), expectedError) { - t.Fatalf("Expected error to contain %q, got %q", expectedError, err.Error()) + if !shouldReset { + t.Error("Expected shouldReset to be true") } }) - t.Run("ErrBufioScannerErr", func(t *testing.T) { - command := "go" - args := []string{"version"} - - // Mock cmdStdoutPipe and cmdStderrPipe to return a pipe that can be scanned - originalCmdStdoutPipe := cmdStdoutPipe - cmdStdoutPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { - r, w := io.Pipe() - go func() { - defer w.Close() - w.Write([]byte("stdout line\n")) - }() - return r, nil + t.Run("NoResetToken", func(t *testing.T) { + shell, mocks := setup(t) + mocks.Shims.Getenv = func(key string) string { + return "" } - defer func() { cmdStdoutPipe = originalCmdStdoutPipe }() - originalCmdStderrPipe := cmdStderrPipe - cmdStderrPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { - r, w := io.Pipe() - go func() { - defer w.Close() - w.Write([]byte("stderr line\n")) - }() - return r, nil + shouldReset, err := shell.CheckResetFlags() + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if shouldReset { + t.Error("Expected shouldReset to be false") } - defer func() { cmdStderrPipe = originalCmdStderrPipe }() + }) - // Mock bufioScannerErr to return an error - originalBufioScannerErr := bufioScannerErr - bufioScannerErr = func(scanner *bufio.Scanner) error { - return fmt.Errorf("error reading stderr") + t.Run("TokenMismatch", func(t *testing.T) { + shell, mocks := setup(t) + mocks.Shims.Getenv = func(key string) string { + if key == "WINDSOR_RESET_TOKEN" { + return "test-token" + } + return "" + } + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + return nil, os.ErrNotExist + } + mocks.Shims.Glob = func(pattern string) ([]string, error) { + return []string{}, nil } - defer func() { bufioScannerErr = originalBufioScannerErr }() // Restore original function after test - shell := NewDefaultShell(nil) - _, err := shell.ExecProgress("Test Progress Command", command, args...) - if err == nil { - t.Fatalf("Expected error, got nil") + shouldReset, err := shell.CheckResetFlags() + if err != nil { + t.Errorf("Expected no error, got %v", err) } - expectedError := "error reading stderr" - if !strings.Contains(err.Error(), expectedError) { - t.Fatalf("Expected error to contain %q, got %q", expectedError, err.Error()) + if shouldReset { + t.Error("Expected shouldReset to be false") } }) - t.Run("ErrCmdWait", func(t *testing.T) { - command := "go" - args := []string{"version"} - - // Mock cmdWait to return an error - originalCmdWait := cmdWait - cmdWait = func(cmd *exec.Cmd) error { - return fmt.Errorf("error waiting for command") + t.Run("ErrorOnGlob", func(t *testing.T) { + // Given a shell with failing Glob + shell, mocks := setup(t) + mocks.Shims.Getenv = func(key string) string { + if key == "WINDSOR_SESSION_TOKEN" { + return "test-token" + } + return "" + } + mocks.Shims.Glob = func(pattern string) ([]string, error) { + return nil, fmt.Errorf("glob error") } - defer func() { cmdWait = originalCmdWait }() // Restore original function after test - shell := NewDefaultShell(nil) - _, err := shell.ExecProgress("Test Progress Command", command, args...) + // When checking reset flags + _, err := shell.CheckResetFlags() + + // Then it should return an error if err == nil { - t.Fatalf("Expected error, got nil") + t.Error("Expected error, got nil") } - expectedError := "error waiting for command" - if !strings.Contains(err.Error(), expectedError) { - t.Fatalf("Expected error to contain %q, got %q", expectedError, err.Error()) + if !strings.Contains(err.Error(), "glob error") { + t.Errorf("Expected error to contain 'glob error', got %v", err) } }) - t.Run("VerboseOutput", func(t *testing.T) { - command := "go" - args := []string{"version"} - - shell := NewDefaultShell(nil) - shell.SetVerbosity(true) - - // Mock execCommand to simulate command execution - originalExecCommand := execCommand - execCommand = func(name string, arg ...string) *exec.Cmd { - cmd := &exec.Cmd{ - Stdout: &bytes.Buffer{}, - Stderr: &bytes.Buffer{}, + t.Run("ErrorOnRemoveAll", func(t *testing.T) { + // Given a shell with failing RemoveAll + shell, mocks := setup(t) + mocks.Shims.Getenv = func(key string) string { + if key == "WINDSOR_SESSION_TOKEN" { + return "test-token" } - return cmd + return "" } - defer func() { execCommand = originalExecCommand }() // Restore original function after test - - // Mock cmdStart and cmdWait to simulate command execution without hanging - originalCmdStart := cmdStart - cmdStart = func(cmd *exec.Cmd) error { - cmd.Stdout.Write([]byte("go version go1.16.3 darwin/amd64\n")) - return nil + mocks.Shims.Glob = func(pattern string) ([]string, error) { + return []string{"/test/project/.windsor/session-test-token"}, nil } - defer func() { cmdStart = originalCmdStart }() // Restore original function after test - - originalCmdWait := cmdWait - cmdWait = func(cmd *exec.Cmd) error { - return nil + mocks.Shims.RemoveAll = func(path string) error { + return fmt.Errorf("remove error") } - defer func() { cmdWait = originalCmdWait }() // Restore original function after test - stdout, stderr := captureStdoutAndStderr(t, func() { - output, err := shell.ExecProgress("Test Progress Command", command, args...) - if err != nil { - t.Fatalf("Expected no error, got %v", err) - } - expectedOutputPrefix := "go version" - if !strings.HasPrefix(output, expectedOutputPrefix) { - t.Fatalf("Expected output to start with %q, got %q", expectedOutputPrefix, output) - } - }) + // When checking reset flags + _, err := shell.CheckResetFlags() - expectedVerboseOutput := "Test Progress Command\n" - if !strings.Contains(stderr, expectedVerboseOutput) { - t.Fatalf("Expected verbose output %q, got %q", expectedVerboseOutput, stderr) + // Then it should return an error + if err == nil { + t.Error("Expected error, got nil") } - - // Check the stdout value - expectedStdoutPrefix := "go version" - if !strings.HasPrefix(stdout, expectedStdoutPrefix) { - t.Fatalf("Expected stdout to start with %q, got %q", expectedStdoutPrefix, stdout) + if !strings.Contains(err.Error(), "remove error") { + t.Errorf("Expected error to contain 'remove error', got %v", err) } }) } -func TestShell_InstallHook(t *testing.T) { - t.Run("Success", func(t *testing.T) { - shell := NewDefaultShell(nil) +func TestShell_GenerateRandomString(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } - // Capture stdout to validate the output - output := captureStdout(t, func() { - if err := shell.InstallHook("bash"); err != nil { - t.Fatalf("Expected no error, got %v", err) + t.Run("Success", func(t *testing.T) { + shell, mocks := setup(t) + mocks.Shims.RandRead = func(b []byte) (n int, err error) { + for i := range b { + b[i] = byte(i % 62) // This will map to characters in the charset } - }) - - // Validate the output contains expected content - expectedOutput := "_windsor_hook" // Replace with actual expected output - if !strings.Contains(output, expectedOutput) { - t.Fatalf("Expected output to contain %q, got %q", expectedOutput, output) + return len(b), nil + } + result, err := shell.generateRandomString(10) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if len(result) != 10 { + t.Errorf("Expected length 10, got %d", len(result)) } }) +} - t.Run("PowerShellSuccess", func(t *testing.T) { - shell := NewDefaultShell(nil) +func TestShell_InstallHook(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } - // Capture stdout to validate the output - output := captureStdout(t, func() { - if err := shell.InstallHook("powershell"); err != nil { - t.Fatalf("Expected no error, got %v", err) - } - }) + t.Run("UnsupportedShell", func(t *testing.T) { + // Given a shell with an unsupported shell type + shell, _ := setup(t) - // Validate the output contains expected content - expectedOutput := "function prompt" // Replace with actual expected output for PowerShell - if !strings.Contains(output, expectedOutput) { - t.Fatalf("Expected output to contain %q, got %q", expectedOutput, output) - } - }) + // When installing a hook for an unsupported shell + err := shell.InstallHook("unsupported") - t.Run("UnsupportedShell", func(t *testing.T) { - shell := NewDefaultShell(nil) - err := shell.InstallHook("unsupported-shell") + // Then it should return an error if err == nil { - t.Fatalf("Expected an error for unsupported shell, but got nil") - } else { - expectedError := "Unsupported shell: unsupported-shell" - if err.Error() != expectedError { - t.Fatalf("Expected error message %q, but got %q", expectedError, err.Error()) - } + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "Unsupported shell: unsupported") { + t.Errorf("Expected error about unsupported shell, got: %v", err) } }) - t.Run("ErrorGettingSelfPath", func(t *testing.T) { - shell := NewDefaultShell(nil) - - // Mock osExecutable to simulate an error - originalOsExecutable := osExecutable - osExecutable = func() (string, error) { - return "", fmt.Errorf("executable file not found") + t.Run("ExecutableError", func(t *testing.T) { + // Given a shell with failing Executable + shell, mocks := setup(t) + mocks.Shims.Executable = func() (string, error) { + return "", fmt.Errorf("executable error") } - defer func() { osExecutable = originalOsExecutable }() // Restore original function after test - err := shell.InstallHook("bash") + // When installing a hook + err := shell.InstallHook("zsh") + + // Then it should return an error if err == nil { - t.Fatalf("Expected error due to self path retrieval failure, but got nil") - } else { - expectedError := "executable file not found" - if !strings.Contains(err.Error(), expectedError) { - t.Fatalf("Expected error message to contain %q, but got %q", expectedError, err.Error()) - } + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "executable error") { + t.Errorf("Expected error about executable, got: %v", err) } }) - t.Run("ErrorCreatingNewTemplate", func(t *testing.T) { - shell := NewDefaultShell(nil) - - // Mock hookTemplateNew to simulate an error - originalHookTemplateNew := hookTemplateNew - hookTemplateNew = func(name string) *template.Template { + t.Run("TemplateCreateError", func(t *testing.T) { + // Given a shell with failing template creation + shell, mocks := setup(t) + mocks.Shims.Executable = func() (string, error) { + return "/usr/bin/windsor", nil + } + mocks.Shims.NewTemplate = func(name string) *template.Template { return nil } - defer func() { hookTemplateNew = originalHookTemplateNew }() // Restore original function after test - err := shell.InstallHook("bash") + // When installing a hook + err := shell.InstallHook("zsh") + + // Then it should return an error if err == nil { - t.Fatalf("Expected error due to hook template creation failure, but got nil") - } else { - expectedError := "failed to create new template" - if !strings.Contains(err.Error(), expectedError) { - t.Fatalf("Expected error message to contain %q, but got %q", expectedError, err.Error()) - } + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to create new template") { + t.Errorf("Expected error about template creation, got: %v", err) } }) - t.Run("ErrorParsingHookTemplate", func(t *testing.T) { - shell := NewDefaultShell(nil) + t.Run("TemplateParseError", func(t *testing.T) { + // Given a shell with failing template parse + shell, mocks := setup(t) + mocks.Shims.Executable = func() (string, error) { + return "/usr/bin/windsor", nil + } + mocks.Shims.TemplateParse = func(tmpl *template.Template, text string) (*template.Template, error) { + return nil, fmt.Errorf("parse error") + } + + // When installing a hook + err := shell.InstallHook("zsh") + + // Then it should return an error + if err == nil { + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to parse hook template") { + t.Errorf("Expected error about template parsing, got: %v", err) + } + }) - // Mock hookTemplateParse to simulate a parsing error - originalHookTemplateParse := hookTemplateParse - hookTemplateParse = func(tmpl *template.Template, text string) (*template.Template, error) { - return nil, fmt.Errorf("template parsing error") + t.Run("TemplateExecuteError", func(t *testing.T) { + // Given a shell with failing template execution + shell, mocks := setup(t) + mocks.Shims.Executable = func() (string, error) { + return "/usr/bin/windsor", nil + } + mocks.Shims.TemplateExecute = func(tmpl *template.Template, wr io.Writer, data any) error { + return fmt.Errorf("execute error") } - defer func() { hookTemplateParse = originalHookTemplateParse }() // Restore original function after test - err := shell.InstallHook("bash") + // When installing a hook + err := shell.InstallHook("zsh") + + // Then it should return an error if err == nil { - t.Fatalf("Expected error due to hook template parsing failure, but got nil") - } else { - expectedError := "template parsing error" - if !strings.Contains(err.Error(), expectedError) { - t.Fatalf("Expected error message to contain %q, but got %q", expectedError, err.Error()) - } + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to execute hook template") { + t.Errorf("Expected error about template execution, got: %v", err) } }) - t.Run("ErrorParsingHookTemplate", func(t *testing.T) { - shell := NewDefaultShell(nil) + t.Run("Success", func(t *testing.T) { + // Given a shell with working template operations + shell, mocks := setup(t) + mocks.Shims.Executable = func() (string, error) { + return "/usr/bin/windsor", nil + } + mocks.Shims.TemplateExecute = func(tmpl *template.Template, wr io.Writer, data any) error { + _, err := wr.Write([]byte(` + _windsor_hook() { + trap -- '' SIGINT; + eval "$(/usr/bin/windsor env --decrypt)"; + trap - SIGINT; + }; + `)) + return err + } + + // Capture stdout + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + // When installing a hook + err := shell.InstallHook("zsh") + + // Close writer and restore stdout + w.Close() + os.Stdout = oldStdout + + // Read captured output + output, _ := io.ReadAll(r) - // Mock shellHooks to provide an invalid template command - originalShellHooks := shellHooks - shellHooks = map[string]string{ - "bash": "{{ .InvalidField }}", // Invalid template field to cause parsing error + // Then it should succeed and output the hook script + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if !strings.Contains(string(output), "_windsor_hook") { + t.Error("Expected output to contain _windsor_hook") } - defer func() { shellHooks = originalShellHooks }() // Restore original shellHooks after test + }) - err := shell.InstallHook("bash") - if err == nil { - t.Fatalf("Expected error due to hook template parsing failure, but got nil") - } else { - expectedError := "can't evaluate field InvalidField" - if !strings.Contains(err.Error(), expectedError) { - t.Fatalf("Expected error message to contain %q, but got %q", expectedError, err.Error()) - } + t.Run("PowerShellSuccess", func(t *testing.T) { + // Given a shell with working template operations + shell, mocks := setup(t) + mocks.Shims.Executable = func() (string, error) { + return "/usr/bin/windsor", nil + } + mocks.Shims.TemplateExecute = func(tmpl *template.Template, wr io.Writer, data any) error { + _, err := wr.Write([]byte(` + function prompt { + $windsorEnvScript = & "/usr/bin/windsor" env --decrypt | Out-String + if ($windsorEnvScript) { + Invoke-Expression $windsorEnvScript + } + } + `)) + return err + } + + // Capture stdout + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + // When installing a hook for PowerShell + err := shell.InstallHook("powershell") + + // Close writer and restore stdout + w.Close() + os.Stdout = oldStdout + + // Read captured output + output, _ := io.ReadAll(r) + + // Then it should succeed and output the hook script without empty lines + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if strings.Contains(string(output), "\n\n") { + t.Error("Expected no empty lines in PowerShell output") + } + if !strings.Contains(string(output), "function prompt") { + t.Error("Expected output to contain function prompt") } }) } -var tempDirs []string - -// Helper function to create a temporary directory -func createTempDir(t *testing.T, name string) string { - dir, err := os.MkdirTemp("", name) - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - tempDirs = append(tempDirs, dir) - return dir +// mockReadCloser implements io.ReadCloser for testing +type mockReadCloser struct { + io.Reader + closeFunc func() error } -// Helper function to create a file with specified content -func createFile(t *testing.T, dir, name, content string) { - filePath := filepath.Join(dir, name) - if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { - t.Fatalf("Failed to create file %s: %v", filePath, err) +func (m *mockReadCloser) Close() error { + if m.closeFunc != nil { + return m.closeFunc() } + return nil } -// Helper function to change the working directory -func changeDir(t *testing.T, dir string) { - originalDir, err := os.Getwd() +// ============================================================================= +// Test Helpers +// ============================================================================= + +// Helper function to create a temporary directory for testing +func createTempDir(t *testing.T, prefix string) string { + t.Helper() + dir, err := os.MkdirTemp("", prefix) if err != nil { - t.Fatalf("Failed to get current directory: %v", err) - } - if err := os.Chdir(dir); err != nil { - t.Fatalf("Failed to change directory: %v", err) + t.Fatalf("Failed to create temp dir: %v", err) } - t.Cleanup(func() { - if err := os.Chdir(originalDir); err != nil { - t.Fatalf("Failed to revert to original directory: %v", err) - } - }) + return dir } -// Helper function to normalize a path -func normalizePath(path string) string { - return strings.ReplaceAll(filepath.Clean(path), "\\", "/") +// Helper function to create a file in a directory +func createFile(t *testing.T, dir, name, content string) { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatalf("Failed to create file %s: %v", path, err) + } } // Helper function to capture stdout @@ -1041,1437 +1390,1414 @@ func captureStdout(t *testing.T, f func()) string { return output.String() } -// Updated helper function to mock exec.Command for failed execution using PowerShell -func mockExecCommandError(command string, args ...string) *exec.Cmd { - if runtime.GOOS == "windows" { - // Use PowerShell to simulate a failing command - fullCommand := fmt.Sprintf("exit 1; Write-Error 'mock error for: %s %s'", command, strings.Join(args, " ")) - cmdArgs := []string{"-Command", fullCommand} - return exec.Command("powershell.exe", cmdArgs...) - } else { - // Use 'false' command on Unix-like systems - return exec.Command("false") +func TestShell_AddCurrentDirToTrustedFile(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks } -} -// captureStdoutAndStderr captures output sent to os.Stdout and os.Stderr during the execution of f() -func captureStdoutAndStderr(t *testing.T, f func()) (string, string) { - // Save the original os.Stdout and os.Stderr - originalStdout := os.Stdout - originalStderr := os.Stderr + t.Run("Success", func(t *testing.T) { + // Given a shell with mocked operations + shell, mocks := setup(t) + projectRoot := "/test/project" + homeDir := "/home/test" + trustedFilePath := "/home/test/.config/windsor/.trusted" + + // Mock GetProjectRoot + mocks.Shims.Getwd = func() (string, error) { + return projectRoot, nil + } + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == filepath.Join(projectRoot, "windsor.yaml") { + return nil, nil + } + return nil, os.ErrNotExist + } - // Create pipes for os.Stdout and os.Stderr - rOut, wOut, _ := os.Pipe() - rErr, wErr, _ := os.Pipe() - os.Stdout = wOut - os.Stderr = wErr + // Mock UserHomeDir + mocks.Shims.UserHomeDir = func() (string, error) { + return homeDir, nil + } - // Channel to signal completion - done := make(chan struct{}) - go func() { - defer close(done) - f() - wOut.Close() - wErr.Close() - }() + // Mock file operations + mocks.Shims.MkdirAll = func(path string, perm os.FileMode) error { + if path != "/home/test/.config/windsor" { + t.Errorf("Expected path /home/test/.config/windsor, got %s", path) + } + return nil + } - // Read from the pipes - var stdoutBuf, stderrBuf bytes.Buffer - var wg sync.WaitGroup - wg.Add(2) - readFromPipe := func(pipe *os.File, buf *bytes.Buffer, pipeName string) { - defer wg.Done() - if _, err := buf.ReadFrom(pipe); err != nil { - t.Errorf("Failed to read from %s pipe: %v", pipeName, err) + mocks.Shims.ReadFile = func(name string) ([]byte, error) { + if name != trustedFilePath { + t.Errorf("Expected path %s, got %s", trustedFilePath, name) + } + return []byte("/other/project\n"), nil } - } - go readFromPipe(rOut, &stdoutBuf, "stdout") - go readFromPipe(rErr, &stderrBuf, "stderr") - // Wait for reading to complete - wg.Wait() - <-done + mocks.Shims.WriteFile = func(name string, data []byte, perm os.FileMode) error { + if name != trustedFilePath { + t.Errorf("Expected path %s, got %s", trustedFilePath, name) + } + expectedData := []byte("/other/project\n" + projectRoot + "\n") + if string(data) != string(expectedData) { + t.Errorf("Expected data %s, got %s", string(expectedData), string(data)) + } + if perm != 0600 { + t.Errorf("Expected perm 0600, got %o", perm) + } + return nil + } - // Restore os.Stdout and os.Stderr - os.Stdout = originalStdout - os.Stderr = originalStderr + // When adding current dir to trusted file + err := shell.AddCurrentDirToTrustedFile() + + // Then it should succeed + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + t.Run("ErrorOnGetProjectRoot", func(t *testing.T) { + // Given a shell with failing GetProjectRoot + shell, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { + return "", fmt.Errorf("getwd failed") + } - return stdoutBuf.String(), stderrBuf.String() + // When adding current dir to trusted file + err := shell.AddCurrentDirToTrustedFile() + + // Then it should return an error + if err == nil { + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "Error getting project root directory") { + t.Errorf("Expected error about project root, got: %v", err) + } + }) + + t.Run("ErrorOnUserHomeDir", func(t *testing.T) { + // Given a shell with failing UserHomeDir + shell, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { + return "/test/project", nil + } + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == "/test/project/windsor.yaml" { + return nil, nil + } + return nil, os.ErrNotExist + } + mocks.Shims.UserHomeDir = func() (string, error) { + return "", fmt.Errorf("user home dir failed") + } + + // When adding current dir to trusted file + err := shell.AddCurrentDirToTrustedFile() + + // Then it should return an error + if err == nil { + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "Error getting user home directory") { + t.Errorf("Expected error about user home dir, got: %v", err) + } + }) + + t.Run("ErrorOnMkdirAll", func(t *testing.T) { + // Given a shell with failing MkdirAll + shell, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { + return "/test/project", nil + } + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == "/test/project/windsor.yaml" { + return nil, nil + } + return nil, os.ErrNotExist + } + mocks.Shims.UserHomeDir = func() (string, error) { + return "/home/test", nil + } + mocks.Shims.MkdirAll = func(path string, perm os.FileMode) error { + return fmt.Errorf("mkdir failed") + } + + // When adding current dir to trusted file + err := shell.AddCurrentDirToTrustedFile() + + // Then it should return an error + if err == nil { + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "Error creating directories for trusted file") { + t.Errorf("Expected error about mkdir, got: %v", err) + } + }) + + t.Run("ErrorOnReadFile", func(t *testing.T) { + // Given a shell with failing ReadFile + shell, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { + return "/test/project", nil + } + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == "/test/project/windsor.yaml" { + return nil, nil + } + return nil, os.ErrNotExist + } + mocks.Shims.UserHomeDir = func() (string, error) { + return "/home/test", nil + } + mocks.Shims.MkdirAll = func(path string, perm os.FileMode) error { + return nil + } + mocks.Shims.ReadFile = func(name string) ([]byte, error) { + return nil, fmt.Errorf("read failed") + } + + // When adding current dir to trusted file + err := shell.AddCurrentDirToTrustedFile() + + // Then it should return an error + if err == nil { + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "Error reading trusted file") { + t.Errorf("Expected error about read file, got: %v", err) + } + }) + + t.Run("ErrorOnWriteFile", func(t *testing.T) { + // Given a shell with failing WriteFile + shell, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { + return "/test/project", nil + } + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == "/test/project/windsor.yaml" { + return nil, nil + } + return nil, os.ErrNotExist + } + mocks.Shims.UserHomeDir = func() (string, error) { + return "/home/test", nil + } + mocks.Shims.MkdirAll = func(path string, perm os.FileMode) error { + return nil + } + mocks.Shims.ReadFile = func(name string) ([]byte, error) { + return []byte("/other/project\n"), nil + } + mocks.Shims.WriteFile = func(name string, data []byte, perm os.FileMode) error { + return fmt.Errorf("write failed") + } + + // When adding current dir to trusted file + err := shell.AddCurrentDirToTrustedFile() + + // Then it should return an error + if err == nil { + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "Error writing to trusted file") { + t.Errorf("Expected error about write file, got: %v", err) + } + }) + + t.Run("AlreadyTrusted", func(t *testing.T) { + // Given a shell with current dir already trusted + shell, mocks := setup(t) + projectRoot := "/test/project" + mocks.Shims.Getwd = func() (string, error) { + return projectRoot, nil + } + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == filepath.Join(projectRoot, "windsor.yaml") { + return nil, nil + } + return nil, os.ErrNotExist + } + mocks.Shims.UserHomeDir = func() (string, error) { + return "/home/test", nil + } + mocks.Shims.MkdirAll = func(path string, perm os.FileMode) error { + return nil + } + mocks.Shims.ReadFile = func(name string) ([]byte, error) { + return []byte(projectRoot + "\n"), nil + } + + // When adding current dir to trusted file + err := shell.AddCurrentDirToTrustedFile() + + // Then it should succeed without writing + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) } -func TestEnv_CheckTrustedDirectory(t *testing.T) { - // Mock the getwd function - originalGetwd := getwd - originalOsUserHomeDir := osUserHomeDir - originalReadFile := osReadFile - - defer func() { - getwd = originalGetwd - osUserHomeDir = originalOsUserHomeDir - osReadFile = originalReadFile - }() - - getwd = func() (string, error) { - return "/mock/current/dir", nil +func TestShell_CheckTrustedDirectory(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks } - osUserHomeDir = func() (string, error) { - return "/mock/home/dir", nil - } + t.Run("Success", func(t *testing.T) { + // Given a shell with mocked operations + shell, mocks := setup(t) + projectRoot := "/test/project" + homeDir := "/home/test" + trustedFilePath := "/home/test/.config/windsor/.trusted" + + // Mock GetProjectRoot + mocks.Shims.Getwd = func() (string, error) { + return projectRoot, nil + } + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == filepath.Join(projectRoot, "windsor.yaml") { + return nil, nil + } + return nil, os.ErrNotExist + } - osReadFile = func(filename string) ([]byte, error) { - return []byte("/mock/current/dir\n"), nil - } + // Mock UserHomeDir + mocks.Shims.UserHomeDir = func() (string, error) { + return homeDir, nil + } - t.Run("Success", func(t *testing.T) { - shell := NewDefaultShell(di.NewInjector()) - shell.Initialize() + // Mock ReadFile + mocks.Shims.ReadFile = func(name string) ([]byte, error) { + if name != trustedFilePath { + t.Errorf("Expected path %s, got %s", trustedFilePath, name) + } + return []byte("/test\n"), nil + } - // Call CheckTrustedDirectory and check for errors + // When checking trusted directory err := shell.CheckTrustedDirectory() + + // Then it should succeed if err != nil { - t.Errorf("unexpected error: %v", err) + t.Errorf("Expected no error, got %v", err) } }) - t.Run("ErrorGettingCurrentDir", func(t *testing.T) { - // Save the original getwd function - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - - // Override the getwd function locally to simulate an error - getwd = func() (string, error) { - return "", fmt.Errorf("Error getting current directory: error getting current directory") + t.Run("ErrorOnGetProjectRoot", func(t *testing.T) { + // Given a shell with failing GetProjectRoot + shell, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { + return "", fmt.Errorf("getwd failed") } - // Call CheckTrustedDirectory and expect an error - shell := &DefaultShell{} + // When checking trusted directory err := shell.CheckTrustedDirectory() - if err == nil || !strings.Contains(err.Error(), "error getting current directory") { - t.Errorf("expected error containing 'error getting current directory', got %v", err) + + // Then it should return an error + if err == nil { + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "Error getting project root directory") { + t.Errorf("Expected error about project root, got: %v", err) } }) - t.Run("ErrorGettingUserHomeDir", func(t *testing.T) { - // Save the original osUserHomeDir function - originalOsUserHomeDir := osUserHomeDir - defer func() { osUserHomeDir = originalOsUserHomeDir }() - - // Override the osUserHomeDir function locally to simulate an error - osUserHomeDir = func() (string, error) { - return "", fmt.Errorf("Error getting user home directory: error getting user home directory") + t.Run("ErrorOnUserHomeDir", func(t *testing.T) { + // Given a shell with failing UserHomeDir + shell, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { + return "/test/project", nil + } + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == "/test/project/windsor.yaml" { + return nil, nil + } + return nil, os.ErrNotExist + } + mocks.Shims.UserHomeDir = func() (string, error) { + return "", fmt.Errorf("user home dir failed") } - // Call CheckTrustedDirectory and expect an error - shell := &DefaultShell{} + // When checking trusted directory err := shell.CheckTrustedDirectory() - if err == nil || !strings.Contains(err.Error(), "Error getting user home directory") { - t.Errorf("expected error containing 'Error getting user home directory', got %v", err) + + // Then it should return an error + if err == nil { + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "Error getting user home directory") { + t.Errorf("Expected error about user home dir, got: %v", err) } }) - t.Run("ErrorReadingTrustedFile", func(t *testing.T) { - // Save the original osReadFile function - originalReadFile := osReadFile - defer func() { osReadFile = originalReadFile }() - - // Override the osReadFile function locally to simulate an error - osReadFile = func(filename string) ([]byte, error) { - return nil, fmt.Errorf("error reading trusted file") + t.Run("ErrorOnReadFile", func(t *testing.T) { + // Given a shell with failing ReadFile + shell, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { + return "/test/project", nil + } + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == "/test/project/windsor.yaml" { + return nil, nil + } + return nil, os.ErrNotExist + } + mocks.Shims.UserHomeDir = func() (string, error) { + return "/home/test", nil + } + mocks.Shims.ReadFile = func(name string) ([]byte, error) { + return nil, fmt.Errorf("read failed") } - // Call CheckTrustedDirectory and expect an error - shell := &DefaultShell{} + // When checking trusted directory err := shell.CheckTrustedDirectory() - if err == nil || !strings.Contains(err.Error(), "error reading trusted file") { - t.Errorf("expected error containing 'error reading trusted file', got %v", err) + + // Then it should return an error + if err == nil { + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "Error reading trusted file") { + t.Errorf("Expected error about read file, got: %v", err) } }) - t.Run("TrustedFileDoesNotExist", func(t *testing.T) { - // Save the original osReadFile function - originalReadFile := osReadFile - defer func() { osReadFile = originalReadFile }() - - // Override the osReadFile function locally to simulate a non-existent trusted file - osReadFile = func(filename string) ([]byte, error) { + t.Run("TrustedFileNotExist", func(t *testing.T) { + // Given a shell with non-existent trusted file + shell, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { + return "/test/project", nil + } + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == "/test/project/windsor.yaml" { + return nil, nil + } + return nil, os.ErrNotExist + } + mocks.Shims.UserHomeDir = func() (string, error) { + return "/home/test", nil + } + mocks.Shims.ReadFile = func(name string) ([]byte, error) { return nil, os.ErrNotExist } - // Call CheckTrustedDirectory and expect an error - shell := &DefaultShell{} + // When checking trusted directory err := shell.CheckTrustedDirectory() - if err == nil || err.Error() != "Trusted file does not exist" { - t.Errorf("expected error 'Trusted file does not exist', got %v", err) + + // Then it should return an error + if err == nil { + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "Trusted file does not exist") { + t.Errorf("Expected error about file not existing, got: %v", err) } }) - t.Run("CurrentDirNotInTrustedList", func(t *testing.T) { - // Mock the getwd function to return a specific current directory - getwd = func() (string, error) { - return "/mock/current/dir", nil + t.Run("NotTrusted", func(t *testing.T) { + // Given a shell with directory not in trusted list + shell, mocks := setup(t) + projectRoot := "/test/project" + mocks.Shims.Getwd = func() (string, error) { + return projectRoot, nil } - - // Mock the osReadFile function to simulate a trusted file without the current directory - osReadFile = func(filename string) ([]byte, error) { - return []byte("/mock/other/dir\n"), nil + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == filepath.Join(projectRoot, "windsor.yaml") { + return nil, nil + } + return nil, os.ErrNotExist + } + mocks.Shims.UserHomeDir = func() (string, error) { + return "/home/test", nil + } + mocks.Shims.ReadFile = func(name string) ([]byte, error) { + return []byte("/other/project\n"), nil } - // Execute CheckTrustedDirectory and verify it returns the expected error - shell := &DefaultShell{} + // When checking trusted directory err := shell.CheckTrustedDirectory() - if err == nil || !strings.Contains(err.Error(), "Current directory not in the trusted list") { - t.Errorf("expected error 'Current directory not in the trusted list', got %v", err) + + // Then it should return an error + if err == nil { + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "Current directory not in the trusted list") { + t.Errorf("Expected error about directory not trusted, got: %v", err) } }) } -func TestDefaultShell_GetSessionToken(t *testing.T) { - t.Run("GenerateNewToken", func(t *testing.T) { - // Given - ResetSessionToken() - shell := setupShellTest(t) - - // Save original functions to restore later - originalRandRead := randRead - originalOsGetenv := osGetenv - defer func() { - randRead = originalRandRead - osGetenv = originalOsGetenv - }() +func TestShell_WriteResetToken(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } - // Mock osGetenv to return empty string (no env token) - osGetenv = func(key string) string { - return "" + t.Run("Success", func(t *testing.T) { + // Given a shell with mocked operations + shell, mocks := setup(t) + projectRoot := "/test/project" + sessionToken := "test-token" + sessionFilePath := filepath.Join(projectRoot, ".windsor", ".session."+sessionToken) + + // Mock GetProjectRoot + mocks.Shims.Getwd = func() (string, error) { + return projectRoot, nil + } + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == filepath.Join(projectRoot, "windsor.yaml") { + return nil, nil + } + return nil, os.ErrNotExist } - // Create a deterministic token generator - randRead = func(b []byte) (n int, err error) { - for i := range b { - b[i] = byte(i % 62) // Use a deterministic pattern + // Mock environment + mocks.Shims.Getenv = func(key string) string { + if key == "WINDSOR_SESSION_TOKEN" { + return sessionToken } - return len(b), nil + return "" } - // When - token, err := shell.GetSessionToken() - - // Then - if err != nil { - t.Errorf("GetSessionToken() error = %v, want nil", err) - } - if len(token) != 7 { - t.Errorf("GetSessionToken() token length = %d, want 7", len(token)) + // Mock file operations + mocks.Shims.MkdirAll = func(path string, perm os.FileMode) error { + expectedPath := filepath.Join(projectRoot, ".windsor") + if path != expectedPath { + t.Errorf("Expected path %s, got %s", expectedPath, path) + } + if perm != 0750 { + t.Errorf("Expected perm 0750, got %o", perm) + } + return nil } - }) - - t.Run("ReuseExistingToken", func(t *testing.T) { - // Given - ResetSessionToken() - shell := setupShellTest(t) - // Save original functions - originalRandRead := randRead - originalOsGetenv := osGetenv - defer func() { - randRead = originalRandRead - osGetenv = originalOsGetenv - }() - - // Mock rand.Read to generate a predictable token - randRead = func(b []byte) (n int, err error) { - for i := range b { - b[i] = byte(i % 62) // Use a deterministic pattern + mocks.Shims.WriteFile = func(name string, data []byte, perm os.FileMode) error { + if name != sessionFilePath { + t.Errorf("Expected path %s, got %s", sessionFilePath, name) } - return len(b), nil + if len(data) != 0 { + t.Errorf("Expected empty data, got %v", data) + } + if perm != 0600 { + t.Errorf("Expected perm 0600, got %o", perm) + } + return nil } - // Generate a first token to cache it - firstToken, _ := shell.GetSessionToken() - - // When getting a second token - secondToken, err := shell.GetSessionToken() + // When writing reset token + path, err := shell.WriteResetToken() - // Then + // Then it should succeed if err != nil { - t.Errorf("GetSessionToken() error = %v, want nil", err) + t.Errorf("Expected no error, got %v", err) } - if firstToken != secondToken { - t.Errorf("GetSessionToken() token = %s, want %s", secondToken, firstToken) + if path != sessionFilePath { + t.Errorf("Expected path %s, got %s", sessionFilePath, path) } }) - t.Run("UseEnvironmentToken", func(t *testing.T) { - // Given - ResetSessionToken() - shell := setupShellTest(t) - - // Save original functions - originalOsGetenv := osGetenv - defer func() { - osGetenv = originalOsGetenv - }() - - // Mock osGetenv to return a specific token - osGetenv = func(key string) string { - if key == "WINDSOR_SESSION_TOKEN" { - return "testtoken" - } + t.Run("NoSessionToken", func(t *testing.T) { + // Given a shell with no session token + shell, mocks := setup(t) + mocks.Shims.Getenv = func(key string) string { return "" } - // When - token, err := shell.GetSessionToken() + // When writing reset token + path, err := shell.WriteResetToken() - // Then + // Then it should return empty path without error if err != nil { - t.Errorf("GetSessionToken() error = %v, want nil", err) + t.Errorf("Expected no error, got %v", err) } - if token != "testtoken" { - t.Errorf("GetSessionToken() token = %s, want testtoken", token) + if path != "" { + t.Errorf("Expected empty path, got %s", path) } }) - t.Run("ErrorGeneratingRandomString", func(t *testing.T) { - // Given - ResetSessionToken() - shell := setupShellTest(t) - - // Save original functions - originalRandRead := randRead - originalOsGetenv := osGetenv - defer func() { - randRead = originalRandRead - osGetenv = originalOsGetenv - }() - - // Mock osGetenv to return empty string (no env token) - osGetenv = func(key string) string { + t.Run("ErrorOnGetProjectRoot", func(t *testing.T) { + // Given a shell with failing GetProjectRoot + shell, mocks := setup(t) + mocks.Shims.Getenv = func(key string) string { + if key == "WINDSOR_SESSION_TOKEN" { + return "test-token" + } return "" } - - // Mock random generation to fail - randRead = func(b []byte) (n int, err error) { - return 0, fmt.Errorf("mock random generation error") + mocks.Shims.Getwd = func() (string, error) { + return "", fmt.Errorf("getwd failed") } - // When - token, err := shell.GetSessionToken() + // When writing reset token + path, err := shell.WriteResetToken() - // Then + // Then it should return an error if err == nil { - t.Error("GetSessionToken() expected error, got nil") - return + t.Error("Expected error, got nil") } - if token != "" { - t.Errorf("GetSessionToken() token = %s, want empty string", token) + if !strings.Contains(err.Error(), "error getting project root") { + t.Errorf("Expected error about project root, got: %v", err) } - expectedErr := "error generating session token: mock random generation error" - if err.Error() != expectedErr { - t.Errorf("GetSessionToken() error = %v, want %v", err, expectedErr) + if path != "" { + t.Errorf("Expected empty path, got %s", path) } }) - t.Run("GenerateRandomString", func(t *testing.T) { - // This test checks that generateRandomString properly generates strings of the right length - shell := setupShellTest(t) - - // Save original randRead - originalRandRead := randRead - defer func() { randRead = originalRandRead }() - - // Make randRead produce deterministic output for testing - randRead = func(b []byte) (n int, err error) { - for i := range b { - b[i] = byte(i % 62) // Use a deterministic pattern + t.Run("ErrorOnMkdirAll", func(t *testing.T) { + // Given a shell with failing MkdirAll + shell, mocks := setup(t) + mocks.Shims.Getenv = func(key string) string { + if key == "WINDSOR_SESSION_TOKEN" { + return "test-token" } - return len(b), nil + return "" } - - testRandomStringGeneration(t, shell, 7) - testRandomStringGeneration(t, shell, 10) - testRandomStringGeneration(t, shell, 15) - }) -} - -// TestDefaultShell_WriteResetToken tests the WriteResetToken method -func TestDefaultShell_WriteResetToken(t *testing.T) { - // Save original functions and environment - originalOsMkdirAll := osMkdirAll - originalOsWriteFile := osWriteFile - originalEnvValue := os.Getenv("WINDSOR_SESSION_TOKEN") - - // Restore original functions and environment after all tests - defer func() { - osMkdirAll = originalOsMkdirAll - osWriteFile = originalOsWriteFile - if originalEnvValue != "" { - os.Setenv("WINDSOR_SESSION_TOKEN", originalEnvValue) - } else { - os.Unsetenv("WINDSOR_SESSION_TOKEN") + mocks.Shims.Getwd = func() (string, error) { + return "/test/project", nil + } + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == "/test/project/windsor.yaml" { + return nil, nil + } + return nil, os.ErrNotExist + } + mocks.Shims.MkdirAll = func(path string, perm os.FileMode) error { + return fmt.Errorf("mkdir failed") } - }() - - t.Run("NoSessionToken", func(t *testing.T) { - // Given a default shell with no session token in environment - shell := setupShellTest(t) - - // Ensure the environment variable is not set - os.Unsetenv("WINDSOR_SESSION_TOKEN") - // When calling WriteResetToken + // When writing reset token path, err := shell.WriteResetToken() - // Then no error should be returned and path should be empty - if err != nil { - t.Errorf("WriteResetToken() error = %v, want nil", err) + // Then it should return an error + if err == nil { + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "error creating .windsor directory") { + t.Errorf("Expected error about mkdir, got: %v", err) } if path != "" { - t.Errorf("WriteResetToken() path = %v, want empty string", path) - } - }) - - t.Run("SuccessfulTokenWrite", func(t *testing.T) { - // Given a default shell with a session token - shell := setupShellTest(t) - - // Set up test data using platform-specific path functions - testProjectRoot := filepath.FromSlash("/test/project/root") - testToken := "test-token-123" - expectedDirPath := filepath.Join(testProjectRoot, ".windsor") - expectedFilePath := filepath.Join(expectedDirPath, SessionTokenPrefix+testToken) - - // For comparison in errors, we'll use ToSlash to show normalized paths - expectedDirPathNormalized := filepath.ToSlash(expectedDirPath) - expectedFilePathNormalized := filepath.ToSlash(expectedFilePath) - - // Track function calls - var mkdirAllCalled bool - var writeFileCalled bool - var mkdirAllPath string - var mkdirAllPerm os.FileMode - var writeFilePath string - var writeFileData []byte - var writeFilePerm os.FileMode - - // Mock OS functions - osMkdirAll = func(path string, perm os.FileMode) error { - mkdirAllCalled = true - mkdirAllPath = path - mkdirAllPerm = perm - return nil + t.Errorf("Expected empty path, got %s", path) } + }) - osWriteFile = func(name string, data []byte, perm os.FileMode) error { - writeFileCalled = true - writeFilePath = name - writeFileData = data - writeFilePerm = perm + t.Run("ErrorOnWriteFile", func(t *testing.T) { + // Given a shell with failing WriteFile + shell, mocks := setup(t) + mocks.Shims.Getenv = func(key string) string { + if key == "WINDSOR_SESSION_TOKEN" { + return "test-token" + } + return "" + } + mocks.Shims.Getwd = func() (string, error) { + return "/test/project", nil + } + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == "/test/project/windsor.yaml" { + return nil, nil + } + return nil, os.ErrNotExist + } + mocks.Shims.MkdirAll = func(path string, perm os.FileMode) error { return nil } - - // Mock getwd to return our test project root - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return testProjectRoot, nil + mocks.Shims.WriteFile = func(name string, data []byte, perm os.FileMode) error { + return fmt.Errorf("write failed") } - // Set the environment variable - os.Setenv("WINDSOR_SESSION_TOKEN", testToken) - - // When calling WriteResetToken + // When writing reset token path, err := shell.WriteResetToken() - // Then no error should be returned and the path should match expected - if err != nil { - t.Errorf("WriteResetToken() error = %v, want nil", err) + // Then it should return an error + if err == nil { + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "error writing reset token file") { + t.Errorf("Expected error about write file, got: %v", err) } + if path != "" { + t.Errorf("Expected empty path, got %s", path) + } + }) +} - // Use ToSlash to normalize paths for comparison - if filepath.ToSlash(path) != expectedFilePathNormalized { - t.Errorf("WriteResetToken() path = %v, want %v", - filepath.ToSlash(path), expectedFilePathNormalized) +func TestShell_Reset(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } + + t.Run("Success", func(t *testing.T) { + // Given a shell with mocked operations + shell, mocks := setup(t) + projectRoot := "/test/project" + sessionFiles := []string{ + filepath.Join(projectRoot, ".windsor", ".session.1"), + filepath.Join(projectRoot, ".windsor", ".session.2"), } - // Verify that MkdirAll was called with correct parameters - if !mkdirAllCalled { - t.Error("Expected MkdirAll to be called, but it wasn't") - } else { - if filepath.ToSlash(mkdirAllPath) != expectedDirPathNormalized { - t.Errorf("Expected MkdirAll path %s, got %s", - expectedDirPathNormalized, filepath.ToSlash(mkdirAllPath)) - } - if mkdirAllPerm != 0750 { - t.Errorf("Expected MkdirAll permissions 0750, got %v", mkdirAllPerm) + // Mock GetProjectRoot + mocks.Shims.Getwd = func() (string, error) { + return projectRoot, nil + } + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == filepath.Join(projectRoot, "windsor.yaml") { + return nil, nil } + return nil, os.ErrNotExist } - // Verify that WriteFile was called with correct parameters - if !writeFileCalled { - t.Error("Expected WriteFile to be called, but it wasn't") - } else { - if filepath.ToSlash(writeFilePath) != expectedFilePathNormalized { - t.Errorf("Expected WriteFile path %s, got %s", - expectedFilePathNormalized, filepath.ToSlash(writeFilePath)) + // Mock file operations + mocks.Shims.Glob = func(pattern string) ([]string, error) { + expectedPattern := filepath.Join(projectRoot, ".windsor", ".session.*") + if pattern != expectedPattern { + t.Errorf("Expected pattern %s, got %s", expectedPattern, pattern) } - if len(writeFileData) != 0 { - t.Errorf("Expected empty file, got %v bytes", len(writeFileData)) - } - if writeFilePerm != 0600 { - t.Errorf("Expected WriteFile permissions 0600, got %v", writeFilePerm) + return sessionFiles, nil + } + + mocks.Shims.RemoveAll = func(path string) error { + if path != filepath.Join(projectRoot, ".windsor") { + t.Errorf("Expected path %s, got %s", filepath.Join(projectRoot, ".windsor"), path) } + return nil } - }) - t.Run("ErrorGettingProjectRoot", func(t *testing.T) { - // Given a default shell with a session token - shell := setupShellTest(t) + // When resetting + shell.Reset() + }) - // Mock getwd to return an error - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return "", fmt.Errorf("error getting project root") + t.Run("ErrorOnGetProjectRoot", func(t *testing.T) { + // Given a shell with failing GetProjectRoot + shell, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { + return "", fmt.Errorf("getwd failed") } - // Set the environment variable - os.Setenv("WINDSOR_SESSION_TOKEN", "test-token") - - // When calling WriteResetToken - path, err := shell.WriteResetToken() + // When resetting + shell.Reset() + // No error expected since Reset() doesn't return error + }) - // Then an error should be returned and the path should be empty - if err == nil { - t.Error("WriteResetToken() expected error, got nil") + t.Run("ErrorOnGlob", func(t *testing.T) { + // Given a shell with failing Glob + shell, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { + return "/test/project", nil } - if !strings.Contains(err.Error(), "error getting project root") { - t.Errorf("WriteResetToken() error = %v, want error containing 'error getting project root'", err) + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == "/test/project/windsor.yaml" { + return nil, nil + } + return nil, os.ErrNotExist } - if path != "" { - t.Errorf("WriteResetToken() path = %v, want empty string", path) + mocks.Shims.Glob = func(pattern string) ([]string, error) { + return nil, fmt.Errorf("glob failed") } - }) - t.Run("ErrorCreatingDirectory", func(t *testing.T) { - // Given a default shell with a session token - shell := setupShellTest(t) + // When resetting + shell.Reset() + // No error expected since Reset() doesn't return error + }) - // Mock getwd to return a test path - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return "/test/project/root", nil + t.Run("ErrorOnRemoveAll", func(t *testing.T) { + // Given a shell with failing RemoveAll + shell, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { + return "/test/project", nil } - - // Mock MkdirAll to return an error - expectedError := fmt.Errorf("error creating directory") - osMkdirAll = func(path string, perm os.FileMode) error { - return expectedError + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == "/test/project/windsor.yaml" { + return nil, nil + } + return nil, os.ErrNotExist + } + mocks.Shims.Glob = func(pattern string) ([]string, error) { + return []string{"/test/project/.windsor/.session.1"}, nil + } + mocks.Shims.RemoveAll = func(path string) error { + return fmt.Errorf("remove failed") } - // Set the environment variable - os.Setenv("WINDSOR_SESSION_TOKEN", "test-token") - - // When calling WriteResetToken - path, err := shell.WriteResetToken() + // When resetting + shell.Reset() + // No error expected since Reset() doesn't return error + }) - // Then an error should be returned and the path should be empty - if err == nil { - t.Error("WriteResetToken() expected error, got nil") + t.Run("NoSessionFiles", func(t *testing.T) { + // Given a shell with no session files + shell, mocks := setup(t) + mocks.Shims.Getwd = func() (string, error) { + return "/test/project", nil + } + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == "/test/project/windsor.yaml" { + return nil, nil + } + return nil, os.ErrNotExist } - if !strings.Contains(err.Error(), expectedError.Error()) { - t.Errorf("WriteResetToken() error = %v, want error containing %v", err, expectedError) + mocks.Shims.Glob = func(pattern string) ([]string, error) { + return []string{}, nil } - if path != "" { - t.Errorf("WriteResetToken() path = %v, want empty string", path) + + // When resetting + shell.Reset() + // No error expected since Reset() doesn't return error + }) + + t.Run("EnvironmentAndAliasReset", func(t *testing.T) { + // Given a shell with managed environment variables and aliases + shell, mocks := setup(t) + + // Mock environment variables + mocks.Shims.Getenv = func(key string) string { + switch key { + case "WINDSOR_MANAGED_ENV": + return "ENV1, ENV2, ENV3" + case "WINDSOR_MANAGED_ALIAS": + return "ALIAS1, ALIAS2, ALIAS3" + default: + return "" + } } + + // When resetting + shell.Reset() + + // Then environment variables and aliases should be unset + // Note: We can't directly verify the unset operations since they're system calls + // The test coverage will show that these code paths were executed }) - t.Run("ErrorWritingFile", func(t *testing.T) { - // Given a default shell with a session token - shell := setupShellTest(t) + t.Run("EmptyEnvironmentAndAlias", func(t *testing.T) { + // Given a shell with empty managed environment variables and aliases + shell, mocks := setup(t) - // Mock getwd to return a test path - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - getwd = func() (string, error) { - return "/test/project/root", nil + // Mock empty environment variables + mocks.Shims.Getenv = func(key string) string { + return "" } - // Mock MkdirAll to succeed but WriteFile to fail - osMkdirAll = func(path string, perm os.FileMode) error { - return nil - } + // When resetting + shell.Reset() + + // Then no environment variables or aliases should be unset + // Note: We can't directly verify the unset operations since they're system calls + // The test coverage will show that these code paths were executed + }) +} + +func TestShell_ResetSessionToken(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } + + t.Run("Success", func(t *testing.T) { + // Given a shell with mocked operations + shell, mocks := setup(t) + expectedToken := "test-token" - expectedError := fmt.Errorf("error writing file") - osWriteFile = func(name string, data []byte, perm os.FileMode) error { - return expectedError + // Mock environment variable + mocks.Shims.Getenv = func(key string) string { + if key == "WINDSOR_SESSION_TOKEN" { + return expectedToken + } + return "" } - // Set the environment variable - os.Setenv("WINDSOR_SESSION_TOKEN", "test-token") + // Mock random generation to return predictable bytes + mocks.Shims.RandRead = func(b []byte) (n int, err error) { + // Fill with bytes that will map to "new-test-token" + for i := range b { + b[i] = byte(i % 62) + } + return len(b), nil + } - // When calling WriteResetToken - path, err := shell.WriteResetToken() + // When getting session token + token, err := shell.GetSessionToken() + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if token != expectedToken { + t.Errorf("Expected token %q, got %q", expectedToken, token) + } - // Then an error should be returned and the path should be empty - if err == nil { - t.Error("WriteResetToken() expected error, got nil") + // When resetting session token + shell.ResetSessionToken() + mocks.Shims.Getenv = func(key string) string { + return "" // Simulate environment variable being unset } - if !strings.Contains(err.Error(), expectedError.Error()) { - t.Errorf("WriteResetToken() error = %v, want error containing %v", err, expectedError) + + // Then getting session token should return a new token + newToken, err := shell.GetSessionToken() + if err != nil { + t.Errorf("Expected no error, got %v", err) } - if path != "" { - t.Errorf("WriteResetToken() path = %v, want empty string", path) + if len(newToken) != 7 { + t.Errorf("Expected new token length to be 7, got %d", len(newToken)) } }) } -// TestDefaultShell_AddCurrentDirToTrustedFile tests the AddCurrentDirToTrustedFile method -func TestDefaultShell_AddCurrentDirToTrustedFile(t *testing.T) { - // Save original functions and environment - originalGetwd := getwd - originalOsUserHomeDir := osUserHomeDir - originalOsReadFile := osReadFile - originalOsMkdirAll := osMkdirAll - originalOsWriteFile := osWriteFile - - // Restore original functions after all tests - defer func() { - getwd = originalGetwd - osUserHomeDir = originalOsUserHomeDir - osReadFile = originalOsReadFile - osMkdirAll = originalOsMkdirAll - osWriteFile = originalOsWriteFile - }() +func TestShell_ExecProgress(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } t.Run("Success", func(t *testing.T) { - // Given a default shell - shell := setupShellTest(t) + // Given a shell with mocked operations + shell, mocks := setup(t) + + // Set expected output + expectedOutput := "test output\n" + message := "Test Progress" - // Mock required functions - getwd = func() (string, error) { - return "/mock/current/dir", nil + // Mock command execution + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + return exec.Command("test") } - osUserHomeDir = func() (string, error) { - return "/mock/home/dir", nil + // Mock stdout pipe to write expected output + mocks.Shims.StdoutPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + go func() { + w.Write([]byte(expectedOutput)) + w.Close() + }() + return r, nil } - osReadFile = func(filename string) ([]byte, error) { - return []byte{}, nil // Empty trusted file + // Mock stderr pipe + mocks.Shims.StderrPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + w.Close() + return r, nil } - osMkdirAll = func(path string, perm os.FileMode) error { - expectedPath := "/mock/home/dir/.config/windsor" - if filepath.ToSlash(path) != expectedPath { - t.Errorf("Expected MkdirAll path %s, got %s", expectedPath, path) - } - if perm != 0750 { - t.Errorf("Expected MkdirAll permissions 0750, got %v", perm) - } + // Mock command start and wait + mocks.Shims.CmdStart = func(cmd *exec.Cmd) error { return nil } - - var capturedData []byte - osWriteFile = func(filename string, data []byte, perm os.FileMode) error { - expectedPath := "/mock/home/dir/.config/windsor/.trusted" - if filepath.ToSlash(filename) != expectedPath { - t.Errorf("Expected WriteFile path %s, got %s", expectedPath, filename) - } - capturedData = data - if perm != 0600 { - t.Errorf("Expected WriteFile permissions 0600, got %v", perm) - } + mocks.Shims.CmdWait = func(cmd *exec.Cmd) error { return nil } - // When adding the current directory to the trusted file - err := shell.AddCurrentDirToTrustedFile() + // When executing command with progress + output, err := shell.ExecProgress(message, "test") - // Then no error should be returned + // Then it should succeed and return expected output if err != nil { - t.Errorf("AddCurrentDirToTrustedFile() error = %v, want nil", err) + t.Errorf("Expected no error, got %v", err) } - - // Verify that the current directory was added to the trusted file - expectedData := "/mock/current/dir\n" - if string(capturedData) != expectedData { - t.Errorf("Expected data %q, got %q", expectedData, string(capturedData)) + if output != expectedOutput { + t.Errorf("Expected output %q, got %q", expectedOutput, output) } }) - t.Run("SuccessAlreadyTrusted", func(t *testing.T) { - // Given a default shell - shell := setupShellTest(t) + t.Run("VerboseMode", func(t *testing.T) { + // Given a shell with verbose mode enabled + shell, mocks := setup(t) + shell.SetVerbosity(true) - // Mock required functions - getwd = func() (string, error) { - return "/mock/current/dir", nil - } + // Set expected output and message + expectedOutput := "test output\n" + message := "Test Progress" - osUserHomeDir = func() (string, error) { - return "/mock/home/dir", nil + // Mock command execution + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + cmd := exec.Command("test") + cmd.Stdout = new(bytes.Buffer) + cmd.Stderr = new(bytes.Buffer) + return cmd } - osReadFile = func(filename string) ([]byte, error) { - return []byte("/mock/current/dir\n"), nil // Directory already in trusted file + // Mock command start to write output + mocks.Shims.CmdStart = func(cmd *exec.Cmd) error { + if w, ok := cmd.Stdout.(io.Writer); ok { + w.Write([]byte(expectedOutput)) + } + return nil } - // Track if WriteFile is called (it shouldn't be) - writeFileCalled := false - osWriteFile = func(filename string, data []byte, perm os.FileMode) error { - writeFileCalled = true + // Mock command wait + mocks.Shims.CmdWait = func(cmd *exec.Cmd) error { return nil } - // When adding the current directory to the trusted file - err := shell.AddCurrentDirToTrustedFile() + // When executing command with progress + output, err := shell.ExecProgress(message, "test") - // Then no error should be returned + // Then it should succeed and return expected output if err != nil { - t.Errorf("AddCurrentDirToTrustedFile() error = %v, want nil", err) + t.Errorf("Expected no error, got %v", err) } - - // Verify that WriteFile was not called - if writeFileCalled { - t.Error("Expected WriteFile not to be called, but it was") + if output != expectedOutput { + t.Errorf("Expected output %q, got %q", expectedOutput, output) } }) - t.Run("ErrorGettingProjectRoot", func(t *testing.T) { - // Given a default shell - shell := setupShellTest(t) + t.Run("CommandCreationFailure", func(t *testing.T) { + // Given a shell with failing command creation + shell, mocks := setup(t) - // Mock getwd to return an error - getwd = func() (string, error) { - return "", fmt.Errorf("error getting project root directory") + // Mock command execution to fail + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + return nil } - // When adding the current directory to the trusted file - err := shell.AddCurrentDirToTrustedFile() + // When executing command with progress + output, err := shell.ExecProgress("Test Progress", "test") - // Then an error should be returned + // Then it should return an error if err == nil { - t.Error("AddCurrentDirToTrustedFile() expected error, got nil") + t.Error("Expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to create command") { + t.Errorf("Expected error to contain 'failed to create command', got %v", err) } - expectedError := "Error getting project root directory: error getting project root directory" - if err.Error() != expectedError { - t.Errorf("AddCurrentDirToTrustedFile() error = %q, want %q", err.Error(), expectedError) + if output != "" { + t.Errorf("Expected empty output, got %q", output) } }) - t.Run("ErrorGettingUserHomeDir", func(t *testing.T) { - // Given a default shell - shell := setupShellTest(t) + t.Run("StdoutPipeFailure", func(t *testing.T) { + // Given a shell with failing stdout pipe + shell, mocks := setup(t) - // Mock getwd to succeed but osUserHomeDir to fail - getwd = func() (string, error) { - return "/mock/current/dir", nil + // Mock command execution + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + return exec.Command("test") } - osUserHomeDir = func() (string, error) { - return "", fmt.Errorf("error getting user home directory") + // Mock stdout pipe to fail + mocks.Shims.StdoutPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + return nil, fmt.Errorf("stdout pipe error") } - // When adding the current directory to the trusted file - err := shell.AddCurrentDirToTrustedFile() + // When executing command with progress + output, err := shell.ExecProgress("Test Progress", "test") - // Then an error should be returned + // Then it should return an error if err == nil { - t.Error("AddCurrentDirToTrustedFile() expected error, got nil") + t.Error("Expected error, got nil") } - expectedError := "Error getting user home directory: error getting user home directory" - if err.Error() != expectedError { - t.Errorf("AddCurrentDirToTrustedFile() error = %q, want %q", err.Error(), expectedError) + if !strings.Contains(err.Error(), "stdout pipe error") { + t.Errorf("Expected error to contain 'stdout pipe error', got %v", err) + } + if output != "" { + t.Errorf("Expected empty output, got %q", output) } }) - t.Run("ErrorCreatingDirectories", func(t *testing.T) { - // Given a default shell - shell := setupShellTest(t) + t.Run("StderrPipeFailure", func(t *testing.T) { + // Given a shell with failing stderr pipe + shell, mocks := setup(t) - // Mock getwd and osUserHomeDir to succeed but osMkdirAll to fail - getwd = func() (string, error) { - return "/mock/current/dir", nil + // Mock command execution + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + return exec.Command("test") } - osUserHomeDir = func() (string, error) { - return "/mock/home/dir", nil + // Mock stdout pipe to succeed + mocks.Shims.StdoutPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + go func() { + w.Write([]byte("test output\n")) + w.Close() + }() + return r, nil } - expectedError := fmt.Errorf("error creating directories") - osMkdirAll = func(path string, perm os.FileMode) error { - return expectedError + // Mock stderr pipe to fail + mocks.Shims.StderrPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + return nil, fmt.Errorf("stderr pipe error") } - // When adding the current directory to the trusted file - err := shell.AddCurrentDirToTrustedFile() + // When executing command with progress + output, err := shell.ExecProgress("Test Progress", "test") - // Then an error should be returned + // Then it should return an error if err == nil { - t.Error("AddCurrentDirToTrustedFile() expected error, got nil") + t.Error("Expected error, got nil") } - - expectedErrorMsg := "Error creating directories for trusted file: error creating directories" - if err.Error() != expectedErrorMsg { - t.Errorf("AddCurrentDirToTrustedFile() error = %q, want %q", err.Error(), expectedErrorMsg) + if !strings.Contains(err.Error(), "stderr pipe error") { + t.Errorf("Expected error to contain 'stderr pipe error', got %v", err) + } + if output != "" { + t.Errorf("Expected empty output, got %q", output) } }) - t.Run("ErrorReadingTrustedFile", func(t *testing.T) { - // Given a default shell - shell := setupShellTest(t) + t.Run("CommandStartFailure", func(t *testing.T) { + // Given a shell with failing command start + shell, mocks := setup(t) - // Mock getwd, osUserHomeDir, and osMkdirAll to succeed but osReadFile to fail - getwd = func() (string, error) { - return "/mock/current/dir", nil + // Mock command execution + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + return exec.Command("test") } - osUserHomeDir = func() (string, error) { - return "/mock/home/dir", nil + // Mock stdout pipe + mocks.Shims.StdoutPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + go func() { + w.Write([]byte("test output\n")) + w.Close() + }() + return r, nil } - osMkdirAll = func(path string, perm os.FileMode) error { - return nil + // Mock stderr pipe + mocks.Shims.StderrPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + w.Close() + return r, nil } - expectedError := fmt.Errorf("error reading trusted file") - osReadFile = func(filename string) ([]byte, error) { - return nil, expectedError + // Mock command start to fail + mocks.Shims.CmdStart = func(cmd *exec.Cmd) error { + return fmt.Errorf("command start error") } - // When adding the current directory to the trusted file - err := shell.AddCurrentDirToTrustedFile() + // When executing command with progress + output, err := shell.ExecProgress("Test Progress", "test") - // Then an error should be returned + // Then it should return an error if err == nil { - t.Error("AddCurrentDirToTrustedFile() expected error, got nil") + t.Error("Expected error, got nil") } - - expectedErrorMsg := "Error reading trusted file: error reading trusted file" - if err.Error() != expectedErrorMsg { - t.Errorf("AddCurrentDirToTrustedFile() error = %q, want %q", err.Error(), expectedErrorMsg) + if !strings.Contains(err.Error(), "command start error") { + t.Errorf("Expected error to contain 'command start error', got %v", err) + } + if output != "" { + t.Errorf("Expected empty output, got %q", output) } }) - t.Run("ErrorReadingNonExistentTrustedFile", func(t *testing.T) { - // Given a default shell - shell := setupShellTest(t) + t.Run("CommandExecutionFailure", func(t *testing.T) { + shell, mocks := setup(t) - // Mock getwd, osUserHomeDir, and osMkdirAll to succeed but osReadFile to return file not exist error - getwd = func() (string, error) { - return "/mock/current/dir", nil - } + expectedOutput := "" + message := "Test Progress" - osUserHomeDir = func() (string, error) { - return "/mock/home/dir", nil + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + return exec.Command("test") } - osMkdirAll = func(path string, perm os.FileMode) error { - return nil + mocks.Shims.StdoutPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + go func() { + w.Close() + }() + return r, nil } - osReadFile = func(filename string) ([]byte, error) { - return nil, os.ErrNotExist + mocks.Shims.StderrPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + go func() { + w.Close() + }() + return r, nil } - var capturedData []byte - osWriteFile = func(filename string, data []byte, perm os.FileMode) error { - capturedData = data + mocks.Shims.CmdStart = func(cmd *exec.Cmd) error { return nil } - // When adding the current directory to the trusted file - err := shell.AddCurrentDirToTrustedFile() - - // Then no error should be returned - if err != nil { - t.Errorf("AddCurrentDirToTrustedFile() error = %v, want nil", err) + mocks.Shims.CmdWait = func(cmd *exec.Cmd) error { + return fmt.Errorf("command execution error") } - // Verify that the current directory was added to the trusted file - expectedData := "/mock/current/dir\n" - if string(capturedData) != expectedData { - t.Errorf("Expected data %q, got %q", expectedData, string(capturedData)) + output, err := shell.ExecProgress(message, "test") + if err == nil { + t.Error("Expected error, got nil") } - }) - - t.Run("ErrorWritingToTrustedFile", func(t *testing.T) { - // Given a default shell - shell := setupShellTest(t) - - // Mock getwd, osUserHomeDir, osMkdirAll, and osReadFile to succeed but osWriteFile to fail - getwd = func() (string, error) { - return "/mock/current/dir", nil + if !strings.Contains(err.Error(), "command execution error") { + t.Errorf("Expected error to contain 'command execution error', got %v", err) } - - osUserHomeDir = func() (string, error) { - return "/mock/home/dir", nil + if output != expectedOutput { + t.Errorf("Expected output %q, got %q", expectedOutput, output) } + }) - osMkdirAll = func(path string, perm os.FileMode) error { - return nil - } + t.Run("StdoutScannerError", func(t *testing.T) { + shell, mocks := setup(t) + message := "Test Progress" - osReadFile = func(filename string) ([]byte, error) { - return []byte{}, nil + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + return exec.Command("test") } - expectedError := fmt.Errorf("error writing to trusted file") - osWriteFile = func(filename string, data []byte, perm os.FileMode) error { - return expectedError + mocks.Shims.StdoutPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + go func() { + w.Close() + }() + return r, nil } - // When adding the current directory to the trusted file - err := shell.AddCurrentDirToTrustedFile() - - // Then an error should be returned - if err == nil { - t.Error("AddCurrentDirToTrustedFile() expected error, got nil") + mocks.Shims.StderrPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + go func() { + w.Close() + }() + return r, nil } - expectedErrorMsg := "Error writing to trusted file: error writing to trusted file" - if err.Error() != expectedErrorMsg { - t.Errorf("AddCurrentDirToTrustedFile() error = %q, want %q", err.Error(), expectedErrorMsg) + mocks.Shims.CmdStart = func(cmd *exec.Cmd) error { + return nil } - }) -} -// TestMockShell_GetSessionToken tests the MockShell's GetSessionToken method -func TestMockShell_GetSessionToken(t *testing.T) { - t.Run("Success", func(t *testing.T) { - // Given - injector := di.NewInjector() - mockShell := NewMockShell(injector) - - expectedToken := "mock-token" - mockShell.GetSessionTokenFunc = func() (string, error) { - return expectedToken, nil + mocks.Shims.CmdWait = func(cmd *exec.Cmd) error { + return nil } - // When - token, err := mockShell.GetSessionToken() - - // Then - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - if token != expectedToken { - t.Errorf("Expected token %s, got %s", expectedToken, token) + mocks.Shims.ScannerErr = func(scanner *bufio.Scanner) error { + return fmt.Errorf("stdout scanner error") } - }) - - t.Run("Error", func(t *testing.T) { - // Given - injector := di.NewInjector() - mockShell := NewMockShell(injector) - expectedError := "custom error" - mockShell.GetSessionTokenFunc = func() (string, error) { - return "", fmt.Errorf(expectedError) + output, err := shell.ExecProgress(message, "test") + if output != "" { + t.Errorf("Expected empty output, got %q", output) } - - // When - token, err := mockShell.GetSessionToken() - - // Then if err == nil { t.Error("Expected error, got nil") } - if err.Error() != expectedError { - t.Errorf("Expected error %s, got %s", expectedError, err.Error()) - } - if token != "" { - t.Errorf("Expected empty token, got %s", token) + if !strings.Contains(err.Error(), "stdout scanner error") { + t.Errorf("Expected error to contain 'stdout scanner error', got %v", err) } }) - t.Run("NotImplemented", func(t *testing.T) { - // Given - injector := di.NewInjector() - mockShell := NewMockShell(injector) - - // Don't set GetSessionTokenFunc - - // When - token, err := mockShell.GetSessionToken() + t.Run("StderrScannerError", func(t *testing.T) { + shell, mocks := setup(t) + message := "Test Progress" - // Then - if err == nil { - t.Error("Expected error, got nil") - } - expectedError := "GetSessionToken not implemented" - if err.Error() != expectedError { - t.Errorf("Expected error %s, got %s", expectedError, err.Error()) - } - if token != "" { - t.Errorf("Expected empty token, got %s", token) + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + return exec.Command("test") } - }) -} - -// TestDefaultShell_CheckResetFlags tests the CheckResetFlags method of DefaultShell -func TestDefaultShell_CheckResetFlags(t *testing.T) { - // Save original environment variable and restore it after all tests - origEnv := os.Getenv("WINDSOR_SESSION_TOKEN") - defer func() { os.Setenv("WINDSOR_SESSION_TOKEN", origEnv) }() - - // Save original session token and restore it after all tests - origSessionToken := sessionToken - defer func() { sessionToken = origSessionToken }() - t.Run("NoSessionToken", func(t *testing.T) { - // Given - shell := setupShellTest(t) - ResetSessionToken() - - // When no session token is set in the environment - osSetenv("WINDSOR_SESSION_TOKEN", "") - result, err := shell.CheckResetFlags() - - // Then - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - if result { - t.Errorf("Expected result to be false when no session token exists") + mocks.Shims.StdoutPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + go func() { + w.Close() + }() + return r, nil } - }) - - t.Run("ErrorGettingProjectRoot", func(t *testing.T) { - // Given - shell := setupShellTest(t) - ResetSessionToken() - // Save original getwd function - originalGetwd := getwd - defer func() { getwd = originalGetwd }() - - // Mock the getwd function to return an error - getwd = func() (string, error) { - return "", fmt.Errorf("error getting working directory") + mocks.Shims.StderrPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + go func() { + w.Close() + }() + return r, nil } - // Set a test session token - osSetenv("WINDSOR_SESSION_TOKEN", "test-token") - - // When - result, err := shell.CheckResetFlags() - - // Then - if err == nil { - t.Errorf("Expected error, got nil") - } - if !strings.Contains(err.Error(), "error getting project root") { - t.Errorf("Expected error to contain 'error getting project root', got: %v", err) - } - if result { - t.Errorf("Expected result to be false when error occurs") + mocks.Shims.CmdStart = func(cmd *exec.Cmd) error { + return nil } - }) - - t.Run("WindsorDirectoryDoesNotExist", func(t *testing.T) { - // Given - shell := setupShellTest(t) - ResetSessionToken() - - // Save original functions - originalGetwd := getwd - originalOsStat := osStat - defer func() { - getwd = originalGetwd - osStat = originalOsStat - }() - // Mock the getwd function - getwd = func() (string, error) { - return "/test/project", nil + mocks.Shims.CmdWait = func(cmd *exec.Cmd) error { + return nil } - // Mock the osStat function to simulate .windsor directory not existing - osStat = func(name string) (os.FileInfo, error) { - if strings.Contains(name, "windsor.yaml") || strings.Contains(name, "windsor.yml") { - return nil, os.ErrNotExist - } - if strings.Contains(name, ".windsor") { - return nil, os.ErrNotExist - } - return nil, os.ErrNotExist + mocks.Shims.ScannerErr = func(scanner *bufio.Scanner) error { + return fmt.Errorf("stderr scanner error") } - // Set a test session token - osSetenv("WINDSOR_SESSION_TOKEN", "test-token") - - // When - result, err := shell.CheckResetFlags() - - // Then - if err != nil { - t.Errorf("Expected no error, got %v", err) + output, err := shell.ExecProgress(message, "test") + if output != "" { + t.Errorf("Expected empty output, got %q", output) + } + if err == nil { + t.Error("Expected error, got nil") } - if result { - t.Errorf("Expected result to be false when .windsor directory doesn't exist") + if !strings.Contains(err.Error(), "stderr scanner error") { + t.Errorf("Expected error to contain 'stderr scanner error', got %v", err) } }) - t.Run("ResetFileExists", func(t *testing.T) { - // Given - shell := setupShellTest(t) - ResetSessionToken() + t.Run("EmptyOutput", func(t *testing.T) { + shell, mocks := setup(t) - // Save original functions - originalGetwd := getwd - originalOsStat := osStat - defer func() { - getwd = originalGetwd - osStat = originalOsStat - }() + expectedOutput := "" + message := "Test Progress" - // Mock the getwd function - getwd = func() (string, error) { - return "/test/project", nil + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + return exec.Command("test") } - // Mock the osStat function to simulate reset file existing - osStat = func(name string) (os.FileInfo, error) { - if strings.Contains(name, "windsor.yaml") { - return nil, nil // windsor.yaml exists - } - if strings.Contains(name, ".windsor") { - return nil, nil // .windsor directory exists - } - if strings.Contains(name, ".session.test-token") { - return nil, nil // Reset file exists - } - return nil, os.ErrNotExist + mocks.Shims.StdoutPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + go func() { + w.Close() + }() + return r, nil } - // Set a test session token - osSetenv("WINDSOR_SESSION_TOKEN", "test-token") - - // When - result, err := shell.CheckResetFlags() - - // Then - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - if !result { - t.Errorf("Expected result to be true when reset file exists") + mocks.Shims.StderrPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + go func() { + w.Close() + }() + return r, nil } - }) - t.Run("ResetFileDoesNotExist", func(t *testing.T) { - // Given - shell := setupShellTest(t) - ResetSessionToken() - - // Save original functions - originalGetwd := getwd - originalOsStat := osStat - defer func() { - getwd = originalGetwd - osStat = originalOsStat - }() - - // Mock the getwd function - getwd = func() (string, error) { - return "/test/project", nil + mocks.Shims.CmdStart = func(cmd *exec.Cmd) error { + return nil } - // Mock the osStat function to simulate reset file not existing - osStat = func(name string) (os.FileInfo, error) { - if strings.Contains(name, "windsor.yaml") { - return nil, nil // windsor.yaml exists - } - if strings.Contains(name, ".windsor") && !strings.Contains(name, ".session.") { - return nil, nil // .windsor directory exists - } - // Reset file does not exist - return nil, os.ErrNotExist + mocks.Shims.CmdWait = func(cmd *exec.Cmd) error { + return nil } - // Set a test session token - osSetenv("WINDSOR_SESSION_TOKEN", "test-token") - - // When - result, err := shell.CheckResetFlags() - - // Then + output, err := shell.ExecProgress(message, "test") if err != nil { t.Errorf("Expected no error, got %v", err) } - if result { - t.Errorf("Expected result to be false when reset file doesn't exist") + if output != expectedOutput { + t.Errorf("Expected output %q, got %q", expectedOutput, output) } }) - t.Run("ErrorFindingSessionFiles", func(t *testing.T) { - // Given - shell := setupShellTest(t) - ResetSessionToken() + t.Run("MultiLineOutput", func(t *testing.T) { + shell, mocks := setup(t) - // Save original functions - originalGetwd := getwd - originalOsStat := osStat - originalFilepathGlob := filepathGlob - defer func() { - getwd = originalGetwd - osStat = originalOsStat - filepathGlob = originalFilepathGlob - }() + expectedOutput := "line 1\nline 2\nline 3\n" + message := "Test Progress" - // Mock the getwd function - getwd = func() (string, error) { - return "/test/project", nil + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + return exec.Command("test") } - // Mock osStat to simulate .windsor dir exists - osStat = func(name string) (os.FileInfo, error) { - if strings.Contains(name, "windsor.yaml") { - return nil, nil // windsor.yaml exists - } - return nil, os.ErrNotExist + mocks.Shims.StdoutPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + go func() { + w.Write([]byte(expectedOutput)) + w.Close() + }() + return r, nil } - // Mock filepath.Glob to return an error - filepathGlob = func(pattern string) ([]string, error) { - return nil, fmt.Errorf("mock error finding session files") + mocks.Shims.StderrPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + r, w := io.Pipe() + go func() { + w.Close() + }() + return r, nil } - // Set a test session token - osSetenv("WINDSOR_SESSION_TOKEN", "test-token") - - // When - result, err := shell.CheckResetFlags() + mocks.Shims.CmdStart = func(cmd *exec.Cmd) error { + return nil + } - // Then - if err == nil { - t.Errorf("Expected error, got nil") + mocks.Shims.CmdWait = func(cmd *exec.Cmd) error { + return nil } - if !strings.Contains(err.Error(), "error finding session files") { - t.Errorf("Expected error to contain 'error finding session files', got: %v", err) + + output, err := shell.ExecProgress(message, "test") + if err != nil { + t.Errorf("Expected no error, got %v", err) } - if result { - t.Errorf("Expected result to be false when error occurs") + if output != expectedOutput { + t.Errorf("Expected output %q, got %q", expectedOutput, output) } }) - t.Run("ErrorRemovingSessionFiles", func(t *testing.T) { - // Given - shell := setupShellTest(t) - ResetSessionToken() - - // Save original functions - originalGetwd := getwd - originalOsStat := osStat - originalFilepathGlob := filepathGlob - originalOsRemoveAll := osRemoveAll - defer func() { - getwd = originalGetwd - osStat = originalOsStat - filepathGlob = originalFilepathGlob - osRemoveAll = originalOsRemoveAll - }() + t.Run("ScannerBehavior", func(t *testing.T) { + shell, mocks := setup(t) + message := "Test Progress" + expectedOutput := "test output\n" - // Mock the getwd function - getwd = func() (string, error) { - return "/test/project", nil + // Mock command execution + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + cmd := exec.Command("test") + cmd.Stdout = new(bytes.Buffer) + cmd.Stderr = new(bytes.Buffer) + return cmd } - // Mock osStat to simulate .windsor dir exists - osStat = func(name string) (os.FileInfo, error) { - if strings.Contains(name, "windsor.yaml") || strings.Contains(name, ".windsor") { - return nil, nil // both config file and directory exist + // Mock command start to write output + mocks.Shims.CmdStart = func(cmd *exec.Cmd) error { + if w, ok := cmd.Stdout.(io.Writer); ok { + w.Write([]byte(expectedOutput)) } - return nil, os.ErrNotExist - } - - // Mock filepath.Glob to return some session files - filepathGlob = func(pattern string) ([]string, error) { - return []string{"/test/project/.windsor/.session.test-token"}, nil - } - - // Mock osRemoveAll to return an error - osRemoveAll = func(path string) error { - return fmt.Errorf("mock error removing session file") - } - - // Set a test session token - osSetenv("WINDSOR_SESSION_TOKEN", "test-token") - - // When - result, err := shell.CheckResetFlags() - - // Then - if err == nil { - t.Errorf("Expected error, got nil") - } - if !strings.Contains(err.Error(), "error removing session file") { - t.Errorf("Expected error to contain 'error removing session file', got: %v", err) - } - if result { - t.Errorf("Expected result to be false when error occurs") + return nil } - }) -} -// TestMockShell_CheckReset tests the MockShell's CheckReset method -func TestMockShell_CheckReset(t *testing.T) { - t.Run("Success", func(t *testing.T) { - // Given - injector := di.NewInjector() - mockShell := NewMockShell(injector) - - // Configure the mock to return a success response - mockShell.CheckResetFlagsFunc = func() (bool, error) { - return true, nil + // Mock command wait + mocks.Shims.CmdWait = func(cmd *exec.Cmd) error { + return nil } - // When - result, err := mockShell.CheckResetFlags() + // When executing command with progress + output, err := shell.ExecProgress(message, "test") - // Then + // Then it should succeed and return expected output if err != nil { t.Errorf("Expected no error, got %v", err) } - if !result { - t.Errorf("Expected result to be true, got false") + if output != expectedOutput { + t.Errorf("Expected output %q, got %q", expectedOutput, output) } }) - t.Run("Error", func(t *testing.T) { - // Given - injector := di.NewInjector() - mockShell := NewMockShell(injector) + t.Run("EmptyStderrOnFailure", func(t *testing.T) { + // Given a shell with mocked operations + shell, mocks := setup(t) - // Configure the mock to return an error - expectedError := fmt.Errorf("custom error") - mockShell.CheckResetFlagsFunc = func() (bool, error) { - return false, expectedError + // Mock command execution + mocks.Shims.Command = func(name string, args ...string) *exec.Cmd { + return exec.Command("test-command", "arg1", "arg2") } - // When - result, err := mockShell.CheckResetFlags() - - // Then - if err == nil || err.Error() != expectedError.Error() { - t.Errorf("Expected error %v, got %v", expectedError, err) + // Mock stdout pipe to return test output + mocks.Shims.StdoutPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + return io.NopCloser(strings.NewReader("test output\n")), nil } - if result { - t.Errorf("Expected result to be false, got true") - } - }) - t.Run("DefaultImplementation", func(t *testing.T) { - // Given - injector := di.NewInjector() - mockShell := NewMockShell(injector) - - // When CheckResetFunc isn't set, the default implementation should be used - result, err := mockShell.CheckResetFlags() + // Mock stderr pipe to return empty string + mocks.Shims.StderrPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { + return io.NopCloser(strings.NewReader("")), nil + } - // Then - if err != nil { - t.Errorf("Expected no error, got %v", err) + // Mock command start and wait + mocks.Shims.CmdStart = func(cmd *exec.Cmd) error { + return nil } - if result { - t.Errorf("Expected result to be false by default, got true") + mocks.Shims.CmdWait = func(cmd *exec.Cmd) error { + return fmt.Errorf("command failed") } - }) -} - -// TestDefaultShell_Reset tests the Reset method of the DefaultShell struct -func TestDefaultShell_Reset(t *testing.T) { - t.Run("ResetWithNoEnvVars", func(t *testing.T) { - // Given a default shell - shell := setupShellTest(t) - - // Make sure environment variables are not set - os.Unsetenv("WINDSOR_MANAGED_ENV") - os.Unsetenv("WINDSOR_MANAGED_ALIAS") - - // Set up the test - origStdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - // When calling Reset - shell.Reset() - - // Capture and restore stdout - w.Close() - var buf bytes.Buffer - io.Copy(&buf, r) - os.Stdout = origStdout + // When executing command with progress + output, err := shell.ExecProgress("test progress", "test-command", "arg1", "arg2") - // Then no unset commands should be issued - output := buf.String() - if strings.Contains(output, "unset") { - t.Errorf("Expected no unset commands, but got: %s", output) + // Then it should return error and output + if err == nil { + t.Error("Expected error, got nil") } - }) - - t.Run("ResetWithEnvironmentVariables", func(t *testing.T) { - // Given a default shell - shell := setupShellTest(t) - - // Set environment variables - os.Setenv("WINDSOR_MANAGED_ENV", "ENV1,ENV2, ENV3") - os.Setenv("WINDSOR_MANAGED_ALIAS", "alias1,alias2, alias3") - defer func() { - os.Unsetenv("WINDSOR_MANAGED_ENV") - os.Unsetenv("WINDSOR_MANAGED_ALIAS") - }() - - // Set up the test - origStdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - - // When calling Reset - shell.Reset() - - // Capture and restore stdout - w.Close() - var buf bytes.Buffer - io.Copy(&buf, r) - os.Stdout = origStdout - - // Then unset commands should be issued - output := buf.String() - // Check for unset ENV1 ENV2 ENV3 (on Unix) or Remove-Item ENV:ENV1, etc on Windows - if runtime.GOOS == "windows" { - if !strings.Contains(output, "Remove-Item Env:ENV1") { - t.Errorf("Expected Remove-Item Env:ENV1 command, but got: %s", output) - } - if !strings.Contains(output, "Remove-Item Env:ENV2") { - t.Errorf("Expected Remove-Item Env:ENV2 command, but got: %s", output) - } - if !strings.Contains(output, "Remove-Item Env:ENV3") { - t.Errorf("Expected Remove-Item Env:ENV3 command, but got: %s", output) - } - // And unalias for aliases - if !strings.Contains(output, "Remove-Item Alias:alias1") { - t.Errorf("Expected Remove-Item Alias:alias1 command, but got: %s", output) - } - } else { - // For Unix - if !strings.Contains(output, "unset ENV1 ENV2 ENV3") { - t.Errorf("Expected unset ENV1 ENV2 ENV3 command, but got: %s", output) - } - // And unalias for aliases - if !strings.Contains(output, "unalias alias1") { - t.Errorf("Expected unalias alias1 command, but got: %s", output) - } - if !strings.Contains(output, "unalias alias2") { - t.Errorf("Expected unalias alias2 command, but got: %s", output) - } - if !strings.Contains(output, "unalias alias3") { - t.Errorf("Expected unalias alias3 command, but got: %s", output) - } + if output != "test output\n" { + t.Errorf("Expected output 'test output', got '%s'", output) } }) } diff --git a/pkg/shell/shims.go b/pkg/shell/shims.go index f72e059aa..437f105ea 100644 --- a/pkg/shell/shims.go +++ b/pkg/shell/shims.go @@ -1,108 +1,171 @@ +// The shims package is a system call abstraction layer +// It provides mockable wrappers around system and runtime functions +// It serves as a testing aid by allowing system calls to be intercepted +// It enables dependency injection and test isolation for system-level operations + package shell import ( "bufio" + "crypto/rand" "io" - "math/rand" "os" "os/exec" "path/filepath" "text/template" ) -// getwd is a variable that points to os.Getwd, allowing it to be overridden in tests -var getwd = os.Getwd - -// execCommand is a variable that points to exec.Command, allowing it to be overridden in tests -var execCommand = osExecCommand - -// osExecCommand is a wrapper around exec.Command to allow it to be overridden in tests -func osExecCommand(name string, arg ...string) *exec.Cmd { - return exec.Command(name, arg...) -} - -// cmdRun is a variable that points to cmd.Run, allowing it to be overridden in tests -var cmdRun = func(cmd *exec.Cmd) error { - return cmd.Run() -} - -// cmdStart is a variable that points to cmd.Start, allowing it to be overridden in tests -var cmdStart = func(cmd *exec.Cmd) error { - return cmd.Start() -} - -// osUserHomeDir is a variable that points to os.UserHomeDir, allowing it to be overridden in tests -var osUserHomeDir = os.UserHomeDir - -// osStat is a variable that points to os.Stat, allowing it to be overridden in tests -var osStat = os.Stat - -// osOpenFile is a variable that points to os.OpenFile, allowing it to be overridden in tests -var osOpenFile = os.OpenFile - -// osReadFile is a variable that points to os.ReadFile, allowing it to be overridden in tests -var osReadFile = os.ReadFile - -// osWriteFile is a variable that points to os.WriteFile, allowing it to be overridden in tests -var osWriteFile = os.WriteFile - -// osMkdirAll is a variable that points to os.MkdirAll, allowing it to be overridden in tests -var osMkdirAll = os.MkdirAll - -// cmdWait is a variable that points to cmd.Wait, allowing it to be overridden in tests -var cmdWait = func(cmd *exec.Cmd) error { - return cmd.Wait() -} - -// cmdStdoutPipe is a variable that points to cmd.StdoutPipe, allowing it to be overridden in tests -var cmdStdoutPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { - return cmd.StdoutPipe() +// ============================================================================= +// Types +// ============================================================================= + +// Shims provides mockable wrappers around system and runtime functions +type Shims struct { + // OS operations + Getwd func() (string, error) + Stat func(name string) (os.FileInfo, error) + Executable func() (string, error) + + // Standard I/O operations + Stderr func() io.Writer + SetStderr func(w io.Writer) + Stdout func() io.Writer + SetStdout func(w io.Writer) + Pipe func() (*os.File, *os.File, error) + + // Shell operations + UnsetEnvs func(envVars []string) + UnsetAlias func(aliases []string) + + // Exec operations + Command func(name string, arg ...string) *exec.Cmd + LookPath func(file string) (string, error) + OpenFile func(name string, flag int, perm os.FileMode) (*os.File, error) + WriteFile func(name string, data []byte, perm os.FileMode) error + ReadFile func(name string) ([]byte, error) + MkdirAll func(path string, perm os.FileMode) error + Remove func(name string) error + RemoveAll func(path string) error + Chdir func(dir string) error + Setenv func(key, value string) error + Getenv func(key string) string + UserHomeDir func() (string, error) + + // Exec operations + CmdRun func(cmd *exec.Cmd) error + CmdStart func(cmd *exec.Cmd) error + CmdWait func(cmd *exec.Cmd) error + StdoutPipe func(cmd *exec.Cmd) (io.ReadCloser, error) + StderrPipe func(cmd *exec.Cmd) (io.ReadCloser, error) + StdinPipe func(cmd *exec.Cmd) (io.WriteCloser, error) + + // Template operations + NewTemplate func(name string) *template.Template + TemplateParse func(tmpl *template.Template, text string) (*template.Template, error) + TemplateExecute func(tmpl *template.Template, wr io.Writer, data any) error + ExecuteTemplate func(tmpl *template.Template, data interface{}) error + + // Bufio operations + NewScanner func(r io.Reader) *bufio.Scanner + ScannerScan func(scanner *bufio.Scanner) bool + ScannerErr func(scanner *bufio.Scanner) error + ScannerText func(scanner *bufio.Scanner) string + NewWriter func(w io.Writer) *bufio.Writer + + // Filepath operations + Glob func(pattern string) ([]string, error) + Join func(elem ...string) string + + // Random operations + RandRead func(b []byte) (n int, err error) } -// cmdStderrPipe is a variable that points to cmd.StderrPipe, allowing it to be overridden in tests -var cmdStderrPipe = func(cmd *exec.Cmd) (io.ReadCloser, error) { - return cmd.StderrPipe() +// ============================================================================= +// Constructor +// ============================================================================= + +// NewShims creates a new Shims instance with default implementations +func NewShims() *Shims { + s := &Shims{ + // OS operations + Getwd: os.Getwd, + Stat: os.Stat, + Executable: os.Executable, + + // Standard I/O operations + Stderr: func() io.Writer { + return os.Stderr + }, + SetStderr: func(w io.Writer) { + if f, ok := w.(*os.File); ok { + os.Stderr = f + } + }, + Stdout: func() io.Writer { + return os.Stdout + }, + SetStdout: func(w io.Writer) { + if f, ok := w.(*os.File); ok { + os.Stdout = f + } + }, + Pipe: os.Pipe, + + // Shell operations + UnsetEnvs: func(envVars []string) {}, + UnsetAlias: func(aliases []string) {}, + + // Exec operations + Command: exec.Command, + LookPath: exec.LookPath, + OpenFile: os.OpenFile, + WriteFile: os.WriteFile, + ReadFile: os.ReadFile, + MkdirAll: os.MkdirAll, + Remove: os.Remove, + RemoveAll: os.RemoveAll, + Chdir: os.Chdir, + Setenv: os.Setenv, + Getenv: os.Getenv, + UserHomeDir: os.UserHomeDir, + + // Exec operations + CmdRun: (*exec.Cmd).Run, + CmdStart: (*exec.Cmd).Start, + CmdWait: (*exec.Cmd).Wait, + StdoutPipe: (*exec.Cmd).StdoutPipe, + StderrPipe: (*exec.Cmd).StderrPipe, + StdinPipe: (*exec.Cmd).StdinPipe, + + // Template operations + NewTemplate: template.New, + TemplateParse: (*template.Template).Parse, + TemplateExecute: (*template.Template).Execute, + ExecuteTemplate: func(tmpl *template.Template, data interface{}) error { + return tmpl.Execute(os.Stdout, data) + }, + + // Bufio operations + NewScanner: bufio.NewScanner, + ScannerScan: func(scanner *bufio.Scanner) bool { + return scanner.Scan() + }, + ScannerErr: func(scanner *bufio.Scanner) error { + return scanner.Err() + }, + ScannerText: func(scanner *bufio.Scanner) string { + return scanner.Text() + }, + NewWriter: bufio.NewWriter, + + // Filepath operations + Glob: filepath.Glob, + Join: filepath.Join, + + // Random operations + RandRead: func(b []byte) (n int, err error) { + return rand.Read(b) + }, + } + return s } - -// bufioScannerScan is a variable that points to bufio.Scanner.Scan, allowing it to be overridden in tests -var bufioScannerScan = func(scanner *bufio.Scanner) bool { - return scanner.Scan() -} - -// bufioScannerErr is a variable that points to bufio.Scanner.Err, allowing it to be overridden in tests -var bufioScannerErr = func(scanner *bufio.Scanner) error { - return scanner.Err() -} - -// osExecutable is a variable that points to os.Executable, allowing it to be overridden in tests -var osExecutable = os.Executable - -// hookTemplateNew is a variable that points to template.New, allowing it to be overridden in tests -var hookTemplateNew = func(name string) *template.Template { - return template.New(name) -} - -// hookTemplateParse is a variable that points to template.Template.Parse, allowing it to be overridden in tests -var hookTemplateParse = func(tmpl *template.Template, text string) (*template.Template, error) { - return tmpl.Parse(text) -} - -// hookTemplateExecute is a variable that points to template.Template.Execute, allowing it to be overridden in tests -var hookTemplateExecute = func(tmpl *template.Template, wr io.Writer, data interface{}) error { - return tmpl.Execute(wr, data) -} - -// randRead is a variable that points to rand.Read, allowing it to be overridden in tests -var randRead = rand.Read - -// osGetenv is a variable that points to os.Getenv, allowing it to be overridden in tests -var osGetenv = os.Getenv - -// filepathGlob is a variable that points to filepath.Glob, allowing it to be overridden in tests -var filepathGlob = filepath.Glob - -// osRemoveAll is a variable that points to os.RemoveAll, allowing it to be overridden in tests -var osRemoveAll = os.RemoveAll - -// osSetenv is a variable that points to os.Setenv, allowing it to be overridden in tests -var osSetenv = os.Setenv diff --git a/pkg/shell/shims_test.go b/pkg/shell/shims_test.go new file mode 100644 index 000000000..9b37d97a7 --- /dev/null +++ b/pkg/shell/shims_test.go @@ -0,0 +1,440 @@ +package shell + +import ( + "bufio" + "os" + "os/exec" + "strings" + "testing" +) + +func TestShell_NewShims(t *testing.T) { + t.Run("InitializesAllShims", func(t *testing.T) { + // When we create new shims + shims := NewShims() + + // Then all shims should be non-nil + if shims == nil { + t.Fatal("Expected non-nil shims") + } + + // We don't test the actual implementations since they are real system calls + // Instead we just verify that all fields are initialized + if shims.Getwd == nil { + t.Error("Expected Getwd to be initialized") + } + if shims.Stat == nil { + t.Error("Expected Stat to be initialized") + } + if shims.Executable == nil { + t.Error("Expected Executable to be initialized") + } + if shims.Stderr == nil { + t.Error("Expected Stderr to be initialized") + } + if shims.SetStderr == nil { + t.Error("Expected SetStderr to be initialized") + } + if shims.Stdout == nil { + t.Error("Expected Stdout to be initialized") + } + if shims.SetStdout == nil { + t.Error("Expected SetStdout to be initialized") + } + if shims.Pipe == nil { + t.Error("Expected Pipe to be initialized") + } + if shims.Command == nil { + t.Error("Expected Command to be initialized") + } + if shims.LookPath == nil { + t.Error("Expected LookPath to be initialized") + } + if shims.OpenFile == nil { + t.Error("Expected OpenFile to be initialized") + } + if shims.WriteFile == nil { + t.Error("Expected WriteFile to be initialized") + } + if shims.ReadFile == nil { + t.Error("Expected ReadFile to be initialized") + } + if shims.MkdirAll == nil { + t.Error("Expected MkdirAll to be initialized") + } + if shims.Remove == nil { + t.Error("Expected Remove to be initialized") + } + if shims.RemoveAll == nil { + t.Error("Expected RemoveAll to be initialized") + } + if shims.Chdir == nil { + t.Error("Expected Chdir to be initialized") + } + if shims.Setenv == nil { + t.Error("Expected Setenv to be initialized") + } + if shims.Getenv == nil { + t.Error("Expected Getenv to be initialized") + } + if shims.UserHomeDir == nil { + t.Error("Expected UserHomeDir to be initialized") + } + if shims.CmdRun == nil { + t.Error("Expected CmdRun to be initialized") + } + if shims.CmdStart == nil { + t.Error("Expected CmdStart to be initialized") + } + if shims.CmdWait == nil { + t.Error("Expected CmdWait to be initialized") + } + if shims.StdoutPipe == nil { + t.Error("Expected StdoutPipe to be initialized") + } + if shims.StderrPipe == nil { + t.Error("Expected StderrPipe to be initialized") + } + if shims.StdinPipe == nil { + t.Error("Expected StdinPipe to be initialized") + } + if shims.NewTemplate == nil { + t.Error("Expected NewTemplate to be initialized") + } + if shims.TemplateParse == nil { + t.Error("Expected TemplateParse to be initialized") + } + if shims.TemplateExecute == nil { + t.Error("Expected TemplateExecute to be initialized") + } + if shims.ExecuteTemplate == nil { + t.Error("Expected ExecuteTemplate to be initialized") + } + if shims.ScannerScan == nil { + t.Error("Expected ScannerScan to be initialized") + } + if shims.ScannerErr == nil { + t.Error("Expected ScannerErr to be initialized") + } + if shims.ScannerText == nil { + t.Error("Expected ScannerText to be initialized") + } + if shims.NewWriter == nil { + t.Error("Expected NewWriter to be initialized") + } + if shims.Glob == nil { + t.Error("Expected Glob to be initialized") + } + if shims.Join == nil { + t.Error("Expected Join to be initialized") + } + if shims.RandRead == nil { + t.Error("Expected RandRead to be initialized") + } + // Int63n is allowed to be nil as it's not initialized in NewShims + }) +} + +func TestNewShims(t *testing.T) { + t.Run("InitializesAllShims", func(t *testing.T) { + // When we create new shims + shims := NewShims() + + // Then all shims should be initialized + if shims.Getwd == nil { + t.Error("Expected Getwd to be initialized") + } + if shims.Stat == nil { + t.Error("Expected Stat to be initialized") + } + if shims.Executable == nil { + t.Error("Expected Executable to be initialized") + } + if shims.Stderr == nil { + t.Error("Expected Stderr to be initialized") + } + if shims.Stdout == nil { + t.Error("Expected Stdout to be initialized") + } + if shims.Command == nil { + t.Error("Expected Command to be initialized") + } + if shims.UserHomeDir == nil { + t.Error("Expected UserHomeDir to be initialized") + } + + // Verify a shim actually works + stderr := shims.Stderr() + if stderr != os.Stderr { + t.Error("Expected Stderr to return os.Stderr") + } + }) + + t.Run("SetStderr", func(t *testing.T) { + // Given + shims := NewShims() + origStderr := os.Stderr + defer func() { os.Stderr = origStderr }() + + // When setting a non-file writer + shims.SetStderr(&mockWriter{}) + // Then stderr should not change + if os.Stderr != origStderr { + t.Error("Expected Stderr to remain unchanged for non-file writer") + } + + // When setting a file writer + tmpFile, err := os.CreateTemp("", "stderr") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name()) + shims.SetStderr(tmpFile) + // Then stderr should change + if os.Stderr == origStderr { + t.Error("Expected Stderr to change for file writer") + } + }) + + t.Run("SetStdout", func(t *testing.T) { + // Given + shims := NewShims() + origStdout := os.Stdout + defer func() { os.Stdout = origStdout }() + + // When setting a non-file writer + shims.SetStdout(&mockWriter{}) + // Then stdout should not change + if os.Stdout != origStdout { + t.Error("Expected Stdout to remain unchanged for non-file writer") + } + + // When setting a file writer + tmpFile, err := os.CreateTemp("", "stdout") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name()) + shims.SetStdout(tmpFile) + // Then stdout should change + if os.Stdout == origStdout { + t.Error("Expected Stdout to change for file writer") + } + }) + + t.Run("Pipe", func(t *testing.T) { + // Given + shims := NewShims() + + // When creating a pipe + r, w, err := shims.Pipe() + if err != nil { + t.Fatal(err) + } + defer r.Close() + defer w.Close() + + // Then we should be able to write and read + testData := []byte("test data") + if _, err := w.Write(testData); err != nil { + t.Errorf("Failed to write test data: %v", err) + } + + buf := make([]byte, len(testData)) + n, err := r.Read(buf) + if err != nil { + t.Fatal(err) + } + if n != len(testData) { + t.Errorf("Expected to read %d bytes, got %d", len(testData), n) + } + if string(buf) != string(testData) { + t.Errorf("Expected %q, got %q", testData, buf) + } + }) + + t.Run("RandRead", func(t *testing.T) { + // Given + shims := NewShims() + buf := make([]byte, 32) + + // When reading random bytes + n, err := shims.RandRead(buf) + if err != nil { + t.Fatal(err) + } + if n != len(buf) { + t.Errorf("Expected to read %d bytes, got %d", len(buf), n) + } + + // Then we should get random data + zeroCount := 0 + for _, b := range buf { + if b == 0 { + zeroCount++ + } + } + if zeroCount == len(buf) { + t.Error("Expected random data, got all zeros") + } + }) +} + +type mockWriter struct{} + +func (m *mockWriter) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func TestShims_ExecOperations(t *testing.T) { + t.Run("CmdOperations", func(t *testing.T) { + // Given + shims := NewShims() + + // Use Go's built-in test binary as a reliable cross-platform command + cmd := shims.Command(os.Args[0], "-test.run=TestShims_ExecOperations/ThisTestDoesNotExist") + + // Test CmdRun + if err := shims.CmdRun(cmd); err != nil { + // We expect an error since the test doesn't exist, but it should run + if e, ok := err.(*exec.ExitError); !ok || e.ExitCode() != 1 { + t.Errorf("Expected exit code 1 for non-existent test, got: %v", err) + } + } + + // Test CmdStart and CmdWait + cmd = shims.Command(os.Args[0], "-test.run=TestShims_ExecOperations/ThisTestDoesNotExist") + if err := shims.CmdStart(cmd); err != nil { + t.Errorf("Expected CmdStart to succeed, got error: %v", err) + } + if err := shims.CmdWait(cmd); err != nil { + // We expect an error since the test doesn't exist, but it should complete + if e, ok := err.(*exec.ExitError); !ok || e.ExitCode() != 1 { + t.Errorf("Expected exit code 1 for non-existent test, got: %v", err) + } + } + }) + + t.Run("PipeOperations", func(t *testing.T) { + // Given + shims := NewShims() + cmd := shims.Command(os.Args[0], "-test.run=TestShims_ExecOperations/ThisTestDoesNotExist") + + // Test StdinPipe + stdin, err := shims.StdinPipe(cmd) + if err != nil { + t.Errorf("Expected StdinPipe to succeed, got error: %v", err) + } + defer stdin.Close() + + // Test StdoutPipe + stdout, err := shims.StdoutPipe(cmd) + if err != nil { + t.Errorf("Expected StdoutPipe to succeed, got error: %v", err) + } + defer stdout.Close() + + // Test StderrPipe + stderr, err := shims.StderrPipe(cmd) + if err != nil { + t.Errorf("Expected StderrPipe to succeed, got error: %v", err) + } + defer stderr.Close() + }) +} + +func TestShims_TemplateOperations(t *testing.T) { + t.Run("TemplateOperations", func(t *testing.T) { + // Given + shims := NewShims() + tmpl := shims.NewTemplate("test") + + // Test TemplateParse + parsed, err := shims.TemplateParse(tmpl, "Hello {{.}}") + if err != nil { + t.Errorf("Expected TemplateParse to succeed, got error: %v", err) + } + + // Test TemplateExecute + var buf mockWriter + if err := shims.TemplateExecute(parsed, &buf, "World"); err != nil { + t.Errorf("Expected TemplateExecute to succeed, got error: %v", err) + } + + // Test ExecuteTemplate + if err := shims.ExecuteTemplate(parsed, "World"); err != nil { + t.Errorf("Expected ExecuteTemplate to succeed, got error: %v", err) + } + }) + + t.Run("TemplateParseError", func(t *testing.T) { + // Given + shims := NewShims() + tmpl := shims.NewTemplate("test") + + // When parsing an invalid template + _, err := shims.TemplateParse(tmpl, "{{.Invalid}") + if err == nil { + t.Error("Expected error for invalid template") + } + }) + + t.Run("TemplateExecuteError", func(t *testing.T) { + // Given + shims := NewShims() + tmpl := shims.NewTemplate("test") + tmpl, err := shims.TemplateParse(tmpl, "{{.MissingField}}") + if err != nil { + t.Fatal(err) + } + + // When executing with invalid data + err = shims.TemplateExecute(tmpl, &strings.Builder{}, struct{}{}) + if err == nil { + t.Error("Expected error for missing field") + } + }) +} + +func TestShims_ScannerOperations(t *testing.T) { + t.Run("ScannerOperations", func(t *testing.T) { + // Given + shims := NewShims() + scanner := bufio.NewScanner(strings.NewReader("test\n")) + + // Test ScannerScan + if !shims.ScannerScan(scanner) { + t.Error("Expected ScannerScan to return true") + } + + // Test ScannerText + if text := shims.ScannerText(scanner); text != "test" { + t.Errorf("Expected ScannerText to return 'test', got %q", text) + } + + // Test ScannerErr + if err := shims.ScannerErr(scanner); err != nil { + t.Errorf("Expected ScannerErr to return nil, got %v", err) + } + }) + + t.Run("ScannerError", func(t *testing.T) { + // Given + shims := NewShims() + r := strings.NewReader(strings.Repeat("x", bufio.MaxScanTokenSize+1)) + scanner := bufio.NewScanner(r) + scanner.Split(bufio.ScanWords) + + // When scanning a token that's too large + ok := shims.ScannerScan(scanner) + if ok { + t.Error("Expected scan to fail") + } + + // Then we should get an error + err := shims.ScannerErr(scanner) + if err == nil { + t.Error("Expected error for token too large") + } + }) +} diff --git a/pkg/shell/unix_shell.go b/pkg/shell/unix_shell.go index 77750da5c..b6f2965be 100644 --- a/pkg/shell/unix_shell.go +++ b/pkg/shell/unix_shell.go @@ -9,6 +9,15 @@ import ( "strings" ) +// The UnixShell is a platform-specific implementation of shell operations for Unix-like systems. +// It provides Unix-specific implementations of environment variable and alias management. +// It handles the differences between Unix shells and Windows PowerShell. +// Key features include Unix-specific command generation for environment variables and aliases. + +// ============================================================================= +// Public Methods +// ============================================================================= + // PrintEnvVars prints the provided environment variables in a sorted order. // If the value of an environment variable is an empty string, it will print an unset command. func (s *DefaultShell) PrintEnvVars(envVars map[string]string) { diff --git a/pkg/shell/unix_shell_test.go b/pkg/shell/unix_shell_test.go new file mode 100644 index 000000000..2bf026b97 --- /dev/null +++ b/pkg/shell/unix_shell_test.go @@ -0,0 +1,252 @@ +//go:build darwin || linux +// +build darwin linux + +package shell + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" +) + +// The UnixShellTest is a test suite for Unix-specific shell operations. +// It provides comprehensive test coverage for environment variable management, +// project root detection, and alias handling on Unix-like systems. +// The UnixShellTest ensures reliable shell operations on macOS and Linux platforms. + +// ============================================================================= +// Test Public Methods +// ============================================================================= + +// TestDefaultShell_PrintEnvVars tests the PrintEnvVars method on Unix systems +func TestDefaultShell_PrintEnvVars(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } + + t.Run("PrintEnvVars", func(t *testing.T) { + // Given a shell with environment variables + shell, _ := setup(t) + envVars := map[string]string{ + "VAR2": "value2", + "VAR1": "value1", + "VAR3": "", + } + expectedOutput := "export VAR1=\"value1\"\nexport VAR2=\"value2\"\nunset VAR3\n" + + // When capturing the output of PrintEnvVars + output := captureStdout(t, func() { + shell.PrintEnvVars(envVars) + }) + + // Then the output should match the expected output + if output != expectedOutput { + t.Errorf("PrintEnvVars() output = %v, want %v", output, expectedOutput) + } + }) +} + +// TestDefaultShell_GetProjectRoot tests the GetProjectRoot method on Unix systems +func TestDefaultShell_GetProjectRoot(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } + + testCases := []struct { + name string + fileName string + }{ + {"WindsorYaml", "windsor.yaml"}, + {"WindsorYml", "windsor.yml"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Given a shell with mocked file system + shell, mocks := setup(t) + + // Mock the file system behavior + rootDir := mocks.TmpDir + subDir := filepath.Join(rootDir, "subdir") + + // Override Getwd to return the subdirectory + shell.shims.Getwd = func() (string, error) { + return subDir, nil + } + + // Override Stat to return nil for the windsor file + shell.shims.Stat = func(name string) (os.FileInfo, error) { + if name == filepath.Join(rootDir, tc.fileName) { + return nil, nil + } + return nil, os.ErrNotExist + } + + // When finding the project root using the specified file + projectRoot, err := shell.GetProjectRoot() + if err != nil { + t.Fatalf("GetProjectRoot returned an error: %v", err) + } + + // Then the project root should match the expected root directory + if projectRoot != rootDir { + t.Errorf("Expected project root to be %s, got %s", rootDir, projectRoot) + } + }) + } +} + +// TestDefaultShell_PrintAlias tests the PrintAlias method on Unix systems +func TestDefaultShell_PrintAlias(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } + + aliasVars := map[string]string{ + "ALIAS1": "command1", + "ALIAS2": "command2", + } + + t.Run("PrintAlias", func(t *testing.T) { + // Given a default shell + shell, _ := setup(t) + + // When capturing the output of PrintAlias + output := captureStdout(t, func() { + shell.PrintAlias(aliasVars) + }) + + // Then the output should contain all expected alias variables + for key, value := range aliasVars { + expectedLine := fmt.Sprintf("alias %s=\"%s\"\n", key, value) + if !strings.Contains(output, expectedLine) { + t.Errorf("PrintAlias() output missing expected line: %v", expectedLine) + } + } + }) + + t.Run("PrintAliasWithEmptyValue", func(t *testing.T) { + // Given a default shell with an alias having an empty value + shell, _ := setup(t) + aliasVarsWithEmpty := map[string]string{ + "ALIAS1": "command1", + "ALIAS2": "", + } + + // When capturing the output of PrintAlias + output := captureStdout(t, func() { + shell.PrintAlias(aliasVarsWithEmpty) + }) + + // Then the output should contain the expected alias and unalias commands + expectedAliasLine := fmt.Sprintf("alias ALIAS1=\"command1\"\n") + expectedUnaliasLine := fmt.Sprintf("unalias ALIAS2\n") + + if !strings.Contains(output, expectedAliasLine) { + t.Errorf("PrintAlias() output missing expected line: %v", expectedAliasLine) + } + if !strings.Contains(output, expectedUnaliasLine) { + t.Errorf("PrintAlias() output missing expected line: %v", expectedUnaliasLine) + } + }) +} + +// TestDefaultShell_UnsetEnvs tests the UnsetEnvs method on Unix systems +func TestDefaultShell_UnsetEnvs(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } + + t.Run("UnsetEnvs", func(t *testing.T) { + // Given a set of environment variables to unset + shell, _ := setup(t) + envVars := []string{"VAR1", "VAR2", "VAR3"} + expectedOutput := "unset VAR1 VAR2 VAR3\n" + + // When capturing the output of UnsetEnvs + output := captureStdout(t, func() { + shell.UnsetEnvs(envVars) + }) + + // Then the output should match the expected output + if output != expectedOutput { + t.Errorf("UnsetEnvs() output = %v, want %v", output, expectedOutput) + } + }) + + t.Run("UnsetEnvsWithEmptyList", func(t *testing.T) { + // Given an empty list of environment variables + shell, _ := setup(t) + + // When capturing the output of UnsetEnvs + output := captureStdout(t, func() { + shell.UnsetEnvs([]string{}) + }) + + // Then the output should be empty + if output != "" { + t.Errorf("UnsetEnvs() with empty list should produce no output, got %v", output) + } + }) +} + +// TestDefaultShell_UnsetAlias tests the UnsetAlias method on Unix systems +func TestDefaultShell_UnsetAlias(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } + + t.Run("UnsetAlias", func(t *testing.T) { + // Given a set of aliases to unset + shell, _ := setup(t) + aliases := []string{"ALIAS1", "ALIAS2", "ALIAS3"} + expectedOutput := "unalias ALIAS1\nunalias ALIAS2\nunalias ALIAS3\n" + + // When capturing the output of UnsetAlias + output := captureStdout(t, func() { + shell.UnsetAlias(aliases) + }) + + // Then the output should match the expected output + if output != expectedOutput { + t.Errorf("UnsetAlias() output = %v, want %v", output, expectedOutput) + } + }) + + t.Run("UnsetAliasWithEmptyList", func(t *testing.T) { + // Given an empty list of aliases + shell, _ := setup(t) + + // When capturing the output of UnsetAlias + output := captureStdout(t, func() { + shell.UnsetAlias([]string{}) + }) + + // Then the output should be empty + if output != "" { + t.Errorf("UnsetAlias() with empty list should produce no output, got %v", output) + } + }) +} diff --git a/pkg/shell/unix_test.go b/pkg/shell/unix_test.go deleted file mode 100644 index eb2438078..000000000 --- a/pkg/shell/unix_test.go +++ /dev/null @@ -1,202 +0,0 @@ -//go:build darwin || linux -// +build darwin linux - -package shell - -import ( - "fmt" - "os" - "path/filepath" - "strings" - "testing" - - "github.com/windsorcli/cli/pkg/di" -) - -func TestDefaultShell_PrintEnvVars(t *testing.T) { - injector := di.NewInjector() - - // Given a set of environment variables - shell := NewDefaultShell(injector) - envVars := map[string]string{ - "VAR2": "value2", - "VAR1": "value1", - "VAR3": "", - } - expectedOutput := "export VAR1=\"value1\"\nexport VAR2=\"value2\"\nunset VAR3\n" - - // When capturing the output of PrintEnvVars - output := captureStdout(t, func() { - shell.PrintEnvVars(envVars) - }) - - // Then the output should match the expected output - if output != expectedOutput { - t.Errorf("PrintEnvVars() output = %v, want %v", output, expectedOutput) - } -} - -func TestDefaultShell_GetProjectRoot(t *testing.T) { - testCases := []struct { - name string - fileName string - }{ - {"WindsorYaml", "windsor.yaml"}, - {"WindsorYml", "windsor.yml"}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - injector := di.NewInjector() - - // Given a temporary directory structure with the specified file - rootDir := createTempDir(t, "project-root") - defer os.RemoveAll(rootDir) - - subDir := filepath.Join(rootDir, "subdir") - if err := os.Mkdir(subDir, 0755); err != nil { - t.Fatalf("Failed to create subdir: %v", err) - } - - // When creating the specified file in the root directory - createFile(t, rootDir, tc.fileName, "") - - // And changing the working directory to subDir - changeDir(t, subDir) - - shell := NewDefaultShell(injector) - - // Then GetProjectRoot should find the project root using the specified file - projectRoot, err := shell.GetProjectRoot() - if err != nil { - t.Fatalf("GetProjectRoot returned an error: %v", err) - } - - // Resolve symlinks to handle macOS /private prefix - expectedRootDir, err := filepath.EvalSymlinks(rootDir) - if err != nil { - t.Fatalf("Failed to evaluate symlinks for rootDir: %v", err) - } - - // Normalize paths for comparison - expectedRootDir = normalizePath(expectedRootDir) - projectRoot = normalizePath(projectRoot) - - if projectRoot != expectedRootDir { - t.Errorf("Expected project root to be %s, got %s", expectedRootDir, projectRoot) - } - }) - } -} - -func TestDefaultShell_PrintAlias(t *testing.T) { - aliasVars := map[string]string{ - "ALIAS1": "command1", - "ALIAS2": "command2", - } - - t.Run("PrintAlias", func(t *testing.T) { - injector := di.NewInjector() - - // Given a default shell - shell := NewDefaultShell(injector) - - // Capture the output of PrintAlias - output := captureStdout(t, func() { - shell.PrintAlias(aliasVars) - }) - - // Then the output should contain all expected alias variables - for key, value := range aliasVars { - expectedLine := fmt.Sprintf("alias %s=\"%s\"\n", key, value) - if !strings.Contains(output, expectedLine) { - t.Errorf("PrintAlias() output missing expected line: %v", expectedLine) - } - } - }) - - t.Run("PrintAliasWithEmptyValue", func(t *testing.T) { - injector := di.NewInjector() - - // Given a default shell with an alias having an empty value - shell := NewDefaultShell(injector) - aliasVarsWithEmpty := map[string]string{ - "ALIAS1": "command1", - "ALIAS2": "", - } - - // Capture the output of PrintAlias - output := captureStdout(t, func() { - shell.PrintAlias(aliasVarsWithEmpty) - }) - - // Then the output should contain the expected alias and unalias commands - expectedAliasLine := fmt.Sprintf("alias ALIAS1=\"command1\"\n") - expectedUnaliasLine := fmt.Sprintf("unalias ALIAS2\n") - - if !strings.Contains(output, expectedAliasLine) { - t.Errorf("PrintAlias() output missing expected line: %v", expectedAliasLine) - } - if !strings.Contains(output, expectedUnaliasLine) { - t.Errorf("PrintAlias() output missing expected line: %v", expectedUnaliasLine) - } - }) -} - -func TestDefaultShell_UnsetEnvs(t *testing.T) { - injector := di.NewInjector() - - // Given a set of environment variables to unset - shell := NewDefaultShell(injector) - envVars := []string{"VAR1", "VAR2", "VAR3"} - expectedOutput := "unset VAR1 VAR2 VAR3\n" - - // When capturing the output of UnsetEnvs - output := captureStdout(t, func() { - shell.UnsetEnvs(envVars) - }) - - // Then the output should match the expected output - if output != expectedOutput { - t.Errorf("UnsetEnvs() output = %v, want %v", output, expectedOutput) - } - - // Test with empty list - emptyOutput := captureStdout(t, func() { - shell.UnsetEnvs([]string{}) - }) - - // Then the output should be empty - if emptyOutput != "" { - t.Errorf("UnsetEnvs() with empty list should produce no output, got %v", emptyOutput) - } -} - -func TestDefaultShell_UnsetAlias(t *testing.T) { - injector := di.NewInjector() - - // Given a set of aliases to unset - shell := NewDefaultShell(injector) - aliases := []string{"ALIAS1", "ALIAS2", "ALIAS3"} - expectedOutput := "unalias ALIAS1\nunalias ALIAS2\nunalias ALIAS3\n" - - // When capturing the output of UnsetAlias - output := captureStdout(t, func() { - shell.UnsetAlias(aliases) - }) - - // Then the output should match the expected output - if output != expectedOutput { - t.Errorf("UnsetAlias() output = %v, want %v", output, expectedOutput) - } - - // Test with empty list - emptyOutput := captureStdout(t, func() { - shell.UnsetAlias([]string{}) - }) - - // Then the output should be empty - if emptyOutput != "" { - t.Errorf("UnsetAlias() with empty list should produce no output, got %v", emptyOutput) - } -} diff --git a/pkg/shell/windows_shell.go b/pkg/shell/windows_shell.go index e3520136b..6a117cb10 100644 --- a/pkg/shell/windows_shell.go +++ b/pkg/shell/windows_shell.go @@ -8,6 +8,15 @@ import ( "sort" ) +// The WindowsShell is a platform-specific implementation of shell operations for Windows systems. +// It provides Windows PowerShell-specific implementations of environment variable and alias management. +// It handles the differences between Windows PowerShell and Unix shells. +// Key features include PowerShell-specific command generation for environment variables and aliases. + +// ============================================================================= +// Public Methods +// ============================================================================= + // PrintEnvVars sorts and prints environment variables. Empty values trigger a removal command. func (s *DefaultShell) PrintEnvVars(envVars map[string]string) { keys := make([]string, 0, len(envVars)) diff --git a/pkg/shell/windows_shell_test.go b/pkg/shell/windows_shell_test.go new file mode 100644 index 000000000..a6ba43e34 --- /dev/null +++ b/pkg/shell/windows_shell_test.go @@ -0,0 +1,329 @@ +//go:build windows +// +build windows + +package shell + +import ( + "fmt" + "io" + "os" + "path/filepath" + "strings" + "testing" + + "golang.org/x/sys/windows" +) + +// The WindowsShellTest is a test suite for Windows-specific shell operations. +// It provides comprehensive test coverage for PowerShell environment management, +// project root detection, and alias handling on Windows systems. +// The WindowsShellTest ensures reliable shell operations on Windows platforms. + +// ============================================================================= +// Test Setup +// ============================================================================= + +// Helper function to get the long path name on Windows +// This function converts a short path to its long form +func getLongPathName(shortPath string) (string, error) { + p, err := windows.UTF16PtrFromString(shortPath) + if err != nil { + return "", err + } + b := make([]uint16, windows.MAX_LONG_PATH) + r, err := windows.GetLongPathName(p, &b[0], uint32(len(b))) + if r == 0 { + return "", err + } + return windows.UTF16ToString(b), nil +} + +// Helper function to normalize a Windows path +// This function ensures the path is in its long form and normalized +func normalizeWindowsPath(path string) string { + longPath, err := getLongPathName(path) + if err != nil { + return normalizePath(path) + } + return normalizePath(longPath) +} + +// ============================================================================= +// Test Public Methods +// ============================================================================= + +// TestDefaultShell_PrintEnvVars tests the PrintEnvVars method on Windows systems +func TestDefaultShell_PrintEnvVars(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } + + t.Run("PrintEnvVars", func(t *testing.T) { + // Given a shell with environment variables + shell, _ := setup(t) + envVars := map[string]string{ + "VAR2": "value2", + "VAR1": "value1", + "VAR3": "", + } + expectedOutput := "$env:VAR1='value1'\n$env:VAR2='value2'\nRemove-Item Env:VAR3\n" + + // When capturing the output of PrintEnvVars + output := captureStdout(t, func() { + shell.PrintEnvVars(envVars) + }) + + // Then the output should match the expected output + if output != expectedOutput { + t.Errorf("PrintEnvVars() output = %v, want %v", output, expectedOutput) + } + }) +} + +// TestDefaultShell_GetProjectRoot tests the GetProjectRoot method on Windows systems +func TestDefaultShell_GetProjectRoot(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } + + testCases := []struct { + name string + fileName string + }{ + {"WindsorYaml", "windsor.yaml"}, + {"WindsorYml", "windsor.yml"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Given a shell with mocked file system + shell, mocks := setup(t) + rootDir := "C:\\test\\project" + subDir := filepath.Join(rootDir, "subdir") + + // Mock Getwd to return the subdirectory + mocks.Shims.Getwd = func() (string, error) { + return subDir, nil + } + + // Mock Stat to return nil for the windsor file + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + if name == filepath.Join(rootDir, tc.fileName) { + return nil, nil + } + return nil, os.ErrNotExist + } + + // When finding the project root + projectRoot, err := shell.GetProjectRoot() + if err != nil { + t.Fatalf("GetProjectRoot returned an error: %v", err) + } + + // Then the project root should match the expected root directory + if projectRoot != rootDir { + t.Errorf("Expected project root to be %s, got %s", rootDir, projectRoot) + } + }) + } +} + +// TestDefaultShell_PrintAlias tests the PrintAlias method on Windows systems +func TestDefaultShell_PrintAlias(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } + + aliasVars := map[string]string{ + "ALIAS1": "command1", + "ALIAS2": "command2", + } + + t.Run("PrintAlias", func(t *testing.T) { + // Given a default shell + shell, _ := setup(t) + + // When capturing the output of PrintAlias + output := captureStdout(t, func() { + shell.PrintAlias(aliasVars) + }) + + // Then the output should contain all expected alias variables + for key, value := range aliasVars { + expectedLine := fmt.Sprintf("Set-Alias -Name %s -Value \"%s\"\n", key, value) + if !strings.Contains(output, expectedLine) { + t.Errorf("PrintAlias() output missing expected line: %v", expectedLine) + } + } + }) + + t.Run("PrintAliasWithEmptyValue", func(t *testing.T) { + // Given a default shell with an alias having an empty value + shell, _ := setup(t) + aliasVarsWithEmpty := map[string]string{ + "ALIAS1": "command1", + "ALIAS2": "", + } + + // When capturing the output of PrintAlias + output := captureStdout(t, func() { + shell.PrintAlias(aliasVarsWithEmpty) + }) + + // Then the output should contain the expected alias and remove alias commands + expectedAliasLine := fmt.Sprintf("Set-Alias -Name ALIAS1 -Value \"command1\"\n") + expectedRemoveAliasLine := fmt.Sprintf("Remove-Item Alias:ALIAS2\n") + + if !strings.Contains(output, expectedAliasLine) { + t.Errorf("PrintAlias() output missing expected line: %v", expectedAliasLine) + } + if !strings.Contains(output, expectedRemoveAliasLine) { + t.Errorf("PrintAlias() output missing expected line: %v", expectedRemoveAliasLine) + } + }) +} + +// TestDefaultShell_UnsetEnvs tests the UnsetEnvs method on Windows systems +func TestDefaultShell_UnsetEnvs(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } + + t.Run("UnsetEnvs", func(t *testing.T) { + // Given a set of environment variables to unset + shell, _ := setup(t) + envVars := []string{"VAR1", "VAR2", "VAR3"} + expectedOutput := "Remove-Item Env:VAR1\nRemove-Item Env:VAR2\nRemove-Item Env:VAR3\n" + + // When capturing the output of UnsetEnvs + output := captureStdout(t, func() { + shell.UnsetEnvs(envVars) + }) + + // Then the output should match the expected output + if output != expectedOutput { + t.Errorf("UnsetEnvs() output = %v, want %v", output, expectedOutput) + } + }) + + t.Run("UnsetEnvsWithEmptyList", func(t *testing.T) { + // Given an empty list of environment variables + shell, _ := setup(t) + + // When capturing the output of UnsetEnvs + output := captureStdout(t, func() { + shell.UnsetEnvs([]string{}) + }) + + // Then the output should be empty + if output != "" { + t.Errorf("UnsetEnvs() with empty list should produce no output, got %v", output) + } + }) +} + +// TestDefaultShell_UnsetAlias tests the UnsetAlias method on Windows systems +func TestDefaultShell_UnsetAlias(t *testing.T) { + setup := func(t *testing.T) (*DefaultShell, *Mocks) { + t.Helper() + mocks := setupMocks(t) + shell := NewDefaultShell(mocks.Injector) + shell.shims = mocks.Shims + return shell, mocks + } + + t.Run("UnsetAlias", func(t *testing.T) { + // Given a set of aliases to unset + shell, _ := setup(t) + aliases := []string{"ALIAS1", "ALIAS2", "ALIAS3"} + expectedOutput := "Remove-Item Alias:ALIAS1\nRemove-Item Alias:ALIAS2\nRemove-Item Alias:ALIAS3\n" + + // When capturing the output of UnsetAlias + output := captureStdout(t, func() { + shell.UnsetAlias(aliases) + }) + + // Then the output should match the expected output + if output != expectedOutput { + t.Errorf("UnsetAlias() output = %v, want %v", output, expectedOutput) + } + }) + + t.Run("UnsetAliasWithEmptyList", func(t *testing.T) { + // Given an empty list of aliases + shell, _ := setup(t) + + // When capturing the output of UnsetAlias + output := captureStdout(t, func() { + shell.UnsetAlias([]string{}) + }) + + // Then the output should be empty + if output != "" { + t.Errorf("UnsetAlias() with empty list should produce no output, got %v", output) + } + }) +} + +// ============================================================================= +// Helper Functions +// ============================================================================= + +// Helper function to change the current directory +func changeDir(t *testing.T, dir string) { + t.Helper() + if err := os.Chdir(dir); err != nil { + t.Fatalf("Failed to change directory to %s: %v", dir, err) + } +} + +// Helper function to normalize a path for comparison +func normalizePath(path string) string { + return filepath.Clean(path) +} + +// Helper function to capture stdout from a function +func captureStdoutFromFunc(t *testing.T, fn func()) string { + t.Helper() + + // Create a pipe to capture stdout + oldStdout := os.Stdout + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("Failed to create pipe: %v", err) + } + os.Stdout = w + + // Run the function + fn() + + // Close the writer + w.Close() + + // Restore stdout + os.Stdout = oldStdout + + // Read the output + buf, err := io.ReadAll(r) + if err != nil { + t.Fatalf("Failed to read from pipe: %v", err) + } + + return string(buf) +} diff --git a/pkg/shell/windows_test.go b/pkg/shell/windows_test.go deleted file mode 100644 index cdd9b8637..000000000 --- a/pkg/shell/windows_test.go +++ /dev/null @@ -1,238 +0,0 @@ -//go:build windows -// +build windows - -package shell - -import ( - "bytes" - "fmt" - "os" - "path/filepath" - "strings" - "testing" - - "github.com/windsorcli/cli/pkg/di" - "golang.org/x/sys/windows" -) - -// Helper function to get the long path name on Windows -// This function converts a short path to its long form -func getLongPathName(shortPath string) (string, error) { - p, err := windows.UTF16PtrFromString(shortPath) - if err != nil { - return "", err - } - b := make([]uint16, windows.MAX_LONG_PATH) - r, err := windows.GetLongPathName(p, &b[0], uint32(len(b))) - if r == 0 { - return "", err - } - return windows.UTF16ToString(b), nil -} - -// Helper function to normalize a Windows path -// This function ensures the path is in its long form and normalized -func normalizeWindowsPath(path string) string { - longPath, err := getLongPathName(path) - if err != nil { - return normalizePath(path) - } - return normalizePath(longPath) -} - -func TestDefaultShell_PrintEnvVars(t *testing.T) { - injector := di.NewInjector() - - // Given a default shell and a set of environment variables - shell := NewDefaultShell(injector) - envVars := map[string]string{ - "VAR2": "value2", - "VAR1": "value1", - "VAR3": "", - } - - // Expected output for PowerShell - expectedOutputPowerShell := "$env:VAR1='value1'\n$env:VAR2='value2'\nRemove-Item Env:VAR3\n" - - // Capture the output - var output bytes.Buffer - originalStdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - - // Run PrintEnvVars in a goroutine and capture its output - go func() { - shell.PrintEnvVars(envVars) - w.Close() - }() - - output.ReadFrom(r) - os.Stdout = originalStdout - - // Then the output should match the expected PowerShell format - if output.String() != expectedOutputPowerShell { - t.Errorf("PrintEnvVars() output = %v, want %v", output.String(), expectedOutputPowerShell) - } -} - -func TestDefaultShell_GetProjectRoot(t *testing.T) { - injector := di.NewInjector() - - testCases := []struct { - name string - fileName string - }{ - {"WindsorYaml", "windsor.yaml"}, - {"WindsorYml", "windsor.yml"}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Given a temporary directory structure with the specified file - rootDir := createTempDir(t, "project-root") - subDir := filepath.Join(rootDir, "subdir") - if err := os.Mkdir(subDir, 0755); err != nil { - t.Fatalf("Failed to create subdir: %v", err) - } - - // When creating the specified file in the root directory - createFile(t, rootDir, tc.fileName, "") - - // And changing the working directory to subDir - changeDir(t, subDir) - - shell := NewDefaultShell(injector) - - // When finding the project root using the specified file - projectRoot, err := shell.GetProjectRoot() - if err != nil { - t.Fatalf("GetProjectRoot returned an error: %v", err) - } - - // Resolve symlinks to handle macOS /private prefix - expectedRootDir, err := filepath.EvalSymlinks(rootDir) - if err != nil { - t.Fatalf("Failed to evaluate symlinks for rootDir: %v", err) - } - - // Normalize paths for comparison - expectedRootDir = normalizeWindowsPath(expectedRootDir) - projectRoot = normalizeWindowsPath(projectRoot) - - // Then the project root should match the expected root directory - if projectRoot != expectedRootDir { - t.Errorf("Expected project root to be %s, got %s", expectedRootDir, projectRoot) - } - }) - } -} - -func TestDefaultShell_PrintAlias(t *testing.T) { - aliasVars := map[string]string{ - "ALIAS1": "command1", - "ALIAS2": "command2", - } - - t.Run("PrintAlias", func(t *testing.T) { - // Given a default shell - injector := di.NewInjector() - shell := NewDefaultShell(injector) - - // Capture the output of PrintAlias - output := captureStdout(t, func() { - shell.PrintAlias(aliasVars) - }) - - // Then the output should contain all expected alias variables - for key, value := range aliasVars { - expectedLine := fmt.Sprintf("Set-Alias -Name %s -Value \"%s\"\n", key, value) - if !strings.Contains(output, expectedLine) { - t.Errorf("PrintAlias() output missing expected line: %v", expectedLine) - } - } - }) - - t.Run("PrintAliasWithEmptyValue", func(t *testing.T) { - // Given a default shell with an alias having an empty value - injector := di.NewInjector() - shell := NewDefaultShell(injector) - aliasVarsWithEmpty := map[string]string{ - "ALIAS1": "command1", - "ALIAS2": "", - } - - // Capture the output of PrintAlias - output := captureStdout(t, func() { - shell.PrintAlias(aliasVarsWithEmpty) - }) - - // Then the output should contain the expected alias and remove alias commands - expectedAliasLine := fmt.Sprintf("Set-Alias -Name ALIAS1 -Value \"command1\"\n") - expectedRemoveAliasLine := fmt.Sprintf("Remove-Item Alias:ALIAS2\n") - - if !strings.Contains(output, expectedAliasLine) { - t.Errorf("PrintAlias() output missing expected line: %v", expectedAliasLine) - } - if !strings.Contains(output, expectedRemoveAliasLine) { - t.Errorf("PrintAlias() output missing expected line: %v", expectedRemoveAliasLine) - } - }) -} - -func TestDefaultShell_UnsetEnvs(t *testing.T) { - injector := di.NewInjector() - - // Given a set of environment variables to unset - shell := NewDefaultShell(injector) - envVars := []string{"VAR1", "VAR2", "VAR3"} - expectedOutput := "Remove-Item Env:VAR1\nRemove-Item Env:VAR2\nRemove-Item Env:VAR3\n" - - // When capturing the output of UnsetEnvs - output := captureStdout(t, func() { - shell.UnsetEnvs(envVars) - }) - - // Then the output should match the expected output - if output != expectedOutput { - t.Errorf("UnsetEnvs() output = %v, want %v", output, expectedOutput) - } - - // Test with empty list - emptyOutput := captureStdout(t, func() { - shell.UnsetEnvs([]string{}) - }) - - // Then the output should be empty - if emptyOutput != "" { - t.Errorf("UnsetEnvs() with empty list should produce no output, got %v", emptyOutput) - } -} - -func TestDefaultShell_UnsetAlias(t *testing.T) { - injector := di.NewInjector() - - // Given a set of aliases to unset - shell := NewDefaultShell(injector) - aliases := []string{"ALIAS1", "ALIAS2", "ALIAS3"} - expectedOutput := "Remove-Item Alias:ALIAS1\nRemove-Item Alias:ALIAS2\nRemove-Item Alias:ALIAS3\n" - - // When capturing the output of UnsetAlias - output := captureStdout(t, func() { - shell.UnsetAlias(aliases) - }) - - // Then the output should match the expected output - if output != expectedOutput { - t.Errorf("UnsetAlias() output = %v, want %v", output, expectedOutput) - } - - // Test with empty list - emptyOutput := captureStdout(t, func() { - shell.UnsetAlias([]string{}) - }) - - // Then the output should be empty - if emptyOutput != "" { - t.Errorf("UnsetAlias() with empty list should produce no output, got %v", emptyOutput) - } -}