diff --git a/pkg/services/dns_service.go b/pkg/services/dns_service.go index 6cda0ae64..d18db03bb 100644 --- a/pkg/services/dns_service.go +++ b/pkg/services/dns_service.go @@ -109,11 +109,11 @@ func (s *DNSService) GetComposeConfig() (*types.Config, error) { return &types.Config{Services: services}, nil } -// WriteConfig generates a Corefile by collecting the project root directory, top-level domain (TLD), and IP addresses. -// It adds DNS entries for each service, ensuring that each service's hostname resolves to its IP address. -// For localhost environments, it uses a specific DNS template to handle local DNS resolution and sets up forwarding -// rules to direct DNS queries to the appropriate addresses. -// The Corefile is saved in the .windsor directory, which is used by CoreDNS to manage DNS queries for the project. +// WriteConfig generates and writes a CoreDNS Corefile for the Windsor project. +// It collects the project root directory, top-level domain (TLD), and service IP addresses. +// For each service, it adds DNS entries mapping hostnames to IP addresses, and includes wildcard DNS entries if supported. +// In localhost mode, it uses a template for local DNS resolution and sets up forwarding rules for DNS queries. +// The generated Corefile is saved in the .windsor directory for CoreDNS to manage project DNS queries. func (s *DNSService) WriteConfig() error { projectRoot, err := s.shell.GetProjectRoot() if err != nil { @@ -217,6 +217,12 @@ func (s *DNSService) WriteConfig() error { return fmt.Errorf("error creating parent folders: %w", err) } + if stat, err := s.shims.Stat(corefilePath); err == nil && stat.IsDir() { + if err := s.shims.RemoveAll(corefilePath); err != nil { + return fmt.Errorf("error removing Corefile directory: %w", err) + } + } + if err := s.shims.WriteFile(corefilePath, []byte(corefileContent), 0644); err != nil { return fmt.Errorf("error writing Corefile: %w", err) } diff --git a/pkg/services/dns_service_test.go b/pkg/services/dns_service_test.go index ea14829ac..a4f183ca6 100644 --- a/pkg/services/dns_service_test.go +++ b/pkg/services/dns_service_test.go @@ -681,6 +681,50 @@ func TestDNSService_WriteConfig(t *testing.T) { t.Errorf("Expected error to contain 'error retrieving project root', got: %v", err) } }) + + t.Run("SuccessRemovingCorefileDirectory", func(t *testing.T) { + // Given a DNSService with mock components + service, mocks := setup(t) + + var removedPath string + var statCalled bool + + // Mock Stat to return a directory + mocks.Shims.Stat = func(name string) (os.FileInfo, error) { + statCalled = true + if strings.Contains(name, "Corefile") { + return &mockFileInfo{isDir: true}, nil + } + return &mockFileInfo{isDir: false}, nil + } + + // Mock RemoveAll to capture the removed path + mocks.Shims.RemoveAll = func(path string) error { + removedPath = path + return nil + } + + // When WriteConfig is called + err := service.WriteConfig() + + // Then no error should be returned + if err != nil { + t.Fatalf("WriteConfig() error = %v", err) + } + + // And Stat should have been called + if !statCalled { + t.Error("Expected Stat to be called") + } + + // And RemoveAll should have been called with the Corefile path + if removedPath == "" { + t.Error("Expected RemoveAll to be called") + } + if !strings.Contains(removedPath, "Corefile") { + t.Errorf("Expected RemoveAll to be called with Corefile path, got: %s", removedPath) + } + }) } func TestDNSService_SetName(t *testing.T) { diff --git a/pkg/services/service_test.go b/pkg/services/service_test.go index aa736d554..393e7e9a8 100644 --- a/pkg/services/service_test.go +++ b/pkg/services/service_test.go @@ -3,6 +3,7 @@ package services import ( "os" "testing" + "time" "github.com/windsorcli/cli/pkg/config" "github.com/windsorcli/cli/pkg/di" @@ -18,6 +19,18 @@ import ( // Test Setup // ============================================================================= +// mockFileInfo implements os.FileInfo for testing +type mockFileInfo struct { + isDir bool +} + +func (m *mockFileInfo) Name() string { return "mockfile" } +func (m *mockFileInfo) Size() int64 { return 0 } +func (m *mockFileInfo) Mode() os.FileMode { return 0644 } +func (m *mockFileInfo) ModTime() time.Time { return time.Now() } +func (m *mockFileInfo) IsDir() bool { return m.isDir } +func (m *mockFileInfo) Sys() interface{} { return nil } + type Mocks struct { Injector di.Injector ConfigHandler config.ConfigHandler @@ -45,7 +58,8 @@ func setupShims(t *testing.T) *Shims { return nil } shims.Stat = func(name string) (os.FileInfo, error) { - return nil, nil + // Return a mock file info that indicates it's not a directory + return &mockFileInfo{isDir: false}, nil } shims.Mkdir = func(path string, perm os.FileMode) error { return nil @@ -53,6 +67,9 @@ func setupShims(t *testing.T) *Shims { shims.MkdirAll = func(path string, perm os.FileMode) error { return nil } + shims.RemoveAll = func(path string) error { + return nil + } shims.Rename = func(oldpath, newpath string) error { return nil } diff --git a/pkg/services/shims.go b/pkg/services/shims.go index 9809308aa..6ba164d25 100644 --- a/pkg/services/shims.go +++ b/pkg/services/shims.go @@ -25,6 +25,7 @@ type Shims struct { Stat func(name string) (os.FileInfo, error) Mkdir func(path string, perm os.FileMode) error MkdirAll func(path string, perm os.FileMode) error + RemoveAll func(path string) error Rename func(oldpath, newpath string) error YamlMarshal func(in any) ([]byte, error) YamlUnmarshal func(in []byte, out any) error @@ -41,6 +42,7 @@ func NewShims() *Shims { Stat: os.Stat, Mkdir: os.Mkdir, MkdirAll: os.MkdirAll, + RemoveAll: os.RemoveAll, Rename: os.Rename, YamlMarshal: yaml.Marshal, YamlUnmarshal: yaml.Unmarshal,