diff --git a/pkg/network/darwin_network.go b/pkg/network/darwin_network.go index b5dfba1f5..940a8dae8 100644 --- a/pkg/network/darwin_network.go +++ b/pkg/network/darwin_network.go @@ -70,7 +70,16 @@ func (n *BaseNetworkManager) ConfigureDNS() error { if tld == "" { return fmt.Errorf("DNS domain is not configured") } - dnsIP := n.configHandler.GetString("dns.address") + + var dnsIP string + if n.isLocalhostMode() { + dnsIP = "127.0.0.1" + } else { + dnsIP = n.configHandler.GetString("dns.address") + if dnsIP == "" { + return fmt.Errorf("DNS address is not configured") + } + } resolverDir := "/etc/resolver" resolverFile := fmt.Sprintf("%s/%s", resolverDir, tld) diff --git a/pkg/network/darwin_network_test.go b/pkg/network/darwin_network_test.go index 564cfcf4e..104202477 100644 --- a/pkg/network/darwin_network_test.go +++ b/pkg/network/darwin_network_test.go @@ -360,14 +360,22 @@ func TestDarwinNetworkManager_ConfigureDNS(t *testing.T) { } }) - t.Run("SuccessLocalhost", func(t *testing.T) { + t.Run("SuccessLocalhostMode", func(t *testing.T) { mocks := setupDarwinNetworkManagerMocks() + // Mock the config handler to return docker-desktop for vm.driver mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "vm.driver" { + switch key { + case "vm.driver": return "docker-desktop" + case "dns.domain": + return "example.com" + default: + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" } - return "some_value" } nm := NewBaseNetworkManager(mocks.Injector) @@ -377,10 +385,23 @@ func TestDarwinNetworkManager_ConfigureDNS(t *testing.T) { t.Fatalf("expected no error during initialization, got %v", err) } + // Mock the writeFile function to capture the content + var capturedContent []byte + writeFile = func(_ string, content []byte, _ os.FileMode) error { + capturedContent = content + return nil + } + err = nm.ConfigureDNS() if err != nil { t.Fatalf("expected no error, got %v", err) } + + // Verify that the resolver file contains 127.0.0.1 + expectedContent := "nameserver 127.0.0.1\n" + if string(capturedContent) != expectedContent { + t.Errorf("expected resolver file content to be %q, got %q", expectedContent, string(capturedContent)) + } }) t.Run("NoDNSDomainConfigured", func(t *testing.T) { @@ -410,6 +431,43 @@ func TestDarwinNetworkManager_ConfigureDNS(t *testing.T) { } }) + t.Run("NoDNSAddressConfigured", func(t *testing.T) { + mocks := setupDarwinNetworkManagerMocks() + + // Mock the config handler to return empty DNS address but valid domain + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + switch key { + case "dns.domain": + return "example.com" + case "dns.address": + return "" + case "vm.driver": + return "lima" // Not localhost mode + default: + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } + } + + nm := NewBaseNetworkManager(mocks.Injector) + + err := nm.Initialize() + if err != nil { + t.Fatalf("expected no error during initialization, got %v", err) + } + + err = nm.ConfigureDNS() + if err == nil { + t.Fatalf("expected error, got nil") + } + expectedError := "DNS address is not configured" + if err.Error() != expectedError { + t.Fatalf("expected error %q, got %q", expectedError, err.Error()) + } + }) + t.Run("ResolverFileAlreadyExists", func(t *testing.T) { mocks := setupDarwinNetworkManagerMocks() diff --git a/pkg/network/linux_network.go b/pkg/network/linux_network.go index 4f7b5b132..2149b31d7 100644 --- a/pkg/network/linux_network.go +++ b/pkg/network/linux_network.go @@ -74,7 +74,16 @@ func (n *BaseNetworkManager) ConfigureDNS() error { if tld == "" { return fmt.Errorf("DNS domain is not configured") } - dnsIP := n.configHandler.GetString("dns.address") + + var dnsIP string + if n.isLocalhostMode() { + dnsIP = "127.0.0.1" + } else { + dnsIP = n.configHandler.GetString("dns.address") + if dnsIP == "" { + return fmt.Errorf("DNS address is not configured") + } + } // If DNS address is configured, use systemd-resolved resolvConf, err := readLink("/etc/resolv.conf") diff --git a/pkg/network/linux_network_test.go b/pkg/network/linux_network_test.go index 3388329a2..47e8f7ce8 100644 --- a/pkg/network/linux_network_test.go +++ b/pkg/network/linux_network_test.go @@ -6,6 +6,7 @@ package network import ( "fmt" "net" + "os" "strings" "testing" @@ -338,6 +339,83 @@ func TestLinuxNetworkManager_ConfigureDNS(t *testing.T) { } }) + t.Run("SuccessLocalhostMode", func(t *testing.T) { + mocks := setupLinuxNetworkManagerMocks() + + // Mock the config handler to return docker-desktop for vm.driver + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + switch key { + case "vm.driver": + return "docker-desktop" + case "dns.domain": + return "example.com" + default: + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } + } + + // Mock the readLink function to simulate systemd-resolved being in use + originalReadLink := readLink + defer func() { readLink = originalReadLink }() + readLink = func(_ string) (string, error) { + return "../run/systemd/resolve/stub-resolv.conf", nil + } + + // Mock the readFile function to capture the content + var capturedContent []byte + originalReadFile := readFile + defer func() { readFile = originalReadFile }() + readFile = func(_ string) ([]byte, error) { + if capturedContent != nil { + return capturedContent, nil + } + return nil, os.ErrNotExist + } + + // Create a networkManager using NewBaseNetworkManager with the mock DI container + nm := NewBaseNetworkManager(mocks.Injector) + err := nm.Initialize() + if err != nil { + t.Fatalf("expected no error during initialization, got %v", err) + } + + // Mock the shell.ExecSudo function to capture the content + mocks.MockShell.ExecSudoFunc = func(description, command string, args ...string) (string, error) { + if command == "bash" && args[0] == "-c" { + // Extract the content from the echo command + cmdStr := args[1] + + // The command is in the format: echo '[Resolve]\nDNS=127.0.0.1\n' | sudo tee /etc/systemd/resolved.conf.d/dns-override-example.com.con + // We need to extract the content between the first and last single quote before the pipe + if strings.Contains(cmdStr, "echo '") && strings.Contains(cmdStr, "' | sudo tee") { + start := strings.Index(cmdStr, "echo '") + 6 + end := strings.Index(cmdStr, "' | sudo tee") + if start < end { + content := cmdStr[start:end] + capturedContent = []byte(content) + } + } + return "", nil + } + return "", nil + } + + // Call the ConfigureDNS method + err = nm.ConfigureDNS() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + // Verify that the drop-in file contains 127.0.0.1 + expectedContent := "[Resolve]\nDNS=127.0.0.1\n" + if string(capturedContent) != expectedContent { + t.Errorf("expected drop-in file content to be %q, got %q", expectedContent, string(capturedContent)) + } + }) + t.Run("domainNotConfigured", func(t *testing.T) { mocks := setupLinuxNetworkManagerMocks() @@ -367,6 +445,51 @@ func TestLinuxNetworkManager_ConfigureDNS(t *testing.T) { } }) + t.Run("NoDNSAddressConfigured", func(t *testing.T) { + mocks := setupLinuxNetworkManagerMocks() + + // Mock the config handler to return empty DNS address but valid domain + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + switch key { + case "dns.domain": + return "example.com" + case "dns.address": + return "" + case "vm.driver": + return "lima" // Not localhost mode + default: + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } + } + + // Mock the readLink function to simulate systemd-resolved being in use + originalReadLink := readLink + defer func() { readLink = originalReadLink }() + readLink = func(_ string) (string, error) { + return "../run/systemd/resolve/stub-resolv.conf", nil + } + + // Create a networkManager using NewBaseNetworkManager with the mock DI container + nm := NewBaseNetworkManager(mocks.Injector) + err := nm.Initialize() + if err != nil { + t.Fatalf("expected no error during initialization, got %v", err) + } + + // Call the ConfigureDNS method and expect an error due to missing DNS address + err = nm.ConfigureDNS() + if err == nil { + t.Fatalf("expected error, got nil") + } + expectedError := "DNS address is not configured" + if !strings.Contains(err.Error(), expectedError) { + t.Fatalf("expected error %q, got %q", expectedError, err.Error()) + } + }) + t.Run("SystemdResolvedNotInUse", func(t *testing.T) { mocks := setupLinuxNetworkManagerMocks() diff --git a/pkg/network/network.go b/pkg/network/network.go index b6ed08fa2..5f2df78d9 100644 --- a/pkg/network/network.go +++ b/pkg/network/network.go @@ -34,7 +34,6 @@ type BaseNetworkManager struct { configHandler config.ConfigHandler networkInterfaceProvider NetworkInterfaceProvider services []services.Service - isLocalhost bool } // NewNetworkManager creates a new NetworkManager @@ -75,27 +74,16 @@ func (n *BaseNetworkManager) Initialize() error { n.services = serviceList - vmDriver := n.configHandler.GetString("vm.driver") - n.isLocalhost = vmDriver == "docker-desktop" - - if n.isLocalhost { - for _, service := range n.services { - if err := service.SetAddress("127.0.0.1"); err != nil { - return fmt.Errorf("error setting address for service: %w", err) - } - } - } else { - networkCIDR := n.configHandler.GetString("network.cidr_block") - if networkCIDR == "" { - networkCIDR = constants.DEFAULT_NETWORK_CIDR - if err := n.configHandler.SetContextValue("network.cidr_block", networkCIDR); err != nil { - return fmt.Errorf("error setting default network CIDR: %w", err) - } - } - if err := assignIPAddresses(n.services, &networkCIDR); err != nil { - return fmt.Errorf("error assigning IP addresses: %w", err) + networkCIDR := n.configHandler.GetString("network.cidr_block") + if networkCIDR == "" { + networkCIDR = constants.DEFAULT_NETWORK_CIDR + if err := n.configHandler.SetContextValue("network.cidr_block", networkCIDR); err != nil { + return fmt.Errorf("error setting default network CIDR: %w", err) } } + if err := assignIPAddresses(n.services, &networkCIDR); err != nil { + return fmt.Errorf("error assigning IP addresses: %w", err) + } return nil } @@ -109,6 +97,11 @@ func (n *BaseNetworkManager) ConfigureGuest() error { // Ensure BaseNetworkManager implements NetworkManager var _ NetworkManager = (*BaseNetworkManager)(nil) +// isLocalhostMode checks if the system is in localhost mode +func (n *BaseNetworkManager) isLocalhostMode() bool { + return n.configHandler.GetString("vm.driver") == "docker-desktop" +} + // assignIPAddresses assigns IP addresses to services based on the network CIDR. var assignIPAddresses = func(services []services.Service, networkCIDR *string) error { if networkCIDR == nil || *networkCIDR == "" { diff --git a/pkg/network/network_test.go b/pkg/network/network_test.go index 8432c22cd..725d02f06 100644 --- a/pkg/network/network_test.go +++ b/pkg/network/network_test.go @@ -107,12 +107,8 @@ func setupNetworkManagerMocks(optionalInjector ...di.Injector) *NetworkManagerMo // Create mock services mockService1 := services.NewMockService() - mockService1.SetName("Service1") - injector.Register("service1", mockService1) - - // Create another mock service mockService2 := services.NewMockService() - mockService2.SetName("Service2") + injector.Register("service1", mockService1) injector.Register("service2", mockService2) // Return a struct containing all mocks @@ -129,47 +125,41 @@ func setupNetworkManagerMocks(optionalInjector ...di.Injector) *NetworkManagerMo func TestNetworkManager_Initialize(t *testing.T) { t.Run("Success", func(t *testing.T) { mocks := setupNetworkManagerMocks() - nm := NewBaseNetworkManager(mocks.Injector) - err := nm.Initialize() - if err != nil { - t.Fatalf("expected no error, got %v", err) + // Track IP address assignments + var setAddressCalls []string + mockService1 := services.NewMockService() + mockService1.SetAddressFunc = func(address string) error { + setAddressCalls = append(setAddressCalls, address) + return nil } - - if nm == nil { - t.Fatalf("expected a valid NetworkManager instance, got nil") + mockService2 := services.NewMockService() + mockService2.SetAddressFunc = func(address string) error { + setAddressCalls = append(setAddressCalls, address) + return nil } - }) + mocks.Injector.Register("service1", mockService1) + mocks.Injector.Register("service2", mockService2) - t.Run("SuccessLocalhost", func(t *testing.T) { - mocks := setupNetworkManagerMocks() nm := NewBaseNetworkManager(mocks.Injector) - // Set the configuration to simulate docker-desktop - mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "vm.driver" { - return "docker-desktop" - } - return "" - } - - // Capture the SetAddress calls - mockService := services.NewMockService() - mockService.SetAddressFunc = func(address string) error { - if address != "127.0.0.1" { - return fmt.Errorf("expected address to be 127.0.0.1, got %v", address) - } - return nil - } - mocks.Injector.Register("service", mockService) - err := nm.Initialize() if err != nil { - t.Fatalf("expected no error, got %v", err) + t.Fatalf("expected no error during initialization, got %v", err) } - if !nm.isLocalhost { - t.Fatalf("expected isLocalhost to be true, got false") + // Verify that services were assigned IP addresses from the CIDR range + expectedIPs := []string{"192.168.1.2", "192.168.1.3"} + if len(setAddressCalls) != len(expectedIPs) { + t.Errorf("expected %d IP assignments, got %d", len(expectedIPs), len(setAddressCalls)) + } + for i, expectedIP := range expectedIPs { + if i >= len(setAddressCalls) { + break + } + if setAddressCalls[i] != expectedIP { + t.Errorf("expected IP %s to be assigned, got %s", expectedIP, setAddressCalls[i]) + } } }) @@ -255,49 +245,6 @@ func TestNetworkManager_Initialize(t *testing.T) { } }) - t.Run("ErrorSettingLocalhostAddresses", func(t *testing.T) { - // Setup mock components - mocks := setupNetworkManagerMocks() - nm := NewBaseNetworkManager(mocks.Injector) - - // Set the configuration to simulate docker-desktop - mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { - if key == "vm.driver" { - return "docker-desktop" - } - return "" - } - - // Mock SetAddress to return an error - mockService := services.NewMockService() - mockService.SetAddressFunc = func(address string) error { - if address == "127.0.0.1" { - return fmt.Errorf("mock error setting address") - } - return nil - } - mocks.Injector.Register("service", mockService) - - // Call the Initialize method - err := nm.Initialize() - - // Assert that an error occurred - if err == nil { - t.Errorf("expected error, got none") - } - - // Verify the error message contains the expected substring - expectedErrorSubstring := "error setting address for service" - if !strings.Contains(err.Error(), expectedErrorSubstring) { - t.Errorf("expected error message to contain %q, got %q", expectedErrorSubstring, err.Error()) - } - - // Verify that isLocalhost is true - if !nm.isLocalhost { - t.Errorf("expected isLocalhost to be true, got false") - } - }) - t.Run("ErrorSettingNetworkCidr", func(t *testing.T) { // Setup mock components mocks := setupNetworkManagerMocks() diff --git a/pkg/network/windows_network.go b/pkg/network/windows_network.go index ce3d676f9..99e5ee9ed 100644 --- a/pkg/network/windows_network.go +++ b/pkg/network/windows_network.go @@ -68,10 +68,14 @@ func (n *BaseNetworkManager) ConfigureDNS() error { return fmt.Errorf("DNS domain is not configured") } - dnsIP := n.configHandler.GetString("dns.address") - if dnsIP == "" { - // If there's no DNS address to configure, we simply skip - return nil + var dnsIP string + if n.isLocalhostMode() { + dnsIP = "127.0.0.1" + } else { + dnsIP = n.configHandler.GetString("dns.address") + if dnsIP == "" { + return fmt.Errorf("DNS address is not configured") + } } // Prepend a "." to the domain for the namespace diff --git a/pkg/network/windows_network_test.go b/pkg/network/windows_network_test.go index a7e066fed..500817f36 100644 --- a/pkg/network/windows_network_test.go +++ b/pkg/network/windows_network_test.go @@ -304,6 +304,84 @@ func TestWindowsNetworkManager_ConfigureDNS(t *testing.T) { } }) + t.Run("SuccessLocalhostMode", func(t *testing.T) { + // Given setup mocks using setupWindowsNetworkManagerMocks + mocks := setupWindowsNetworkManagerMocks() + + // Mock the config handler to return valid DNS domain and set VM driver to docker-desktop + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + switch key { + case "dns.domain": + return "example.com" + case "dns.address": + return "" // Empty DNS address is fine in localhost mode + case "vm.driver": + return "docker-desktop" // This enables localhost mode + default: + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } + } + + // Mock the shell to capture the namespace and nameservers + var capturedNamespace string + var capturedNameServers string + mocks.MockShell.ExecSilentFunc = func(command string, args ...string) (string, error) { + if command == "powershell" && len(args) > 1 && args[0] == "-Command" { + script := args[1] + if strings.Contains(script, "Get-DnsClientNrptRule") { + // Extract namespace from the script + namespaceMatch := strings.Split(script, "$namespace = '") + if len(namespaceMatch) > 1 { + namespaceParts := strings.Split(namespaceMatch[1], "'") + if len(namespaceParts) > 0 { + capturedNamespace = namespaceParts[0] + } + } + + // Extract nameservers from the script + nameserversMatch := strings.Split(script, "NameServers -ne \"") + if len(nameserversMatch) > 1 { + parts := strings.Split(nameserversMatch[1], "\"") + if len(parts) > 1 { + capturedNameServers = strings.Trim(parts[0], "\"") + } + } + return "", nil + } + } + return "", nil + } + + // And create a network manager using NewBaseNetworkManager with the mock injector + nm := NewBaseNetworkManager(mocks.Injector) + err := nm.Initialize() + if err != nil { + t.Errorf("expected no error during initialization, got %v", err) + } + + // When call the method under test + err = nm.ConfigureDNS() + + // Then expect no error + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + // Verify that the DNS rule is configured with 127.0.0.1 + expectedNamespace := ".example.com" + if capturedNamespace != expectedNamespace { + t.Errorf("expected namespace to be %q, got %q", expectedNamespace, capturedNamespace) + } + + expectedNameServers := "127.0.0.1" + if capturedNameServers != expectedNameServers { + t.Errorf("expected nameservers to be %q, got %q", expectedNameServers, capturedNameServers) + } + }) + t.Run("NoDNSName", func(t *testing.T) { // Given setup mocks using setupWindowsNetworkManagerMocks mocks := setupWindowsNetworkManagerMocks() @@ -348,6 +426,9 @@ func TestWindowsNetworkManager_ConfigureDNS(t *testing.T) { if key == "dns.address" { return "" } + if key == "vm.driver" { + return "hyperv" // Not localhost mode + } if len(defaultValue) > 0 { return defaultValue[0] } @@ -367,9 +448,9 @@ func TestWindowsNetworkManager_ConfigureDNS(t *testing.T) { // When call the method under test err = nm.ConfigureDNS() - // Then expect no error since DNS IP is not required for NRPT rule configuration - if err != nil { - t.Errorf("expected no error, got %v", err) + // Then expect error since DNS IP is required when not in localhost mode + if err == nil || !strings.Contains(err.Error(), "DNS address is not configured") { + t.Errorf("expected error 'DNS address is not configured', got %v", err) } }) @@ -471,4 +552,40 @@ func TestWindowsNetworkManager_ConfigureDNS(t *testing.T) { t.Errorf("expected error about failing to add or update DNS rule, got %v", err) } }) + + t.Run("NoDNSAddressConfigured", func(t *testing.T) { + mocks := setupWindowsNetworkManagerMocks() + + // Mock the config handler to return empty DNS address but valid domain + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + switch key { + case "dns.domain": + return "example.com" + case "dns.address": + return "" + case "vm.driver": + return "hyperv" // Not localhost mode + default: + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } + } + + nm := NewBaseNetworkManager(mocks.Injector) + err := nm.Initialize() + if err != nil { + t.Fatalf("expected no error during initialization, got %v", err) + } + + err = nm.ConfigureDNS() + if err == nil { + t.Fatalf("expected error, got nil") + } + expectedError := "DNS address is not configured" + if !strings.Contains(err.Error(), expectedError) { + t.Fatalf("expected error %q, got %q", expectedError, err.Error()) + } + }) } diff --git a/pkg/services/dns_service.go b/pkg/services/dns_service.go index 42c59e4e1..69c397cec 100644 --- a/pkg/services/dns_service.go +++ b/pkg/services/dns_service.go @@ -72,7 +72,7 @@ func (s *DNSService) GetComposeConfig() (*types.Config, error) { }, } - if s.IsLocalhost() { + if s.IsLocalhostMode() { corednsConfig.Ports = []types.ServicePortConfig{ { Target: 53, @@ -116,6 +116,9 @@ func (s *DNSService) WriteConfig() error { address := service.GetAddress() if address != "" { hostname := service.GetHostname() + if s.IsLocalhostMode() { + address = "127.0.0.1" + } hostEntries += fmt.Sprintf(" %s %s\n", address, hostname) } } @@ -139,14 +142,19 @@ func (s *DNSService) WriteConfig() error { var corefileContent string corefileContent = fmt.Sprintf(` %s:53 { + errors + reload + loop hosts { %s fallthrough } - + forward . %s +} +.:53 { + errors reload loop - - forward . %s + forward . 1.1.1.1 8.8.8.8 } `, tld, hostEntries, forwardAddressesStr) diff --git a/pkg/services/dns_service_test.go b/pkg/services/dns_service_test.go index 68a5a97ae..b46f56813 100644 --- a/pkg/services/dns_service_test.go +++ b/pkg/services/dns_service_test.go @@ -277,20 +277,25 @@ func TestDNSService_GetComposeConfig(t *testing.T) { }) t.Run("LocalhostPorts", func(t *testing.T) { - // Create a mock injector with necessary mocks + // Setup mock components mocks := createDNSServiceMocks() - - // Given: a DNSService with the mock injector service := NewDNSService(mocks.Injector) + service.Initialize() - // Initialize the service - if err := service.Initialize(); err != nil { - t.Fatalf("Initialize() error = %v", err) + // Set vm.driver to docker-desktop to simulate localhost mode + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "vm.driver" { + return "docker-desktop" + } + if key == "dns.domain" { + return "test" + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" } - // Set the address to localhost - service.SetAddress("127.0.0.1") - // When: GetComposeConfig is called cfg, err := service.GetComposeConfig() @@ -349,15 +354,20 @@ func TestDNSService_WriteConfig(t *testing.T) { // Verify that the Corefile content is correctly formatted expectedCorefileContent := ` test:53 { + errors + reload + loop hosts { 127.0.0.1 test 192.168.1.1 test fallthrough } - + forward . 1.1.1.1 8.8.8.8 +} +.:53 { + errors reload loop - forward . 1.1.1.1 8.8.8.8 } ` @@ -400,15 +410,20 @@ test:53 { // Verify that the Corefile content is correctly formatted for localhost expectedCorefileContent := ` test:53 { + errors + reload + loop hosts { 127.0.0.1 test 192.168.1.1 test fallthrough } - + forward . 1.1.1.1 8.8.8.8 +} +.:53 { + errors reload loop - forward . 1.1.1.1 8.8.8.8 } ` @@ -541,4 +556,118 @@ test:53 { t.Fatalf("expected error %v, got %v", expectedError, err) } }) + + t.Run("SuccessLocalhostMode", func(t *testing.T) { + // Setup mock components + mocks := createDNSServiceMocks() + service := NewDNSService(mocks.Injector) + + // Initialize the service + if err := service.Initialize(); err != nil { + t.Fatalf("Initialize() error = %v", err) + } + + // Create a mock service with a hostname + mockService := NewMockService() + mockService.GetNameFunc = func() string { + return "test-service" + } + mockService.GetHostnameFunc = func() string { + return "test-service.test" + } + mockService.GetAddressFunc = func() string { + return "192.168.1.2" + } + mockService.GetComposeConfigFunc = func() (*types.Config, error) { + return &types.Config{ + Services: []types.ServiceConfig{ + { + Name: "test-service", + }, + }, + }, nil + } + service.services = []Service{mockService} + + // Mock the config handler to return docker-desktop for vm.driver and DNS records + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "vm.driver" { + return "docker-desktop" + } + if key == "dns.domain" { + return "test" + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } + + mocks.MockConfigHandler.GetStringSliceFunc = func(key string, defaultValue ...[]string) []string { + if key == "dns.records" { + return []string{ + "127.0.0.1 test", + "192.168.1.1 test", + } + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return nil + } + + // Mock the GetProjectRoot function to return a valid path + mocks.MockShell.GetProjectRootFunc = func() (string, error) { + return "/valid/path", nil + } + + // Mock the mkdirAll function to simulate successful directory creation + originalMkdirAll := mkdirAll + defer func() { mkdirAll = originalMkdirAll }() + mkdirAll = func(path string, perm os.FileMode) error { + return nil + } + + // Mock the writeFile function to capture the content written + var writtenContent []byte + originalWriteFile := writeFile + defer func() { writeFile = originalWriteFile }() + writeFile = func(filename string, data []byte, perm os.FileMode) error { + writtenContent = data + return nil + } + + // When: WriteConfig is called + err := service.WriteConfig() + + // Then: no error should be returned + if err != nil { + t.Fatalf("WriteConfig() error = %v", err) + } + + // Verify that the Corefile content uses 127.0.0.1 for the service address + expectedCorefileContent := ` +test:53 { + errors + reload + loop + hosts { + 127.0.0.1 test-service.test + 127.0.0.1 test + 192.168.1.1 test + fallthrough + } + forward . 1.1.1.1 8.8.8.8 +} +.:53 { + errors + reload + loop + forward . 1.1.1.1 8.8.8.8 +} +` + if string(writtenContent) != expectedCorefileContent { + t.Errorf("Expected Corefile content:\n%s\nGot:\n%s", expectedCorefileContent, string(writtenContent)) + } + }) } diff --git a/pkg/services/registry_service.go b/pkg/services/registry_service.go index bf95fe9b4..55c484993 100644 --- a/pkg/services/registry_service.go +++ b/pkg/services/registry_service.go @@ -68,7 +68,7 @@ func (s *RegistryService) SetAddress(address string) error { if registryConfig.HostPort != 0 { hostPort = registryConfig.HostPort - } else if registryConfig.Remote == "" && s.IsLocalhost() { + } else if registryConfig.Remote == "" && s.IsLocalhostMode() { hostPort = defaultPort err = s.configHandler.SetContextValue("docker.registry_url", hostName) if err != nil { @@ -144,7 +144,7 @@ func (s *RegistryService) generateRegistryService(hostname string, registry dock {Type: "bind", Source: "${WINDSOR_PROJECT_ROOT}/.windsor/.docker-cache", Target: "/var/lib/registry"}, } - if registry.Remote == "" && s.IsLocalhost() { + if registry.Remote == "" && s.IsLocalhostMode() { service.Ports = []types.ServicePortConfig{ { Target: 5000, diff --git a/pkg/services/registry_service_test.go b/pkg/services/registry_service_test.go index 09da3fd7d..d5038ff14 100644 --- a/pkg/services/registry_service_test.go +++ b/pkg/services/registry_service_test.go @@ -213,6 +213,28 @@ func TestRegistryService_GetComposeConfig(t *testing.T) { t.Run("LocalRegistry", func(t *testing.T) { // Given a mock config handler, shell, context, and service mocks := setupSafeRegistryServiceMocks() + mocks.MockConfigHandler.GetConfigFunc = func() *v1alpha1.Context { + return &v1alpha1.Context{ + Docker: &docker.DockerConfig{ + Registries: map[string]docker.RegistryConfig{ + "local-registry": {HostPort: 5000}, + }, + }, + } + } + // Set vm.driver to docker-desktop for localhost tests + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "vm.driver" { + return "docker-desktop" + } + if key == "dns.domain" { + return "test" + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } registryService := NewRegistryService(mocks.Injector) registryService.SetName("local-registry") err := registryService.Initialize() @@ -344,7 +366,7 @@ func TestRegistryService_SetAddress(t *testing.T) { }) t.Run("NoHostPortSetAndLocalhost", func(t *testing.T) { - // Given a mock config handler, shell, context, and service with no HostPort set + // Given a mock config handler, shell, context, and service with no HostPort mocks := setupSafeRegistryServiceMocks() mocks.MockConfigHandler.GetConfigFunc = func() *v1alpha1.Context { return &v1alpha1.Context{ @@ -355,6 +377,19 @@ func TestRegistryService_SetAddress(t *testing.T) { }, } } + // Set vm.driver to docker-desktop for localhost tests + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "vm.driver" { + return "docker-desktop" + } + if key == "dns.domain" { + return "test" + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } registryService := NewRegistryService(mocks.Injector) registryService.SetName("registry") err := registryService.Initialize() @@ -419,6 +454,19 @@ func TestRegistryService_SetAddress(t *testing.T) { }, } } + // Set vm.driver to docker-desktop for localhost tests + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "vm.driver" { + return "docker-desktop" + } + if key == "dns.domain" { + return "test" + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } registryService := NewRegistryService(mocks.Injector) registryService.SetName("registry") err := registryService.Initialize() @@ -487,14 +535,8 @@ func TestRegistryService_SetAddress(t *testing.T) { }) t.Run("SetContextValueErrorForRegistryURL", func(t *testing.T) { - // Given a mock config handler that will fail to set context value for registry URL + // Given a mock config handler, shell, context, and service with no HostPort and no Remote mocks := setupSafeRegistryServiceMocks() - mocks.MockConfigHandler.SetContextValueFunc = func(key string, value interface{}) error { - if key == "docker.registry_url" { - return fmt.Errorf("failed to set registry URL") - } - return nil - } mocks.MockConfigHandler.GetConfigFunc = func() *v1alpha1.Context { return &v1alpha1.Context{ Docker: &docker.DockerConfig{ @@ -504,6 +546,26 @@ func TestRegistryService_SetAddress(t *testing.T) { }, } } + // Set vm.driver to docker-desktop for localhost tests + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "vm.driver" { + return "docker-desktop" + } + if key == "dns.domain" { + return "test" + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } + // Mock the SetContextValue function to return an error for registry URL + mocks.MockConfigHandler.SetContextValueFunc = func(key string, value interface{}) error { + if key == "docker.registry_url" { + return fmt.Errorf("failed to set registry URL") + } + return nil + } registryService := NewRegistryService(mocks.Injector) registryService.SetName("registry") err := registryService.Initialize() diff --git a/pkg/services/service.go b/pkg/services/service.go index 7e51ad5c1..6644e83cb 100644 --- a/pkg/services/service.go +++ b/pkg/services/service.go @@ -38,8 +38,8 @@ type Service interface { // GetHostname returns the name plus the tld from the config GetHostname() string - // IsLocalhost checks if the current address is a localhost address - IsLocalhost() bool + // IsLocalhostMode checks if we are in localhost mode (vm.driver == "docker-desktop") + IsLocalhostMode() bool } // BaseService is a base implementation of the Service interface @@ -104,13 +104,8 @@ func (s *BaseService) GetHostname() string { return fmt.Sprintf("%s.%s", s.name, tld) } -// IsLocalhost checks if the current address is a localhost address -func (s *BaseService) IsLocalhost() bool { - localhostAddresses := []string{"localhost", "127.0.0.1", "::1"} - for _, localhost := range localhostAddresses { - if s.address == localhost { - return true - } - } - return false +// IsLocalhostMode checks if we are in localhost mode (vm.driver == "docker-desktop") +func (s *BaseService) IsLocalhostMode() bool { + vmDriver := s.configHandler.GetString("vm.driver") + return vmDriver == "docker-desktop" } diff --git a/pkg/services/service_test.go b/pkg/services/service_test.go index 44dee3e67..f43f93b37 100644 --- a/pkg/services/service_test.go +++ b/pkg/services/service_test.go @@ -192,33 +192,60 @@ func TestBaseService_GetHostname(t *testing.T) { }) } -func TestBaseService_IsLocalhost(t *testing.T) { - tests := []struct { - name string - address string - expectedLocal bool - }{ - {"Localhost by name", "localhost", true}, - {"Localhost by IPv4", "127.0.0.1", true}, - {"Localhost by IPv6", "::1", true}, - {"Non-localhost IPv4", "192.168.1.1", false}, - {"Non-localhost IPv6", "2001:0db8:85a3:0000:0000:8a2e:0370:7334", false}, - {"Empty address", "", false}, - } +func TestBaseService_IsLocalhostMode(t *testing.T) { + t.Run("Success", func(t *testing.T) { + // Setup mock components + mocks := setupSafeBaseServiceMocks() + service := &BaseService{ + injector: mocks.Injector, + } + service.Initialize() - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Given: a new BaseService with a mocked IsLocalhost method - service := &BaseService{} - service.address = tt.address + // Configure mock behavior + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "vm.driver" { + return "docker-desktop" + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } + + // When: IsLocalhostMode is called + isLocal := service.IsLocalhostMode() + + // Then: the result should be true for docker-desktop + if !isLocal { + t.Fatal("expected IsLocalhostMode to be true for docker-desktop driver") + } + }) - // Mocking IsLocalhost by directly setting the address - isLocal := service.IsLocalhost() + t.Run("NotDockerDesktop", func(t *testing.T) { + // Setup mock components + mocks := setupSafeBaseServiceMocks() + service := &BaseService{ + injector: mocks.Injector, + } + service.Initialize() - // Then: the result should match the expected outcome - if isLocal != tt.expectedLocal { - t.Fatalf("expected IsLocalhost to be %v for address '%s', got %v", tt.expectedLocal, tt.address, isLocal) + // Configure mock behavior + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "vm.driver" { + return "lima" } - }) - } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } + + // When: IsLocalhostMode is called + isLocal := service.IsLocalhostMode() + + // Then: the result should be false for non-docker-desktop driver + if isLocal { + t.Fatal("expected IsLocalhostMode to be false for non-docker-desktop driver") + } + }) } diff --git a/pkg/services/talos_service.go b/pkg/services/talos_service.go index 37c25d623..c6bed7e07 100644 --- a/pkg/services/talos_service.go +++ b/pkg/services/talos_service.go @@ -68,10 +68,13 @@ func (s *TalosService) SetAddress(address string) error { nodeType = "controlplanes" } + // Always use DNS name for hostname if err := s.configHandler.SetContextValue(fmt.Sprintf("cluster.%s.nodes.%s.hostname", nodeType, s.name), s.name+"."+tld); err != nil { return err } - if err := s.configHandler.SetContextValue(fmt.Sprintf("cluster.%s.nodes.%s.node", nodeType, s.name), address); err != nil { + + // Always use DNS name for node + if err := s.configHandler.SetContextValue(fmt.Sprintf("cluster.%s.nodes.%s.node", nodeType, s.name), s.name+"."+tld); err != nil { return err } @@ -79,14 +82,14 @@ func (s *TalosService) SetAddress(address string) error { defer portLock.Unlock() var port int - if s.isLeader || !s.IsLocalhost() { + if s.isLeader || !s.IsLocalhostMode() { port = defaultAPIPort } else { port = nextAPIPort nextAPIPort++ } - if err := s.configHandler.SetContextValue(fmt.Sprintf("cluster.%s.nodes.%s.endpoint", nodeType, s.name), fmt.Sprintf("%s:%d", address, port)); err != nil { + if err := s.configHandler.SetContextValue(fmt.Sprintf("cluster.%s.nodes.%s.endpoint", nodeType, s.name), fmt.Sprintf("%s.%s:%d", s.name, tld, port)); err != nil { return err } @@ -251,7 +254,7 @@ func (s *TalosService) GetComposeConfig() (*types.Config, error) { } defaultAPIPortUint32 := uint32(defaultAPIPort) - if s.IsLocalhost() { + if s.IsLocalhostMode() { ports = append(ports, types.ServicePortConfig{ Target: defaultAPIPortUint32, Published: publishedPort, diff --git a/pkg/services/talos_service_test.go b/pkg/services/talos_service_test.go index 21f60fb58..a571e5f19 100644 --- a/pkg/services/talos_service_test.go +++ b/pkg/services/talos_service_test.go @@ -832,4 +832,111 @@ func TestTalosService_GetComposeConfig(t *testing.T) { t.Fatalf("expected volumes, got 0") } }) + + t.Run("LocalhostModeControlPlaneLeader", func(t *testing.T) { + // Setup mocks for this test + mocks := setupTalosServiceMocks() + service := NewTalosService(mocks.Injector, "controlplane") + + // Mock vm.driver to enable localhost mode + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "vm.driver" { + return "docker-desktop" + } + return "" + } + + // Set isLeader to true + service.isLeader = true + + // Initialize the service + err := service.Initialize() + if err != nil { + t.Fatalf("expected no error during initialization, got %v", err) + } + + // When the GetComposeConfig method is called + config, err := service.GetComposeConfig() + + // Then no error should be returned + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + // And the config should contain both API and Kubernetes ports + if len(config.Services) != 1 { + t.Fatalf("expected 1 service, got %d", len(config.Services)) + } + + serviceConfig := config.Services[0] + if len(serviceConfig.Ports) != 2 { + t.Fatalf("expected 2 ports, got %d", len(serviceConfig.Ports)) + } + + // Verify API port + foundAPIPort := false + foundKubePort := false + for _, port := range serviceConfig.Ports { + if port.Target == 50000 && port.Protocol == "tcp" { + foundAPIPort = true + } + if port.Target == 6443 && port.Published == "6443" && port.Protocol == "tcp" { + foundKubePort = true + } + } + + if !foundAPIPort { + t.Error("expected to find API port configuration") + } + if !foundKubePort { + t.Error("expected to find Kubernetes API port configuration") + } + }) + + t.Run("LocalhostModeControlPlaneNonLeader", func(t *testing.T) { + // Setup mocks for this test + mocks := setupTalosServiceMocks() + service := NewTalosService(mocks.Injector, "controlplane") + + // Mock vm.driver to enable localhost mode + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "vm.driver" { + return "docker-desktop" + } + return "" + } + + // Set isLeader to false + service.isLeader = false + + // Initialize the service + err := service.Initialize() + if err != nil { + t.Fatalf("expected no error during initialization, got %v", err) + } + + // When the GetComposeConfig method is called + config, err := service.GetComposeConfig() + + // Then no error should be returned + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + // And the config should contain only the API port + if len(config.Services) != 1 { + t.Fatalf("expected 1 service, got %d", len(config.Services)) + } + + serviceConfig := config.Services[0] + if len(serviceConfig.Ports) != 1 { + t.Fatalf("expected 1 port, got %d", len(serviceConfig.Ports)) + } + + // Verify only API port is present + port := serviceConfig.Ports[0] + if port.Target != 50000 || port.Protocol != "tcp" { + t.Errorf("expected API port configuration, got target=%d protocol=%s", port.Target, port.Protocol) + } + }) } diff --git a/pkg/virt/docker_virt.go b/pkg/virt/docker_virt.go index 17b6970ed..82fdc4484 100644 --- a/pkg/virt/docker_virt.go +++ b/pkg/virt/docker_virt.go @@ -2,7 +2,9 @@ package virt import ( "fmt" + "maps" "path/filepath" + "slices" "sort" "strings" "time" @@ -28,19 +30,19 @@ func NewDockerVirt(injector di.Injector) *DockerVirt { } } -// Initialize resolves the dependencies for DockerVirt +// Initialize resolves all dependencies for DockerVirt, including services from the DI +// container, Docker configuration status, and determines the appropriate docker compose +// command to use. It alphabetizes services and verifies Docker is enabled. func (v *DockerVirt) Initialize() error { if err := v.BaseVirt.Initialize(); err != nil { return fmt.Errorf("error initializing base: %w", err) } - // Resolve all services resolvedServices, err := v.injector.ResolveAll((*services.Service)(nil)) if err != nil { return fmt.Errorf("error resolving services: %w", err) } - // Convert the resolved services to the correct type serviceSlice := make([]services.Service, len(resolvedServices)) for i, service := range resolvedServices { if s, _ := service.(services.Service); s != nil { @@ -48,20 +50,16 @@ func (v *DockerVirt) Initialize() error { } } - // Alphabetize the services by their name sort.Slice(serviceSlice, func(i, j int) bool { return fmt.Sprintf("%T", serviceSlice[i]) < fmt.Sprintf("%T", serviceSlice[j]) }) - // Check if Docker is enabled using configHandler if !v.configHandler.GetBool("docker.enabled") { return fmt.Errorf("Docker configuration is not defined") } - // Set the services v.services = serviceSlice - // Determine the correct docker compose command if err := v.determineComposeCommand(); err != nil { return fmt.Errorf("error determining docker compose command: %w", err) } @@ -69,8 +67,9 @@ func (v *DockerVirt) Initialize() error { return nil } -// determineComposeCommand checks for available docker compose commands. If a docker-compose -// command is not available, none is set. +// determineComposeCommand checks for available docker compose commands in order of +// preference: docker-compose, docker-cli-plugin-docker-compose, and docker compose. +// It sets the first available command for later use in Docker operations. func (v *DockerVirt) determineComposeCommand() error { commands := []string{"docker-compose", "docker-cli-plugin-docker-compose", "docker compose"} for _, cmd := range commands { @@ -82,36 +81,32 @@ func (v *DockerVirt) determineComposeCommand() error { return nil } -// Up starts docker compose +// Up starts docker compose in detached mode with retry logic for reliability. It +// verifies Docker is enabled, checks the daemon is running, sets the compose file +// path, and attempts to start services with up to 3 retries if initial attempts fail. func (v *DockerVirt) Up() error { - // Check if Docker is enabled and run "docker compose up" in daemon mode if necessary if v.configHandler.GetBool("docker.enabled") { - // Ensure Docker daemon is running if err := v.checkDockerDaemon(); err != nil { return fmt.Errorf("Docker daemon is not running: %w", err) } - // Get the path to the docker-compose.yaml file projectRoot, err := v.shell.GetProjectRoot() if err != nil { return fmt.Errorf("error retrieving project root: %w", err) } composeFilePath := filepath.Join(projectRoot, ".windsor", "docker-compose.yaml") - // Set the COMPOSE_FILE environment variable and handle potential error if err := osSetenv("COMPOSE_FILE", composeFilePath); err != nil { return fmt.Errorf("failed to set COMPOSE_FILE environment variable: %w", err) } - // Retry logic for docker compose up with progress display retries := 3 var lastErr error var lastOutput string - for i := 0; i < retries; i++ { + for i := range make([]struct{}, retries) { args := []string{"up", "--detach", "--remove-orphans"} message := "📦 Running docker compose up" - // Use ExecProgress for the first attempt to show progress if i == 0 { output, err := v.shell.ExecProgress(message, v.composeCommand, args...) if err == nil { @@ -120,7 +115,6 @@ func (v *DockerVirt) Up() error { lastErr = err lastOutput = output } else { - // Use ExecSilent for retries to avoid multiple progress messages output, err := v.shell.ExecSilent(v.composeCommand, args...) if err == nil { return nil @@ -140,28 +134,25 @@ func (v *DockerVirt) Up() error { return nil } -// Down stops the Docker container +// Down stops all Docker containers managed by Windsor and removes associated volumes +// to ensure a clean shutdown. It verifies Docker is enabled, checks the daemon is +// running, and executes docker compose down with the --remove-orphans and --volumes flags. func (v *DockerVirt) Down() error { - // Check if Docker is enabled and run "docker compose down" if necessary if v.configHandler.GetBool("docker.enabled") { - // Ensure Docker daemon is running if err := v.checkDockerDaemon(); err != nil { return fmt.Errorf("Docker daemon is not running: %w", err) } - // Get the path to the docker-compose.yaml file projectRoot, err := v.shell.GetProjectRoot() if err != nil { return fmt.Errorf("error retrieving project root: %w", err) } composeFilePath := filepath.Join(projectRoot, ".windsor", "docker-compose.yaml") - // Set the COMPOSE_FILE environment variable and handle potential error if err := osSetenv("COMPOSE_FILE", composeFilePath); err != nil { return fmt.Errorf("error setting COMPOSE_FILE environment variable: %w", err) } - // Run docker compose down with clean flags using the Exec function from shell.go output, err := v.shell.ExecProgress("📦 Running docker compose down", v.composeCommand, "down", "--remove-orphans", "--volumes") if err != nil { return fmt.Errorf("Error executing command %s down: %w\n%s", v.composeCommand, err, output) @@ -170,33 +161,31 @@ func (v *DockerVirt) Down() error { return nil } -// WriteConfig writes the Docker configuration file +// WriteConfig generates and writes the Docker Compose configuration file by combining +// settings from all services. It creates the necessary directory structure, retrieves +// the full compose configuration, serializes it to YAML, and writes it to the .windsor +// directory with appropriate permissions. func (v *DockerVirt) WriteConfig() error { - // Get the project root and construct the file path projectRoot, err := v.shell.GetProjectRoot() if err != nil { return fmt.Errorf("error retrieving project root: %w", err) } composeFilePath := filepath.Join(projectRoot, ".windsor", "docker-compose.yaml") - // Ensure the parent context folder exists if err := mkdirAll(filepath.Dir(composeFilePath), 0755); err != nil { return fmt.Errorf("error creating parent context folder: %w", err) } - // Retrieve the full compose configuration project, err := v.getFullComposeConfig() if err != nil { return fmt.Errorf("error getting full compose config: %w", err) } - // Serialize the docker compose config to YAML yamlData, err := yamlMarshal(project) if err != nil { return fmt.Errorf("error marshaling docker compose config to YAML: %w", err) } - // Write the YAML data to the specified file err = writeFile(composeFilePath, yamlData, 0644) if err != nil { return fmt.Errorf("error writing docker compose file: %w", err) @@ -205,9 +194,11 @@ func (v *DockerVirt) WriteConfig() error { return nil } -// GetContainerInfo returns a list of information about the Docker containers, including their labels +// GetContainerInfo retrieves detailed information about Docker containers managed by +// Windsor, including their names, IP addresses, and labels. It filters containers +// by Windsor-managed labels and context, and optionally by service name if provided. +// For each container, it retrieves network settings to determine IP addresses. func (v *DockerVirt) GetContainerInfo(name ...string) ([]ContainerInfo, error) { - // Get the context name contextName := v.configHandler.GetContext() command := "docker" @@ -237,7 +228,6 @@ func (v *DockerVirt) GetContainerInfo(name ...string) ([]ContainerInfo, error) { serviceName, _ := labels["com.docker.compose.service"] - // If a name is provided, check if it matches the current serviceName if len(name) > 0 && serviceName != name[0] { continue } @@ -267,7 +257,6 @@ func (v *DockerVirt) GetContainerInfo(name ...string) ([]ContainerInfo, error) { Labels: labels, } - // If a name is provided and matches, return immediately with this containerInfo if len(name) > 0 && serviceName == name[0] { return []ContainerInfo{containerInfo}, nil } @@ -278,7 +267,10 @@ func (v *DockerVirt) GetContainerInfo(name ...string) ([]ContainerInfo, error) { return containerInfos, nil } -// PrintInfo prints the container information +// PrintInfo displays a formatted table of running Docker containers with their names, +// IP addresses, and roles. It retrieves container information using GetContainerInfo +// and presents it in a tabular format for easy reading. If no containers are running, +// it displays an appropriate message. func (v *DockerVirt) PrintInfo() error { containerInfos, err := v.GetContainerInfo() if err != nil { @@ -303,7 +295,9 @@ func (v *DockerVirt) PrintInfo() error { // Ensure DockerVirt implements ContainerRuntime var _ ContainerRuntime = (*DockerVirt)(nil) -// checkDockerDaemon checks if the Docker daemon is running +// checkDockerDaemon verifies that the Docker daemon is running and accessible by +// executing the 'docker info' command. It returns an error if the daemon cannot +// be contacted, which is used by other functions to ensure Docker is available. func (v *DockerVirt) checkDockerDaemon() error { command := "docker" args := []string{"info"} @@ -311,11 +305,12 @@ func (v *DockerVirt) checkDockerDaemon() error { return err } -// getFullComposeConfig builds a Docker Compose configuration for DockerVirt. It retrieves the -// context name and configuration, checks if Docker is defined, and returns nil if not. It sets up -// combined configurations for services, volumes, and networks, defining a network with IPAM if a -// NetworkCIDR is specified. It iterates over services, gathering their configurations and IPs, -// and returns a Project with these combined settings. +// getFullComposeConfig builds a complete Docker Compose configuration by combining +// settings from all services. It creates a network configuration with optional IPAM +// settings based on the network CIDR, collects service configurations with their +// network settings and IP addresses, and aggregates volumes and networks from all +// services into a single project configuration. When DNS is enabled, it configures +// all services to use the DNS service for name resolution. func (v *DockerVirt) getFullComposeConfig() (*types.Project, error) { contextName := v.configHandler.GetContext() @@ -330,7 +325,6 @@ func (v *DockerVirt) getFullComposeConfig() (*types.Project, error) { combinedVolumes = make(map[string]types.VolumeConfig) combinedNetworks = make(map[string]types.NetworkConfig) - // Configure the network networkName := fmt.Sprintf("windsor-%s", contextName) networkConfig := types.NetworkConfig{ @@ -351,7 +345,17 @@ func (v *DockerVirt) getFullComposeConfig() (*types.Project, error) { combinedNetworks[networkName] = networkConfig - // Iterate over each service and collect container configs + var dnsAddress string + dnsEnabled := v.configHandler.GetBool("dns.enabled") + if dnsEnabled { + dnsAddress = v.configHandler.GetString("dns.address") + if dnsAddress == "" { + if dnsService, ok := v.injector.Resolve("dns").(services.Service); ok { + dnsAddress = dnsService.GetAddress() + } + } + } + for _, service := range v.services { if serviceInstance, ok := service.(interface { GetComposeConfig() (*types.Config, error) @@ -376,24 +380,32 @@ func (v *DockerVirt) getFullComposeConfig() (*types.Project, error) { } networkCIDR := v.configHandler.GetString("network.cidr_block") - if networkCIDR != "" && ipAddress != "127.0.0.1" && ipAddress != "" { + if networkCIDR != "" && ipAddress != "" { containerConfig.Networks[networkName].Ipv4Address = ipAddress } + if dnsEnabled && dnsAddress != "" { + if containerConfig.DNS == nil { + containerConfig.DNS = []string{} + } + + dnsExists := slices.Contains(containerConfig.DNS, dnsAddress) + + if !dnsExists { + containerConfig.DNS = append(containerConfig.DNS, dnsAddress) + } + } + combinedServices = append(combinedServices, containerConfig) } } if containerConfigs.Volumes != nil { - for volumeName, volumeConfig := range containerConfigs.Volumes { - combinedVolumes[volumeName] = volumeConfig - } + maps.Copy(combinedVolumes, containerConfigs.Volumes) } if containerConfigs.Networks != nil { - for networkName, networkConfig := range containerConfigs.Networks { - combinedNetworks[networkName] = networkConfig - } + maps.Copy(combinedNetworks, containerConfigs.Networks) } } } diff --git a/pkg/virt/docker_virt_test.go b/pkg/virt/docker_virt_test.go index bb63061c4..8622ddfaf 100644 --- a/pkg/virt/docker_virt_test.go +++ b/pkg/virt/docker_virt_test.go @@ -1346,4 +1346,241 @@ func TestDockerVirt_getFullComposeConfig(t *testing.T) { t.Errorf("expected %d networks, got %d", expectedNetworks, len(project.Networks)) } }) + + t.Run("WithDNS", func(t *testing.T) { + // Setup mock components + mocks := setupSafeDockerContainerMocks() + dockerVirt := NewDockerVirt(mocks.Injector) + dockerVirt.services = []services.Service{} // Initialize empty services slice + dockerVirt.configHandler = mocks.MockConfigHandler // Set the config handler + + // Configure mock behavior for DNS + mocks.MockConfigHandler.GetBoolFunc = func(key string, defaultValue ...bool) bool { + if key == "docker.enabled" { + return true + } + if key == "dns.enabled" { + return true + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return false + } + + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "network.cidr_block" { + return "10.0.0.0/24" + } + if key == "dns.address" { + return "" + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } + + mocks.MockConfigHandler.GetContextFunc = func() string { + return "mock-context" + } + + // Setup mock DNS service + mockDNS := services.NewMockService() + mockDNS.GetAddressFunc = func() string { + return "10.0.0.53" + } + mocks.Injector.Register("dns", mockDNS) + + // Call the function + project, err := dockerVirt.getFullComposeConfig() + + // Assertions + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if project == nil { + t.Fatal("expected project to be non-nil") + } + if project.Networks == nil { + t.Fatal("expected networks to be non-nil") + } + + // Check network configuration + networkName := "windsor-mock-context" + network, exists := project.Networks[networkName] + if !exists { + t.Fatalf("expected network %s to exist", networkName) + } + if network.Driver != "bridge" { + t.Errorf("expected driver to be bridge, got %s", network.Driver) + } + if network.Ipam.Config == nil { + t.Fatal("expected Ipam config to be non-nil") + } + if len(network.Ipam.Config) != 1 { + t.Fatalf("expected 1 Ipam config, got %d", len(network.Ipam.Config)) + } + if network.Ipam.Config[0].Subnet != "10.0.0.0/24" { + t.Errorf("expected subnet to be 10.0.0.0/24, got %s", network.Ipam.Config[0].Subnet) + } + }) + + t.Run("Disabled", func(t *testing.T) { + // Setup mock components + mocks := setupSafeDockerContainerMocks() + dockerVirt := NewDockerVirt(mocks.Injector) + dockerVirt.services = []services.Service{} // Initialize empty services slice + dockerVirt.configHandler = mocks.MockConfigHandler // Set the config handler + + // Configure mock behavior + mocks.MockConfigHandler.GetBoolFunc = func(key string, defaultValue ...bool) bool { + if key == "docker.enabled" { + return false + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return false + } + + // Call the function + project, err := dockerVirt.getFullComposeConfig() + + // Assertions + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if project != nil { + t.Fatal("expected project to be nil") + } + }) + + t.Run("WithServices", func(t *testing.T) { + // Setup mock components + mocks := setupSafeDockerContainerMocks() + dockerVirt := NewDockerVirt(mocks.Injector) + dockerVirt.configHandler = mocks.MockConfigHandler // Set the config handler + + // Configure mock behavior + mocks.MockConfigHandler.GetBoolFunc = func(key string, defaultValue ...bool) bool { + if key == "docker.enabled" { + return true + } + if key == "dns.enabled" { + return true + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return false + } + + mocks.MockConfigHandler.GetStringFunc = func(key string, defaultValue ...string) string { + if key == "network.cidr_block" { + return "10.0.0.0/24" + } + if key == "dns.address" { + return "" + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } + + mocks.MockConfigHandler.GetContextFunc = func() string { + return "mock-context" + } + + // Create a mock service + mockService := services.NewMockService() + mockService.GetNameFunc = func() string { + return "test-service" + } + mockService.GetAddressFunc = func() string { + return "10.0.0.2" + } + mockService.GetComposeConfigFunc = func() (*types.Config, error) { + return &types.Config{ + Services: []types.ServiceConfig{ + { + Name: "test-service", + Image: "test-image:latest", + Ports: []types.ServicePortConfig{ + { + Published: "8080", + Target: 80, + }, + }, + Environment: map[string]*string{ + "ENV": ptrString("test"), + }, + Volumes: []types.ServiceVolumeConfig{ + { + Source: "/host:/container", + }, + }, + }, + }, + }, nil + } + + // Register the mock service + mocks.Injector.Register("test-service", mockService) + dockerVirt.services = []services.Service{mockService} + + // Setup mock DNS service + mockDNS := services.NewMockService() + mockDNS.GetAddressFunc = func() string { + return "10.0.0.53" + } + mocks.Injector.Register("dns", mockDNS) + + // Call the function + project, err := dockerVirt.getFullComposeConfig() + + // Assertions + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if project == nil { + t.Fatal("expected project to be non-nil") + } + if len(project.Services) != 1 { + t.Fatalf("expected 1 service, got %d", len(project.Services)) + } + + service := project.Services[0] + if service.Name != "test-service" { + t.Errorf("expected service name to be test-service, got %s", service.Name) + } + if service.Image != "test-image:latest" { + t.Errorf("expected image to be test-image:latest, got %s", service.Image) + } + if len(service.Ports) != 1 { + t.Fatalf("expected 1 port, got %d", len(service.Ports)) + } + if service.Ports[0].Published != "8080" || service.Ports[0].Target != 80 { + t.Errorf("expected port to be 8080:80, got %s:%d", service.Ports[0].Published, service.Ports[0].Target) + } + if service.Environment["ENV"] == nil || *service.Environment["ENV"] != "test" { + t.Errorf("expected environment ENV to be test, got %v", service.Environment["ENV"]) + } + if len(service.Volumes) != 1 { + t.Fatalf("expected 1 volume, got %d", len(service.Volumes)) + } + if service.Volumes[0].Source != "/host:/container" { + t.Errorf("expected volume source to be /host:/container, got %s", service.Volumes[0].Source) + } + if len(service.Networks) != 1 { + t.Fatalf("expected 1 network, got %d", len(service.Networks)) + } + if service.Networks["windsor-mock-context"] == nil { + t.Error("expected network windsor-mock-context to exist") + } + if service.DNS == nil || len(service.DNS) != 1 || service.DNS[0] != "10.0.0.53" { + t.Errorf("expected DNS to be [10.0.0.53], got %v", service.DNS) + } + }) }