From 09cef9dee158ef8237adfbca6573c6b56d5cecda Mon Sep 17 00:00:00 2001 From: Andreas Linde Date: Sun, 28 Dec 2025 15:20:42 +0100 Subject: [PATCH 1/3] Bump module to v3 Breaking change release targeting: - Proper interfaces for testability - Dependency injection - Generated mocks - Improved test coverage - Removal of global state --- examples/proxyservice/server.go | 2 +- examples/register/server.go | 2 +- examples/resolv/client.go | 2 +- go.mod | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/proxyservice/server.go b/examples/proxyservice/server.go index a601f52..4ad27a1 100644 --- a/examples/proxyservice/server.go +++ b/examples/proxyservice/server.go @@ -9,7 +9,7 @@ import ( "time" - "github.com/enbility/zeroconf/v2" + "github.com/enbility/zeroconf/v3" ) var ( diff --git a/examples/register/server.go b/examples/register/server.go index eefb72e..e37755f 100644 --- a/examples/register/server.go +++ b/examples/register/server.go @@ -9,7 +9,7 @@ import ( "time" - "github.com/enbility/zeroconf/v2" + "github.com/enbility/zeroconf/v3" ) var ( diff --git a/examples/resolv/client.go b/examples/resolv/client.go index f435ac5..6a3b3d3 100644 --- a/examples/resolv/client.go +++ b/examples/resolv/client.go @@ -6,7 +6,7 @@ import ( "log" "time" - "github.com/enbility/zeroconf/v2" + "github.com/enbility/zeroconf/v3" ) var ( diff --git a/go.mod b/go.mod index 599b792..9173cd8 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/enbility/zeroconf/v2 +module github.com/enbility/zeroconf/v3 go 1.22.0 From 63258a6bd561b42fcef3315d44c8a4eb21e4caa2 Mon Sep 17 00:00:00 2001 From: Andreas Linde Date: Sun, 28 Dec 2025 19:27:11 +0100 Subject: [PATCH 2/3] feat: v3 refactoring with interface-based testability - Add api/ package with PacketConn, ConnectionFactory, InterfaceProvider interfaces - Add mocks/ package with mockery-generated mocks for testing - Split connection.go into conn_factory.go, conn_ipv4.go, conn_ipv6.go, conn_provider.go - Rename connection.go to mdns.go (contains only mDNS constants) - Export Client type and add NewClient() constructor - Add WithClientConnFactory and WithServerConnFactory options for mock injection - Remove deprecated Server.TTL() method - Add comprehensive unit tests (87.6% coverage) - Update README with v3 examples and testing documentation - Bump version to v3.0.0 Breaking changes: - Browse() now requires a 'removed' channel parameter - Module path is github.com/enbility/zeroconf/v3 --- .gitignore | 3 + .mockery.yml | 20 + README.md | 63 ++- V3_REFACTORING_PLAN.md | 248 +++++++++++ api/interfaces.go | 57 +++ client.go | 122 +++--- client_unit_test.go | 588 +++++++++++++++++++++++++ conn_factory.go | 72 +++ conn_ipv4.go | 80 ++++ conn_ipv6.go | 80 ++++ conn_provider.go | 37 ++ connection.go | 119 ----- go.mod | 5 + go.sum | 12 + mdns.go | 30 ++ mocks/mock_connection_factory.go | 163 +++++++ mocks/mock_interface_provider.go | 84 ++++ mocks/mock_packet_conn.go | 495 +++++++++++++++++++++ server.go | 178 +++----- server_unit_test.go | 728 +++++++++++++++++++++++++++++++ version.json | 2 +- 21 files changed, 2867 insertions(+), 319 deletions(-) create mode 100644 .mockery.yml create mode 100644 V3_REFACTORING_PLAN.md create mode 100644 api/interfaces.go create mode 100644 client_unit_test.go create mode 100644 conn_factory.go create mode 100644 conn_ipv4.go create mode 100644 conn_ipv6.go create mode 100644 conn_provider.go delete mode 100644 connection.go create mode 100644 mdns.go create mode 100644 mocks/mock_connection_factory.go create mode 100644 mocks/mock_interface_provider.go create mode 100644 mocks/mock_packet_conn.go create mode 100644 server_unit_test.go diff --git a/.gitignore b/.gitignore index daf913b..966dc06 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,6 @@ _testmain.go *.exe *.test *.prof + +# Coverage +coverage.out diff --git a/.mockery.yml b/.mockery.yml new file mode 100644 index 0000000..5b97245 --- /dev/null +++ b/.mockery.yml @@ -0,0 +1,20 @@ +all: false +dir: 'mocks' +filename: 'mock_{{.InterfaceName | snakecase}}.go' +force-file-write: true +formatter: goimports +generate: true +include-auto-generated: false +log-level: info +structname: 'Mock{{.InterfaceName}}' +pkgname: 'mocks' +recursive: false +require-template-schema-exists: true +template: testify +template-schema: '{{.Template}}.schema.json' +packages: + github.com/enbility/zeroconf/v3/api: + interfaces: + PacketConn: + ConnectionFactory: + InterfaceProvider: diff --git a/README.md b/README.md index eb36503..a223bb6 100644 --- a/README.md +++ b/README.md @@ -22,24 +22,30 @@ Target environments: private LAN/Wifi, small or isolated networks. ## Install Nothing is as easy as that: ```bash -$ go get -u github.com/enbility/zeroconf/v2 +$ go get -u github.com/enbility/zeroconf/v3 ``` ## Browse for services in your local network ```go entries := make(chan *zeroconf.ServiceEntry) -go func(results <-chan *zeroconf.ServiceEntry) { - for entry := range results { - log.Println(entry) +removed := make(chan *zeroconf.ServiceEntry) + +go func() { + for { + select { + case entry := <-entries: + log.Println("Found:", entry) + case entry := <-removed: + log.Println("Removed:", entry) + } } - log.Println("No more entries.") -}(entries) +}() ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() // Discover all services on the network (e.g. _workstation._tcp) -err = zeroconf.Browse(ctx, "_workstation._tcp", "local.", entries) +err := zeroconf.Browse(ctx, "_workstation._tcp", "local.", entries, removed) if err != nil { log.Fatalln("Failed to browse:", err.Error()) } @@ -53,7 +59,23 @@ See https://github.com/enbility/zeroconf/blob/master/examples/resolv/client.go. ## Lookup a specific service instance ```go -// Example filled soon. +entries := make(chan *zeroconf.ServiceEntry) + +go func() { + for entry := range entries { + log.Println("Found:", entry) + } +}() + +ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) +defer cancel() +// Lookup a specific service instance by name +err := zeroconf.Lookup(ctx, "MyService", "_workstation._tcp", "local.", entries) +if err != nil { + log.Fatalln("Failed to lookup:", err.Error()) +} + +<-ctx.Done() ``` ## Register a service @@ -81,6 +103,29 @@ Multiple subtypes may be added to service name, separated by commas. E.g `_works See https://github.com/enbility/zeroconf/blob/master/examples/register/server.go. +## Testing Support (v3) + +Version 3 introduces interface-based abstractions for improved testability. You can inject mock connections for unit testing without requiring real network access: + +```go +// Create mock connections using the provided interfaces +mockFactory := &MyMockConnectionFactory{} + +// Client with mock connections +client, err := zeroconf.NewClient(zeroconf.WithClientConnFactory(mockFactory)) + +// Server with mock connections +server, err := zeroconf.RegisterProxy( + "MyService", "_http._tcp", "local.", 8080, + "myhost.local.", []string{"192.168.1.100"}, + []string{"txtvers=1"}, + nil, // interfaces + zeroconf.WithServerConnFactory(mockFactory), +) +``` + +See the `api/` package for interface definitions and `mocks/` for mockery-generated mocks. + ## Features and ToDo's This list gives a quick impression about the state of this library. See what needs to be done and submit a pull request :) @@ -89,6 +134,8 @@ See what needs to be done and submit a pull request :) * [x] Multiple IPv6 / IPv4 addresses support * [x] Send multiple probes (exp. back-off) if no service answers (*) * [x] Timestamp entries for TTL checks +* [x] Service removal notifications via `removed` channel +* [x] Interface-based abstractions for testability (v3) * [ ] Compare new multicasts with already received services _Notes:_ diff --git a/V3_REFACTORING_PLAN.md b/V3_REFACTORING_PLAN.md new file mode 100644 index 0000000..99884df --- /dev/null +++ b/V3_REFACTORING_PLAN.md @@ -0,0 +1,248 @@ +# ZeroConf v3 Refactoring Plan + +## Goals + +1. **Testability**: Enable unit testing without real network access +2. **Interfaces**: Define clear abstractions at network boundaries +3. **Dependency Injection**: Allow mock injection for testing +4. **Test Coverage**: Target 85%+ coverage with meaningful unit tests +5. **Generated Mocks**: Use mockery for maintainable mocks +6. **Remove Global State**: Move package-level vars into config structs + +## Key Insight: ControlMessage Simplification + +Analysis of the codebase shows that only `IfIndex` is ever used from `ipv4.ControlMessage` and `ipv6.ControlMessage`. This allows us to create a unified `PacketConn` interface that works for both IPv4 and IPv6: + +```go +// Instead of exposing ControlMessage, we just expose ifIndex +ReadFrom(b []byte) (n int, ifIndex int, src net.Addr, err error) +WriteTo(b []byte, ifIndex int, dst net.Addr) (n int, err error) +``` + +## Architecture Overview + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Public API │ +│ Browse() / Lookup() / Register() / RegisterProxy() │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Client / Server │ +│ - Use api.PacketConn interface (not concrete types) │ +│ - Accept ConnectionFactory via options │ +│ - Use InterfaceProvider internally for default interfaces │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ api/ Package │ +│ PacketConn / ConnectionFactory / InterfaceProvider │ +└─────────────────────────────────────────────────────────────┘ + │ + ┌───────────────┴───────────────┐ + ▼ ▼ +┌──────────────────────────┐ ┌──────────────────────────┐ +│ Real Implementations │ │ mocks/ Package │ +│ - ipv4PacketConn │ │ - MockPacketConn │ +│ - ipv6PacketConn │ │ - MockConnectionFactory │ +│ - defaultConnFactory │ │ - MockInterfaceProvider │ +│ - defaultIfaceProvider │ │ (generated by mockery) │ +└──────────────────────────┘ └──────────────────────────┘ +``` + +## Package Structure + +``` +zeroconf/v3/ +├── api/ # Pure interfaces (no internal deps) +│ └── interfaces.go # PacketConn, ConnectionFactory, InterfaceProvider +├── mocks/ # Generated mocks (mockery) +│ ├── mock_packet_conn.go +│ ├── mock_connection_factory.go +│ └── mock_interface_provider.go +├── .mockery.yml # Mockery configuration +├── client.go # Client implementation +├── server.go # Server implementation +├── conn_ipv4.go # ipv4PacketConn wrapper +├── conn_ipv6.go # ipv6PacketConn wrapper +├── conn_factory.go # defaultConnectionFactory +├── conn_provider.go # defaultInterfaceProvider +├── mdns.go # Network constants (mDNS addresses) +├── service.go # ServiceEntry, ServiceRecord +├── utils.go # Helper functions +├── doc.go # Package documentation +├── *_test.go # Tests +└── examples/ # Example applications +``` + +## Interface Definitions (in api/interfaces.go) + +### PacketConn + +```go +type PacketConn interface { + ReadFrom(b []byte) (n int, ifIndex int, src net.Addr, err error) + WriteTo(b []byte, ifIndex int, dst net.Addr) (n int, err error) + Close() error + JoinGroup(ifi *net.Interface, group net.Addr) error + LeaveGroup(ifi *net.Interface, group net.Addr) error + SetMulticastTTL(ttl int) error + SetMulticastHopLimit(hopLimit int) error + SetMulticastInterface(ifi *net.Interface) error +} +``` + +### ConnectionFactory + +```go +type ConnectionFactory interface { + CreateIPv4Conn(ifaces []net.Interface) (PacketConn, error) + CreateIPv6Conn(ifaces []net.Interface) (PacketConn, error) +} +``` + +### InterfaceProvider + +```go +type InterfaceProvider interface { + MulticastInterfaces() []net.Interface +} +``` + +## Implementation Phases + +### Phase 1: Package Structure & Interfaces ✓ + +**Completed:** +- [x] Create `api/` package with interface definitions +- [x] Configure mockery (`.mockery.yml`) +- [x] Generate mocks in `mocks/` package +- [x] Update implementation to import `api/` +- [x] Update tests to use `mocks/` + +--- + +### Phase 2: Connection Wrappers (File Split) ✓ + +**Completed:** +- [x] `conn_ipv4.go` - ipv4PacketConn wrapper +- [x] `conn_ipv6.go` - ipv6PacketConn wrapper +- [x] `conn_factory.go` - defaultConnectionFactory +- [x] `conn_provider.go` - defaultInterfaceProvider with MulticastInterfaces() +- [x] Removed `conn_wrapper.go` (split into above files) + +--- + +### Phase 3: InterfaceProvider Implementation ✓ + +**Completed:** +- [x] Created `defaultInterfaceProvider` implementing `api.InterfaceProvider` +- [x] Moved `listMulticastInterfaces()` into `defaultInterfaceProvider.MulticastInterfaces()` +- [x] Used internally via `NewInterfaceProvider().MulticastInterfaces()` in client/server + +**Design Decision:** Removed `WithIfaceProvider` options after review - they added complexity without clear benefit since: +- `Register()` already accepts `ifaces []net.Interface` directly +- `SelectIfaces()` option exists for client +- Interface selection is simpler as a direct parameter than an injected provider + +--- + +### Phase 4: Server Improvements ✓ + +**Changes to `server.go`:** +- [x] Change `Server.ipv4conn` to use `api.PacketConn` +- [x] Change `Server.ipv6conn` to use `api.PacketConn` +- [x] Add `WithServerConnFactory()` option +- [x] Remove deprecated `Server.TTL()` method + +--- + +### Phase 5: Client Improvements ✓ + +**Changes to `client.go`:** +- [x] Change connection fields to use `api.PacketConn` +- [x] Add `WithClientConnFactory()` option +- [x] Export `Client` type (renamed `client` -> `Client`) +- [x] Add `NewClient()` constructor + +--- + +### Phase 6: Coverage & Cleanup + +1. Run coverage report, identify gaps +2. Add tests for untested functions +3. Update doc.go for v3 +4. Final integration test pass + +--- + +## File Changes Summary + +| File | Action | Description | +|------|--------|-------------| +| `api/interfaces.go` | DONE | Interface definitions | +| `mocks/*.go` | DONE | Generated mocks (mockery) | +| `.mockery.yml` | DONE | Mockery configuration | +| `conn_ipv4.go` | NEW | IPv4 PacketConn wrapper | +| `conn_ipv6.go` | NEW | IPv6 PacketConn wrapper | +| `conn_factory.go` | NEW | defaultConnectionFactory | +| `conn_provider.go` | NEW | defaultInterfaceProvider + listMulticastInterfaces | +| `conn_wrapper.go` | DELETE | Split into above files | +| `mdns.go` | RENAME | Network constants (was connection.go) | +| `server.go` | MODIFY | Add WithServerConnFactory, remove TTL() | +| `client.go` | MODIFY | Export Client, add NewClient, add WithClientConnFactory | +| `server_unit_test.go` | DONE | Unit tests with mocks | +| `client_unit_test.go` | DONE | Unit tests with mocks | + +--- + +## Breaking Changes + +1. **Module path**: `github.com/enbility/zeroconf/v3` +2. **Exported `Client` type**: New public API +3. **`NewClient()` function**: New constructor +4. **Removed**: `Server.TTL()` method (was deprecated) + +## Backward Compatibility + +Main API functions remain compatible: +- `Browse(ctx, service, domain, entries, removed, opts...)` - unchanged +- `Lookup(ctx, instance, service, domain, entries, opts...)` - unchanged +- `Register(instance, service, domain, port, text, ifaces, opts...)` - unchanged +- `RegisterProxy(...)` - unchanged + +New optional features via options: +- `WithClientConnFactory(factory)` / `WithServerConnFactory(factory)` - for injecting mock connections in tests + +--- + +## Mock Generation + +Using mockery v3. Configuration in `.mockery.yml`: + +```yaml +packages: + github.com/enbility/zeroconf/v3/api: + interfaces: + PacketConn: + ConnectionFactory: + InterfaceProvider: +``` + +Regenerate mocks: +```bash +mockery +``` + +--- + +## Success Criteria + +- [x] All existing tests pass +- [ ] Test coverage > 85% (currently 72.7%) +- [x] All network I/O behind interfaces +- [x] Unit tests run without network access +- [x] Mocks generated automatically +- [ ] Documentation updated diff --git a/api/interfaces.go b/api/interfaces.go new file mode 100644 index 0000000..141c392 --- /dev/null +++ b/api/interfaces.go @@ -0,0 +1,57 @@ +// Package api defines the core interfaces for the zeroconf library. +// These interfaces enable dependency injection and testing. +package api + +import "net" + +//go:generate mockery + +// PacketConn abstracts IPv4/IPv6 multicast packet connections. +// This interface unifies ipv4.PacketConn and ipv6.PacketConn by extracting +// only the IfIndex from ControlMessage, which is the only field used. +type PacketConn interface { + // ReadFrom reads a packet from the connection. + // Returns the number of bytes read, the interface index the packet arrived on, + // the source address, and any error. + ReadFrom(b []byte) (n int, ifIndex int, src net.Addr, err error) + + // WriteTo writes a packet to the destination address. + // The ifIndex specifies which interface to send from (0 for default/all). + WriteTo(b []byte, ifIndex int, dst net.Addr) (n int, err error) + + // Close closes the connection. + Close() error + + // JoinGroup joins the multicast group on the specified interface. + JoinGroup(ifi *net.Interface, group net.Addr) error + + // LeaveGroup leaves the multicast group on the specified interface. + LeaveGroup(ifi *net.Interface, group net.Addr) error + + // SetMulticastTTL sets the TTL for outgoing multicast packets (IPv4). + SetMulticastTTL(ttl int) error + + // SetMulticastHopLimit sets the hop limit for outgoing multicast packets (IPv6). + SetMulticastHopLimit(hopLimit int) error + + // SetMulticastInterface sets the default interface for outgoing multicast. + // Used as fallback on platforms where ControlMessage is not supported (Windows). + SetMulticastInterface(ifi *net.Interface) error +} + +// ConnectionFactory creates multicast connections. +// This abstraction allows injecting mock connections for testing. +type ConnectionFactory interface { + // CreateIPv4Conn creates an IPv4 multicast connection joined to the mDNS group. + CreateIPv4Conn(ifaces []net.Interface) (PacketConn, error) + + // CreateIPv6Conn creates an IPv6 multicast connection joined to the mDNS group. + CreateIPv6Conn(ifaces []net.Interface) (PacketConn, error) +} + +// InterfaceProvider lists network interfaces. +// This abstraction allows injecting mock interface lists for testing. +type InterfaceProvider interface { + // MulticastInterfaces returns all network interfaces capable of multicast. + MulticastInterfaces() []net.Interface +} diff --git a/client.go b/client.go index 5d84541..bbd5703 100644 --- a/client.go +++ b/client.go @@ -3,17 +3,14 @@ package zeroconf import ( "context" "fmt" - "log" "math/rand" "net" "reflect" - "runtime" "strings" "time" + "github.com/enbility/zeroconf/v3/api" "github.com/miekg/dns" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" ) // IPType specifies the IP traffic the client listens for. @@ -32,15 +29,16 @@ const ( var initialQueryInterval = 4 * time.Second // Client structure encapsulates both IPv4/IPv6 UDP connections. -type client struct { - ipv4conn *ipv4.PacketConn - ipv6conn *ipv6.PacketConn +type Client struct { + ipv4conn api.PacketConn + ipv6conn api.PacketConn ifaces []net.Interface } type clientOpts struct { - listenOn IPType - ifaces []net.Interface + listenOn IPType + ifaces []net.Interface + connFactory api.ConnectionFactory } // ClientOption fills the option struct to configure intefaces, etc. @@ -64,6 +62,14 @@ func SelectIfaces(ifaces []net.Interface) ClientOption { } } +// WithClientConnFactory sets a custom connection factory for the client. +// This is primarily useful for testing with mock connections. +func WithClientConnFactory(factory api.ConnectionFactory) ClientOption { + return func(o *clientOpts) { + o.connFactory = factory + } +} + // Browse for all services of a given type in a given domain. // Received entries are sent on the entries channel. // It blocks until the context is canceled (or an error occurs). @@ -112,7 +118,7 @@ func applyOpts(options ...ClientOption) clientOpts { return conf } -func (c *client) run(ctx context.Context, params *lookupParams) error { +func (c *Client) run(ctx context.Context, params *lookupParams) error { ctx, cancel := context.WithCancel(ctx) done := make(chan struct{}) go func() { @@ -133,32 +139,44 @@ func defaultParams(service string) *lookupParams { return newLookupParams("", service, "local", false, make(chan *ServiceEntry), make(chan *ServiceEntry)) } -// Client structure constructor -func newClient(opts clientOpts) (*client, error) { +// NewClient creates a new mDNS client with the given options. +// This is the low-level constructor. For most use cases, prefer Browse() or Lookup(). +func NewClient(opts ...ClientOption) (*Client, error) { + return newClient(applyOpts(opts...)) +} + +// newClient is the internal constructor that takes pre-applied options. +func newClient(opts clientOpts) (*Client, error) { ifaces := opts.ifaces if len(ifaces) == 0 { - ifaces = listMulticastInterfaces() + ifaces = NewInterfaceProvider().MulticastInterfaces() + } + + factory := opts.connFactory + if factory == nil { + factory = NewConnectionFactory() } + // IPv4 interfaces - var ipv4conn *ipv4.PacketConn + var ipv4conn api.PacketConn if (opts.listenOn & IPv4) > 0 { var err error - ipv4conn, err = joinUdp4Multicast(ifaces) + ipv4conn, err = factory.CreateIPv4Conn(ifaces) if err != nil { return nil, err } } // IPv6 interfaces - var ipv6conn *ipv6.PacketConn + var ipv6conn api.PacketConn if (opts.listenOn & IPv6) > 0 { var err error - ipv6conn, err = joinUdp6Multicast(ifaces) + ipv6conn, err = factory.CreateIPv6Conn(ifaces) if err != nil { return nil, err } } - return &client{ + return &Client{ ipv4conn: ipv4conn, ipv6conn: ipv6conn, ifaces: ifaces, @@ -168,7 +186,7 @@ func newClient(opts clientOpts) (*client, error) { var cleanupFreq = 5 * time.Second // Start listeners and waits for the shutdown signal from exit channel -func (c *client) mainloop(ctx context.Context, params *lookupParams) { +func (c *Client) mainloop(ctx context.Context, params *lookupParams) { // start listening for responses msgCh := make(chan *dns.Msg, 32) if c.ipv4conn != nil { @@ -319,7 +337,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) { } // Shutdown client will close currently open connections and channel implicitly. -func (c *client) shutdown() { +func (c *Client) shutdown() { if c.ipv4conn != nil { c.ipv4conn.Close() } @@ -330,22 +348,8 @@ func (c *client) shutdown() { // Data receiving routine reads from connection, unpacks packets into dns.Msg // structures and sends them to a given msgCh channel -func (c *client) recv(ctx context.Context, l interface{}, msgCh chan *dns.Msg) { - var readFrom func([]byte) (n int, src net.Addr, err error) - - switch pConn := l.(type) { - case *ipv6.PacketConn: - readFrom = func(b []byte) (n int, src net.Addr, err error) { - n, _, src, err = pConn.ReadFrom(b) - return - } - case *ipv4.PacketConn: - readFrom = func(b []byte) (n int, src net.Addr, err error) { - n, _, src, err = pConn.ReadFrom(b) - return - } - - default: +func (c *Client) recv(ctx context.Context, conn api.PacketConn, msgCh chan *dns.Msg) { + if conn == nil { return } @@ -355,12 +359,11 @@ func (c *client) recv(ctx context.Context, l interface{}, msgCh chan *dns.Msg) { // Handles the following cases: // - ReadFrom aborts with error due to closed UDP connection -> causes ctx cancel // - ReadFrom aborts otherwise. - // TODO: the context check can be removed. Verify! if ctx.Err() != nil || fatalErr != nil { return } - n, _, err := readFrom(buf) + n, _, _, err := conn.ReadFrom(buf) if err != nil { fatalErr = err continue @@ -384,7 +387,7 @@ func (c *client) recv(ctx context.Context, l interface{}, msgCh chan *dns.Msg) { // the main processing loop or some timeout/cancel fires. // TODO: move error reporting to shutdown function as periodicQuery is called from // go routine context. -func (c *client) periodicQuery(ctx context.Context, params *lookupParams) error { +func (c *Client) periodicQuery(ctx context.Context, params *lookupParams) error { // Do the first query immediately. if err := c.query(params); err != nil { return err @@ -426,7 +429,7 @@ func (c *client) periodicQuery(ctx context.Context, params *lookupParams) error // Performs the actual query by service name (browse) or service instance name (lookup), // start response listeners goroutines and loops over the entries channel. -func (c *client) query(params *lookupParams) error { +func (c *Client) query(params *lookupParams) error { var serviceName, serviceInstanceName string serviceName = fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain)) @@ -448,44 +451,25 @@ func (c *client) query(params *lookupParams) error { } // Pack the dns.Msg and write to available connections (multicast) -func (c *client) sendQuery(msg *dns.Msg) error { +func (c *Client) sendQuery(msg *dns.Msg) error { buf, err := msg.Pack() if err != nil { return err } + + // Send to all interfaces via IPv4 if c.ipv4conn != nil { - // See https://pkg.go.dev/golang.org/x/net/ipv4#pkg-note-BUG - // As of Golang 1.18.4 - // On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented. - var wcm ipv4.ControlMessage - for ifi := range c.ifaces { - switch runtime.GOOS { - case "darwin", "ios", "linux": - wcm.IfIndex = c.ifaces[ifi].Index - default: - if err := c.ipv4conn.SetMulticastInterface(&c.ifaces[ifi]); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) - } - } - _, _ = c.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) + for _, iface := range c.ifaces { + _, _ = c.ipv4conn.WriteTo(buf, iface.Index, ipv4Addr) } } + + // Send to all interfaces via IPv6 if c.ipv6conn != nil { - // See https://pkg.go.dev/golang.org/x/net/ipv6#pkg-note-BUG - // As of Golang 1.18.4 - // On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented. - var wcm ipv6.ControlMessage - for ifi := range c.ifaces { - switch runtime.GOOS { - case "darwin", "ios", "linux": - wcm.IfIndex = c.ifaces[ifi].Index - default: - if err := c.ipv6conn.SetMulticastInterface(&c.ifaces[ifi]); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) - } - } - _, _ = c.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) + for _, iface := range c.ifaces { + _, _ = c.ipv6conn.WriteTo(buf, iface.Index, ipv6Addr) } } + return nil } diff --git a/client_unit_test.go b/client_unit_test.go new file mode 100644 index 0000000..b0a231f --- /dev/null +++ b/client_unit_test.go @@ -0,0 +1,588 @@ +package zeroconf + +import ( + "context" + "errors" + "net" + "sync" + "testing" + "time" + + "github.com/enbility/zeroconf/v3/mocks" + "github.com/miekg/dns" + "github.com/stretchr/testify/mock" +) + +// TestClient_SendQuery_WritesToConnections verifies sendQuery writes to both connections +func TestClient_SendQuery_WritesToConnections(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + + iface := net.Interface{Index: 1, Name: "eth0"} + + // Expect WriteTo to be called on both connections + mockIPv4.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).Return(0, nil).Once() + mockIPv6.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).Return(0, nil).Once() + + c := &Client{ + ipv4conn: mockIPv4, + ipv6conn: mockIPv6, + ifaces: []net.Interface{iface}, + } + + msg := new(dns.Msg) + msg.SetQuestion("_test._tcp.local.", dns.TypePTR) + + err := c.sendQuery(msg) + if err != nil { + t.Fatalf("sendQuery failed: %v", err) + } +} + +// TestClient_SendQuery_MultipleInterfaces verifies sendQuery writes to all interfaces +func TestClient_SendQuery_MultipleInterfaces(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + + ifaces := []net.Interface{ + {Index: 1, Name: "eth0"}, + {Index: 2, Name: "wlan0"}, + {Index: 3, Name: "lo0"}, + } + + // Expect WriteTo to be called 3 times on each connection (once per interface) + mockIPv4.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).Return(0, nil).Once() + mockIPv4.EXPECT().WriteTo(mock.Anything, 2, mock.Anything).Return(0, nil).Once() + mockIPv4.EXPECT().WriteTo(mock.Anything, 3, mock.Anything).Return(0, nil).Once() + mockIPv6.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).Return(0, nil).Once() + mockIPv6.EXPECT().WriteTo(mock.Anything, 2, mock.Anything).Return(0, nil).Once() + mockIPv6.EXPECT().WriteTo(mock.Anything, 3, mock.Anything).Return(0, nil).Once() + + c := &Client{ + ipv4conn: mockIPv4, + ipv6conn: mockIPv6, + ifaces: ifaces, + } + + msg := new(dns.Msg) + msg.SetQuestion("_test._tcp.local.", dns.TypePTR) + + err := c.sendQuery(msg) + if err != nil { + t.Fatalf("sendQuery failed: %v", err) + } +} + +// TestClient_SendQuery_IPv4Only verifies sendQuery handles IPv4-only client +func TestClient_SendQuery_IPv4Only(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + + mockIPv4.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).Return(0, nil).Once() + + c := &Client{ + ipv4conn: mockIPv4, + ipv6conn: nil, + ifaces: []net.Interface{{Index: 1, Name: "eth0"}}, + } + + msg := new(dns.Msg) + msg.SetQuestion("_test._tcp.local.", dns.TypePTR) + + err := c.sendQuery(msg) + if err != nil { + t.Fatalf("sendQuery failed: %v", err) + } +} + +// TestClient_SendQuery_IPv6Only verifies sendQuery handles IPv6-only client +func TestClient_SendQuery_IPv6Only(t *testing.T) { + mockIPv6 := mocks.NewMockPacketConn(t) + + mockIPv6.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).Return(0, nil).Once() + + c := &Client{ + ipv4conn: nil, + ipv6conn: mockIPv6, + ifaces: []net.Interface{{Index: 1, Name: "eth0"}}, + } + + msg := new(dns.Msg) + msg.SetQuestion("_test._tcp.local.", dns.TypePTR) + + err := c.sendQuery(msg) + if err != nil { + t.Fatalf("sendQuery failed: %v", err) + } +} + +// TestClient_Shutdown_ClosesConnections verifies shutdown properly closes connections +func TestClient_Shutdown_ClosesConnections(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + + mockIPv4.EXPECT().Close().Return(nil).Once() + mockIPv6.EXPECT().Close().Return(nil).Once() + + c := &Client{ + ipv4conn: mockIPv4, + ipv6conn: mockIPv6, + ifaces: []net.Interface{{Index: 1, Name: "eth0"}}, + } + + c.shutdown() +} + +// TestClientConfig verifies client configuration options +func TestClientConfig(t *testing.T) { + t.Run("default options", func(t *testing.T) { + opts := applyOpts() + if opts.listenOn != IPv4AndIPv6 { + t.Errorf("Expected default listenOn IPv4AndIPv6, got %d", opts.listenOn) + } + }) + + t.Run("IPv4 only", func(t *testing.T) { + opts := applyOpts(SelectIPTraffic(IPv4)) + if opts.listenOn != IPv4 { + t.Errorf("Expected listenOn IPv4, got %d", opts.listenOn) + } + }) + + t.Run("IPv6 only", func(t *testing.T) { + opts := applyOpts(SelectIPTraffic(IPv6)) + if opts.listenOn != IPv6 { + t.Errorf("Expected listenOn IPv6, got %d", opts.listenOn) + } + }) + + t.Run("custom interfaces", func(t *testing.T) { + ifaces := []net.Interface{{Index: 1, Name: "eth0"}} + opts := applyOpts(SelectIfaces(ifaces)) + if len(opts.ifaces) != 1 { + t.Errorf("Expected 1 interface, got %d", len(opts.ifaces)) + } + }) +} + +// TestNewClient_WithMockFactory verifies newClient uses the connection factory +func TestNewClient_WithMockFactory(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + factory := mocks.NewMockConnectionFactory(t) + + factory.EXPECT().CreateIPv4Conn(mock.Anything).Return(mockIPv4, nil).Once() + factory.EXPECT().CreateIPv6Conn(mock.Anything).Return(mockIPv6, nil).Once() + + opts := clientOpts{ + listenOn: IPv4AndIPv6, + connFactory: factory, + } + + c, err := newClient(opts) + if err != nil { + t.Fatalf("newClient failed: %v", err) + } + + if c.ipv4conn != mockIPv4 { + t.Error("Expected mock IPv4 connection to be used") + } + if c.ipv6conn != mockIPv6 { + t.Error("Expected mock IPv6 connection to be used") + } +} + +// TestNewClient_ExportedConstructor verifies the exported NewClient constructor +func TestNewClient_ExportedConstructor(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + factory := mocks.NewMockConnectionFactory(t) + + factory.EXPECT().CreateIPv4Conn(mock.Anything).Return(mockIPv4, nil).Once() + factory.EXPECT().CreateIPv6Conn(mock.Anything).Return(mockIPv6, nil).Once() + + c, err := NewClient(WithClientConnFactory(factory)) + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + if c.ipv4conn != mockIPv4 { + t.Error("Expected mock IPv4 connection to be used") + } + if c.ipv6conn != mockIPv6 { + t.Error("Expected mock IPv6 connection to be used") + } +} + +// TestWithClientConnFactory verifies the WithClientConnFactory option +func TestWithClientConnFactory(t *testing.T) { + factory := mocks.NewMockConnectionFactory(t) + + opts := applyOpts(WithClientConnFactory(factory)) + + if opts.connFactory != factory { + t.Error("Expected connection factory to be set") + } +} + +// TestClient_Query_WithInstance verifies query builds correct message for Lookup +func TestClient_Query_WithInstance(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + + // Capture the DNS message to verify it contains SRV and TXT questions + var capturedMsg []byte + mockIPv4.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).RunAndReturn( + func(b []byte, ifIndex int, dst net.Addr) (int, error) { + capturedMsg = make([]byte, len(b)) + copy(capturedMsg, b) + return len(b), nil + }).Once() + + c := &Client{ + ipv4conn: mockIPv4, + ipv6conn: nil, + ifaces: []net.Interface{{Index: 1, Name: "eth0"}}, + } + + params := newLookupParams("myservice", "_http._tcp", "local", false, + make(chan *ServiceEntry), make(chan *ServiceEntry)) + + err := c.query(params) + if err != nil { + t.Fatalf("query failed: %v", err) + } + + // Parse the captured message + msg := new(dns.Msg) + if err := msg.Unpack(capturedMsg); err != nil { + t.Fatalf("Failed to unpack captured message: %v", err) + } + + // For instance lookup, we expect SRV and TXT questions + if len(msg.Question) != 2 { + t.Fatalf("Expected 2 questions for instance lookup, got %d", len(msg.Question)) + } + + // Check question types + hasSRV := false + hasTXT := false + for _, q := range msg.Question { + if q.Qtype == dns.TypeSRV { + hasSRV = true + } + if q.Qtype == dns.TypeTXT { + hasTXT = true + } + } + + if !hasSRV { + t.Error("Expected SRV question for instance lookup") + } + if !hasTXT { + t.Error("Expected TXT question for instance lookup") + } +} + +// TestClient_Query_Browse verifies query builds correct message for Browse +func TestClient_Query_Browse(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + + var capturedMsg []byte + mockIPv4.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).RunAndReturn( + func(b []byte, ifIndex int, dst net.Addr) (int, error) { + capturedMsg = make([]byte, len(b)) + copy(capturedMsg, b) + return len(b), nil + }).Once() + + c := &Client{ + ipv4conn: mockIPv4, + ipv6conn: nil, + ifaces: []net.Interface{{Index: 1, Name: "eth0"}}, + } + + // No instance = browse mode + params := newLookupParams("", "_http._tcp", "local", true, + make(chan *ServiceEntry), make(chan *ServiceEntry)) + + err := c.query(params) + if err != nil { + t.Fatalf("query failed: %v", err) + } + + msg := new(dns.Msg) + if err := msg.Unpack(capturedMsg); err != nil { + t.Fatalf("Failed to unpack captured message: %v", err) + } + + // For browse, we expect a single PTR question + if len(msg.Question) != 1 { + t.Fatalf("Expected 1 question for browse, got %d", len(msg.Question)) + } + + if msg.Question[0].Qtype != dns.TypePTR { + t.Errorf("Expected PTR question for browse, got %d", msg.Question[0].Qtype) + } +} + +// createMockDNSResponse creates a complete DNS response for testing Lookup +func createMockDNSResponse(instanceName, hostName string, port uint16, ip net.IP) []byte { + msg := new(dns.Msg) + msg.Response = true + + // SRV record + msg.Answer = append(msg.Answer, &dns.SRV{ + Hdr: dns.RR_Header{ + Name: instanceName, + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: 120, + }, + Priority: 0, + Weight: 0, + Port: port, + Target: hostName, + }) + + // TXT record + msg.Answer = append(msg.Answer, &dns.TXT{ + Hdr: dns.RR_Header{ + Name: instanceName, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: 120, + }, + Txt: []string{"key=value"}, + }) + + // A record + msg.Extra = append(msg.Extra, &dns.A{ + Hdr: dns.RR_Header{ + Name: hostName, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 120, + }, + A: ip, + }) + + data, _ := msg.Pack() + return data +} + +// TestBrowse_WithMockConnections tests the full Browse flow with mocked connections +func TestBrowse_WithMockConnections(t *testing.T) { + // Reduce query interval for faster test + oldInterval := initialQueryInterval + initialQueryInterval = 50 * time.Millisecond + defer func() { initialQueryInterval = oldInterval }() + + mockIPv4 := mocks.NewMockPacketConn(t) + factory := mocks.NewMockConnectionFactory(t) + + factory.EXPECT().CreateIPv4Conn(mock.Anything).Return(mockIPv4, nil).Once() + + // Create a DNS response with PTR record (for browse) + instanceName := "myservice._http._tcp.local." + serviceName := "_http._tcp.local." + hostName := "myhost.local." + + msg := new(dns.Msg) + msg.Response = true + + // PTR record pointing to the instance + msg.Answer = append(msg.Answer, &dns.PTR{ + Hdr: dns.RR_Header{ + Name: serviceName, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: 120, + }, + Ptr: instanceName, + }) + + // SRV record + msg.Answer = append(msg.Answer, &dns.SRV{ + Hdr: dns.RR_Header{ + Name: instanceName, + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: 120, + }, + Port: 8080, + Target: hostName, + }) + + // TXT record + msg.Answer = append(msg.Answer, &dns.TXT{ + Hdr: dns.RR_Header{ + Name: instanceName, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: 120, + }, + Txt: []string{"version=1.0"}, + }) + + // A record + msg.Extra = append(msg.Extra, &dns.A{ + Hdr: dns.RR_Header{ + Name: hostName, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 120, + }, + A: net.ParseIP("192.168.1.100"), + }) + + responseData, _ := msg.Pack() + + var readCount int + var mu sync.Mutex + + mockIPv4.EXPECT().WriteTo(mock.Anything, mock.Anything, mock.Anything).Return(0, nil).Maybe() + mockIPv4.EXPECT().ReadFrom(mock.Anything).RunAndReturn(func(b []byte) (int, int, net.Addr, error) { + mu.Lock() + readCount++ + count := readCount + mu.Unlock() + + if count == 1 { + copy(b, responseData) + return len(responseData), 1, &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 5353}, nil + } + time.Sleep(100 * time.Millisecond) + return 0, 0, nil, errors.New("context cancelled") + }).Maybe() + mockIPv4.EXPECT().Close().Return(nil).Maybe() + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + entries := make(chan *ServiceEntry, 1) + removed := make(chan *ServiceEntry, 1) + + var browseErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + browseErr = Browse(ctx, "_http._tcp", "local", entries, removed, + WithClientConnFactory(factory), + SelectIPTraffic(IPv4)) + }() + + select { + case entry := <-entries: + if entry.Instance != "myservice" { + t.Errorf("Expected instance 'myservice', got '%s'", entry.Instance) + } + if entry.Port != 8080 { + t.Errorf("Expected port 8080, got %d", entry.Port) + } + if len(entry.Text) == 0 || entry.Text[0] != "version=1.0" { + t.Errorf("Expected text 'version=1.0', got %v", entry.Text) + } + cancel() + case <-ctx.Done(): + t.Log("Context done before receiving entry") + } + + wg.Wait() + + if browseErr != nil && browseErr != context.DeadlineExceeded && browseErr != context.Canceled { + t.Errorf("Browse returned unexpected error: %v", browseErr) + } +} + +// TestLookup_WithMockConnections tests the full Lookup flow with mocked connections +func TestLookup_WithMockConnections(t *testing.T) { + // Reduce query interval for faster test + oldInterval := initialQueryInterval + initialQueryInterval = 50 * time.Millisecond + defer func() { initialQueryInterval = oldInterval }() + + mockIPv4 := mocks.NewMockPacketConn(t) + factory := mocks.NewMockConnectionFactory(t) + + // Factory returns our mock connection (IPv4 only since we use SelectIPTraffic(IPv4)) + factory.EXPECT().CreateIPv4Conn(mock.Anything).Return(mockIPv4, nil).Once() + + // Create the DNS response + instanceName := "myservice._http._tcp.local." + hostName := "myhost.local." + responseData := createMockDNSResponse(instanceName, hostName, 8080, net.ParseIP("192.168.1.100")) + + // Track ReadFrom calls + var readCount int + var mu sync.Mutex + + // WriteTo for queries - just accept them + mockIPv4.EXPECT().WriteTo(mock.Anything, mock.Anything, mock.Anything).Return(0, nil).Maybe() + + // ReadFrom returns the response once, then blocks + mockIPv4.EXPECT().ReadFrom(mock.Anything).RunAndReturn(func(b []byte) (int, int, net.Addr, error) { + mu.Lock() + readCount++ + count := readCount + mu.Unlock() + + if count == 1 { + // First call: return the DNS response + copy(b, responseData) + return len(responseData), 1, &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 5353}, nil + } + // Subsequent calls: block until test ends (simulates waiting for more data) + time.Sleep(100 * time.Millisecond) + return 0, 0, nil, errors.New("context cancelled") + }).Maybe() + + // Close when shutdown + mockIPv4.EXPECT().Close().Return(nil).Maybe() + + // Create context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + entries := make(chan *ServiceEntry, 1) + + // Run Lookup in background + var lookupErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + lookupErr = Lookup(ctx, "myservice", "_http._tcp", "local", entries, + WithClientConnFactory(factory), + SelectIPTraffic(IPv4)) + }() + + // Wait for entry or timeout + select { + case entry := <-entries: + if entry.Instance != "myservice" { + t.Errorf("Expected instance 'myservice', got '%s'", entry.Instance) + } + if entry.Port != 8080 { + t.Errorf("Expected port 8080, got %d", entry.Port) + } + if entry.HostName != hostName { + t.Errorf("Expected hostname '%s', got '%s'", hostName, entry.HostName) + } + if len(entry.AddrIPv4) == 0 { + t.Error("Expected IPv4 address") + } else if !entry.AddrIPv4[0].Equal(net.ParseIP("192.168.1.100")) { + t.Errorf("Expected IP 192.168.1.100, got %s", entry.AddrIPv4[0]) + } + // Success - cancel to clean up + cancel() + case <-ctx.Done(): + t.Log("Context done before receiving entry (may be timing issue)") + } + + wg.Wait() + + // Context cancellation is expected, not an error for Lookup + if lookupErr != nil && lookupErr != context.DeadlineExceeded && lookupErr != context.Canceled { + t.Errorf("Lookup returned unexpected error: %v", lookupErr) + } +} diff --git a/conn_factory.go b/conn_factory.go new file mode 100644 index 0000000..a706b60 --- /dev/null +++ b/conn_factory.go @@ -0,0 +1,72 @@ +package zeroconf + +import ( + "fmt" + "net" + + "github.com/enbility/zeroconf/v3/api" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +// defaultConnectionFactory is the production implementation of api.ConnectionFactory. +// It creates real UDP multicast connections for mDNS communication. +type defaultConnectionFactory struct{} + +// Compile-time interface check +var _ api.ConnectionFactory = (*defaultConnectionFactory)(nil) + +// NewConnectionFactory creates a new default connection factory. +func NewConnectionFactory() api.ConnectionFactory { + return &defaultConnectionFactory{} +} + +func (f *defaultConnectionFactory) CreateIPv4Conn(ifaces []net.Interface) (api.PacketConn, error) { + udpConn, err := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) + if err != nil { + return nil, err + } + + pkConn := ipv4.NewPacketConn(udpConn) + _ = pkConn.SetControlMessage(ipv4.FlagInterface, true) + + var failedJoins int + for _, iface := range ifaces { + if err := pkConn.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { + failedJoins++ + } + } + if failedJoins == len(ifaces) { + pkConn.Close() + return nil, fmt.Errorf("udp4: failed to join any of these interfaces: %v", ifaces) + } + + _ = pkConn.SetMulticastTTL(255) + + return newIPv4PacketConn(pkConn), nil +} + +func (f *defaultConnectionFactory) CreateIPv6Conn(ifaces []net.Interface) (api.PacketConn, error) { + udpConn, err := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) + if err != nil { + return nil, err + } + + pkConn := ipv6.NewPacketConn(udpConn) + _ = pkConn.SetControlMessage(ipv6.FlagInterface, true) + + var failedJoins int + for _, iface := range ifaces { + if err := pkConn.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { + failedJoins++ + } + } + if failedJoins == len(ifaces) { + pkConn.Close() + return nil, fmt.Errorf("udp6: failed to join any of these interfaces: %v", ifaces) + } + + _ = pkConn.SetMulticastHopLimit(255) + + return newIPv6PacketConn(pkConn), nil +} diff --git a/conn_ipv4.go b/conn_ipv4.go new file mode 100644 index 0000000..f874ff6 --- /dev/null +++ b/conn_ipv4.go @@ -0,0 +1,80 @@ +package zeroconf + +import ( + "log" + "net" + "runtime" + + "github.com/enbility/zeroconf/v3/api" + "golang.org/x/net/ipv4" +) + +// ipv4PacketConn wraps ipv4.PacketConn to implement api.PacketConn interface. +// This adapter is needed because ipv4.PacketConn uses ControlMessage for +// interface selection, but we only need the IfIndex field. +type ipv4PacketConn struct { + conn *ipv4.PacketConn +} + +// Compile-time interface check +var _ api.PacketConn = (*ipv4PacketConn)(nil) + +// newIPv4PacketConn creates a new IPv4 PacketConn wrapper. +func newIPv4PacketConn(conn *ipv4.PacketConn) *ipv4PacketConn { + return &ipv4PacketConn{conn: conn} +} + +func (c *ipv4PacketConn) ReadFrom(b []byte) (n int, ifIndex int, src net.Addr, err error) { + n, cm, src, err := c.conn.ReadFrom(b) + if cm != nil { + ifIndex = cm.IfIndex + } + return +} + +func (c *ipv4PacketConn) WriteTo(b []byte, ifIndex int, dst net.Addr) (n int, err error) { + // See https://pkg.go.dev/golang.org/x/net/ipv4#pkg-note-BUG + // On Windows, the ControlMessage for WriteTo is not implemented. + // Use SetMulticastInterface as fallback. + var cm *ipv4.ControlMessage + if ifIndex != 0 { + switch runtime.GOOS { + case "darwin", "ios", "linux": + cm = &ipv4.ControlMessage{IfIndex: ifIndex} + default: + // Windows and other platforms: use SetMulticastInterface + iface, _ := net.InterfaceByIndex(ifIndex) + if iface != nil { + if err := c.conn.SetMulticastInterface(iface); err != nil { + log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) + } + } + } + } + return c.conn.WriteTo(b, cm, dst) +} + +func (c *ipv4PacketConn) Close() error { + return c.conn.Close() +} + +func (c *ipv4PacketConn) JoinGroup(ifi *net.Interface, group net.Addr) error { + return c.conn.JoinGroup(ifi, group) +} + +func (c *ipv4PacketConn) LeaveGroup(ifi *net.Interface, group net.Addr) error { + return c.conn.LeaveGroup(ifi, group) +} + +func (c *ipv4PacketConn) SetMulticastTTL(ttl int) error { + return c.conn.SetMulticastTTL(ttl) +} + +func (c *ipv4PacketConn) SetMulticastHopLimit(hopLimit int) error { + // IPv4 doesn't have hop limit, this is a no-op + return nil +} + +func (c *ipv4PacketConn) SetMulticastInterface(ifi *net.Interface) error { + return c.conn.SetMulticastInterface(ifi) +} diff --git a/conn_ipv6.go b/conn_ipv6.go new file mode 100644 index 0000000..0bdb8eb --- /dev/null +++ b/conn_ipv6.go @@ -0,0 +1,80 @@ +package zeroconf + +import ( + "log" + "net" + "runtime" + + "github.com/enbility/zeroconf/v3/api" + "golang.org/x/net/ipv6" +) + +// ipv6PacketConn wraps ipv6.PacketConn to implement api.PacketConn interface. +// This adapter is needed because ipv6.PacketConn uses ControlMessage for +// interface selection, but we only need the IfIndex field. +type ipv6PacketConn struct { + conn *ipv6.PacketConn +} + +// Compile-time interface check +var _ api.PacketConn = (*ipv6PacketConn)(nil) + +// newIPv6PacketConn creates a new IPv6 PacketConn wrapper. +func newIPv6PacketConn(conn *ipv6.PacketConn) *ipv6PacketConn { + return &ipv6PacketConn{conn: conn} +} + +func (c *ipv6PacketConn) ReadFrom(b []byte) (n int, ifIndex int, src net.Addr, err error) { + n, cm, src, err := c.conn.ReadFrom(b) + if cm != nil { + ifIndex = cm.IfIndex + } + return +} + +func (c *ipv6PacketConn) WriteTo(b []byte, ifIndex int, dst net.Addr) (n int, err error) { + // See https://pkg.go.dev/golang.org/x/net/ipv6#pkg-note-BUG + // On Windows, the ControlMessage for WriteTo is not implemented. + // Use SetMulticastInterface as fallback. + var cm *ipv6.ControlMessage + if ifIndex != 0 { + switch runtime.GOOS { + case "darwin", "ios", "linux": + cm = &ipv6.ControlMessage{IfIndex: ifIndex} + default: + // Windows and other platforms: use SetMulticastInterface + iface, _ := net.InterfaceByIndex(ifIndex) + if iface != nil { + if err := c.conn.SetMulticastInterface(iface); err != nil { + log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) + } + } + } + } + return c.conn.WriteTo(b, cm, dst) +} + +func (c *ipv6PacketConn) Close() error { + return c.conn.Close() +} + +func (c *ipv6PacketConn) JoinGroup(ifi *net.Interface, group net.Addr) error { + return c.conn.JoinGroup(ifi, group) +} + +func (c *ipv6PacketConn) LeaveGroup(ifi *net.Interface, group net.Addr) error { + return c.conn.LeaveGroup(ifi, group) +} + +func (c *ipv6PacketConn) SetMulticastTTL(ttl int) error { + // IPv6 doesn't have TTL, this is a no-op + return nil +} + +func (c *ipv6PacketConn) SetMulticastHopLimit(hopLimit int) error { + return c.conn.SetMulticastHopLimit(hopLimit) +} + +func (c *ipv6PacketConn) SetMulticastInterface(ifi *net.Interface) error { + return c.conn.SetMulticastInterface(ifi) +} diff --git a/conn_provider.go b/conn_provider.go new file mode 100644 index 0000000..611e853 --- /dev/null +++ b/conn_provider.go @@ -0,0 +1,37 @@ +package zeroconf + +import ( + "net" + + "github.com/enbility/zeroconf/v3/api" +) + +// defaultInterfaceProvider is the production implementation of api.InterfaceProvider. +// It lists network interfaces capable of multicast communication. +type defaultInterfaceProvider struct{} + +// Compile-time interface check +var _ api.InterfaceProvider = (*defaultInterfaceProvider)(nil) + +// NewInterfaceProvider creates a new default interface provider. +func NewInterfaceProvider() api.InterfaceProvider { + return &defaultInterfaceProvider{} +} + +// MulticastInterfaces returns all network interfaces that are up and support multicast. +func (p *defaultInterfaceProvider) MulticastInterfaces() []net.Interface { + var interfaces []net.Interface + ifaces, err := net.Interfaces() + if err != nil { + return nil + } + for _, ifi := range ifaces { + if (ifi.Flags & net.FlagUp) == 0 { + continue + } + if (ifi.Flags & net.FlagMulticast) > 0 { + interfaces = append(interfaces, ifi) + } + } + return interfaces +} diff --git a/connection.go b/connection.go deleted file mode 100644 index 0efbac1..0000000 --- a/connection.go +++ /dev/null @@ -1,119 +0,0 @@ -package zeroconf - -import ( - "fmt" - "net" - - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" -) - -var ( - // Multicast groups used by mDNS - mdnsGroupIPv4 = net.IPv4(224, 0, 0, 251) - mdnsGroupIPv6 = net.ParseIP("ff02::fb") - - // mDNS wildcard addresses - mdnsWildcardAddrIPv4 = &net.UDPAddr{ - IP: net.ParseIP("224.0.0.0"), - Port: 5353, - } - mdnsWildcardAddrIPv6 = &net.UDPAddr{ - IP: net.ParseIP("ff02::"), - // IP: net.ParseIP("fd00::12d3:26e7:48db:e7d"), - Port: 5353, - } - - // mDNS endpoint addresses - ipv4Addr = &net.UDPAddr{ - IP: mdnsGroupIPv4, - Port: 5353, - } - ipv6Addr = &net.UDPAddr{ - IP: mdnsGroupIPv6, - Port: 5353, - } -) - -func joinUdp6Multicast(interfaces []net.Interface) (*ipv6.PacketConn, error) { - udpConn, err := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) - if err != nil { - return nil, err - } - - // Join multicast groups to receive announcements - pkConn := ipv6.NewPacketConn(udpConn) - _ = pkConn.SetControlMessage(ipv6.FlagInterface, true) - - if len(interfaces) == 0 { - interfaces = listMulticastInterfaces() - } - // log.Println("Using multicast interfaces: ", interfaces) - - var failedJoins int - for _, iface := range interfaces { - if err := pkConn.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { - // log.Println("Udp6 JoinGroup failed for iface ", iface) - failedJoins++ - } - } - if failedJoins == len(interfaces) { - pkConn.Close() - return nil, fmt.Errorf("udp6: failed to join any of these interfaces: %v", interfaces) - } - - _ = pkConn.SetMulticastHopLimit(255) - - return pkConn, nil -} - -func joinUdp4Multicast(interfaces []net.Interface) (*ipv4.PacketConn, error) { - udpConn, err := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) - if err != nil { - // log.Printf("[ERR] bonjour: Failed to bind to udp4 mutlicast: %v", err) - return nil, err - } - - // Join multicast groups to receive announcements - pkConn := ipv4.NewPacketConn(udpConn) - _ = pkConn.SetControlMessage(ipv4.FlagInterface, true) - - if len(interfaces) == 0 { - interfaces = listMulticastInterfaces() - } - // log.Println("Using multicast interfaces: ", interfaces) - - var failedJoins int - for _, iface := range interfaces { - if err := pkConn.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { - // log.Println("Udp4 JoinGroup failed for iface ", iface) - failedJoins++ - } - } - if failedJoins == len(interfaces) { - pkConn.Close() - return nil, fmt.Errorf("udp4: failed to join any of these interfaces: %v", interfaces) - } - - _ = pkConn.SetMulticastTTL(255) - - return pkConn, nil -} - -func listMulticastInterfaces() []net.Interface { - var interfaces []net.Interface - ifaces, err := net.Interfaces() - if err != nil { - return nil - } - for _, ifi := range ifaces { - if (ifi.Flags & net.FlagUp) == 0 { - continue - } - if (ifi.Flags & net.FlagMulticast) > 0 { - interfaces = append(interfaces, ifi) - } - } - - return interfaces -} diff --git a/go.mod b/go.mod index 9173cd8..8e6d8f8 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,17 @@ go 1.22.0 require ( github.com/miekg/dns v1.1.62 + github.com/stretchr/testify v1.11.1 golang.org/x/net v0.29.0 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect golang.org/x/mod v0.21.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.25.0 // indirect golang.org/x/tools v0.25.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 3207eb1..bc8d728 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,13 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= @@ -10,3 +18,7 @@ golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE= golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/mdns.go b/mdns.go new file mode 100644 index 0000000..03c0a1d --- /dev/null +++ b/mdns.go @@ -0,0 +1,30 @@ +package zeroconf + +import "net" + +// mDNS network constants per RFC 6762 +var ( + // Multicast groups used by mDNS + mdnsGroupIPv4 = net.IPv4(224, 0, 0, 251) + mdnsGroupIPv6 = net.ParseIP("ff02::fb") + + // mDNS wildcard addresses for listening + mdnsWildcardAddrIPv4 = &net.UDPAddr{ + IP: net.ParseIP("224.0.0.0"), + Port: 5353, + } + mdnsWildcardAddrIPv6 = &net.UDPAddr{ + IP: net.ParseIP("ff02::"), + Port: 5353, + } + + // mDNS endpoint addresses for sending + ipv4Addr = &net.UDPAddr{ + IP: mdnsGroupIPv4, + Port: 5353, + } + ipv6Addr = &net.UDPAddr{ + IP: mdnsGroupIPv6, + Port: 5353, + } +) diff --git a/mocks/mock_connection_factory.go b/mocks/mock_connection_factory.go new file mode 100644 index 0000000..3f2e1b4 --- /dev/null +++ b/mocks/mock_connection_factory.go @@ -0,0 +1,163 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "net" + + "github.com/enbility/zeroconf/v3/api" + mock "github.com/stretchr/testify/mock" +) + +// NewMockConnectionFactory creates a new instance of MockConnectionFactory. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockConnectionFactory(t interface { + mock.TestingT + Cleanup(func()) +}) *MockConnectionFactory { + mock := &MockConnectionFactory{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// MockConnectionFactory is an autogenerated mock type for the ConnectionFactory type +type MockConnectionFactory struct { + mock.Mock +} + +type MockConnectionFactory_Expecter struct { + mock *mock.Mock +} + +func (_m *MockConnectionFactory) EXPECT() *MockConnectionFactory_Expecter { + return &MockConnectionFactory_Expecter{mock: &_m.Mock} +} + +// CreateIPv4Conn provides a mock function for the type MockConnectionFactory +func (_mock *MockConnectionFactory) CreateIPv4Conn(ifaces []net.Interface) (api.PacketConn, error) { + ret := _mock.Called(ifaces) + + if len(ret) == 0 { + panic("no return value specified for CreateIPv4Conn") + } + + var r0 api.PacketConn + var r1 error + if returnFunc, ok := ret.Get(0).(func([]net.Interface) (api.PacketConn, error)); ok { + return returnFunc(ifaces) + } + if returnFunc, ok := ret.Get(0).(func([]net.Interface) api.PacketConn); ok { + r0 = returnFunc(ifaces) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(api.PacketConn) + } + } + if returnFunc, ok := ret.Get(1).(func([]net.Interface) error); ok { + r1 = returnFunc(ifaces) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockConnectionFactory_CreateIPv4Conn_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateIPv4Conn' +type MockConnectionFactory_CreateIPv4Conn_Call struct { + *mock.Call +} + +// CreateIPv4Conn is a helper method to define mock.On call +// - ifaces []net.Interface +func (_e *MockConnectionFactory_Expecter) CreateIPv4Conn(ifaces interface{}) *MockConnectionFactory_CreateIPv4Conn_Call { + return &MockConnectionFactory_CreateIPv4Conn_Call{Call: _e.mock.On("CreateIPv4Conn", ifaces)} +} + +func (_c *MockConnectionFactory_CreateIPv4Conn_Call) Run(run func(ifaces []net.Interface)) *MockConnectionFactory_CreateIPv4Conn_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 []net.Interface + if args[0] != nil { + arg0 = args[0].([]net.Interface) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockConnectionFactory_CreateIPv4Conn_Call) Return(packetConn api.PacketConn, err error) *MockConnectionFactory_CreateIPv4Conn_Call { + _c.Call.Return(packetConn, err) + return _c +} + +func (_c *MockConnectionFactory_CreateIPv4Conn_Call) RunAndReturn(run func(ifaces []net.Interface) (api.PacketConn, error)) *MockConnectionFactory_CreateIPv4Conn_Call { + _c.Call.Return(run) + return _c +} + +// CreateIPv6Conn provides a mock function for the type MockConnectionFactory +func (_mock *MockConnectionFactory) CreateIPv6Conn(ifaces []net.Interface) (api.PacketConn, error) { + ret := _mock.Called(ifaces) + + if len(ret) == 0 { + panic("no return value specified for CreateIPv6Conn") + } + + var r0 api.PacketConn + var r1 error + if returnFunc, ok := ret.Get(0).(func([]net.Interface) (api.PacketConn, error)); ok { + return returnFunc(ifaces) + } + if returnFunc, ok := ret.Get(0).(func([]net.Interface) api.PacketConn); ok { + r0 = returnFunc(ifaces) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(api.PacketConn) + } + } + if returnFunc, ok := ret.Get(1).(func([]net.Interface) error); ok { + r1 = returnFunc(ifaces) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockConnectionFactory_CreateIPv6Conn_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateIPv6Conn' +type MockConnectionFactory_CreateIPv6Conn_Call struct { + *mock.Call +} + +// CreateIPv6Conn is a helper method to define mock.On call +// - ifaces []net.Interface +func (_e *MockConnectionFactory_Expecter) CreateIPv6Conn(ifaces interface{}) *MockConnectionFactory_CreateIPv6Conn_Call { + return &MockConnectionFactory_CreateIPv6Conn_Call{Call: _e.mock.On("CreateIPv6Conn", ifaces)} +} + +func (_c *MockConnectionFactory_CreateIPv6Conn_Call) Run(run func(ifaces []net.Interface)) *MockConnectionFactory_CreateIPv6Conn_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 []net.Interface + if args[0] != nil { + arg0 = args[0].([]net.Interface) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockConnectionFactory_CreateIPv6Conn_Call) Return(packetConn api.PacketConn, err error) *MockConnectionFactory_CreateIPv6Conn_Call { + _c.Call.Return(packetConn, err) + return _c +} + +func (_c *MockConnectionFactory_CreateIPv6Conn_Call) RunAndReturn(run func(ifaces []net.Interface) (api.PacketConn, error)) *MockConnectionFactory_CreateIPv6Conn_Call { + _c.Call.Return(run) + return _c +} diff --git a/mocks/mock_interface_provider.go b/mocks/mock_interface_provider.go new file mode 100644 index 0000000..60f7f27 --- /dev/null +++ b/mocks/mock_interface_provider.go @@ -0,0 +1,84 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "net" + + mock "github.com/stretchr/testify/mock" +) + +// NewMockInterfaceProvider creates a new instance of MockInterfaceProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockInterfaceProvider(t interface { + mock.TestingT + Cleanup(func()) +}) *MockInterfaceProvider { + mock := &MockInterfaceProvider{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// MockInterfaceProvider is an autogenerated mock type for the InterfaceProvider type +type MockInterfaceProvider struct { + mock.Mock +} + +type MockInterfaceProvider_Expecter struct { + mock *mock.Mock +} + +func (_m *MockInterfaceProvider) EXPECT() *MockInterfaceProvider_Expecter { + return &MockInterfaceProvider_Expecter{mock: &_m.Mock} +} + +// MulticastInterfaces provides a mock function for the type MockInterfaceProvider +func (_mock *MockInterfaceProvider) MulticastInterfaces() []net.Interface { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for MulticastInterfaces") + } + + var r0 []net.Interface + if returnFunc, ok := ret.Get(0).(func() []net.Interface); ok { + r0 = returnFunc() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]net.Interface) + } + } + return r0 +} + +// MockInterfaceProvider_MulticastInterfaces_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MulticastInterfaces' +type MockInterfaceProvider_MulticastInterfaces_Call struct { + *mock.Call +} + +// MulticastInterfaces is a helper method to define mock.On call +func (_e *MockInterfaceProvider_Expecter) MulticastInterfaces() *MockInterfaceProvider_MulticastInterfaces_Call { + return &MockInterfaceProvider_MulticastInterfaces_Call{Call: _e.mock.On("MulticastInterfaces")} +} + +func (_c *MockInterfaceProvider_MulticastInterfaces_Call) Run(run func()) *MockInterfaceProvider_MulticastInterfaces_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockInterfaceProvider_MulticastInterfaces_Call) Return(interfaces []net.Interface) *MockInterfaceProvider_MulticastInterfaces_Call { + _c.Call.Return(interfaces) + return _c +} + +func (_c *MockInterfaceProvider_MulticastInterfaces_Call) RunAndReturn(run func() []net.Interface) *MockInterfaceProvider_MulticastInterfaces_Call { + _c.Call.Return(run) + return _c +} diff --git a/mocks/mock_packet_conn.go b/mocks/mock_packet_conn.go new file mode 100644 index 0000000..c4c9da6 --- /dev/null +++ b/mocks/mock_packet_conn.go @@ -0,0 +1,495 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "net" + + mock "github.com/stretchr/testify/mock" +) + +// NewMockPacketConn creates a new instance of MockPacketConn. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockPacketConn(t interface { + mock.TestingT + Cleanup(func()) +}) *MockPacketConn { + mock := &MockPacketConn{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// MockPacketConn is an autogenerated mock type for the PacketConn type +type MockPacketConn struct { + mock.Mock +} + +type MockPacketConn_Expecter struct { + mock *mock.Mock +} + +func (_m *MockPacketConn) EXPECT() *MockPacketConn_Expecter { + return &MockPacketConn_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function for the type MockPacketConn +func (_mock *MockPacketConn) Close() error { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func() error); ok { + r0 = returnFunc() + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockPacketConn_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockPacketConn_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockPacketConn_Expecter) Close() *MockPacketConn_Close_Call { + return &MockPacketConn_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockPacketConn_Close_Call) Run(run func()) *MockPacketConn_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockPacketConn_Close_Call) Return(err error) *MockPacketConn_Close_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockPacketConn_Close_Call) RunAndReturn(run func() error) *MockPacketConn_Close_Call { + _c.Call.Return(run) + return _c +} + +// JoinGroup provides a mock function for the type MockPacketConn +func (_mock *MockPacketConn) JoinGroup(ifi *net.Interface, group net.Addr) error { + ret := _mock.Called(ifi, group) + + if len(ret) == 0 { + panic("no return value specified for JoinGroup") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(*net.Interface, net.Addr) error); ok { + r0 = returnFunc(ifi, group) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockPacketConn_JoinGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'JoinGroup' +type MockPacketConn_JoinGroup_Call struct { + *mock.Call +} + +// JoinGroup is a helper method to define mock.On call +// - ifi *net.Interface +// - group net.Addr +func (_e *MockPacketConn_Expecter) JoinGroup(ifi interface{}, group interface{}) *MockPacketConn_JoinGroup_Call { + return &MockPacketConn_JoinGroup_Call{Call: _e.mock.On("JoinGroup", ifi, group)} +} + +func (_c *MockPacketConn_JoinGroup_Call) Run(run func(ifi *net.Interface, group net.Addr)) *MockPacketConn_JoinGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 *net.Interface + if args[0] != nil { + arg0 = args[0].(*net.Interface) + } + var arg1 net.Addr + if args[1] != nil { + arg1 = args[1].(net.Addr) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockPacketConn_JoinGroup_Call) Return(err error) *MockPacketConn_JoinGroup_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockPacketConn_JoinGroup_Call) RunAndReturn(run func(ifi *net.Interface, group net.Addr) error) *MockPacketConn_JoinGroup_Call { + _c.Call.Return(run) + return _c +} + +// LeaveGroup provides a mock function for the type MockPacketConn +func (_mock *MockPacketConn) LeaveGroup(ifi *net.Interface, group net.Addr) error { + ret := _mock.Called(ifi, group) + + if len(ret) == 0 { + panic("no return value specified for LeaveGroup") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(*net.Interface, net.Addr) error); ok { + r0 = returnFunc(ifi, group) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockPacketConn_LeaveGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LeaveGroup' +type MockPacketConn_LeaveGroup_Call struct { + *mock.Call +} + +// LeaveGroup is a helper method to define mock.On call +// - ifi *net.Interface +// - group net.Addr +func (_e *MockPacketConn_Expecter) LeaveGroup(ifi interface{}, group interface{}) *MockPacketConn_LeaveGroup_Call { + return &MockPacketConn_LeaveGroup_Call{Call: _e.mock.On("LeaveGroup", ifi, group)} +} + +func (_c *MockPacketConn_LeaveGroup_Call) Run(run func(ifi *net.Interface, group net.Addr)) *MockPacketConn_LeaveGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 *net.Interface + if args[0] != nil { + arg0 = args[0].(*net.Interface) + } + var arg1 net.Addr + if args[1] != nil { + arg1 = args[1].(net.Addr) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockPacketConn_LeaveGroup_Call) Return(err error) *MockPacketConn_LeaveGroup_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockPacketConn_LeaveGroup_Call) RunAndReturn(run func(ifi *net.Interface, group net.Addr) error) *MockPacketConn_LeaveGroup_Call { + _c.Call.Return(run) + return _c +} + +// ReadFrom provides a mock function for the type MockPacketConn +func (_mock *MockPacketConn) ReadFrom(b []byte) (int, int, net.Addr, error) { + ret := _mock.Called(b) + + if len(ret) == 0 { + panic("no return value specified for ReadFrom") + } + + var r0 int + var r1 int + var r2 net.Addr + var r3 error + if returnFunc, ok := ret.Get(0).(func([]byte) (int, int, net.Addr, error)); ok { + return returnFunc(b) + } + if returnFunc, ok := ret.Get(0).(func([]byte) int); ok { + r0 = returnFunc(b) + } else { + r0 = ret.Get(0).(int) + } + if returnFunc, ok := ret.Get(1).(func([]byte) int); ok { + r1 = returnFunc(b) + } else { + r1 = ret.Get(1).(int) + } + if returnFunc, ok := ret.Get(2).(func([]byte) net.Addr); ok { + r2 = returnFunc(b) + } else { + if ret.Get(2) != nil { + r2 = ret.Get(2).(net.Addr) + } + } + if returnFunc, ok := ret.Get(3).(func([]byte) error); ok { + r3 = returnFunc(b) + } else { + r3 = ret.Error(3) + } + return r0, r1, r2, r3 +} + +// MockPacketConn_ReadFrom_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReadFrom' +type MockPacketConn_ReadFrom_Call struct { + *mock.Call +} + +// ReadFrom is a helper method to define mock.On call +// - b []byte +func (_e *MockPacketConn_Expecter) ReadFrom(b interface{}) *MockPacketConn_ReadFrom_Call { + return &MockPacketConn_ReadFrom_Call{Call: _e.mock.On("ReadFrom", b)} +} + +func (_c *MockPacketConn_ReadFrom_Call) Run(run func(b []byte)) *MockPacketConn_ReadFrom_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 []byte + if args[0] != nil { + arg0 = args[0].([]byte) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockPacketConn_ReadFrom_Call) Return(n int, ifIndex int, src net.Addr, err error) *MockPacketConn_ReadFrom_Call { + _c.Call.Return(n, ifIndex, src, err) + return _c +} + +func (_c *MockPacketConn_ReadFrom_Call) RunAndReturn(run func(b []byte) (int, int, net.Addr, error)) *MockPacketConn_ReadFrom_Call { + _c.Call.Return(run) + return _c +} + +// SetMulticastHopLimit provides a mock function for the type MockPacketConn +func (_mock *MockPacketConn) SetMulticastHopLimit(hopLimit int) error { + ret := _mock.Called(hopLimit) + + if len(ret) == 0 { + panic("no return value specified for SetMulticastHopLimit") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(int) error); ok { + r0 = returnFunc(hopLimit) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockPacketConn_SetMulticastHopLimit_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetMulticastHopLimit' +type MockPacketConn_SetMulticastHopLimit_Call struct { + *mock.Call +} + +// SetMulticastHopLimit is a helper method to define mock.On call +// - hopLimit int +func (_e *MockPacketConn_Expecter) SetMulticastHopLimit(hopLimit interface{}) *MockPacketConn_SetMulticastHopLimit_Call { + return &MockPacketConn_SetMulticastHopLimit_Call{Call: _e.mock.On("SetMulticastHopLimit", hopLimit)} +} + +func (_c *MockPacketConn_SetMulticastHopLimit_Call) Run(run func(hopLimit int)) *MockPacketConn_SetMulticastHopLimit_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 int + if args[0] != nil { + arg0 = args[0].(int) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockPacketConn_SetMulticastHopLimit_Call) Return(err error) *MockPacketConn_SetMulticastHopLimit_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockPacketConn_SetMulticastHopLimit_Call) RunAndReturn(run func(hopLimit int) error) *MockPacketConn_SetMulticastHopLimit_Call { + _c.Call.Return(run) + return _c +} + +// SetMulticastInterface provides a mock function for the type MockPacketConn +func (_mock *MockPacketConn) SetMulticastInterface(ifi *net.Interface) error { + ret := _mock.Called(ifi) + + if len(ret) == 0 { + panic("no return value specified for SetMulticastInterface") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(*net.Interface) error); ok { + r0 = returnFunc(ifi) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockPacketConn_SetMulticastInterface_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetMulticastInterface' +type MockPacketConn_SetMulticastInterface_Call struct { + *mock.Call +} + +// SetMulticastInterface is a helper method to define mock.On call +// - ifi *net.Interface +func (_e *MockPacketConn_Expecter) SetMulticastInterface(ifi interface{}) *MockPacketConn_SetMulticastInterface_Call { + return &MockPacketConn_SetMulticastInterface_Call{Call: _e.mock.On("SetMulticastInterface", ifi)} +} + +func (_c *MockPacketConn_SetMulticastInterface_Call) Run(run func(ifi *net.Interface)) *MockPacketConn_SetMulticastInterface_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 *net.Interface + if args[0] != nil { + arg0 = args[0].(*net.Interface) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockPacketConn_SetMulticastInterface_Call) Return(err error) *MockPacketConn_SetMulticastInterface_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockPacketConn_SetMulticastInterface_Call) RunAndReturn(run func(ifi *net.Interface) error) *MockPacketConn_SetMulticastInterface_Call { + _c.Call.Return(run) + return _c +} + +// SetMulticastTTL provides a mock function for the type MockPacketConn +func (_mock *MockPacketConn) SetMulticastTTL(ttl int) error { + ret := _mock.Called(ttl) + + if len(ret) == 0 { + panic("no return value specified for SetMulticastTTL") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(int) error); ok { + r0 = returnFunc(ttl) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockPacketConn_SetMulticastTTL_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetMulticastTTL' +type MockPacketConn_SetMulticastTTL_Call struct { + *mock.Call +} + +// SetMulticastTTL is a helper method to define mock.On call +// - ttl int +func (_e *MockPacketConn_Expecter) SetMulticastTTL(ttl interface{}) *MockPacketConn_SetMulticastTTL_Call { + return &MockPacketConn_SetMulticastTTL_Call{Call: _e.mock.On("SetMulticastTTL", ttl)} +} + +func (_c *MockPacketConn_SetMulticastTTL_Call) Run(run func(ttl int)) *MockPacketConn_SetMulticastTTL_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 int + if args[0] != nil { + arg0 = args[0].(int) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockPacketConn_SetMulticastTTL_Call) Return(err error) *MockPacketConn_SetMulticastTTL_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockPacketConn_SetMulticastTTL_Call) RunAndReturn(run func(ttl int) error) *MockPacketConn_SetMulticastTTL_Call { + _c.Call.Return(run) + return _c +} + +// WriteTo provides a mock function for the type MockPacketConn +func (_mock *MockPacketConn) WriteTo(b []byte, ifIndex int, dst net.Addr) (int, error) { + ret := _mock.Called(b, ifIndex, dst) + + if len(ret) == 0 { + panic("no return value specified for WriteTo") + } + + var r0 int + var r1 error + if returnFunc, ok := ret.Get(0).(func([]byte, int, net.Addr) (int, error)); ok { + return returnFunc(b, ifIndex, dst) + } + if returnFunc, ok := ret.Get(0).(func([]byte, int, net.Addr) int); ok { + r0 = returnFunc(b, ifIndex, dst) + } else { + r0 = ret.Get(0).(int) + } + if returnFunc, ok := ret.Get(1).(func([]byte, int, net.Addr) error); ok { + r1 = returnFunc(b, ifIndex, dst) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockPacketConn_WriteTo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WriteTo' +type MockPacketConn_WriteTo_Call struct { + *mock.Call +} + +// WriteTo is a helper method to define mock.On call +// - b []byte +// - ifIndex int +// - dst net.Addr +func (_e *MockPacketConn_Expecter) WriteTo(b interface{}, ifIndex interface{}, dst interface{}) *MockPacketConn_WriteTo_Call { + return &MockPacketConn_WriteTo_Call{Call: _e.mock.On("WriteTo", b, ifIndex, dst)} +} + +func (_c *MockPacketConn_WriteTo_Call) Run(run func(b []byte, ifIndex int, dst net.Addr)) *MockPacketConn_WriteTo_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 []byte + if args[0] != nil { + arg0 = args[0].([]byte) + } + var arg1 int + if args[1] != nil { + arg1 = args[1].(int) + } + var arg2 net.Addr + if args[2] != nil { + arg2 = args[2].(net.Addr) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *MockPacketConn_WriteTo_Call) Return(n int, err error) *MockPacketConn_WriteTo_Call { + _c.Call.Return(n, err) + return _c +} + +func (_c *MockPacketConn_WriteTo_Call) RunAndReturn(run func(b []byte, ifIndex int, dst net.Addr) (int, error)) *MockPacketConn_WriteTo_Call { + _c.Call.Return(run) + return _c +} diff --git a/server.go b/server.go index 71b7bf8..a7a490f 100644 --- a/server.go +++ b/server.go @@ -6,14 +6,12 @@ import ( "math/rand" "net" "os" - "runtime" "strings" "sync" "time" + "github.com/enbility/zeroconf/v3/api" "github.com/miekg/dns" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" ) const ( @@ -24,7 +22,8 @@ const ( var defaultTTL uint32 = 3200 type serverOpts struct { - ttl uint32 + ttl uint32 + connFactory api.ConnectionFactory } func applyServerOpts(options ...ServerOption) serverOpts { @@ -50,6 +49,14 @@ func TTL(ttl uint32) ServerOption { } } +// WithServerConnFactory sets a custom connection factory for the server. +// This is primarily useful for testing with mock connections. +func WithServerConnFactory(factory api.ConnectionFactory) ServerOption { + return func(o *serverOpts) { + o.connFactory = factory + } +} + // Register a service by given arguments. This call will take the system's hostname // and lookup IP by that hostname. func Register(instance, service, domain string, port int, text []string, ifaces []net.Interface, opts ...ServerOption) (*Server, error) { @@ -83,7 +90,7 @@ func Register(instance, service, domain string, port int, text []string, ifaces } if len(ifaces) == 0 { - ifaces = listMulticastInterfaces() + ifaces = NewInterfaceProvider().MulticastInterfaces() } for _, iface := range ifaces { @@ -149,7 +156,7 @@ func RegisterProxy(instance, service, domain string, port int, host string, ips } if len(ifaces) == 0 { - ifaces = listMulticastInterfaces() + ifaces = NewInterfaceProvider().MulticastInterfaces() } s, err := newServer(ifaces, applyServerOpts(opts...)) @@ -170,8 +177,8 @@ const ( // Server structure encapsulates both IPv4/IPv6 UDP connections type Server struct { service *ServiceEntry - ipv4conn *ipv4.PacketConn - ipv6conn *ipv6.PacketConn + ipv4conn api.PacketConn + ipv6conn api.PacketConn ifaces []net.Interface shouldShutdown chan struct{} @@ -183,11 +190,16 @@ type Server struct { // Constructs server structure func newServer(ifaces []net.Interface, opts serverOpts) (*Server, error) { - ipv4conn, err4 := joinUdp4Multicast(ifaces) + factory := opts.connFactory + if factory == nil { + factory = NewConnectionFactory() + } + + ipv4conn, err4 := factory.CreateIPv4Conn(ifaces) if err4 != nil { log.Printf("[zeroconf] no suitable IPv4 interface: %s", err4.Error()) } - ipv6conn, err6 := joinUdp6Multicast(ifaces) + ipv6conn, err6 := factory.CreateIPv6Conn(ifaces) if err6 != nil { log.Printf("[zeroconf] no suitable IPv6 interface: %s", err6.Error()) } @@ -210,11 +222,11 @@ func newServer(ifaces []net.Interface, opts serverOpts) (*Server, error) { func (s *Server) start() { if s.ipv4conn != nil { s.refCount.Add(1) - go s.recv4(s.ipv4conn) + go s.recvLoop(s.ipv4conn) } if s.ipv6conn != nil { s.refCount.Add(1) - go s.recv6(s.ipv6conn) + go s.recvLoop(s.ipv6conn) } s.refCount.Add(1) go s.probe() @@ -226,13 +238,6 @@ func (s *Server) SetText(text []string) { s.announceText() } -// TTL sets the TTL for DNS replies -// -// Deprecated: This method is racy. Use the TTL server option instead. -func (s *Server) TTL(ttl uint32) { - s.ttl = ttl -} - // Shutdown closes all udp connections and unregisters the service func (s *Server) Shutdown() { s.shutdownLock.Lock() @@ -259,33 +264,9 @@ func (s *Server) Shutdown() { s.isShutdown = true } -// recv4 is a long running routine to receive packets from an interface -func (s *Server) recv4(c *ipv4.PacketConn) { - defer s.refCount.Done() - if c == nil { - return - } - buf := make([]byte, 65536) - for { - select { - case <-s.shouldShutdown: - return - default: - var ifIndex int - n, cm, from, err := c.ReadFrom(buf) - if err != nil { - continue - } - if cm != nil { - ifIndex = cm.IfIndex - } - _ = s.parsePacket(buf[:n], ifIndex, from) - } - } -} - -// recv6 is a long running routine to receive packets from an interface -func (s *Server) recv6(c *ipv6.PacketConn) { +// recvLoop is a long running routine to receive packets from a connection. +// It uses the PacketConn interface, allowing for mock injection in tests. +func (s *Server) recvLoop(c api.PacketConn) { defer s.refCount.Done() if c == nil { return @@ -296,13 +277,15 @@ func (s *Server) recv6(c *ipv6.PacketConn) { case <-s.shouldShutdown: return default: - var ifIndex int - n, cm, from, err := c.ReadFrom(buf) + n, ifIndex, from, err := c.ReadFrom(buf) if err != nil { - continue - } - if cm != nil { - ifIndex = cm.IfIndex + // Backoff to prevent CPU spin on persistent errors + select { + case <-s.shouldShutdown: + return + case <-time.After(50 * time.Millisecond): + continue + } } _ = s.parsePacket(buf[:n], ifIndex, from) } @@ -738,24 +721,11 @@ func (s *Server) unicastResponse(resp *dns.Msg, ifIndex int, from net.Addr) erro } addr := from.(*net.UDPAddr) if addr.IP.To4() != nil { - if ifIndex != 0 { - var wcm ipv4.ControlMessage - wcm.IfIndex = ifIndex - _, err = s.ipv4conn.WriteTo(buf, &wcm, addr) - } else { - _, err = s.ipv4conn.WriteTo(buf, nil, addr) - } - return err - } else { - if ifIndex != 0 { - var wcm ipv6.ControlMessage - wcm.IfIndex = ifIndex - _, err = s.ipv6conn.WriteTo(buf, &wcm, addr) - } else { - _, err = s.ipv6conn.WriteTo(buf, nil, addr) - } + _, err = s.ipv4conn.WriteTo(buf, ifIndex, addr) return err } + _, err = s.ipv6conn.WriteTo(buf, ifIndex, addr) + return err } // multicastResponse is used to send a multicast response packet @@ -764,67 +734,31 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int) error { if err != nil { return fmt.Errorf("failed to pack msg %v: %w", msg, err) } + + // Determine which interfaces to send to + var ifaces []int + if ifIndex != 0 { + ifaces = []int{ifIndex} + } else { + for _, intf := range s.ifaces { + ifaces = append(ifaces, intf.Index) + } + } + + // Send to IPv4 multicast group if s.ipv4conn != nil { - // See https://pkg.go.dev/golang.org/x/net/ipv4#pkg-note-BUG - // As of Golang 1.18.4 - // On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented. - var wcm ipv4.ControlMessage - if ifIndex != 0 { - switch runtime.GOOS { - case "darwin", "ios", "linux": - wcm.IfIndex = ifIndex - default: - iface, _ := net.InterfaceByIndex(ifIndex) - if err := s.ipv4conn.SetMulticastInterface(iface); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) - } - } - _, _ = s.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) - } else { - for _, intf := range s.ifaces { - switch runtime.GOOS { - case "darwin", "ios", "linux": - wcm.IfIndex = intf.Index - default: - if err := s.ipv4conn.SetMulticastInterface(&intf); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) - } - } - _, _ = s.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) - } + for _, idx := range ifaces { + _, _ = s.ipv4conn.WriteTo(buf, idx, ipv4Addr) } } + // Send to IPv6 multicast group if s.ipv6conn != nil { - // See https://pkg.go.dev/golang.org/x/net/ipv6#pkg-note-BUG - // As of Golang 1.18.4 - // On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented. - var wcm ipv6.ControlMessage - if ifIndex != 0 { - switch runtime.GOOS { - case "darwin", "ios", "linux": - wcm.IfIndex = ifIndex - default: - iface, _ := net.InterfaceByIndex(ifIndex) - if err := s.ipv6conn.SetMulticastInterface(iface); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) - } - } - _, _ = s.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) - } else { - for _, intf := range s.ifaces { - switch runtime.GOOS { - case "darwin", "ios", "linux": - wcm.IfIndex = intf.Index - default: - if err := s.ipv6conn.SetMulticastInterface(&intf); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) - } - } - _, _ = s.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) - } + for _, idx := range ifaces { + _, _ = s.ipv6conn.WriteTo(buf, idx, ipv6Addr) } } + return nil } diff --git a/server_unit_test.go b/server_unit_test.go new file mode 100644 index 0000000..75d0697 --- /dev/null +++ b/server_unit_test.go @@ -0,0 +1,728 @@ +package zeroconf + +import ( + "errors" + "net" + "sync" + "testing" + "time" + + "github.com/enbility/zeroconf/v3/api" + "github.com/enbility/zeroconf/v3/mocks" + "github.com/miekg/dns" + "github.com/stretchr/testify/mock" +) + +// TestServer_Recv_BacksOffOnError verifies that recv backs off when ReadFrom returns errors +// This is the fix for the CPU spin bug. +func TestServer_Recv_BacksOffOnError(t *testing.T) { + mockConn := mocks.NewMockPacketConn(t) + + // Track call count + var callCount int + var mu sync.Mutex + + // Configure ReadFrom to always return an error + mockConn.EXPECT().ReadFrom(mock.Anything).RunAndReturn(func(b []byte) (int, int, net.Addr, error) { + mu.Lock() + callCount++ + mu.Unlock() + return 0, 0, nil, errors.New("mock read error") + }).Maybe() + + s := &Server{ + shouldShutdown: make(chan struct{}), + ttl: 3200, + } + + // recvLoop calls s.refCount.Done() on exit, so we need to Add first + s.refCount.Add(1) + + // Start recv in background + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + s.recvLoop(mockConn) + }() + + // Let it run briefly + time.Sleep(200 * time.Millisecond) + + // Shutdown + close(s.shouldShutdown) + wg.Wait() + + mu.Lock() + calls := callCount + mu.Unlock() + + // With 50ms backoff and 200ms runtime, we expect roughly 4 calls max + // Without backoff, we'd see thousands of calls + if calls > 10 { + t.Errorf("Expected few calls with backoff, got %d (suggests spinning)", calls) + } + t.Logf("ReadFrom called %d times in 200ms with backoff", calls) +} + +// TestServer_Recv_ProcessesPacket verifies that recv correctly processes incoming packets +func TestServer_Recv_ProcessesPacket(t *testing.T) { + // Create a valid DNS query packet + msg := new(dns.Msg) + msg.SetQuestion("_test._tcp.local.", dns.TypePTR) + packetData, err := msg.Pack() + if err != nil { + t.Fatalf("Failed to pack DNS message: %v", err) + } + + // We can test the packet parsing directly + parsed := new(dns.Msg) + if err := parsed.Unpack(packetData); err != nil { + t.Fatalf("Failed to unpack: %v", err) + } + + if len(parsed.Question) != 1 { + t.Errorf("Expected 1 question, got %d", len(parsed.Question)) + } + if parsed.Question[0].Name != "_test._tcp.local." { + t.Errorf("Expected question name _test._tcp.local., got %s", parsed.Question[0].Name) + } +} + +// TestServer_MulticastResponse_WritesToConnections verifies multicast sends to both connections +func TestServer_MulticastResponse_WritesToConnections(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + + iface := net.Interface{Index: 1, Name: "eth0"} + + // Expect WriteTo to be called on both connections + mockIPv4.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).Return(0, nil).Once() + mockIPv6.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).Return(0, nil).Once() + + s := &Server{ + ipv4conn: mockIPv4, + ipv6conn: mockIPv6, + ifaces: []net.Interface{iface}, + shouldShutdown: make(chan struct{}), + ttl: 3200, + } + + msg := new(dns.Msg) + msg.SetQuestion("_test._tcp.local.", dns.TypePTR) + + err := s.multicastResponse(msg, 0) + if err != nil { + t.Fatalf("multicastResponse failed: %v", err) + } +} + +// TestServer_MulticastResponse_SpecificInterface verifies multicast to specific interface +func TestServer_MulticastResponse_SpecificInterface(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + + // Expect WriteTo to be called with specific interface index 2 + mockIPv4.EXPECT().WriteTo(mock.Anything, 2, mock.Anything).Return(0, nil).Once() + mockIPv6.EXPECT().WriteTo(mock.Anything, 2, mock.Anything).Return(0, nil).Once() + + s := &Server{ + ipv4conn: mockIPv4, + ipv6conn: mockIPv6, + ifaces: []net.Interface{{Index: 1, Name: "eth0"}, {Index: 2, Name: "wlan0"}}, + shouldShutdown: make(chan struct{}), + ttl: 3200, + } + + msg := new(dns.Msg) + msg.SetQuestion("_test._tcp.local.", dns.TypePTR) + + // Send to specific interface (index 2) + err := s.multicastResponse(msg, 2) + if err != nil { + t.Fatalf("multicastResponse failed: %v", err) + } +} + +// TestServer_Shutdown_ClosesConnections verifies shutdown properly closes connections +func TestServer_Shutdown_ClosesConnections(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + + // Expect Close and WriteTo (for unregister) to be called + mockIPv4.EXPECT().WriteTo(mock.Anything, mock.AnythingOfType("int"), mock.Anything).Return(0, nil).Maybe() + mockIPv6.EXPECT().WriteTo(mock.Anything, mock.AnythingOfType("int"), mock.Anything).Return(0, nil).Maybe() + mockIPv4.EXPECT().Close().Return(nil).Once() + mockIPv6.EXPECT().Close().Return(nil).Once() + + s := &Server{ + ipv4conn: mockIPv4, + ipv6conn: mockIPv6, + ifaces: []net.Interface{{Index: 1, Name: "eth0"}}, + shouldShutdown: make(chan struct{}), + ttl: 3200, + service: newServiceEntry("test", "_test._tcp", "local"), + } + s.service.Port = 8080 + s.service.HostName = "test.local." + + s.Shutdown() +} + +// TestServerConfig verifies server configuration options +func TestServerConfig(t *testing.T) { + t.Run("default TTL", func(t *testing.T) { + opts := applyServerOpts() + if opts.ttl != defaultTTL { + t.Errorf("Expected default TTL %d, got %d", defaultTTL, opts.ttl) + } + }) + + t.Run("custom TTL", func(t *testing.T) { + opts := applyServerOpts(TTL(1000)) + if opts.ttl != 1000 { + t.Errorf("Expected TTL 1000, got %d", opts.ttl) + } + }) +} + +// TestWithServerConnFactory verifies the WithServerConnFactory option +func TestWithServerConnFactory(t *testing.T) { + factory := mocks.NewMockConnectionFactory(t) + + opts := applyServerOpts(WithServerConnFactory(factory)) + + if opts.connFactory != factory { + t.Error("Expected connection factory to be set") + } +} + +// TestIsKnownAnswer verifies known-answer suppression logic +func TestIsKnownAnswer(t *testing.T) { + t.Run("empty response answers", func(t *testing.T) { + resp := &dns.Msg{} + query := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{Rrtype: dns.TypePTR, Ttl: 100}, + Ptr: "test._http._tcp.local.", + }, + }, + } + if isKnownAnswer(resp, query) { + t.Error("Expected false when response has no answers") + } + }) + + t.Run("empty query answers", func(t *testing.T) { + resp := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{Rrtype: dns.TypePTR, Ttl: 100}, + Ptr: "test._http._tcp.local.", + }, + }, + } + query := &dns.Msg{} + if isKnownAnswer(resp, query) { + t.Error("Expected false when query has no answers") + } + }) + + t.Run("non-PTR response", func(t *testing.T) { + resp := &dns.Msg{ + Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{Rrtype: dns.TypeA, Ttl: 100}, + A: net.ParseIP("192.168.1.1"), + }, + }, + } + query := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{Rrtype: dns.TypePTR, Ttl: 100}, + Ptr: "test._http._tcp.local.", + }, + }, + } + if isKnownAnswer(resp, query) { + t.Error("Expected false for non-PTR response") + } + }) + + t.Run("matching known answer with sufficient TTL", func(t *testing.T) { + resp := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{Rrtype: dns.TypePTR, Ttl: 100}, + Ptr: "test._http._tcp.local.", + }, + }, + } + query := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{Rrtype: dns.TypePTR, Ttl: 60}, // >= 100/2 + Ptr: "test._http._tcp.local.", + }, + }, + } + if !isKnownAnswer(resp, query) { + t.Error("Expected true for matching known answer with sufficient TTL") + } + }) + + t.Run("matching known answer with insufficient TTL", func(t *testing.T) { + resp := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{Rrtype: dns.TypePTR, Ttl: 100}, + Ptr: "test._http._tcp.local.", + }, + }, + } + query := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{Rrtype: dns.TypePTR, Ttl: 40}, // < 100/2 + Ptr: "test._http._tcp.local.", + }, + }, + } + if isKnownAnswer(resp, query) { + t.Error("Expected false for known answer with insufficient TTL") + } + }) + + t.Run("non-matching PTR", func(t *testing.T) { + resp := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{Rrtype: dns.TypePTR, Ttl: 100}, + Ptr: "test._http._tcp.local.", + }, + }, + } + query := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{Rrtype: dns.TypePTR, Ttl: 100}, + Ptr: "other._http._tcp.local.", + }, + }, + } + if isKnownAnswer(resp, query) { + t.Error("Expected false for non-matching PTR") + } + }) +} + +// TestServer_HandleQuestion verifies question handling logic +func TestServer_HandleQuestion(t *testing.T) { + createTestServer := func() *Server { + s := &Server{ + ttl: 3200, + shouldShutdown: make(chan struct{}), + service: newServiceEntry("myservice", "_http._tcp", "local"), + } + s.service.Port = 8080 + s.service.HostName = "myhost.local." + s.service.Text = []string{"key=value"} + return s + } + + t.Run("nil service", func(t *testing.T) { + s := &Server{ + ttl: 3200, + shouldShutdown: make(chan struct{}), + service: nil, + } + resp := &dns.Msg{} + query := &dns.Msg{} + q := dns.Question{Name: "_http._tcp.local.", Qtype: dns.TypePTR} + + err := s.handleQuestion(q, resp, query, 1) + if err != nil { + t.Errorf("Expected no error for nil service, got %v", err) + } + if len(resp.Answer) != 0 { + t.Error("Expected no answers for nil service") + } + }) + + t.Run("service type query", func(t *testing.T) { + s := createTestServer() + resp := &dns.Msg{} + query := &dns.Msg{} + q := dns.Question{Name: s.service.ServiceTypeName(), Qtype: dns.TypePTR} + + err := s.handleQuestion(q, resp, query, 1) + if err != nil { + t.Errorf("handleQuestion failed: %v", err) + } + if len(resp.Answer) == 0 { + t.Error("Expected answers for service type query") + } + }) + + t.Run("service name query", func(t *testing.T) { + s := createTestServer() + resp := &dns.Msg{} + query := &dns.Msg{} + q := dns.Question{Name: s.service.ServiceName(), Qtype: dns.TypePTR} + + err := s.handleQuestion(q, resp, query, 1) + if err != nil { + t.Errorf("handleQuestion failed: %v", err) + } + if len(resp.Answer) == 0 { + t.Error("Expected answers for service name query") + } + }) + + t.Run("service instance query", func(t *testing.T) { + s := createTestServer() + resp := &dns.Msg{} + query := &dns.Msg{} + q := dns.Question{Name: s.service.ServiceInstanceName(), Qtype: dns.TypeSRV} + + err := s.handleQuestion(q, resp, query, 1) + if err != nil { + t.Errorf("handleQuestion failed: %v", err) + } + if len(resp.Answer) == 0 { + t.Error("Expected answers for service instance query") + } + }) + + t.Run("subtype query", func(t *testing.T) { + s := createTestServer() + s.service.Subtypes = []string{"_printer"} + resp := &dns.Msg{} + query := &dns.Msg{} + subtypeName := "_printer._sub." + s.service.ServiceName() + q := dns.Question{Name: subtypeName, Qtype: dns.TypePTR} + + err := s.handleQuestion(q, resp, query, 1) + if err != nil { + t.Errorf("handleQuestion failed: %v", err) + } + if len(resp.Answer) == 0 { + t.Error("Expected answers for subtype query") + } + }) + + t.Run("unknown query name", func(t *testing.T) { + s := createTestServer() + resp := &dns.Msg{} + query := &dns.Msg{} + q := dns.Question{Name: "_unknown._tcp.local.", Qtype: dns.TypePTR} + + err := s.handleQuestion(q, resp, query, 1) + if err != nil { + t.Errorf("handleQuestion failed: %v", err) + } + if len(resp.Answer) != 0 { + t.Error("Expected no answers for unknown query") + } + }) + + t.Run("known answer suppression", func(t *testing.T) { + s := createTestServer() + resp := &dns.Msg{} + // Query with known answer + query := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypePTR, + Ttl: 3200, // >= s.ttl/2 + }, + Ptr: s.service.ServiceInstanceName(), + }, + }, + } + q := dns.Question{Name: s.service.ServiceName(), Qtype: dns.TypePTR} + + err := s.handleQuestion(q, resp, query, 1) + if err != nil { + t.Errorf("handleQuestion failed: %v", err) + } + // Answer should be suppressed + if len(resp.Answer) != 0 { + t.Error("Expected answer to be suppressed due to known-answer") + } + }) +} + +// TestRegisterProxy_Validation tests RegisterProxy input validation +func TestRegisterProxy_Validation(t *testing.T) { + t.Run("missing instance name", func(t *testing.T) { + _, err := RegisterProxy("", "_http._tcp", "local", 8080, "myhost", []string{"192.168.1.1"}, nil, nil) + if err == nil { + t.Error("Expected error for missing instance name") + } + }) + + t.Run("missing service name", func(t *testing.T) { + _, err := RegisterProxy("myservice", "", "local", 8080, "myhost", []string{"192.168.1.1"}, nil, nil) + if err == nil { + t.Error("Expected error for missing service name") + } + }) + + t.Run("missing host name", func(t *testing.T) { + _, err := RegisterProxy("myservice", "_http._tcp", "local", 8080, "", []string{"192.168.1.1"}, nil, nil) + if err == nil { + t.Error("Expected error for missing host name") + } + }) + + t.Run("missing port", func(t *testing.T) { + _, err := RegisterProxy("myservice", "_http._tcp", "local", 0, "myhost", []string{"192.168.1.1"}, nil, nil) + if err == nil { + t.Error("Expected error for missing port") + } + }) + + t.Run("invalid IP address", func(t *testing.T) { + _, err := RegisterProxy("myservice", "_http._tcp", "local", 8080, "myhost", []string{"invalid-ip"}, nil, nil) + if err == nil { + t.Error("Expected error for invalid IP address") + } + }) +} + +// setupMockServerConnections creates mock connections for server tests +func setupMockServerConnections(t *testing.T) (*mocks.MockPacketConn, *mocks.MockPacketConn, api.ConnectionFactory) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + factory := mocks.NewMockConnectionFactory(t) + + factory.EXPECT().CreateIPv4Conn(mock.Anything).Return(mockIPv4, nil).Once() + factory.EXPECT().CreateIPv6Conn(mock.Anything).Return(mockIPv6, nil).Once() + + return mockIPv4, mockIPv6, factory +} + +// TestRegisterProxy_WithMockConnections tests RegisterProxy with mocked connections +func TestRegisterProxy_WithMockConnections(t *testing.T) { + mockIPv4, mockIPv6, factory := setupMockServerConnections(t) + + // Mock ReadFrom to block until shutdown + mockIPv4.EXPECT().ReadFrom(mock.Anything).RunAndReturn(func(b []byte) (int, int, net.Addr, error) { + time.Sleep(50 * time.Millisecond) + return 0, 0, nil, errors.New("shutdown") + }).Maybe() + mockIPv6.EXPECT().ReadFrom(mock.Anything).RunAndReturn(func(b []byte) (int, int, net.Addr, error) { + time.Sleep(50 * time.Millisecond) + return 0, 0, nil, errors.New("shutdown") + }).Maybe() + + // Mock WriteTo for probes and announcements + mockIPv4.EXPECT().WriteTo(mock.Anything, mock.Anything, mock.Anything).Return(0, nil).Maybe() + mockIPv6.EXPECT().WriteTo(mock.Anything, mock.Anything, mock.Anything).Return(0, nil).Maybe() + + // Mock Close + mockIPv4.EXPECT().Close().Return(nil).Maybe() + mockIPv6.EXPECT().Close().Return(nil).Maybe() + + // Register the proxy service + server, err := RegisterProxy( + "myservice", + "_http._tcp", + "local", + 8080, + "myhost", + []string{"192.168.1.100", "fe80::1"}, + []string{"key=value"}, + []net.Interface{{Index: 1, Name: "eth0"}}, + WithServerConnFactory(factory), + ) + if err != nil { + t.Fatalf("RegisterProxy failed: %v", err) + } + defer server.Shutdown() + + // Verify service was set up correctly + if server.service.Instance != "myservice" { + t.Errorf("Expected instance 'myservice', got '%s'", server.service.Instance) + } + if server.service.Port != 8080 { + t.Errorf("Expected port 8080, got %d", server.service.Port) + } + if len(server.service.AddrIPv4) != 1 { + t.Errorf("Expected 1 IPv4 address, got %d", len(server.service.AddrIPv4)) + } + if len(server.service.AddrIPv6) != 1 { + t.Errorf("Expected 1 IPv6 address, got %d", len(server.service.AddrIPv6)) + } +} + +// TestServer_SetText tests the SetText method +func TestServer_SetText(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + + // Track WriteTo calls to verify announcement was sent + var writeCount int + var mu sync.Mutex + + mockIPv4.EXPECT().WriteTo(mock.Anything, mock.Anything, mock.Anything).RunAndReturn( + func(b []byte, ifIndex int, dst net.Addr) (int, error) { + mu.Lock() + writeCount++ + mu.Unlock() + return len(b), nil + }).Maybe() + mockIPv6.EXPECT().WriteTo(mock.Anything, mock.Anything, mock.Anything).Return(0, nil).Maybe() + + s := &Server{ + ipv4conn: mockIPv4, + ipv6conn: mockIPv6, + ifaces: []net.Interface{{Index: 1, Name: "eth0"}}, + shouldShutdown: make(chan struct{}), + ttl: 3200, + service: newServiceEntry("test", "_test._tcp", "local"), + } + s.service.Port = 8080 + s.service.HostName = "test.local." + s.service.Text = []string{"old=value"} + + // Update text + s.SetText([]string{"new=value"}) + + // Verify text was updated + if len(s.service.Text) != 1 || s.service.Text[0] != "new=value" { + t.Errorf("Expected text 'new=value', got %v", s.service.Text) + } + + // Verify announcement was sent (WriteTo was called) + mu.Lock() + if writeCount == 0 { + t.Error("Expected announcement to be sent after SetText") + } + mu.Unlock() +} + +// TestServer_HandleQuery_RespondsToQueries tests server responding to mDNS queries +func TestServer_HandleQuery_RespondsToQueries(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + + // Capture responses + var capturedResponses [][]byte + var mu sync.Mutex + + mockIPv4.EXPECT().WriteTo(mock.Anything, mock.Anything, mock.Anything).RunAndReturn( + func(b []byte, ifIndex int, dst net.Addr) (int, error) { + mu.Lock() + responseCopy := make([]byte, len(b)) + copy(responseCopy, b) + capturedResponses = append(capturedResponses, responseCopy) + mu.Unlock() + return len(b), nil + }).Maybe() + mockIPv6.EXPECT().WriteTo(mock.Anything, mock.Anything, mock.Anything).Return(0, nil).Maybe() + + s := &Server{ + ipv4conn: mockIPv4, + ipv6conn: mockIPv6, + ifaces: []net.Interface{{Index: 1, Name: "eth0"}}, + shouldShutdown: make(chan struct{}), + ttl: 3200, + service: newServiceEntry("myservice", "_http._tcp", "local"), + } + s.service.Port = 8080 + s.service.HostName = "myhost.local." + s.service.Text = []string{"key=value"} + s.service.AddrIPv4 = []net.IP{net.ParseIP("192.168.1.100")} + + // Create a query for our service + query := new(dns.Msg) + query.SetQuestion("_http._tcp.local.", dns.TypePTR) + + // Handle the query + err := s.handleQuery(query, 1, &net.UDPAddr{IP: net.ParseIP("192.168.1.50"), Port: 5353}) + if err != nil { + t.Fatalf("handleQuery failed: %v", err) + } + + // Verify response was sent + mu.Lock() + responseCount := len(capturedResponses) + mu.Unlock() + + if responseCount == 0 { + t.Error("Expected response to be sent for matching query") + } + + // Parse and verify the response + if responseCount > 0 { + mu.Lock() + respData := capturedResponses[0] + mu.Unlock() + + resp := new(dns.Msg) + if err := resp.Unpack(respData); err != nil { + t.Fatalf("Failed to unpack response: %v", err) + } + + if len(resp.Answer) == 0 { + t.Error("Expected answers in response") + } + } +} + +// TestServer_UnicastResponse tests unicast response handling +func TestServer_UnicastResponse(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + + // Capture the destination address to verify unicast + var capturedDst net.Addr + var mu sync.Mutex + + mockIPv4.EXPECT().WriteTo(mock.Anything, mock.Anything, mock.Anything).RunAndReturn( + func(b []byte, ifIndex int, dst net.Addr) (int, error) { + mu.Lock() + capturedDst = dst + mu.Unlock() + return len(b), nil + }).Once() + + s := &Server{ + ipv4conn: mockIPv4, + ipv6conn: nil, + ifaces: []net.Interface{{Index: 1, Name: "eth0"}}, + shouldShutdown: make(chan struct{}), + ttl: 3200, + service: newServiceEntry("myservice", "_http._tcp", "local"), + } + s.service.Port = 8080 + s.service.HostName = "myhost.local." + + // Send unicast response + msg := new(dns.Msg) + msg.SetQuestion("_http._tcp.local.", dns.TypePTR) + clientAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.50"), Port: 5353} + + err := s.unicastResponse(msg, 1, clientAddr) + if err != nil { + t.Fatalf("unicastResponse failed: %v", err) + } + + // Verify response was sent to the client's address + mu.Lock() + defer mu.Unlock() + if capturedDst == nil { + t.Error("Expected response to be sent") + } else { + udpAddr, ok := capturedDst.(*net.UDPAddr) + if !ok { + t.Error("Expected UDP address") + } else if !udpAddr.IP.Equal(net.ParseIP("192.168.1.50")) { + t.Errorf("Expected response to 192.168.1.50, got %s", udpAddr.IP) + } + } +} diff --git a/version.json b/version.json index 26f9a28..f6c42b9 100644 --- a/version.json +++ b/version.json @@ -1,3 +1,3 @@ { - "version": "v2.2.0" + "version": "v3.0.0" } From c4ec4d9484dc82199d58c179b23deb160d1fc9df Mon Sep 17 00:00:00 2001 From: Simon Thelen Date: Fri, 20 Feb 2026 11:12:31 +0100 Subject: [PATCH 3/3] Remove V3_REFACTORING_PLAN since refactor is finished --- V3_REFACTORING_PLAN.md | 248 ----------------------------------------- 1 file changed, 248 deletions(-) delete mode 100644 V3_REFACTORING_PLAN.md diff --git a/V3_REFACTORING_PLAN.md b/V3_REFACTORING_PLAN.md deleted file mode 100644 index 99884df..0000000 --- a/V3_REFACTORING_PLAN.md +++ /dev/null @@ -1,248 +0,0 @@ -# ZeroConf v3 Refactoring Plan - -## Goals - -1. **Testability**: Enable unit testing without real network access -2. **Interfaces**: Define clear abstractions at network boundaries -3. **Dependency Injection**: Allow mock injection for testing -4. **Test Coverage**: Target 85%+ coverage with meaningful unit tests -5. **Generated Mocks**: Use mockery for maintainable mocks -6. **Remove Global State**: Move package-level vars into config structs - -## Key Insight: ControlMessage Simplification - -Analysis of the codebase shows that only `IfIndex` is ever used from `ipv4.ControlMessage` and `ipv6.ControlMessage`. This allows us to create a unified `PacketConn` interface that works for both IPv4 and IPv6: - -```go -// Instead of exposing ControlMessage, we just expose ifIndex -ReadFrom(b []byte) (n int, ifIndex int, src net.Addr, err error) -WriteTo(b []byte, ifIndex int, dst net.Addr) (n int, err error) -``` - -## Architecture Overview - -``` -┌─────────────────────────────────────────────────────────────┐ -│ Public API │ -│ Browse() / Lookup() / Register() / RegisterProxy() │ -└─────────────────────────────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────┐ -│ Client / Server │ -│ - Use api.PacketConn interface (not concrete types) │ -│ - Accept ConnectionFactory via options │ -│ - Use InterfaceProvider internally for default interfaces │ -└─────────────────────────────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────┐ -│ api/ Package │ -│ PacketConn / ConnectionFactory / InterfaceProvider │ -└─────────────────────────────────────────────────────────────┘ - │ - ┌───────────────┴───────────────┐ - ▼ ▼ -┌──────────────────────────┐ ┌──────────────────────────┐ -│ Real Implementations │ │ mocks/ Package │ -│ - ipv4PacketConn │ │ - MockPacketConn │ -│ - ipv6PacketConn │ │ - MockConnectionFactory │ -│ - defaultConnFactory │ │ - MockInterfaceProvider │ -│ - defaultIfaceProvider │ │ (generated by mockery) │ -└──────────────────────────┘ └──────────────────────────┘ -``` - -## Package Structure - -``` -zeroconf/v3/ -├── api/ # Pure interfaces (no internal deps) -│ └── interfaces.go # PacketConn, ConnectionFactory, InterfaceProvider -├── mocks/ # Generated mocks (mockery) -│ ├── mock_packet_conn.go -│ ├── mock_connection_factory.go -│ └── mock_interface_provider.go -├── .mockery.yml # Mockery configuration -├── client.go # Client implementation -├── server.go # Server implementation -├── conn_ipv4.go # ipv4PacketConn wrapper -├── conn_ipv6.go # ipv6PacketConn wrapper -├── conn_factory.go # defaultConnectionFactory -├── conn_provider.go # defaultInterfaceProvider -├── mdns.go # Network constants (mDNS addresses) -├── service.go # ServiceEntry, ServiceRecord -├── utils.go # Helper functions -├── doc.go # Package documentation -├── *_test.go # Tests -└── examples/ # Example applications -``` - -## Interface Definitions (in api/interfaces.go) - -### PacketConn - -```go -type PacketConn interface { - ReadFrom(b []byte) (n int, ifIndex int, src net.Addr, err error) - WriteTo(b []byte, ifIndex int, dst net.Addr) (n int, err error) - Close() error - JoinGroup(ifi *net.Interface, group net.Addr) error - LeaveGroup(ifi *net.Interface, group net.Addr) error - SetMulticastTTL(ttl int) error - SetMulticastHopLimit(hopLimit int) error - SetMulticastInterface(ifi *net.Interface) error -} -``` - -### ConnectionFactory - -```go -type ConnectionFactory interface { - CreateIPv4Conn(ifaces []net.Interface) (PacketConn, error) - CreateIPv6Conn(ifaces []net.Interface) (PacketConn, error) -} -``` - -### InterfaceProvider - -```go -type InterfaceProvider interface { - MulticastInterfaces() []net.Interface -} -``` - -## Implementation Phases - -### Phase 1: Package Structure & Interfaces ✓ - -**Completed:** -- [x] Create `api/` package with interface definitions -- [x] Configure mockery (`.mockery.yml`) -- [x] Generate mocks in `mocks/` package -- [x] Update implementation to import `api/` -- [x] Update tests to use `mocks/` - ---- - -### Phase 2: Connection Wrappers (File Split) ✓ - -**Completed:** -- [x] `conn_ipv4.go` - ipv4PacketConn wrapper -- [x] `conn_ipv6.go` - ipv6PacketConn wrapper -- [x] `conn_factory.go` - defaultConnectionFactory -- [x] `conn_provider.go` - defaultInterfaceProvider with MulticastInterfaces() -- [x] Removed `conn_wrapper.go` (split into above files) - ---- - -### Phase 3: InterfaceProvider Implementation ✓ - -**Completed:** -- [x] Created `defaultInterfaceProvider` implementing `api.InterfaceProvider` -- [x] Moved `listMulticastInterfaces()` into `defaultInterfaceProvider.MulticastInterfaces()` -- [x] Used internally via `NewInterfaceProvider().MulticastInterfaces()` in client/server - -**Design Decision:** Removed `WithIfaceProvider` options after review - they added complexity without clear benefit since: -- `Register()` already accepts `ifaces []net.Interface` directly -- `SelectIfaces()` option exists for client -- Interface selection is simpler as a direct parameter than an injected provider - ---- - -### Phase 4: Server Improvements ✓ - -**Changes to `server.go`:** -- [x] Change `Server.ipv4conn` to use `api.PacketConn` -- [x] Change `Server.ipv6conn` to use `api.PacketConn` -- [x] Add `WithServerConnFactory()` option -- [x] Remove deprecated `Server.TTL()` method - ---- - -### Phase 5: Client Improvements ✓ - -**Changes to `client.go`:** -- [x] Change connection fields to use `api.PacketConn` -- [x] Add `WithClientConnFactory()` option -- [x] Export `Client` type (renamed `client` -> `Client`) -- [x] Add `NewClient()` constructor - ---- - -### Phase 6: Coverage & Cleanup - -1. Run coverage report, identify gaps -2. Add tests for untested functions -3. Update doc.go for v3 -4. Final integration test pass - ---- - -## File Changes Summary - -| File | Action | Description | -|------|--------|-------------| -| `api/interfaces.go` | DONE | Interface definitions | -| `mocks/*.go` | DONE | Generated mocks (mockery) | -| `.mockery.yml` | DONE | Mockery configuration | -| `conn_ipv4.go` | NEW | IPv4 PacketConn wrapper | -| `conn_ipv6.go` | NEW | IPv6 PacketConn wrapper | -| `conn_factory.go` | NEW | defaultConnectionFactory | -| `conn_provider.go` | NEW | defaultInterfaceProvider + listMulticastInterfaces | -| `conn_wrapper.go` | DELETE | Split into above files | -| `mdns.go` | RENAME | Network constants (was connection.go) | -| `server.go` | MODIFY | Add WithServerConnFactory, remove TTL() | -| `client.go` | MODIFY | Export Client, add NewClient, add WithClientConnFactory | -| `server_unit_test.go` | DONE | Unit tests with mocks | -| `client_unit_test.go` | DONE | Unit tests with mocks | - ---- - -## Breaking Changes - -1. **Module path**: `github.com/enbility/zeroconf/v3` -2. **Exported `Client` type**: New public API -3. **`NewClient()` function**: New constructor -4. **Removed**: `Server.TTL()` method (was deprecated) - -## Backward Compatibility - -Main API functions remain compatible: -- `Browse(ctx, service, domain, entries, removed, opts...)` - unchanged -- `Lookup(ctx, instance, service, domain, entries, opts...)` - unchanged -- `Register(instance, service, domain, port, text, ifaces, opts...)` - unchanged -- `RegisterProxy(...)` - unchanged - -New optional features via options: -- `WithClientConnFactory(factory)` / `WithServerConnFactory(factory)` - for injecting mock connections in tests - ---- - -## Mock Generation - -Using mockery v3. Configuration in `.mockery.yml`: - -```yaml -packages: - github.com/enbility/zeroconf/v3/api: - interfaces: - PacketConn: - ConnectionFactory: - InterfaceProvider: -``` - -Regenerate mocks: -```bash -mockery -``` - ---- - -## Success Criteria - -- [x] All existing tests pass -- [ ] Test coverage > 85% (currently 72.7%) -- [x] All network I/O behind interfaces -- [x] Unit tests run without network access -- [x] Mocks generated automatically -- [ ] Documentation updated