From 480c2cbff7e17a1b48936e010d1baec941a16cec Mon Sep 17 00:00:00 2001 From: Ryan VanGundy Date: Thu, 1 May 2025 07:55:55 -0400 Subject: [PATCH] Remove split DNS --- pkg/services/dns_service.go | 5 +- pkg/services/dns_service_test.go | 119 +++++------------------------ pkg/services/talos_service.go | 22 ++---- pkg/services/talos_service_test.go | 61 +++++++++------ 4 files changed, 62 insertions(+), 145 deletions(-) diff --git a/pkg/services/dns_service.go b/pkg/services/dns_service.go index 2f43af35b..769a9746d 100644 --- a/pkg/services/dns_service.go +++ b/pkg/services/dns_service.go @@ -119,7 +119,6 @@ func (s *DNSService) WriteConfig() error { } tld := s.configHandler.GetString("dns.domain", "test") - networkCIDR := s.configHandler.GetString("network.cidr_block") var ( hostEntries string @@ -199,9 +198,7 @@ func (s *DNSService) WriteConfig() error { var corefileContent string if s.isLocalhostMode() { - internalView := fmt.Sprintf(" view internal {\n expr incidr(client_ip(), '%s')\n }\n", networkCIDR) - corefileContent = fmt.Sprintf(serverBlockTemplate, tld, internalView, hostEntries, wildcardEntries, forwardAddressesStr) - corefileContent += fmt.Sprintf(serverBlockTemplate, tld, "", localhostHostEntries, localhostWildcardEntries, forwardAddressesStr) + corefileContent = fmt.Sprintf(serverBlockTemplate, tld, "", localhostHostEntries, localhostWildcardEntries, forwardAddressesStr) } else { corefileContent = fmt.Sprintf(serverBlockTemplate, tld, "", hostEntries, wildcardEntries, forwardAddressesStr) } diff --git a/pkg/services/dns_service_test.go b/pkg/services/dns_service_test.go index 8e604610b..f86248550 100644 --- a/pkg/services/dns_service_test.go +++ b/pkg/services/dns_service_test.go @@ -360,70 +360,30 @@ func TestDNSService_WriteConfig(t *testing.T) { }) t.Run("SuccessLocalhostMode", func(t *testing.T) { - // Given a DNSService with mock components + // Setup service, mocks := setup(t) - - // Set vm.driver to docker-desktop to simulate localhost mode mocks.ConfigHandler.SetContextValue("vm.driver", "docker-desktop") - mocks.ConfigHandler.SetContextValue("dns.domain", "test") - mocks.ConfigHandler.SetContextValue("network.cidr_block", "192.168.1.0/24") - // Create a mock service with a hostname - mockService := NewMockService() - mockService.GetNameFunc = func() string { - return "test-service" - } - mockService.GetAddressFunc = func() string { - return "192.168.1.2" - } - mockService.GetComposeConfigFunc = func() (*types.Config, error) { - return &types.Config{ - Services: []types.ServiceConfig{ - {Name: "test-service"}, - }, - }, nil - } - mockService.GetHostnameFunc = func() string { - return "test-service.test" - } - mockService.SupportsWildcardFunc = func() bool { - return false - } - - // Register the mock service - mocks.Injector.Register("test-service", mockService) - service.services = []Service{mockService} - - // Mock the writeFile function to capture the content written var writtenContent []byte mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error { writtenContent = data return nil } - // When WriteConfig is called + service.SetName("test") + service.SetAddress("192.168.1.1") + + // Execute err := service.WriteConfig() - // Then no error should be returned + // Assert if err != nil { - t.Fatalf("WriteConfig() error = %v", err) + t.Errorf("expected no error, got %v", err) } - // Verify that the Corefile content includes both regular and localhost entries content := string(writtenContent) - expectedEntries := []string{ - "192.168.1.2 test-service.test", - "127.0.0.1 test-service.test", - } - for _, entry := range expectedEntries { - if !strings.Contains(content, entry) { - t.Errorf("Expected Corefile to contain entry %q, got:\n%s", entry, content) - } - } - - // Verify that the internal view is present - if !strings.Contains(content, "view internal") { - t.Errorf("Expected Corefile to contain internal view, got:\n%s", content) + if !strings.Contains(content, "127.0.0.1 test") { + t.Errorf("Expected Corefile to contain entry \"127.0.0.1 test\", got:\n%s", content) } }) @@ -659,73 +619,30 @@ func TestDNSService_WriteConfig(t *testing.T) { }) t.Run("SuccessLocalhostModeWithWildcard", func(t *testing.T) { - // Given a DNSService with mock components + // Setup service, mocks := setup(t) - - // Set vm.driver to docker-desktop to simulate localhost mode mocks.ConfigHandler.SetContextValue("vm.driver", "docker-desktop") - mocks.ConfigHandler.SetContextValue("dns.domain", "test") - mocks.ConfigHandler.SetContextValue("network.cidr_block", "192.168.1.0/24") - - // Create a mock service with wildcard support - mockService := NewMockService() - mockService.GetNameFunc = func() string { - return "test-service" - } - mockService.GetAddressFunc = func() string { - return "192.168.1.2" - } - mockService.GetComposeConfigFunc = func() (*types.Config, error) { - return &types.Config{ - Services: []types.ServiceConfig{ - {Name: "test-service"}, - }, - }, nil - } - mockService.GetHostnameFunc = func() string { - return "test-service.test" - } - mockService.SupportsWildcardFunc = func() bool { - return true - } - // Register the mock service - mocks.Injector.Register("test-service", mockService) - service.services = []Service{mockService} - - // Mock the writeFile function to capture the content written var writtenContent []byte mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error { writtenContent = data return nil } - // When WriteConfig is called + service.SetName("test") + service.SetAddress("192.168.1.1") + + // Execute err := service.WriteConfig() - // Then no error should be returned + // Assert if err != nil { - t.Fatalf("WriteConfig() error = %v", err) + t.Errorf("expected no error, got %v", err) } - // Verify that the Corefile content includes both regular and localhost wildcard entries content := string(writtenContent) - expectedWildcardMatches := []string{ - "template IN A", - "match ^(.*)\\.test-service\\.test\\.$", - `answer "{{ .Name }} 60 IN A 192.168.1.2"`, - "fallthrough", - `answer "{{ .Name }} 60 IN A 127.0.0.1"`, - } - for _, expectedMatch := range expectedWildcardMatches { - if !strings.Contains(content, expectedMatch) { - t.Errorf("Expected Corefile to contain %q, got:\n%s", expectedMatch, content) - } - } - - // Verify that the internal view is present - if !strings.Contains(content, "view internal") { - t.Errorf("Expected Corefile to contain internal view, got:\n%s", content) + if !strings.Contains(content, "127.0.0.1 test") { + t.Errorf("Expected Corefile to contain \"127.0.0.1 test\", got:\n%s", content) } }) diff --git a/pkg/services/talos_service.go b/pkg/services/talos_service.go index c528a1bf1..e6bb359ff 100644 --- a/pkg/services/talos_service.go +++ b/pkg/services/talos_service.go @@ -4,7 +4,6 @@ import ( "fmt" "math" "os" - "slices" "strconv" "strings" "sync" @@ -78,7 +77,6 @@ func (s *TalosService) SetAddress(address string) error { return err } - tld := s.configHandler.GetString("dns.domain", "test") nodeType := "workers" if s.mode == "controlplane" { nodeType = "controlplanes" @@ -102,7 +100,12 @@ func (s *TalosService) SetAddress(address string) error { nextAPIPort++ } - if err := s.configHandler.SetContextValue(fmt.Sprintf("cluster.%s.nodes.%s.endpoint", nodeType, s.name), fmt.Sprintf("%s.%s:%d", s.name, tld, port)); err != nil { + endpointAddress := address + if s.isLocalhostMode() { + endpointAddress = "127.0.0.1" + } + + if err := s.configHandler.SetContextValue(fmt.Sprintf("cluster.%s.nodes.%s.endpoint", nodeType, s.name), fmt.Sprintf("%s:%d", endpointAddress, port)); err != nil { return err } @@ -288,19 +291,6 @@ func (s *TalosService) GetComposeConfig() (*types.Config, error) { serviceConfig.Ports = ports - dnsAddress := s.configHandler.GetString("dns.address") - if dnsAddress != "" { - if serviceConfig.DNS == nil { - serviceConfig.DNS = []string{} - } - - dnsExists := slices.Contains(serviceConfig.DNS, dnsAddress) - - if !dnsExists { - serviceConfig.DNS = append(serviceConfig.DNS, dnsAddress) - } - } - volumesMap := map[string]types.VolumeConfig{ strings.ReplaceAll(nodeName+"_system_state", "-", "_"): {}, strings.ReplaceAll(nodeName+"_var", "-", "_"): {}, diff --git a/pkg/services/talos_service_test.go b/pkg/services/talos_service_test.go index 38614e01b..b4735721c 100644 --- a/pkg/services/talos_service_test.go +++ b/pkg/services/talos_service_test.go @@ -162,7 +162,7 @@ func TestTalosService_SetAddress(t *testing.T) { } // And the endpoint should be set correctly - expectedEndpoint := fmt.Sprintf("controlplane1.test:%d", constants.DEFAULT_TALOS_API_PORT) + expectedEndpoint := fmt.Sprintf("127.0.0.1:%d", constants.DEFAULT_TALOS_API_PORT) actualEndpoint := mocks.ConfigHandler.GetString("cluster.controlplanes.nodes.controlplane1.endpoint", "") if actualEndpoint != expectedEndpoint { t.Errorf("expected endpoint %s, got %s", expectedEndpoint, actualEndpoint) @@ -211,7 +211,7 @@ func TestTalosService_SetAddress(t *testing.T) { } // And the endpoint should be set correctly with an incremented port - expectedEndpoint := fmt.Sprintf("controlplane2.test:%d", constants.DEFAULT_TALOS_API_PORT+1) + expectedEndpoint := fmt.Sprintf("127.0.0.1:%d", constants.DEFAULT_TALOS_API_PORT+1) actualEndpoint := mocks.ConfigHandler.GetString("cluster.controlplanes.nodes.controlplane2.endpoint", "") if actualEndpoint != expectedEndpoint { t.Errorf("expected endpoint %s, got %s", expectedEndpoint, actualEndpoint) @@ -249,8 +249,8 @@ func TestTalosService_SetAddress(t *testing.T) { t.Fatalf("expected no error, got %v", err) } - // And the endpoint should be set correctly with an incremented port - expectedEndpoint := fmt.Sprintf("worker1.test:%d", constants.DEFAULT_TALOS_API_PORT+1) + // And the endpoint should be set correctly + expectedEndpoint := fmt.Sprintf("127.0.0.1:%d", constants.DEFAULT_TALOS_API_PORT+1) actualEndpoint := mocks.ConfigHandler.GetString("cluster.workers.nodes.worker1.endpoint", "") if actualEndpoint != expectedEndpoint { t.Errorf("expected endpoint %s, got %s", expectedEndpoint, actualEndpoint) @@ -471,7 +471,7 @@ func TestTalosService_SetAddress(t *testing.T) { } // And the endpoint should use the custom TLD - expectedEndpoint := fmt.Sprintf("worker1.custom.local:%d", constants.DEFAULT_TALOS_API_PORT+1) + expectedEndpoint := fmt.Sprintf("127.0.0.1:%d", constants.DEFAULT_TALOS_API_PORT+1) actualEndpoint := mocks.ConfigHandler.GetString("cluster.workers.nodes.worker1.endpoint", "") if actualEndpoint != expectedEndpoint { t.Errorf("expected endpoint %s, got %s", expectedEndpoint, actualEndpoint) @@ -1271,28 +1271,45 @@ contexts: t.Run("SuccessWithDNS", func(t *testing.T) { // Given a TalosService with mock components - service, mocks := setup(t) + mocks := setupTalosServiceMocks(t) - // And DNS configuration - if err := mocks.ConfigHandler.SetContextValue("dns.address", "8.8.8.8"); err != nil { - t.Fatalf("Failed to set DNS address: %v", err) + // And DNS configuration is set + mocks.ConfigHandler.SetContextValue("dns.domain", "test") + mocks.ConfigHandler.SetContextValue("dns.address", "192.168.1.1") + + // Create a worker node + service := NewTalosService(mocks.Injector, "worker") + service.SetName("worker1") + if err := service.Initialize(); err != nil { + t.Fatalf("Failed to initialize service: %v", err) + } + + // Mock MkdirAll to always succeed + service.shims.MkdirAll = func(path string, perm os.FileMode) error { + return nil } // When GetComposeConfig is called - config, err := service.GetComposeConfig() + cfg, err := service.GetComposeConfig() - // Then there should be no error + // Then no error should be returned if err != nil { t.Fatalf("expected no error, got %v", err) } - // And the DNS configuration should be set correctly - serviceConfig := config.Services[0] - if len(serviceConfig.DNS) != 1 { - t.Errorf("expected 1 DNS server, got %d", len(serviceConfig.DNS)) + // And the config should be valid + if cfg == nil { + t.Fatalf("expected non-nil config, got nil") } - if serviceConfig.DNS[0] != "8.8.8.8" { - t.Errorf("expected DNS server 8.8.8.8, got %s", serviceConfig.DNS[0]) + + // And the service should be configured correctly + if len(cfg.Services) != 1 { + t.Fatalf("expected 1 service, got %d", len(cfg.Services)) + } + + // And the service should have the correct name + if cfg.Services[0].Name != "worker1" { + t.Errorf("expected service name 'worker1', got '%s'", cfg.Services[0].Name) } }) @@ -1655,13 +1672,9 @@ contexts: t.Fatalf("expected no error, got %v", err) } - // And the DNS configuration should have the address only once - serviceConfig := config.Services[0] - if len(serviceConfig.DNS) != 1 { - t.Errorf("expected 1 DNS server, got %d", len(serviceConfig.DNS)) - } - if serviceConfig.DNS[0] != "8.8.8.8" { - t.Errorf("expected DNS server 8.8.8.8, got %s", serviceConfig.DNS[0]) + // And the service should be configured correctly + if len(config.Services) != 1 { + t.Fatalf("expected 1 service, got %d", len(config.Services)) } })