diff --git a/cmd/pinecone/main.go b/cmd/pinecone/main.go index 182691b9..c426618c 100644 --- a/cmd/pinecone/main.go +++ b/cmd/pinecone/main.go @@ -47,12 +47,9 @@ func main() { } dialer := net.Dialer{ - Timeout: time.Second * 5, - KeepAlive: time.Second * 2, - } - listener := net.ListenConfig{ - KeepAlive: time.Second * 2, + Timeout: time.Second * 5, } + listener := net.ListenConfig{} pineconeRouter := router.NewRouter(logger, "router", sk, pk, nil) _ = sessions.NewSessions(logger, pineconeRouter) diff --git a/cmd/pineconeip/main.go b/cmd/pineconeip/main.go index f807fe74..0ca09b9a 100644 --- a/cmd/pineconeip/main.go +++ b/cmd/pineconeip/main.go @@ -21,6 +21,10 @@ import ( "log" "net" "os" + "os/signal" + "runtime/pprof" + "syscall" + "time" "net/http" _ "net/http/pprof" @@ -52,9 +56,15 @@ func main() { logger := log.New(os.Stdout, "", 0) if hostPort := os.Getenv("PPROFLISTEN"); hostPort != "" { - logger.Println("Starting pprof on", hostPort) go func() { - _ = http.ListenAndServe(hostPort, nil) + listener, err := net.Listen("tcp", hostPort) + if err != nil { + panic(err) + } + logger.Println("Starting pprof on", listener.Addr()) + if err := http.Serve(listener, nil); err != nil { + panic(err) + } }() } @@ -72,14 +82,6 @@ func main() { if err := conn.SetNoDelay(true); err != nil { return fmt.Errorf("conn.SetNoDelay: %w", err) } - /* - if err := conn.SetKeepAlive(true); err != nil { - return fmt.Errorf("conn.SetKeepAlive: %w", err) - } - if err := conn.SetKeepAlivePeriod(time.Second); err != nil { - return fmt.Errorf("conn.SetKeepAlivePeriod: %w", err) - } - */ if err := conn.SetLinger(0); err != nil { return fmt.Errorf("conn.SetLinger: %w", err) } @@ -133,5 +135,27 @@ func main() { } }() - select {} + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGUSR1) + for { + switch <-sigs { + case syscall.SIGUSR1: + fn := fmt.Sprintf("/tmp/profile.%d", os.Getpid()) + logger.Println("Requested profile:", fn) + fp, err := os.Create(fn) + if err != nil { + logger.Println("Failed to create profile:", err) + return + } + defer fp.Close() + if err := pprof.StartCPUProfile(fp); err != nil { + logger.Println("Failed to start profiling:", err) + return + } + time.AfterFunc(time.Second*10, func() { + pprof.StopCPUProfile() + logger.Println("Profile written:", fn) + }) + } + } } diff --git a/cmd/pineconesim/page.html b/cmd/pineconesim/page.html index 2483df31..8f878cb9 100644 --- a/cmd/pineconesim/page.html +++ b/cmd/pineconesim/page.html @@ -125,7 +125,7 @@

Node Summary

{{range .Nodes}} {{if not .IsExternal}} - {{.Name}} + {{.Name}} {{.Port}} {{.Coords}} {{.Root}} @@ -208,7 +208,7 @@

Peers

