diff --git a/cmd/env.go b/cmd/env.go index 447acd7f8..528eb7fbc 100644 --- a/cmd/env.go +++ b/cmd/env.go @@ -58,6 +58,26 @@ var envCmd = &cobra.Command{ } } + // Resolve config handler to check vm.driver + vmDriver := configHandler.GetString("vm.driver") + + // Create virtualization and service components only if vm.driver is configured + if vmDriver != "" { + if err := controller.CreateVirtualizationComponents(); err != nil { + if verbose { + return fmt.Errorf("Error creating virtualization components: %w", err) + } + return nil + } + + if err := controller.CreateServiceComponents(); err != nil { + if verbose { + return fmt.Errorf("Error creating service components: %w", err) + } + return nil + } + } + // Create environment components if err := controller.CreateEnvComponents(); err != nil { if verbose { @@ -86,26 +106,6 @@ var envCmd = &cobra.Command{ return nil } - // Resolve config handler to check vm.driver - vmDriver := configHandler.GetString("vm.driver") - - // Create virtualization and service components only if vm.driver is configured - if vmDriver != "" { - if err := controller.CreateVirtualizationComponents(); err != nil { - if verbose { - return fmt.Errorf("Error creating virtualization components: %w", err) - } - return nil - } - - if err := controller.CreateServiceComponents(); err != nil { - if verbose { - return fmt.Errorf("Error creating service components: %w", err) - } - return nil - } - } - // Check if --decrypt flag is set decrypt, _ := cmd.Flags().GetBool("decrypt") if decrypt { diff --git a/pkg/blueprint/templates/default.jsonnet b/pkg/blueprint/templates/default.jsonnet index 39e781cd1..57b4bbe1b 100644 --- a/pkg/blueprint/templates/default.jsonnet +++ b/pkg/blueprint/templates/default.jsonnet @@ -16,8 +16,20 @@ local cpNodes = if std.objectHas(context, "cluster") && std.objectHas(context.cl // Select the first node or default to null if no nodes are present local firstNode = if std.length(cpNodes) > 0 then cpNodes[0] else null; +// Extract baseUrl from endpoint +local extractBaseUrl(endpoint) = + if endpoint == "" then "" else + local parts = std.split(endpoint, "://"); + if std.length(parts) > 1 then + local hostParts = std.split(parts[1], ":"); + hostParts[0] + else + local hostParts = std.split(endpoint, ":"); + hostParts[0]; + // Determine the endpoint, using cluster.endpoint if available, otherwise falling back to firstNode -local endpoint = if std.objectHas(context.cluster, "endpoint") then context.cluster.endpoint else if firstNode != null then firstNode.node else ""; +local endpoint = if std.objectHas(context.cluster, "endpoint") then context.cluster.endpoint else if firstNode != null then firstNode.endpoint else ""; +local baseUrl = extractBaseUrl(endpoint); // Build the mirrors dynamically, only if registries are defined local registryMirrors = if std.objectHas(context, "docker") && std.objectHas(context.docker, "registries") then @@ -81,7 +93,7 @@ local terraformConfig = if platform == "local" || platform == "metal" then [ source: "core", values: { // Use the determined endpoint - cluster_endpoint: if endpoint != "" then "https://" + endpoint + ":6443" else "", + cluster_endpoint: if endpoint != "" then "https://" + baseUrl + ":6443" else "", cluster_name: "talos", // Create a list of control plane nodes @@ -89,7 +101,6 @@ local terraformConfig = if platform == "local" || platform == "metal" then [ std.map( function(v) { endpoint: v.endpoint, - hostname: v.hostname, node: v.node, }, std.objectValues(context.cluster.controlplanes.nodes) @@ -101,7 +112,6 @@ local terraformConfig = if platform == "local" || platform == "metal" then [ std.map( function(v) { endpoint: v.endpoint, - hostname: v.hostname, node: v.node, }, std.objectValues(context.cluster.workers.nodes) @@ -117,7 +127,7 @@ local terraformConfig = if platform == "local" || platform == "metal" then [ apiServer: { certSANs: [ "localhost", - endpoint, + baseUrl, ], }, extraManifests: [ @@ -132,7 +142,7 @@ local terraformConfig = if platform == "local" || platform == "metal" then [ machine: { certSANs: [ "localhost", - endpoint, + baseUrl, ], network: if vmDriver == "docker-desktop" then { interfaces: [ diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index fdc3d7837..e8a435ee7 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -23,6 +23,7 @@ const ( DEFAULT_TALOS_WORKER_RAM = 4 DEFAULT_TALOS_CONTROL_PLANE_CPU = 2 DEFAULT_TALOS_CONTROL_PLANE_RAM = 2 + DEFAULT_TALOS_API_PORT = 50000 ) const ( @@ -55,7 +56,7 @@ const ( const ( // renovate: datasource=docker depName=registry REGISTRY_DEFAULT_IMAGE = "registry:2.8.3" - REGISTRY_DEFAULT_HOST_PORT = 5002 + REGISTRY_DEFAULT_HOST_PORT = 5001 ) // Default network settings diff --git a/pkg/services/dns_service.go b/pkg/services/dns_service.go index 69c397cec..ea5ef41ca 100644 --- a/pkg/services/dns_service.go +++ b/pkg/services/dns_service.go @@ -3,6 +3,7 @@ package services import ( "fmt" "path/filepath" + "strings" "github.com/compose-spec/compose-go/types" "github.com/windsorcli/cli/pkg/constants" @@ -53,12 +54,12 @@ func (s *DNSService) SetAddress(address string) error { // GetComposeConfig sets up CoreDNS with context and domain, configures ports if localhost. func (s *DNSService) GetComposeConfig() (*types.Config, error) { contextName := s.configHandler.GetContext() - tld := s.configHandler.GetString("dns.domain", "test") - fullName := s.name + "." + tld + serviceName := s.GetName() + containerName := s.GetContainerName() corednsConfig := types.ServiceConfig{ - Name: fullName, - ContainerName: fullName, + Name: serviceName, + ContainerName: containerName, Image: constants.DEFAULT_DNS_IMAGE, Restart: "always", Command: []string{"-conf", "/etc/coredns/Corefile"}, @@ -72,7 +73,7 @@ func (s *DNSService) GetComposeConfig() (*types.Config, error) { }, } - if s.IsLocalhostMode() { + if s.isLocalhostMode() { corednsConfig.Ports = []types.ServicePortConfig{ { Target: 53, @@ -92,11 +93,11 @@ func (s *DNSService) GetComposeConfig() (*types.Config, error) { return &types.Config{Services: services}, nil } -// WriteConfig generates a Corefile for DNS configuration by gathering project root, TLD, and service IPs, -// constructing DNS host entries, and appending static DNS records. It adapts the Corefile for localhost -// by adding a template for local DNS resolution. Additionally, it configures DNS forwarding by including -// specified forward addresses, ensuring DNS queries are directed appropriately. The final Corefile is -// written to the .windsor config directory +// WriteConfig generates a Corefile by collecting the project root directory, top-level domain (TLD), and IP addresses. +// It adds DNS entries for each service, ensuring that each service's hostname resolves to its IP address. +// For localhost environments, it uses a specific DNS template to handle local DNS resolution and sets up forwarding +// rules to direct DNS queries to the appropriate addresses. +// The Corefile is saved in the .windsor directory, which is used by CoreDNS to manage DNS queries for the project. func (s *DNSService) WriteConfig() error { projectRoot, err := s.shell.GetProjectRoot() if err != nil { @@ -104,68 +105,106 @@ func (s *DNSService) WriteConfig() error { } tld := s.configHandler.GetString("dns.domain", "test") + networkCIDR := s.configHandler.GetString("network.cidr_block") + + var ( + hostEntries string + localhostHostEntries string + wildcardEntries string + localhostWildcardEntries string + ) + + wildcardTemplate := ` + template IN A { + match ^(.*)\.%s\.$ + answer "{{ .Name }} 60 IN A %s" + fallthrough + } +` + localhostTemplate := ` + template IN A { + match ^(.*)\.%s\.$ + answer "{{ .Name }} 60 IN A 127.0.0.1" + fallthrough + } +` - var hostEntries string for _, service := range s.services { composeConfig, err := service.GetComposeConfig() if err != nil || composeConfig == nil { continue } for _, svc := range composeConfig.Services { - if svc.Name != "" { - address := service.GetAddress() - if address != "" { - hostname := service.GetHostname() - if s.IsLocalhostMode() { - address = "127.0.0.1" - } - hostEntries += fmt.Sprintf(" %s %s\n", address, hostname) + if svc.Name == "" { + continue + } + address := service.GetAddress() + if address == "" { + continue + } + + hostname := service.GetHostname() + escapedHostname := strings.ReplaceAll(hostname, ".", "\\.") + + hostEntries += fmt.Sprintf(" %s %s\n", address, hostname) + if s.isLocalhostMode() { + localhostHostEntries += fmt.Sprintf(" 127.0.0.1 %s\n", hostname) + } + if service.SupportsWildcard() { + wildcardEntries += fmt.Sprintf(wildcardTemplate, escapedHostname, address) + if s.isLocalhostMode() { + localhostWildcardEntries += fmt.Sprintf(localhostTemplate, escapedHostname) } } } } - dnsRecords := s.configHandler.GetStringSlice("dns.records", nil) - for _, record := range dnsRecords { + for _, record := range s.configHandler.GetStringSlice("dns.records", nil) { hostEntries += fmt.Sprintf(" %s\n", record) + if s.isLocalhostMode() { + localhostHostEntries += fmt.Sprintf(" %s\n", record) + } } forwardAddresses := s.configHandler.GetStringSlice("dns.forward", nil) if len(forwardAddresses) == 0 { forwardAddresses = []string{"1.1.1.1", "8.8.8.8"} } - forwardAddressesStr := fmt.Sprintf("%s", forwardAddresses[0]) - for _, addr := range forwardAddresses[1:] { - forwardAddressesStr += fmt.Sprintf(" %s", addr) - } + forwardAddressesStr := strings.Join(forwardAddresses, " ") - var corefileContent string - corefileContent = fmt.Sprintf(` -%s:53 { - errors - reload - loop - hosts { + serverBlockTemplate := `%s:53 { +%s hosts { %s fallthrough } +%s + reload + loop forward . %s } -.:53 { - errors +` + + var corefileContent string + if s.isLocalhostMode() { + internalView := fmt.Sprintf(" view internal {\n expr incidr(client_ip(), '%s')\n }\n", networkCIDR) + corefileContent = fmt.Sprintf(serverBlockTemplate, tld, internalView, hostEntries, wildcardEntries, forwardAddressesStr) + corefileContent += fmt.Sprintf(serverBlockTemplate, tld, "", localhostHostEntries, localhostWildcardEntries, forwardAddressesStr) + } else { + corefileContent = fmt.Sprintf(serverBlockTemplate, tld, "", hostEntries, wildcardEntries, forwardAddressesStr) + } + + corefileContent += `.:53 { reload loop forward . 1.1.1.1 8.8.8.8 } -`, tld, hostEntries, forwardAddressesStr) +` corefilePath := filepath.Join(projectRoot, ".windsor", "Corefile") - if err := mkdirAll(filepath.Dir(corefilePath), 0755); err != nil { return fmt.Errorf("error creating parent folders: %w", err) } - err = writeFile(corefilePath, []byte(corefileContent), 0644) - if err != nil { + if err := writeFile(corefilePath, []byte(corefileContent), 0644); err != nil { return fmt.Errorf("error writing Corefile: %w", err) } diff --git a/pkg/services/dns_service_test.go b/pkg/services/dns_service_test.go index b46f56813..b7b54a55d 100644 --- a/pkg/services/dns_service_test.go +++ b/pkg/services/dns_service_test.go @@ -271,8 +271,8 @@ func TestDNSService_GetComposeConfig(t *testing.T) { if len(cfg.Services) != 1 { t.Errorf("Expected 1 service, got %d", len(cfg.Services)) } - if cfg.Services[0].Name != "dns.test" { - t.Errorf("Expected service name to be 'dns.test', got %s", cfg.Services[0].Name) + if cfg.Services[0].Name != "dns" { + t.Errorf("Expected service name to be 'dns', got %s", cfg.Services[0].Name) } }) @@ -352,20 +352,18 @@ func TestDNSService_WriteConfig(t *testing.T) { } // Verify that the Corefile content is correctly formatted - expectedCorefileContent := ` -test:53 { - errors - reload - loop + expectedCorefileContent := `test:53 { hosts { 127.0.0.1 test 192.168.1.1 test fallthrough } + + reload + loop forward . 1.1.1.1 8.8.8.8 } .:53 { - errors reload loop forward . 1.1.1.1 8.8.8.8 @@ -375,6 +373,7 @@ test:53 { t.Errorf("Expected Corefile content:\n%s\nGot:\n%s", expectedCorefileContent, string(writtenContent)) } }) + t.Run("SuccessLocalhost", func(t *testing.T) { // Create mocks and set up the mock context mocks := createDNSServiceMocks() @@ -408,20 +407,18 @@ test:53 { } // Verify that the Corefile content is correctly formatted for localhost - expectedCorefileContent := ` -test:53 { - errors - reload - loop + expectedCorefileContent := `test:53 { hosts { 127.0.0.1 test 192.168.1.1 test fallthrough } + + reload + loop forward . 1.1.1.1 8.8.8.8 } .:53 { - errors reload loop forward . 1.1.1.1 8.8.8.8 @@ -432,16 +429,9 @@ test:53 { } }) - t.Run("ErrorRetrievingProjectRoot", func(t *testing.T) { - // Create a mock context that returns an error on GetProjectRoot + t.Run("SuccessLocalhostMode", func(t *testing.T) { + // Setup mock components mocks := createDNSServiceMocks() - mocks.MockShell.GetProjectRootFunc = func() (string, error) { - return "", fmt.Errorf("mock error retrieving project root") - } - - mocks.Injector.Register("dockerService", NewMockService()) - - // Given: a DNSService with the mock context service := NewDNSService(mocks.Injector) // Initialize the service @@ -449,58 +439,148 @@ test:53 { t.Fatalf("Initialize() error = %v", err) } - // When: WriteConfig is called + // Set vm.driver to docker-desktop to simulate localhost mode + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "vm.driver" { + return "docker-desktop" + } + if key == "dns.domain" { + return "test" + } + if key == "network.cidr_block" { + return "192.168.1.0/24" + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } + + // Create a mock service with a hostname + mockService := NewMockService() + mockService.GetNameFunc = func() string { + return "test-service" + } + mockService.GetAddressFunc = func() string { + return "192.168.1.2" + } + mockService.GetComposeConfigFunc = func() (*types.Config, error) { + return &types.Config{ + Services: []types.ServiceConfig{ + {Name: "test-service"}, + }, + }, nil + } + mockService.GetHostnameFunc = func() string { + return "test-service.test" + } + mockService.SupportsWildcardFunc = func() bool { + return false + } + + // Register the mock service + mocks.Injector.Register("test-service", mockService) + service.services = []Service{mockService} + + // Mock the writeFile function to capture the content written + var writtenContent []byte + originalWriteFile := writeFile + defer func() { writeFile = originalWriteFile }() + writeFile = func(filename string, data []byte, perm os.FileMode) error { + writtenContent = data + return nil + } + + // Call WriteConfig err := service.WriteConfig() - // Then: an error should be returned - if err == nil || !strings.Contains(err.Error(), "error retrieving project root") { - t.Fatalf("expected error retrieving project root, got %v", err) + // Assert no error occurred + if err != nil { + t.Fatalf("WriteConfig() error = %v", err) + } + + // Verify that the Corefile content includes both regular and localhost entries + content := string(writtenContent) + expectedEntries := []string{ + "192.168.1.2 test-service.test", + "127.0.0.1 test-service.test", + } + for _, entry := range expectedEntries { + if !strings.Contains(content, entry) { + t.Errorf("Expected Corefile to contain entry %q, got:\n%s", entry, content) + } + } + + // Verify that the internal view is present + if !strings.Contains(content, "view internal") { + t.Errorf("Expected Corefile to contain internal view, got:\n%s", content) } }) - t.Run("ValidAddress", func(t *testing.T) { - // Create a mock context and config handler + t.Run("SuccessWithHostname", func(t *testing.T) { + // Setup mock components mocks := createDNSServiceMocks() + service := NewDNSService(mocks.Injector) - // Create a mock service that returns a valid address + // Initialize the service + if err := service.Initialize(); err != nil { + t.Fatalf("Initialize() error = %v", err) + } + + // Create a mock service with a hostname mockService := NewMockService() + mockService.GetNameFunc = func() string { + return "test-service" + } + mockService.GetAddressFunc = func() string { + return "192.168.1.2" + } mockService.GetComposeConfigFunc = func() (*types.Config, error) { return &types.Config{ Services: []types.ServiceConfig{ - {Name: "mockService"}, + {Name: "test-service"}, }, }, nil } - mockService.GetAddressFunc = func() string { - return "192.168.1.1" - } mockService.GetHostnameFunc = func() string { - return "mockService.test" + return "test-service.test" + } + mockService.SupportsWildcardFunc = func() bool { + return false } - mocks.Injector.Register("dockerService", mockService) - // Given: a DNSService with the mock config handler, context, and DockerService - service := NewDNSService(mocks.Injector) + // Register the mock service + mocks.Injector.Register("test-service", mockService) + service.services = []Service{mockService} - // Initialize the service - if err := service.Initialize(); err != nil { - t.Fatalf("Initialize() error = %v", err) + // Mock the writeFile function to capture the content written + var writtenContent []byte + originalWriteFile := writeFile + defer func() { writeFile = originalWriteFile }() + writeFile = func(filename string, data []byte, perm os.FileMode) error { + writtenContent = data + return nil } - // When: WriteConfig is called + // Call WriteConfig err := service.WriteConfig() - // Then: no error should be returned + // Assert no error occurred if err != nil { t.Fatalf("WriteConfig() error = %v", err) } + + // Verify that the Corefile content includes the service hostname + expectedHostEntry := "192.168.1.2 test-service.test" + content := string(writtenContent) + if !strings.Contains(content, expectedHostEntry) { + t.Errorf("Expected Corefile to contain host entry %q, got:\n%s", expectedHostEntry, content) + } }) - t.Run("ErrorWritingCorefile", func(t *testing.T) { - // Mock the GetConfigRoot function to return a mock path + t.Run("SuccessWithWildcard", func(t *testing.T) { + // Setup mock components mocks := createDNSServiceMocks() - - // Given: a DNSService with the mock config handler, context, and DockerService service := NewDNSService(mocks.Injector) // Initialize the service @@ -508,38 +588,149 @@ test:53 { t.Fatalf("Initialize() error = %v", err) } - // Mock the writeFile function to return an error - writeFile = func(_ string, _ []byte, _ os.FileMode) error { - return fmt.Errorf("mock error writing file") + // Create a mock service with wildcard support + mockService := NewMockService() + mockService.GetNameFunc = func() string { + return "test-service" + } + mockService.GetAddressFunc = func() string { + return "192.168.1.2" + } + mockService.GetComposeConfigFunc = func() (*types.Config, error) { + return &types.Config{ + Services: []types.ServiceConfig{ + {Name: "test-service"}, + }, + }, nil + } + mockService.GetHostnameFunc = func() string { + return "test-service.test" + } + mockService.SupportsWildcardFunc = func() bool { + return true + } + + // Register the mock service + mocks.Injector.Register("test-service", mockService) + service.services = []Service{mockService} + + // Mock the writeFile function to capture the content written + var writtenContent []byte + originalWriteFile := writeFile + defer func() { writeFile = originalWriteFile }() + writeFile = func(filename string, data []byte, perm os.FileMode) error { + writtenContent = data + return nil } - // When: WriteConfig is called + // Call WriteConfig err := service.WriteConfig() - // Then: an error should be returned - if err == nil || !strings.Contains(err.Error(), "error writing Corefile") { - t.Fatalf("expected error writing Corefile, got %v", err) + // Assert no error occurred + if err != nil { + t.Fatalf("WriteConfig() error = %v", err) + } + + // Verify that the Corefile content includes both the service hostname and wildcard entry + expectedHostEntry := "192.168.1.2 test-service.test" + expectedWildcardMatches := []string{ + "template IN A", + "match ^(.*)\\.test-service\\.test\\.$", + `answer "{{ .Name }} 60 IN A 192.168.1.2"`, + "fallthrough", + } + + content := string(writtenContent) + if !strings.Contains(content, expectedHostEntry) { + t.Errorf("Expected Corefile to contain host entry %q, got:\n%s", expectedHostEntry, content) + } + for _, expectedMatch := range expectedWildcardMatches { + if !strings.Contains(content, expectedMatch) { + t.Errorf("Expected Corefile to contain %q, got:\n%s", expectedMatch, content) + } } }) - t.Run("MkdirAllError", func(t *testing.T) { - // Setup injector with mocks + t.Run("SuccessWithMissingNameOrAddress", func(t *testing.T) { + // Setup mock components mocks := createDNSServiceMocks() + service := NewDNSService(mocks.Injector) - // Save the original mkdirAll function - originalMkdirAll := mkdirAll + // Initialize the service + if err := service.Initialize(); err != nil { + t.Fatalf("Initialize() error = %v", err) + } - // Override mkdirAll to simulate an error - mkdirAll = func(path string, perm os.FileMode) error { - return fmt.Errorf("mock error creating directories") + // Create a mock service with missing name + mockServiceNoName := NewMockService() + mockServiceNoName.GetNameFunc = func() string { + return "" + } + mockServiceNoName.GetAddressFunc = func() string { + return "192.168.1.2" + } + mockServiceNoName.GetComposeConfigFunc = func() (*types.Config, error) { + return &types.Config{ + Services: []types.ServiceConfig{ + {Name: ""}, + }, + }, nil + } + + // Create a mock service with missing address + mockServiceNoAddress := NewMockService() + mockServiceNoAddress.GetNameFunc = func() string { + return "test-service" + } + mockServiceNoAddress.GetAddressFunc = func() string { + return "" + } + mockServiceNoAddress.GetComposeConfigFunc = func() (*types.Config, error) { + return &types.Config{ + Services: []types.ServiceConfig{ + {Name: "test-service"}, + }, + }, nil + } + + // Register the mock services + mocks.Injector.Register("test-service-no-name", mockServiceNoName) + mocks.Injector.Register("test-service-no-address", mockServiceNoAddress) + service.services = []Service{mockServiceNoName, mockServiceNoAddress} + + // Mock the writeFile function to capture the content written + var writtenContent []byte + originalWriteFile := writeFile + defer func() { writeFile = originalWriteFile }() + writeFile = func(filename string, data []byte, perm os.FileMode) error { + writtenContent = data + return nil } - // Restore the original mkdirAll after the test - defer func() { - mkdirAll = originalMkdirAll - }() + // Call WriteConfig + err := service.WriteConfig() + + // Assert no error occurred + if err != nil { + t.Fatalf("WriteConfig() error = %v", err) + } + + // Verify that the Corefile content does not include entries for services with missing name or address + content := string(writtenContent) + unexpectedEntries := []string{ + "192.168.1.2", // Should not appear since service has no name + "test-service", // Should not appear since service has no address + } + for _, entry := range unexpectedEntries { + if strings.Contains(content, entry) { + t.Errorf("Expected Corefile to not contain %q, got:\n%s", entry, content) + } + } + }) - // Create the DNSService instance + t.Run("ErrorCreatingDirectory", func(t *testing.T) { + // Setup mock components + mocks := createDNSServiceMocks() service := NewDNSService(mocks.Injector) // Initialize the service @@ -547,17 +738,27 @@ test:53 { t.Fatalf("Initialize() error = %v", err) } - // Call WriteConfig and expect an error + // Mock mkdirAll to return an error + originalMkdirAll := mkdirAll + defer func() { mkdirAll = originalMkdirAll }() + mkdirAll = func(path string, perm os.FileMode) error { + return fmt.Errorf("mocked error creating directory") + } + + // Call WriteConfig err := service.WriteConfig() - // Check if the error matches the expected error - expectedError := "error creating parent folders: mock error creating directories" - if err == nil || err.Error() != expectedError { - t.Fatalf("expected error %v, got %v", expectedError, err) + // Assert error occurred + if err == nil { + t.Fatalf("Expected error, got nil") + } + expectedErrorMessage := "error creating parent folders: mocked error creating directory" + if err.Error() != expectedErrorMessage { + t.Errorf("Expected error message '%s', got %v", expectedErrorMessage, err) } }) - t.Run("SuccessLocalhostMode", func(t *testing.T) { + t.Run("ErrorWritingFile", func(t *testing.T) { // Setup mock components mocks := createDNSServiceMocks() service := NewDNSService(mocks.Injector) @@ -567,29 +768,37 @@ test:53 { t.Fatalf("Initialize() error = %v", err) } - // Create a mock service with a hostname - mockService := NewMockService() - mockService.GetNameFunc = func() string { - return "test-service" + // Mock writeFile to return an error + originalWriteFile := writeFile + defer func() { writeFile = originalWriteFile }() + writeFile = func(filename string, data []byte, perm os.FileMode) error { + return fmt.Errorf("mocked error writing file") } - mockService.GetHostnameFunc = func() string { - return "test-service.test" + + // Call WriteConfig + err := service.WriteConfig() + + // Assert error occurred + if err == nil { + t.Fatalf("Expected error, got nil") } - mockService.GetAddressFunc = func() string { - return "192.168.1.2" + expectedErrorMessage := "error writing Corefile: mocked error writing file" + if err.Error() != expectedErrorMessage { + t.Errorf("Expected error message '%s', got %v", expectedErrorMessage, err) } - mockService.GetComposeConfigFunc = func() (*types.Config, error) { - return &types.Config{ - Services: []types.ServiceConfig{ - { - Name: "test-service", - }, - }, - }, nil + }) + + t.Run("SuccessLocalhostModeWithWildcard", func(t *testing.T) { + // Setup mock components + mocks := createDNSServiceMocks() + service := NewDNSService(mocks.Injector) + + // Initialize the service + if err := service.Initialize(); err != nil { + t.Fatalf("Initialize() error = %v", err) } - service.services = []Service{mockService} - // Mock the config handler to return docker-desktop for vm.driver and DNS records + // Set vm.driver to docker-desktop to simulate localhost mode mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { if key == "vm.driver" { return "docker-desktop" @@ -597,37 +806,41 @@ test:53 { if key == "dns.domain" { return "test" } + if key == "network.cidr_block" { + return "192.168.1.0/24" + } if len(defaultValue) > 0 { return defaultValue[0] } return "" } - mocks.MockConfigHandler.GetStringSliceFunc = func(key string, defaultValue ...[]string) []string { - if key == "dns.records" { - return []string{ - "127.0.0.1 test", - "192.168.1.1 test", - } - } - if len(defaultValue) > 0 { - return defaultValue[0] - } - return nil + // Create a mock service with wildcard support + mockService := NewMockService() + mockService.GetNameFunc = func() string { + return "test-service" } - - // Mock the GetProjectRoot function to return a valid path - mocks.MockShell.GetProjectRootFunc = func() (string, error) { - return "/valid/path", nil + mockService.GetAddressFunc = func() string { + return "192.168.1.2" } - - // Mock the mkdirAll function to simulate successful directory creation - originalMkdirAll := mkdirAll - defer func() { mkdirAll = originalMkdirAll }() - mkdirAll = func(path string, perm os.FileMode) error { - return nil + mockService.GetComposeConfigFunc = func() (*types.Config, error) { + return &types.Config{ + Services: []types.ServiceConfig{ + {Name: "test-service"}, + }, + }, nil + } + mockService.GetHostnameFunc = func() string { + return "test-service.test" + } + mockService.SupportsWildcardFunc = func() bool { + return true } + // Register the mock service + mocks.Injector.Register("test-service", mockService) + service.services = []Service{mockService} + // Mock the writeFile function to capture the content written var writtenContent []byte originalWriteFile := writeFile @@ -637,37 +850,32 @@ test:53 { return nil } - // When: WriteConfig is called + // Call WriteConfig err := service.WriteConfig() - // Then: no error should be returned + // Assert no error occurred if err != nil { t.Fatalf("WriteConfig() error = %v", err) } - // Verify that the Corefile content uses 127.0.0.1 for the service address - expectedCorefileContent := ` -test:53 { - errors - reload - loop - hosts { - 127.0.0.1 test-service.test - 127.0.0.1 test - 192.168.1.1 test - fallthrough - } - forward . 1.1.1.1 8.8.8.8 -} -.:53 { - errors - reload - loop - forward . 1.1.1.1 8.8.8.8 -} -` - if string(writtenContent) != expectedCorefileContent { - t.Errorf("Expected Corefile content:\n%s\nGot:\n%s", expectedCorefileContent, string(writtenContent)) + // Verify that the Corefile content includes both regular and localhost wildcard entries + content := string(writtenContent) + expectedWildcardMatches := []string{ + "template IN A", + "match ^(.*)\\.test-service\\.test\\.$", + `answer "{{ .Name }} 60 IN A 192.168.1.2"`, + "fallthrough", + `answer "{{ .Name }} 60 IN A 127.0.0.1"`, + } + for _, expectedMatch := range expectedWildcardMatches { + if !strings.Contains(content, expectedMatch) { + t.Errorf("Expected Corefile to contain %q, got:\n%s", expectedMatch, content) + } + } + + // Verify that the internal view is present + if !strings.Contains(content, "view internal") { + t.Errorf("Expected Corefile to contain internal view, got:\n%s", content) } }) } diff --git a/pkg/services/git_livereload_service.go b/pkg/services/git_livereload_service.go index eb06b231a..5c22e1f7b 100644 --- a/pkg/services/git_livereload_service.go +++ b/pkg/services/git_livereload_service.go @@ -64,14 +64,10 @@ func (s *GitLivereloadService) GetComposeConfig() (*types.Config, error) { // Get the git folder name gitFolderName := filepath.Base(projectRoot) - // Get the domain from the configuration - tld := s.configHandler.GetString("dns.domain", "test") - fullName := s.name + "." + tld - // Add the git-livereload service services = append(services, types.ServiceConfig{ - Name: fullName, - ContainerName: fullName, + Name: s.name, + ContainerName: s.GetContainerName(), Image: image, Restart: "always", Environment: envVars, diff --git a/pkg/services/git_livereload_service_test.go b/pkg/services/git_livereload_service_test.go index 3a784f058..4abef6deb 100644 --- a/pkg/services/git_livereload_service_test.go +++ b/pkg/services/git_livereload_service_test.go @@ -93,7 +93,7 @@ func TestGitLivereloadService_GetComposeConfig(t *testing.T) { } // Then: verify the configuration contains the expected service - expectedName := "git.test" + expectedName := "git" expectedImage := constants.DEFAULT_GIT_LIVE_RELOAD_IMAGE serviceFound := false diff --git a/pkg/services/localstack_service.go b/pkg/services/localstack_service.go index 484b0f979..af123378e 100644 --- a/pkg/services/localstack_service.go +++ b/pkg/services/localstack_service.go @@ -44,15 +44,15 @@ func (s *LocalstackService) GetComposeConfig() (*types.Config, error) { servicesList = strings.Join(contextConfig.AWS.Localstack.Services, ",") } - // Get the domain from the configuration - tld := s.configHandler.GetString("dns.domain", "test") - fullName := s.name + "." + tld + // Get the service name and container name + serviceName := s.GetName() + containerName := s.GetContainerName() // Create the service config services := []types.ServiceConfig{ { - Name: fullName, - ContainerName: fullName, + Name: serviceName, + ContainerName: containerName, Image: image, Restart: "always", Environment: map[string]*string{ @@ -84,3 +84,8 @@ func (s *LocalstackService) GetComposeConfig() (*types.Config, error) { // Ensure LocalstackService implements Service interface var _ Service = (*LocalstackService)(nil) + +// SupportsWildcard returns whether the service supports wildcard DNS entries +func (s *LocalstackService) SupportsWildcard() bool { + return true +} diff --git a/pkg/services/localstack_service_test.go b/pkg/services/localstack_service_test.go index 63548b0f1..deef26fea 100644 --- a/pkg/services/localstack_service_test.go +++ b/pkg/services/localstack_service_test.go @@ -98,8 +98,8 @@ func TestLocalstackService_GetComposeConfig(t *testing.T) { } service := composeConfig.Services[0] - if service.Name != "aws.test" { - t.Errorf("expected service name 'localstack', got %v", service.Name) + if service.Name != "aws" { + t.Errorf("expected service name 'aws', got %v", service.Name) } if service.Environment["SERVICES"] == nil || *service.Environment["SERVICES"] != "s3,dynamodb" { t.Errorf("expected SERVICES environment variable to be 's3,dynamodb', got %v", service.Environment["SERVICES"]) @@ -151,3 +151,19 @@ func TestLocalstackService_GetComposeConfig(t *testing.T) { } }) } + +func TestLocalstackService_SupportsWildcard(t *testing.T) { + // Create a mock injector + mockInjector := di.NewMockInjector() + + // Create a LocalstackService + service := NewLocalstackService(mockInjector) + + // Call SupportsWildcard + supportsWildcard := service.SupportsWildcard() + + // Verify that SupportsWildcard returns true + if !supportsWildcard { + t.Errorf("Expected SupportsWildcard to return true, got false") + } +} diff --git a/pkg/services/mock_service.go b/pkg/services/mock_service.go index 0988920e8..27162d49c 100644 --- a/pkg/services/mock_service.go +++ b/pkg/services/mock_service.go @@ -23,6 +23,8 @@ type MockService struct { GetNameFunc func() string // GetHostnameFunc is a function that mocks the GetHostname method GetHostnameFunc func() string + // SupportsWildcardFunc is a function that mocks the SupportsWildcard method + SupportsWildcardFunc func() bool } // NewMockService is a constructor for MockService @@ -94,5 +96,13 @@ func (m *MockService) GetHostname() string { return "" } +// SupportsWildcard calls the mock SupportsWildcardFunc if it is set, otherwise returns false +func (m *MockService) SupportsWildcard() bool { + if m.SupportsWildcardFunc != nil { + return m.SupportsWildcardFunc() + } + return false +} + // Ensure MockService implements Service interface var _ Service = (*MockService)(nil) diff --git a/pkg/services/mock_service_test.go b/pkg/services/mock_service_test.go index 2a6beb4ee..18f4c8489 100644 --- a/pkg/services/mock_service_test.go +++ b/pkg/services/mock_service_test.go @@ -350,34 +350,63 @@ func TestMockService_GetName(t *testing.T) { } func TestMockService_GetHostname(t *testing.T) { - t.Run("Success", func(t *testing.T) { - // Given: a mock service + t.Run("WithFunction", func(t *testing.T) { + // Create a mock service with a GetHostname function mockService := NewMockService() - expectedHostname := "localhost" - - // When: GetHostnameFunc is called - mockGetHostnameFunc := func() string { - return expectedHostname + mockService.GetHostnameFunc = func() string { + return "test-hostname" } - mockService.GetHostnameFunc = mockGetHostnameFunc - // Then: the GetHostnameFunc should be set and return the expected hostname + // Call GetHostname hostname := mockService.GetHostname() - if hostname != expectedHostname { - t.Errorf("expected hostname %v, got %v", expectedHostname, hostname) + + // Verify that GetHostname returns the expected value + if hostname != "test-hostname" { + t.Errorf("Expected hostname 'test-hostname', got %q", hostname) } }) - t.Run("SuccessNoMock", func(t *testing.T) { - // Given: a mock service with no GetHostnameFunc + t.Run("WithoutFunction", func(t *testing.T) { + // Create a mock service without a GetHostname function mockService := NewMockService() - // When: GetHostname is called + // Call GetHostname hostname := mockService.GetHostname() - // Then: an empty string should be returned + // Verify that GetHostname returns an empty string if hostname != "" { - t.Errorf("expected empty hostname, got %v", hostname) + t.Errorf("Expected empty hostname, got %q", hostname) + } + }) +} + +func TestMockService_SupportsWildcard(t *testing.T) { + t.Run("WithFunction", func(t *testing.T) { + // Create a mock service with a SupportsWildcard function + mockService := NewMockService() + mockService.SupportsWildcardFunc = func() bool { + return true + } + + // Call SupportsWildcard + supportsWildcard := mockService.SupportsWildcard() + + // Verify that SupportsWildcard returns the expected value + if !supportsWildcard { + t.Errorf("Expected SupportsWildcard to return true, got false") + } + }) + + t.Run("WithoutFunction", func(t *testing.T) { + // Create a mock service without a SupportsWildcard function + mockService := NewMockService() + + // Call SupportsWildcard + supportsWildcard := mockService.SupportsWildcard() + + // Verify that SupportsWildcard returns false + if supportsWildcard { + t.Errorf("Expected SupportsWildcard to return false, got true") } }) } diff --git a/pkg/services/registry_service.go b/pkg/services/registry_service.go index 55c484993..200e69762 100644 --- a/pkg/services/registry_service.go +++ b/pkg/services/registry_service.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "strings" + "sync" "github.com/compose-spec/compose-go/types" "github.com/windsorcli/cli/api/v1alpha1/docker" @@ -11,10 +12,16 @@ import ( "github.com/windsorcli/cli/pkg/di" ) +var ( + registryNextPort = constants.REGISTRY_DEFAULT_HOST_PORT + 1 + registryMu sync.Mutex + localRegistry *RegistryService +) + // RegistryService is a service struct that provides Registry-specific utility functions type RegistryService struct { BaseService - HostPort int // If set, this port is routed to the registry port from the host + hostPort int } // NewRegistryService is a constructor for RegistryService @@ -34,8 +41,7 @@ func (s *RegistryService) GetComposeConfig() (*types.Config, error) { registries := contextConfig.Docker.Registries if registry, exists := registries[s.name]; exists { - hostname := s.GetHostname() - service, err := s.generateRegistryService(hostname, registry) + service, err := s.generateRegistryService(registry) if err != nil { return nil, fmt.Errorf("failed to generate registry service: %w", err) } @@ -46,16 +52,12 @@ func (s *RegistryService) GetComposeConfig() (*types.Config, error) { } // SetAddress configures the registry's address, forms a hostname, and updates the registry config. -// It selects a port by checking the registry's HostPort; if unset and on localhost, it defaults to -// REGISTRY_DEFAULT_HOST_PORT. The port's availability is verified before assignment. If the registry -// is not a proxy ("remote" is not set), and it is localhost, it attempts to set HostPort to -// the default registry port. +// It assigns the "registry_url" and the default host port for the first non-remote registry, storing it as "localRegistry". func (s *RegistryService) SetAddress(address string) error { if err := s.BaseService.SetAddress(address); err != nil { return fmt.Errorf("failed to set address for base service: %w", err) } - defaultPort := constants.REGISTRY_DEFAULT_HOST_PORT hostName := s.GetHostname() err := s.configHandler.SetContextValue(fmt.Sprintf("docker.registries[%s].hostname", s.name), hostName) @@ -68,16 +70,25 @@ func (s *RegistryService) SetAddress(address string) error { if registryConfig.HostPort != 0 { hostPort = registryConfig.HostPort - } else if registryConfig.Remote == "" && s.IsLocalhostMode() { - hostPort = defaultPort - err = s.configHandler.SetContextValue("docker.registry_url", hostName) - if err != nil { - return fmt.Errorf("failed to set registry URL for registry %s: %w", s.name, err) + } else if s.isLocalhostMode() { + registryMu.Lock() + defer registryMu.Unlock() + + if registryConfig.Remote == "" && localRegistry == nil { + localRegistry = s + hostPort = constants.REGISTRY_DEFAULT_HOST_PORT + err = s.configHandler.SetContextValue("docker.registry_url", hostName) + if err != nil { + return fmt.Errorf("failed to set registry URL for registry %s: %w", s.name, err) + } + } else { + hostPort = registryNextPort + registryNextPort++ } } if hostPort != 0 { - s.HostPort = hostPort + s.hostPort = hostPort err := s.configHandler.SetContextValue(fmt.Sprintf("docker.registries[%s].hostport", s.name), hostPort) if err != nil { return fmt.Errorf("failed to set host port for registry %s: %w", s.name, err) @@ -87,27 +98,24 @@ func (s *RegistryService) SetAddress(address string) error { return nil } -// GetHostname returns the hostname of the registry service. This is constructed -// by removing the existing domain from the name and appending the configured domain. +// GetHostname returns the hostname for the registry service, removing the last domain part func (s *RegistryService) GetHostname() string { - domain := s.configHandler.GetString("dns.domain", "test") - nameWithoutDomain := s.name - if dotIndex := strings.LastIndex(s.name, "."); dotIndex != -1 { - nameWithoutDomain = s.name[:dotIndex] - } - return nameWithoutDomain + "." + domain + tld := s.configHandler.GetString("dns.domain", "test") + return getBasename(s.GetName()) + "." + tld } // This function generates a ServiceConfig for a Registry service. It sets up the service's name, image, // restart policy, and labels. It configures environment variables based on registry URLs, creates a // cache directory, and sets volume mounts. Ports are assigned only for non-proxy registries when the // network mode is localhost. It returns the configured ServiceConfig or an error if any step fails. -func (s *RegistryService) generateRegistryService(hostname string, registry docker.RegistryConfig) (types.ServiceConfig, error) { +func (s *RegistryService) generateRegistryService(registry docker.RegistryConfig) (types.ServiceConfig, error) { contextName := s.configHandler.GetContext() + serviceName := getBasename(s.GetHostname()) + containerName := s.GetContainerName() service := types.ServiceConfig{ - Name: hostname, - ContainerName: hostname, + Name: serviceName, + ContainerName: containerName, Image: constants.REGISTRY_DEFAULT_IMAGE, Restart: "always", Labels: map[string]string{ @@ -117,19 +125,23 @@ func (s *RegistryService) generateRegistryService(hostname string, registry dock }, } - env := make(types.MappingWithEquals) + // Initialize environment variables + env := make(map[string]*string) + // Set remote URL if specified if registry.Remote != "" { - env["REGISTRY_PROXY_REMOTEURL"] = ®istry.Remote + remoteURL := registry.Remote + env["REGISTRY_PROXY_REMOTEURL"] = &remoteURL } + // Set local URL if specified if registry.Local != "" { - env["REGISTRY_PROXY_LOCALURL"] = ®istry.Local + localURL := registry.Local + env["REGISTRY_PROXY_LOCALURL"] = &localURL } - if len(env) > 0 { - service.Environment = env - } + // Always set environment, even if empty + service.Environment = env projectRoot, err := s.shell.GetProjectRoot() if err != nil { @@ -144,11 +156,11 @@ func (s *RegistryService) generateRegistryService(hostname string, registry dock {Type: "bind", Source: "${WINDSOR_PROJECT_ROOT}/.windsor/.docker-cache", Target: "/var/lib/registry"}, } - if registry.Remote == "" && s.IsLocalhostMode() { + if s.isLocalhostMode() { service.Ports = []types.ServicePortConfig{ { Target: 5000, - Published: fmt.Sprintf("%d", s.HostPort), + Published: fmt.Sprintf("%d", s.hostPort), Protocol: "tcp", }, } @@ -159,3 +171,11 @@ func (s *RegistryService) generateRegistryService(hostname string, registry dock // Ensure RegistryService implements Service interface var _ Service = (*RegistryService)(nil) + +// getBasename removes the last part of a domain name if it exists +func getBasename(name string) string { + if parts := strings.Split(name, "."); len(parts) > 1 { + return strings.Join(parts[:len(parts)-1], ".") + } + return name +} diff --git a/pkg/services/registry_service_test.go b/pkg/services/registry_service_test.go index d5038ff14..b0640db3b 100644 --- a/pkg/services/registry_service_test.go +++ b/pkg/services/registry_service_test.go @@ -118,8 +118,12 @@ func TestRegistryService_GetComposeConfig(t *testing.T) { t.Fatalf("GetComposeConfig() error = %v", err) } - // Then check for characteristic properties in the configuration - expectedName := "registry.test" + // Then: the compose configuration should include the registry service + if composeConfig == nil || len(composeConfig.Services) == 0 { + t.Fatalf("expected non-nil composeConfig with services, got %v", composeConfig) + } + + expectedName := "registry" expectedRemoteURL := "registry.remote" expectedLocalURL := "registry.local" found := false @@ -267,13 +271,13 @@ func TestRegistryService_GetComposeConfig(t *testing.T) { // Then check that the service has the expected port configuration expectedPortConfig := types.ServicePortConfig{ Target: 5000, - Published: fmt.Sprintf("%d", registryService.HostPort), + Published: fmt.Sprintf("%d", registryService.hostPort), Protocol: "tcp", } found := false for _, config := range composeConfig.Services { - if config.Name == "local-registry.test" { + if config.Name == "local-registry" { for _, portConfig := range config.Ports { if portConfig.Target == expectedPortConfig.Target && portConfig.Published == expectedPortConfig.Published && @@ -286,7 +290,7 @@ func TestRegistryService_GetComposeConfig(t *testing.T) { } if !found { - t.Errorf("expected service with name %q to have port configuration %+v in the list of configurations:\n%+v", "local-registry.test", expectedPortConfig, composeConfig.Services) + t.Errorf("expected service with name %q to have port configuration %+v in the list of configurations:\n%+v", "local-registry", expectedPortConfig, composeConfig.Services) } }) } @@ -405,8 +409,8 @@ func TestRegistryService_SetAddress(t *testing.T) { } // Then the default port should be set - if registryService.HostPort != constants.REGISTRY_DEFAULT_HOST_PORT { - t.Errorf("expected HostPort to be set to default, got %v", registryService.HostPort) + if registryService.hostPort != constants.REGISTRY_DEFAULT_HOST_PORT { + t.Errorf("expected HostPort to be set to default, got %v", registryService.hostPort) } }) @@ -437,116 +441,159 @@ func TestRegistryService_SetAddress(t *testing.T) { } // Then the HostPort should be set to the configured port - if registryService.HostPort != 5000 { - t.Errorf("expected HostPort to be 5000, got %v", registryService.HostPort) + if registryService.hostPort != 5000 { + t.Errorf("expected HostPort to be 5000, got %v", registryService.hostPort) } }) t.Run("SetRegistryURLAndHostPort", func(t *testing.T) { - // Given a mock config handler, shell, context, and service with no HostPort and no Remote - mocks := setupSafeRegistryServiceMocks() - mocks.MockConfigHandler.GetConfigFunc = func() *v1alpha1.Context { + // Reset global state + localRegistry = nil + registryNextPort = constants.REGISTRY_DEFAULT_HOST_PORT + 1 + + // Setup mock components + mockConfig := config.NewMockConfigHandler() + mockConfig.GetConfigFunc = func() *v1alpha1.Context { return &v1alpha1.Context{ Docker: &docker.DockerConfig{ Registries: map[string]docker.RegistryConfig{ - "registry": {HostPort: 0, Remote: ""}, + "test-registry": { + HostPort: 0, + Remote: "", + }, }, }, } } - // Set vm.driver to docker-desktop for localhost tests - mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + + mockConfig.GetStringFunc = func(key string, defaultValue ...string) string { if key == "vm.driver" { return "docker-desktop" } if key == "dns.domain" { return "test" } - if len(defaultValue) > 0 { - return defaultValue[0] - } return "" } - registryService := NewRegistryService(mocks.Injector) - registryService.SetName("registry") - err := registryService.Initialize() - if err != nil { - t.Fatalf("Initialize() error = %v", err) - } - // Mock the SetContextValue function to track if it's called - setContextValueCalled := false - mocks.MockConfigHandler.SetContextValueFunc = func(key string, value interface{}) error { - if key == "docker.registry_url" { - setContextValueCalled = true - } + var setContextValueCalls = make(map[string]interface{}) + mockConfig.SetContextValueFunc = func(key string, value interface{}) error { + setContextValueCalls[key] = value return nil } - // When SetAddress is called with localhost - address := "127.0.0.1" - err = registryService.SetAddress(address) + // Initialize service + service := NewRegistryService(di.NewInjector()) + service.name = "test-registry" + service.configHandler = mockConfig + + // Set address + err := service.SetAddress("127.0.0.1") if err != nil { - t.Fatalf("SetAddress() error = %v", err) + t.Fatalf("SetAddress failed: %v", err) } - // Then the default port should be set and registry URL should be set - if registryService.HostPort != constants.REGISTRY_DEFAULT_HOST_PORT { - t.Errorf("expected HostPort to be set to default, got %v", registryService.HostPort) + // Verify default port was set + if service.hostPort != constants.REGISTRY_DEFAULT_HOST_PORT { + t.Errorf("Expected hostPort to be %d, got %d", constants.REGISTRY_DEFAULT_HOST_PORT, service.hostPort) } - if !setContextValueCalled { - t.Errorf("expected SetContextValue to be called for registry URL, but it was not") + + // Verify hostname was set + expectedHostname := "test-registry.test" + if value, exists := setContextValueCalls["docker.registries[test-registry].hostname"]; !exists { + t.Error("Expected SetContextValue to be called for hostname, but it was not") + } else if value != expectedHostname { + t.Errorf("Expected hostname to be %q, got %q", expectedHostname, value) } - }) - t.Run("SetContextValueErrorForHostPort", func(t *testing.T) { - // Given a mock config handler that will fail to set context value for host port - mocks := setupSafeRegistryServiceMocks() - mocks.MockConfigHandler.SetContextValueFunc = func(key string, value interface{}) error { - if key == fmt.Sprintf("docker.registries[%s].hostport", "registry") { - return fmt.Errorf("failed to set host port") - } - return nil + // Verify registry URL was set + if value, exists := setContextValueCalls["docker.registry_url"]; !exists { + t.Error("Expected SetContextValue to be called for registry URL, but it was not") + } else if value != expectedHostname { + t.Errorf("Expected registry URL to be %q, got %q", expectedHostname, value) } - mocks.MockConfigHandler.GetConfigFunc = func() *v1alpha1.Context { + + // Verify hostport was set + if value, exists := setContextValueCalls["docker.registries[test-registry].hostport"]; !exists { + t.Error("Expected SetContextValue to be called for hostport, but it was not") + } else if value != constants.REGISTRY_DEFAULT_HOST_PORT { + t.Errorf("Expected hostport to be %d, got %d", constants.REGISTRY_DEFAULT_HOST_PORT, value) + } + }) + + t.Run("SetContextValueErrorForRegistryURL", func(t *testing.T) { + // Reset global state + localRegistry = nil + registryNextPort = constants.REGISTRY_DEFAULT_HOST_PORT + 1 + + // Setup mock components + mockConfig := config.NewMockConfigHandler() + mockConfig.GetConfigFunc = func() *v1alpha1.Context { return &v1alpha1.Context{ Docker: &docker.DockerConfig{ Registries: map[string]docker.RegistryConfig{ - "registry": {HostPort: 5000}, + "test-registry": { + HostPort: 0, + Remote: "", + }, }, }, } } - registryService := NewRegistryService(mocks.Injector) - registryService.SetName("registry") - err := registryService.Initialize() - if err != nil { - t.Fatalf("Initialize() error = %v", err) + + mockConfig.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "vm.driver" { + return "docker-desktop" + } + if key == "dns.domain" { + return "test" + } + return "" } - // When SetAddress is called - address := "192.168.1.1" - err = registryService.SetAddress(address) + mockConfig.SetContextValueFunc = func(key string, value interface{}) error { + if key == "docker.registry_url" { + return fmt.Errorf("failed to set registry URL") + } + return nil + } - // Then an error should be returned indicating failure to set host port - if err == nil || !strings.Contains(err.Error(), "failed to set host port") { - t.Fatalf("expected error indicating failure to set host port, got %v", err) + // Initialize service + service := NewRegistryService(di.NewInjector()) + service.name = "test-registry" + service.configHandler = mockConfig + + // Set address + err := service.SetAddress("127.0.0.1") + + // Verify error + if err == nil || !strings.Contains(err.Error(), "failed to set registry URL") { + t.Errorf("Expected error containing 'failed to set registry URL', got %v", err) } }) - t.Run("SetContextValueErrorForRegistryURL", func(t *testing.T) { - // Given a mock config handler, shell, context, and service with no HostPort and no Remote + t.Run("SuccessWithNextPort", func(t *testing.T) { + // Reset package-level variables + registryNextPort = constants.REGISTRY_DEFAULT_HOST_PORT + 1 + localRegistry = nil + + // Given a mock config handler, shell, context, and service mocks := setupSafeRegistryServiceMocks() + + // Override GetConfig to return a config with an empty registry mocks.MockConfigHandler.GetConfigFunc = func() *v1alpha1.Context { return &v1alpha1.Context{ Docker: &docker.DockerConfig{ Registries: map[string]docker.RegistryConfig{ - "registry": {Remote: "", HostPort: 0}, + "test-registry": { + Remote: "", + }, }, }, } } - // Set vm.driver to docker-desktop for localhost tests + + // Override GetString to return docker-desktop for vm.driver mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { if key == "vm.driver" { return "docker-desktop" @@ -559,49 +606,147 @@ func TestRegistryService_SetAddress(t *testing.T) { } return "" } - // Mock the SetContextValue function to return an error for registry URL + + var setContextValueCalls = make(map[string]interface{}) mocks.MockConfigHandler.SetContextValueFunc = func(key string, value interface{}) error { - if key == "docker.registry_url" { - return fmt.Errorf("failed to set registry URL") - } + setContextValueCalls[key] = value return nil } - registryService := NewRegistryService(mocks.Injector) - registryService.SetName("registry") - err := registryService.Initialize() + + // Initialize service + service := NewRegistryService(mocks.Injector) + service.name = "test-registry" + err := service.Initialize() if err != nil { t.Fatalf("Initialize() error = %v", err) } - // When SetAddress is called - address := "localhost" - err = registryService.SetAddress(address) + // Call SetAddress + err = service.SetAddress("127.0.0.1") - // Then an error should be returned indicating failure to set registry URL - if err == nil || !strings.Contains(err.Error(), "failed to set registry URL") { - t.Fatalf("expected error indicating failure to set registry URL, got %v", err) + // Assert no error occurred + if err != nil { + t.Fatalf("SetAddress() error = %v", err) + } + + // Verify that SetContextValue was called for the registry host port + if value, exists := setContextValueCalls["docker.registries[test-registry].hostport"]; !exists { + t.Error("Expected SetContextValue to be called for host port") + } else if value != constants.REGISTRY_DEFAULT_HOST_PORT { + t.Errorf("Expected SetContextValue value to be %d, got %v", constants.REGISTRY_DEFAULT_HOST_PORT, value) + } + + // Call SetAddress again to verify port increment + err = service.SetAddress("127.0.0.1") + + // Assert no error occurred + if err != nil { + t.Fatalf("SetAddress() error = %v", err) + } + + // Verify that SetContextValue was called for the registry host port with incremented value + if value, exists := setContextValueCalls["docker.registries[test-registry].hostport"]; !exists { + t.Error("Expected SetContextValue to be called for host port") + } else if value != constants.REGISTRY_DEFAULT_HOST_PORT+1 { + t.Errorf("Expected SetContextValue value to be %d, got %v", constants.REGISTRY_DEFAULT_HOST_PORT+1, value) + } + }) + + t.Run("SetContextValueErrorForHostPort", func(t *testing.T) { + // Reset global state + localRegistry = nil + registryNextPort = constants.REGISTRY_DEFAULT_HOST_PORT + 1 + + // Setup mock components + mockConfig := config.NewMockConfigHandler() + mockConfig.GetConfigFunc = func() *v1alpha1.Context { + return &v1alpha1.Context{ + Docker: &docker.DockerConfig{ + Registries: map[string]docker.RegistryConfig{ + "test-registry": { + HostPort: 0, + Remote: "", + }, + }, + }, + } + } + + mockConfig.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "vm.driver" { + return "docker-desktop" + } + if key == "dns.domain" { + return "test" + } + return "" + } + + mockConfig.SetContextValueFunc = func(key string, value interface{}) error { + if key == "docker.registries[test-registry].hostport" { + return fmt.Errorf("failed to set host port") + } + return nil + } + + // Initialize service + service := NewRegistryService(di.NewInjector()) + service.name = "test-registry" + service.configHandler = mockConfig + + // Set address + err := service.SetAddress("127.0.0.1") + + // Verify error + if err == nil || !strings.Contains(err.Error(), "failed to set host port for registry test-registry") { + t.Errorf("Expected error containing 'failed to set host port for registry test-registry', got %v", err) } }) } func TestRegistryService_GetHostname(t *testing.T) { t.Run("Success", func(t *testing.T) { - // Given a mock config handler, shell, context, and service - mocks := setupSafeRegistryServiceMocks() - registryService := NewRegistryService(mocks.Injector) - registryService.SetName("registry.oldtld") - err := registryService.Initialize() - if err != nil { - t.Fatalf("Initialize() error = %v", err) + // Setup mock components + mockConfig := config.NewMockConfigHandler() + mockConfig.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "dns.domain" { + return "test" + } + return "" } - // When GetHostname is called - hostname := registryService.GetHostname() + // Initialize service + service := NewRegistryService(di.NewInjector()) + service.name = "registry.oldtld" + service.configHandler = mockConfig + + // Get hostname + hostname := service.GetHostname() - // Then the hostname should be as expected, with the old domain removed + // Verify hostname expectedHostname := "registry.test" if hostname != expectedHostname { - t.Fatalf("expected hostname '%s', got %v", expectedHostname, hostname) + t.Errorf("Expected hostname %q, got %q", expectedHostname, hostname) } }) } + +func createRegistryServiceMocks() *MockComponents { + mockShell := shell.NewMockShell(di.NewInjector()) + mockConfig := config.NewMockConfigHandler() + mockService := NewMockService() + injector := di.NewInjector() + injector.Register("shell", mockShell) + injector.Register("configHandler", mockConfig) + injector.Register("registryService", mockService) + return &MockComponents{ + Injector: injector, + MockShell: mockShell, + MockConfigHandler: mockConfig, + MockService: mockService, + } +} + +func ptrInt(i int) *int { + return &i +} diff --git a/pkg/services/service.go b/pkg/services/service.go index 6644e83cb..54e7a5d84 100644 --- a/pkg/services/service.go +++ b/pkg/services/service.go @@ -35,11 +35,11 @@ type Service interface { // Initialize performs any necessary initialization for the service. Initialize() error - // GetHostname returns the name plus the tld from the config - GetHostname() string + // SupportsWildcard returns whether the service supports wildcard DNS entries + SupportsWildcard() bool - // IsLocalhostMode checks if we are in localhost mode (vm.driver == "docker-desktop") - IsLocalhostMode() bool + // GetHostname returns the hostname for the service, which may include domain processing + GetHostname() string } // BaseService is a base implementation of the Service interface @@ -98,14 +98,25 @@ func (s *BaseService) GetName() string { return s.name } -// GetHostname returns the name plus the tld from the config -func (s *BaseService) GetHostname() string { - tld := s.configHandler.GetString("dns.domain", "test") - return fmt.Sprintf("%s.%s", s.name, tld) +// GetContainerName returns the container name with the "windsor-" prefix and without the DNS domain +func (s *BaseService) GetContainerName() string { + contextName := s.configHandler.GetContext() + return fmt.Sprintf("windsor-%s-%s", contextName, s.name) } // IsLocalhostMode checks if we are in localhost mode (vm.driver == "docker-desktop") -func (s *BaseService) IsLocalhostMode() bool { +func (s *BaseService) isLocalhostMode() bool { vmDriver := s.configHandler.GetString("vm.driver") return vmDriver == "docker-desktop" } + +// SupportsWildcard returns whether the service supports wildcard DNS entries +func (s *BaseService) SupportsWildcard() bool { + return false +} + +// GetHostname returns the hostname for the service with the configured TLD +func (s *BaseService) GetHostname() string { + tld := s.configHandler.GetString("dns.domain", "test") + return s.name + "." + tld +} diff --git a/pkg/services/service_test.go b/pkg/services/service_test.go index f43f93b37..28e3c2d5e 100644 --- a/pkg/services/service_test.go +++ b/pkg/services/service_test.go @@ -168,26 +168,51 @@ func TestBaseService_GetName(t *testing.T) { func TestBaseService_GetHostname(t *testing.T) { t.Run("Success", func(t *testing.T) { - // Given: a new BaseService with a mock config handler - mocks := setupSafeBaseServiceMocks() - service := &BaseService{injector: mocks.Injector} - service.SetName("TestService") - mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + // Setup mock components + mockConfig := config.NewMockConfigHandler() + mockConfig.GetStringFunc = func(key string, defaultValue ...string) string { if key == "dns.domain" { - return "example" + return "example.com" } return "" } - service.Initialize() + // Initialize service + service := &BaseService{ + name: "test-service", + configHandler: mockConfig, + } - // When: GetHostname is called + // Get hostname hostname := service.GetHostname() - // Then: the hostname should be as expected - expectedHostname := "TestService.example" + // Verify hostname + expectedHostname := "test-service.example.com" if hostname != expectedHostname { - t.Fatalf("expected hostname '%s', got %v", expectedHostname, hostname) + t.Errorf("Expected hostname %q, got %q", expectedHostname, hostname) + } + }) + + t.Run("DefaultTLD", func(t *testing.T) { + // Setup mock components + mockConfig := config.NewMockConfigHandler() + mockConfig.GetStringFunc = func(key string, defaultValue ...string) string { + return defaultValue[0] + } + + // Initialize service + service := &BaseService{ + name: "test-service", + configHandler: mockConfig, + } + + // Get hostname + hostname := service.GetHostname() + + // Verify hostname uses default TLD + expectedHostname := "test-service.test" + if hostname != expectedHostname { + t.Errorf("Expected hostname %q, got %q", expectedHostname, hostname) } }) } @@ -212,12 +237,12 @@ func TestBaseService_IsLocalhostMode(t *testing.T) { return "" } - // When: IsLocalhostMode is called - isLocal := service.IsLocalhostMode() + // When: isLocalhostMode is called + isLocal := service.isLocalhostMode() // Then: the result should be true for docker-desktop if !isLocal { - t.Fatal("expected IsLocalhostMode to be true for docker-desktop driver") + t.Fatal("expected isLocalhostMode to be true for docker-desktop driver") } }) @@ -240,12 +265,24 @@ func TestBaseService_IsLocalhostMode(t *testing.T) { return "" } - // When: IsLocalhostMode is called - isLocal := service.IsLocalhostMode() + // When: isLocalhostMode is called + isLocal := service.isLocalhostMode() // Then: the result should be false for non-docker-desktop driver if isLocal { - t.Fatal("expected IsLocalhostMode to be false for non-docker-desktop driver") + t.Fatal("expected isLocalhostMode to be false for non-docker-desktop driver") + } + }) +} + +func TestBaseService_SupportsWildcard(t *testing.T) { + t.Run("Success", func(t *testing.T) { + // Initialize service + service := &BaseService{} + + // Verify wildcard support is false by default + if service.SupportsWildcard() { + t.Error("Expected SupportsWildcard to return false") } }) } diff --git a/pkg/services/talos_service.go b/pkg/services/talos_service.go index c6bed7e07..227b71dce 100644 --- a/pkg/services/talos_service.go +++ b/pkg/services/talos_service.go @@ -4,6 +4,7 @@ import ( "fmt" "math" "os" + "slices" "strconv" "strings" "sync" @@ -68,13 +69,10 @@ func (s *TalosService) SetAddress(address string) error { nodeType = "controlplanes" } - // Always use DNS name for hostname - if err := s.configHandler.SetContextValue(fmt.Sprintf("cluster.%s.nodes.%s.hostname", nodeType, s.name), s.name+"."+tld); err != nil { + if err := s.configHandler.SetContextValue(fmt.Sprintf("cluster.%s.nodes.%s.hostname", nodeType, s.name), s.name); err != nil { return err } - - // Always use DNS name for node - if err := s.configHandler.SetContextValue(fmt.Sprintf("cluster.%s.nodes.%s.node", nodeType, s.name), s.name+"."+tld); err != nil { + if err := s.configHandler.SetContextValue(fmt.Sprintf("cluster.%s.nodes.%s.node", nodeType, s.name), s.name); err != nil { return err } @@ -82,7 +80,7 @@ func (s *TalosService) SetAddress(address string) error { defer portLock.Unlock() var port int - if s.isLeader || !s.IsLocalhostMode() { + if s.isLeader || !s.isLocalhostMode() { port = defaultAPIPort } else { port = nextAPIPort @@ -98,8 +96,8 @@ func (s *TalosService) SetAddress(address string) error { hostPortsCopy := make([]string, len(hostPorts)) copy(hostPortsCopy, hostPorts) - for i := 0; i < len(hostPortsCopy); i++ { - parts := strings.Split(hostPortsCopy[i], ":") + for i, hostPortStr := range hostPortsCopy { + parts := strings.Split(hostPortStr, ":") var hostPort, nodePort int protocol := "tcp" @@ -130,7 +128,7 @@ func (s *TalosService) SetAddress(address string) error { } } default: - return fmt.Errorf("invalid hostPort format: %s", hostPortsCopy[i]) + return fmt.Errorf("invalid hostPort format: %s", hostPortStr) } // Check for conflicts in hostPort @@ -231,16 +229,11 @@ func (s *TalosService) GetComposeConfig() (*types.Config, error) { }) } - tld := s.configHandler.GetString("dns.domain", "test") - fullName := nodeName + "." + tld - if s.name == "" { - fullName = nodeType[:len(nodeType)-1] + "." + tld - } - serviceConfig := commonConfig - serviceConfig.Name = fullName - serviceConfig.ContainerName = fullName - serviceConfig.Hostname = fullName + serviceConfig.Name = nodeName + s.SetName(nodeName) + serviceConfig.ContainerName = s.GetContainerName() + serviceConfig.Hostname = nodeName serviceConfig.Environment = map[string]*string{ "PLATFORM": ptrString("container"), "TALOSSKU": ptrString(fmt.Sprintf("%dCPU-%dRAM", cpu, ram*1024)), @@ -254,7 +247,7 @@ func (s *TalosService) GetComposeConfig() (*types.Config, error) { } defaultAPIPortUint32 := uint32(defaultAPIPort) - if s.IsLocalhostMode() { + if s.isLocalhostMode() { ports = append(ports, types.ServicePortConfig{ Target: defaultAPIPortUint32, Published: publishedPort, @@ -296,6 +289,19 @@ func (s *TalosService) GetComposeConfig() (*types.Config, error) { serviceConfig.Ports = ports + dnsAddress := s.configHandler.GetString("dns.address") + if dnsAddress != "" { + if serviceConfig.DNS == nil { + serviceConfig.DNS = []string{} + } + + dnsExists := slices.Contains(serviceConfig.DNS, dnsAddress) + + if !dnsExists { + serviceConfig.DNS = append(serviceConfig.DNS, dnsAddress) + } + } + volumesMap := map[string]types.VolumeConfig{ strings.ReplaceAll(nodeName+"_system_state", "-", "_"): {}, strings.ReplaceAll(nodeName+"_var", "-", "_"): {}, diff --git a/pkg/services/talos_service_test.go b/pkg/services/talos_service_test.go index a571e5f19..11305d325 100644 --- a/pkg/services/talos_service_test.go +++ b/pkg/services/talos_service_test.go @@ -142,6 +142,9 @@ func TestTalosService_NewTalosService(t *testing.T) { func TestTalosService_SetAddress(t *testing.T) { t.Run("SuccessWorker", func(t *testing.T) { + // Reset package-level variables + nextAPIPort = constants.DEFAULT_TALOS_API_PORT + 1 + // Setup mocks for this test mocks := setupTalosServiceMocks() service := NewTalosService(mocks.Injector, "worker") @@ -504,6 +507,82 @@ func TestTalosService_SetAddress(t *testing.T) { t.Fatalf("expected endpoint to be set without error, got %v", err) } }) + + t.Run("PortIncrement", func(t *testing.T) { + // Reset package-level variables + nextAPIPort = constants.DEFAULT_TALOS_API_PORT + 1 + + // Setup mocks for this test + mocks := setupTalosServiceMocks() + + // Mock vm.driver to enable localhost mode + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "vm.driver" { + return "docker-desktop" + } + return "" + } + + // Create and initialize first service (non-leader) + service1 := NewTalosService(mocks.Injector, "worker1") + service1.isLeader = false + err := service1.Initialize() + if err != nil { + t.Fatalf("expected no error during initialization, got %v", err) + } + + // Set address for first service + err = service1.SetAddress("127.0.0.1") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + // Create and initialize second service (non-leader) + service2 := NewTalosService(mocks.Injector, "worker2") + service2.isLeader = false + err = service2.Initialize() + if err != nil { + t.Fatalf("expected no error during initialization, got %v", err) + } + + // Set address for second service + err = service2.SetAddress("127.0.0.1") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + // Verify that the ports were incremented correctly + expectedPort1 := constants.DEFAULT_TALOS_API_PORT + 1 + expectedPort2 := constants.DEFAULT_TALOS_API_PORT + 2 + + // Check if the ports were set correctly in the config handler + var setContextValueCalls = make(map[string]interface{}) + mocks.MockConfigHandler.SetContextValueFunc = func(key string, value interface{}) error { + setContextValueCalls[key] = value + return nil + } + + // Set endpoints for both services + err = mocks.MockConfigHandler.SetContextValue("cluster.workers.nodes.worker1.endpoint", fmt.Sprintf("127.0.0.1:%d", expectedPort1)) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + err = mocks.MockConfigHandler.SetContextValue("cluster.workers.nodes.worker2.endpoint", fmt.Sprintf("127.0.0.1:%d", expectedPort2)) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + // Verify the endpoints were set with correct ports + endpoint1 := setContextValueCalls["cluster.workers.nodes.worker1.endpoint"] + endpoint2 := setContextValueCalls["cluster.workers.nodes.worker2.endpoint"] + + if endpoint1 != fmt.Sprintf("127.0.0.1:%d", expectedPort1) { + t.Errorf("Expected endpoint1 to be 127.0.0.1:%d, got %v", expectedPort1, endpoint1) + } + if endpoint2 != fmt.Sprintf("127.0.0.1:%d", expectedPort2) { + t.Errorf("Expected endpoint2 to be 127.0.0.1:%d, got %v", expectedPort2, endpoint2) + } + }) } func TestTalosService_GetComposeConfig(t *testing.T) { @@ -939,4 +1018,198 @@ func TestTalosService_GetComposeConfig(t *testing.T) { t.Errorf("expected API port configuration, got target=%d protocol=%s", port.Target, port.Protocol) } }) + + t.Run("PortIncrementInGetComposeConfig", func(t *testing.T) { + // Reset package-level variables + nextAPIPort = constants.DEFAULT_TALOS_API_PORT + 1 + + // Setup mocks for this test + mocks := setupTalosServiceMocks() + + // Track SetContextValue calls + setContextValueCalls := make(map[string]string) + mocks.MockConfigHandler.SetContextValueFunc = func(key string, value interface{}) error { + if strValue, ok := value.(string); ok { + setContextValueCalls[key] = strValue + } + return nil + } + + // Mock GetStringSlice to return empty hostports + mocks.MockConfigHandler.GetStringSliceFunc = func(key string, defaultValue ...[]string) []string { + return []string{} + } + + // Mock GetString to return the stored endpoint values + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "vm.driver" { + return "docker-desktop" + } + if strings.HasSuffix(key, ".endpoint") { + if value, exists := setContextValueCalls[key]; exists { + return value + } + } + return "" + } + + // Create and initialize first service (non-leader) + service1 := NewTalosService(mocks.Injector, "worker1") + service1.isLeader = false + service1.SetName("worker1") + err := service1.Initialize() + if err != nil { + t.Fatalf("expected no error during initialization, got %v", err) + } + + // Set address for first service + err = service1.SetAddress("127.0.0.1") + if err != nil { + t.Fatalf("expected no error setting address, got %v", err) + } + + // Get compose config for first service + config1, err := service1.GetComposeConfig() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + // Create and initialize second service (non-leader) + service2 := NewTalosService(mocks.Injector, "worker2") + service2.isLeader = false + service2.SetName("worker2") + err = service2.Initialize() + if err != nil { + t.Fatalf("expected no error during initialization, got %v", err) + } + + // Set address for second service + err = service2.SetAddress("127.0.0.1") + if err != nil { + t.Fatalf("expected no error setting address, got %v", err) + } + + // Get compose config for second service + config2, err := service2.GetComposeConfig() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + // Verify port configurations + if len(config1.Services) != 1 { + t.Fatalf("expected 1 service in config1, got %d", len(config1.Services)) + } + if len(config2.Services) != 1 { + t.Fatalf("expected 1 service in config2, got %d", len(config2.Services)) + } + + // Check ports for first service + ports1 := config1.Services[0].Ports + if len(ports1) != 1 { + t.Fatalf("expected 1 port in service1, got %d", len(ports1)) + } + if ports1[0].Target != 50000 || ports1[0].Published != "50001" { + t.Errorf("expected port 50000:50001 in service1, got %d:%s", ports1[0].Target, ports1[0].Published) + } + + // Check ports for second service + ports2 := config2.Services[0].Ports + if len(ports2) != 1 { + t.Fatalf("expected 1 port in service2, got %d", len(ports2)) + } + if ports2[0].Target != 50000 || ports2[0].Published != "50002" { + t.Errorf("expected port 50000:50002 in service2, got %d:%s", ports2[0].Target, ports2[0].Published) + } + }) + + t.Run("DNSConfiguration", func(t *testing.T) { + // Setup mocks for this test + mocks := setupTalosServiceMocks() + + // Mock GetString to return DNS address + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "dns.address" { + return "10.0.0.53" + } + return "" + } + + // Create and initialize service + service := NewTalosService(mocks.Injector, "worker") + err := service.Initialize() + if err != nil { + t.Fatalf("expected no error during initialization, got %v", err) + } + + // Get compose config + config, err := service.GetComposeConfig() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + // Verify DNS configuration + if len(config.Services) != 1 { + t.Fatalf("expected 1 service, got %d", len(config.Services)) + } + + serviceConfig := config.Services[0] + if serviceConfig.DNS == nil { + t.Fatal("expected DNS to be initialized") + } + if len(serviceConfig.DNS) != 1 { + t.Fatalf("expected 1 DNS entry, got %d", len(serviceConfig.DNS)) + } + if serviceConfig.DNS[0] != "10.0.0.53" { + t.Errorf("expected DNS address 10.0.0.53, got %s", serviceConfig.DNS[0]) + } + }) + + t.Run("DNSConfigurationDuplicate", func(t *testing.T) { + // Setup mocks for this test + mocks := setupTalosServiceMocks() + + // Mock GetString to return DNS address + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "dns.address" { + return "10.0.0.53" + } + return "" + } + + // Create and initialize service + service := NewTalosService(mocks.Injector, "worker") + err := service.Initialize() + if err != nil { + t.Fatalf("expected no error during initialization, got %v", err) + } + + // Get compose config twice to test duplicate prevention + config1, err := service.GetComposeConfig() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + config2, err := service.GetComposeConfig() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + // Verify DNS configuration in both configs + if len(config1.Services) != 1 || len(config2.Services) != 1 { + t.Fatalf("expected 1 service in each config, got %d and %d", len(config1.Services), len(config2.Services)) + } + + serviceConfig1 := config1.Services[0] + serviceConfig2 := config2.Services[0] + + if serviceConfig1.DNS == nil || serviceConfig2.DNS == nil { + t.Fatal("expected DNS to be initialized in both configs") + } + if len(serviceConfig1.DNS) != 1 || len(serviceConfig2.DNS) != 1 { + t.Fatalf("expected 1 DNS entry in each config, got %d and %d", len(serviceConfig1.DNS), len(serviceConfig2.DNS)) + } + if serviceConfig1.DNS[0] != "10.0.0.53" || serviceConfig2.DNS[0] != "10.0.0.53" { + t.Errorf("expected DNS address 10.0.0.53 in both configs, got %s and %s", serviceConfig1.DNS[0], serviceConfig2.DNS[0]) + } + }) } diff --git a/pkg/virt/docker_virt.go b/pkg/virt/docker_virt.go index 82fdc4484..2758fa9ca 100644 --- a/pkg/virt/docker_virt.go +++ b/pkg/virt/docker_virt.go @@ -4,7 +4,6 @@ import ( "fmt" "maps" "path/filepath" - "slices" "sort" "strings" "time" @@ -309,8 +308,7 @@ func (v *DockerVirt) checkDockerDaemon() error { // settings from all services. It creates a network configuration with optional IPAM // settings based on the network CIDR, collects service configurations with their // network settings and IP addresses, and aggregates volumes and networks from all -// services into a single project configuration. When DNS is enabled, it configures -// all services to use the DNS service for name resolution. +// services into a single project configuration. func (v *DockerVirt) getFullComposeConfig() (*types.Project, error) { contextName := v.configHandler.GetContext() @@ -345,17 +343,6 @@ func (v *DockerVirt) getFullComposeConfig() (*types.Project, error) { combinedNetworks[networkName] = networkConfig - var dnsAddress string - dnsEnabled := v.configHandler.GetBool("dns.enabled") - if dnsEnabled { - dnsAddress = v.configHandler.GetString("dns.address") - if dnsAddress == "" { - if dnsService, ok := v.injector.Resolve("dns").(services.Service); ok { - dnsAddress = dnsService.GetAddress() - } - } - } - for _, service := range v.services { if serviceInstance, ok := service.(interface { GetComposeConfig() (*types.Config, error) @@ -384,18 +371,6 @@ func (v *DockerVirt) getFullComposeConfig() (*types.Project, error) { containerConfig.Networks[networkName].Ipv4Address = ipAddress } - if dnsEnabled && dnsAddress != "" { - if containerConfig.DNS == nil { - containerConfig.DNS = []string{} - } - - dnsExists := slices.Contains(containerConfig.DNS, dnsAddress) - - if !dnsExists { - containerConfig.DNS = append(containerConfig.DNS, dnsAddress) - } - } - combinedServices = append(combinedServices, containerConfig) } } diff --git a/pkg/virt/docker_virt_test.go b/pkg/virt/docker_virt_test.go index 8622ddfaf..b5c6fc67e 100644 --- a/pkg/virt/docker_virt_test.go +++ b/pkg/virt/docker_virt_test.go @@ -24,25 +24,21 @@ func setupSafeDockerContainerMocks(optionalInjector ...di.Injector) *MockCompone mockShell := shell.NewMockShell(injector) mockConfigHandler := config.NewMockConfigHandler() - mockService1 := services.NewMockService() - mockService2 := services.NewMockService() + mockService := services.NewMockService() // Register mock instances in the injector injector.Register("shell", mockShell) injector.Register("configHandler", mockConfigHandler) - injector.Register("service1", mockService1) - injector.Register("service2", mockService2) + injector.Register("service", mockService) - // Implement GetContextFunc on mock context - mockConfigHandler.GetContextFunc = func() string { - return "mock-context" - } - - // Set up the mock config handler to return specific values for Docker configuration + // Set up default mock behaviors mockConfigHandler.GetBoolFunc = func(key string, defaultValue ...bool) bool { if key == "docker.enabled" { return true } + if key == "dns.enabled" { + return true + } if len(defaultValue) > 0 { return defaultValue[0] } @@ -51,12 +47,19 @@ func setupSafeDockerContainerMocks(optionalInjector ...di.Injector) *MockCompone mockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { if key == "network.cidr_block" { - return "192.168.1.0/24" + return "10.0.0.0/24" + } + if key == "dns.address" { + return "" } if len(defaultValue) > 0 { return defaultValue[0] } - return "default-value" + return "" + } + + mockConfigHandler.GetContextFunc = func() string { + return "mock-context" } // Mock the shell Exec function to return generic JSON structures for two containers @@ -89,7 +92,7 @@ func setupSafeDockerContainerMocks(optionalInjector ...di.Injector) *MockCompone } // Mock the service's GetComposeConfigFunc to return a default configuration for two services - mockService1.GetComposeConfigFunc = func() (*types.Config, error) { + mockService.GetComposeConfigFunc = func() (*types.Config, error) { return &types.Config{ Services: []types.ServiceConfig{ {Name: "service1", Networks: map[string]*types.ServiceNetworkConfig{"windsor-mock-context": {Ipv4Address: "192.168.1.2"}}}, @@ -111,34 +114,20 @@ func setupSafeDockerContainerMocks(optionalInjector ...di.Injector) *MockCompone } // Mock the GetAddress function to return specific IP addresses for services - mockService1.GetAddressFunc = func() string { + mockService.GetAddressFunc = func() string { return "192.168.1.2" } - mockService2.GetAddressFunc = func() string { - return "192.168.1.3" - } // Mock the GetProjectRootFunc to return a mock project root path mockShell.GetProjectRootFunc = func() (string, error) { return "/mock/project/root", nil } - // Mock the GetString function to return a specific value for network.cidr_block - mockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "network.cidr_block" { - return "192.168.1.0/24" - } - if len(defaultValue) > 0 { - return defaultValue[0] - } - return "default-value" - } - return &MockComponents{ Injector: injector, MockShell: mockShell, MockConfigHandler: mockConfigHandler, - MockService: mockService1, + MockService: mockService, } } @@ -1373,7 +1362,7 @@ func TestDockerVirt_getFullComposeConfig(t *testing.T) { return "10.0.0.0/24" } if key == "dns.address" { - return "" + return "10.0.0.53" } if len(defaultValue) > 0 { return defaultValue[0] @@ -1413,7 +1402,7 @@ func TestDockerVirt_getFullComposeConfig(t *testing.T) { t.Fatalf("expected network %s to exist", networkName) } if network.Driver != "bridge" { - t.Errorf("expected driver to be bridge, got %s", network.Driver) + t.Errorf("expected network driver to be bridge, got %s", network.Driver) } if network.Ipam.Config == nil { t.Fatal("expected Ipam config to be non-nil") @@ -1460,9 +1449,9 @@ func TestDockerVirt_getFullComposeConfig(t *testing.T) { // Setup mock components mocks := setupSafeDockerContainerMocks() dockerVirt := NewDockerVirt(mocks.Injector) - dockerVirt.configHandler = mocks.MockConfigHandler // Set the config handler + dockerVirt.Initialize() - // Configure mock behavior + // Enable DNS in configuration mocks.MockConfigHandler.GetBoolFunc = func(key string, defaultValue ...bool) bool { if key == "docker.enabled" { return true @@ -1476,23 +1465,6 @@ func TestDockerVirt_getFullComposeConfig(t *testing.T) { return false } - mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "network.cidr_block" { - return "10.0.0.0/24" - } - if key == "dns.address" { - return "" - } - if len(defaultValue) > 0 { - return defaultValue[0] - } - return "" - } - - mocks.MockConfigHandler.GetContextFunc = func() string { - return "mock-context" - } - // Create a mock service mockService := services.NewMockService() mockService.GetNameFunc = func() string { @@ -1521,6 +1493,7 @@ func TestDockerVirt_getFullComposeConfig(t *testing.T) { Source: "/host:/container", }, }, + DNS: []string{"10.0.0.53"}, }, }, }, nil @@ -1579,8 +1552,5 @@ func TestDockerVirt_getFullComposeConfig(t *testing.T) { if service.Networks["windsor-mock-context"] == nil { t.Error("expected network windsor-mock-context to exist") } - if service.DNS == nil || len(service.DNS) != 1 || service.DNS[0] != "10.0.0.53" { - t.Errorf("expected DNS to be [10.0.0.53], got %v", service.DNS) - } }) }