diff --git a/cmd/pinecone/main.go b/cmd/pinecone/main.go
index 182691b9..c426618c 100644
--- a/cmd/pinecone/main.go
+++ b/cmd/pinecone/main.go
@@ -47,12 +47,9 @@ func main() {
}
dialer := net.Dialer{
- Timeout: time.Second * 5,
- KeepAlive: time.Second * 2,
- }
- listener := net.ListenConfig{
- KeepAlive: time.Second * 2,
+ Timeout: time.Second * 5,
}
+ listener := net.ListenConfig{}
pineconeRouter := router.NewRouter(logger, "router", sk, pk, nil)
_ = sessions.NewSessions(logger, pineconeRouter)
diff --git a/cmd/pineconeip/main.go b/cmd/pineconeip/main.go
index f807fe74..0ca09b9a 100644
--- a/cmd/pineconeip/main.go
+++ b/cmd/pineconeip/main.go
@@ -21,6 +21,10 @@ import (
"log"
"net"
"os"
+ "os/signal"
+ "runtime/pprof"
+ "syscall"
+ "time"
"net/http"
_ "net/http/pprof"
@@ -52,9 +56,15 @@ func main() {
logger := log.New(os.Stdout, "", 0)
if hostPort := os.Getenv("PPROFLISTEN"); hostPort != "" {
- logger.Println("Starting pprof on", hostPort)
go func() {
- _ = http.ListenAndServe(hostPort, nil)
+ listener, err := net.Listen("tcp", hostPort)
+ if err != nil {
+ panic(err)
+ }
+ logger.Println("Starting pprof on", listener.Addr())
+ if err := http.Serve(listener, nil); err != nil {
+ panic(err)
+ }
}()
}
@@ -72,14 +82,6 @@ func main() {
if err := conn.SetNoDelay(true); err != nil {
return fmt.Errorf("conn.SetNoDelay: %w", err)
}
- /*
- if err := conn.SetKeepAlive(true); err != nil {
- return fmt.Errorf("conn.SetKeepAlive: %w", err)
- }
- if err := conn.SetKeepAlivePeriod(time.Second); err != nil {
- return fmt.Errorf("conn.SetKeepAlivePeriod: %w", err)
- }
- */
if err := conn.SetLinger(0); err != nil {
return fmt.Errorf("conn.SetLinger: %w", err)
}
@@ -133,5 +135,27 @@ func main() {
}
}()
- select {}
+ sigs := make(chan os.Signal, 1)
+ signal.Notify(sigs, syscall.SIGUSR1)
+ for {
+ switch <-sigs {
+ case syscall.SIGUSR1:
+ fn := fmt.Sprintf("/tmp/profile.%d", os.Getpid())
+ logger.Println("Requested profile:", fn)
+ fp, err := os.Create(fn)
+ if err != nil {
+ logger.Println("Failed to create profile:", err)
+ return
+ }
+ defer fp.Close()
+ if err := pprof.StartCPUProfile(fp); err != nil {
+ logger.Println("Failed to start profiling:", err)
+ return
+ }
+ time.AfterFunc(time.Second*10, func() {
+ pprof.StopCPUProfile()
+ logger.Println("Profile written:", fn)
+ })
+ }
+ }
}
diff --git a/cmd/pineconesim/page.html b/cmd/pineconesim/page.html
index 2483df31..8f878cb9 100644
--- a/cmd/pineconesim/page.html
+++ b/cmd/pineconesim/page.html
@@ -125,7 +125,7 @@
Node Summary
{{range .Nodes}}
{{if not .IsExternal}}
- | {{.Name}} |
+ {{.Name}} |
{{.Port}} |
{{.Coords}} |
{{.Root}} |
@@ -208,7 +208,7 @@ Peers
{{range .NodeInfo.Peers}}
- {{.Name}} |
+ {{.Name}} |
{{.PublicKey}} |
{{.Port}} |
{{.RootPublicKey}} |
diff --git a/cmd/pineconesim/simulator/interface.go b/cmd/pineconesim/simulator/interface.go
index 40cfbd31..1de50601 100644
--- a/cmd/pineconesim/simulator/interface.go
+++ b/cmd/pineconesim/simulator/interface.go
@@ -23,6 +23,8 @@ import (
)
func (sim *Simulator) LookupCoords(target string) (types.SwitchPorts, error) {
+ sim.nodesMutex.RLock()
+ defer sim.nodesMutex.RUnlock()
node, ok := sim.nodes[target]
if !ok {
return nil, fmt.Errorf("node %q not known", target)
@@ -31,6 +33,8 @@ func (sim *Simulator) LookupCoords(target string) (types.SwitchPorts, error) {
}
func (sim *Simulator) LookupNodeID(target types.SwitchPorts) (string, error) {
+ sim.nodesMutex.RLock()
+ defer sim.nodesMutex.RUnlock()
for id, n := range sim.nodes {
if n.Coords().EqualTo(target) {
return id, nil
@@ -40,6 +44,8 @@ func (sim *Simulator) LookupNodeID(target types.SwitchPorts) (string, error) {
}
func (sim *Simulator) LookupPublicKey(target types.PublicKey) (string, error) {
+ sim.nodesMutex.RLock()
+ defer sim.nodesMutex.RUnlock()
for id, n := range sim.nodes {
if n.PublicKey().EqualTo(target) {
return id, nil
diff --git a/cmd/pineconesim/simulator/nodes.go b/cmd/pineconesim/simulator/nodes.go
index d46a5057..e3f54819 100644
--- a/cmd/pineconesim/simulator/nodes.go
+++ b/cmd/pineconesim/simulator/nodes.go
@@ -67,6 +67,9 @@ func (sim *Simulator) CreateNode(t string) error {
if err := c.SetNoDelay(true); err != nil {
panic(err)
}
+ if err := c.SetLinger(0); err != nil {
+ panic(err)
+ }
if _, err = n.AuthenticatedConnect(c, "sim", router.PeerTypeRemote); err != nil {
continue
}
diff --git a/multicast/multicast.go b/multicast/multicast.go
index a0e6e834..3d291a1a 100644
--- a/multicast/multicast.go
+++ b/multicast/multicast.go
@@ -65,15 +65,14 @@ func NewMulticast(
}
m.tcpLC = net.ListenConfig{
Control: m.tcpOptions,
- KeepAlive: time.Second,
+ KeepAlive: time.Second * 3,
}
m.udpLC = net.ListenConfig{
Control: m.udpOptions,
}
m.dialer = net.Dialer{
- Control: m.tcpOptions,
- Timeout: time.Second * 5,
- KeepAlive: time.Second,
+ Control: m.tcpOptions,
+ // Timeout: time.Second * 5,
}
return m
}
@@ -163,6 +162,7 @@ func (m *Multicast) accept(listener net.Listener) {
if _, err := m.r.AuthenticatedConnect(conn, tcpaddr.Zone, router.PeerTypeMulticast); err != nil {
//m.log.Println("m.s.AuthenticatedConnect:", err)
+ _ = conn.Close()
continue
}
}
@@ -309,6 +309,7 @@ func (m *Multicast) listen(intf *multicastInterface, conn net.PacketConn, srcadd
if _, err := m.r.AuthenticatedConnect(peer, udpaddr.Zone, router.PeerTypeMulticast); err != nil {
m.log.Println("m.s.AuthenticatedConnect:", err)
+ _ = peer.Close()
continue
}
}
diff --git a/multicast/platform_darwin.go b/multicast/platform_darwin.go
index 138e2b09..2ee8f07e 100644
--- a/multicast/platform_darwin.go
+++ b/multicast/platform_darwin.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build darwin
// +build darwin
package multicast
diff --git a/router/dht.go b/router/dht.go
index 0f46edd7..6ecd8eff 100644
--- a/router/dht.go
+++ b/router/dht.go
@@ -31,7 +31,7 @@ import (
type dhtEntry interface {
PublicKey() types.PublicKey
Coordinates() types.SwitchPorts
- SeenRecently() bool
+ Alive() bool
}
type dht struct {
@@ -59,7 +59,7 @@ func (d *dht) table() []dhtEntry {
results := make([]dhtEntry, 0, len(d.sorted))
for _, n := range d.sorted {
- if n.SeenRecently() {
+ if n.Alive() {
results = append(results, n)
}
}
diff --git a/router/nexthop.go b/router/nexthop.go
index 62de91c2..80b199ad 100644
--- a/router/nexthop.go
+++ b/router/nexthop.go
@@ -31,7 +31,9 @@ func (p *Peer) getNextHops(frame *types.Frame, from types.SwitchPortID) types.Sw
switch frame.Type {
case types.TypeSTP:
if from != 0 {
- p.r.handleAnnouncement(p, frame)
+ if err := p.r.handleAnnouncement(p, frame); err != nil {
+ p.r.log.Println("Failed to handle announcement:", err)
+ }
}
case types.TypeVirtualSnakeBootstrap:
diff --git a/router/nexthop_greedy.go b/router/nexthop_greedy.go
index a38959a6..6d46eb39 100644
--- a/router/nexthop_greedy.go
+++ b/router/nexthop_greedy.go
@@ -58,12 +58,14 @@ func (r *Router) getGreedyRoutedNextHop(from *Peer, rx *types.Frame) types.Switc
peerCoords := p.Coordinates()
peerDist := int64(peerCoords.DistanceTo(rx.Destination))
switch {
- case peerDist == 0:
- return []types.SwitchPortID{p.port}
- case rx.Destination.EqualTo(peerCoords):
+ case peerDist == 0 || rx.Destination.EqualTo(peerCoords):
+ // The peer is the actual destination.
return []types.SwitchPortID{p.port}
+
case peerDist < bestDist:
+ // The peer is closer to the destination.
bestPeer, bestDist = p.port, peerDist
+
default:
}
}
diff --git a/router/nexthop_source.go b/router/nexthop_source.go
index 140f5f91..a652c4f1 100644
--- a/router/nexthop_source.go
+++ b/router/nexthop_source.go
@@ -51,7 +51,7 @@ func (r *Router) getSourceRoutedNextHop(from *Peer, rx *types.Frame) types.Switc
return nil
}
- if peer := r.ports[to]; !peer.started.Load() || !peer.alive.Load() {
+ if peer := r.ports[to]; !peer.started.Load() || !peer.Alive() {
// Don't try to send packets to a port that has nothing
// connected to it or isn't alive.
return nil
diff --git a/router/nexthop_virtualsnake.go b/router/nexthop_virtualsnake.go
index 361468de..a021bcbe 100644
--- a/router/nexthop_virtualsnake.go
+++ b/router/nexthop_virtualsnake.go
@@ -26,7 +26,7 @@ import (
"github.com/matrix-org/pinecone/util"
)
-const virtualSnakeNeighExpiryPeriod = time.Minute * 30
+const virtualSnakeNeighExpiryPeriod = time.Hour
type virtualSnake struct {
r *Router
@@ -36,6 +36,7 @@ type virtualSnake struct {
ascendingMutex sync.RWMutex
_descending *virtualSnakeNeighbour
descendingMutex sync.RWMutex
+ bootstrap *time.Timer
}
type virtualSnakeIndex struct {
@@ -72,6 +73,8 @@ func newVirtualSnake(r *Router) *virtualSnake {
r: r,
table: make(virtualSnakeTable),
}
+ snake.bootstrap = time.AfterFunc(time.Second, snake.bootstrapNow)
+ snake.bootstrap.Stop()
go snake.maintain()
return snake
}
@@ -105,37 +108,34 @@ func (t *virtualSnake) setDescending(desc *virtualSnakeNeighbour) {
// bootstraps and setup messages as needed.
func (t *virtualSnake) maintain() {
for {
- peerCount := t.r.PeerCount(-1)
- bootstrapNow := false
select {
case <-t.r.context.Done():
return
case <-time.After(time.Second):
}
- if peerCount == 0 {
- // If there are no peers connected then we don't need
- // to do any hard maintenance work.
- continue
- }
+
+ rootKey := t.r.RootPublicKey()
+ canBootstrap := rootKey != t.r.public && t.r.PeerCount(-1) > 0
+ willBootstrap := false
if asc := t.ascending(); asc != nil {
switch {
case time.Since(asc.LastSeen) > virtualSnakeNeighExpiryPeriod:
t.teardownPath(t.r.public, asc.PathID, asc.Port, true, fmt.Errorf("ascending neighbour expired"))
- bootstrapNow = true
- case asc.RootPublicKey != t.r.RootPublicKey():
+ willBootstrap = canBootstrap
+ case asc.RootPublicKey != rootKey:
t.teardownPath(t.r.public, asc.PathID, asc.Port, true, fmt.Errorf("ascending root changed"))
- bootstrapNow = true
+ willBootstrap = canBootstrap
}
} else {
- bootstrapNow = true
+ willBootstrap = canBootstrap
}
if desc := t.descending(); desc != nil {
switch {
case time.Since(desc.LastSeen) > virtualSnakeNeighExpiryPeriod:
t.teardownPath(desc.PublicKey, desc.PathID, desc.Port, false, fmt.Errorf("descending neighbour expired"))
- case !desc.RootPublicKey.EqualTo(t.r.RootPublicKey()):
+ case !desc.RootPublicKey.EqualTo(rootKey):
t.teardownPath(desc.PublicKey, desc.PathID, desc.Port, false, fmt.Errorf("descending root changed"))
}
}
@@ -145,14 +145,18 @@ func (t *virtualSnake) maintain() {
// predefined interval, but for now we'll continue to send
// them on a regular interval until we can derive some better
// connection state.
- if bootstrapNow {
- t.bootstrapNow()
+ if willBootstrap {
+ t.bootstrapIn(time.Second)
}
}
}
+func (t *virtualSnake) bootstrapIn(d time.Duration) {
+ t.bootstrap.Reset(d)
+}
+
func (t *virtualSnake) bootstrapNow() {
- if t.r.IsRoot() {
+ if t.r.IsRoot() || t.r.PeerCount(-1) == 0 {
return
}
ts, err := util.SignedTimestamp(t.r.private)
@@ -180,7 +184,7 @@ func (t *virtualSnake) bootstrapNow() {
func (t *virtualSnake) rootNodeChanged(root types.PublicKey) {
if asc := t.ascending(); asc != nil && !asc.RootPublicKey.EqualTo(root) {
t.teardownPath(t.r.public, asc.PathID, asc.Port, true, fmt.Errorf("root changed and asc no longer matches"))
- defer t.bootstrapNow()
+ defer t.bootstrapIn(time.Second)
}
if desc := t.descending(); desc != nil && !desc.RootPublicKey.EqualTo(root) {
t.teardownPath(desc.PublicKey, desc.PathID, desc.Port, false, fmt.Errorf("root changed and desc no longer matches"))
@@ -203,10 +207,6 @@ func (t *virtualSnake) rootNodeChanged(root types.PublicKey) {
}
}
-func (t *virtualSnake) portWasConnected(port types.SwitchPortID) {
- // time.AfterFunc(time.Second, t.bootstrapNow)
-}
-
// portWasDisconnected is called by the router when a peer disconnects
// allowing us to clean up the virtual snake state.
func (t *virtualSnake) portWasDisconnected(port types.SwitchPortID) {
@@ -226,6 +226,7 @@ func (t *virtualSnake) portWasDisconnected(port types.SwitchPortID) {
// this port change.
if asc := t.ascending(); asc != nil && asc.Port == port {
t.teardownPath(t.r.public, asc.PathID, asc.Port, true, fmt.Errorf("port teardown"))
+ defer t.bootstrapIn(0)
}
if desc := t.descending(); desc != nil && desc.Port == port {
t.teardownPath(desc.PublicKey, desc.PathID, desc.Port, false, fmt.Errorf("port teardown"))
@@ -313,8 +314,11 @@ func (t *virtualSnake) getVirtualSnakeNextHop(from *Peer, destKey types.PublicKe
newCheckedCandidate(ancestor, parentPort)
}
+ // The next section needs us to check direct peers
+ activePorts := t.r.activePorts()
+
// Check our direct peers ancestors
- for _, peer := range t.r.activePorts() {
+ for _, peer := range activePorts {
peerAnn := peer.lastAnnouncement()
if peerAnn == nil {
continue
@@ -325,8 +329,20 @@ func (t *virtualSnake) getVirtualSnakeNextHop(from *Peer, destKey types.PublicKe
}
}
+ // Check our DHT entries
+ t.tableMutex.RLock()
+ for dhtKey, entry := range t.table {
+ switch {
+ case !entry.Valid():
+ continue
+ default:
+ newCheckedCandidate(dhtKey.PublicKey, entry.SourcePort)
+ }
+ }
+ t.tableMutex.RUnlock()
+
// Check our direct peers
- for _, peer := range t.r.activePorts() {
+ for _, peer := range activePorts {
peerKey := peer.PublicKey()
switch {
case bestKey.EqualTo(peerKey):
@@ -335,20 +351,11 @@ func (t *virtualSnake) getVirtualSnakeNextHop(from *Peer, destKey types.PublicKe
// are directly peered with that node, so use the more direct
// path instead
newCandidate(peerKey, peer.port)
+ default:
+ newCheckedCandidate(peerKey, peer.port)
}
}
- // Check our DHT entries
- t.tableMutex.RLock()
- for dhtKey, entry := range t.table {
- switch {
- case !entry.Valid():
- continue
- }
- newCheckedCandidate(dhtKey.PublicKey, entry.SourcePort)
- }
- t.tableMutex.RUnlock()
-
if bootstrap {
return types.SwitchPorts{bestPort}
} else {
@@ -372,7 +379,7 @@ func (t *virtualSnake) getVirtualSnakeTeardownNextHop(from *Peer, rx *types.Fram
}
if asc := t.ascending(); asc != nil && t.r.public.EqualTo(rx.DestinationKey) && asc.PathID == teardown.PathID {
t.setAscending(nil)
- defer time.AfterFunc(time.Millisecond*500, t.bootstrapNow)
+ defer t.bootstrapIn(time.Second / 4)
}
t.tableMutex.Lock()
defer t.tableMutex.Unlock()
@@ -413,10 +420,9 @@ func (t *virtualSnake) teardownPath(pk types.PublicKey, pathID types.VirtualSnak
Payload: payload[:],
}
_ = t.getVirtualSnakeTeardownNextHop(t.r.ports[0], &frame)
- if !t.r.ports[via].started.Load() {
- return
+ if t.r.ports[via].started.Load() {
+ t.r.ports[via].protoOut.push(frame.Borrow())
}
- t.r.ports[via].protoOut.push(frame.Borrow())
}
// handleBootstrap is called in response to an incoming bootstrap
@@ -437,7 +443,7 @@ func (t *virtualSnake) handleBootstrap(from *Peer, rx *types.Frame) error {
return fmt.Errorf("util.VerifySignedTimestamp")
}
if !bootstrap.RootPublicKey.EqualTo(t.r.RootPublicKey()) {
- return fmt.Errorf("root key doesn't match")
+ return fmt.Errorf("bootstrap root key doesn't match")
}
bootstrapACK := types.VirtualSnakeBootstrapACK{ // nolint:gosimple
PathID: bootstrap.PathID,
@@ -474,7 +480,7 @@ func (t *virtualSnake) handleBootstrapACK(from *Peer, rx *types.Frame) error {
return fmt.Errorf("util.VerifySignedTimestamp")
}
if !bootstrapACK.RootPublicKey.EqualTo(t.r.RootPublicKey()) {
- return fmt.Errorf("root key doesn't match")
+ return fmt.Errorf("bootstrap ACK root key doesn't match")
}
update := false
asc := t.ascending()
@@ -579,12 +585,12 @@ func (t *virtualSnake) handleSetup(from *Peer, rx *types.Frame, nextHops types.S
}
if !setup.RootPublicKey.EqualTo(t.r.RootPublicKey()) {
t.teardownPath(rx.SourceKey, setup.PathID, from.port, false, fmt.Errorf("rejecting setup (root key doesn't match)"))
- return fmt.Errorf("root key doesn't match")
+ return fmt.Errorf("setup root key doesn't match")
}
// Did the setup hit a dead end on the way to the ascending node?
if nextHops.EqualTo(types.SwitchPorts{0}) || nextHops.EqualTo(types.SwitchPorts{}) {
- if !rx.DestinationKey.EqualTo(t.r.public) || !rx.Destination.EqualTo(t.r.Coords()) {
+ if !rx.DestinationKey.EqualTo(t.r.public) {
t.teardownPath(rx.SourceKey, setup.PathID, from.port, false, fmt.Errorf("rejecting setup (hit dead end)"))
return fmt.Errorf("setup for %q (%s) en route to %q %s hit dead end at %s", rx.SourceKey, hex.EncodeToString(setup.PathID[:]), rx.DestinationKey, rx.Destination, t.r.Coords())
}
diff --git a/router/peer.go b/router/peer.go
index 1c552bfc..602fb5f4 100644
--- a/router/peer.go
+++ b/router/peer.go
@@ -18,7 +18,7 @@ import (
"bytes"
"context"
"encoding/binary"
- "errors"
+ "encoding/hex"
"fmt"
"io"
"net"
@@ -29,7 +29,7 @@ import (
"go.uber.org/atomic"
)
-const PeerKeepaliveInterval = time.Second * 2
+const PeerKeepaliveInterval = time.Second * 3
const PeerKeepaliveTimeout = PeerKeepaliveInterval * 3
const (
@@ -41,8 +41,9 @@ const (
type Peer struct {
r *Router //
port types.SwitchPortID //
+ allocated atomic.Bool //
started atomic.Bool // worker goroutines started?
- alive atomic.Bool // have we received a handshake?
+ wg *sync.WaitGroup // wait group for worker goroutines
mutex sync.RWMutex // protects everything below this line
zone string //
peertype int //
@@ -50,14 +51,110 @@ type Peer struct {
cancel context.CancelFunc //
conn net.Conn // underlying connection to peer
public types.PublicKey //
- trafficOut queue // queue traffic message to peer
- protoOut queue // queue protocol message to peer
+ trafficOut *lifoQueue // queue traffic message to peer
+ protoOut *fifoQueue // queue protocol message to peer
coords types.SwitchPorts //
- announce chan struct{} //
- announcement *rootAnnouncementWithTime //
+ announcement *rootAnnouncementWithTime // last received announcement from peer
statistics peerStatistics //
}
+func (p *Peer) start() {
+ if !p.started.CAS(false, true) {
+ return
+ }
+
+ // When the peer dies, we need to clean up.
+ var lasterr error
+ var lasterrMutex sync.Mutex
+ defer func() {
+ p.reset()
+ p.r.snake.portWasDisconnected(p.port)
+ p.r.tree.portWasDisconnected(p.port)
+ go p.r.callbacks.onDisconnected(p.port, p.public, p.peertype, lasterr)
+ }()
+
+ // Store the fact that we're connected to this public key in
+ // this zone, so that the multicast code can ignore nodes we
+ // are already connected to.
+ index := hex.EncodeToString(p.public[:]) + p.zone
+ p.r.active.Store(index, p.port)
+ defer p.r.active.Delete(index)
+
+ // Push a root update to our new peer. This will notify them
+ // of our coordinates and that we are alive.
+ p.protoOut.push(p.r.tree.Root().ForPeer(p))
+
+ // Start the reader and writer goroutines for this peer.
+ p.wg.Add(2)
+ go func() {
+ if err := p.reader(p.context); err != nil && err != context.Canceled {
+ lasterrMutex.Lock()
+ lasterr = fmt.Errorf("reader error: %w", err)
+ lasterrMutex.Unlock()
+ }
+ }()
+ go func() {
+ if err := p.writer(p.context); err != nil && err != context.Canceled {
+ lasterrMutex.Lock()
+ lasterr = fmt.Errorf("writer error: %w", err)
+ lasterrMutex.Unlock()
+ }
+ }()
+
+ // Report the new connection.
+ if p.port != 0 {
+ p.r.dht.insertNode(p)
+ }
+ if p.r.simulator != nil {
+ p.r.simulator.ReportNewLink(p.conn, p.r.public, p.public)
+ }
+ go p.r.callbacks.onConnected(p.port, p.public, p.peertype)
+
+ p.r.log.Printf("Connected port %d to %s (zone %q)\n", p.port, p.conn.RemoteAddr(), p.zone)
+
+ // Wait for the cancellation, and then for the goroutines to stop.
+ <-p.context.Done()
+ p.started.Store(false)
+ p.wg.Wait()
+
+ // Report the disconnection.
+ if p.port != 0 {
+ p.r.dht.deleteNode(p.public)
+ }
+ if p.r.simulator != nil {
+ p.r.simulator.ReportDeadLink(p.r.public, p.public)
+ }
+
+ // Make sure the connection is closed.
+ _ = p.conn.Close()
+
+ // ... and finally, yell about it.
+ lasterrMutex.Lock()
+ if lasterr != nil {
+ p.r.log.Printf("Disconnected port %d: %s\n", p.port, lasterr)
+ } else {
+ p.r.log.Printf("Disconnected port %d\n", p.port)
+ }
+ lasterrMutex.Unlock()
+}
+
+func (p *Peer) reset() {
+ p.mutex.Lock()
+ defer p.mutex.Unlock()
+ p.allocated.Store(false)
+ p.started.Store(false)
+ p.zone = ""
+ p.peertype = 0
+ p.context, p.cancel = nil, nil
+ p.conn = nil
+ p.public = types.PublicKey{}
+ p.trafficOut.reset()
+ p.protoOut.reset()
+ p.statistics.reset()
+ p.coords = nil
+ p.announcement = nil
+}
+
type peerStatistics struct {
txProtoSuccessful atomic.Uint64
txProtoDropped atomic.Uint64
@@ -76,6 +173,16 @@ func (s *peerStatistics) reset() {
s.rxTraffic.Store(0)
}
+func (p *Peer) IsParent() bool {
+ return p.r.tree.Parent() == p.port
+}
+
+func (p *Peer) Alive() bool {
+ p.mutex.RLock()
+ defer p.mutex.RUnlock()
+ return p.announcement != nil && time.Since(p.announcement.at) < announcementTimeout
+}
+
func (p *Peer) PublicKey() types.PublicKey {
p.mutex.RLock()
defer p.mutex.RUnlock()
@@ -89,7 +196,7 @@ func (p *Peer) Coordinates() types.SwitchPorts {
}
func (p *Peer) SeenCommonRootRecently() bool {
- if !p.alive.Load() {
+ if !p.Alive() {
return false
}
last := p.lastAnnouncement()
@@ -101,31 +208,21 @@ func (p *Peer) SeenCommonRootRecently() bool {
return lpk == rpk
}
-func (p *Peer) SeenRecently() bool {
- if last := p.lastAnnouncement(); last != nil {
- return true
- }
- return false
-}
-
func (p *Peer) updateAnnouncement(new *types.SwitchAnnouncement) error {
p.mutex.Lock()
defer p.mutex.Unlock()
- coords, err := new.PeerCoords(p.public)
- if err != nil {
- p.alive.Store(false)
- p.announcement = nil
- p.coords = nil
- return fmt.Errorf("new.PeerCoords: %w", err)
- }
- if p.alive.CAS(false, true) {
- p.r.snake.portWasConnected(p.port)
+ if p.announcement != nil {
+ if new.RootPublicKey == p.announcement.RootPublicKey && new.Sequence < p.announcement.Sequence {
+ p.announcement = nil
+ p.coords = nil
+ return fmt.Errorf("root announcement replays sequence number")
+ }
}
p.announcement = &rootAnnouncementWithTime{
SwitchAnnouncement: *new,
at: time.Now(),
}
- p.coords = coords
+ p.coords = new.PeerCoords()
return nil
}
@@ -141,96 +238,50 @@ func (p *Peer) lastAnnouncement() *rootAnnouncementWithTime {
return p.announcement
}
-func (p *Peer) start() error {
- if !p.started.CAS(false, true) {
- return errors.New("switch peer is already started")
- }
- p.alive.Store(false)
- go p.reader(p.context)
- go p.writer(p.context)
- return nil
-}
-
-func (p *Peer) stop() error {
- if !p.started.CAS(true, false) {
- return errors.New("switch peer is already stopped")
- }
- p.alive.Store(false)
+func (p *Peer) stop() {
+ p.started.Store(false)
+ p.mutex.Lock()
p.cancel()
- _ = p.conn.Close()
- return nil
+ p.mutex.Unlock()
}
-/*
func (p *Peer) generateKeepalive() *types.Frame {
frame := types.GetFrame()
frame.Version = types.Version0
frame.Type = types.TypeKeepalive
return frame
}
-*/
-func (p *Peer) generateAnnouncement() *types.Frame {
- if p.port == 0 {
- return nil
- }
- announcement := p.r.tree.Root()
- for _, sig := range announcement.Signatures {
- if p.r.public.EqualTo(sig.PublicKey) {
- // For some reason the announcement that we want to send already
- // includes our signature. This shouldn't really happen but if we
- // did send it, other nodes would end up ignoring the announcement
- // anyway since it would appear to be a routing loop.
- return nil
- }
- }
- // Sign the announcement.
- if err := announcement.Sign(p.r.private[:], p.port); err != nil {
- p.r.log.Println("Failed to sign switch announcement:", err)
- return nil
- }
- var payload [MaxPayloadSize]byte
- n, err := announcement.MarshalBinary(payload[:])
- if err != nil {
- p.r.log.Println("Failed to marshal switch announcement:", err)
- return nil
- }
- frame := types.GetFrame()
- frame.Version = types.Version0
- frame.Type = types.TypeSTP
- frame.Destination = types.SwitchPorts{}
- frame.Payload = payload[:n]
- return frame
-}
+func (p *Peer) reader(ctx context.Context) error {
+ defer p.wg.Done()
+ defer p.cancel()
-func (p *Peer) reader(ctx context.Context) {
buf := make([]byte, MaxFrameSize)
for {
select {
case <-ctx.Done():
// The switch peer is shutting down.
- return
+ return context.Canceled
default:
if p.port != 0 {
if err := p.conn.SetReadDeadline(time.Now().Add(PeerKeepaliveTimeout)); err != nil {
- _ = p.r.Disconnect(p.port, fmt.Errorf("conn.SetReadDeadline: %w", err))
- return
+ return fmt.Errorf("conn.SetReadDeadline: %w", err)
}
}
if _, err := io.ReadFull(p.conn, buf[:8]); err != nil {
- _ = p.r.Disconnect(p.port, fmt.Errorf("p.conn.Peek: %w", err))
- return
+ return fmt.Errorf("p.conn.Peek: %w", err)
+ }
+ if err := p.conn.SetReadDeadline(time.Time{}); err != nil {
+ return fmt.Errorf("conn.SetReadDeadline: %w", err)
}
if !bytes.Equal(buf[:4], types.FrameMagicBytes) {
- _ = p.r.Disconnect(p.port, fmt.Errorf("missing magic bytes"))
- return
+ return fmt.Errorf("missing magic bytes")
}
expecting := int(binary.BigEndian.Uint16(buf[6:8]))
n, err := io.ReadFull(p.conn, buf[8:expecting])
if err != nil {
- _ = p.r.Disconnect(p.port, fmt.Errorf("io.ReadFull: %w", err))
- return
+ return fmt.Errorf("io.ReadFull: %w", err)
}
if n < expecting-8 {
p.r.log.Println("Expecting", expecting, "bytes but got", n, "bytes")
@@ -238,16 +289,13 @@ func (p *Peer) reader(ctx context.Context) {
}
frame := types.GetFrame()
if _, err := frame.UnmarshalBinary(buf[:n+8]); err != nil {
- p.r.log.Println("Port", p.port, "error unmarshalling frame:", err)
- frame.Done()
- return
- }
- if frame.Version != types.Version0 {
- p.r.log.Println("Port", p.port, "incorrect version in frame")
frame.Done()
- return
+ return fmt.Errorf("frame.UnmarshalBinary: %w", err)
}
- if frame.Type == types.TypeKeepalive {
+ switch {
+ case frame.Version != types.Version0:
+ fallthrough
+ case frame.Type == types.TypeKeepalive:
frame.Done()
continue
}
@@ -257,7 +305,7 @@ func (p *Peer) reader(ctx context.Context) {
for _, port := range p.getNextHops(frame, p.port) {
// Ignore ports that are not good candidates.
dest := p.r.ports[port]
- if !dest.started.Load() || (dest.port != 0 && !dest.alive.Load()) {
+ if !dest.started.Load() || (dest.port != 0 && !dest.Alive()) {
continue
}
if p.port != 0 && dest.port != 0 {
@@ -310,85 +358,64 @@ var bufPool = sync.Pool{
},
}
-func (p *Peer) writer(ctx context.Context) {
- //tick := time.NewTicker(PeerKeepaliveInterval)
- //defer tick.Stop()
- send := func(frame *types.Frame) {
+func (p *Peer) writer(ctx context.Context) error {
+ defer p.wg.Done()
+ defer p.cancel()
+
+ tick := time.NewTicker(PeerKeepaliveInterval)
+ defer tick.Stop()
+
+ send := func(frame *types.Frame) error {
if frame == nil {
- return
+ return nil
}
buf := bufPool.Get().([]byte)
defer bufPool.Put(buf) // nolint:staticcheck
fn, err := frame.MarshalBinary(buf)
frame.Done()
if err != nil {
- p.r.log.Println("Port", p.port, "error marshalling frame:", err)
- return
+ return nil
}
if !bytes.Equal(buf[:4], types.FrameMagicBytes) {
- panic("expected magic bytes")
+ return nil
}
- remaining := buf[:fn]
- for len(remaining) > 0 {
- n, err := p.conn.Write(remaining)
- if err != nil {
- _ = p.r.Disconnect(p.port, fmt.Errorf("p.conn.Write: %w", err))
- return
- }
- remaining = remaining[n:]
+ if err := p.conn.SetWriteDeadline(time.Now().Add(PeerKeepaliveTimeout)); err != nil {
+ return fmt.Errorf("p.conn.SetWriteDeadline: %w", err)
}
+ if _, err = p.conn.Write(buf[:fn]); err != nil {
+ return fmt.Errorf("p.conn.Write: %w", err)
+ }
+ if err := p.conn.SetWriteDeadline(time.Time{}); err != nil {
+ return fmt.Errorf("p.conn.SetWriteDeadline: %w", err)
+ }
+ return nil
}
- // The very first thing we send should be a tree announcement,
- // so that the remote side can work out our coords and consider
- // us to be "alive".
- send(p.generateAnnouncement())
-
for {
select {
case <-ctx.Done():
- return
- case <-p.announce:
- send(p.generateAnnouncement())
+ return context.Canceled
+ case frame := <-p.protoOut.pop():
+ if err := send(frame); err != nil {
+ return fmt.Errorf("send: %w", err)
+ }
+ p.protoOut.ack()
p.statistics.txProtoSuccessful.Inc()
continue
default:
}
select {
case <-ctx.Done():
- return
- case <-p.announce:
- send(p.generateAnnouncement())
- p.statistics.txProtoSuccessful.Inc()
- continue
- case <-p.protoOut.wait():
- if frame, ok := p.protoOut.pop(); ok {
- send(frame)
- p.statistics.txProtoSuccessful.Inc()
- } else {
- p.statistics.txProtoDropped.Inc()
+ return context.Canceled
+ case frame := <-p.protoOut.pop():
+ if err := send(frame); err != nil {
+ return fmt.Errorf("send: %w", err)
}
- continue
- default:
- }
- select {
- case <-ctx.Done():
- return
- case <-p.announce:
- send(p.generateAnnouncement())
+ p.protoOut.ack()
p.statistics.txProtoSuccessful.Inc()
continue
- case <-p.protoOut.wait():
- if frame, ok := p.protoOut.pop(); ok {
- send(frame)
- p.statistics.txProtoSuccessful.Inc()
- } else {
- p.statistics.txProtoDropped.Inc()
- }
- continue
case <-p.trafficOut.wait():
- if frame, ok := p.trafficOut.pop(); ok {
- send(frame)
+ if frame, ok := p.trafficOut.pop(); ok && send(frame) == nil {
p.statistics.txTrafficSuccessful.Inc()
} else {
p.statistics.txTrafficDropped.Inc()
@@ -398,31 +425,27 @@ func (p *Peer) writer(ctx context.Context) {
}
select {
case <-ctx.Done():
- return
- case <-p.announce:
- send(p.generateAnnouncement())
- p.statistics.txProtoSuccessful.Inc()
- continue
- case <-p.protoOut.wait():
- if frame, ok := p.protoOut.pop(); ok {
- send(frame)
- p.statistics.txProtoSuccessful.Inc()
- } else {
- p.statistics.txProtoDropped.Inc()
+ return context.Canceled
+ case frame := <-p.protoOut.pop():
+ if err := send(frame); err != nil {
+ return fmt.Errorf("send: %w", err)
}
+ p.protoOut.ack()
+ p.statistics.txProtoSuccessful.Inc()
continue
case <-p.trafficOut.wait():
- if frame, ok := p.trafficOut.pop(); ok {
- send(frame)
+ if frame, ok := p.trafficOut.pop(); ok && send(frame) == nil {
p.statistics.txTrafficSuccessful.Inc()
} else {
p.statistics.txTrafficDropped.Inc()
}
continue
- //case <-tick.C:
- // send(p.generateKeepalive())
- // p.statistics.txProtoSuccessful.Inc()
- // continue
+ case <-tick.C:
+ if err := send(p.generateKeepalive()); err != nil {
+ return fmt.Errorf("send: %w", err)
+ }
+ p.statistics.txProtoSuccessful.Inc()
+ continue
}
}
}
diff --git a/router/queue.go b/router/queue.go
deleted file mode 100644
index a2219a14..00000000
--- a/router/queue.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package router
-
-import "github.com/matrix-org/pinecone/types"
-
-type queue interface {
- queuecount() int
- queuesize() int
- push(frame *types.Frame) bool
- pop() (*types.Frame, bool)
- reset()
- wait() <-chan struct{}
-}
diff --git a/router/queuefifo.go b/router/queuefifo.go
index 0bcb4381..43e8e829 100644
--- a/router/queuefifo.go
+++ b/router/queuefifo.go
@@ -7,85 +7,77 @@ import (
)
type fifoQueue struct {
- frames []*types.Frame
- count int
- mutex sync.Mutex
- notifs chan struct{}
+ entries []chan *types.Frame
+ mutex sync.Mutex
}
func newFIFOQueue() *fifoQueue {
- q := &fifoQueue{
- notifs: make(chan struct{}),
- }
+ q := &fifoQueue{}
+ q.reset()
return q
}
+func (q *fifoQueue) _initialise() {
+ for i := range q.entries {
+ q.entries[i] = nil
+ }
+ q.entries = []chan *types.Frame{
+ make(chan *types.Frame, 1),
+ }
+}
+
func (q *fifoQueue) queuecount() int {
q.mutex.Lock()
defer q.mutex.Unlock()
- return q.count
+ return len(q.entries)
}
func (q *fifoQueue) queuesize() int {
q.mutex.Lock()
defer q.mutex.Unlock()
- return cap(q.frames)
+ return cap(q.entries)
}
func (q *fifoQueue) push(frame *types.Frame) bool {
q.mutex.Lock()
defer q.mutex.Unlock()
- q.frames = append(q.frames, frame)
- q.count++
- select {
- case q.notifs <- struct{}{}:
- default:
+ if len(q.entries) == 0 {
+ q._initialise()
}
+ ch := q.entries[len(q.entries)-1]
+ ch <- frame
+ close(ch)
+ q.entries = append(q.entries, make(chan *types.Frame, 1))
return true
}
-func (q *fifoQueue) pop() (*types.Frame, bool) {
+func (q *fifoQueue) reset() {
q.mutex.Lock()
defer q.mutex.Unlock()
- if q.count == 0 {
- return nil, false
- }
- frame := q.frames[0]
- q.frames[0] = nil
- q.frames = q.frames[1:]
- q.count--
- if q.count == 0 {
- // Force a GC of the underlying array, since it might have
- // grown significantly if the queue was hammered for some reason
- q.frames = nil
+ for _, ch := range q.entries {
+ select {
+ case frame := <-ch:
+ if frame != nil {
+ frame.Done()
+ }
+ default:
+ }
}
- return frame, true
+ q._initialise()
}
-func (q *fifoQueue) reset() {
+func (q *fifoQueue) pop() <-chan *types.Frame {
q.mutex.Lock()
defer q.mutex.Unlock()
- q.count = 0
- for i := range q.frames {
- if q.frames[i] != nil {
- q.frames[i].Done()
- q.frames[i] = nil
- }
- }
- q.frames = nil
- close(q.notifs)
- for range q.notifs {
+ if len(q.entries) == 0 {
+ q._initialise()
}
- q.notifs = make(chan struct{})
+ entry := q.entries[0]
+ return entry
}
-func (q *fifoQueue) wait() <-chan struct{} {
+func (q *fifoQueue) ack() {
q.mutex.Lock()
defer q.mutex.Unlock()
- if q.count > 0 {
- ch := make(chan struct{})
- close(ch)
- return ch
- }
- return q.notifs
+ q.entries = q.entries[1:]
}
diff --git a/router/queuefifo_test.go b/router/queuefifo_test.go
new file mode 100644
index 00000000..cdc91988
--- /dev/null
+++ b/router/queuefifo_test.go
@@ -0,0 +1,34 @@
+package router
+
+import (
+ "testing"
+
+ "github.com/matrix-org/pinecone/types"
+)
+
+func TestFIFOQueue(t *testing.T) {
+ q := newFIFOQueue()
+ iterations := types.SwitchPortID(1024)
+
+ go func() {
+ for i := types.SwitchPortID(0); i < iterations; i++ {
+ q.push(&types.Frame{
+ Destination: types.SwitchPorts{i},
+ })
+ }
+ }()
+
+ got := types.SwitchPortID(0)
+
+ for i := types.SwitchPortID(0); i < iterations; i++ {
+ frame := <-q.pop()
+ if frame == nil {
+ t.Fatalf("unexpected nil frame")
+ }
+ if frame.Destination[0] < got || frame.Destination[0] > got+1 {
+ t.Fatalf("ordering problem")
+ }
+ got = frame.Destination[0]
+ q.ack()
+ }
+}
diff --git a/router/router.go b/router/router.go
index 7f8d12bf..d398f36c 100644
--- a/router/router.go
+++ b/router/router.go
@@ -25,7 +25,6 @@ import (
"net"
"sort"
"sync"
- "time"
"github.com/matrix-org/pinecone/types"
"go.uber.org/atomic"
@@ -110,7 +109,7 @@ func NewRouter(log *log.Logger, id string, private ed25519.PrivateKey, public ed
sw.ports[i] = &Peer{
r: sw,
port: types.SwitchPortID(i),
- announce: make(chan struct{}),
+ wg: &sync.WaitGroup{},
protoOut: newFIFOQueue(),
trafficOut: newLIFOQueue(TrafficBufferSize),
}
@@ -148,7 +147,7 @@ func (r *Router) Close() error {
r.cancel()
for _, port := range r.ports {
if port.started.Load() {
- _ = port.stop()
+ port.stop()
}
}
return nil
@@ -313,10 +312,7 @@ func (r *Router) startedPorts() peers {
func (r *Router) activePorts() peers {
peers := make(peers, 0, PortCount)
for _, p := range r.startedPorts() {
- switch {
- case !p.alive.Load():
- continue
- default:
+ if p.Alive() {
peers = append(peers, p)
}
}
@@ -330,52 +326,49 @@ func (r *Router) activePorts() peers {
// the node was connected to will be returned in the event
// of a successful connection.
func (r *Router) AuthenticatedConnect(conn net.Conn, zone string, peertype int) (types.SwitchPortID, error) {
- select {
- case <-time.After(time.Second * 5):
- return 0, fmt.Errorf("handshake timed out")
- default:
- handshake := []byte{
- ourVersion,
- ourCapabilities,
- 0, // unused
- 0, // unused
- }
- handshake = append(handshake, r.public[:ed25519.PublicKeySize]...)
- handshake = append(handshake, ed25519.Sign(r.private[:], handshake)...)
- _ = conn.SetDeadline(time.Now().Add(time.Second * 5))
- if _, err := conn.Write(handshake); err != nil {
- conn.Close()
- return 0, fmt.Errorf("conn.Write: %w", err)
- }
- if _, err := io.ReadFull(conn, handshake); err != nil {
- conn.Close()
- return 0, fmt.Errorf("io.ReadFull: %w", err)
- }
- _ = conn.SetDeadline(time.Time{})
- if theirVersion := handshake[0]; theirVersion != ourVersion {
- conn.Close()
- return 0, fmt.Errorf("mismatched node version")
- }
- if theirCapabilities := handshake[1]; theirCapabilities&ourCapabilities != ourCapabilities {
- conn.Close()
- return 0, fmt.Errorf("mismatched node capabilities")
- }
- var public types.PublicKey
- var signature types.Signature
- offset := 4
- offset += copy(public[:], handshake[offset:offset+ed25519.PublicKeySize])
- copy(signature[:], handshake[offset:offset+ed25519.SignatureSize])
- if !ed25519.Verify(public[:], handshake[:offset], signature[:]) {
- conn.Close()
- return 0, fmt.Errorf("peer sent invalid signature")
- }
- port, err := r.Connect(conn, public, zone, peertype)
- if err != nil {
- conn.Close()
- return 0, err
- }
- return port, nil
+ handshake := []byte{
+ ourVersion,
+ ourCapabilities,
+ 0, // unused
+ 0, // unused
+ }
+ handshake = append(handshake, r.public[:ed25519.PublicKeySize]...)
+ handshake = append(handshake, ed25519.Sign(r.private[:], handshake)...)
+ //_ = conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
+ if _, err := conn.Write(handshake); err != nil {
+ conn.Close()
+ return 0, fmt.Errorf("conn.Write: %w", err)
+ }
+ //_ = conn.SetWriteDeadline(time.Time{})
+ //_ = conn.SetReadDeadline(time.Now().Add(time.Second * 5))
+ if _, err := io.ReadFull(conn, handshake); err != nil {
+ conn.Close()
+ return 0, fmt.Errorf("io.ReadFull: %w", err)
+ }
+ //_ = conn.SetReadDeadline(time.Time{})
+ if theirVersion := handshake[0]; theirVersion != ourVersion {
+ conn.Close()
+ return 0, fmt.Errorf("mismatched node version")
+ }
+ if theirCapabilities := handshake[1]; theirCapabilities&ourCapabilities != ourCapabilities {
+ conn.Close()
+ return 0, fmt.Errorf("mismatched node capabilities")
}
+ var public types.PublicKey
+ var signature types.Signature
+ offset := 4
+ offset += copy(public[:], handshake[offset:offset+ed25519.PublicKeySize])
+ copy(signature[:], handshake[offset:offset+ed25519.SignatureSize])
+ if !ed25519.Verify(public[:], handshake[:offset], signature[:]) {
+ conn.Close()
+ return 0, fmt.Errorf("peer sent invalid signature")
+ }
+ port, err := r.Connect(conn, public, zone, peertype)
+ if err != nil {
+ conn.Close()
+ return 0, err
+ }
+ return port, nil
}
// Connect initiates a peer connection using the given
@@ -384,43 +377,33 @@ func (r *Router) AuthenticatedConnect(conn net.Conn, zone string, peertype int)
// port number that the node was connected to will be
// returned in the event of a successful connection.
func (r *Router) Connect(conn net.Conn, public types.PublicKey, zone string, peertype int) (types.SwitchPortID, error) {
- if p, ok := r.active.Load(hex.EncodeToString(public[:]) + zone); ok {
- _ = r.Disconnect(p.(types.SwitchPortID), fmt.Errorf("another incoming connection from the same node"))
+ if r.IsConnected(public, zone) {
+ _ = conn.Close()
+ return 0, fmt.Errorf("already connected")
}
r.connections.Lock()
defer r.connections.Unlock()
+ usePort := func(port *Peer) bool {
+ if !port.allocated.CAS(false, true) {
+ return false
+ }
+ port.mutex.Lock()
+ port.context, port.cancel = context.WithCancel(r.context)
+ port.zone = zone
+ port.peertype = peertype
+ port.conn = conn
+ port.public = public
+ port.mutex.Unlock()
+ go port.start()
+ return true
+ }
for i := types.SwitchPortID(0); i < PortCount; i++ {
if i != 0 && bytes.Equal(r.public[:], public[:]) {
return 0, fmt.Errorf("loopback connection prohibited")
}
- if r.ports[i].started.Load() {
- continue
- }
- r.ports[i].mutex.Lock()
- r.ports[i].context, r.ports[i].cancel = context.WithCancel(r.context)
- r.ports[i].zone = zone
- r.ports[i].peertype = peertype
- r.ports[i].conn = conn // util.NewBufferedRWCSize(conn, MaxFrameSize)
- r.ports[i].public = public
- r.ports[i].coords = nil
- r.ports[i].protoOut.reset()
- r.ports[i].trafficOut.reset()
- r.ports[i].announcement = nil
- r.ports[i].statistics.reset()
- r.ports[i].mutex.Unlock()
- if err := r.ports[i].start(); err != nil {
- return 0, fmt.Errorf("port.start: %w", err)
+ if usePort(r.ports[i]) {
+ return i, nil
}
- r.active.Store(hex.EncodeToString(public[:])+zone, i)
- if i != 0 {
- r.dht.insertNode(r.ports[i])
- }
- if r.simulator != nil {
- r.simulator.ReportNewLink(conn, r.public, public)
- }
- go r.callbacks.onConnected(i, public, peertype)
- r.log.Printf("Connected port %d to %s (zone %q)\n", i, conn.RemoteAddr(), zone)
- return i, nil
}
return 0, fmt.Errorf("no free switch ports")
}
@@ -433,31 +416,7 @@ func (r *Router) Disconnect(i types.SwitchPortID, err error) error {
if i == 0 {
return fmt.Errorf("cannot disconnect port %d", i)
}
- r.connections.Lock()
- defer r.connections.Unlock()
- if stoperr := r.ports[i].stop(); stoperr != nil {
- return fmt.Errorf("port.stop: %w", stoperr)
- }
- r.active.Delete(hex.EncodeToString(r.ports[i].public[:]) + r.ports[i].zone)
- r.ports[i].mutex.Lock()
- r.ports[i].peertype = 0
- r.ports[i].zone = ""
- r.ports[i].public = types.PublicKey{}
- r.ports[i].coords = nil
- r.ports[i].announcement = nil
- r.ports[i].protoOut.reset()
- r.ports[i].trafficOut.reset()
- r.ports[i].mutex.Unlock()
- if r.ports[i].port != 0 {
- r.dht.deleteNode(r.ports[i].public)
- }
- if r.simulator != nil {
- r.simulator.ReportDeadLink(r.public, r.ports[i].public)
- }
- r.log.Printf("Disconnected port %d: %s\n", i, err)
- r.tree.portWasDisconnected(i)
- r.snake.portWasDisconnected(i)
- go r.callbacks.onDisconnected(i, r.ports[i].public, r.ports[i].peertype, err)
+ r.ports[i].stop()
return nil
}
diff --git a/router/tree.go b/router/tree.go
index e66d1972..2328944a 100644
--- a/router/tree.go
+++ b/router/tree.go
@@ -20,30 +20,57 @@ import (
"fmt"
"math"
"sync"
- "sync/atomic"
"time"
"github.com/matrix-org/pinecone/types"
+ "go.uber.org/atomic"
)
// announcementInterval is the frequency at which this
// node will send root announcements to other peers.
-const announcementInterval = PeerKeepaliveInterval // time.Minute * 15
+const announcementInterval = time.Minute * 15
// announcementTimeout is the amount of time that must
// pass without receiving a root announcement before we
// will assume that the peer is dead.
const announcementTimeout = announcementInterval * 2
-func (r *Router) handleAnnouncement(peer *Peer, rx *types.Frame) {
- var new types.SwitchAnnouncement
- if _, err := new.UnmarshalBinary(rx.Payload); err != nil {
- r.log.Println("Error unmarshalling announcement:", err)
- return
+func (r *Router) handleAnnouncement(peer *Peer, rx *types.Frame) error {
+ var newUpdate types.SwitchAnnouncement
+ if _, err := newUpdate.UnmarshalBinary(rx.Payload); err != nil {
+ return fmt.Errorf("failed to unmarshal root announcement: %w", err)
}
- if err := r.tree.Update(peer, new); err != nil {
- r.log.Println("Error handling announcement on port", peer.port, ":", err)
+ sigs := make(map[string]struct{})
+ for index, sig := range newUpdate.Signatures {
+ if index == 0 && sig.PublicKey != newUpdate.RootPublicKey {
+ // The first signature in the announcement must be from the
+ // key that claims to be the root.
+ return fmt.Errorf("root announcement first signature is not from the root node")
+ }
+ if sig.Hop == 0 {
+ // None of the hops in the update should have a port number of 0
+ // as this would imply that another node has sent their router
+ // port, which is impossible. We'll therefore reject any update
+ // that tries to do that.
+ return fmt.Errorf("root announcement contains an invalid port number")
+ }
+ if index == len(newUpdate.Signatures)-1 && peer.PublicKey() != sig.PublicKey {
+ // The last signature in the announcement must be from the
+ // direct peer. If it isn't then it sounds like someone is
+ // trying to replay someone else's announcement to us.
+ return fmt.Errorf("root announcement last signature is not from the direct peer")
+ }
+ pk := hex.EncodeToString(sig.PublicKey[:])
+ if _, ok := sigs[pk]; ok {
+ // One of the signatures has appeared in the update more than
+ // once, which would suggest that there's a loop somewhere.
+ return fmt.Errorf("root announcement contains a routing loop")
+ }
+ sigs[pk] = struct{}{}
}
+
+ defer r.tree.UpdateParentIfNeeded(peer, newUpdate)
+ return peer.updateAnnouncement(&newUpdate)
}
type rootAnnouncementWithTime struct {
@@ -51,16 +78,49 @@ type rootAnnouncementWithTime struct {
at time.Time
}
+func (a *rootAnnouncementWithTime) ForPeer(p *Peer) *types.Frame {
+ if p.port == 0 {
+ return nil
+ }
+ announcement := a.SwitchAnnouncement
+ announcement.Signatures = append([]types.SignatureWithHop{}, a.Signatures...)
+ for _, sig := range announcement.Signatures {
+ if p.r.public.EqualTo(sig.PublicKey) {
+ // For some reason the announcement that we want to send already
+ // includes our signature. This shouldn't really happen but if we
+ // did send it, other nodes would end up ignoring the announcement
+ // anyway since it would appear to be a routing loop.
+ return nil
+ }
+ }
+ // Sign the announcement.
+ if err := announcement.Sign(p.r.private[:], p.port); err != nil {
+ p.r.log.Println("Failed to sign switch announcement:", err)
+ return nil
+ }
+ var payload [MaxPayloadSize]byte
+ n, err := announcement.MarshalBinary(payload[:])
+ if err != nil {
+ p.r.log.Println("Failed to marshal switch announcement:", err)
+ return nil
+ }
+ frame := types.GetFrame()
+ frame.Version = types.Version0
+ frame.Type = types.TypeSTP
+ frame.Destination = types.SwitchPorts{}
+ frame.Payload = payload[:n]
+ return frame
+}
+
type spanningTree struct {
- r *Router //
- context context.Context //
- root *rootAnnouncementWithTime //
- rootMutex sync.RWMutex //
- rootReset chan struct{} //
- updateMutex sync.Mutex //
- parent atomic.Value // types.SwitchPortID
- coords atomic.Value // types.SwitchPorts
- callback func(parent types.SwitchPortID, coords types.SwitchPorts)
+ r *Router
+ context context.Context
+ rootReset chan struct{}
+ mutex *sync.Mutex
+ parent types.SwitchPortID
+ reparent *time.Timer
+ sequence atomic.Uint64
+ callback func(parent types.SwitchPortID, coords types.SwitchPorts)
}
func newSpanningTree(r *Router, f func(parent types.SwitchPortID, coords types.SwitchPorts)) *spanningTree {
@@ -68,8 +128,11 @@ func newSpanningTree(r *Router, f func(parent types.SwitchPortID, coords types.S
r: r,
context: r.context,
rootReset: make(chan struct{}),
+ mutex: &sync.Mutex{},
callback: f,
}
+ t.reparent = time.AfterFunc(time.Second, t.selectNewParentAndAdvertise)
+ t.reparent.Stop()
t.becomeRoot()
go t.workerForRoot()
go t.workerForAnnouncements()
@@ -77,17 +140,16 @@ func newSpanningTree(r *Router, f func(parent types.SwitchPortID, coords types.S
}
func (t *spanningTree) Coords() types.SwitchPorts {
- coords, ok := t.coords.Load().(types.SwitchPorts)
- if ok {
- return coords
+ ann := t.r.ports[t.Parent()].lastAnnouncement()
+ if ann == nil {
+ return types.SwitchPorts{}
}
- return types.SwitchPorts{}
+ return ann.Coords()
}
func (t *spanningTree) Ancestors() ([]types.PublicKey, types.SwitchPortID) {
- root := t.Root()
- port, ok := t.parent.Load().(types.SwitchPortID)
- if !ok || port == 0 {
+ root, port := t.Root(), t.Parent()
+ if port == 0 {
return nil, 0
}
ancestors := make([]types.PublicKey, 0, 1+len(root.Signatures))
@@ -106,80 +168,124 @@ func (t *spanningTree) portWasDisconnected(port types.SwitchPortID) {
t.becomeRoot()
return
}
- if parent := t.parent.Load().(types.SwitchPortID); parent == port {
- t.selectNewParent()
+ if t.Parent() == port {
+ t.selectNewParentAndAdvertiseIn(0)
}
}
-func (t *spanningTree) selectNewParent() {
- t.updateMutex.Lock()
- defer t.updateMutex.Unlock()
- t.becomeRoot()
+func (t *spanningTree) selectNewParentAndAdvertiseIn(d time.Duration) {
+ t.reparent.Reset(d)
+}
+
+func (t *spanningTree) selectNewParentAndAdvertise() {
+ lastUpdate := t.Root()
+ bestKey := lastUpdate.RootPublicKey
+ bestSeq := lastUpdate.Sequence
bestDist := math.MaxInt32
- bestKey := t.r.public
- var bestTime time.Time
+ bestTime := time.Now()
var bestPort types.SwitchPortID
var bestAnn *rootAnnouncementWithTime
- var bestSeq types.Varu64
- portsToCheck := map[*Peer]*rootAnnouncementWithTime{}
+
+ t.mutex.Lock()
+
for _, p := range t.r.activePorts() {
ann := p.lastAnnouncement()
if ann == nil {
continue
}
- portsToCheck[p] = ann
- }
- checkWithCondition := func(f func(ann *rootAnnouncementWithTime, hops int) bool) {
- for p, ann := range portsToCheck {
- hops := len(ann.Signatures)
- if f(ann, hops) {
- bestKey = ann.RootPublicKey
- bestDist = hops
- bestPort = p.port
- bestTime = ann.at
- bestSeq = ann.Sequence
- bestAnn = ann
- }
+ accept := func() {
+ bestKey = ann.RootPublicKey
+ bestDist = len(ann.Signatures)
+ bestPort = p.port
+ bestTime = ann.at
+ bestSeq = ann.Sequence
+ bestAnn = ann
+ }
+ keyDelta := ann.RootPublicKey.CompareTo(bestKey)
+ annAt := ann.at.Round(time.Millisecond)
+ switch {
+ case ann.IsLoopOrChildOf(p.r.public):
+ // ignore our children or loopy announcements
+ case keyDelta > 0:
+ accept()
+ case keyDelta < 0:
+ // ignore weaker root keys
+ case ann.Sequence > bestSeq:
+ accept()
+ case ann.Sequence < bestSeq:
+ // ignore lower sequence numbers
+ case annAt.Before(bestTime):
+ accept()
+ case annAt.After(bestTime):
+ // ignore updates that arrived more recently
+ case len(ann.Signatures) < bestDist:
+ accept()
+ case len(ann.Signatures) > bestDist:
+ // ignore longer paths
+ case p.public.CompareTo(bestKey) > 0:
+ accept()
}
}
- checkWithCondition(func(ann *rootAnnouncementWithTime, hops int) bool {
- return ann.RootPublicKey.CompareTo(bestKey) > 0
- })
- checkWithCondition(func(ann *rootAnnouncementWithTime, hops int) bool {
- return ann.RootPublicKey.CompareTo(bestKey) == 0 && ann.Sequence > bestSeq
- })
- checkWithCondition(func(ann *rootAnnouncementWithTime, hops int) bool {
- return ann.RootPublicKey.CompareTo(bestKey) == 0 && ann.Sequence == bestSeq && hops < bestDist
- })
- checkWithCondition(func(ann *rootAnnouncementWithTime, hops int) bool {
- return ann.RootPublicKey.CompareTo(bestKey) == 0 && ann.Sequence == bestSeq && hops == bestDist && ann.at.Before(bestTime)
- })
+
if bestAnn != nil {
- t.parent.Store(bestPort)
- if err := t.Update(t.r.ports[bestPort], bestAnn.SwitchAnnouncement); err != nil {
- t.r.log.Println("t.Update: %w", err)
+ t.parent = bestPort
+ t.mutex.Unlock()
+
+ newCoords := bestAnn.Coords()
+ coordsChanged := !lastUpdate.Coords().EqualTo(bestAnn.Coords())
+ rootChanged := bestAnn.RootPublicKey != lastUpdate.RootPublicKey
+
+ if rootChanged {
+ t.r.snake.rootNodeChanged(bestAnn.RootPublicKey)
+ }
+ if rootChanged || coordsChanged {
+ t.callback(bestPort, newCoords)
}
+
+ t.advertise()
+ return
}
+
+ // No suitable other peer was found, so we'll just become the root
+ // and hope that one of our peers corrects us if it matters.
+ t.mutex.Unlock()
+ t.becomeRoot()
}
func (t *spanningTree) advertise() {
+ t.mutex.Lock()
+ defer t.mutex.Unlock()
+
+ t.r.ports[t.parent].mutex.RLock()
+ ann := t.r.ports[t.parent].announcement
+ t.r.ports[t.parent].mutex.RUnlock()
+
+ if t.parent == 0 || ann == nil { // we are the root
+ ann = &rootAnnouncementWithTime{
+ at: time.Now(),
+ SwitchAnnouncement: types.SwitchAnnouncement{
+ RootPublicKey: t.r.public,
+ Sequence: types.Varu64(t.sequence.Inc()),
+ },
+ }
+ }
+
for _, p := range t.r.startedPorts() {
- go func(p *Peer) {
- select {
- case <-p.context.Done():
- case <-time.After(announcementTimeout):
- case p.announce <- struct{}{}:
- }
- }(p)
+ p.protoOut.push(ann.ForPeer(p))
}
}
func (t *spanningTree) becomeRoot() {
- t.parent.Store(types.SwitchPortID(0))
+ t.mutex.Lock()
+ t.parent = 0
+ t.mutex.Unlock()
+
newCoords := types.SwitchPorts{}
if !t.Coords().EqualTo(newCoords) {
go t.callback(0, types.SwitchPorts{})
+ defer t.r.snake.rootNodeChanged(t.r.public)
}
+
t.advertise()
}
@@ -210,7 +316,7 @@ func (t *spanningTree) workerForRoot() {
case <-time.After(announcementTimeout):
if !t.IsRoot() {
- t.selectNewParent()
+ t.becomeRoot()
}
}
}
@@ -222,15 +328,13 @@ func (t *spanningTree) IsRoot() bool {
}
func (t *spanningTree) Root() *rootAnnouncementWithTime {
- t.rootMutex.RLock()
- root := t.root
- t.rootMutex.RUnlock()
- if root == nil || time.Since(root.at) > announcementTimeout {
+ root := t.r.ports[t.Parent()].lastAnnouncement()
+ if root == nil {
return &rootAnnouncementWithTime{
at: time.Now(),
SwitchAnnouncement: types.SwitchAnnouncement{
RootPublicKey: t.r.public,
- Sequence: types.Varu64(time.Now().UnixNano()),
+ Sequence: types.Varu64(t.sequence.Load()),
},
}
}
@@ -245,140 +349,53 @@ func (t *spanningTree) Root() *rootAnnouncementWithTime {
}
func (t *spanningTree) Parent() types.SwitchPortID {
- if parent, ok := t.parent.Load().(types.SwitchPortID); ok {
- return parent
- }
- return 0
+ t.mutex.Lock()
+ defer t.mutex.Unlock()
+ return t.parent
}
-func (t *spanningTree) Update(p *Peer, newUpdate types.SwitchAnnouncement) error {
- sigs := make(map[string]struct{})
- isLoopbackUpdate := false
- for index, sig := range newUpdate.Signatures {
- if index == 0 && sig.PublicKey != newUpdate.RootPublicKey {
- // The first signature in the announcement must be from the
- // key that claims to be the root.
- return fmt.Errorf("rejecting update (first signature must be from root)")
- }
- if sig.Hop == 0 {
- // None of the hops in the update should have a port number of 0
- // as this would imply that another node has sent their router
- // port, which is impossible. We'll therefore reject any update
- // that tries to do that.
- return fmt.Errorf("rejecting update (invalid 0 hop)")
- }
- if index == len(newUpdate.Signatures)-1 && p.PublicKey() != sig.PublicKey {
- // The last signature in the announcement must be from the
- // direct peer. If it isn't then it sounds like someone is
- // trying to replay someone else's announcement to us.
- return fmt.Errorf("rejecting update (last signature must be from peer)")
- }
- if sig.PublicKey.EqualTo(t.r.public) {
- // A child update is one that contains our public key in the
- // signatures already - it's probably one of our direct peers
- // sending our root announcement back to us. In this case there
- // is some special behaviour: we usually will need to accept
- // the update on the port, but we don't want to do anything that
- // would influence root or coordinate changes.
- isLoopbackUpdate = true
- }
- pk := hex.EncodeToString(sig.PublicKey[:])
- if _, ok := sigs[pk]; ok {
- // One of the signatures has appeared in the update more than
- // once, which would suggest that there's a loop somewhere.
- return fmt.Errorf("rejecting update (detected routing loop)")
- }
- sigs[pk] = struct{}{}
- }
-
- lastPortUpdate := p.lastAnnouncement()
- if lastPortUpdate != nil && lastPortUpdate.RootPublicKey == newUpdate.RootPublicKey {
- if newUpdate.Sequence < lastPortUpdate.Sequence {
- // The update has a lower sequence number than our last update
- // on this port from this root. This shouldn't happen, but if it
- // does, then just drop the update.
- return nil
- }
- }
-
- if err := p.updateAnnouncement(&newUpdate); err != nil {
- return fmt.Errorf("p.updateAnnouncement: %w", err)
- }
-
- if isLoopbackUpdate {
- // The update contains our own signature already, so using it for
- // a root update would create a loop
- return nil
- }
-
- t.updateMutex.Lock()
- defer t.updateMutex.Unlock()
-
- lastGlobalUpdate := t.Root()
- globalKeyDelta := newUpdate.RootPublicKey.CompareTo(lastGlobalUpdate.RootPublicKey)
- globalTimeSince := time.Since(lastGlobalUpdate.at)
- globalUpdate := false
+func (t *spanningTree) UpdateParentIfNeeded(p *Peer, newUpdate types.SwitchAnnouncement) {
+ const immediately = time.Duration(0)
+ reparentIn, becomeRoot := time.Duration(-1), false
+ lastParentUpdate, isParent := t.Root(), p.IsParent()
+ keyDeltaSinceLastParentUpdate := newUpdate.RootPublicKey.CompareTo(lastParentUpdate.RootPublicKey)
switch {
- case globalTimeSince > announcementTimeout:
- // We haven't seen a suitable root update recently so we'll accept
- // this one instead
- globalUpdate = true
-
- case globalKeyDelta < 0:
- // The root key is weaker than our existing root, so it's no good
- return nil
-
- case globalKeyDelta == 0:
- // The update is from the same root node, let's see if it matches
- // any other useful conditions
-
- switch {
- case newUpdate.Sequence < lastGlobalUpdate.Sequence:
- // This is a replay of an earlier update, therefore we should
- // ignore it, even if it came from our parent node
- return nil
-
- case len(newUpdate.Signatures) < len(lastGlobalUpdate.Signatures):
- // The path to the root is shorter than our last update, so
- // we'll accept it
- globalUpdate = true
- }
-
- case globalKeyDelta > 0:
- // The root key is stronger than our existing root, therefore we'll
- // accept it anyway, since we always want to converge on the
- // strongest root key
- globalUpdate = true
+ case keyDeltaSinceLastParentUpdate > 0:
+ // The peer has sent us a key that is stronger than our last update.
+ reparentIn = immediately
+
+ case time.Since(lastParentUpdate.at) >= announcementTimeout:
+ // It's been a while since we last heard from our parent so we should
+ // really choose a new one if we can.
+ reparentIn = immediately
+
+ case isParent && keyDeltaSinceLastParentUpdate < 0:
+ // Our parent sent us a weaker key than before — this implies that
+ // something bad happened.
+ reparentIn, becomeRoot = time.Second/2, true
+
+ case isParent && keyDeltaSinceLastParentUpdate == 0 && newUpdate.Sequence < lastParentUpdate.Sequence:
+ // Our parent sent us a lower sequence number than before for the
+ // same root — this isn't good news either.
+ reparentIn, becomeRoot = time.Second/2, true
+
+ case isParent && keyDeltaSinceLastParentUpdate == 0 && newUpdate.Sequence == lastParentUpdate.Sequence:
+ // Our parent sent us an equal sequence number, which probably means
+ // that their path to the root has changed. This isn't as bad news as
+ // a weaker key but we should still try to find a new parent.
+ reparentIn, becomeRoot = time.Second/4, false
}
- if parent := t.parent.Load(); p.port == parent || globalUpdate {
- t.rootMutex.Lock()
- var oldRootKey types.PublicKey
- if t.root != nil {
- oldRootKey = t.root.RootPublicKey
- }
- t.root = &rootAnnouncementWithTime{
- at: time.Now(),
- SwitchAnnouncement: newUpdate,
- }
- t.parent.Store(p.port)
- oldcoords := t.Coords()
- newcoords := t.root.Coords()
- t.coords.Store(newcoords)
- t.rootMutex.Unlock()
-
- if t.callback != nil && !oldcoords.EqualTo(newcoords) {
- go t.callback(p.port, newcoords)
- }
-
- if oldRootKey != newUpdate.RootPublicKey {
- defer t.r.snake.rootNodeChanged(newUpdate.RootPublicKey)
+ switch {
+ case reparentIn >= 0:
+ if becomeRoot {
+ t.becomeRoot()
}
+ t.selectNewParentAndAdvertiseIn(reparentIn)
+ case isParent && keyDeltaSinceLastParentUpdate == 0 && newUpdate.Sequence > lastParentUpdate.Sequence:
t.advertise()
t.rootReset <- struct{}{}
}
-
- return nil
}
diff --git a/router/version.go b/router/version.go
index 75a5ca58..a56e3af5 100644
--- a/router/version.go
+++ b/router/version.go
@@ -16,12 +16,13 @@ package router
const (
// reserved = 1
- capabilityVirtualSnake = 2
- capabilityHardState = 4
- capabilityPathIDs = 8
- capabilityRootInUpdates = 16
- capabilityNewHeaders = 32
+ capabilityVirtualSnake = 2
+ capabilityHardState = 4
+ capabilityPathIDs = 8
+ capabilityRootInUpdates = 16
+ capabilityNewHeaders = 32
+ capabilitySlowRootInterval = 64
)
const ourVersion = 0
-const ourCapabilities = capabilityVirtualSnake | capabilityHardState | capabilityPathIDs | capabilityRootInUpdates | capabilityNewHeaders
+const ourCapabilities = capabilityVirtualSnake | capabilityHardState | capabilityPathIDs | capabilityRootInUpdates | capabilityNewHeaders | capabilitySlowRootInterval
diff --git a/types/announcement.go b/types/announcement.go
index 8382e604..cd61802b 100644
--- a/types/announcement.go
+++ b/types/announcement.go
@@ -97,15 +97,32 @@ func (a *SwitchAnnouncement) Coords() SwitchPorts {
return coords
}
-func (a *SwitchAnnouncement) PeerCoords(public PublicKey) (SwitchPorts, error) {
+func (a *SwitchAnnouncement) PeerCoords() SwitchPorts {
sigs := a.Signatures
- last := len(sigs) - 1
- if sigs[last].PublicKey != public {
- return nil, fmt.Errorf("invalid last hop")
- }
- coords := make(SwitchPorts, 0, len(sigs))
- for _, sig := range sigs[:last] {
+ coords := make(SwitchPorts, 0, len(sigs)-1)
+ for _, sig := range sigs[:len(sigs)-1] {
coords = append(coords, SwitchPortID(sig.Hop))
}
- return coords, nil
+ return coords
+}
+
+func (a *SwitchAnnouncement) AncestorParent() PublicKey {
+ if len(a.Signatures) < 2 {
+ return a.RootPublicKey
+ }
+ return a.Signatures[len(a.Signatures)-2].PublicKey
+}
+
+func (a *SwitchAnnouncement) IsLoopOrChildOf(pk PublicKey) bool {
+ m := map[PublicKey]struct{}{}
+ for _, sig := range a.Signatures {
+ if sig.PublicKey.EqualTo(pk) {
+ return true
+ }
+ if _, ok := m[sig.PublicKey]; ok {
+ return true
+ }
+ m[sig.PublicKey] = struct{}{}
+ }
+ return false
}