{{range .NodeInfo.Peers}} - {{.Name}} + {{.Name}} {{.PublicKey}} {{.Port}} {{.RootPublicKey}} diff --git a/cmd/pineconesim/simulator/interface.go b/cmd/pineconesim/simulator/interface.go index 40cfbd31..1de50601 100644 --- a/cmd/pineconesim/simulator/interface.go +++ b/cmd/pineconesim/simulator/interface.go @@ -23,6 +23,8 @@ import ( ) func (sim *Simulator) LookupCoords(target string) (types.SwitchPorts, error) { + sim.nodesMutex.RLock() + defer sim.nodesMutex.RUnlock() node, ok := sim.nodes[target] if !ok { return nil, fmt.Errorf("node %q not known", target) @@ -31,6 +33,8 @@ func (sim *Simulator) LookupCoords(target string) (types.SwitchPorts, error) { } func (sim *Simulator) LookupNodeID(target types.SwitchPorts) (string, error) { + sim.nodesMutex.RLock() + defer sim.nodesMutex.RUnlock() for id, n := range sim.nodes { if n.Coords().EqualTo(target) { return id, nil @@ -40,6 +44,8 @@ func (sim *Simulator) LookupNodeID(target types.SwitchPorts) (string, error) { } func (sim *Simulator) LookupPublicKey(target types.PublicKey) (string, error) { + sim.nodesMutex.RLock() + defer sim.nodesMutex.RUnlock() for id, n := range sim.nodes { if n.PublicKey().EqualTo(target) { return id, nil diff --git a/cmd/pineconesim/simulator/nodes.go b/cmd/pineconesim/simulator/nodes.go index d46a5057..e3f54819 100644 --- a/cmd/pineconesim/simulator/nodes.go +++ b/cmd/pineconesim/simulator/nodes.go @@ -67,6 +67,9 @@ func (sim *Simulator) CreateNode(t string) error { if err := c.SetNoDelay(true); err != nil { panic(err) } + if err := c.SetLinger(0); err != nil { + panic(err) + } if _, err = n.AuthenticatedConnect(c, "sim", router.PeerTypeRemote); err != nil { continue } diff --git a/multicast/multicast.go b/multicast/multicast.go index a0e6e834..3d291a1a 100644 --- a/multicast/multicast.go +++ b/multicast/multicast.go @@ -65,15 +65,14 @@ func NewMulticast( } m.tcpLC = net.ListenConfig{ Control: m.tcpOptions, - KeepAlive: time.Second, + KeepAlive: time.Second * 3, } m.udpLC = net.ListenConfig{ Control: m.udpOptions, } m.dialer = net.Dialer{ - Control: m.tcpOptions, - Timeout: time.Second * 5, - KeepAlive: time.Second, + Control: m.tcpOptions, + // Timeout: time.Second * 5, } return m } @@ -163,6 +162,7 @@ func (m *Multicast) accept(listener net.Listener) { if _, err := m.r.AuthenticatedConnect(conn, tcpaddr.Zone, router.PeerTypeMulticast); err != nil { //m.log.Println("m.s.AuthenticatedConnect:", err) + _ = conn.Close() continue } } @@ -309,6 +309,7 @@ func (m *Multicast) listen(intf *multicastInterface, conn net.PacketConn, srcadd if _, err := m.r.AuthenticatedConnect(peer, udpaddr.Zone, router.PeerTypeMulticast); err != nil { m.log.Println("m.s.AuthenticatedConnect:", err) + _ = peer.Close() continue } } diff --git a/multicast/platform_darwin.go b/multicast/platform_darwin.go index 138e2b09..2ee8f07e 100644 --- a/multicast/platform_darwin.go +++ b/multicast/platform_darwin.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build darwin // +build darwin package multicast diff --git a/router/dht.go b/router/dht.go index 0f46edd7..6ecd8eff 100644 --- a/router/dht.go +++ b/router/dht.go @@ -31,7 +31,7 @@ import ( type dhtEntry interface { PublicKey() types.PublicKey Coordinates() types.SwitchPorts - SeenRecently() bool + Alive() bool } type dht struct { @@ -59,7 +59,7 @@ func (d *dht) table() []dhtEntry { results := make([]dhtEntry, 0, len(d.sorted)) for _, n := range d.sorted { - if n.SeenRecently() { + if n.Alive() { results = append(results, n) } } diff --git a/router/nexthop.go b/router/nexthop.go index 62de91c2..80b199ad 100644 --- a/router/nexthop.go +++ b/router/nexthop.go @@ -31,7 +31,9 @@ func (p *Peer) getNextHops(frame *types.Frame, from types.SwitchPortID) types.Sw switch frame.Type { case types.TypeSTP: if from != 0 { - p.r.handleAnnouncement(p, frame) + if err := p.r.handleAnnouncement(p, frame); err != nil { + p.r.log.Println("Failed to handle announcement:", err) + } } case types.TypeVirtualSnakeBootstrap: diff --git a/router/nexthop_greedy.go b/router/nexthop_greedy.go index a38959a6..6d46eb39 100644 --- a/router/nexthop_greedy.go +++ b/router/nexthop_greedy.go @@ -58,12 +58,14 @@ func (r *Router) getGreedyRoutedNextHop(from *Peer, rx *types.Frame) types.Switc peerCoords := p.Coordinates() peerDist := int64(peerCoords.DistanceTo(rx.Destination)) switch { - case peerDist == 0: - return []types.SwitchPortID{p.port} - case rx.Destination.EqualTo(peerCoords): + case peerDist == 0 || rx.Destination.EqualTo(peerCoords): + // The peer is the actual destination. return []types.SwitchPortID{p.port} + case peerDist < bestDist: + // The peer is closer to the destination. bestPeer, bestDist = p.port, peerDist + default: } } diff --git a/router/nexthop_source.go b/router/nexthop_source.go index 140f5f91..a652c4f1 100644 --- a/router/nexthop_source.go +++ b/router/nexthop_source.go @@ -51,7 +51,7 @@ func (r *Router) getSourceRoutedNextHop(from *Peer, rx *types.Frame) types.Switc return nil } - if peer := r.ports[to]; !peer.started.Load() || !peer.alive.Load() { + if peer := r.ports[to]; !peer.started.Load() || !peer.Alive() { // Don't try to send packets to a port that has nothing // connected to it or isn't alive. return nil diff --git a/router/nexthop_virtualsnake.go b/router/nexthop_virtualsnake.go index 361468de..a021bcbe 100644 --- a/router/nexthop_virtualsnake.go +++ b/router/nexthop_virtualsnake.go @@ -26,7 +26,7 @@ import ( "github.com/matrix-org/pinecone/util" ) -const virtualSnakeNeighExpiryPeriod = time.Minute * 30 +const virtualSnakeNeighExpiryPeriod = time.Hour type virtualSnake struct { r *Router @@ -36,6 +36,7 @@ type virtualSnake struct { ascendingMutex sync.RWMutex _descending *virtualSnakeNeighbour descendingMutex sync.RWMutex + bootstrap *time.Timer } type virtualSnakeIndex struct { @@ -72,6 +73,8 @@ func newVirtualSnake(r *Router) *virtualSnake { r: r, table: make(virtualSnakeTable), } + snake.bootstrap = time.AfterFunc(time.Second, snake.bootstrapNow) + snake.bootstrap.Stop() go snake.maintain() return snake } @@ -105,37 +108,34 @@ func (t *virtualSnake) setDescending(desc *virtualSnakeNeighbour) { // bootstraps and setup messages as needed. func (t *virtualSnake) maintain() { for { - peerCount := t.r.PeerCount(-1) - bootstrapNow := false select { case <-t.r.context.Done(): return case <-time.After(time.Second): } - if peerCount == 0 { - // If there are no peers connected then we don't need - // to do any hard maintenance work. - continue - } + + rootKey := t.r.RootPublicKey() + canBootstrap := rootKey != t.r.public && t.r.PeerCount(-1) > 0 + willBootstrap := false if asc := t.ascending(); asc != nil { switch { case time.Since(asc.LastSeen) > virtualSnakeNeighExpiryPeriod: t.teardownPath(t.r.public, asc.PathID, asc.Port, true, fmt.Errorf("ascending neighbour expired")) - bootstrapNow = true - case asc.RootPublicKey != t.r.RootPublicKey(): + willBootstrap = canBootstrap + case asc.RootPublicKey != rootKey: t.teardownPath(t.r.public, asc.PathID, asc.Port, true, fmt.Errorf("ascending root changed")) - bootstrapNow = true + willBootstrap = canBootstrap } } else { - bootstrapNow = true + willBootstrap = canBootstrap } if desc := t.descending(); desc != nil { switch { case time.Since(desc.LastSeen) > virtualSnakeNeighExpiryPeriod: t.teardownPath(desc.PublicKey, desc.PathID, desc.Port, false, fmt.Errorf("descending neighbour expired")) - case !desc.RootPublicKey.EqualTo(t.r.RootPublicKey()): + case !desc.RootPublicKey.EqualTo(rootKey): t.teardownPath(desc.PublicKey, desc.PathID, desc.Port, false, fmt.Errorf("descending root changed")) } } @@ -145,14 +145,18 @@ func (t *virtualSnake) maintain() { // predefined interval, but for now we'll continue to send // them on a regular interval until we can derive some better // connection state. - if bootstrapNow { - t.bootstrapNow() + if willBootstrap { + t.bootstrapIn(time.Second) } } } +func (t *virtualSnake) bootstrapIn(d time.Duration) { + t.bootstrap.Reset(d) +} + func (t *virtualSnake) bootstrapNow() { - if t.r.IsRoot() { + if t.r.IsRoot() || t.r.PeerCount(-1) == 0 { return } ts, err := util.SignedTimestamp(t.r.private) @@ -180,7 +184,7 @@ func (t *virtualSnake) bootstrapNow() { func (t *virtualSnake) rootNodeChanged(root types.PublicKey) { if asc := t.ascending(); asc != nil && !asc.RootPublicKey.EqualTo(root) { t.teardownPath(t.r.public, asc.PathID, asc.Port, true, fmt.Errorf("root changed and asc no longer matches")) - defer t.bootstrapNow() + defer t.bootstrapIn(time.Second) } if desc := t.descending(); desc != nil && !desc.RootPublicKey.EqualTo(root) { t.teardownPath(desc.PublicKey, desc.PathID, desc.Port, false, fmt.Errorf("root changed and desc no longer matches")) @@ -203,10 +207,6 @@ func (t *virtualSnake) rootNodeChanged(root types.PublicKey) { } } -func (t *virtualSnake) portWasConnected(port types.SwitchPortID) { - // time.AfterFunc(time.Second, t.bootstrapNow) -} - // portWasDisconnected is called by the router when a peer disconnects // allowing us to clean up the virtual snake state. func (t *virtualSnake) portWasDisconnected(port types.SwitchPortID) { @@ -226,6 +226,7 @@ func (t *virtualSnake) portWasDisconnected(port types.SwitchPortID) { // this port change. if asc := t.ascending(); asc != nil && asc.Port == port { t.teardownPath(t.r.public, asc.PathID, asc.Port, true, fmt.Errorf("port teardown")) + defer t.bootstrapIn(0) } if desc := t.descending(); desc != nil && desc.Port == port { t.teardownPath(desc.PublicKey, desc.PathID, desc.Port, false, fmt.Errorf("port teardown")) @@ -313,8 +314,11 @@ func (t *virtualSnake) getVirtualSnakeNextHop(from *Peer, destKey types.PublicKe newCheckedCandidate(ancestor, parentPort) } + // The next section needs us to check direct peers + activePorts := t.r.activePorts() + // Check our direct peers ancestors - for _, peer := range t.r.activePorts() { + for _, peer := range activePorts { peerAnn := peer.lastAnnouncement() if peerAnn == nil { continue @@ -325,8 +329,20 @@ func (t *virtualSnake) getVirtualSnakeNextHop(from *Peer, destKey types.PublicKe } } + // Check our DHT entries + t.tableMutex.RLock() + for dhtKey, entry := range t.table { + switch { + case !entry.Valid(): + continue + default: + newCheckedCandidate(dhtKey.PublicKey, entry.SourcePort) + } + } + t.tableMutex.RUnlock() + // Check our direct peers - for _, peer := range t.r.activePorts() { + for _, peer := range activePorts { peerKey := peer.PublicKey() switch { case bestKey.EqualTo(peerKey): @@ -335,20 +351,11 @@ func (t *virtualSnake) getVirtualSnakeNextHop(from *Peer, destKey types.PublicKe // are directly peered with that node, so use the more direct // path instead newCandidate(peerKey, peer.port) + default: + newCheckedCandidate(peerKey, peer.port) } } - // Check our DHT entries - t.tableMutex.RLock() - for dhtKey, entry := range t.table { - switch { - case !entry.Valid(): - continue - } - newCheckedCandidate(dhtKey.PublicKey, entry.SourcePort) - } - t.tableMutex.RUnlock() - if bootstrap { return types.SwitchPorts{bestPort} } else { @@ -372,7 +379,7 @@ func (t *virtualSnake) getVirtualSnakeTeardownNextHop(from *Peer, rx *types.Fram } if asc := t.ascending(); asc != nil && t.r.public.EqualTo(rx.DestinationKey) && asc.PathID == teardown.PathID { t.setAscending(nil) - defer time.AfterFunc(time.Millisecond*500, t.bootstrapNow) + defer t.bootstrapIn(time.Second / 4) } t.tableMutex.Lock() defer t.tableMutex.Unlock() @@ -413,10 +420,9 @@ func (t *virtualSnake) teardownPath(pk types.PublicKey, pathID types.VirtualSnak Payload: payload[:], } _ = t.getVirtualSnakeTeardownNextHop(t.r.ports[0], &frame) - if !t.r.ports[via].started.Load() { - return + if t.r.ports[via].started.Load() { + t.r.ports[via].protoOut.push(frame.Borrow()) } - t.r.ports[via].protoOut.push(frame.Borrow()) } // handleBootstrap is called in response to an incoming bootstrap @@ -437,7 +443,7 @@ func (t *virtualSnake) handleBootstrap(from *Peer, rx *types.Frame) error { return fmt.Errorf("util.VerifySignedTimestamp") } if !bootstrap.RootPublicKey.EqualTo(t.r.RootPublicKey()) { - return fmt.Errorf("root key doesn't match") + return fmt.Errorf("bootstrap root key doesn't match") } bootstrapACK := types.VirtualSnakeBootstrapACK{ // nolint:gosimple PathID: bootstrap.PathID, @@ -474,7 +480,7 @@ func (t *virtualSnake) handleBootstrapACK(from *Peer, rx *types.Frame) error { return fmt.Errorf("util.VerifySignedTimestamp") } if !bootstrapACK.RootPublicKey.EqualTo(t.r.RootPublicKey()) { - return fmt.Errorf("root key doesn't match") + return fmt.Errorf("bootstrap ACK root key doesn't match") } update := false asc := t.ascending() @@ -579,12 +585,12 @@ func (t *virtualSnake) handleSetup(from *Peer, rx *types.Frame, nextHops types.S } if !setup.RootPublicKey.EqualTo(t.r.RootPublicKey()) { t.teardownPath(rx.SourceKey, setup.PathID, from.port, false, fmt.Errorf("rejecting setup (root key doesn't match)")) - return fmt.Errorf("root key doesn't match") + return fmt.Errorf("setup root key doesn't match") } // Did the setup hit a dead end on the way to the ascending node? if nextHops.EqualTo(types.SwitchPorts{0}) || nextHops.EqualTo(types.SwitchPorts{}) { - if !rx.DestinationKey.EqualTo(t.r.public) || !rx.Destination.EqualTo(t.r.Coords()) { + if !rx.DestinationKey.EqualTo(t.r.public) { t.teardownPath(rx.SourceKey, setup.PathID, from.port, false, fmt.Errorf("rejecting setup (hit dead end)")) return fmt.Errorf("setup for %q (%s) en route to %q %s hit dead end at %s", rx.SourceKey, hex.EncodeToString(setup.PathID[:]), rx.DestinationKey, rx.Destination, t.r.Coords()) } diff --git a/router/peer.go b/router/peer.go index 1c552bfc..602fb5f4 100644 --- a/router/peer.go +++ b/router/peer.go @@ -18,7 +18,7 @@ import ( "bytes" "context" "encoding/binary" - "errors" + "encoding/hex" "fmt" "io" "net" @@ -29,7 +29,7 @@ import ( "go.uber.org/atomic" ) -const PeerKeepaliveInterval = time.Second * 2 +const PeerKeepaliveInterval = time.Second * 3 const PeerKeepaliveTimeout = PeerKeepaliveInterval * 3 const ( @@ -41,8 +41,9 @@ const ( type Peer struct { r *Router // port types.SwitchPortID // + allocated atomic.Bool // started atomic.Bool // worker goroutines started? - alive atomic.Bool // have we received a handshake? + wg *sync.WaitGroup // wait group for worker goroutines mutex sync.RWMutex // protects everything below this line zone string // peertype int // @@ -50,14 +51,110 @@ type Peer struct { cancel context.CancelFunc // conn net.Conn // underlying connection to peer public types.PublicKey // - trafficOut queue // queue traffic message to peer - protoOut queue // queue protocol message to peer + trafficOut *lifoQueue // queue traffic message to peer + protoOut *fifoQueue // queue protocol message to peer coords types.SwitchPorts // - announce chan struct{} // - announcement *rootAnnouncementWithTime // + announcement *rootAnnouncementWithTime // last received announcement from peer statistics peerStatistics // } +func (p *Peer) start() { + if !p.started.CAS(false, true) { + return + } + + // When the peer dies, we need to clean up. + var lasterr error + var lasterrMutex sync.Mutex + defer func() { + p.reset() + p.r.snake.portWasDisconnected(p.port) + p.r.tree.portWasDisconnected(p.port) + go p.r.callbacks.onDisconnected(p.port, p.public, p.peertype, lasterr) + }() + + // Store the fact that we're connected to this public key in + // this zone, so that the multicast code can ignore nodes we + // are already connected to. + index := hex.EncodeToString(p.public[:]) + p.zone + p.r.active.Store(index, p.port) + defer p.r.active.Delete(index) + + // Push a root update to our new peer. This will notify them + // of our coordinates and that we are alive. + p.protoOut.push(p.r.tree.Root().ForPeer(p)) + + // Start the reader and writer goroutines for this peer. + p.wg.Add(2) + go func() { + if err := p.reader(p.context); err != nil && err != context.Canceled { + lasterrMutex.Lock() + lasterr = fmt.Errorf("reader error: %w", err) + lasterrMutex.Unlock() + } + }() + go func() { + if err := p.writer(p.context); err != nil && err != context.Canceled { + lasterrMutex.Lock() + lasterr = fmt.Errorf("writer error: %w", err) + lasterrMutex.Unlock() + } + }() + + // Report the new connection. + if p.port != 0 { + p.r.dht.insertNode(p) + } + if p.r.simulator != nil { + p.r.simulator.ReportNewLink(p.conn, p.r.public, p.public) + } + go p.r.callbacks.onConnected(p.port, p.public, p.peertype) + + p.r.log.Printf("Connected port %d to %s (zone %q)\n", p.port, p.conn.RemoteAddr(), p.zone) + + // Wait for the cancellation, and then for the goroutines to stop. + <-p.context.Done() + p.started.Store(false) + p.wg.Wait() + + // Report the disconnection. + if p.port != 0 { + p.r.dht.deleteNode(p.public) + } + if p.r.simulator != nil { + p.r.simulator.ReportDeadLink(p.r.public, p.public) + } + + // Make sure the connection is closed. + _ = p.conn.Close() + + // ... and finally, yell about it. + lasterrMutex.Lock() + if lasterr != nil { + p.r.log.Printf("Disconnected port %d: %s\n", p.port, lasterr) + } else { + p.r.log.Printf("Disconnected port %d\n", p.port) + } + lasterrMutex.Unlock() +} + +func (p *Peer) reset() { + p.mutex.Lock() + defer p.mutex.Unlock() + p.allocated.Store(false) + p.started.Store(false) + p.zone = "" + p.peertype = 0 + p.context, p.cancel = nil, nil + p.conn = nil + p.public = types.PublicKey{} + p.trafficOut.reset() + p.protoOut.reset() + p.statistics.reset() + p.coords = nil + p.announcement = nil +} + type peerStatistics struct { txProtoSuccessful atomic.Uint64 txProtoDropped atomic.Uint64 @@ -76,6 +173,16 @@ func (s *peerStatistics) reset() { s.rxTraffic.Store(0) } +func (p *Peer) IsParent() bool { + return p.r.tree.Parent() == p.port +} + +func (p *Peer) Alive() bool { + p.mutex.RLock() + defer p.mutex.RUnlock() + return p.announcement != nil && time.Since(p.announcement.at) < announcementTimeout +} + func (p *Peer) PublicKey() types.PublicKey { p.mutex.RLock() defer p.mutex.RUnlock() @@ -89,7 +196,7 @@ func (p *Peer) Coordinates() types.SwitchPorts { } func (p *Peer) SeenCommonRootRecently() bool { - if !p.alive.Load() { + if !p.Alive() { return false } last := p.lastAnnouncement() @@ -101,31 +208,21 @@ func (p *Peer) SeenCommonRootRecently() bool { return lpk == rpk } -func (p *Peer) SeenRecently() bool { - if last := p.lastAnnouncement(); last != nil { - return true - } - return false -} - func (p *Peer) updateAnnouncement(new *types.SwitchAnnouncement) error { p.mutex.Lock() defer p.mutex.Unlock() - coords, err := new.PeerCoords(p.public) - if err != nil { - p.alive.Store(false) - p.announcement = nil - p.coords = nil - return fmt.Errorf("new.PeerCoords: %w", err) - } - if p.alive.CAS(false, true) { - p.r.snake.portWasConnected(p.port) + if p.announcement != nil { + if new.RootPublicKey == p.announcement.RootPublicKey && new.Sequence < p.announcement.Sequence { + p.announcement = nil + p.coords = nil + return fmt.Errorf("root announcement replays sequence number") + } } p.announcement = &rootAnnouncementWithTime{ SwitchAnnouncement: *new, at: time.Now(), } - p.coords = coords + p.coords = new.PeerCoords() return nil } @@ -141,96 +238,50 @@ func (p *Peer) lastAnnouncement() *rootAnnouncementWithTime { return p.announcement } -func (p *Peer) start() error { - if !p.started.CAS(false, true) { - return errors.New("switch peer is already started") - } - p.alive.Store(false) - go p.reader(p.context) - go p.writer(p.context) - return nil -} - -func (p *Peer) stop() error { - if !p.started.CAS(true, false) { - return errors.New("switch peer is already stopped") - } - p.alive.Store(false) +func (p *Peer) stop() { + p.started.Store(false) + p.mutex.Lock() p.cancel() - _ = p.conn.Close() - return nil + p.mutex.Unlock() } -/* func (p *Peer) generateKeepalive() *types.Frame { frame := types.GetFrame() frame.Version = types.Version0 frame.Type = types.TypeKeepalive return frame } -*/ -func (p *Peer) generateAnnouncement() *types.Frame { - if p.port == 0 { - return nil - } - announcement := p.r.tree.Root() - for _, sig := range announcement.Signatures { - if p.r.public.EqualTo(sig.PublicKey) { - // For some reason the announcement that we want to send already - // includes our signature. This shouldn't really happen but if we - // did send it, other nodes would end up ignoring the announcement - // anyway since it would appear to be a routing loop. - return nil - } - } - // Sign the announcement. - if err := announcement.Sign(p.r.private[:], p.port); err != nil { - p.r.log.Println("Failed to sign switch announcement:", err) - return nil - } - var payload [MaxPayloadSize]byte - n, err := announcement.MarshalBinary(payload[:]) - if err != nil { - p.r.log.Println("Failed to marshal switch announcement:", err) - return nil - } - frame := types.GetFrame() - frame.Version = types.Version0 - frame.Type = types.TypeSTP - frame.Destination = types.SwitchPorts{} - frame.Payload = payload[:n] - return frame -} +func (p *Peer) reader(ctx context.Context) error { + defer p.wg.Done() + defer p.cancel() -func (p *Peer) reader(ctx context.Context) { buf := make([]byte, MaxFrameSize) for { select { case <-ctx.Done(): // The switch peer is shutting down. - return + return context.Canceled default: if p.port != 0 { if err := p.conn.SetReadDeadline(time.Now().Add(PeerKeepaliveTimeout)); err != nil { - _ = p.r.Disconnect(p.port, fmt.Errorf("conn.SetReadDeadline: %w", err)) - return + return fmt.Errorf("conn.SetReadDeadline: %w", err) } } if _, err := io.ReadFull(p.conn, buf[:8]); err != nil { - _ = p.r.Disconnect(p.port, fmt.Errorf("p.conn.Peek: %w", err)) - return + return fmt.Errorf("p.conn.Peek: %w", err) + } + if err := p.conn.SetReadDeadline(time.Time{}); err != nil { + return fmt.Errorf("conn.SetReadDeadline: %w", err) } if !bytes.Equal(buf[:4], types.FrameMagicBytes) { - _ = p.r.Disconnect(p.port, fmt.Errorf("missing magic bytes")) - return + return fmt.Errorf("missing magic bytes") } expecting := int(binary.BigEndian.Uint16(buf[6:8])) n, err := io.ReadFull(p.conn, buf[8:expecting]) if err != nil { - _ = p.r.Disconnect(p.port, fmt.Errorf("io.ReadFull: %w", err)) - return + return fmt.Errorf("io.ReadFull: %w", err) } if n < expecting-8 { p.r.log.Println("Expecting", expecting, "bytes but got", n, "bytes") @@ -238,16 +289,13 @@ func (p *Peer) reader(ctx context.Context) { } frame := types.GetFrame() if _, err := frame.UnmarshalBinary(buf[:n+8]); err != nil { - p.r.log.Println("Port", p.port, "error unmarshalling frame:", err) - frame.Done() - return - } - if frame.Version != types.Version0 { - p.r.log.Println("Port", p.port, "incorrect version in frame") frame.Done() - return + return fmt.Errorf("frame.UnmarshalBinary: %w", err) } - if frame.Type == types.TypeKeepalive { + switch { + case frame.Version != types.Version0: + fallthrough + case frame.Type == types.TypeKeepalive: frame.Done() continue } @@ -257,7 +305,7 @@ func (p *Peer) reader(ctx context.Context) { for _, port := range p.getNextHops(frame, p.port) { // Ignore ports that are not good candidates. dest := p.r.ports[port] - if !dest.started.Load() || (dest.port != 0 && !dest.alive.Load()) { + if !dest.started.Load() || (dest.port != 0 && !dest.Alive()) { continue } if p.port != 0 && dest.port != 0 { @@ -310,85 +358,64 @@ var bufPool = sync.Pool{ }, } -func (p *Peer) writer(ctx context.Context) { - //tick := time.NewTicker(PeerKeepaliveInterval) - //defer tick.Stop() - send := func(frame *types.Frame) { +func (p *Peer) writer(ctx context.Context) error { + defer p.wg.Done() + defer p.cancel() + + tick := time.NewTicker(PeerKeepaliveInterval) + defer tick.Stop() + + send := func(frame *types.Frame) error { if frame == nil { - return + return nil } buf := bufPool.Get().([]byte) defer bufPool.Put(buf) // nolint:staticcheck fn, err := frame.MarshalBinary(buf) frame.Done() if err != nil { - p.r.log.Println("Port", p.port, "error marshalling frame:", err) - return + return nil } if !bytes.Equal(buf[:4], types.FrameMagicBytes) { - panic("expected magic bytes") + return nil } - remaining := buf[:fn] - for len(remaining) > 0 { - n, err := p.conn.Write(remaining) - if err != nil { - _ = p.r.Disconnect(p.port, fmt.Errorf("p.conn.Write: %w", err)) - return - } - remaining = remaining[n:] + if err := p.conn.SetWriteDeadline(time.Now().Add(PeerKeepaliveTimeout)); err != nil { + return fmt.Errorf("p.conn.SetWriteDeadline: %w", err) } + if _, err = p.conn.Write(buf[:fn]); err != nil { + return fmt.Errorf("p.conn.Write: %w", err) + } + if err := p.conn.SetWriteDeadline(time.Time{}); err != nil { + return fmt.Errorf("p.conn.SetWriteDeadline: %w", err) + } + return nil } - // The very first thing we send should be a tree announcement, - // so that the remote side can work out our coords and consider - // us to be "alive". - send(p.generateAnnouncement()) - for { select { case <-ctx.Done(): - return - case <-p.announce: - send(p.generateAnnouncement()) + return context.Canceled + case frame := <-p.protoOut.pop(): + if err := send(frame); err != nil { + return fmt.Errorf("send: %w", err) + } + p.protoOut.ack() p.statistics.txProtoSuccessful.Inc() continue default: } select { case <-ctx.Done(): - return - case <-p.announce: - send(p.generateAnnouncement()) - p.statistics.txProtoSuccessful.Inc() - continue - case <-p.protoOut.wait(): - if frame, ok := p.protoOut.pop(); ok { - send(frame) - p.statistics.txProtoSuccessful.Inc() - } else { - p.statistics.txProtoDropped.Inc() + return context.Canceled + case frame := <-p.protoOut.pop(): + if err := send(frame); err != nil { + return fmt.Errorf("send: %w", err) } - continue - default: - } - select { - case <-ctx.Done(): - return - case <-p.announce: - send(p.generateAnnouncement()) + p.protoOut.ack() p.statistics.txProtoSuccessful.Inc() continue - case <-p.protoOut.wait(): - if frame, ok := p.protoOut.pop(); ok { - send(frame) - p.statistics.txProtoSuccessful.Inc() - } else { - p.statistics.txProtoDropped.Inc() - } - continue case <-p.trafficOut.wait(): - if frame, ok := p.trafficOut.pop(); ok { - send(frame) + if frame, ok := p.trafficOut.pop(); ok && send(frame) == nil { p.statistics.txTrafficSuccessful.Inc() } else { p.statistics.txTrafficDropped.Inc() @@ -398,31 +425,27 @@ func (p *Peer) writer(ctx context.Context) { } select { case <-ctx.Done(): - return - case <-p.announce: - send(p.generateAnnouncement()) - p.statistics.txProtoSuccessful.Inc() - continue - case <-p.protoOut.wait(): - if frame, ok := p.protoOut.pop(); ok { - send(frame) - p.statistics.txProtoSuccessful.Inc() - } else { - p.statistics.txProtoDropped.Inc() + return context.Canceled + case frame := <-p.protoOut.pop(): + if err := send(frame); err != nil { + return fmt.Errorf("send: %w", err) } + p.protoOut.ack() + p.statistics.txProtoSuccessful.Inc() continue case <-p.trafficOut.wait(): - if frame, ok := p.trafficOut.pop(); ok { - send(frame) + if frame, ok := p.trafficOut.pop(); ok && send(frame) == nil { p.statistics.txTrafficSuccessful.Inc() } else { p.statistics.txTrafficDropped.Inc() } continue - //case <-tick.C: - // send(p.generateKeepalive()) - // p.statistics.txProtoSuccessful.Inc() - // continue + case <-tick.C: + if err := send(p.generateKeepalive()); err != nil { + return fmt.Errorf("send: %w", err) + } + p.statistics.txProtoSuccessful.Inc() + continue } } } diff --git a/router/queue.go b/router/queue.go deleted file mode 100644 index a2219a14..00000000 --- a/router/queue.go +++ /dev/null @@ -1,12 +0,0 @@ -package router - -import "github.com/matrix-org/pinecone/types" - -type queue interface { - queuecount() int - queuesize() int - push(frame *types.Frame) bool - pop() (*types.Frame, bool) - reset() - wait() <-chan struct{} -} diff --git a/router/queuefifo.go b/router/queuefifo.go index 0bcb4381..43e8e829 100644 --- a/router/queuefifo.go +++ b/router/queuefifo.go @@ -7,85 +7,77 @@ import ( ) type fifoQueue struct { - frames []*types.Frame - count int - mutex sync.Mutex - notifs chan struct{} + entries []chan *types.Frame + mutex sync.Mutex } func newFIFOQueue() *fifoQueue { - q := &fifoQueue{ - notifs: make(chan struct{}), - } + q := &fifoQueue{} + q.reset() return q } +func (q *fifoQueue) _initialise() { + for i := range q.entries { + q.entries[i] = nil + } + q.entries = []chan *types.Frame{ + make(chan *types.Frame, 1), + } +} + func (q *fifoQueue) queuecount() int { q.mutex.Lock() defer q.mutex.Unlock() - return q.count + return len(q.entries) } func (q *fifoQueue) queuesize() int { q.mutex.Lock() defer q.mutex.Unlock() - return cap(q.frames) + return cap(q.entries) } func (q *fifoQueue) push(frame *types.Frame) bool { q.mutex.Lock() defer q.mutex.Unlock() - q.frames = append(q.frames, frame) - q.count++ - select { - case q.notifs <- struct{}{}: - default: + if len(q.entries) == 0 { + q._initialise() } + ch := q.entries[len(q.entries)-1] + ch <- frame + close(ch) + q.entries = append(q.entries, make(chan *types.Frame, 1)) return true } -func (q *fifoQueue) pop() (*types.Frame, bool) { +func (q *fifoQueue) reset() { q.mutex.Lock() defer q.mutex.Unlock() - if q.count == 0 { - return nil, false - } - frame := q.frames[0] - q.frames[0] = nil - q.frames = q.frames[1:] - q.count-- - if q.count == 0 { - // Force a GC of the underlying array, since it might have - // grown significantly if the queue was hammered for some reason - q.frames = nil + for _, ch := range q.entries { + select { + case frame := <-ch: + if frame != nil { + frame.Done() + } + default: + } } - return frame, true + q._initialise() } -func (q *fifoQueue) reset() { +func (q *fifoQueue) pop() <-chan *types.Frame { q.mutex.Lock() defer q.mutex.Unlock() - q.count = 0 - for i := range q.frames { - if q.frames[i] != nil { - q.frames[i].Done() - q.frames[i] = nil - } - } - q.frames = nil - close(q.notifs) - for range q.notifs { + if len(q.entries) == 0 { + q._initialise() } - q.notifs = make(chan struct{}) + entry := q.entries[0] + return entry } -func (q *fifoQueue) wait() <-chan struct{} { +func (q *fifoQueue) ack() { q.mutex.Lock() defer q.mutex.Unlock() - if q.count > 0 { - ch := make(chan struct{}) - close(ch) - return ch - } - return q.notifs + q.entries = q.entries[1:] } diff --git a/router/queuefifo_test.go b/router/queuefifo_test.go new file mode 100644 index 00000000..cdc91988 --- /dev/null +++ b/router/queuefifo_test.go @@ -0,0 +1,34 @@ +package router + +import ( + "testing" + + "github.com/matrix-org/pinecone/types" +) + +func TestFIFOQueue(t *testing.T) { + q := newFIFOQueue() + iterations := types.SwitchPortID(1024) + + go func() { + for i := types.SwitchPortID(0); i < iterations; i++ { + q.push(&types.Frame{ + Destination: types.SwitchPorts{i}, + }) + } + }() + + got := types.SwitchPortID(0) + + for i := types.SwitchPortID(0); i < iterations; i++ { + frame := <-q.pop() + if frame == nil { + t.Fatalf("unexpected nil frame") + } + if frame.Destination[0] < got || frame.Destination[0] > got+1 { + t.Fatalf("ordering problem") + } + got = frame.Destination[0] + q.ack() + } +} diff --git a/router/router.go b/router/router.go index 7f8d12bf..d398f36c 100644 --- a/router/router.go +++ b/router/router.go @@ -25,7 +25,6 @@ import ( "net" "sort" "sync" - "time" "github.com/matrix-org/pinecone/types" "go.uber.org/atomic" @@ -110,7 +109,7 @@ func NewRouter(log *log.Logger, id string, private ed25519.PrivateKey, public ed sw.ports[i] = &Peer{ r: sw, port: types.SwitchPortID(i), - announce: make(chan struct{}), + wg: &sync.WaitGroup{}, protoOut: newFIFOQueue(), trafficOut: newLIFOQueue(TrafficBufferSize), } @@ -148,7 +147,7 @@ func (r *Router) Close() error { r.cancel() for _, port := range r.ports { if port.started.Load() { - _ = port.stop() + port.stop() } } return nil @@ -313,10 +312,7 @@ func (r *Router) startedPorts() peers { func (r *Router) activePorts() peers { peers := make(peers, 0, PortCount) for _, p := range r.startedPorts() { - switch { - case !p.alive.Load(): - continue - default: + if p.Alive() { peers = append(peers, p) } } @@ -330,52 +326,49 @@ func (r *Router) activePorts() peers { // the node was connected to will be returned in the event // of a successful connection. func (r *Router) AuthenticatedConnect(conn net.Conn, zone string, peertype int) (types.SwitchPortID, error) { - select { - case <-time.After(time.Second * 5): - return 0, fmt.Errorf("handshake timed out") - default: - handshake := []byte{ - ourVersion, - ourCapabilities, - 0, // unused - 0, // unused - } - handshake = append(handshake, r.public[:ed25519.PublicKeySize]...) - handshake = append(handshake, ed25519.Sign(r.private[:], handshake)...) - _ = conn.SetDeadline(time.Now().Add(time.Second * 5)) - if _, err := conn.Write(handshake); err != nil { - conn.Close() - return 0, fmt.Errorf("conn.Write: %w", err) - } - if _, err := io.ReadFull(conn, handshake); err != nil { - conn.Close() - return 0, fmt.Errorf("io.ReadFull: %w", err) - } - _ = conn.SetDeadline(time.Time{}) - if theirVersion := handshake[0]; theirVersion != ourVersion { - conn.Close() - return 0, fmt.Errorf("mismatched node version") - } - if theirCapabilities := handshake[1]; theirCapabilities&ourCapabilities != ourCapabilities { - conn.Close() - return 0, fmt.Errorf("mismatched node capabilities") - } - var public types.PublicKey - var signature types.Signature - offset := 4 - offset += copy(public[:], handshake[offset:offset+ed25519.PublicKeySize]) - copy(signature[:], handshake[offset:offset+ed25519.SignatureSize]) - if !ed25519.Verify(public[:], handshake[:offset], signature[:]) { - conn.Close() - return 0, fmt.Errorf("peer sent invalid signature") - } - port, err := r.Connect(conn, public, zone, peertype) - if err != nil { - conn.Close() - return 0, err - } - return port, nil + handshake := []byte{ + ourVersion, + ourCapabilities, + 0, // unused + 0, // unused + } + handshake = append(handshake, r.public[:ed25519.PublicKeySize]...) + handshake = append(handshake, ed25519.Sign(r.private[:], handshake)...) + //_ = conn.SetWriteDeadline(time.Now().Add(time.Second * 5)) + if _, err := conn.Write(handshake); err != nil { + conn.Close() + return 0, fmt.Errorf("conn.Write: %w", err) + } + //_ = conn.SetWriteDeadline(time.Time{}) + //_ = conn.SetReadDeadline(time.Now().Add(time.Second * 5)) + if _, err := io.ReadFull(conn, handshake); err != nil { + conn.Close() + return 0, fmt.Errorf("io.ReadFull: %w", err) + } + //_ = conn.SetReadDeadline(time.Time{}) + if theirVersion := handshake[0]; theirVersion != ourVersion { + conn.Close() + return 0, fmt.Errorf("mismatched node version") + } + if theirCapabilities := handshake[1]; theirCapabilities&ourCapabilities != ourCapabilities { + conn.Close() + return 0, fmt.Errorf("mismatched node capabilities") } + var public types.PublicKey + var signature types.Signature + offset := 4 + offset += copy(public[:], handshake[offset:offset+ed25519.PublicKeySize]) + copy(signature[:], handshake[offset:offset+ed25519.SignatureSize]) + if !ed25519.Verify(public[:], handshake[:offset], signature[:]) { + conn.Close() + return 0, fmt.Errorf("peer sent invalid signature") + } + port, err := r.Connect(conn, public, zone, peertype) + if err != nil { + conn.Close() + return 0, err + } + return port, nil } // Connect initiates a peer connection using the given @@ -384,43 +377,33 @@ func (r *Router) AuthenticatedConnect(conn net.Conn, zone string, peertype int) // port number that the node was connected to will be // returned in the event of a successful connection. func (r *Router) Connect(conn net.Conn, public types.PublicKey, zone string, peertype int) (types.SwitchPortID, error) { - if p, ok := r.active.Load(hex.EncodeToString(public[:]) + zone); ok { - _ = r.Disconnect(p.(types.SwitchPortID), fmt.Errorf("another incoming connection from the same node")) + if r.IsConnected(public, zone) { + _ = conn.Close() + return 0, fmt.Errorf("already connected") } r.connections.Lock() defer r.connections.Unlock() + usePort := func(port *Peer) bool { + if !port.allocated.CAS(false, true) { + return false + } + port.mutex.Lock() + port.context, port.cancel = context.WithCancel(r.context) + port.zone = zone + port.peertype = peertype + port.conn = conn + port.public = public + port.mutex.Unlock() + go port.start() + return true + } for i := types.SwitchPortID(0); i < PortCount; i++ { if i != 0 && bytes.Equal(r.public[:], public[:]) { return 0, fmt.Errorf("loopback connection prohibited") } - if r.ports[i].started.Load() { - continue - } - r.ports[i].mutex.Lock() - r.ports[i].context, r.ports[i].cancel = context.WithCancel(r.context) - r.ports[i].zone = zone - r.ports[i].peertype = peertype - r.ports[i].conn = conn // util.NewBufferedRWCSize(conn, MaxFrameSize) - r.ports[i].public = public - r.ports[i].coords = nil - r.ports[i].protoOut.reset() - r.ports[i].trafficOut.reset() - r.ports[i].announcement = nil - r.ports[i].statistics.reset() - r.ports[i].mutex.Unlock() - if err := r.ports[i].start(); err != nil { - return 0, fmt.Errorf("port.start: %w", err) + if usePort(r.ports[i]) { + return i, nil } - r.active.Store(hex.EncodeToString(public[:])+zone, i) - if i != 0 { - r.dht.insertNode(r.ports[i]) - } - if r.simulator != nil { - r.simulator.ReportNewLink(conn, r.public, public) - } - go r.callbacks.onConnected(i, public, peertype) - r.log.Printf("Connected port %d to %s (zone %q)\n", i, conn.RemoteAddr(), zone) - return i, nil } return 0, fmt.Errorf("no free switch ports") } @@ -433,31 +416,7 @@ func (r *Router) Disconnect(i types.SwitchPortID, err error) error { if i == 0 { return fmt.Errorf("cannot disconnect port %d", i) } - r.connections.Lock() - defer r.connections.Unlock() - if stoperr := r.ports[i].stop(); stoperr != nil { - return fmt.Errorf("port.stop: %w", stoperr) - } - r.active.Delete(hex.EncodeToString(r.ports[i].public[:]) + r.ports[i].zone) - r.ports[i].mutex.Lock() - r.ports[i].peertype = 0 - r.ports[i].zone = "" - r.ports[i].public = types.PublicKey{} - r.ports[i].coords = nil - r.ports[i].announcement = nil - r.ports[i].protoOut.reset() - r.ports[i].trafficOut.reset() - r.ports[i].mutex.Unlock() - if r.ports[i].port != 0 { - r.dht.deleteNode(r.ports[i].public) - } - if r.simulator != nil { - r.simulator.ReportDeadLink(r.public, r.ports[i].public) - } - r.log.Printf("Disconnected port %d: %s\n", i, err) - r.tree.portWasDisconnected(i) - r.snake.portWasDisconnected(i) - go r.callbacks.onDisconnected(i, r.ports[i].public, r.ports[i].peertype, err) + r.ports[i].stop() return nil } diff --git a/router/tree.go b/router/tree.go index e66d1972..2328944a 100644 --- a/router/tree.go +++ b/router/tree.go @@ -20,30 +20,57 @@ import ( "fmt" "math" "sync" - "sync/atomic" "time" "github.com/matrix-org/pinecone/types" + "go.uber.org/atomic" ) // announcementInterval is the frequency at which this // node will send root announcements to other peers. -const announcementInterval = PeerKeepaliveInterval // time.Minute * 15 +const announcementInterval = time.Minute * 15 // announcementTimeout is the amount of time that must // pass without receiving a root announcement before we // will assume that the peer is dead. const announcementTimeout = announcementInterval * 2 -func (r *Router) handleAnnouncement(peer *Peer, rx *types.Frame) { - var new types.SwitchAnnouncement - if _, err := new.UnmarshalBinary(rx.Payload); err != nil { - r.log.Println("Error unmarshalling announcement:", err) - return +func (r *Router) handleAnnouncement(peer *Peer, rx *types.Frame) error { + var newUpdate types.SwitchAnnouncement + if _, err := newUpdate.UnmarshalBinary(rx.Payload); err != nil { + return fmt.Errorf("failed to unmarshal root announcement: %w", err) } - if err := r.tree.Update(peer, new); err != nil { - r.log.Println("Error handling announcement on port", peer.port, ":", err) + sigs := make(map[string]struct{}) + for index, sig := range newUpdate.Signatures { + if index == 0 && sig.PublicKey != newUpdate.RootPublicKey { + // The first signature in the announcement must be from the + // key that claims to be the root. + return fmt.Errorf("root announcement first signature is not from the root node") + } + if sig.Hop == 0 { + // None of the hops in the update should have a port number of 0 + // as this would imply that another node has sent their router + // port, which is impossible. We'll therefore reject any update + // that tries to do that. + return fmt.Errorf("root announcement contains an invalid port number") + } + if index == len(newUpdate.Signatures)-1 && peer.PublicKey() != sig.PublicKey { + // The last signature in the announcement must be from the + // direct peer. If it isn't then it sounds like someone is + // trying to replay someone else's announcement to us. + return fmt.Errorf("root announcement last signature is not from the direct peer") + } + pk := hex.EncodeToString(sig.PublicKey[:]) + if _, ok := sigs[pk]; ok { + // One of the signatures has appeared in the update more than + // once, which would suggest that there's a loop somewhere. + return fmt.Errorf("root announcement contains a routing loop") + } + sigs[pk] = struct{}{} } + + defer r.tree.UpdateParentIfNeeded(peer, newUpdate) + return peer.updateAnnouncement(&newUpdate) } type rootAnnouncementWithTime struct { @@ -51,16 +78,49 @@ type rootAnnouncementWithTime struct { at time.Time } +func (a *rootAnnouncementWithTime) ForPeer(p *Peer) *types.Frame { + if p.port == 0 { + return nil + } + announcement := a.SwitchAnnouncement + announcement.Signatures = append([]types.SignatureWithHop{}, a.Signatures...) + for _, sig := range announcement.Signatures { + if p.r.public.EqualTo(sig.PublicKey) { + // For some reason the announcement that we want to send already + // includes our signature. This shouldn't really happen but if we + // did send it, other nodes would end up ignoring the announcement + // anyway since it would appear to be a routing loop. + return nil + } + } + // Sign the announcement. + if err := announcement.Sign(p.r.private[:], p.port); err != nil { + p.r.log.Println("Failed to sign switch announcement:", err) + return nil + } + var payload [MaxPayloadSize]byte + n, err := announcement.MarshalBinary(payload[:]) + if err != nil { + p.r.log.Println("Failed to marshal switch announcement:", err) + return nil + } + frame := types.GetFrame() + frame.Version = types.Version0 + frame.Type = types.TypeSTP + frame.Destination = types.SwitchPorts{} + frame.Payload = payload[:n] + return frame +} + type spanningTree struct { - r *Router // - context context.Context // - root *rootAnnouncementWithTime // - rootMutex sync.RWMutex // - rootReset chan struct{} // - updateMutex sync.Mutex // - parent atomic.Value // types.SwitchPortID - coords atomic.Value // types.SwitchPorts - callback func(parent types.SwitchPortID, coords types.SwitchPorts) + r *Router + context context.Context + rootReset chan struct{} + mutex *sync.Mutex + parent types.SwitchPortID + reparent *time.Timer + sequence atomic.Uint64 + callback func(parent types.SwitchPortID, coords types.SwitchPorts) } func newSpanningTree(r *Router, f func(parent types.SwitchPortID, coords types.SwitchPorts)) *spanningTree { @@ -68,8 +128,11 @@ func newSpanningTree(r *Router, f func(parent types.SwitchPortID, coords types.S r: r, context: r.context, rootReset: make(chan struct{}), + mutex: &sync.Mutex{}, callback: f, } + t.reparent = time.AfterFunc(time.Second, t.selectNewParentAndAdvertise) + t.reparent.Stop() t.becomeRoot() go t.workerForRoot() go t.workerForAnnouncements() @@ -77,17 +140,16 @@ func newSpanningTree(r *Router, f func(parent types.SwitchPortID, coords types.S } func (t *spanningTree) Coords() types.SwitchPorts { - coords, ok := t.coords.Load().(types.SwitchPorts) - if ok { - return coords + ann := t.r.ports[t.Parent()].lastAnnouncement() + if ann == nil { + return types.SwitchPorts{} } - return types.SwitchPorts{} + return ann.Coords() } func (t *spanningTree) Ancestors() ([]types.PublicKey, types.SwitchPortID) { - root := t.Root() - port, ok := t.parent.Load().(types.SwitchPortID) - if !ok || port == 0 { + root, port := t.Root(), t.Parent() + if port == 0 { return nil, 0 } ancestors := make([]types.PublicKey, 0, 1+len(root.Signatures)) @@ -106,80 +168,124 @@ func (t *spanningTree) portWasDisconnected(port types.SwitchPortID) { t.becomeRoot() return } - if parent := t.parent.Load().(types.SwitchPortID); parent == port { - t.selectNewParent() + if t.Parent() == port { + t.selectNewParentAndAdvertiseIn(0) } } -func (t *spanningTree) selectNewParent() { - t.updateMutex.Lock() - defer t.updateMutex.Unlock() - t.becomeRoot() +func (t *spanningTree) selectNewParentAndAdvertiseIn(d time.Duration) { + t.reparent.Reset(d) +} + +func (t *spanningTree) selectNewParentAndAdvertise() { + lastUpdate := t.Root() + bestKey := lastUpdate.RootPublicKey + bestSeq := lastUpdate.Sequence bestDist := math.MaxInt32 - bestKey := t.r.public - var bestTime time.Time + bestTime := time.Now() var bestPort types.SwitchPortID var bestAnn *rootAnnouncementWithTime - var bestSeq types.Varu64 - portsToCheck := map[*Peer]*rootAnnouncementWithTime{} + + t.mutex.Lock() + for _, p := range t.r.activePorts() { ann := p.lastAnnouncement() if ann == nil { continue } - portsToCheck[p] = ann - } - checkWithCondition := func(f func(ann *rootAnnouncementWithTime, hops int) bool) { - for p, ann := range portsToCheck { - hops := len(ann.Signatures) - if f(ann, hops) { - bestKey = ann.RootPublicKey - bestDist = hops - bestPort = p.port - bestTime = ann.at - bestSeq = ann.Sequence - bestAnn = ann - } + accept := func() { + bestKey = ann.RootPublicKey + bestDist = len(ann.Signatures) + bestPort = p.port + bestTime = ann.at + bestSeq = ann.Sequence + bestAnn = ann + } + keyDelta := ann.RootPublicKey.CompareTo(bestKey) + annAt := ann.at.Round(time.Millisecond) + switch { + case ann.IsLoopOrChildOf(p.r.public): + // ignore our children or loopy announcements + case keyDelta > 0: + accept() + case keyDelta < 0: + // ignore weaker root keys + case ann.Sequence > bestSeq: + accept() + case ann.Sequence < bestSeq: + // ignore lower sequence numbers + case annAt.Before(bestTime): + accept() + case annAt.After(bestTime): + // ignore updates that arrived more recently + case len(ann.Signatures) < bestDist: + accept() + case len(ann.Signatures) > bestDist: + // ignore longer paths + case p.public.CompareTo(bestKey) > 0: + accept() } } - checkWithCondition(func(ann *rootAnnouncementWithTime, hops int) bool { - return ann.RootPublicKey.CompareTo(bestKey) > 0 - }) - checkWithCondition(func(ann *rootAnnouncementWithTime, hops int) bool { - return ann.RootPublicKey.CompareTo(bestKey) == 0 && ann.Sequence > bestSeq - }) - checkWithCondition(func(ann *rootAnnouncementWithTime, hops int) bool { - return ann.RootPublicKey.CompareTo(bestKey) == 0 && ann.Sequence == bestSeq && hops < bestDist - }) - checkWithCondition(func(ann *rootAnnouncementWithTime, hops int) bool { - return ann.RootPublicKey.CompareTo(bestKey) == 0 && ann.Sequence == bestSeq && hops == bestDist && ann.at.Before(bestTime) - }) + if bestAnn != nil { - t.parent.Store(bestPort) - if err := t.Update(t.r.ports[bestPort], bestAnn.SwitchAnnouncement); err != nil { - t.r.log.Println("t.Update: %w", err) + t.parent = bestPort + t.mutex.Unlock() + + newCoords := bestAnn.Coords() + coordsChanged := !lastUpdate.Coords().EqualTo(bestAnn.Coords()) + rootChanged := bestAnn.RootPublicKey != lastUpdate.RootPublicKey + + if rootChanged { + t.r.snake.rootNodeChanged(bestAnn.RootPublicKey) + } + if rootChanged || coordsChanged { + t.callback(bestPort, newCoords) } + + t.advertise() + return } + + // No suitable other peer was found, so we'll just become the root + // and hope that one of our peers corrects us if it matters. + t.mutex.Unlock() + t.becomeRoot() } func (t *spanningTree) advertise() { + t.mutex.Lock() + defer t.mutex.Unlock() + + t.r.ports[t.parent].mutex.RLock() + ann := t.r.ports[t.parent].announcement + t.r.ports[t.parent].mutex.RUnlock() + + if t.parent == 0 || ann == nil { // we are the root + ann = &rootAnnouncementWithTime{ + at: time.Now(), + SwitchAnnouncement: types.SwitchAnnouncement{ + RootPublicKey: t.r.public, + Sequence: types.Varu64(t.sequence.Inc()), + }, + } + } + for _, p := range t.r.startedPorts() { - go func(p *Peer) { - select { - case <-p.context.Done(): - case <-time.After(announcementTimeout): - case p.announce <- struct{}{}: - } - }(p) + p.protoOut.push(ann.ForPeer(p)) } } func (t *spanningTree) becomeRoot() { - t.parent.Store(types.SwitchPortID(0)) + t.mutex.Lock() + t.parent = 0 + t.mutex.Unlock() + newCoords := types.SwitchPorts{} if !t.Coords().EqualTo(newCoords) { go t.callback(0, types.SwitchPorts{}) + defer t.r.snake.rootNodeChanged(t.r.public) } + t.advertise() } @@ -210,7 +316,7 @@ func (t *spanningTree) workerForRoot() { case <-time.After(announcementTimeout): if !t.IsRoot() { - t.selectNewParent() + t.becomeRoot() } } } @@ -222,15 +328,13 @@ func (t *spanningTree) IsRoot() bool { } func (t *spanningTree) Root() *rootAnnouncementWithTime { - t.rootMutex.RLock() - root := t.root - t.rootMutex.RUnlock() - if root == nil || time.Since(root.at) > announcementTimeout { + root := t.r.ports[t.Parent()].lastAnnouncement() + if root == nil { return &rootAnnouncementWithTime{ at: time.Now(), SwitchAnnouncement: types.SwitchAnnouncement{ RootPublicKey: t.r.public, - Sequence: types.Varu64(time.Now().UnixNano()), + Sequence: types.Varu64(t.sequence.Load()), }, } } @@ -245,140 +349,53 @@ func (t *spanningTree) Root() *rootAnnouncementWithTime { } func (t *spanningTree) Parent() types.SwitchPortID { - if parent, ok := t.parent.Load().(types.SwitchPortID); ok { - return parent - } - return 0 + t.mutex.Lock() + defer t.mutex.Unlock() + return t.parent } -func (t *spanningTree) Update(p *Peer, newUpdate types.SwitchAnnouncement) error { - sigs := make(map[string]struct{}) - isLoopbackUpdate := false - for index, sig := range newUpdate.Signatures { - if index == 0 && sig.PublicKey != newUpdate.RootPublicKey { - // The first signature in the announcement must be from the - // key that claims to be the root. - return fmt.Errorf("rejecting update (first signature must be from root)") - } - if sig.Hop == 0 { - // None of the hops in the update should have a port number of 0 - // as this would imply that another node has sent their router - // port, which is impossible. We'll therefore reject any update - // that tries to do that. - return fmt.Errorf("rejecting update (invalid 0 hop)") - } - if index == len(newUpdate.Signatures)-1 && p.PublicKey() != sig.PublicKey { - // The last signature in the announcement must be from the - // direct peer. If it isn't then it sounds like someone is - // trying to replay someone else's announcement to us. - return fmt.Errorf("rejecting update (last signature must be from peer)") - } - if sig.PublicKey.EqualTo(t.r.public) { - // A child update is one that contains our public key in the - // signatures already - it's probably one of our direct peers - // sending our root announcement back to us. In this case there - // is some special behaviour: we usually will need to accept - // the update on the port, but we don't want to do anything that - // would influence root or coordinate changes. - isLoopbackUpdate = true - } - pk := hex.EncodeToString(sig.PublicKey[:]) - if _, ok := sigs[pk]; ok { - // One of the signatures has appeared in the update more than - // once, which would suggest that there's a loop somewhere. - return fmt.Errorf("rejecting update (detected routing loop)") - } - sigs[pk] = struct{}{} - } - - lastPortUpdate := p.lastAnnouncement() - if lastPortUpdate != nil && lastPortUpdate.RootPublicKey == newUpdate.RootPublicKey { - if newUpdate.Sequence < lastPortUpdate.Sequence { - // The update has a lower sequence number than our last update - // on this port from this root. This shouldn't happen, but if it - // does, then just drop the update. - return nil - } - } - - if err := p.updateAnnouncement(&newUpdate); err != nil { - return fmt.Errorf("p.updateAnnouncement: %w", err) - } - - if isLoopbackUpdate { - // The update contains our own signature already, so using it for - // a root update would create a loop - return nil - } - - t.updateMutex.Lock() - defer t.updateMutex.Unlock() - - lastGlobalUpdate := t.Root() - globalKeyDelta := newUpdate.RootPublicKey.CompareTo(lastGlobalUpdate.RootPublicKey) - globalTimeSince := time.Since(lastGlobalUpdate.at) - globalUpdate := false +func (t *spanningTree) UpdateParentIfNeeded(p *Peer, newUpdate types.SwitchAnnouncement) { + const immediately = time.Duration(0) + reparentIn, becomeRoot := time.Duration(-1), false + lastParentUpdate, isParent := t.Root(), p.IsParent() + keyDeltaSinceLastParentUpdate := newUpdate.RootPublicKey.CompareTo(lastParentUpdate.RootPublicKey) switch { - case globalTimeSince > announcementTimeout: - // We haven't seen a suitable root update recently so we'll accept - // this one instead - globalUpdate = true - - case globalKeyDelta < 0: - // The root key is weaker than our existing root, so it's no good - return nil - - case globalKeyDelta == 0: - // The update is from the same root node, let's see if it matches - // any other useful conditions - - switch { - case newUpdate.Sequence < lastGlobalUpdate.Sequence: - // This is a replay of an earlier update, therefore we should - // ignore it, even if it came from our parent node - return nil - - case len(newUpdate.Signatures) < len(lastGlobalUpdate.Signatures): - // The path to the root is shorter than our last update, so - // we'll accept it - globalUpdate = true - } - - case globalKeyDelta > 0: - // The root key is stronger than our existing root, therefore we'll - // accept it anyway, since we always want to converge on the - // strongest root key - globalUpdate = true + case keyDeltaSinceLastParentUpdate > 0: + // The peer has sent us a key that is stronger than our last update. + reparentIn = immediately + + case time.Since(lastParentUpdate.at) >= announcementTimeout: + // It's been a while since we last heard from our parent so we should + // really choose a new one if we can. + reparentIn = immediately + + case isParent && keyDeltaSinceLastParentUpdate < 0: + // Our parent sent us a weaker key than before — this implies that + // something bad happened. + reparentIn, becomeRoot = time.Second/2, true + + case isParent && keyDeltaSinceLastParentUpdate == 0 && newUpdate.Sequence < lastParentUpdate.Sequence: + // Our parent sent us a lower sequence number than before for the + // same root — this isn't good news either. + reparentIn, becomeRoot = time.Second/2, true + + case isParent && keyDeltaSinceLastParentUpdate == 0 && newUpdate.Sequence == lastParentUpdate.Sequence: + // Our parent sent us an equal sequence number, which probably means + // that their path to the root has changed. This isn't as bad news as + // a weaker key but we should still try to find a new parent. + reparentIn, becomeRoot = time.Second/4, false } - if parent := t.parent.Load(); p.port == parent || globalUpdate { - t.rootMutex.Lock() - var oldRootKey types.PublicKey - if t.root != nil { - oldRootKey = t.root.RootPublicKey - } - t.root = &rootAnnouncementWithTime{ - at: time.Now(), - SwitchAnnouncement: newUpdate, - } - t.parent.Store(p.port) - oldcoords := t.Coords() - newcoords := t.root.Coords() - t.coords.Store(newcoords) - t.rootMutex.Unlock() - - if t.callback != nil && !oldcoords.EqualTo(newcoords) { - go t.callback(p.port, newcoords) - } - - if oldRootKey != newUpdate.RootPublicKey { - defer t.r.snake.rootNodeChanged(newUpdate.RootPublicKey) + switch { + case reparentIn >= 0: + if becomeRoot { + t.becomeRoot() } + t.selectNewParentAndAdvertiseIn(reparentIn) + case isParent && keyDeltaSinceLastParentUpdate == 0 && newUpdate.Sequence > lastParentUpdate.Sequence: t.advertise() t.rootReset <- struct{}{} } - - return nil } diff --git a/router/version.go b/router/version.go index 75a5ca58..a56e3af5 100644 --- a/router/version.go +++ b/router/version.go @@ -16,12 +16,13 @@ package router const ( // reserved = 1 - capabilityVirtualSnake = 2 - capabilityHardState = 4 - capabilityPathIDs = 8 - capabilityRootInUpdates = 16 - capabilityNewHeaders = 32 + capabilityVirtualSnake = 2 + capabilityHardState = 4 + capabilityPathIDs = 8 + capabilityRootInUpdates = 16 + capabilityNewHeaders = 32 + capabilitySlowRootInterval = 64 ) const ourVersion = 0 -const ourCapabilities = capabilityVirtualSnake | capabilityHardState | capabilityPathIDs | capabilityRootInUpdates | capabilityNewHeaders +const ourCapabilities = capabilityVirtualSnake | capabilityHardState | capabilityPathIDs | capabilityRootInUpdates | capabilityNewHeaders | capabilitySlowRootInterval diff --git a/types/announcement.go b/types/announcement.go index 8382e604..cd61802b 100644 --- a/types/announcement.go +++ b/types/announcement.go @@ -97,15 +97,32 @@ func (a *SwitchAnnouncement) Coords() SwitchPorts { return coords } -func (a *SwitchAnnouncement) PeerCoords(public PublicKey) (SwitchPorts, error) { +func (a *SwitchAnnouncement) PeerCoords() SwitchPorts { sigs := a.Signatures - last := len(sigs) - 1 - if sigs[last].PublicKey != public { - return nil, fmt.Errorf("invalid last hop") - } - coords := make(SwitchPorts, 0, len(sigs)) - for _, sig := range sigs[:last] { + coords := make(SwitchPorts, 0, len(sigs)-1) + for _, sig := range sigs[:len(sigs)-1] { coords = append(coords, SwitchPortID(sig.Hop)) } - return coords, nil + return coords +} + +func (a *SwitchAnnouncement) AncestorParent() PublicKey { + if len(a.Signatures) < 2 { + return a.RootPublicKey + } + return a.Signatures[len(a.Signatures)-2].PublicKey +} + +func (a *SwitchAnnouncement) IsLoopOrChildOf(pk PublicKey) bool { + m := map[PublicKey]struct{}{} + for _, sig := range a.Signatures { + if sig.PublicKey.EqualTo(pk) { + return true + } + if _, ok := m[sig.PublicKey]; ok { + return true + } + m[sig.PublicKey] = struct{}{} + } + return false }