diff --git a/api/v1alpha1/cluster/cluster_config.go b/api/v1alpha1/cluster/cluster_config.go index c5f987905..f5de0bc39 100644 --- a/api/v1alpha1/cluster/cluster_config.go +++ b/api/v1alpha1/cluster/cluster_config.go @@ -2,22 +2,10 @@ package cluster // ClusterConfig represents the cluster configuration type ClusterConfig struct { - Enabled *bool `yaml:"enabled"` - Driver *string `yaml:"driver"` - ControlPlanes struct { - Count *int `yaml:"count,omitempty"` - CPU *int `yaml:"cpu,omitempty"` - Memory *int `yaml:"memory,omitempty"` - Nodes map[string]NodeConfig `yaml:"nodes,omitempty"` - HostPorts []string `yaml:"hostports,omitempty"` - } `yaml:"controlplanes,omitempty"` - Workers struct { - Count *int `yaml:"count,omitempty"` - CPU *int `yaml:"cpu,omitempty"` - Memory *int `yaml:"memory,omitempty"` - Nodes map[string]NodeConfig `yaml:"nodes,omitempty"` - HostPorts []string `yaml:"hostports,omitempty"` - } `yaml:"workers,omitempty"` + Enabled *bool `yaml:"enabled"` + Driver *string `yaml:"driver"` + ControlPlanes NodeGroupConfig `yaml:"controlplanes,omitempty"` + Workers NodeGroupConfig `yaml:"workers,omitempty"` } // NodeConfig represents the node configuration @@ -28,6 +16,16 @@ type NodeConfig struct { HostPorts []string `yaml:"hostports,omitempty"` } +// NodeGroupConfig represents the configuration for a group of nodes +type NodeGroupConfig struct { + Count *int `yaml:"count,omitempty"` + CPU *int `yaml:"cpu,omitempty"` + Memory *int `yaml:"memory,omitempty"` + Nodes map[string]NodeConfig `yaml:"nodes,omitempty"` + HostPorts []string `yaml:"hostports,omitempty"` + Volumes []string `yaml:"volumes,omitempty"` +} + // Merge performs a deep merge of the current ClusterConfig with another ClusterConfig. func (base *ClusterConfig) Merge(overlay *ClusterConfig) { if overlay.Enabled != nil { @@ -51,6 +49,14 @@ func (base *ClusterConfig) Merge(overlay *ClusterConfig) { base.ControlPlanes.Nodes[key] = node } } + if overlay.ControlPlanes.HostPorts != nil { + base.ControlPlanes.HostPorts = make([]string, len(overlay.ControlPlanes.HostPorts)) + copy(base.ControlPlanes.HostPorts, overlay.ControlPlanes.HostPorts) + } + if overlay.ControlPlanes.Volumes != nil { + base.ControlPlanes.Volumes = make([]string, len(overlay.ControlPlanes.Volumes)) + copy(base.ControlPlanes.Volumes, overlay.ControlPlanes.Volumes) + } if overlay.Workers.Count != nil { base.Workers.Count = overlay.Workers.Count } @@ -70,6 +76,10 @@ func (base *ClusterConfig) Merge(overlay *ClusterConfig) { base.Workers.HostPorts = make([]string, len(overlay.Workers.HostPorts)) copy(base.Workers.HostPorts, overlay.Workers.HostPorts) } + if overlay.Workers.Volumes != nil { + base.Workers.Volumes = make([]string, len(overlay.Workers.Volumes)) + copy(base.Workers.Volumes, overlay.Workers.Volumes) + } } // Copy creates a deep copy of the ClusterConfig object @@ -84,11 +94,13 @@ func (c *ClusterConfig) Copy() *ClusterConfig { Hostname: node.Hostname, Node: node.Node, Endpoint: node.Endpoint, - HostPorts: append([]string{}, node.HostPorts...), // Copy HostPorts for each node + HostPorts: append([]string{}, node.HostPorts...), } } controlPlanesHostPortsCopy := make([]string, len(c.ControlPlanes.HostPorts)) copy(controlPlanesHostPortsCopy, c.ControlPlanes.HostPorts) + controlPlanesVolumesCopy := make([]string, len(c.ControlPlanes.Volumes)) + copy(controlPlanesVolumesCopy, c.ControlPlanes.Volumes) workersNodesCopy := make(map[string]NodeConfig, len(c.Workers.Nodes)) for key, node := range c.Workers.Nodes { @@ -96,40 +108,32 @@ func (c *ClusterConfig) Copy() *ClusterConfig { Hostname: node.Hostname, Node: node.Node, Endpoint: node.Endpoint, - HostPorts: append([]string{}, node.HostPorts...), // Copy HostPorts for each node + HostPorts: append([]string{}, node.HostPorts...), } } workersHostPortsCopy := make([]string, len(c.Workers.HostPorts)) copy(workersHostPortsCopy, c.Workers.HostPorts) + workersVolumesCopy := make([]string, len(c.Workers.Volumes)) + copy(workersVolumesCopy, c.Workers.Volumes) return &ClusterConfig{ Enabled: c.Enabled, Driver: c.Driver, - ControlPlanes: struct { - Count *int `yaml:"count,omitempty"` - CPU *int `yaml:"cpu,omitempty"` - Memory *int `yaml:"memory,omitempty"` - Nodes map[string]NodeConfig `yaml:"nodes,omitempty"` - HostPorts []string `yaml:"hostports,omitempty"` - }{ + ControlPlanes: NodeGroupConfig{ Count: c.ControlPlanes.Count, CPU: c.ControlPlanes.CPU, Memory: c.ControlPlanes.Memory, Nodes: controlPlanesNodesCopy, HostPorts: controlPlanesHostPortsCopy, + Volumes: controlPlanesVolumesCopy, }, - Workers: struct { - Count *int `yaml:"count,omitempty"` - CPU *int `yaml:"cpu,omitempty"` - Memory *int `yaml:"memory,omitempty"` - Nodes map[string]NodeConfig `yaml:"nodes,omitempty"` - HostPorts []string `yaml:"hostports,omitempty"` - }{ + Workers: NodeGroupConfig{ Count: c.Workers.Count, CPU: c.Workers.CPU, Memory: c.Workers.Memory, Nodes: workersNodesCopy, HostPorts: workersHostPortsCopy, + Volumes: workersVolumesCopy, }, } } diff --git a/api/v1alpha1/cluster/cluster_config_test.go b/api/v1alpha1/cluster/cluster_config_test.go index 1abfd7e60..9ed5d282a 100644 --- a/api/v1alpha1/cluster/cluster_config_test.go +++ b/api/v1alpha1/cluster/cluster_config_test.go @@ -28,13 +28,18 @@ func TestClusterConfig_Merge(t *testing.T) { Memory *int `yaml:"memory,omitempty"` Nodes map[string]NodeConfig `yaml:"nodes,omitempty"` HostPorts []string `yaml:"hostports,omitempty"` + Volumes []string `yaml:"volumes,omitempty"` }{ Count: ptrInt(3), CPU: ptrInt(4), Memory: ptrInt(8192), Nodes: map[string]NodeConfig{ - "node1": {Hostname: ptrString("base-node1")}, + "node1": { + Hostname: ptrString("base-node1"), + }, }, + HostPorts: []string{"1000:1000/tcp", "2000:2000/tcp"}, + Volumes: []string{"${WINDSOR_PROJECT_ROOT}/base/volume1:/var/local/base1"}, }, Workers: struct { Count *int `yaml:"count,omitempty"` @@ -42,14 +47,18 @@ func TestClusterConfig_Merge(t *testing.T) { Memory *int `yaml:"memory,omitempty"` Nodes map[string]NodeConfig `yaml:"nodes,omitempty"` HostPorts []string `yaml:"hostports,omitempty"` + Volumes []string `yaml:"volumes,omitempty"` }{ Count: ptrInt(5), CPU: ptrInt(2), Memory: ptrInt(4096), Nodes: map[string]NodeConfig{ - "worker1": {Hostname: ptrString("base-worker1")}, + "worker1": { + Hostname: ptrString("base-worker1"), + }, }, HostPorts: []string{"8080", "9090"}, + Volumes: []string{"${WINDSOR_PROJECT_ROOT}/base/worker/volume1:/var/local/worker1"}, }, } @@ -62,13 +71,18 @@ func TestClusterConfig_Merge(t *testing.T) { Memory *int `yaml:"memory,omitempty"` Nodes map[string]NodeConfig `yaml:"nodes,omitempty"` HostPorts []string `yaml:"hostports,omitempty"` + Volumes []string `yaml:"volumes,omitempty"` }{ Count: ptrInt(1), CPU: ptrInt(2), Memory: ptrInt(4096), Nodes: map[string]NodeConfig{ - "node2": {Hostname: ptrString("overlay-node2")}, + "node2": { + Hostname: ptrString("overlay-node2"), + }, }, + HostPorts: []string{"3000:3000/tcp", "4000:4000/tcp"}, + Volumes: []string{"${WINDSOR_PROJECT_ROOT}/overlay/volume2:/var/local/overlay2"}, }, Workers: struct { Count *int `yaml:"count,omitempty"` @@ -76,14 +90,18 @@ func TestClusterConfig_Merge(t *testing.T) { Memory *int `yaml:"memory,omitempty"` Nodes map[string]NodeConfig `yaml:"nodes,omitempty"` HostPorts []string `yaml:"hostports,omitempty"` + Volumes []string `yaml:"volumes,omitempty"` }{ Count: ptrInt(3), CPU: ptrInt(1), Memory: ptrInt(2048), Nodes: map[string]NodeConfig{ - "worker2": {Hostname: ptrString("overlay-worker2")}, + "worker2": { + Hostname: ptrString("overlay-worker2"), + }, }, HostPorts: []string{"8082", "9092"}, + Volumes: []string{"${WINDSOR_PROJECT_ROOT}/overlay/worker/volume2:/var/local/worker2"}, }, } @@ -95,8 +113,11 @@ func TestClusterConfig_Merge(t *testing.T) { if base.Driver == nil || *base.Driver != "overlay-driver" { t.Errorf("Driver mismatch: expected 'overlay-driver', got '%s'", *base.Driver) } + if len(base.ControlPlanes.HostPorts) != 2 || base.ControlPlanes.HostPorts[0] != "3000:3000/tcp" || base.ControlPlanes.HostPorts[1] != "4000:4000/tcp" { + t.Errorf("ControlPlanes HostPorts mismatch: expected ['3000:3000/tcp', '4000:4000/tcp'], got %v", base.ControlPlanes.HostPorts) + } if len(base.Workers.HostPorts) != 2 || base.Workers.HostPorts[0] != "8082" || base.Workers.HostPorts[1] != "9092" { - t.Errorf("HostPorts mismatch: expected ['8082', '9092'], got %v", base.Workers.HostPorts) + t.Errorf("Workers HostPorts mismatch: expected ['8082', '9092'], got %v", base.Workers.HostPorts) } if base.ControlPlanes.Count == nil || *base.ControlPlanes.Count != 1 { t.Errorf("ControlPlanes Count mismatch: expected 1, got %v", *base.ControlPlanes.Count) @@ -122,6 +143,9 @@ func TestClusterConfig_Merge(t *testing.T) { if len(base.Workers.Nodes) != 1 || base.Workers.Nodes["worker2"].Hostname == nil || *base.Workers.Nodes["worker2"].Hostname != "overlay-worker2" { t.Errorf("Workers Nodes mismatch: expected 'overlay-worker2', got %v", base.Workers.Nodes) } + if len(base.Workers.Volumes) != 1 || base.Workers.Volumes[0] != "${WINDSOR_PROJECT_ROOT}/overlay/worker/volume2:/var/local/worker2" { + t.Errorf("Workers Volumes mismatch: expected ['${WINDSOR_PROJECT_ROOT}/overlay/worker/volume2:/var/local/worker2'], got %v", base.Workers.Volumes) + } }) t.Run("MergeWithAllNils", func(t *testing.T) { @@ -134,12 +158,14 @@ func TestClusterConfig_Merge(t *testing.T) { Memory *int `yaml:"memory,omitempty"` Nodes map[string]NodeConfig `yaml:"nodes,omitempty"` HostPorts []string `yaml:"hostports,omitempty"` + Volumes []string `yaml:"volumes,omitempty"` }{ Count: nil, CPU: nil, Memory: nil, Nodes: nil, HostPorts: nil, + Volumes: nil, }, Workers: struct { Count *int `yaml:"count,omitempty"` @@ -147,12 +173,14 @@ func TestClusterConfig_Merge(t *testing.T) { Memory *int `yaml:"memory,omitempty"` Nodes map[string]NodeConfig `yaml:"nodes,omitempty"` HostPorts []string `yaml:"hostports,omitempty"` + Volumes []string `yaml:"volumes,omitempty"` }{ Count: nil, CPU: nil, Memory: nil, Nodes: nil, HostPorts: nil, + Volumes: nil, }, } @@ -165,12 +193,14 @@ func TestClusterConfig_Merge(t *testing.T) { Memory *int `yaml:"memory,omitempty"` Nodes map[string]NodeConfig `yaml:"nodes,omitempty"` HostPorts []string `yaml:"hostports,omitempty"` + Volumes []string `yaml:"volumes,omitempty"` }{ Count: nil, CPU: nil, Memory: nil, Nodes: nil, HostPorts: nil, + Volumes: nil, }, Workers: struct { Count *int `yaml:"count,omitempty"` @@ -178,12 +208,14 @@ func TestClusterConfig_Merge(t *testing.T) { Memory *int `yaml:"memory,omitempty"` Nodes map[string]NodeConfig `yaml:"nodes,omitempty"` HostPorts []string `yaml:"hostports,omitempty"` + Volumes []string `yaml:"volumes,omitempty"` }{ Count: nil, CPU: nil, Memory: nil, Nodes: nil, HostPorts: nil, + Volumes: nil, }, } @@ -196,7 +228,10 @@ func TestClusterConfig_Merge(t *testing.T) { t.Errorf("Driver mismatch: expected nil, got '%s'", *base.Driver) } if base.Workers.HostPorts != nil { - t.Errorf("HostPorts mismatch: expected nil, got %v", base.Workers.HostPorts) + t.Errorf("Workers HostPorts mismatch: expected nil, got %v", base.Workers.HostPorts) + } + if base.ControlPlanes.HostPorts != nil { + t.Errorf("ControlPlanes HostPorts mismatch: expected nil, got %v", base.ControlPlanes.HostPorts) } if base.ControlPlanes.Count != nil { t.Errorf("ControlPlanes Count mismatch: expected nil, got %v", *base.ControlPlanes.Count) @@ -222,6 +257,9 @@ func TestClusterConfig_Merge(t *testing.T) { if base.Workers.Nodes != nil { t.Errorf("Workers Nodes mismatch: expected nil, got %v", base.Workers.Nodes) } + if base.Workers.Volumes != nil { + t.Errorf("Workers Volumes mismatch: expected nil, got %v", base.Workers.Volumes) + } }) } @@ -236,14 +274,18 @@ func TestClusterConfig_Copy(t *testing.T) { Memory *int `yaml:"memory,omitempty"` Nodes map[string]NodeConfig `yaml:"nodes,omitempty"` HostPorts []string `yaml:"hostports,omitempty"` + Volumes []string `yaml:"volumes,omitempty"` }{ Count: ptrInt(3), CPU: ptrInt(4), Memory: ptrInt(8192), Nodes: map[string]NodeConfig{ - "node1": {Hostname: ptrString("original-node1")}, + "node1": { + Hostname: ptrString("original-node1"), + }, }, HostPorts: []string{"1000:1000/tcp", "2000:2000/tcp"}, + Volumes: []string{"${WINDSOR_PROJECT_ROOT}/original/volume1:/var/local/original1"}, }, Workers: struct { Count *int `yaml:"count,omitempty"` @@ -251,14 +293,18 @@ func TestClusterConfig_Copy(t *testing.T) { Memory *int `yaml:"memory,omitempty"` Nodes map[string]NodeConfig `yaml:"nodes,omitempty"` HostPorts []string `yaml:"hostports,omitempty"` + Volumes []string `yaml:"volumes,omitempty"` }{ Count: ptrInt(5), CPU: ptrInt(2), Memory: ptrInt(4096), Nodes: map[string]NodeConfig{ - "worker1": {Hostname: ptrString("original-worker1")}, + "worker1": { + Hostname: ptrString("original-worker1"), + }, }, HostPorts: []string{"3000:3000/tcp", "4000:4000/tcp"}, + Volumes: []string{"${WINDSOR_PROJECT_ROOT}/original/worker/volume1:/var/local/worker1"}, }, } diff --git a/cmd/down.go b/cmd/down.go index c82fdcf2b..f8d4b59ee 100644 --- a/cmd/down.go +++ b/cmd/down.go @@ -3,6 +3,7 @@ package cmd import ( "fmt" "os" + "path/filepath" "github.com/spf13/cobra" ctrl "github.com/windsorcli/cli/pkg/controller" @@ -36,6 +37,9 @@ var downCmd = &cobra.Command{ return fmt.Errorf("No config handler found") } + // Resolve the shell + shell := controller.ResolveShell() + // Determine if the container runtime is enabled containerRuntimeEnabled := configHandler.GetBool("docker.enabled") @@ -55,9 +59,19 @@ var downCmd = &cobra.Command{ // Clean up context specific artifacts if --clean flag is set if cleanFlag { - if err := controller.ResolveConfigHandler().Clean(); err != nil { + if err := configHandler.Clean(); err != nil { return fmt.Errorf("Error cleaning up context specific artifacts: %w", err) } + + // Delete everything in the .volumes folder + projectRoot, err := shell.GetProjectRoot() + if err != nil { + return fmt.Errorf("Error retrieving project root: %w", err) + } + volumesPath := filepath.Join(projectRoot, ".volumes") + if err := osRemoveAll(volumesPath); err != nil { + return fmt.Errorf("Error deleting .volumes folder: %w", err) + } } // Print success message diff --git a/cmd/down_test.go b/cmd/down_test.go index b085e211e..14fdb95f2 100644 --- a/cmd/down_test.go +++ b/cmd/down_test.go @@ -2,6 +2,7 @@ package cmd import ( "fmt" + "path/filepath" "strings" "testing" @@ -210,4 +211,76 @@ func TestDownCmd(t *testing.T) { t.Fatalf("Expected error containing 'Error cleaning up context specific artifacts: error cleaning context artifacts', got %v", err) } }) + + t.Run("ErrorDeletingVolumes", func(t *testing.T) { + mocks := setupSafeDownCmdMocks() + mocks.MockShell.GetProjectRootFunc = func() (string, error) { + return filepath.Join("mock", "project", "root"), nil + } + + // Mock the osRemoveAll function to simulate an error when attempting to delete the .volumes folder + originalOsRemoveAll := osRemoveAll + defer func() { osRemoveAll = originalOsRemoveAll }() + osRemoveAll = func(path string) error { + if path == filepath.Join("mock", "project", "root", ".volumes") { + return fmt.Errorf("Error deleting .volumes folder") + } + return nil + } + + // Given a mock osRemoveAll that returns an error when deleting the .volumes folder + rootCmd.SetArgs([]string{"down", "--clean"}) + err := Execute(mocks.MockController) + // Then the error should contain the expected message + if err == nil || !strings.Contains(err.Error(), "Error deleting .volumes folder") { + t.Fatalf("Expected error containing 'Error deleting .volumes folder', got %v", err) + } + }) + + t.Run("SuccessDeletingVolumes", func(t *testing.T) { + mocks := setupSafeDownCmdMocks() + mocks.MockShell.GetProjectRootFunc = func() (string, error) { + return filepath.Join("mock", "project", "root"), nil + } + + // Mock the shell's Exec function to simulate successful deletion of the .volumes folder + mocks.MockShell.ExecFunc = func(command string, args ...string) (string, error) { + if command == "cmd" && len(args) > 0 && args[0] == "/C" && args[1] == "rmdir" && args[2] == "/S" && args[3] == "/Q" && args[4] == filepath.Join("mock", "project", "root", ".volumes") { + return "", nil + } + return "", fmt.Errorf("Unexpected command: %s %v", command, args) + } + + // Given a mock shell that successfully deletes the .volumes folder + output := captureStderr(func() { + rootCmd.SetArgs([]string{"down", "--clean"}) + if err := Execute(mocks.MockController); err != nil { + t.Fatalf("Execute() error = %v", err) + } + }) + + // Then the output should indicate success + expectedOutput := "Windsor environment torn down successfully.\n" + if output != expectedOutput { + t.Errorf("Expected output %q, got %q", expectedOutput, output) + } + }) + + t.Run("ErrorGettingProjectRoot", func(t *testing.T) { + mocks := setupSafeDownCmdMocks() + callCount := 0 + mocks.MockShell.GetProjectRootFunc = func() (string, error) { + callCount++ + if callCount == 2 { + return "", fmt.Errorf("Error retrieving project root") + } + return filepath.Join("mock", "project", "root"), nil + } + + rootCmd.SetArgs([]string{"down", "--clean"}) + err := Execute(mocks.MockController) + if err == nil || !strings.Contains(err.Error(), "Error retrieving project root") { + t.Fatalf("Expected error containing 'Error retrieving project root', got %v", err) + } + }) } diff --git a/cmd/shims.go b/cmd/shims.go index d1db5e71f..7e8444211 100644 --- a/cmd/shims.go +++ b/cmd/shims.go @@ -16,6 +16,9 @@ var osUserHomeDir = os.UserHomeDir // osStat retrieves the file information var osStat = os.Stat +// osRemoveAll removes a directory and all its contents +var osRemoveAll = os.RemoveAll + // getwd retrieves the current working directory var getwd = os.Getwd diff --git a/pkg/blueprint/blueprint_handler.go b/pkg/blueprint/blueprint_handler.go index 050f30f18..8143cc500 100644 --- a/pkg/blueprint/blueprint_handler.go +++ b/pkg/blueprint/blueprint_handler.go @@ -888,10 +888,19 @@ func (b *BaseBlueprintHandler) applyConfigMap() error { lbStart := b.configHandler.GetString("network.loadbalancer_ips.start") lbEnd := b.configHandler.GetString("network.loadbalancer_ips.end") registryURL := b.configHandler.GetString("docker.registry_url") + localVolumePaths := b.configHandler.GetStringSlice("cluster.workers.volumes") // Generate LOADBALANCER_IP_RANGE from the start and end IPs for network loadBalancerIPRange := fmt.Sprintf("%s-%s", lbStart, lbEnd) + // Handle the case where localVolumePaths might not be defined + var localVolumePath string + if len(localVolumePaths) > 0 { + localVolumePath = strings.Split(localVolumePaths[0], ":")[1] + } else { + localVolumePath = "" + } + configMap := &corev1.ConfigMap{ TypeMeta: metav1.TypeMeta{ APIVersion: "v1", @@ -908,6 +917,7 @@ func (b *BaseBlueprintHandler) applyConfigMap() error { "LOADBALANCER_IP_START": lbStart, "LOADBALANCER_IP_END": lbEnd, "REGISTRY_URL": registryURL, + "LOCAL_VOLUME_PATH": localVolumePath, }, } diff --git a/pkg/blueprint/blueprint_handler_test.go b/pkg/blueprint/blueprint_handler_test.go index 8343f4d1c..4e86e3cb9 100644 --- a/pkg/blueprint/blueprint_handler_test.go +++ b/pkg/blueprint/blueprint_handler_test.go @@ -219,6 +219,17 @@ func setupSafeMocks(injector ...di.Injector) MockSafeComponents { } } + // Return mock volume paths + mockConfigHandler.GetStringSliceFunc = func(key string, defaultValue ...[]string) []string { + if key == "cluster.workers.volumes" { + return []string{"${WINDSOR_PROJECT_ROOT}/.volumes:/var/local"} + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return nil + } + mockConfigHandler.GetContextFunc = func() string { return "mock-context" } @@ -1470,6 +1481,9 @@ func TestBlueprintHandler_Install(t *testing.T) { if configMap.Data["REGISTRY_URL"] != "mock.registry.com" { return fmt.Errorf("unexpected REGISTRY_URL value: got %s, want %s", configMap.Data["REGISTRY_URL"], "mock.registry.com") } + if configMap.Data["LOCAL_VOLUME_PATH"] != "/var/local" { + return fmt.Errorf("unexpected LOCAL_VOLUME_PATH value: got %s, want %s", configMap.Data["LOCAL_VOLUME_PATH"], "/var/local") + } } return nil } diff --git a/pkg/blueprint/templates/local.jsonnet b/pkg/blueprint/templates/local.jsonnet index 1f00e4154..cc03b6f04 100644 --- a/pkg/blueprint/templates/local.jsonnet +++ b/pkg/blueprint/templates/local.jsonnet @@ -131,7 +131,7 @@ local registryMirrors = std.foldl( } + // Conditionally add 'machine.registries' only if registryMirrors is non-empty - if std.length(std.objectFields(registryMirrors)) == 0 then + (if std.length(std.objectFields(registryMirrors)) == 0 then {} else { @@ -140,7 +140,32 @@ local registryMirrors = std.foldl( mirrors: registryMirrors, }, }, + }) + ), + worker_config_patches: std.manifestYamlDoc( + if std.objectHas(context.cluster.workers, "volumes") then + { + machine: { + kubelet: { + extraMounts: std.map( + function(volume) + local parts = std.split(volume, ":"); + { + destination: parts[1], + type: "bind", + source: parts[1], + options: [ + "rbind", + "rw", + ], + }, + context.cluster.workers.volumes + ), + }, + }, } + else + {} ), }, variables: { @@ -285,6 +310,19 @@ local registryMirrors = std.foldl( "policy-base" ], }, + { + name: "csi", + source: "core", + path: "csi", + dependsOn: [ + "policy-resources" + ], + force: true, + components: [ + "openebs", + "openebs/dynamic-localpv", + ], + }, ] + (if context.vm.driver != "docker-desktop" then [ { name: "lb-base", diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 848424a5d..1019b2f1c 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -66,29 +66,19 @@ var commonTerraformConfig = terraform.TerraformConfig{ var commonClusterConfig = cluster.ClusterConfig{ Enabled: ptrBool(true), Driver: ptrString("talos"), - ControlPlanes: struct { - Count *int `yaml:"count,omitempty"` - CPU *int `yaml:"cpu,omitempty"` - Memory *int `yaml:"memory,omitempty"` - Nodes map[string]cluster.NodeConfig `yaml:"nodes,omitempty"` - HostPorts []string `yaml:"hostports,omitempty"` - }{ + ControlPlanes: cluster.NodeGroupConfig{ Count: ptrInt(1), CPU: ptrInt(constants.DEFAULT_TALOS_CONTROL_PLANE_CPU), Memory: ptrInt(constants.DEFAULT_TALOS_CONTROL_PLANE_RAM), Nodes: make(map[string]cluster.NodeConfig), }, - Workers: struct { - Count *int `yaml:"count,omitempty"` - CPU *int `yaml:"cpu,omitempty"` - Memory *int `yaml:"memory,omitempty"` - Nodes map[string]cluster.NodeConfig `yaml:"nodes,omitempty"` - HostPorts []string `yaml:"hostports,omitempty"` - }{ - Count: ptrInt(1), - CPU: ptrInt(constants.DEFAULT_TALOS_WORKER_CPU), - Memory: ptrInt(constants.DEFAULT_TALOS_WORKER_RAM), - Nodes: make(map[string]cluster.NodeConfig), + Workers: cluster.NodeGroupConfig{ + Count: ptrInt(1), + CPU: ptrInt(constants.DEFAULT_TALOS_WORKER_CPU), + Memory: ptrInt(constants.DEFAULT_TALOS_WORKER_RAM), + Nodes: make(map[string]cluster.NodeConfig), + HostPorts: []string{}, + Volumes: []string{"${WINDSOR_PROJECT_ROOT}/.volumes:/var/local"}, }, } @@ -105,30 +95,19 @@ var DefaultConfig_Localhost = v1alpha1.Context{ Cluster: &cluster.ClusterConfig{ Enabled: ptrBool(true), Driver: ptrString("talos"), - ControlPlanes: struct { - Count *int `yaml:"count,omitempty"` - CPU *int `yaml:"cpu,omitempty"` - Memory *int `yaml:"memory,omitempty"` - Nodes map[string]cluster.NodeConfig `yaml:"nodes,omitempty"` - HostPorts []string `yaml:"hostports,omitempty"` - }{ + ControlPlanes: cluster.NodeGroupConfig{ Count: ptrInt(1), CPU: ptrInt(constants.DEFAULT_TALOS_CONTROL_PLANE_CPU), Memory: ptrInt(constants.DEFAULT_TALOS_CONTROL_PLANE_RAM), Nodes: make(map[string]cluster.NodeConfig), }, - Workers: struct { - Count *int `yaml:"count,omitempty"` - CPU *int `yaml:"cpu,omitempty"` - Memory *int `yaml:"memory,omitempty"` - Nodes map[string]cluster.NodeConfig `yaml:"nodes,omitempty"` - HostPorts []string `yaml:"hostports,omitempty"` - }{ + Workers: cluster.NodeGroupConfig{ Count: ptrInt(1), CPU: ptrInt(constants.DEFAULT_TALOS_WORKER_CPU), Memory: ptrInt(constants.DEFAULT_TALOS_WORKER_RAM), Nodes: make(map[string]cluster.NodeConfig), HostPorts: []string{"8080:30080/tcp", "8443:30443/tcp", "9292:30292/tcp", "8053:30053/udp"}, + Volumes: []string{"${WINDSOR_PROJECT_ROOT}/.volumes:/var/local"}, }, }, Network: &network.NetworkConfig{ diff --git a/pkg/config/yaml_config_handler_test.go b/pkg/config/yaml_config_handler_test.go index 52b33a67c..95cc72fcd 100644 --- a/pkg/config/yaml_config_handler_test.go +++ b/pkg/config/yaml_config_handler_test.go @@ -547,13 +547,7 @@ func TestYamlConfigHandler_GetInt(t *testing.T) { Contexts: map[string]*v1alpha1.Context{ "default": { Cluster: &cluster.ClusterConfig{ - ControlPlanes: struct { - Count *int `yaml:"count,omitempty"` - CPU *int `yaml:"cpu,omitempty"` - Memory *int `yaml:"memory,omitempty"` - Nodes map[string]cluster.NodeConfig `yaml:"nodes,omitempty"` - HostPorts []string `yaml:"hostports,omitempty"` - }{ + ControlPlanes: cluster.NodeGroupConfig{ Count: ptrInt(3), }, }, @@ -662,13 +656,7 @@ func TestYamlConfigHandler_GetStringSlice(t *testing.T) { handler.config.Contexts = map[string]*v1alpha1.Context{ "default": { Cluster: &cluster.ClusterConfig{ - Workers: struct { - Count *int `yaml:"count,omitempty"` - CPU *int `yaml:"cpu,omitempty"` - Memory *int `yaml:"memory,omitempty"` - Nodes map[string]cluster.NodeConfig `yaml:"nodes,omitempty"` - HostPorts []string `yaml:"hostports,omitempty"` - }{ + Workers: cluster.NodeGroupConfig{ HostPorts: []string{"50000:50002/tcp", "30080:8080/tcp", "30443:8443/tcp"}, }, }, diff --git a/pkg/env/kube_env.go b/pkg/env/kube_env.go index a2dc28124..8fd9f0556 100644 --- a/pkg/env/kube_env.go +++ b/pkg/env/kube_env.go @@ -1,10 +1,18 @@ package env import ( + "context" "fmt" + "os" "path/filepath" + "regexp" + "strings" "github.com/windsorcli/cli/pkg/di" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/tools/clientcmd" ) // KubeEnvPrinter is a struct that simulates a Kubernetes environment for testing purposes. @@ -21,23 +29,83 @@ func NewKubeEnvPrinter(injector di.Injector) *KubeEnvPrinter { } } -// GetEnvVars retrieves the environment variables for the Kubernetes environment. +// GetEnvVars constructs a map of Kubernetes environment variables by setting +// KUBECONFIG and KUBE_CONFIG_PATH based on the configuration root directory. +// It checks for a project-specific volume directory and returns current variables +// if it doesn't exist. If it does, it ensures each PVC directory has a corresponding +// "PV_" environment variable, returning the map if all are accounted for. func (e *KubeEnvPrinter) GetEnvVars() (map[string]string, error) { envVars := make(map[string]string) - - // Determine the root directory for configuration files. configRoot, err := e.configHandler.GetConfigRoot() if err != nil { return nil, fmt.Errorf("error retrieving configuration root directory: %w", err) } - - // Construct the path to the kubeconfig file. kubeConfigPath := filepath.Join(configRoot, ".kube", "config") - - // Populate environment variables with Kubernetes configuration data. envVars["KUBECONFIG"] = kubeConfigPath envVars["KUBE_CONFIG_PATH"] = kubeConfigPath + projectRoot := os.Getenv("WINDSOR_PROJECT_ROOT") + volumeDir := filepath.Join(projectRoot, ".volumes") + + if _, err := stat(volumeDir); os.IsNotExist(err) { + return envVars, nil + } + + volumeDirs, err := readDir(volumeDir) + if err != nil { + return nil, fmt.Errorf("error reading volume directories: %w", err) + } + + existingEnvVars := make(map[string]string) + for _, env := range os.Environ() { + if strings.HasPrefix(env, "PV_") { + parts := strings.SplitN(env, "=", 2) + if len(parts) == 2 { + existingEnvVars[parts[0]] = parts[1] + envVars[parts[0]] = parts[1] // Include existing PV environment variables + } + } + } + + allVolumesAccounted := true + for _, dir := range volumeDirs { + if strings.HasPrefix(dir.Name(), "pvc-") { + found := false + for _, envVarValue := range existingEnvVars { + if strings.HasSuffix(dir.Name(), filepath.Base(envVarValue)) { + found = true + break + } + } + if !found { + allVolumesAccounted = false + break + } + } + } + + if allVolumesAccounted { + return envVars, nil + } + + pvcs, _ := queryPersistentVolumeClaims(kubeConfigPath) // ignores error + + if pvcs != nil && pvcs.Items != nil { + for _, dir := range volumeDirs { + if strings.HasPrefix(dir.Name(), "pvc-") { + for _, pvc := range pvcs.Items { + if strings.HasSuffix(dir.Name(), string(pvc.UID)) { + envVarName := fmt.Sprintf("PV_%s_%s", sanitizeEnvVar(pvc.Namespace), sanitizeEnvVar(pvc.Name)) + if _, exists := existingEnvVars[envVarName]; !exists { + envVars[envVarName] = filepath.Join(volumeDir, dir.Name()) + } + break + } + } + } + } + } + return envVars, nil } @@ -48,9 +116,39 @@ func (e *KubeEnvPrinter) Print() error { // Return the error if GetEnvVars fails return fmt.Errorf("error getting environment variables: %w", err) } + // Call the Print method of the embedded BaseEnvPrinter struct with the retrieved environment variables return e.BaseEnvPrinter.Print(envVars) } // Ensure kubeEnv implements the EnvPrinter interface var _ EnvPrinter = (*KubeEnvPrinter)(nil) + +// sanitizeEnvVar converts a string to uppercase, trims whitespace, and replaces invalid characters with underscores. +func sanitizeEnvVar(input string) string { + trimmed := strings.TrimSpace(input) + upper := strings.ToUpper(trimmed) + re := regexp.MustCompile(`[^A-Z0-9_]`) + sanitized := re.ReplaceAllString(upper, "_") + return strings.Trim(sanitized, "_") +} + +// queryPersistentVolumeClaims retrieves a list of PersistentVolumeClaims (PVCs) from the Kubernetes cluster. +// It returns a list of PVCs and an error if there is any issue in building the Kubernetes configuration +var queryPersistentVolumeClaims = func(kubeConfigPath string) (*corev1.PersistentVolumeClaimList, error) { + config, err := clientcmd.BuildConfigFromFlags("", kubeConfigPath) + if err != nil { + return nil, err + } + clientset, err := kubernetes.NewForConfig(config) + if err != nil { + return nil, err + } + + pvcs, err := clientset.CoreV1().PersistentVolumeClaims("").List(context.TODO(), metav1.ListOptions{}) + if err != nil { + return nil, err + } + + return pvcs, nil +} diff --git a/pkg/env/kube_env_test.go b/pkg/env/kube_env_test.go index 3b1cad628..36061585a 100644 --- a/pkg/env/kube_env_test.go +++ b/pkg/env/kube_env_test.go @@ -7,10 +7,13 @@ import ( "reflect" "strings" "testing" + "time" "github.com/windsorcli/cli/pkg/config" "github.com/windsorcli/cli/pkg/di" "github.com/windsorcli/cli/pkg/shell" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) type KubeEnvPrinterMocks struct { @@ -37,6 +40,36 @@ func setupSafeKubeEnvPrinterMocks(injector ...di.Injector) *KubeEnvPrinterMocks mockInjector.Register("configHandler", mockConfigHandler) mockInjector.Register("shell", mockShell) + // Mock readDir to return some valid persistent volume folders + originalReadDir := readDir + readDir = func(dirname string) ([]os.DirEntry, error) { + if dirname == ".volumes" { + return []os.DirEntry{ + mockDirEntry{name: "pvc-1234"}, + mockDirEntry{name: "pvc-5678"}, + }, nil + } + return originalReadDir(dirname) + } + + // Mock stat to return nil + stat = func(name string) (os.FileInfo, error) { + if strings.HasSuffix(name, ".kube/config") || strings.HasSuffix(name, ".volumes") { + return nil, nil + } + return nil, os.ErrNotExist + } + + // Mock queryPersistentVolumeClaims to return appropriate PVC claims + queryPersistentVolumeClaims = func(kubeConfigPath string) (*corev1.PersistentVolumeClaimList, error) { + return &corev1.PersistentVolumeClaimList{ + Items: []corev1.PersistentVolumeClaim{ + {ObjectMeta: metav1.ObjectMeta{UID: "1234", Namespace: "default", Name: "claim1"}}, + {ObjectMeta: metav1.ObjectMeta{UID: "5678", Namespace: "default", Name: "claim2"}}, + }, + }, nil + } + return &KubeEnvPrinterMocks{ Injector: mockInjector, ConfigHandler: mockConfigHandler, @@ -44,19 +77,32 @@ func setupSafeKubeEnvPrinterMocks(injector ...di.Injector) *KubeEnvPrinterMocks } } +// mockDirEntry is a simple mock implementation of os.DirEntry +type mockDirEntry struct { + name string +} + +func (m mockDirEntry) Name() string { return m.name } +func (m mockDirEntry) IsDir() bool { return true } +func (m mockDirEntry) Type() os.FileMode { return os.ModeDir } +func (m mockDirEntry) Info() (os.FileInfo, error) { return mockFileInfo{name: m.name}, nil } + +// mockFileInfo is a simple mock implementation of os.FileInfo +type mockFileInfo struct { + name string +} + +func (m mockFileInfo) Name() string { return m.name } +func (m mockFileInfo) Size() int64 { return 0 } +func (m mockFileInfo) Mode() os.FileMode { return os.ModeDir } +func (m mockFileInfo) ModTime() time.Time { return time.Time{} } +func (m mockFileInfo) IsDir() bool { return true } +func (m mockFileInfo) Sys() interface{} { return nil } + func TestKubeEnvPrinter_GetEnvVars(t *testing.T) { t.Run("Success", func(t *testing.T) { mocks := setupSafeKubeEnvPrinterMocks() - originalStat := stat - defer func() { stat = originalStat }() - stat = func(name string) (os.FileInfo, error) { - if name == filepath.FromSlash("/mock/config/root/.kube/config") { - return nil, nil - } - return nil, os.ErrNotExist - } - kubeEnvPrinter := NewKubeEnvPrinter(mocks.Injector) kubeEnvPrinter.Initialize() @@ -109,6 +155,107 @@ func TestKubeEnvPrinter_GetEnvVars(t *testing.T) { t.Errorf("error = %v, want %v", err, expectedError) } }) + + t.Run("ErrorReadingVolumes", func(t *testing.T) { + mocks := setupSafeKubeEnvPrinterMocks() + mocks.ConfigHandler.GetConfigRootFunc = func() (string, error) { + return "/mock/config/root", nil + } + + originalReadDir := readDir + defer func() { readDir = originalReadDir }() + readDir = func(dirname string) ([]os.DirEntry, error) { + return nil, errors.New("mock readDir error") + } + + kubeEnvPrinter := NewKubeEnvPrinter(mocks.Injector) + kubeEnvPrinter.Initialize() + + _, err := kubeEnvPrinter.GetEnvVars() + expectedError := "error reading volume directories: mock readDir error" + if err == nil || err.Error() != expectedError { + t.Errorf("error = %v, want %v", err, expectedError) + } + }) + + t.Run("SuccessWithExistingPVCEnvVars", func(t *testing.T) { + // Use setupSafeKubeEnvPrinterMocks to create mocks + mocks := setupSafeKubeEnvPrinterMocks() + kubeEnvPrinter := NewKubeEnvPrinter(mocks.Injector) + kubeEnvPrinter.Initialize() + + // Set up environment variables to simulate existing PVC environment variables + os.Setenv("PV_NAMESPACE_PVCNAME", "/mock/volume/dir/pvc-12345") + defer os.Unsetenv("PV_NAMESPACE_PVCNAME") + + // Mock the readDir function to simulate reading the volume directory + readDir = func(dirname string) ([]os.DirEntry, error) { + return []os.DirEntry{ + mockDirEntry{name: "pvc-12345"}, + }, nil + } + + // Call GetEnvVars and check for errors + envVars, err := kubeEnvPrinter.GetEnvVars() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + // Verify that GetEnvVars returns the correct envVars + expectedEnvVars := map[string]string{ + "KUBECONFIG": filepath.FromSlash("/mock/config/root/.kube/config"), + "KUBE_CONFIG_PATH": filepath.FromSlash("/mock/config/root/.kube/config"), + "PV_NAMESPACE_PVCNAME": "/mock/volume/dir/pvc-12345", + } + if !reflect.DeepEqual(envVars, expectedEnvVars) { + t.Errorf("envVars = %v, want %v", envVars, expectedEnvVars) + } + }) + + t.Run("AllVolumesAccountedFor", func(t *testing.T) { + mocks := setupSafeKubeEnvPrinterMocks() + kubeEnvPrinter := NewKubeEnvPrinter(mocks.Injector) + kubeEnvPrinter.Initialize() + + // Set up environment variables to simulate all PVCs being accounted for + os.Setenv("PV_DEFAULT_CLAIM1", "/mock/volume/dir/pvc-1234") + os.Setenv("PV_DEFAULT_CLAIM2", "/mock/volume/dir/pvc-5678") + defer os.Unsetenv("PV_DEFAULT_CLAIM1") + defer os.Unsetenv("PV_DEFAULT_CLAIM2") + + // Mock the readDir function to simulate reading the volume directory + readDir = func(dirname string) ([]os.DirEntry, error) { + return []os.DirEntry{ + mockDirEntry{name: "pvc-1234"}, + mockDirEntry{name: "pvc-5678"}, + }, nil + } + + // Mock queryPersistentVolumeClaims to verify it is not called + originalQueryPVCs := queryPersistentVolumeClaims + defer func() { queryPersistentVolumeClaims = originalQueryPVCs }() + queryPersistentVolumeClaims = func(kubeConfigPath string) (*corev1.PersistentVolumeClaimList, error) { + t.Error("queryPersistentVolumeClaims should not be called") + return nil, nil + } + + // Call GetEnvVars and check for errors + envVars, err := kubeEnvPrinter.GetEnvVars() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + // Verify that GetEnvVars returns the correct envVars without calling queryPersistentVolumeClaims + expectedEnvVars := map[string]string{ + "KUBECONFIG": filepath.FromSlash("/mock/config/root/.kube/config"), + "KUBE_CONFIG_PATH": filepath.FromSlash("/mock/config/root/.kube/config"), + "PV_DEFAULT_CLAIM1": "/mock/volume/dir/pvc-1234", + "PV_DEFAULT_CLAIM2": "/mock/volume/dir/pvc-5678", + } + if !reflect.DeepEqual(envVars, expectedEnvVars) { + t.Errorf("envVars = %v, want %v", envVars, expectedEnvVars) + } + }) } func TestKubeEnvPrinter_Print(t *testing.T) { diff --git a/pkg/env/shims.go b/pkg/env/shims.go index 890d277b0..ce2b310d7 100644 --- a/pkg/env/shims.go +++ b/pkg/env/shims.go @@ -21,6 +21,9 @@ var glob = filepath.Glob // Wrapper function for os.WriteFile var writeFile = os.WriteFile +// Wrapper function for os.ReadDir +var readDir = os.ReadDir + // Wrapper function for yaml.Unmarshal var yamlUnmarshal = yaml.Unmarshal diff --git a/pkg/services/talos_service.go b/pkg/services/talos_service.go index 866ffa3a5..37c25d623 100644 --- a/pkg/services/talos_service.go +++ b/pkg/services/talos_service.go @@ -4,7 +4,6 @@ import ( "fmt" "math" "os" - "path/filepath" "strconv" "strings" "sync" @@ -204,21 +203,28 @@ func (s *TalosService) GetComposeConfig() (*types.Config, error) { }, } - if s.mode != "controlplane" { - projectRoot, err := s.shell.GetProjectRoot() - if err != nil { - return nil, fmt.Errorf("error retrieving project root: %w", err) + // Use volumes from cluster configuration + volumesKey := fmt.Sprintf("cluster.%s.volumes", nodeType) + volumes := s.configHandler.GetStringSlice(volumesKey, []string{}) + for _, volume := range volumes { + parts := strings.Split(volume, ":") + if len(parts) != 2 { + return nil, fmt.Errorf("invalid volume format: %s", volume) } - volumesPath := filepath.Join(projectRoot, ".volumes") - if _, err := stat(volumesPath); os.IsNotExist(err) { - if err := mkdir(volumesPath, os.ModePerm); err != nil { - return nil, fmt.Errorf("error creating .volumes directory: %w", err) - } + + // Expand environment variables in the source path for directory creation + expandedSourcePath := os.ExpandEnv(parts[0]) + + // Create the directory if it doesn't exist + if err := mkdirAll(expandedSourcePath, os.ModePerm); err != nil { + return nil, fmt.Errorf("failed to create directory %s: %v", expandedSourcePath, err) } + + // Use the original, pre-expanded source path in the volume configuration commonConfig.Volumes = append(commonConfig.Volumes, types.ServiceVolumeConfig{ Type: "bind", - Source: "${WINDSOR_PROJECT_ROOT}/.volumes", - Target: "/var/local", + Source: parts[0], + Target: parts[1], }) } @@ -287,7 +293,7 @@ func (s *TalosService) GetComposeConfig() (*types.Config, error) { serviceConfig.Ports = ports - volumes := map[string]types.VolumeConfig{ + volumesMap := map[string]types.VolumeConfig{ strings.ReplaceAll(nodeName+"_system_state", "-", "_"): {}, strings.ReplaceAll(nodeName+"_var", "-", "_"): {}, strings.ReplaceAll(nodeName+"_etc_cni", "-", "_"): {}, @@ -298,6 +304,6 @@ func (s *TalosService) GetComposeConfig() (*types.Config, error) { return &types.Config{ Services: []types.ServiceConfig{serviceConfig}, - Volumes: volumes, + Volumes: volumesMap, }, nil } diff --git a/pkg/services/talos_service_test.go b/pkg/services/talos_service_test.go index db5e28a9c..21f60fb58 100644 --- a/pkg/services/talos_service_test.go +++ b/pkg/services/talos_service_test.go @@ -4,7 +4,6 @@ import ( "fmt" "math" "os" - "path/filepath" "strings" "testing" @@ -56,6 +55,8 @@ func setupTalosServiceMocks(optionalInjector ...di.Injector) *MockComponents { return "192.168.1.2:50001" case "dns.domain": return "test" + case "cluster.workers.local_volume_path": + return "/var/local" default: if len(defaultValue) > 0 { return defaultValue[0] @@ -72,6 +73,12 @@ func setupTalosServiceMocks(optionalInjector ...di.Injector) *MockComponents { return []string{"30002:30002/tcp", "30003:30003"} case "cluster.workers.hostports": return []string{"30000:30000", "30001:30001/udp", "30002:30002/tcp", "30003:30003"} + case "cluster.workers.nodes.worker1.volumes": + return []string{"/data/worker1:/mnt/data", "/logs/worker1:/mnt/logs"} + case "cluster.workers.nodes.worker2.volumes": + return []string{"/data/worker2:/mnt/data", "/logs/worker2:/mnt/logs"} + case "cluster.workers.volumes": + return []string{"/data/common:/mnt/data", "/logs/common:/mnt/logs"} default: if len(defaultValue) > 0 { return defaultValue[0] @@ -83,13 +90,7 @@ func setupTalosServiceMocks(optionalInjector ...di.Injector) *MockComponents { mockConfigHandler.GetConfigFunc = func() *v1alpha1.Context { return &v1alpha1.Context{ Cluster: &cluster.ClusterConfig{ - Workers: struct { - Count *int `yaml:"count,omitempty"` - CPU *int `yaml:"cpu,omitempty"` - Memory *int `yaml:"memory,omitempty"` - Nodes map[string]cluster.NodeConfig `yaml:"nodes,omitempty"` - HostPorts []string `yaml:"hostports,omitempty"` - }{ + Workers: cluster.NodeGroupConfig{ Nodes: map[string]cluster.NodeConfig{ "worker1": {}, "worker2": {}, @@ -638,14 +639,17 @@ func TestTalosService_GetComposeConfig(t *testing.T) { } }) - t.Run("ErrorGettingProjectRoot", func(t *testing.T) { + t.Run("InvalidVolumeFormat", func(t *testing.T) { // Setup mocks for this test mocks := setupTalosServiceMocks() service := NewTalosService(mocks.Injector, "worker") - // Mock the GetProjectRoot method to return an error - mocks.MockShell.GetProjectRootFunc = func() (string, error) { - return "", fmt.Errorf("mock error retrieving project root") + // Mock the GetStringSlice method to return an invalid volume format + mocks.MockConfigHandler.GetStringSliceFunc = func(key string, defaultValue ...[]string) []string { + if key == "cluster.workers.volumes" { + return []string{"invalidVolumeFormat"} + } + return nil } // Initialize the service @@ -655,49 +659,26 @@ func TestTalosService_GetComposeConfig(t *testing.T) { } // When the GetComposeConfig method is called - config, err := service.GetComposeConfig() + _, err = service.GetComposeConfig() - // Then an error should be returned and the config should be nil + // Then an error should be returned if err == nil { - t.Fatalf("expected an error, got nil") - } - if err.Error() != "error retrieving project root: mock error retrieving project root" { - t.Fatalf("expected error message 'error retrieving project root: mock error retrieving project root', got %v", err) + t.Fatalf("expected an error due to invalid volume format, got nil") } - if config != nil { - t.Fatalf("expected config to be nil, got %v", config) + if err.Error() != "invalid volume format: invalidVolumeFormat" { + t.Fatalf("expected error message 'invalid volume format: invalidVolumeFormat', got %v", err) } }) - t.Run("ErrorCreatingVolumesDirectory", func(t *testing.T) { + t.Run("InvalidDefaultAPIPort", func(t *testing.T) { // Setup mocks for this test mocks := setupTalosServiceMocks() service := NewTalosService(mocks.Injector, "worker") - // Mock the GetProjectRoot method to return a valid project root - mocks.MockShell.GetProjectRootFunc = func() (string, error) { - return filepath.FromSlash("/mock/project/root"), nil - } - - // Mock the stat function to simulate the .volumes directory does not exist - originalStat := stat - defer func() { stat = originalStat }() - stat = func(name string) (os.FileInfo, error) { - if filepath.Clean(name) == filepath.Clean(filepath.Join("/mock/project/root", ".volumes")) { - return nil, os.ErrNotExist - } - return nil, nil - } - - // Mock the mkdir function to return an error - originalMkdir := mkdir - defer func() { mkdir = originalMkdir }() - mkdir = func(name string, perm os.FileMode) error { - if filepath.Clean(name) == filepath.Clean(filepath.Join("/mock/project/root", ".volumes")) { - return fmt.Errorf("mock error creating .volumes directory") - } - return nil - } + // Set the defaultAPIPort to an invalid value exceeding MaxUint32 + originalDefaultAPIPort := defaultAPIPort + defaultAPIPort = int(math.MaxUint32) + 1 + defer func() { defaultAPIPort = originalDefaultAPIPort }() // Initialize the service err := service.Initialize() @@ -706,29 +687,28 @@ func TestTalosService_GetComposeConfig(t *testing.T) { } // When the GetComposeConfig method is called - config, err := service.GetComposeConfig() + _, err = service.GetComposeConfig() - // Then an error should be returned and the config should be nil + // Then an error should be returned if err == nil { - t.Fatalf("expected an error, got nil") - } - if err.Error() != "error creating .volumes directory: mock error creating .volumes directory" { - t.Fatalf("expected error message 'error creating .volumes directory: mock error creating .volumes directory', got %v", err) + t.Fatalf("expected an error due to invalid default API port, got nil") } - if config != nil { - t.Fatalf("expected config to be nil, got %v", config) + if err.Error() != fmt.Sprintf("defaultAPIPort value out of range: %d", defaultAPIPort) { + t.Fatalf("expected error message 'defaultAPIPort value out of range: %d', got %v", defaultAPIPort, err) } }) - t.Run("InvalidDefaultAPIPort", func(t *testing.T) { + t.Run("ErrorMkdirAll", func(t *testing.T) { // Setup mocks for this test mocks := setupTalosServiceMocks() service := NewTalosService(mocks.Injector, "worker") - // Set the defaultAPIPort to an invalid value exceeding MaxUint32 - originalDefaultAPIPort := defaultAPIPort - defaultAPIPort = int(math.MaxUint32) + 1 - defer func() { defaultAPIPort = originalDefaultAPIPort }() + // Mock the mkdirAll function to return an error + originalMkdirAll := mkdirAll + defer func() { mkdirAll = originalMkdirAll }() + mkdirAll = func(path string, perm os.FileMode) error { + return fmt.Errorf("mocked mkdirAll error") + } // Initialize the service err := service.Initialize() @@ -741,19 +721,19 @@ func TestTalosService_GetComposeConfig(t *testing.T) { // Then an error should be returned if err == nil { - t.Fatalf("expected an error due to invalid default API port, got nil") + t.Fatalf("expected an error due to mkdirAll failure, got nil") } - if err.Error() != fmt.Sprintf("defaultAPIPort value out of range: %d", defaultAPIPort) { - t.Fatalf("expected error message 'defaultAPIPort value out of range: %d', got %v", defaultAPIPort, err) + if !strings.Contains(err.Error(), "mocked mkdirAll error") { + t.Fatalf("expected error message containing 'mocked mkdirAll error', got %v", err) } }) - t.Run("InvalidHostPort", func(t *testing.T) { + t.Run("InvalidHostPortFormat", func(t *testing.T) { // Setup mocks for this test mocks := setupTalosServiceMocks() service := NewTalosService(mocks.Injector, "worker") - // Mock the GetStringSlice method to return an invalid host port + // Mock the GetStringSlice method to return an invalid host port format mocks.MockConfigHandler.GetStringSliceFunc = func(key string, defaultValue ...[]string) []string { if key == "cluster.workers.nodes.worker.hostports" { return []string{"invalidPort:30000/tcp"} @@ -779,12 +759,12 @@ func TestTalosService_GetComposeConfig(t *testing.T) { } }) - t.Run("InvalidHostPort", func(t *testing.T) { + t.Run("InvalidHostPortValue", func(t *testing.T) { // Setup mocks for this test mocks := setupTalosServiceMocks() service := NewTalosService(mocks.Injector, "worker") - // Mock the GetStringSlice method to return an invalid host port + // Mock the GetStringSlice method to return an invalid host port value mocks.MockConfigHandler.GetStringSliceFunc = func(key string, defaultValue ...[]string) []string { if key == "cluster.workers.nodes.worker.hostports" { return []string{"30000:invalidHostPort/tcp"}