Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions pkg/services/dns_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ func (s *DNSService) WriteConfig() error {
}

tld := s.configHandler.GetString("dns.domain", "test")
networkCIDR := s.configHandler.GetString("network.cidr_block")

var (
hostEntries string
Expand Down Expand Up @@ -199,9 +198,7 @@ func (s *DNSService) WriteConfig() error {

var corefileContent string
if s.isLocalhostMode() {
internalView := fmt.Sprintf(" view internal {\n expr incidr(client_ip(), '%s')\n }\n", networkCIDR)
corefileContent = fmt.Sprintf(serverBlockTemplate, tld, internalView, hostEntries, wildcardEntries, forwardAddressesStr)
corefileContent += fmt.Sprintf(serverBlockTemplate, tld, "", localhostHostEntries, localhostWildcardEntries, forwardAddressesStr)
corefileContent = fmt.Sprintf(serverBlockTemplate, tld, "", localhostHostEntries, localhostWildcardEntries, forwardAddressesStr)
} else {
corefileContent = fmt.Sprintf(serverBlockTemplate, tld, "", hostEntries, wildcardEntries, forwardAddressesStr)
}
Expand Down
119 changes: 18 additions & 101 deletions pkg/services/dns_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,70 +360,30 @@ func TestDNSService_WriteConfig(t *testing.T) {
})

t.Run("SuccessLocalhostMode", func(t *testing.T) {
// Given a DNSService with mock components
// Setup
service, mocks := setup(t)

// Set vm.driver to docker-desktop to simulate localhost mode
mocks.ConfigHandler.SetContextValue("vm.driver", "docker-desktop")
mocks.ConfigHandler.SetContextValue("dns.domain", "test")
mocks.ConfigHandler.SetContextValue("network.cidr_block", "192.168.1.0/24")

// Create a mock service with a hostname
mockService := NewMockService()
mockService.GetNameFunc = func() string {
return "test-service"
}
mockService.GetAddressFunc = func() string {
return "192.168.1.2"
}
mockService.GetComposeConfigFunc = func() (*types.Config, error) {
return &types.Config{
Services: []types.ServiceConfig{
{Name: "test-service"},
},
}, nil
}
mockService.GetHostnameFunc = func() string {
return "test-service.test"
}
mockService.SupportsWildcardFunc = func() bool {
return false
}

// Register the mock service
mocks.Injector.Register("test-service", mockService)
service.services = []Service{mockService}

// Mock the writeFile function to capture the content written
var writtenContent []byte
mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error {
writtenContent = data
return nil
}

// When WriteConfig is called
service.SetName("test")
service.SetAddress("192.168.1.1")

// Execute
err := service.WriteConfig()

// Then no error should be returned
// Assert
if err != nil {
t.Fatalf("WriteConfig() error = %v", err)
t.Errorf("expected no error, got %v", err)
}

// Verify that the Corefile content includes both regular and localhost entries
content := string(writtenContent)
expectedEntries := []string{
"192.168.1.2 test-service.test",
"127.0.0.1 test-service.test",
}
for _, entry := range expectedEntries {
if !strings.Contains(content, entry) {
t.Errorf("Expected Corefile to contain entry %q, got:\n%s", entry, content)
}
}

// Verify that the internal view is present
if !strings.Contains(content, "view internal") {
t.Errorf("Expected Corefile to contain internal view, got:\n%s", content)
if !strings.Contains(content, "127.0.0.1 test") {
t.Errorf("Expected Corefile to contain entry \"127.0.0.1 test\", got:\n%s", content)
}
})

Expand Down Expand Up @@ -659,73 +619,30 @@ func TestDNSService_WriteConfig(t *testing.T) {
})

t.Run("SuccessLocalhostModeWithWildcard", func(t *testing.T) {
// Given a DNSService with mock components
// Setup
service, mocks := setup(t)

// Set vm.driver to docker-desktop to simulate localhost mode
mocks.ConfigHandler.SetContextValue("vm.driver", "docker-desktop")
mocks.ConfigHandler.SetContextValue("dns.domain", "test")
mocks.ConfigHandler.SetContextValue("network.cidr_block", "192.168.1.0/24")

// Create a mock service with wildcard support
mockService := NewMockService()
mockService.GetNameFunc = func() string {
return "test-service"
}
mockService.GetAddressFunc = func() string {
return "192.168.1.2"
}
mockService.GetComposeConfigFunc = func() (*types.Config, error) {
return &types.Config{
Services: []types.ServiceConfig{
{Name: "test-service"},
},
}, nil
}
mockService.GetHostnameFunc = func() string {
return "test-service.test"
}
mockService.SupportsWildcardFunc = func() bool {
return true
}

// Register the mock service
mocks.Injector.Register("test-service", mockService)
service.services = []Service{mockService}

// Mock the writeFile function to capture the content written
var writtenContent []byte
mocks.Shims.WriteFile = func(filename string, data []byte, perm os.FileMode) error {
writtenContent = data
return nil
}

// When WriteConfig is called
service.SetName("test")
service.SetAddress("192.168.1.1")

// Execute
err := service.WriteConfig()

// Then no error should be returned
// Assert
if err != nil {
t.Fatalf("WriteConfig() error = %v", err)
t.Errorf("expected no error, got %v", err)
}

// Verify that the Corefile content includes both regular and localhost wildcard entries
content := string(writtenContent)
expectedWildcardMatches := []string{
"template IN A",
"match ^(.*)\\.test-service\\.test\\.$",
`answer "{{ .Name }} 60 IN A 192.168.1.2"`,
"fallthrough",
`answer "{{ .Name }} 60 IN A 127.0.0.1"`,
}
for _, expectedMatch := range expectedWildcardMatches {
if !strings.Contains(content, expectedMatch) {
t.Errorf("Expected Corefile to contain %q, got:\n%s", expectedMatch, content)
}
}

// Verify that the internal view is present
if !strings.Contains(content, "view internal") {
t.Errorf("Expected Corefile to contain internal view, got:\n%s", content)
if !strings.Contains(content, "127.0.0.1 test") {
t.Errorf("Expected Corefile to contain \"127.0.0.1 test\", got:\n%s", content)
}
})

Expand Down
22 changes: 6 additions & 16 deletions pkg/services/talos_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"
"math"
"os"
"slices"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -78,7 +77,6 @@ func (s *TalosService) SetAddress(address string) error {
return err
}

tld := s.configHandler.GetString("dns.domain", "test")
nodeType := "workers"
if s.mode == "controlplane" {
nodeType = "controlplanes"
Expand All @@ -102,7 +100,12 @@ func (s *TalosService) SetAddress(address string) error {
nextAPIPort++
}

if err := s.configHandler.SetContextValue(fmt.Sprintf("cluster.%s.nodes.%s.endpoint", nodeType, s.name), fmt.Sprintf("%s.%s:%d", s.name, tld, port)); err != nil {
endpointAddress := address
if s.isLocalhostMode() {
endpointAddress = "127.0.0.1"
}

if err := s.configHandler.SetContextValue(fmt.Sprintf("cluster.%s.nodes.%s.endpoint", nodeType, s.name), fmt.Sprintf("%s:%d", endpointAddress, port)); err != nil {
return err
}

Expand Down Expand Up @@ -288,19 +291,6 @@ func (s *TalosService) GetComposeConfig() (*types.Config, error) {

serviceConfig.Ports = ports

dnsAddress := s.configHandler.GetString("dns.address")
if dnsAddress != "" {
if serviceConfig.DNS == nil {
serviceConfig.DNS = []string{}
}

dnsExists := slices.Contains(serviceConfig.DNS, dnsAddress)

if !dnsExists {
serviceConfig.DNS = append(serviceConfig.DNS, dnsAddress)
}
}

volumesMap := map[string]types.VolumeConfig{
strings.ReplaceAll(nodeName+"_system_state", "-", "_"): {},
strings.ReplaceAll(nodeName+"_var", "-", "_"): {},
Expand Down
61 changes: 37 additions & 24 deletions pkg/services/talos_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func TestTalosService_SetAddress(t *testing.T) {
}

// And the endpoint should be set correctly
expectedEndpoint := fmt.Sprintf("controlplane1.test:%d", constants.DEFAULT_TALOS_API_PORT)
expectedEndpoint := fmt.Sprintf("127.0.0.1:%d", constants.DEFAULT_TALOS_API_PORT)
actualEndpoint := mocks.ConfigHandler.GetString("cluster.controlplanes.nodes.controlplane1.endpoint", "")
if actualEndpoint != expectedEndpoint {
t.Errorf("expected endpoint %s, got %s", expectedEndpoint, actualEndpoint)
Expand Down Expand Up @@ -211,7 +211,7 @@ func TestTalosService_SetAddress(t *testing.T) {
}

// And the endpoint should be set correctly with an incremented port
expectedEndpoint := fmt.Sprintf("controlplane2.test:%d", constants.DEFAULT_TALOS_API_PORT+1)
expectedEndpoint := fmt.Sprintf("127.0.0.1:%d", constants.DEFAULT_TALOS_API_PORT+1)
actualEndpoint := mocks.ConfigHandler.GetString("cluster.controlplanes.nodes.controlplane2.endpoint", "")
if actualEndpoint != expectedEndpoint {
t.Errorf("expected endpoint %s, got %s", expectedEndpoint, actualEndpoint)
Expand Down Expand Up @@ -249,8 +249,8 @@ func TestTalosService_SetAddress(t *testing.T) {
t.Fatalf("expected no error, got %v", err)
}

// And the endpoint should be set correctly with an incremented port
expectedEndpoint := fmt.Sprintf("worker1.test:%d", constants.DEFAULT_TALOS_API_PORT+1)
// And the endpoint should be set correctly
expectedEndpoint := fmt.Sprintf("127.0.0.1:%d", constants.DEFAULT_TALOS_API_PORT+1)
actualEndpoint := mocks.ConfigHandler.GetString("cluster.workers.nodes.worker1.endpoint", "")
if actualEndpoint != expectedEndpoint {
t.Errorf("expected endpoint %s, got %s", expectedEndpoint, actualEndpoint)
Expand Down Expand Up @@ -471,7 +471,7 @@ func TestTalosService_SetAddress(t *testing.T) {
}

// And the endpoint should use the custom TLD
expectedEndpoint := fmt.Sprintf("worker1.custom.local:%d", constants.DEFAULT_TALOS_API_PORT+1)
expectedEndpoint := fmt.Sprintf("127.0.0.1:%d", constants.DEFAULT_TALOS_API_PORT+1)
actualEndpoint := mocks.ConfigHandler.GetString("cluster.workers.nodes.worker1.endpoint", "")
if actualEndpoint != expectedEndpoint {
t.Errorf("expected endpoint %s, got %s", expectedEndpoint, actualEndpoint)
Expand Down Expand Up @@ -1271,28 +1271,45 @@ contexts:

t.Run("SuccessWithDNS", func(t *testing.T) {
// Given a TalosService with mock components
service, mocks := setup(t)
mocks := setupTalosServiceMocks(t)

// And DNS configuration
if err := mocks.ConfigHandler.SetContextValue("dns.address", "8.8.8.8"); err != nil {
t.Fatalf("Failed to set DNS address: %v", err)
// And DNS configuration is set
mocks.ConfigHandler.SetContextValue("dns.domain", "test")
mocks.ConfigHandler.SetContextValue("dns.address", "192.168.1.1")

// Create a worker node
service := NewTalosService(mocks.Injector, "worker")
service.SetName("worker1")
if err := service.Initialize(); err != nil {
t.Fatalf("Failed to initialize service: %v", err)
}

// Mock MkdirAll to always succeed
service.shims.MkdirAll = func(path string, perm os.FileMode) error {
return nil
}

// When GetComposeConfig is called
config, err := service.GetComposeConfig()
cfg, err := service.GetComposeConfig()

// Then there should be no error
// Then no error should be returned
if err != nil {
t.Fatalf("expected no error, got %v", err)
}

// And the DNS configuration should be set correctly
serviceConfig := config.Services[0]
if len(serviceConfig.DNS) != 1 {
t.Errorf("expected 1 DNS server, got %d", len(serviceConfig.DNS))
// And the config should be valid
if cfg == nil {
t.Fatalf("expected non-nil config, got nil")
}
if serviceConfig.DNS[0] != "8.8.8.8" {
t.Errorf("expected DNS server 8.8.8.8, got %s", serviceConfig.DNS[0])

// And the service should be configured correctly
if len(cfg.Services) != 1 {
t.Fatalf("expected 1 service, got %d", len(cfg.Services))
}

// And the service should have the correct name
if cfg.Services[0].Name != "worker1" {
t.Errorf("expected service name 'worker1', got '%s'", cfg.Services[0].Name)
}
})

Expand Down Expand Up @@ -1655,13 +1672,9 @@ contexts:
t.Fatalf("expected no error, got %v", err)
}

// And the DNS configuration should have the address only once
serviceConfig := config.Services[0]
if len(serviceConfig.DNS) != 1 {
t.Errorf("expected 1 DNS server, got %d", len(serviceConfig.DNS))
}
if serviceConfig.DNS[0] != "8.8.8.8" {
t.Errorf("expected DNS server 8.8.8.8, got %s", serviceConfig.DNS[0])
// And the service should be configured correctly
if len(config.Services) != 1 {
t.Fatalf("expected 1 service, got %d", len(config.Services))
}
})

Expand Down
Loading