diff --git a/go.mod b/go.mod index 2106ce130..d87a05273 100644 --- a/go.mod +++ b/go.mod @@ -226,7 +226,6 @@ require ( go.uber.org/zap v1.21.0 // indirect golang.org/x/exp/typeparams v0.0.0-20220218215828-6cf2b201936e // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/sys v0.28.0 // indirect golang.org/x/term v0.27.0 // indirect golang.org/x/text v0.21.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect @@ -254,5 +253,6 @@ require ( go.opentelemetry.io/otel v1.9.0 go.opentelemetry.io/otel/sdk v1.9.0 go.opentelemetry.io/otel/trace v1.9.0 + golang.org/x/sys v0.28.0 google.golang.org/protobuf v1.28.0 ) diff --git a/internal/blocksync/reactor_test.go b/internal/blocksync/reactor_test.go index 1816ad925..f81dc69c7 100644 --- a/internal/blocksync/reactor_test.go +++ b/internal/blocksync/reactor_test.go @@ -73,7 +73,7 @@ func setup( } chDesc := &p2p.ChannelDescriptor{ID: BlockSyncChannel, MessageType: new(bcproto.Message)} - rts.blockSyncChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc) + rts.blockSyncChannels = rts.network.MakeChannelsNoCleanup(t, chDesc) i := 0 for nodeID := range rts.network.Nodes { @@ -101,10 +101,7 @@ func setup( func makeReactor( ctx context.Context, t *testing.T, - nodeID types.NodeID, genDoc *types.GenesisDoc, - privVal types.PrivValidator, - channelCreator p2p.ChannelCreator, peerEvents p2p.PeerEventSubscriber, peerManager *p2p.PeerManager, restartChan chan struct{}, @@ -188,10 +185,6 @@ func (rts *reactorTestSuite) addNode( rts.peerUpdates[nodeID] = p2p.NewPeerUpdates(rts.peerChans[nodeID], 1) rts.network.Nodes[nodeID].PeerManager.Register(ctx, rts.peerUpdates[nodeID]) - chCreator := func(ctx context.Context, chdesc *p2p.ChannelDescriptor) (*p2p.Channel, error) { - return rts.blockSyncChannels[nodeID], nil - } - peerEvents := func(ctx context.Context) *p2p.PeerUpdates { return rts.peerUpdates[nodeID] } restartChan := make(chan struct{}) remediationConfig := config.DefaultSelfRemediationConfig() @@ -200,10 +193,7 @@ func (rts *reactorTestSuite) addNode( reactor := makeReactor( ctx, t, - nodeID, genDoc, - privVal, - chCreator, peerEvents, rts.network.Nodes[nodeID].PeerManager, restartChan, diff --git a/internal/consensus/reactor_test.go b/internal/consensus/reactor_test.go index 88ad7ed39..9ed50116e 100644 --- a/internal/consensus/reactor_test.go +++ b/internal/consensus/reactor_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "math" "math/rand" "os" "sync" @@ -57,7 +58,7 @@ func chDesc(chID p2p.ChannelID, size int) *p2p.ChannelDescriptor { return &p2p.ChannelDescriptor{ ID: chID, MessageType: new(tmcons.Message), - RecvBufferCapacity: size, + RecvBufferCapacity: int(math.Sqrt(float64(size)) + 1), } } @@ -78,10 +79,10 @@ func setup( blocksyncSubs: make(map[types.NodeID]eventbus.Subscription, numNodes), } - rts.stateChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc(StateChannel, size)) - rts.dataChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc(DataChannel, size)) - rts.voteChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc(VoteChannel, size)) - rts.voteSetBitsChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc(VoteSetBitsChannel, size)) + rts.stateChannels = rts.network.MakeChannelsNoCleanup(t, chDesc(StateChannel, size)) + rts.dataChannels = rts.network.MakeChannelsNoCleanup(t, chDesc(DataChannel, size)) + rts.voteChannels = rts.network.MakeChannelsNoCleanup(t, chDesc(VoteChannel, size)) + rts.voteSetBitsChannels = rts.network.MakeChannelsNoCleanup(t, chDesc(VoteSetBitsChannel, size)) ctx, cancel := context.WithCancel(ctx) t.Cleanup(cancel) diff --git a/internal/evidence/reactor_test.go b/internal/evidence/reactor_test.go index 3d898dc59..a5084a30e 100644 --- a/internal/evidence/reactor_test.go +++ b/internal/evidence/reactor_test.go @@ -63,8 +63,12 @@ func setup(ctx context.Context, t *testing.T, stateStores []sm.Store) *reactorTe peerChans: make(map[types.NodeID]chan p2p.PeerUpdate, numStateStores), } - chDesc := &p2p.ChannelDescriptor{ID: evidence.EvidenceChannel, MessageType: new(tmproto.Evidence)} - rts.evidenceChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc) + chDesc := &p2p.ChannelDescriptor{ + ID: evidence.EvidenceChannel, + MessageType: new(tmproto.Evidence), + RecvBufferCapacity: 10, + } + rts.evidenceChannels = rts.network.MakeChannelsNoCleanup(t, chDesc) require.Len(t, rts.network.RandomNode().PeerManager.Peers(), 0) idx := 0 diff --git a/internal/mempool/reactor_test.go b/internal/mempool/reactor_test.go index d26f2bec3..2de190e8e 100644 --- a/internal/mempool/reactor_test.go +++ b/internal/mempool/reactor_test.go @@ -59,7 +59,7 @@ func setupReactors(ctx context.Context, t *testing.T, logger log.Logger, numNode } chDesc := GetChannelDescriptor(cfg.Mempool) - rts.mempoolChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc) + rts.mempoolChannels = rts.network.MakeChannelsNoCleanup(t, chDesc) for nodeID := range rts.network.Nodes { rts.kvstores[nodeID] = kvstore.NewApplication() @@ -174,7 +174,7 @@ func TestReactorBroadcastDoesNotPanic(t *testing.T) { go primaryReactor.broadcastTxRoutine(ctx, secondary, rts.mempoolChannels[primary]) wg := &sync.WaitGroup{} - for i := 0; i < 50; i++ { + for range 50 { next := &WrappedTx{} wg.Add(1) go func() { diff --git a/internal/p2p/address.go b/internal/p2p/address.go index 0f4066faf..18d95100b 100644 --- a/internal/p2p/address.go +++ b/internal/p2p/address.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "net/netip" "net/url" "regexp" "strconv" @@ -97,7 +98,7 @@ func ParseNodeAddress(urlString string) (NodeAddress, error) { // Resolve resolves a NodeAddress into a set of Endpoints, by expanding // out a DNS hostname to IP addresses. -func (a NodeAddress) Resolve(ctx context.Context) ([]*Endpoint, error) { +func (a NodeAddress) Resolve(ctx context.Context) ([]Endpoint, error) { if a.Protocol == "" { return nil, errors.New("address has no protocol") } @@ -109,7 +110,7 @@ func (a NodeAddress) Resolve(ctx context.Context) ([]*Endpoint, error) { if a.NodeID == "" { return nil, errors.New("local address has no node ID") } - return []*Endpoint{{ + return []Endpoint{{ Protocol: a.Protocol, Path: string(a.NodeID), }}, nil @@ -119,12 +120,15 @@ func (a NodeAddress) Resolve(ctx context.Context) ([]*Endpoint, error) { if err != nil { return nil, err } - endpoints := make([]*Endpoint, len(ips)) + endpoints := make([]Endpoint, len(ips)) for i, ip := range ips { - endpoints[i] = &Endpoint{ + ip, ok := netip.AddrFromSlice(ip) + if !ok { + return nil, fmt.Errorf("LookupIP returned invalid IP %q", ip) + } + endpoints[i] = Endpoint{ Protocol: a.Protocol, - IP: ip, - Port: a.Port, + Addr: netip.AddrPortFrom(ip, a.Port), Path: a.Path, } } diff --git a/internal/p2p/address_test.go b/internal/p2p/address_test.go index 7c6fdb9bc..6b660aac1 100644 --- a/internal/p2p/address_test.go +++ b/internal/p2p/address_test.go @@ -1,14 +1,14 @@ package p2p_test import ( - "net" + "net/netip" "strings" "testing" - "github.com/stretchr/testify/require" - "github.com/tendermint/tendermint/crypto/ed25519" "github.com/tendermint/tendermint/internal/p2p" + "github.com/tendermint/tendermint/libs/utils/require" + "github.com/tendermint/tendermint/libs/utils/tcp" "github.com/tendermint/tendermint/types" ) @@ -202,61 +202,61 @@ func TestNodeAddress_Resolve(t *testing.T) { testcases := []struct { address p2p.NodeAddress - expect *p2p.Endpoint + expect p2p.Endpoint ok bool }{ // Valid networked addresses (with hostname). { p2p.NodeAddress{Protocol: "tcp", Hostname: "127.0.0.1", Port: 80, Path: "/path"}, - &p2p.Endpoint{Protocol: "tcp", IP: net.IPv4(127, 0, 0, 1), Port: 80, Path: "/path"}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(tcp.IPv4Loopback(), 80), Path: "/path"}, true, }, { p2p.NodeAddress{Protocol: "tcp", Hostname: "127.0.0.1"}, - &p2p.Endpoint{Protocol: "tcp", IP: net.IPv4(127, 0, 0, 1)}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(tcp.IPv4Loopback(), 0)}, true, }, { p2p.NodeAddress{Protocol: "tcp", Hostname: "::1"}, - &p2p.Endpoint{Protocol: "tcp", IP: net.IPv6loopback}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(netip.IPv6Loopback(), 0)}, true, }, { p2p.NodeAddress{Protocol: "tcp", Hostname: "8.8.8.8"}, - &p2p.Endpoint{Protocol: "tcp", IP: net.IPv4(8, 8, 8, 8)}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{8, 8, 8, 8}), 0)}, true, }, { p2p.NodeAddress{Protocol: "tcp", Hostname: "2001:0db8::ff00:0042:8329"}, - &p2p.Endpoint{Protocol: "tcp", IP: []byte{ - 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x42, 0x83, 0x29}}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(netip.AddrFrom16([16]byte{ + 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x42, 0x83, 0x29}), 0)}, true, }, { p2p.NodeAddress{Protocol: "tcp", Hostname: "some.missing.host.tendermint.com"}, - &p2p.Endpoint{}, + p2p.Endpoint{}, false, }, // Valid non-networked addresses. { p2p.NodeAddress{Protocol: "memory", NodeID: id}, - &p2p.Endpoint{Protocol: "memory", Path: string(id)}, + p2p.Endpoint{Protocol: "memory", Path: string(id)}, true, }, { p2p.NodeAddress{Protocol: "memory", NodeID: id, Path: string(id)}, - &p2p.Endpoint{Protocol: "memory", Path: string(id)}, + p2p.Endpoint{Protocol: "memory", Path: string(id)}, true, }, // Invalid addresses. - {p2p.NodeAddress{}, &p2p.Endpoint{}, false}, - {p2p.NodeAddress{Hostname: "127.0.0.1"}, &p2p.Endpoint{}, false}, - {p2p.NodeAddress{Protocol: "tcp", Hostname: "127.0.0.1:80"}, &p2p.Endpoint{}, false}, - {p2p.NodeAddress{Protocol: "memory"}, &p2p.Endpoint{}, false}, - {p2p.NodeAddress{Protocol: "memory", Path: string(id)}, &p2p.Endpoint{}, false}, - {p2p.NodeAddress{Protocol: "tcp", Hostname: "💥"}, &p2p.Endpoint{}, false}, + {p2p.NodeAddress{}, p2p.Endpoint{}, false}, + {p2p.NodeAddress{Hostname: "127.0.0.1"}, p2p.Endpoint{}, false}, + {p2p.NodeAddress{Protocol: "tcp", Hostname: "127.0.0.1:80"}, p2p.Endpoint{}, false}, + {p2p.NodeAddress{Protocol: "memory"}, p2p.Endpoint{}, false}, + {p2p.NodeAddress{Protocol: "memory", Path: string(id)}, p2p.Endpoint{}, false}, + {p2p.NodeAddress{Protocol: "tcp", Hostname: "💥"}, p2p.Endpoint{}, false}, } for _, tc := range testcases { t.Run(tc.address.String(), func(t *testing.T) { @@ -265,39 +265,26 @@ func TestNodeAddress_Resolve(t *testing.T) { require.Error(t, err) return } - - // Special handling for localhost tests - accept either IPv4 or IPv6 - if tc.address.Hostname == "localhost" && tc.address.Port == 80 && tc.address.Path == "/path" { - hasIPv4 := false - hasIPv6 := false - for _, ep := range endpoints { - if ep.Protocol == "tcp" && ep.Port == 80 && ep.Path == "/path" { - if ep.IP.Equal(net.IPv4(127, 0, 0, 1)) { - hasIPv4 = true - } - if ep.IP.Equal(net.IPv6loopback) { - hasIPv6 = true - } - } - } - require.True(t, hasIPv4 || hasIPv6, "localhost should resolve to either IPv4 or IPv6") - return + ok := false + tc.expect.Addr = tcp.Norm(tc.expect.Addr) + for _, e := range endpoints { + e.Addr = tcp.Norm(e.Addr) + ok = ok || e == tc.expect + } + if !ok { + t.Fatalf("%v not in %v", tc.expect, endpoints) } - - require.Contains(t, endpoints, tc.expect) }) } t.Run("Resolve localhost", func(t *testing.T) { addr := p2p.NodeAddress{Protocol: "tcp", Hostname: "localhost", Port: 80, Path: "/path"} endpoints, err := addr.Resolve(t.Context()) require.NoError(t, err) - - want := &p2p.Endpoint{Protocol: "tcp", Port: 80, Path: "/path"} require.True(t, len(endpoints) > 0) for _, got := range endpoints { - require.True(t, got.IP.IsLoopback()) + require.True(t, got.Addr.Addr().IsLoopback()) // Any loopback address is acceptable, so ignore it in comparison. - want.IP = got.IP + want := p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(got.Addr.Addr(), 80), Path: "/path"} require.Equal(t, want, got) } }) diff --git a/internal/p2p/channel.go b/internal/p2p/channel.go index 82d7a6b98..e44445f93 100644 --- a/internal/p2p/channel.go +++ b/internal/p2p/channel.go @@ -7,6 +7,7 @@ import ( "github.com/gogo/protobuf/proto" + "github.com/tendermint/tendermint/libs/utils" "github.com/tendermint/tendermint/types" ) @@ -60,7 +61,7 @@ func (pe PeerError) Unwrap() error { return pe.Err } // Each message is wrapped in an Envelope to specify its sender and receiver. type Channel struct { ID ChannelID - inCh <-chan Envelope // inbound messages (peers to reactors) + inCh *Queue // inbound messages (peers to reactors) outCh chan<- Envelope // outbound messages (reactors to peers) errCh chan<- PeerError // peer error reporting @@ -69,7 +70,7 @@ type Channel struct { // NewChannel creates a new channel. It is primarily for internal and test // use, reactors should use Router.OpenChannel(). -func NewChannel(id ChannelID, inCh <-chan Envelope, outCh chan<- Envelope, errCh chan<- PeerError) *Channel { +func NewChannel(id ChannelID, inCh *Queue, outCh chan<- Envelope, errCh chan<- PeerError) *Channel { return &Channel{ ID: id, inCh: inCh, @@ -102,6 +103,8 @@ func (ch *Channel) SendError(ctx context.Context, pe PeerError) error { func (ch *Channel) String() string { return fmt.Sprintf("p2p.Channel<%d:%s>", ch.ID, ch.name) } +func (ch *Channel) ReceiveLen() int { return ch.inCh.Len() } + // Receive returns a new unbuffered iterator to receive messages from ch. // The iterator runs until ctx ends. func (ch *Channel) Receive(ctx context.Context) *ChannelIterator { @@ -128,15 +131,12 @@ type ChannelIterator struct { func iteratorWorker(ctx context.Context, ch *Channel, pipe chan Envelope) { for { - select { - case <-ctx.Done(): + e, err := ch.inCh.Recv(ctx) + if err != nil { + return + } + if err := utils.Send(ctx, pipe, e); err != nil { return - case envelope := <-ch.inCh: - select { - case <-ctx.Done(): - return - case pipe <- envelope: - } } } } diff --git a/internal/p2p/channel_test.go b/internal/p2p/channel_test.go index 4bbe178ac..f889f3f0c 100644 --- a/internal/p2p/channel_test.go +++ b/internal/p2p/channel_test.go @@ -11,14 +11,14 @@ import ( ) type channelInternal struct { - In chan Envelope + In *Queue Out chan Envelope Error chan PeerError } func testChannel(size int) (*channelInternal, *Channel) { in := &channelInternal{ - In: make(chan Envelope, size), + In: NewQueue(size), Out: make(chan Envelope, size), Error: make(chan PeerError, size), } @@ -112,7 +112,7 @@ func TestChannel(t *testing.T) { Case: func(t *testing.T) { ctx := t.Context() ins, ch := testChannel(1) - ins.In <- Envelope{From: "kip", To: "merlin"} + ins.In.Send(Envelope{From: "kip", To: "merlin"}, 0) iter := ch.Receive(ctx) require.NotNil(t, iter) require.True(t, iter.Next(ctx)) @@ -157,7 +157,7 @@ func TestChannel(t *testing.T) { ctx := t.Context() ins, ch := testChannel(1) - ins.In <- Envelope{From: "kip", To: "merlin"} + ins.In.Send(Envelope{From: "kip", To: "merlin"}, 0) iter := ch.Receive(ctx) require.NotNil(t, iter) @@ -180,7 +180,7 @@ func TestChannel(t *testing.T) { ctx := t.Context() ins, ch := testChannel(1) - ins.In <- Envelope{From: "kip", To: "merlin"} + ins.In.Send(Envelope{From: "kip", To: "merlin"}, 0) iter := ch.Receive(ctx) require.NotNil(t, iter) @@ -204,7 +204,7 @@ func TestChannel(t *testing.T) { require.NotNil(t, iter) require.Nil(t, iter.Envelope()) - ins.In <- Envelope{From: "kip", To: "merlin"} + ins.In.Send(Envelope{From: "kip", To: "merlin"}, 0) require.NotNil(t, iter) require.True(t, iter.Next(ctx)) diff --git a/internal/p2p/conn/connection.go b/internal/p2p/conn/connection.go index 4bd85d6e2..83054ecb4 100644 --- a/internal/p2p/conn/connection.go +++ b/internal/p2p/conn/connection.go @@ -22,6 +22,7 @@ import ( "github.com/tendermint/tendermint/libs/log" tmmath "github.com/tendermint/tendermint/libs/math" "github.com/tendermint/tendermint/libs/service" + "github.com/tendermint/tendermint/libs/utils" tmp2p "github.com/tendermint/tendermint/proto/tendermint/p2p" ) @@ -302,9 +303,9 @@ func (c *MConnection) stopForError(ctx context.Context, r interface{}) { } // Queues a message to be sent to channel. -func (c *MConnection) Send(chID ChannelID, msgBytes []byte) bool { +func (c *MConnection) Send(ctx context.Context, chID ChannelID, msgBytes []byte) error { if !c.IsRunning() { - return false + return errors.New("not running") } c.logger.Debug("Send", "channel", chID, "conn", c, "msgBytes", msgBytes) @@ -312,21 +313,18 @@ func (c *MConnection) Send(chID ChannelID, msgBytes []byte) bool { // Send message to channel. channel, ok := c.channelsIdx[chID] if !ok { - c.logger.Error(fmt.Sprintf("Cannot send bytes, unknown channel %X", chID)) - return false + return fmt.Errorf("Cannot send bytes, unknown channel %X", chID) } - success := channel.sendBytes(msgBytes) - if success { - // Wake up sendRoutine if necessary - select { - case c.send <- struct{}{}: - default: - } - } else { - c.logger.Debug("Send failed", "channel", chID, "conn", c, "msgBytes", msgBytes) + if err := channel.sendBytes(ctx, msgBytes); err != nil { + return fmt.Errorf("channel.sendBytes(): %v", err) } - return success + // Wake up sendRoutine if necessary + select { + case c.send <- struct{}{}: + default: + } + return nil } // sendRoutine polls for packets to send from channels. @@ -645,12 +643,11 @@ type channel struct { // See https://github.com/tendermint/tendermint/issues/7000. recentlySent int64 - conn *MConnection - desc ChannelDescriptor - sendQueue chan []byte - sendQueueSize int32 // atomic. - recving []byte - sending []byte + conn *MConnection + desc ChannelDescriptor + sendQueue chan []byte + recving []byte + sending []byte maxPacketMsgPayloadSize int @@ -675,16 +672,10 @@ func newChannel(conn *MConnection, desc ChannelDescriptor) *channel { // Queues message to send to this channel. // Goroutine-safe // Times out (and returns false) after defaultSendTimeout -func (ch *channel) sendBytes(bytes []byte) bool { - timer := time.NewTimer(defaultSendTimeout) - defer timer.Stop() - select { - case ch.sendQueue <- bytes: - atomic.AddInt32(&ch.sendQueueSize, 1) - return true - case <-timer.C: - return false - } +func (ch *channel) sendBytes(ctx context.Context, bytes []byte) error { + ctx, cancel := context.WithTimeout(ctx, defaultSendTimeout) + defer cancel() + return utils.Send(ctx, ch.sendQueue, bytes) } // Returns true if any PacketMsgs are pending to be sent. @@ -709,7 +700,6 @@ func (ch *channel) nextPacketMsg() tmp2p.PacketMsg { if len(ch.sending) <= maxSize { packet.EOF = true ch.sending = nil - atomic.AddInt32(&ch.sendQueueSize, -1) // decrement sendQueueSize } else { packet.EOF = false ch.sending = ch.sending[tmmath.MinInt(maxSize, len(ch.sending)):] diff --git a/internal/p2p/conn/connection_test.go b/internal/p2p/conn/connection_test.go index 72e65a1a4..3e148a708 100644 --- a/internal/p2p/conn/connection_test.go +++ b/internal/p2p/conn/connection_test.go @@ -29,7 +29,7 @@ func createTestMConnection(logger log.Logger, conn net.Conn) *MConnection { func(ctx context.Context, chID ChannelID, msgBytes []byte) { }, // onError - func(ctx context.Context, r interface{}) { + func(ctx context.Context, r any) { }) } @@ -37,7 +37,7 @@ func createMConnectionWithCallbacks( logger log.Logger, conn net.Conn, onReceive func(ctx context.Context, chID ChannelID, msgBytes []byte), - onError func(ctx context.Context, r interface{}), + onError func(ctx context.Context, r any), ) *MConnection { cfg := DefaultMConnConfig() cfg.PingInterval = 250 * time.Millisecond @@ -59,7 +59,7 @@ func TestMConnectionSendFlushStop(t *testing.T) { t.Cleanup(waitAll(clientConn)) msg := []byte("abc") - assert.True(t, clientConn.Send(0x01, msg)) + assert.NoError(t, clientConn.Send(ctx, 0x01, msg)) msgLength := 14 @@ -95,7 +95,7 @@ func TestMConnectionSend(t *testing.T) { t.Cleanup(waitAll(mconn)) msg := []byte("Ant-Man") - assert.True(t, mconn.Send(0x01, msg)) + assert.NoError(t, mconn.Send(ctx, 0x01, msg)) // Note: subsequent Send/TrySend calls could pass because we are reading from // the send queue in a separate goroutine. _, err = server.Read(make([]byte, len(msg))) @@ -104,13 +104,13 @@ func TestMConnectionSend(t *testing.T) { } msg = []byte("Spider-Man") - assert.True(t, mconn.Send(0x01, msg)) + assert.NoError(t, mconn.Send(ctx, 0x01, msg)) _, err = server.Read(make([]byte, len(msg))) if err != nil { t.Error(err) } - assert.False(t, mconn.Send(0x05, []byte("Absorbing Man")), "Send should return false because channel is unknown") + assert.Error(t, mconn.Send(ctx, 0x05, []byte("Absorbing Man")), "Send should fail because channel is unknown") } func TestMConnectionReceive(t *testing.T) { @@ -118,14 +118,14 @@ func TestMConnectionReceive(t *testing.T) { t.Cleanup(closeAll(t, client, server)) receivedCh := make(chan []byte) - errorsCh := make(chan interface{}) + errorsCh := make(chan any) onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { select { case receivedCh <- msgBytes: case <-ctx.Done(): } } - onError := func(ctx context.Context, r interface{}) { + onError := func(ctx context.Context, r any) { select { case errorsCh <- r: case <-ctx.Done(): @@ -146,7 +146,7 @@ func TestMConnectionReceive(t *testing.T) { t.Cleanup(waitAll(mconn2)) msg := []byte("Cyclops") - assert.True(t, mconn2.Send(0x01, msg)) + assert.NoError(t, mconn2.Send(ctx, 0x01, msg)) select { case receivedBytes := <-receivedCh: @@ -203,14 +203,14 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { t.Cleanup(closeAll(t, client, server)) receivedCh := make(chan []byte) - errorsCh := make(chan interface{}) + errorsCh := make(chan any) onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { select { case receivedCh <- msgBytes: case <-ctx.Done(): } } - onError := func(ctx context.Context, r interface{}) { + onError := func(ctx context.Context, r any) { select { case errorsCh <- r: case <-ctx.Done(): @@ -261,14 +261,14 @@ func TestMConnectionMultiplePings(t *testing.T) { t.Cleanup(closeAll(t, client, server)) receivedCh := make(chan []byte) - errorsCh := make(chan interface{}) + errorsCh := make(chan any) onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { select { case receivedCh <- msgBytes: case <-ctx.Done(): } } - onError := func(ctx context.Context, r interface{}) { + onError := func(ctx context.Context, r any) { select { case errorsCh <- r: case <-ctx.Done(): @@ -316,14 +316,14 @@ func TestMConnectionPingPongs(t *testing.T) { t.Cleanup(closeAll(t, client, server)) receivedCh := make(chan []byte) - errorsCh := make(chan interface{}) + errorsCh := make(chan any) onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { select { case receivedCh <- msgBytes: case <-ctx.Done(): } } - onError := func(ctx context.Context, r interface{}) { + onError := func(ctx context.Context, r any) { select { case errorsCh <- r: case <-ctx.Done(): @@ -375,14 +375,14 @@ func TestMConnectionStopsAndReturnsError(t *testing.T) { t.Cleanup(closeAll(t, client, server)) receivedCh := make(chan []byte) - errorsCh := make(chan interface{}) + errorsCh := make(chan any) onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { select { case receivedCh <- msgBytes: case <-ctx.Done(): } } - onError := func(ctx context.Context, r interface{}) { + onError := func(ctx context.Context, r any) { select { case errorsCh <- r: case <-ctx.Done(): @@ -418,7 +418,7 @@ func newClientAndServerConnsForReadErrors( server, client := net.Pipe() onReceive := func(context.Context, ChannelID, []byte) {} - onError := func(context.Context, interface{}) {} + onError := func(context.Context, any) {} // create client conn with two channels chDescs := []*ChannelDescriptor{ @@ -434,7 +434,7 @@ func newClientAndServerConnsForReadErrors( // create server conn with 1 channel // it fires on chOnErr when there's an error serverLogger := logger.With("module", "server") - onError = func(ctx context.Context, r interface{}) { + onError = func(ctx context.Context, r any) { select { case <-ctx.Done(): case chOnErr <- struct{}{}: @@ -481,11 +481,11 @@ func TestMConnectionReadErrorUnknownChannel(t *testing.T) { msg := []byte("Ant-Man") // fail to send msg on channel unknown by client - assert.False(t, mconnClient.Send(0x03, msg)) + assert.Error(t, mconnClient.Send(ctx, 0x03, msg)) // send msg on channel unknown by the server. // should cause an error - assert.True(t, mconnClient.Send(0x02, msg)) + assert.NoError(t, mconnClient.Send(ctx, 0x02, msg)) assert.True(t, expectSend(chOnErr), "unknown channel") t.Cleanup(waitAll(mconnClient, mconnServer)) } @@ -557,15 +557,15 @@ func TestMConnectionTrySend(t *testing.T) { msg := []byte("Semicolon-Woman") resultCh := make(chan string, 2) - assert.True(t, mconn.Send(0x01, msg)) + assert.NoError(t, mconn.Send(ctx, 0x01, msg)) _, err = server.Read(make([]byte, len(msg))) require.NoError(t, err) - assert.True(t, mconn.Send(0x01, msg)) + assert.NoError(t, mconn.Send(ctx, 0x01, msg)) go func() { - mconn.Send(0x01, msg) + mconn.Send(ctx, 0x01, msg) resultCh <- "TrySend" }() - assert.False(t, mconn.Send(0x01, msg)) + assert.Error(t, mconn.Send(ctx, 0x01, msg)) assert.Equal(t, "TrySend", <-resultCh) } diff --git a/internal/p2p/conn_tracker.go b/internal/p2p/conn_tracker.go index 54f9c8980..385e734e9 100644 --- a/internal/p2p/conn_tracker.go +++ b/internal/p2p/conn_tracker.go @@ -2,20 +2,20 @@ package p2p import ( "fmt" - "net" + "net/netip" "sync" "time" ) type connectionTracker interface { - AddConn(net.IP) error - RemoveConn(net.IP) + AddConn(netip.AddrPort) error + RemoveConn(netip.AddrPort) Len() int } type connTrackerImpl struct { - cache map[string]uint - lastConnect map[string]time.Time + cache map[netip.Addr]uint + lastConnect map[netip.Addr]time.Time mutex sync.RWMutex max uint window time.Duration @@ -23,8 +23,8 @@ type connTrackerImpl struct { func newConnTracker(max uint, window time.Duration) connectionTracker { return &connTrackerImpl{ - cache: make(map[string]uint), - lastConnect: make(map[string]time.Time), + cache: map[netip.Addr]uint{}, + lastConnect: map[netip.Addr]time.Time{}, max: max, window: window, } @@ -36,8 +36,8 @@ func (rat *connTrackerImpl) Len() int { return len(rat.cache) } -func (rat *connTrackerImpl) AddConn(addr net.IP) error { - address := addr.String() +func (rat *connTrackerImpl) AddConn(addrPort netip.AddrPort) error { + address := addrPort.Addr() rat.mutex.Lock() defer rat.mutex.Unlock() @@ -58,8 +58,8 @@ func (rat *connTrackerImpl) AddConn(addr net.IP) error { return nil } -func (rat *connTrackerImpl) RemoveConn(addr net.IP) { - address := addr.String() +func (rat *connTrackerImpl) RemoveConn(addrPort netip.AddrPort) { + address := addrPort.Addr() rat.mutex.Lock() defer rat.mutex.Unlock() diff --git a/internal/p2p/conn_tracker_test.go b/internal/p2p/conn_tracker_test.go index daa3351f2..93216bdcd 100644 --- a/internal/p2p/conn_tracker_test.go +++ b/internal/p2p/conn_tracker_test.go @@ -3,7 +3,7 @@ package p2p import ( "math" "math/rand" - "net" + "net/netip" "testing" "time" @@ -14,8 +14,15 @@ func randByte() byte { return byte(rand.Intn(math.MaxUint8)) } -func randLocalIPv4() net.IP { - return net.IPv4(127, randByte(), randByte(), randByte()) +func randPort() uint16 { + return uint16(rand.Intn(math.MaxUint16)) +} + +func randLocalAddr() netip.AddrPort { + return netip.AddrPortFrom( + netip.AddrFrom4([4]byte{127, randByte(), randByte(), randByte()}), + randPort(), + ) } func TestConnTracker(t *testing.T) { @@ -35,7 +42,7 @@ func TestConnTracker(t *testing.T) { }) t.Run("RepeatedAdding", func(t *testing.T) { ct := factory() - ip := randLocalIPv4() + ip := randLocalAddr() require.NoError(t, ct.AddConn(ip)) for i := 0; i < 100; i++ { _ = ct.AddConn(ip) @@ -45,14 +52,14 @@ func TestConnTracker(t *testing.T) { t.Run("AddingMany", func(t *testing.T) { ct := factory() for i := 0; i < 100; i++ { - _ = ct.AddConn(randLocalIPv4()) + _ = ct.AddConn(randLocalAddr()) } require.Equal(t, 100, ct.Len()) }) t.Run("Cycle", func(t *testing.T) { ct := factory() for i := 0; i < 100; i++ { - ip := randLocalIPv4() + ip := randLocalAddr() require.NoError(t, ct.AddConn(ip)) ct.RemoveConn(ip) } @@ -63,7 +70,7 @@ func TestConnTracker(t *testing.T) { t.Run("VeryShort", func(t *testing.T) { ct := newConnTracker(10, time.Microsecond) for i := 0; i < 10; i++ { - ip := randLocalIPv4() + ip := randLocalAddr() require.NoError(t, ct.AddConn(ip)) time.Sleep(2 * time.Microsecond) require.NoError(t, ct.AddConn(ip)) @@ -73,7 +80,7 @@ func TestConnTracker(t *testing.T) { t.Run("Window", func(t *testing.T) { const window = 100 * time.Millisecond ct := newConnTracker(10, window) - ip := randLocalIPv4() + ip := randLocalAddr() require.NoError(t, ct.AddConn(ip)) ct.RemoveConn(ip) require.Error(t, ct.AddConn(ip)) diff --git a/internal/p2p/metrics.gen.go b/internal/p2p/metrics.gen.go index 101e652f7..d07febed5 100644 --- a/internal/p2p/metrics.gen.go +++ b/internal/p2p/metrics.gen.go @@ -44,6 +44,12 @@ func PrometheusMetrics(namespace string, labelsAndValues ...string) *Metrics { Name: "peer_pending_send_bytes", Help: "Number of bytes pending being sent to a given peer.", }, append(labels, "peer_id")).With(labelsAndValues...), + NewConnections: prometheus.NewCounterFrom(stdprometheus.CounterOpts{ + Namespace: namespace, + Subsystem: MetricsSubsystem, + Name: "new_connections", + Help: "Number of newly established connections.", + }, append(labels, "direction")).With(labelsAndValues...), RouterPeerQueueRecv: prometheus.NewHistogramFrom(stdprometheus.HistogramOpts{ Namespace: namespace, Subsystem: MetricsSubsystem, @@ -62,18 +68,12 @@ func PrometheusMetrics(namespace string, labelsAndValues ...string) *Metrics { Name: "router_channel_queue_send", Help: "The time taken to send on a p2p channel's queue which will later be consued by the corresponding reactor/service.", }, labels).With(labelsAndValues...), - PeerQueueDroppedMsgs: prometheus.NewCounterFrom(stdprometheus.CounterOpts{ - Namespace: namespace, - Subsystem: MetricsSubsystem, - Name: "router_channel_queue_dropped_msgs", - Help: "The number of messages dropped from a peer's queue for a specific p2p Channel.", - }, append(labels, "ch_id")).With(labelsAndValues...), - PeerQueueMsgSize: prometheus.NewGaugeFrom(stdprometheus.GaugeOpts{ + QueueDroppedMsgs: prometheus.NewCounterFrom(stdprometheus.CounterOpts{ Namespace: namespace, Subsystem: MetricsSubsystem, - Name: "peer_queue_msg_size", - Help: "The size of messages sent over a peer's queue for a specific p2p Channel.", - }, append(labels, "ch_id")).With(labelsAndValues...), + Name: "queue_dropped_msgs", + Help: "The number of messages dropped from router's queues.", + }, append(labels, "ch_id", "direction")).With(labelsAndValues...), } } @@ -84,10 +84,10 @@ func NopMetrics() *Metrics { PeerReceiveBytesTotal: discard.NewCounter(), PeerSendBytesTotal: discard.NewCounter(), PeerPendingSendBytes: discard.NewGauge(), + NewConnections: discard.NewCounter(), RouterPeerQueueRecv: discard.NewHistogram(), RouterPeerQueueSend: discard.NewHistogram(), RouterChannelQueueSend: discard.NewHistogram(), - PeerQueueDroppedMsgs: discard.NewCounter(), - PeerQueueMsgSize: discard.NewGauge(), + QueueDroppedMsgs: discard.NewCounter(), } } diff --git a/internal/p2p/metrics.go b/internal/p2p/metrics.go index bc9678414..52fe5b1a3 100644 --- a/internal/p2p/metrics.go +++ b/internal/p2p/metrics.go @@ -36,6 +36,8 @@ type Metrics struct { PeerSendBytesTotal metrics.Counter `metrics_labels:"peer_id, chID, message_type"` // Number of bytes pending being sent to a given peer. PeerPendingSendBytes metrics.Gauge `metrics_labels:"peer_id"` + // Number of newly established connections. + NewConnections metrics.Counter `metrics_labels:"direction"` // RouterPeerQueueRecv defines the time taken to read off of a peer's queue // before sending on the connection. @@ -52,15 +54,9 @@ type Metrics struct { //metrics:The time taken to send on a p2p channel's queue which will later be consued by the corresponding reactor/service. RouterChannelQueueSend metrics.Histogram - // PeerQueueDroppedMsgs defines the number of messages dropped from a peer's - // queue for a specific flow (i.e. Channel). - //metrics:The number of messages dropped from a peer's queue for a specific p2p Channel. - PeerQueueDroppedMsgs metrics.Counter `metrics_labels:"ch_id" metrics_name:"router_channel_queue_dropped_msgs"` - - // PeerQueueMsgSize defines the average size of messages sent over a peer's - // queue for a specific flow (i.e. Channel). - //metrics:The size of messages sent over a peer's queue for a specific p2p Channel. - PeerQueueMsgSize metrics.Gauge `metrics_labels:"ch_id" metric_name:"router_channel_queue_msg_size"` + // QueueDroppedMsgs counts the messages dropped from the router's queues. + //metrics:The number of messages dropped from router's queues. + QueueDroppedMsgs metrics.Counter `metrics_labels:"ch_id, direction"` } type metricsLabelCache struct { diff --git a/internal/p2p/mocks/transport.go b/internal/p2p/mocks/transport.go index cd9b7ae8c..e2ee3c913 100644 --- a/internal/p2p/mocks/transport.go +++ b/internal/p2p/mocks/transport.go @@ -52,26 +52,8 @@ func (_m *Transport) AddChannelDescriptors(_a0 []*conn.ChannelDescriptor) { _m.Called(_a0) } -// Close provides a mock function with no fields -func (_m *Transport) Close() error { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for Close") - } - - var r0 error - if rf, ok := ret.Get(0).(func() error); ok { - r0 = rf() - } else { - r0 = ret.Error(0) - } - - return r0 -} - // Dial provides a mock function with given fields: _a0, _a1 -func (_m *Transport) Dial(_a0 context.Context, _a1 *p2p.Endpoint) (p2p.Connection, error) { +func (_m *Transport) Dial(_a0 context.Context, _a1 p2p.Endpoint) (p2p.Connection, error) { ret := _m.Called(_a0, _a1) if len(ret) == 0 { @@ -80,10 +62,10 @@ func (_m *Transport) Dial(_a0 context.Context, _a1 *p2p.Endpoint) (p2p.Connectio var r0 p2p.Connection var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *p2p.Endpoint) (p2p.Connection, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, p2p.Endpoint) (p2p.Connection, error)); ok { return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context, *p2p.Endpoint) p2p.Connection); ok { + if rf, ok := ret.Get(0).(func(context.Context, p2p.Endpoint) p2p.Connection); ok { r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { @@ -91,7 +73,7 @@ func (_m *Transport) Dial(_a0 context.Context, _a1 *p2p.Endpoint) (p2p.Connectio } } - if rf, ok := ret.Get(1).(func(context.Context, *p2p.Endpoint) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, p2p.Endpoint) error); ok { r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) @@ -101,48 +83,18 @@ func (_m *Transport) Dial(_a0 context.Context, _a1 *p2p.Endpoint) (p2p.Connectio } // Endpoint provides a mock function with no fields -func (_m *Transport) Endpoint() (*p2p.Endpoint, error) { +func (_m *Transport) Endpoint() p2p.Endpoint { ret := _m.Called() if len(ret) == 0 { panic("no return value specified for Endpoint") } - var r0 *p2p.Endpoint - var r1 error - if rf, ok := ret.Get(0).(func() (*p2p.Endpoint, error)); ok { - return rf() - } - if rf, ok := ret.Get(0).(func() *p2p.Endpoint); ok { + var r0 p2p.Endpoint + if rf, ok := ret.Get(0).(func() p2p.Endpoint); ok { r0 = rf() } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*p2p.Endpoint) - } - } - - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// Listen provides a mock function with given fields: _a0 -func (_m *Transport) Listen(_a0 *p2p.Endpoint) error { - ret := _m.Called(_a0) - - if len(ret) == 0 { - panic("no return value specified for Listen") - } - - var r0 error - if rf, ok := ret.Get(0).(func(*p2p.Endpoint) error); ok { - r0 = rf(_a0) - } else { - r0 = ret.Error(0) + r0 = ret.Get(0).(p2p.Endpoint) } return r0 @@ -168,6 +120,24 @@ func (_m *Transport) Protocols() []p2p.Protocol { return r0 } +// Run provides a mock function with given fields: ctx +func (_m *Transport) Run(ctx context.Context) error { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Run") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // String provides a mock function with no fields func (_m *Transport) String() string { ret := _m.Called() diff --git a/internal/p2p/p2p_test.go b/internal/p2p/p2p_test.go index d8657b774..1fcd46c94 100644 --- a/internal/p2p/p2p_test.go +++ b/internal/p2p/p2p_test.go @@ -17,6 +17,7 @@ var ( MessageType: &p2ptest.Message{}, Priority: 5, SendQueueCapacity: 10, + RecvBufferCapacity: 10, RecvMessageCapacity: 10, } diff --git a/internal/p2p/p2ptest/network.go b/internal/p2p/p2ptest/network.go index 7069ac850..cea53defc 100644 --- a/internal/p2p/p2ptest/network.go +++ b/internal/p2p/p2ptest/network.go @@ -51,7 +51,7 @@ func (opts *NetworkOptions) setDefaults() { // connects them to each other. func MakeNetwork(ctx context.Context, t *testing.T, opts NetworkOptions) *Network { opts.setDefaults() - logger := log.NewNopLogger() + logger, _ := log.NewDefaultLogger("plain", "info") network := &Network{ Nodes: map[types.NodeID]*Node{}, logger: logger, @@ -144,13 +144,12 @@ func (n *Network) NodeIDs() []types.NodeID { // MakeChannels makes a channel on all nodes and returns them, automatically // doing error checks and cleanups. func (n *Network) MakeChannels( - ctx context.Context, t *testing.T, chDesc *p2p.ChannelDescriptor, ) map[types.NodeID]*p2p.Channel { channels := map[types.NodeID]*p2p.Channel{} for _, node := range n.Nodes { - channels[node.NodeID] = node.MakeChannel(ctx, t, chDesc) + channels[node.NodeID] = node.MakeChannel(t, chDesc) } return channels } @@ -159,13 +158,12 @@ func (n *Network) MakeChannels( // automatically doing error checks. The caller must ensure proper cleanup of // all the channels. func (n *Network) MakeChannelsNoCleanup( - ctx context.Context, t *testing.T, chDesc *p2p.ChannelDescriptor, ) map[types.NodeID]*p2p.Channel { channels := map[types.NodeID]*p2p.Channel{} for _, node := range n.Nodes { - channels[node.NodeID] = node.MakeChannelNoCleanup(ctx, t, chDesc) + channels[node.NodeID] = node.MakeChannelNoCleanup(t, chDesc) } return channels } @@ -205,7 +203,6 @@ func (n *Network) Remove(ctx context.Context, t *testing.T, id types.NodeID) { subs = append(subs, sub) } - require.NoError(t, node.Transport.Close()) node.cancel() if node.Router.IsRunning() { node.Router.Stop() @@ -222,6 +219,7 @@ func (n *Network) Remove(ctx context.Context, t *testing.T, id types.NodeID) { // Node is a node in a Network, with a Router and a PeerManager. type Node struct { + Logger log.Logger NodeID types.NodeID NodeInfo types.NodeInfo NodeAddress p2p.NodeAddress @@ -248,16 +246,13 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) } transport := n.memoryNetwork.CreateTransport(nodeID) - ep, err := transport.Endpoint() - require.NoError(t, err) - require.NotNil(t, ep, "transport not listening an endpoint") - maxRetryTime := 1000 * time.Millisecond if opts.MaxRetryTime > 0 { maxRetryTime = opts.MaxRetryTime } - peerManager, err := p2p.NewPeerManager(n.logger, nodeID, dbm.NewMemDB(), p2p.PeerManagerOptions{ + logger := n.logger.With("node", nodeID[:5]) + peerManager, err := p2p.NewPeerManager(logger, nodeID, dbm.NewMemDB(), p2p.PeerManagerOptions{ MinRetryTime: 10 * time.Millisecond, MaxRetryTime: maxRetryTime, RetryTimeJitter: time.Millisecond, @@ -267,15 +262,14 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) require.NoError(t, err) router, err := p2p.NewRouter( - n.logger, + logger, p2p.NopMetrics(), privKey, peerManager, func() *types.NodeInfo { return &nodeInfo }, transport, - ep, nil, - p2p.RouterOptions{DialSleep: func(_ context.Context) {}}, + p2p.RouterOptions{DialSleep: func(_ context.Context) error { return nil }}, ) require.NoError(t, err) @@ -286,14 +280,14 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) router.Stop() router.Wait() } - require.NoError(t, transport.Close()) cancel() }) return &Node{ + Logger: logger, NodeID: nodeID, NodeInfo: nodeInfo, - NodeAddress: ep.NodeAddress(nodeID), + NodeAddress: transport.Endpoint().NodeAddress(nodeID), PrivKey: privKey, Router: router, PeerManager: peerManager, @@ -306,16 +300,13 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) // test cleanup, it also checks that the channel is empty, to make sure // all expected messages have been asserted. func (n *Node) MakeChannel( - ctx context.Context, t *testing.T, chDesc *p2p.ChannelDescriptor, ) *p2p.Channel { - ctx, cancel := context.WithCancel(ctx) - channel, err := n.Router.OpenChannel(ctx, chDesc) + channel, err := n.Router.OpenChannel(chDesc) require.NoError(t, err) t.Cleanup(func() { - RequireEmpty(ctx, t, channel) - cancel() + RequireEmpty(t, channel) }) return channel } @@ -323,11 +314,10 @@ func (n *Node) MakeChannel( // MakeChannelNoCleanup opens a channel, with automatic error handling. The // caller must ensure proper cleanup of the channel. func (n *Node) MakeChannelNoCleanup( - ctx context.Context, t *testing.T, chDesc *p2p.ChannelDescriptor, ) *p2p.Channel { - channel, err := n.Router.OpenChannel(ctx, chDesc) + channel, err := n.Router.OpenChannel(chDesc) require.NoError(t, err) return channel } diff --git a/internal/p2p/p2ptest/require.go b/internal/p2p/p2ptest/require.go index 885e080d4..31042eda3 100644 --- a/internal/p2p/p2ptest/require.go +++ b/internal/p2p/p2ptest/require.go @@ -2,64 +2,38 @@ package p2ptest import ( "context" - "errors" "testing" - "time" - "github.com/gogo/protobuf/proto" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tendermint/tendermint/internal/p2p" - "github.com/tendermint/tendermint/types" + "github.com/tendermint/tendermint/libs/utils" ) // RequireEmpty requires that the given channel is empty. -func RequireEmpty(ctx context.Context, t *testing.T, channels ...*p2p.Channel) { +func RequireEmpty(t *testing.T, channels ...*p2p.Channel) { t.Helper() - - ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) - defer cancel() - - iter := p2p.MergedChannelIterator(ctx, channels...) - count := 0 - for iter.Next(ctx) { - count++ - require.Nil(t, iter.Envelope()) + for _, ch := range channels { + if ch.ReceiveLen() != 0 { + t.Errorf("nonempty channel %v", ch) + } } - require.Zero(t, count) - require.Error(t, ctx.Err()) } // RequireReceive requires that the given envelope is received on the channel. -func RequireReceive(ctx context.Context, t *testing.T, channel *p2p.Channel, expect p2p.Envelope) { +func RequireReceive(t *testing.T, channel *p2p.Channel, expect p2p.Envelope) { t.Helper() - - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - - iter := channel.Receive(ctx) - count := 0 - for iter.Next(ctx) { - count++ - envelope := iter.Envelope() - require.Equal(t, expect.From, envelope.From) - require.Equal(t, expect.Message, envelope.Message) - } - - if !assert.True(t, count >= 1) { - require.NoError(t, ctx.Err(), "timed out waiting for message %v", expect) - } + RequireReceiveUnordered(t, channel, utils.Slice(&expect)) } // RequireReceiveUnordered requires that the given envelopes are all received on // the channel, ignoring order. -func RequireReceiveUnordered(ctx context.Context, t *testing.T, channel *p2p.Channel, expect []*p2p.Envelope) { - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - +func RequireReceiveUnordered(t *testing.T, channel *p2p.Channel, expect []*p2p.Envelope) { + t.Helper() + t.Logf("awaiting %d messages", len(expect)) actual := []*p2p.Envelope{} - + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() iter := channel.Receive(ctx) for iter.Next(ctx) { actual = append(actual, iter.Envelope()) @@ -68,103 +42,55 @@ func RequireReceiveUnordered(ctx context.Context, t *testing.T, channel *p2p.Cha return } } - - if errors.Is(ctx.Err(), context.DeadlineExceeded) { - require.ElementsMatch(t, expect, actual) - } + require.FailNow(t, "not enough messages") } // RequireSend requires that the given envelope is sent on the channel. -func RequireSend(ctx context.Context, t *testing.T, channel *p2p.Channel, envelope p2p.Envelope) { - tctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - - err := channel.Send(tctx, envelope) - switch { - case errors.Is(err, context.DeadlineExceeded): - require.Fail(t, "timed out sending message to %q", envelope.To) - default: - require.NoError(t, err, "unexpected error") - } -} - -// RequireSendReceive requires that a given Protobuf message is sent to the -// given peer, and then that the given response is received back. -func RequireSendReceive( - ctx context.Context, - t *testing.T, - channel *p2p.Channel, - peerID types.NodeID, - send proto.Message, - receive proto.Message, -) { - RequireSend(ctx, t, channel, p2p.Envelope{To: peerID, Message: send}) - RequireReceive(ctx, t, channel, p2p.Envelope{From: peerID, Message: send}) +func RequireSend(t *testing.T, channel *p2p.Channel, envelope p2p.Envelope) { + t.Logf("sending message %v", envelope) + require.NoError(t, channel.Send(t.Context(), envelope)) } // RequireNoUpdates requires that a PeerUpdates subscription is empty. func RequireNoUpdates(ctx context.Context, t *testing.T, peerUpdates *p2p.PeerUpdates) { t.Helper() - select { - case update := <-peerUpdates.Updates(): - if ctx.Err() == nil { - require.Fail(t, "unexpected peer updates", "got %v", update) - } - case <-ctx.Done(): - default: + if len(peerUpdates.Updates()) != 0 { + require.FailNow(t, "unexpected peer updates") } } // RequireError requires that the given peer error is submitted for a peer. -func RequireError(ctx context.Context, t *testing.T, channel *p2p.Channel, peerError p2p.PeerError) { - tctx, tcancel := context.WithTimeout(ctx, time.Second) - defer tcancel() - - err := channel.SendError(tctx, peerError) - switch { - case errors.Is(err, context.DeadlineExceeded): - require.Fail(t, "timed out reporting error", "%v for %q", peerError, channel.String()) - default: - require.NoError(t, err, "unexpected error") - } +func RequireSendError(t *testing.T, channel *p2p.Channel, peerError p2p.PeerError) { + require.NoError(t, channel.SendError(t.Context(), peerError)) } // RequireUpdate requires that a PeerUpdates subscription yields the given update. func RequireUpdate(t *testing.T, peerUpdates *p2p.PeerUpdates, expect p2p.PeerUpdate) { - timer := time.NewTimer(time.Second) // not time.After due to goroutine leaks - defer timer.Stop() - - select { - case update := <-peerUpdates.Updates(): - require.Equal(t, expect.NodeID, update.NodeID, "node id did not match") - require.Equal(t, expect.Status, update.Status, "statuses did not match") - case <-timer.C: - require.Fail(t, "timed out waiting for peer update", "expected %v", expect) + t.Logf("awaiting update %v", expect) + update, err := utils.Recv(t.Context(), peerUpdates.Updates()) + if err != nil { + require.FailNow(t, "utils.Recv(): %w", err) } + require.Equal(t, expect.NodeID, update.NodeID, "node id did not match") + require.Equal(t, expect.Status, update.Status, "statuses did not match") } // RequireUpdates requires that a PeerUpdates subscription yields the given updates // in the given order. func RequireUpdates(t *testing.T, peerUpdates *p2p.PeerUpdates, expect []p2p.PeerUpdate) { - timer := time.NewTimer(time.Second) // not time.After due to goroutine leaks - defer timer.Stop() - + t.Logf("awaiting %d updates", len(expect)) actual := []p2p.PeerUpdate{} for { - select { - case update := <-peerUpdates.Updates(): - actual = append(actual, update) - if len(actual) == len(expect) { - for idx := range expect { - require.Equal(t, expect[idx].NodeID, actual[idx].NodeID) - require.Equal(t, expect[idx].Status, actual[idx].Status) - } - - return + update, err := utils.Recv(t.Context(), peerUpdates.Updates()) + if err != nil { + require.FailNow(t, "utils.Recv(): %v", err) + } + actual = append(actual, update) + if len(actual) == len(expect) { + for idx := range expect { + require.Equal(t, expect[idx].NodeID, actual[idx].NodeID) + require.Equal(t, expect[idx].Status, actual[idx].Status) } - - case <-timer.C: - require.Equal(t, expect, actual, "did not receive expected peer updates") return } } diff --git a/internal/p2p/peermanager.go b/internal/p2p/peermanager.go index b75f960f5..f1a79cddc 100644 --- a/internal/p2p/peermanager.go +++ b/internal/p2p/peermanager.go @@ -18,6 +18,7 @@ import ( dbm "github.com/tendermint/tm-db" tmsync "github.com/tendermint/tendermint/internal/libs/sync" + "github.com/tendermint/tendermint/libs/utils" p2pproto "github.com/tendermint/tendermint/proto/tendermint/p2p" "github.com/tendermint/tendermint/types" ) @@ -27,6 +28,15 @@ const ( retryNever time.Duration = math.MaxInt64 ) +type DialFailuresError struct { + Failures uint32 + Address types.NodeID +} + +func (e DialFailuresError) Error() string { + return fmt.Sprintf("dialing failed %d times will not retry for address=%s, deleting peer", e.Failures, e.Address) +} + // PeerStatus is a peer status. // // The peer manager has many more internal states for a peer (e.g. dialing, @@ -315,7 +325,7 @@ type PeerManager struct { upgrading map[types.NodeID]types.NodeID // peers claimed for upgrade (DialNext → Dialed/DialFail) connected map[types.NodeID]bool // connected peers (Dialed/Accepted → Disconnected) ready map[types.NodeID]bool // ready peers (Ready → Disconnected) - evict map[types.NodeID]bool // peers scheduled for eviction (Connected → EvictNext) + evict map[types.NodeID]error // peers scheduled for eviction (Connected → EvictNext) evicting map[types.NodeID]bool // peers being evicted (EvictNext → Disconnected) metrics *Metrics } @@ -355,7 +365,7 @@ func NewPeerManager( upgrading: map[types.NodeID]types.NodeID{}, connected: map[types.NodeID]bool{}, ready: map[types.NodeID]bool{}, - evict: map[types.NodeID]bool{}, + evict: map[types.NodeID]error{}, evicting: map[types.NodeID]bool{}, subscriptions: map[*PeerUpdates]*PeerUpdates{}, metrics: metrics, @@ -485,6 +495,12 @@ func (m *PeerManager) Add(address NodeAddress) (bool, error) { return true, nil } +func (m *PeerManager) Delete(id types.NodeID) error { + m.mtx.Lock() + defer m.mtx.Unlock() + return m.store.Delete(id) +} + func (m *PeerManager) GetBlockSyncPeers() map[types.NodeID]bool { return m.options.blocksyncPeers } @@ -613,15 +629,11 @@ func (m *PeerManager) DialFailed(ctx context.Context, address NodeAddress) error if err := m.store.Delete(address.NodeID); err != nil { return err } - return fmt.Errorf("dialing failed %d times will not retry for address=%s, deleting peer", addressInfo.DialFailures, address.NodeID) + return DialFailuresError{addressInfo.DialFailures, address.NodeID} } go func() { - // Use an explicit timer with deferred cleanup instead of - // time.After(), to avoid leaking goroutines on PeerManager.Close(). - timer := time.NewTimer(d) - defer timer.Stop() select { - case <-timer.C: + case <-time.After(d): m.dialWaker.Wake() case <-ctx.Done(): } @@ -654,8 +666,7 @@ func (m *PeerManager) Dialed(address NodeAddress) error { return fmt.Errorf("rejecting connection to self (%v)", address.NodeID) } if m.connected[address.NodeID] { - dupeConnectionErr := fmt.Errorf("cant dial, peer=%q is already connected", address.NodeID) - return dupeConnectionErr + return fmt.Errorf("cant dial, peer=%q is already connected", address.NodeID) } if m.options.MaxConnected > 0 && m.NumConnected() >= int(m.options.MaxConnected) { if upgradeFromPeer == "" || m.NumConnected() >= @@ -689,7 +700,7 @@ func (m *PeerManager) Dialed(address NodeAddress) error { upgradeFromPeer = u } } - m.evict[upgradeFromPeer] = true + m.evict[upgradeFromPeer] = errors.New("too many peers") } m.connected[peer.ID] = true m.evictWaker.Wake() @@ -722,8 +733,7 @@ func (m *PeerManager) Accepted(peerID types.NodeID) error { return fmt.Errorf("rejecting connection from self (%v)", peerID) } if m.connected[peerID] { - dupeConnectionErr := fmt.Errorf("can't accept, peer=%q is already connected", peerID) - return dupeConnectionErr + return fmt.Errorf("can't accept, peer=%q is already connected", peerID) } if !m.options.isUnconditional(peerID) && m.options.MaxConnected > 0 && m.NumConnected() >= int(m.options.MaxConnected)+int(m.options.MaxConnectedUpgrade) { @@ -758,7 +768,7 @@ func (m *PeerManager) Accepted(peerID types.NodeID) error { m.connected[peerID] = true if upgradeFromPeer != "" { - m.evict[upgradeFromPeer] = true + m.evict[upgradeFromPeer] = errors.New("found better peer") } m.evictWaker.Wake() return nil @@ -787,40 +797,48 @@ func (m *PeerManager) Ready(ctx context.Context, peerID types.NodeID, channels C // EvictNext returns the next peer to evict (i.e. disconnect). If no evictable // peers are found, the call will block until one becomes available. -func (m *PeerManager) EvictNext(ctx context.Context) (types.NodeID, error) { +func (m *PeerManager) EvictNext(ctx context.Context) (Eviction, error) { for { - id, err := m.TryEvictNext() - if err != nil || id != "" { - return id, err + ev, err := m.TryEvictNext() + if err != nil { + return Eviction{}, err + } + if ev, ok := ev.Get(); ok { + return ev, nil } select { case <-m.evictWaker.Sleep(): case <-ctx.Done(): - return "", ctx.Err() + return Eviction{}, ctx.Err() } } } +type Eviction struct { + ID types.NodeID + Cause error +} + // TryEvictNext is equivalent to EvictNext, but immediately returns an empty // node ID if no evictable peers are found. -func (m *PeerManager) TryEvictNext() (types.NodeID, error) { +func (m *PeerManager) TryEvictNext() (utils.Option[Eviction], error) { m.mtx.Lock() defer m.mtx.Unlock() // If any connected peers are explicitly scheduled for eviction, we return a // random one. - for peerID := range m.evict { + for peerID, cause := range m.evict { delete(m.evict, peerID) if m.connected[peerID] && !m.evicting[peerID] { m.evicting[peerID] = true - return peerID, nil + return utils.Some(Eviction{peerID, cause}), nil } } // If we're below capacity, we don't need to evict anything. if m.options.MaxConnected == 0 || m.NumConnected()-len(m.evicting) <= int(m.options.MaxConnected) { - return "", nil + return utils.None[Eviction](), nil } // If we're above capacity (shouldn't really happen), just pick the @@ -830,11 +848,11 @@ func (m *PeerManager) TryEvictNext() (types.NodeID, error) { peer := ranked[i] if m.connected[peer.ID] && !m.evicting[peer.ID] { m.evicting[peer.ID] = true - return peer.ID, nil + return utils.Some(Eviction{peer.ID, errors.New("too many peers")}), nil } } - return "", nil + return utils.None[Eviction](), nil } // Disconnected unmarks a peer as connected, allowing it to be dialed or @@ -888,7 +906,7 @@ func (m *PeerManager) Errored(peerID types.NodeID, err error) { defer m.mtx.Unlock() if m.connected[peerID] { - m.evict[peerID] = true + m.evict[peerID] = err } m.evictWaker.Wake() @@ -1144,7 +1162,7 @@ func (m *PeerManager) findUpgradeCandidate(id types.NodeID, score PeerScore) typ case candidate.Score() >= score: return "" // no further peers can be scored lower, due to sorting case !m.connected[candidate.ID]: - case m.evict[candidate.ID]: + case m.evict[candidate.ID] != nil: case m.evicting[candidate.ID]: case m.upgrading[candidate.ID] != "": default: @@ -1335,7 +1353,12 @@ func (s *peerStore) Ranked() []*peerInfo { sort.Slice(s.ranked, func(i, j int) bool { // FIXME: If necessary, consider precomputing scores before sorting, // to reduce the number of Score() calls. - return s.ranked[i].Score() > s.ranked[j].Score() + if a, b := s.ranked[i].Score(), s.ranked[j].Score(); a != b { + return a > b + } + // TODO(gprusak): we don't allow ties because tests require deterministic order. + // If not necessary in prod, then fix the tests instaed. + return s.ranked[i].ID < s.ranked[j].ID }) for _, peer := range s.ranked { s.metrics.PeerScore.With("peer_id", string(peer.ID)).Set(float64(int(peer.Score()))) diff --git a/internal/p2p/peermanager_test.go b/internal/p2p/peermanager_test.go index 6aaa1ddff..04df86f82 100644 --- a/internal/p2p/peermanager_test.go +++ b/internal/p2p/peermanager_test.go @@ -12,10 +12,10 @@ import ( "github.com/fortytw2/leaktest" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" dbm "github.com/tendermint/tm-db" "github.com/tendermint/tendermint/internal/p2p" + "github.com/tendermint/tendermint/libs/utils/require" "github.com/tendermint/tendermint/types" ) @@ -432,7 +432,10 @@ func TestPeerManagerDeleteOnMaxRetries(t *testing.T) { require.GreaterOrEqual(t, elapsed, time.Duration(math.Pow(2, float64(i)))*options.MinRetryTime) } if i == 3 { - require.ErrorContains(t, peerManager.DialFailed(ctx, a), "dialing failed 4 times") + if got, err := (p2p.DialFailuresError{}), peerManager.DialFailed(ctx, a); !errors.As(err, &got) || got.Failures != 4 { + t.Errorf("expected 4 failures, got error %v", err) + } + continue } require.NoError(t, peerManager.DialFailed(ctx, a)) @@ -834,7 +837,8 @@ func TestPeerManager_DialFailed_UnreservePeer(t *testing.T) { b := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("b", 40))} c := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("c", 40))} - peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{ + logger, _ := log.NewDefaultLogger("plain", "debug") + peerManager, err := p2p.NewPeerManager(logger, selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{ PeerScores: map[types.NodeID]p2p.PeerScore{ a.NodeID: p2p.DefaultMutableScore - 1, // Set lower score for a to make it upgradeable b.NodeID: p2p.DefaultMutableScore + 1, // Higher score for b to attempt upgrade @@ -845,7 +849,7 @@ func TestPeerManager_DialFailed_UnreservePeer(t *testing.T) { }, p2p.NopMetrics()) require.NoError(t, err) - // Add and connect to peer a (lower scored) + t.Logf("Add and connect to peer a (lower scored)") added, err := peerManager.Add(a) require.NoError(t, err) require.True(t, added) @@ -854,7 +858,7 @@ func TestPeerManager_DialFailed_UnreservePeer(t *testing.T) { require.Equal(t, a, dial) require.NoError(t, peerManager.Dialed(a)) - // Add both higher scored peers b and c + t.Logf("Add both higher scored peers b and c") added, err = peerManager.Add(b) require.NoError(t, err) require.True(t, added) @@ -1044,7 +1048,9 @@ func TestPeerManager_Dialed_Upgrade(t *testing.T) { // a should now be evicted. evict, err := peerManager.TryEvictNext() require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + if ev, ok := evict.Get(); !ok || ev.ID != a.NodeID { + t.Fatalf("evict = %v, expected %s", evict, a.NodeID) + } } func TestPeerManager_Dialed_UpgradeEvenLower(t *testing.T) { @@ -1100,7 +1106,9 @@ func TestPeerManager_Dialed_UpgradeEvenLower(t *testing.T) { require.NoError(t, peerManager.Dialed(c)) evict, err := peerManager.TryEvictNext() require.NoError(t, err) - require.Equal(t, d.NodeID, evict) + if ev, ok := evict.Get(); !ok || ev.ID != d.NodeID { + t.Fatalf("evict = %v, expected %s", evict, d.NodeID) + } } func TestPeerManager_Dialed_UpgradeNoEvict(t *testing.T) { @@ -1311,7 +1319,9 @@ func TestPeerManager_Accepted_Upgrade(t *testing.T) { // This should cause a to get evicted. evict, err := peerManager.TryEvictNext() require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + if ev, ok := evict.Get(); !ok || ev.ID != a.NodeID { + t.Fatalf("evict = %v, expected %s", evict, a.NodeID) + } peerManager.Disconnected(ctx, a.NodeID) // c still cannot get accepted, since it's not scored above b. @@ -1361,7 +1371,9 @@ func TestPeerManager_Accepted_UpgradeDialing(t *testing.T) { // This should cause a to get evicted, and the dial upgrade to fail. evict, err := peerManager.TryEvictNext() require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + if ev, ok := evict.Get(); !ok || ev.ID != a.NodeID { + t.Fatalf("evict = %v, expected %s", evict, a.NodeID) + } require.Error(t, peerManager.Dialed(b)) } @@ -1449,7 +1461,7 @@ func TestPeerManager_EvictNext(t *testing.T) { peerManager.Errored(a.NodeID, errors.New("foo")) evict, err := peerManager.EvictNext(timeoutCtx) require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + require.Equal(t, a.NodeID, evict.ID) // Since there are no more peers to evict, the next call should block. timeoutCtx, cancel = context.WithTimeout(ctx, 100*time.Millisecond) @@ -1484,7 +1496,7 @@ func TestPeerManager_EvictNext_WakeOnError(t *testing.T) { defer cancel() evict, err := peerManager.EvictNext(ctx) require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + require.Equal(t, a.NodeID, evict.ID) } func TestPeerManager_EvictNext_WakeOnUpgradeDialed(t *testing.T) { @@ -1524,7 +1536,7 @@ func TestPeerManager_EvictNext_WakeOnUpgradeDialed(t *testing.T) { defer cancel() evict, err := peerManager.EvictNext(ctx) require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + require.Equal(t, a.NodeID, evict.ID) } func TestPeerManager_EvictNext_WakeOnUpgradeAccepted(t *testing.T) { @@ -1558,7 +1570,7 @@ func TestPeerManager_EvictNext_WakeOnUpgradeAccepted(t *testing.T) { defer cancel() evict, err := peerManager.EvictNext(ctx) require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + require.Equal(t, a.NodeID, evict.ID) } func TestPeerManager_TryEvictNext(t *testing.T) { ctx := t.Context() @@ -1585,7 +1597,9 @@ func TestPeerManager_TryEvictNext(t *testing.T) { peerManager.Errored(a.NodeID, errors.New("foo")) evict, err = peerManager.TryEvictNext() require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + if ev, ok := evict.Get(); !ok || ev.ID != a.NodeID { + t.Fatalf("evict = %v, expected %s", evict, a.NodeID) + } // While a is being evicted (before disconnect), it shouldn't get evicted again. evict, err = peerManager.TryEvictNext() @@ -1688,7 +1702,9 @@ func TestPeerManager_Errored(t *testing.T) { peerManager.Errored(a.NodeID, errors.New("foo")) evict, err = peerManager.TryEvictNext() require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + if ev, ok := evict.Get(); !ok || ev.ID != a.NodeID { + t.Fatalf("evict = %v, expected %s", evict, a.NodeID) + } } func TestPeerManager_Subscribe(t *testing.T) { @@ -1737,7 +1753,9 @@ func TestPeerManager_Subscribe(t *testing.T) { evict, err := peerManager.TryEvictNext() require.NoError(t, err) - require.Equal(t, a.NodeID, evict) + if ev, ok := evict.Get(); !ok || ev.ID != a.NodeID { + t.Fatalf("evict = %v, expected %s", evict, a.NodeID) + } peerManager.Disconnected(ctx, a.NodeID) require.NotEmpty(t, sub.Updates()) diff --git a/internal/p2p/pex/reactor.go b/internal/p2p/pex/reactor.go index 972ed0499..f0bb7f1d4 100644 --- a/internal/p2p/pex/reactor.go +++ b/internal/p2p/pex/reactor.go @@ -2,6 +2,7 @@ package pex import ( "context" + "errors" "fmt" "sync" "time" @@ -11,6 +12,7 @@ import ( "github.com/tendermint/tendermint/internal/p2p/conn" "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/libs/service" + "github.com/tendermint/tendermint/libs/utils" protop2p "github.com/tendermint/tendermint/proto/tendermint/p2p" "github.com/tendermint/tendermint/types" ) @@ -52,13 +54,7 @@ const ( fullCapacityInterval = 10 * time.Minute ) -type NoPeersAvailableError struct { - error -} - -func (e *NoPeersAvailableError) Error() string { - return fmt.Sprintf("no available peers to send a PEX request to (retrying)") -} +var NoPeersAvailableError = errors.New("no available peers to send a PEX request to (retrying)") // TODO: We should decide whether we want channel descriptors to be housed // within each reactor (as they are now) or, considering that the reactor doesn't @@ -114,7 +110,7 @@ type Reactor struct { channel *p2p.Channel // Used to signal a restart the node on the application level - restartCh chan struct{} + restartCh chan<- struct{} restartNoAvailablePeersWindow time.Duration } @@ -123,7 +119,7 @@ func NewReactor( logger log.Logger, peerManager *p2p.PeerManager, peerEvents p2p.PeerEventSubscriber, - restartCh chan struct{}, + restartCh chan<- struct{}, selfRemediationConfig *config.SelfRemediationConfig, ) *Reactor { r := &Reactor{ @@ -152,8 +148,8 @@ func (r *Reactor) SetChannel(ch *p2p.Channel) { // OnStop to ensure the outbound p2p Channels are closed. func (r *Reactor) OnStart(ctx context.Context) error { peerUpdates := r.peerEvents(ctx) - go r.processPexCh(ctx, r.channel) - go r.processPeerUpdates(ctx, peerUpdates) + r.Spawn("processPexCh", func(ctx context.Context) error { return r.processPexCh(ctx) }) + r.Spawn("processPeerUpdates", func(ctx context.Context) error { return r.processPeerUpdates(ctx, peerUpdates) }) return nil } @@ -163,16 +159,14 @@ func (r *Reactor) OnStop() {} // processPexCh implements a blocking event loop where we listen for p2p // Envelope messages from the pexCh. -func (r *Reactor) processPexCh(ctx context.Context, pexCh *p2p.Channel) { +func (r *Reactor) processPexCh(ctx context.Context) error { incoming := make(chan *p2p.Envelope) go func() { defer close(incoming) - iter := pexCh.Receive(ctx) + iter := r.channel.Receive(ctx) for iter.Next(ctx) { - select { - case <-ctx.Done(): + if err := utils.Send(ctx, incoming, iter.Envelope()); err != nil { return - case incoming <- iter.Envelope(): } } }() @@ -184,52 +178,48 @@ func (r *Reactor) processPexCh(ctx context.Context, pexCh *p2p.Channel) { lastNoAvailablePeersTime := time.Now() timer := time.NewTimer(0) - defer timer.Stop() - for { timer.Reset(nextPeerRequest) select { case <-ctx.Done(): - return + return nil case <-timer.C: // back off sending peer requests if there's none available. // Let the loop continue to handle incoming pex messages - if noAvailablePeerFailCounter > 0 { - waitPeriod := float64(noAvailablePeersWaitPeriod) * float64(noAvailablePeerFailCounter) - if time.Since(lastNoAvailablePeersTime).Seconds() < time.Duration(waitPeriod).Seconds() { - r.logger.Debug(fmt.Sprintf("waiting for more peers to become available still in the waitPeriod=%f\n", time.Duration(waitPeriod).Seconds())) - continue - } + waitPeriod := noAvailablePeersWaitPeriod * time.Duration(noAvailablePeerFailCounter) + if time.Since(lastNoAvailablePeersTime) < waitPeriod { + r.logger.Debug(fmt.Sprintf("waiting for more peers to become available still in the waitPeriod=%v\n", waitPeriod)) + continue } // Send a request for more peer addresses. - if err := r.sendRequestForPeers(ctx, pexCh); err != nil { + if err := r.sendRequestForPeers(ctx); err != nil { r.logger.Error("failed to send request for peers", "err", err) - if _, ok := err.(*NoPeersAvailableError); ok { + if errors.Is(err, NoPeersAvailableError) { noAvailablePeerFailCounter++ lastNoAvailablePeersTime = time.Now() continue } - return + return err } noAvailablePeerFailCounter = 0 case envelope, ok := <-incoming: if !ok { - return // channel closed + return nil // channel closed } // A request from another peer, or a response to one of our requests. - dur, err := r.handlePexMessage(ctx, envelope, pexCh) + dur, err := r.handlePexMessage(ctx, envelope) if err != nil { r.logger.Error("failed to process message", "ch_id", envelope.ChannelID, "envelope", envelope, "err", err) - if serr := pexCh.SendError(ctx, p2p.PeerError{ + if serr := r.channel.SendError(ctx, p2p.PeerError{ NodeID: envelope.From, Err: err, }); serr != nil { - return + return serr } } else if dur != 0 { // We got a useful result; update the poll timer. @@ -244,29 +234,27 @@ func (r *Reactor) processPexCh(ctx context.Context, pexCh *p2p.Channel) { // processPeerUpdates initiates a blocking process where we listen for and handle // PeerUpdate messages. When the reactor is stopped, we will catch the signal and // close the p2p PeerUpdatesCh gracefully. -func (r *Reactor) processPeerUpdates(ctx context.Context, peerUpdates *p2p.PeerUpdates) { +func (r *Reactor) processPeerUpdates(ctx context.Context, peerUpdates *p2p.PeerUpdates) error { for { - select { - case <-ctx.Done(): - return - case peerUpdate := <-peerUpdates.Updates(): - r.processPeerUpdate(peerUpdate) + peerUpdate, err := utils.Recv(ctx, peerUpdates.Updates()) + if err != nil { + return err } + r.processPeerUpdate(peerUpdate) } } // handlePexMessage handles envelopes sent from peers on the PexChannel. // If an update was received, a new polling interval is returned; otherwise the // duration is 0. -func (r *Reactor) handlePexMessage(ctx context.Context, envelope *p2p.Envelope, pexCh *p2p.Channel) (time.Duration, error) { +func (r *Reactor) handlePexMessage(ctx context.Context, envelope *p2p.Envelope) (time.Duration, error) { logger := r.logger.With("peer", envelope.From) switch msg := envelope.Message.(type) { case *protop2p.PexRequest: // Verify that this peer hasn't sent us another request too recently. if err := r.markPeerRequest(envelope.From); err != nil { - r.logger.Error(fmt.Sprintf("PEX mark peer req from %s error %s", envelope.From, err)) - return 0, err + return 0, fmt.Errorf("PEX mark peer req from %s: %w", envelope.From, err) } // Fetch peers from the peer manager, convert NodeAddresses into URL @@ -278,7 +266,7 @@ func (r *Reactor) handlePexMessage(ctx context.Context, envelope *p2p.Envelope, URL: addr.String(), } } - return 0, pexCh.Send(ctx, p2p.Envelope{ + return 0, r.channel.Send(ctx, p2p.Envelope{ To: envelope.From, Message: &protop2p.PexResponse{Addresses: pexAddresses}, }) @@ -286,14 +274,11 @@ func (r *Reactor) handlePexMessage(ctx context.Context, envelope *p2p.Envelope, case *protop2p.PexResponse: // Verify that this response corresponds to one of our pending requests. if err := r.markPeerResponse(envelope.From); err != nil { - r.logger.Error(fmt.Sprintf("PEX mark peer resp from %s error %s", envelope.From, err)) - return 0, err + return 0, fmt.Errorf("PEX mark peer resp from %s: %w", envelope.From, err) } // Verify that the response does not exceed the safety limit. if len(msg.Addresses) > maxAddresses { - r.logger.Error(fmt.Sprintf("peer %s sent too many addresses (%d > maxiumum %d)", - envelope.From, len(msg.Addresses), maxAddresses)) return 0, fmt.Errorf("peer sent too many addresses (%d > maxiumum %d)", len(msg.Addresses), maxAddresses) } @@ -302,11 +287,11 @@ func (r *Reactor) handlePexMessage(ctx context.Context, envelope *p2p.Envelope, for _, pexAddress := range msg.Addresses { peerAddress, err := p2p.ParseNodeAddress(pexAddress.URL) if err != nil { - r.logger.Error(fmt.Sprintf("PEX parse node address error %s", err)) - continue + return 0, fmt.Errorf("PEX parse node address error %s", err) } added, err := r.peerManager.Add(peerAddress) if err != nil { + // TODO(gprusak): This does not distinguish between bad messages (should drop peer) and internal errors (ignore/abort). logger.Error("failed to add PEX address", "address", peerAddress, "err", err) continue } @@ -357,11 +342,11 @@ func (r *Reactor) processPeerUpdate(peerUpdate p2p.PeerUpdate) { // that peer a request for more peer addresses. The chosen peer is moved into // the requestsSent bucket so that we will not attempt to contact them again // until they've replied or updated. -func (r *Reactor) sendRequestForPeers(ctx context.Context, pexCh *p2p.Channel) error { +func (r *Reactor) sendRequestForPeers(ctx context.Context) error { r.mtx.Lock() defer r.mtx.Unlock() if len(r.availablePeers) == 0 { - return &NoPeersAvailableError{} + return NoPeersAvailableError } // Select an arbitrary peer from the available set. @@ -369,19 +354,15 @@ func (r *Reactor) sendRequestForPeers(ctx context.Context, pexCh *p2p.Channel) e for peerID = range r.availablePeers { break } - - if err := pexCh.Send(ctx, p2p.Envelope{ - To: peerID, - Message: &protop2p.PexRequest{}, - }); err != nil { - return err - } - // Move the peer from available to pending. delete(r.availablePeers, peerID) r.requestsSent[peerID] = struct{}{} - return nil + // TODO(gprusak): blocking send while holding a mutex. + return r.channel.Send(ctx, p2p.Envelope{ + To: peerID, + Message: &protop2p.PexRequest{}, + }) } // calculateNextRequestTime selects how long we should wait before attempting diff --git a/internal/p2p/pex/reactor_test.go b/internal/p2p/pex/reactor_test.go index 49e860130..89e6a8788 100644 --- a/internal/p2p/pex/reactor_test.go +++ b/internal/p2p/pex/reactor_test.go @@ -60,10 +60,10 @@ func TestReactorConnectFullNetwork(t *testing.T) { // make every node be only connected with one other node (it actually ends up // being two because of two way connections but oh well) - testNet.connectN(ctx, t, 1) + testNet.seedAddrs(t) testNet.start(ctx, t) - // assert that all nodes add each other in the network + t.Logf("assert that all nodes add each other in the network") for idx := 0; idx < len(testNet.nodes); idx++ { testNet.requireNumberOfPeers(t, idx, len(testNet.nodes)-1, longWait) } @@ -76,20 +76,20 @@ func TestReactorSendsRequestsTooOften(t *testing.T) { badNode := newNodeID(t, "b") - r.pexInCh <- p2p.Envelope{ + r.pexInCh.Send(p2p.Envelope{ From: badNode, Message: &p2pproto.PexRequest{}, - } + }, 0) resp := <-r.pexOutCh msg, ok := resp.Message.(*p2pproto.PexResponse) require.True(t, ok) require.Empty(t, msg.Addresses) - r.pexInCh <- p2p.Envelope{ + r.pexInCh.Send(p2p.Envelope{ From: badNode, Message: &p2pproto.PexRequest{}, - } + }, 0) peerErr := <-r.pexErrCh require.Error(t, peerErr.Err) @@ -132,7 +132,7 @@ func TestReactorNeverSendsTooManyPeers(t *testing.T) { testNet.addNodes(ctx, t, 110) nodes := make([]int, 110) - for i := 0; i < len(nodes); i++ { + for i := range nodes { nodes[i] = i + 2 } testNet.addAddresses(t, secondNode, nodes) @@ -152,7 +152,7 @@ func TestReactorErrorsOnReceivingTooManyPeers(t *testing.T) { require.True(t, added) addresses := make([]p2pproto.PexAddress, 101) - for i := 0; i < len(addresses); i++ { + for i := range addresses { nodeAddress := p2p.NodeAddress{Protocol: p2p.MemoryProtocol, NodeID: randomNodeID()} addresses[i] = p2pproto.PexAddress{ URL: nodeAddress.String(), @@ -170,12 +170,12 @@ func TestReactorErrorsOnReceivingTooManyPeers(t *testing.T) { if _, ok := req.Message.(*p2pproto.PexRequest); !ok { t.Fatal("expected v2 pex request") } - r.pexInCh <- p2p.Envelope{ + r.pexInCh.Send(p2p.Envelope{ From: peer.NodeID, Message: &p2pproto.PexResponse{ Addresses: addresses, }, - } + }, 0) case <-time.After(10 * time.Second): t.Fatal("pex failed to send a request within 10 seconds") @@ -193,20 +193,20 @@ func TestReactorSmallPeerStoreInALargeNetwork(t *testing.T) { testNet := setupNetwork(ctx, t, testOptions{ TotalNodes: 8, - MaxPeers: 4, - MaxConnected: 3, - BufferSize: 8, + MaxPeers: 7, // total-1, because PeerManager doesn't count self + MaxConnected: 2, // enough capacity to establish a connected graph + BufferSize: 8, // reactor deadlocks if peer updates' subscribers are full (which is stupid) MaxRetryTime: 5 * time.Minute, }) - testNet.connectN(ctx, t, 1) + testNet.connectCycle(ctx, t) // Saturate capacity by connecting nodes in a cycle. testNet.start(ctx, t) - // test that all nodes reach full capacity + t.Logf("test that peers are gossiped even if connection cap is reached") for _, nodeID := range testNet.nodes { require.Eventually(t, func() bool { // nolint:scopelint return testNet.network.Nodes[nodeID].PeerManager.PeerRatio() >= 0.9 - }, longWait, checkFrequency, + }, time.Minute, checkFrequency, "peer ratio is: %f", testNet.network.Nodes[nodeID].PeerManager.PeerRatio()) } } @@ -221,7 +221,7 @@ func TestReactorLargePeerStoreInASmallNetwork(t *testing.T) { BufferSize: 5, MaxRetryTime: 5 * time.Minute, }) - testNet.connectN(ctx, t, 1) + testNet.seedAddrs(t) testNet.start(ctx, t) // assert that all nodes add each other in the network @@ -266,7 +266,7 @@ func TestReactorWithNetworkGrowth(t *testing.T) { type singleTestReactor struct { reactor *pex.Reactor - pexInCh chan p2p.Envelope + pexInCh *p2p.Queue pexOutCh chan p2p.Envelope pexErrCh chan p2p.PeerError pexCh *p2p.Channel @@ -278,7 +278,7 @@ func setupSingle(ctx context.Context, t *testing.T) *singleTestReactor { t.Helper() nodeID := newNodeID(t, "a") chBuf := 2 - pexInCh := make(chan p2p.Envelope, chBuf) + pexInCh := p2p.NewQueue(chBuf) pexOutCh := make(chan p2p.Envelope, chBuf) pexErrCh := make(chan p2p.PeerError, chBuf) pexCh := p2p.NewChannel( @@ -318,13 +318,11 @@ func setupSingle(ctx context.Context, t *testing.T) *singleTestReactor { type reactorTestSuite struct { network *p2ptest.Network - logger log.Logger reactors map[types.NodeID]*pex.Reactor pexChannels map[types.NodeID]*p2p.Channel - peerChans map[types.NodeID]chan p2p.PeerUpdate - peerUpdates map[types.NodeID]*p2p.PeerUpdates + peerChans map[types.NodeID]chan p2p.PeerUpdate nodes []types.NodeID mocks []types.NodeID @@ -363,38 +361,32 @@ func setupNetwork(ctx context.Context, t *testing.T, opts testOptions) *reactorT realNodes := opts.TotalNodes - opts.MockNodes rts := &reactorTestSuite{ - logger: log.NewNopLogger().With("testCase", t.Name()), network: p2ptest.MakeNetwork(ctx, t, networkOpts), reactors: make(map[types.NodeID]*pex.Reactor, realNodes), pexChannels: make(map[types.NodeID]*p2p.Channel, opts.TotalNodes), peerChans: make(map[types.NodeID]chan p2p.PeerUpdate, opts.TotalNodes), - peerUpdates: make(map[types.NodeID]*p2p.PeerUpdates, opts.TotalNodes), total: opts.TotalNodes, opts: opts, } // NOTE: we don't assert that the channels get drained after stopping the // reactor - rts.pexChannels = rts.network.MakeChannelsNoCleanup(ctx, t, pex.ChannelDescriptor()) + rts.pexChannels = rts.network.MakeChannelsNoCleanup(t, pex.ChannelDescriptor()) idx := 0 for nodeID := range rts.network.Nodes { - // make a copy to avoid getting hit by the range ref - // confusion: - nodeID := nodeID - rts.peerChans[nodeID] = make(chan p2p.PeerUpdate, chBuf) - rts.peerUpdates[nodeID] = p2p.NewPeerUpdates(rts.peerChans[nodeID], chBuf) - rts.network.Nodes[nodeID].PeerManager.Register(ctx, rts.peerUpdates[nodeID]) + peerUpdates := p2p.NewPeerUpdates(rts.peerChans[nodeID], chBuf) + rts.network.Nodes[nodeID].PeerManager.Register(ctx, peerUpdates) // the first nodes in the array are always mock nodes if idx < opts.MockNodes { rts.mocks = append(rts.mocks, nodeID) } else { rts.reactors[nodeID] = pex.NewReactor( - rts.logger.With("nodeID", nodeID), + rts.network.Nodes[nodeID].Logger, rts.network.Nodes[nodeID].PeerManager, - func(_ context.Context) *p2p.PeerUpdates { return rts.peerUpdates[nodeID] }, + func(_ context.Context) *p2p.PeerUpdates { return peerUpdates }, make(chan struct{}), config.DefaultSelfRemediationConfig(), ) @@ -433,7 +425,7 @@ func (r *reactorTestSuite) start(ctx context.Context, t *testing.T) { func (r *reactorTestSuite) addNodes(ctx context.Context, t *testing.T, nodes int) { t.Helper() - for i := 0; i < nodes; i++ { + for range nodes { node := r.network.MakeNode(ctx, t, p2ptest.NodeOptions{ MaxPeers: r.opts.MaxPeers, MaxConnected: r.opts.MaxConnected, @@ -441,15 +433,15 @@ func (r *reactorTestSuite) addNodes(ctx context.Context, t *testing.T, nodes int }) r.network.Nodes[node.NodeID] = node nodeID := node.NodeID - r.pexChannels[nodeID] = node.MakeChannelNoCleanup(ctx, t, pex.ChannelDescriptor()) + r.pexChannels[nodeID] = node.MakeChannelNoCleanup(t, pex.ChannelDescriptor()) r.peerChans[nodeID] = make(chan p2p.PeerUpdate, r.opts.BufferSize) - r.peerUpdates[nodeID] = p2p.NewPeerUpdates(r.peerChans[nodeID], r.opts.BufferSize) - r.network.Nodes[nodeID].PeerManager.Register(ctx, r.peerUpdates[nodeID]) + peerUpdates := p2p.NewPeerUpdates(r.peerChans[nodeID], r.opts.BufferSize) + r.network.Nodes[nodeID].PeerManager.Register(ctx, peerUpdates) r.reactors[nodeID] = pex.NewReactor( - r.logger.With("nodeID", nodeID), + r.network.Nodes[nodeID].Logger, r.network.Nodes[nodeID].PeerManager, - func(_ context.Context) *p2p.PeerUpdates { return r.peerUpdates[nodeID] }, + func(_ context.Context) *p2p.PeerUpdates { return peerUpdates }, make(chan struct{}), config.DefaultSelfRemediationConfig(), ) @@ -631,23 +623,34 @@ func (r *reactorTestSuite) requireNumberOfPeers( ) } -func (r *reactorTestSuite) connectAll(ctx context.Context, t *testing.T) { - r.connectN(ctx, t, r.total-1) -} - -// connects all nodes with n other nodes -func (r *reactorTestSuite) connectN(ctx context.Context, t *testing.T, n int) { - if n >= r.total { - require.Fail(t, "connectN: n must be less than the size of the network - 1") +func (r *reactorTestSuite) connectCycle(ctx context.Context, t *testing.T) { + if r.total == 0 { + return } + for i := range r.total { + r.connectPeers(ctx, t, i, (i+1)%r.total) + } +} - for i := 0; i < r.total; i++ { - for j := 0; j < n; j++ { +func (r *reactorTestSuite) connectAll(ctx context.Context, t *testing.T) { + for i := range r.total { + for j := range r.total - 1 { r.connectPeers(ctx, t, i, (i+j+1)%r.total) } } } +// Adds enough addresses to peerManagers, so that all nodes are discoverable. +func (r *reactorTestSuite) seedAddrs(t *testing.T) { + t.Helper() + for i := range r.total - 1 { + n1 := r.network.Nodes[r.nodes[i]] + n2 := r.network.Nodes[r.nodes[i+1]] + _, err := n1.PeerManager.Add(n2.NodeAddress) + require.NoError(t, err) + } +} + // connects node1 to node2 func (r *reactorTestSuite) connectPeers(ctx context.Context, t *testing.T, sourceNode, targetNode int) { t.Helper() @@ -665,6 +668,9 @@ func (r *reactorTestSuite) connectPeers(ctx context.Context, t *testing.T, sourc return } + // Subscription is for the ctx lifetime. + ctx, cancel := context.WithCancel(ctx) + defer cancel() sourceSub := n1.PeerManager.Subscribe(ctx) targetSub := n2.PeerManager.Subscribe(ctx) @@ -678,22 +684,12 @@ func (r *reactorTestSuite) connectPeers(ctx context.Context, t *testing.T, sourc return } - select { - case peerUpdate := <-targetSub.Updates(): - require.Equal(t, peerUpdate.NodeID, node1) - require.Equal(t, peerUpdate.Status, p2p.PeerStatusUp) - case <-time.After(2 * time.Second): - require.Fail(t, "timed out waiting for peer", "%v accepting %v", - targetNode, sourceNode) - } - select { - case peerUpdate := <-sourceSub.Updates(): - require.Equal(t, peerUpdate.NodeID, node2) - require.Equal(t, peerUpdate.Status, p2p.PeerStatusUp) - case <-time.After(2 * time.Second): - require.Fail(t, "timed out waiting for peer", "%v dialing %v", - sourceNode, targetNode) - } + peerUpdate := <-targetSub.Updates() + require.Equal(t, peerUpdate.NodeID, node1) + require.Equal(t, peerUpdate.Status, p2p.PeerStatusUp) + peerUpdate = <-sourceSub.Updates() + require.Equal(t, peerUpdate.NodeID, node2) + require.Equal(t, peerUpdate.Status, p2p.PeerStatusUp) added, err = n2.PeerManager.Add(sourceAddress) require.NoError(t, err) diff --git a/internal/p2p/pqueue.go b/internal/p2p/pqueue.go deleted file mode 100644 index 3cd1c897a..000000000 --- a/internal/p2p/pqueue.go +++ /dev/null @@ -1,293 +0,0 @@ -package p2p - -import ( - "container/heap" - "context" - "sort" - "strconv" - "sync" - "time" - - "github.com/gogo/protobuf/proto" - - "github.com/tendermint/tendermint/libs/log" -) - -// pqEnvelope defines a wrapper around an Envelope with priority to be inserted -// into a priority queue used for Envelope scheduling. -type pqEnvelope struct { - envelope Envelope - priority uint - size uint - timestamp time.Time - - index int -} - -// priorityQueue defines a type alias for a priority queue implementation. -type priorityQueue []*pqEnvelope - -func (pq priorityQueue) get(i int) *pqEnvelope { return pq[i] } -func (pq priorityQueue) Len() int { return len(pq) } - -func (pq priorityQueue) Less(i, j int) bool { - // if both elements have the same priority, prioritize based - // on most recent and largest - if pq[i].priority == pq[j].priority { - diff := pq[i].timestamp.Sub(pq[j].timestamp) - if diff < 0 { - diff *= -1 - } - if diff < 10*time.Millisecond { - return pq[i].size > pq[j].size - } - return pq[i].timestamp.After(pq[j].timestamp) - } - - // otherwise, pick the pqEnvelope with the higher priority - return pq[i].priority > pq[j].priority -} - -func (pq priorityQueue) Swap(i, j int) { - pq[i], pq[j] = pq[j], pq[i] - pq[i].index = i - pq[j].index = j -} - -func (pq *priorityQueue) Push(x interface{}) { - n := len(*pq) - pqEnv := x.(*pqEnvelope) - pqEnv.index = n - *pq = append(*pq, pqEnv) -} - -func (pq *priorityQueue) Pop() interface{} { - old := *pq - n := len(old) - pqEnv := old[n-1] - old[n-1] = nil - pqEnv.index = -1 - *pq = old[:n-1] - return pqEnv -} - -// Assert the priority queue scheduler implements the queue interface at -// compile-time. -var _ queue = (*pqScheduler)(nil) - -type pqScheduler struct { - logger log.Logger - metrics *Metrics - lc *metricsLabelCache - size uint - sizes map[uint]uint // cumulative priority sizes - pq *priorityQueue - chDescs []*ChannelDescriptor - capacity uint - chPriorities map[ChannelID]uint - - enqueueCh chan Envelope - dequeueCh chan Envelope - - closeFn func() - closeCh <-chan struct{} - done chan struct{} -} - -func newPQScheduler( - logger log.Logger, - m *Metrics, - lc *metricsLabelCache, - chDescs []*ChannelDescriptor, - enqueueBuf, dequeueBuf, capacity uint, -) *pqScheduler { - - // copy each ChannelDescriptor and sort them by ascending channel priority - chDescsCopy := make([]*ChannelDescriptor, len(chDescs)) - copy(chDescsCopy, chDescs) - sort.Slice(chDescsCopy, func(i, j int) bool { return chDescsCopy[i].Priority < chDescsCopy[j].Priority }) - - var ( - chPriorities = make(map[ChannelID]uint) - sizes = make(map[uint]uint) - ) - - for _, chDesc := range chDescsCopy { - chID := chDesc.ID - chPriorities[chID] = uint(chDesc.Priority) - sizes[uint(chDesc.Priority)] = 0 - } - - pq := make(priorityQueue, 0) - heap.Init(&pq) - - closeCh := make(chan struct{}) - once := &sync.Once{} - - return &pqScheduler{ - logger: logger.With("router", "scheduler"), - metrics: m, - lc: lc, - chDescs: chDescsCopy, - capacity: capacity, - chPriorities: chPriorities, - pq: &pq, - sizes: sizes, - enqueueCh: make(chan Envelope, enqueueBuf), - dequeueCh: make(chan Envelope, dequeueBuf), - closeFn: func() { once.Do(func() { close(closeCh) }) }, - closeCh: closeCh, - done: make(chan struct{}), - } -} - -// start starts non-blocking process that starts the priority queue scheduler. -func (s *pqScheduler) start(ctx context.Context) { go s.process(ctx) } -func (s *pqScheduler) enqueue() chan<- Envelope { return s.enqueueCh } -func (s *pqScheduler) dequeue() <-chan Envelope { return s.dequeueCh } -func (s *pqScheduler) close() { s.closeFn() } -func (s *pqScheduler) closed() <-chan struct{} { return s.done } - -// process starts a block process where we listen for Envelopes to enqueue. If -// there is sufficient capacity, it will be enqueued into the priority queue, -// otherwise, we attempt to dequeue enough elements from the priority queue to -// make room for the incoming Envelope by dropping lower priority elements. If -// there isn't sufficient capacity at lower priorities for the incoming Envelope, -// it is dropped. -// -// After we attempt to enqueue the incoming Envelope, if the priority queue is -// non-empty, we pop the top Envelope and send it on the dequeueCh. -func (s *pqScheduler) process(ctx context.Context) { - defer close(s.done) - - for { - select { - case e := <-s.enqueueCh: - chIDStr := strconv.Itoa(int(e.ChannelID)) - pqEnv := &pqEnvelope{ - envelope: e, - size: uint(proto.Size(e.Message)), - priority: s.chPriorities[e.ChannelID], - timestamp: time.Now().UTC(), - } - - // enqueue - - // Check if we have sufficient capacity to simply enqueue the incoming - // Envelope. - if s.size+pqEnv.size <= s.capacity { - s.metrics.PeerPendingSendBytes.With("peer_id", string(pqEnv.envelope.To)).Add(float64(pqEnv.size)) - // enqueue the incoming Envelope - s.push(pqEnv) - } else { - // There is not sufficient capacity to simply enqueue the incoming - // Envelope. So we have to attempt to make room for it by dropping lower - // priority Envelopes or drop the incoming Envelope otherwise. - - // The cumulative size of all enqueue envelopes at the incoming envelope's - // priority or lower. - total := s.sizes[pqEnv.priority] - - if total >= pqEnv.size { - // There is room for the incoming Envelope, so we drop as many lower - // priority Envelopes as we need to. - var ( - canEnqueue bool - tmpSize = s.size - i = s.pq.Len() - 1 - ) - - // Drop lower priority Envelopes until sufficient capacity exists for - // the incoming Envelope - for i >= 0 && !canEnqueue { - pqEnvTmp := s.pq.get(i) - - if pqEnvTmp.priority < pqEnv.priority { - if tmpSize+pqEnv.size <= s.capacity { - canEnqueue = true - } else { - pqEnvTmpChIDStr := strconv.Itoa(int(pqEnvTmp.envelope.ChannelID)) - s.metrics.PeerQueueDroppedMsgs.With("ch_id", pqEnvTmpChIDStr).Add(1) - s.logger.Debug( - "dropped envelope", - "ch_id", pqEnvTmpChIDStr, - "priority", pqEnvTmp.priority, - "msg_size", pqEnvTmp.size, - "capacity", s.capacity, - ) - - s.metrics.PeerPendingSendBytes.With("peer_id", string(pqEnvTmp.envelope.To)).Add(float64(-pqEnvTmp.size)) - - // dequeue/drop from the priority queue - heap.Remove(s.pq, pqEnvTmp.index) - - // update the size tracker - tmpSize -= pqEnvTmp.size - - // start from the end again - i = s.pq.Len() - 1 - } - } else { - i-- - } - } - - // enqueue the incoming Envelope - s.push(pqEnv) - } else { - // There is not sufficient capacity to drop lower priority Envelopes, - // so we drop the incoming Envelope. - s.metrics.PeerQueueDroppedMsgs.With("ch_id", chIDStr).Add(1) - s.logger.Debug( - "dropped envelope", - "ch_id", chIDStr, - "priority", pqEnv.priority, - "msg_size", pqEnv.size, - "capacity", s.capacity, - ) - } - } - - // dequeue - - for s.pq.Len() > 0 { - pqEnv = heap.Pop(s.pq).(*pqEnvelope) - s.size -= pqEnv.size - - // deduct the Envelope size from all the relevant cumulative sizes - for i := 0; i < len(s.chDescs) && pqEnv.priority <= uint(s.chDescs[i].Priority); i++ { - s.sizes[uint(s.chDescs[i].Priority)] -= pqEnv.size - } - - s.metrics.PeerSendBytesTotal.With( - "chID", chIDStr, - "peer_id", string(pqEnv.envelope.To), - "message_type", s.lc.ValueToMetricLabel(pqEnv.envelope.Message)).Add(float64(pqEnv.size)) - s.metrics.PeerPendingSendBytes.With( - "peer_id", string(pqEnv.envelope.To)).Add(float64(-pqEnv.size)) - select { - case s.dequeueCh <- pqEnv.envelope: - case <-s.closeCh: - return - } - } - case <-ctx.Done(): - return - case <-s.closeCh: - return - } - } -} - -func (s *pqScheduler) push(pqEnv *pqEnvelope) { - // enqueue the incoming Envelope - heap.Push(s.pq, pqEnv) - s.size += pqEnv.size - s.metrics.PeerQueueMsgSize.With("ch_id", strconv.Itoa(int(pqEnv.envelope.ChannelID))).Add(float64(pqEnv.size)) - - // Update the cumulative sizes by adding the Envelope's size to every - // priority less than or equal to it. - for i := 0; i < len(s.chDescs) && pqEnv.priority <= uint(s.chDescs[i].Priority); i++ { - s.sizes[uint(s.chDescs[i].Priority)] += pqEnv.size - } -} diff --git a/internal/p2p/pqueue_test.go b/internal/p2p/pqueue_test.go deleted file mode 100644 index 614954589..000000000 --- a/internal/p2p/pqueue_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package p2p - -import ( - "testing" - "time" - - gogotypes "github.com/gogo/protobuf/types" - - "github.com/tendermint/tendermint/libs/log" -) - -type testMessage = gogotypes.StringValue - -func TestCloseWhileDequeueFull(t *testing.T) { - enqueueLength := 5 - chDescs := []*ChannelDescriptor{ - {ID: 0x01, Priority: 1}, - } - pqueue := newPQScheduler(log.NewNopLogger(), NopMetrics(), newMetricsLabelCache(), chDescs, uint(enqueueLength), 1, 120) - - for i := 0; i < enqueueLength; i++ { - pqueue.enqueue() <- Envelope{ - ChannelID: 0x01, - Message: &testMessage{Value: "foo"}, // 5 bytes - } - } - - ctx := t.Context() - - go pqueue.process(ctx) - - // sleep to allow context switch for process() to run - time.Sleep(10 * time.Millisecond) - doneCh := make(chan struct{}) - go func() { - pqueue.close() - close(doneCh) - }() - - select { - case <-doneCh: - case <-time.After(2 * time.Second): - t.Fatal("pqueue failed to close") - } -} diff --git a/internal/p2p/queue.go b/internal/p2p/queue.go deleted file mode 100644 index 2ce2f23fe..000000000 --- a/internal/p2p/queue.go +++ /dev/null @@ -1,53 +0,0 @@ -package p2p - -import ( - "sync" -) - -// default capacity for the size of a queue -const defaultCapacity uint = 16e6 // ~16MB - -// queue does QoS scheduling for Envelopes, enqueueing and dequeueing according -// to some policy. Queues are used at contention points, i.e.: -// -// - Receiving inbound messages to a single channel from all peers. -// - Sending outbound messages to a single peer from all channels. -type queue interface { - // enqueue returns a channel for submitting envelopes. - enqueue() chan<- Envelope - - // dequeue returns a channel ordered according to some queueing policy. - dequeue() <-chan Envelope - - // close closes the queue. After this call enqueue() will block, so the - // caller must select on closed() as well to avoid blocking forever. The - // enqueue() and dequeue() channels will not be closed. - close() - - // closed returns a channel that's closed when the scheduler is closed. - closed() <-chan struct{} -} - -// fifoQueue is a simple unbuffered lossless queue that passes messages through -// in the order they were received, and blocks until message is received. -type fifoQueue struct { - queueCh chan Envelope - closeFn func() - closeCh <-chan struct{} -} - -func newFIFOQueue(size int) queue { - closeCh := make(chan struct{}) - once := &sync.Once{} - - return &fifoQueue{ - queueCh: make(chan Envelope, size), - closeFn: func() { once.Do(func() { close(closeCh) }) }, - closeCh: closeCh, - } -} - -func (q *fifoQueue) enqueue() chan<- Envelope { return q.queueCh } -func (q *fifoQueue) dequeue() <-chan Envelope { return q.queueCh } -func (q *fifoQueue) close() { q.closeFn() } -func (q *fifoQueue) closed() <-chan struct{} { return q.closeCh } diff --git a/internal/p2p/router.go b/internal/p2p/router.go index 978f4812e..3f40d4cb6 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -6,9 +6,8 @@ import ( "fmt" "io" "math/rand" - "net" + "net/netip" "runtime" - "strings" "sync" "time" @@ -17,10 +16,12 @@ import ( "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/libs/service" + "github.com/tendermint/tendermint/libs/utils" + "github.com/tendermint/tendermint/libs/utils/scope" "github.com/tendermint/tendermint/types" ) -const queueBufferDefault = 32 +const queueBufferDefault = 1024 // RouterOptions specifies options for a Router. type RouterOptions struct { @@ -35,10 +36,6 @@ type RouterOptions struct { // no timeout. HandshakeTimeout time.Duration - // QueueType must be, "priority", or "fifo". Defaults to - // "fifo". - QueueType string - // MaxIncomingConnectionAttempts rate limits the number of incoming connection // attempts per IP address. Defaults to 100. MaxIncomingConnectionAttempts uint @@ -53,7 +50,7 @@ type RouterOptions struct { // the remote IP of the incoming connection the port number as // arguments. Functions should return an error to reject the // peer. - FilterPeerByIP func(context.Context, net.IP, uint16) error + FilterPeerByIP func(context.Context, netip.AddrPort) error // FilterPeerByID is used by the router to inject filtering // behavior for new incoming connections. The router passes @@ -67,7 +64,7 @@ type RouterOptions struct { // sleeps between dialing peers. If not set, a default value // is used that sleeps for a (random) amount of time up to 3 // seconds between submitting each peer to be dialed. - DialSleep func(context.Context) + DialSleep func(context.Context) error // NumConcrruentDials controls how many parallel go routines // are used to dial peers. This defaults to the value of @@ -75,23 +72,8 @@ type RouterOptions struct { NumConcurrentDials func() int } -const ( - queueTypeFifo = "fifo" - queueTypePriority = "priority" - queueTypeSimplePriority = "simple-priority" -) - // Validate validates router options. func (o *RouterOptions) Validate() error { - switch o.QueueType { - case "": - o.QueueType = queueTypeFifo - case queueTypeFifo, queueTypePriority, queueTypeSimplePriority: - // pass - default: - return fmt.Errorf("queue type %q is not supported", o.QueueType) - } - switch { case o.IncomingConnectionWindow == 0: o.IncomingConnectionWindow = 100 * time.Millisecond @@ -107,6 +89,12 @@ func (o *RouterOptions) Validate() error { return nil } +type peerState struct { + cancel context.CancelFunc + queue *Queue // outbound messages per peer for all channels + channels ChannelIDSet // the channels that the peer queue has open +} + // Router manages peer connections and routes messages between peers and reactor // channels. It takes a PeerManager for peer lifecycle management (e.g. which // peers to dial and when) and a set of Transports for connecting and @@ -158,21 +146,16 @@ type Router struct { peerManager *PeerManager chDescs []*ChannelDescriptor transport Transport - endpoint *Endpoint connTracker connectionTracker - peerMtx sync.RWMutex - peerQueues map[types.NodeID]queue // outbound messages per peer for all channels - // the channels that the peer queue has open - peerChannels map[types.NodeID]ChannelIDSet - queueFactory func(int) queue + peerStates utils.RWMutex[map[types.NodeID]*peerState] nodeInfoProducer func() *types.NodeInfo // FIXME: We don't strictly need to use a mutex for this if we seal the // channels on router start. This depends on whether we want to allow // dynamic channels in the future. channelMtx sync.RWMutex - channelQueues map[ChannelID]queue // inbound messages from all peers to a single channel + channelQueues map[ChannelID]*Queue // inbound messages from all peers to a single channel channelMessages map[ChannelID]proto.Message chDescsToBeAdded []chDescAdderWithCallback @@ -195,7 +178,6 @@ func NewRouter( peerManager *PeerManager, nodeInfoProducer func() *types.NodeInfo, transport Transport, - endpoint *Endpoint, dynamicIDFilterer func(context.Context, types.NodeID) error, options RouterOptions, ) (*Router, error) { @@ -216,13 +198,11 @@ func NewRouter( ), chDescs: make([]*ChannelDescriptor, 0), transport: transport, - endpoint: endpoint, peerManager: peerManager, options: options, - channelQueues: map[ChannelID]queue{}, + channelQueues: map[ChannelID]*Queue{}, channelMessages: map[ChannelID]proto.Message{}, - peerQueues: map[types.NodeID]queue{}, - peerChannels: make(map[types.NodeID]ChannelIDSet), + peerStates: utils.NewRWMutex(map[types.NodeID]*peerState{}), dynamicIDFilterer: dynamicIDFilterer, } @@ -231,42 +211,13 @@ func NewRouter( return router, nil } -func (r *Router) createQueueFactory(ctx context.Context) (func(int) queue, error) { - switch r.options.QueueType { - case queueTypeFifo: - return newFIFOQueue, nil - - case queueTypePriority: - return func(size int) queue { - if size%2 != 0 { - size++ - } - - q := newPQScheduler(r.logger, r.metrics, r.lc, r.chDescs, uint(size)/2, uint(size)/2, defaultCapacity) - q.start(ctx) - return q - }, nil - - case queueTypeSimplePriority: - return func(size int) queue { return newSimplePriorityQueue(ctx, size, r.chDescs) }, nil - - default: - return nil, fmt.Errorf("cannot construct queue of type %q", r.options.QueueType) - } -} - // ChannelCreator allows routers to construct their own channels, // either by receiving a reference to Router.OpenChannel or using some // kind shim for testing purposes. type ChannelCreator func(context.Context, *ChannelDescriptor) (*Channel, error) -// OpenChannel opens a new channel for the given message type. The caller must -// close the channel when done, before stopping the Router. messageType is the -// type of message passed through the channel (used for unmarshaling), which can -// implement Wrapper to automatically (un)wrap multiple message types in a -// wrapper message. The caller may provide a size to make the channel buffered, -// which internally makes the inbound, outbound, and error channel buffered. -func (r *Router) OpenChannel(ctx context.Context, chDesc *ChannelDescriptor) (*Channel, error) { +// OpenChannel opens a new channel for the given message type. +func (r *Router) OpenChannel(chDesc *ChannelDescriptor) (*Channel, error) { r.channelMtx.Lock() defer r.channelMtx.Unlock() @@ -278,10 +229,12 @@ func (r *Router) OpenChannel(ctx context.Context, chDesc *ChannelDescriptor) (*C messageType := chDesc.MessageType - queue := r.queueFactory(chDesc.RecvBufferCapacity) + // TODO(gprusak): get rid of this random cap*cap value once we understand + // what the sizes per channel really should be. + queue := NewQueue(chDesc.RecvBufferCapacity * chDesc.RecvBufferCapacity) outCh := make(chan Envelope, chDesc.RecvBufferCapacity) errCh := make(chan PeerError, chDesc.RecvBufferCapacity) - channel := NewChannel(id, queue.dequeue(), outCh, errCh) + channel := NewChannel(id, queue, outCh, errCh) channel.name = chDesc.Name var wrapper Wrapper @@ -297,18 +250,31 @@ func (r *Router) OpenChannel(ctx context.Context, chDesc *ChannelDescriptor) (*C r.transport.AddChannelDescriptors([]*ChannelDescriptor{chDesc}) - go func() { - defer func() { - r.channelMtx.Lock() - delete(r.channelQueues, id) - delete(r.channelMessages, id) - r.channelMtx.Unlock() - queue.close() - }() - - r.routeChannel(ctx, id, outCh, errCh, wrapper) - }() - + r.Spawn("channel", func(ctx context.Context) error { + return scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + s.Spawn(func() error { return r.routeChannel(ctx, chDesc, outCh, wrapper) }) + for { + peerError, err := utils.Recv(ctx, errCh) + if err != nil { + return err + } + shouldEvict := peerError.Fatal || r.peerManager.HasMaxPeerCapacity() + r.logger.Error("peer error", + "peer", peerError.NodeID, + "err", peerError.Err, + "evicting", shouldEvict, + ) + if shouldEvict { + r.peerManager.Errored(peerError.NodeID, peerError.Err) + } else { + r.peerManager.processPeerEvent(ctx, PeerUpdate{ + NodeID: peerError.NodeID, + Status: PeerStatusBad, + }) + } + } + }) + }) return channel, nil } @@ -319,119 +285,69 @@ func (r *Router) OpenChannel(ctx context.Context, chDesc *ChannelDescriptor) (*C // for messages, see Wrapper for details. func (r *Router) routeChannel( ctx context.Context, - chID ChannelID, + chDesc *ChannelDescriptor, outCh <-chan Envelope, - errCh <-chan PeerError, wrapper Wrapper, -) { +) error { for { - select { - case envelope, ok := <-outCh: - if !ok { - return - } - if envelope.IsZero() { - continue - } - - // Mark the envelope with the channel ID to allow sendPeer() to pass - // it on to Transport.SendMessage(). - envelope.ChannelID = chID + envelope, err := utils.Recv(ctx, outCh) + if err != nil { + return err + } + if envelope.IsZero() { + continue + } - // wrap the message in a wrapper message, if requested - if wrapper != nil { - msg := proto.Clone(wrapper) - if err := msg.(Wrapper).Wrap(envelope.Message); err != nil { - r.logger.Error("failed to wrap message", "channel", chID, "err", err) - continue - } + // Mark the envelope with the channel ID to allow sendPeer() to pass + // it on to Transport.SendMessage(). + envelope.ChannelID = chDesc.ID - envelope.Message = msg + // wrap the message in a wrapper message, if requested + if wrapper != nil { + msg := utils.ProtoClone(wrapper) + if err := msg.Wrap(envelope.Message); err != nil { + r.logger.Error("failed to wrap message", "channel", chDesc.ID, "err", err) + continue } - // collect peer queues to pass the message via - var queues []queue - if envelope.Broadcast { - r.peerMtx.RLock() - - queues = make([]queue, 0, len(r.peerQueues)) - for nodeID, q := range r.peerQueues { - peerChs := r.peerChannels[nodeID] + envelope.Message = msg + } - // check whether the peer is receiving on that channel - if _, ok := peerChs[chID]; ok { - queues = append(queues, q) + // collect peer queues to pass the message via + var queues []*Queue + if envelope.Broadcast { + for states := range r.peerStates.RLock() { + queues = make([]*Queue, 0, len(states)) + for _, s := range states { + if _, ok := s.channels[chDesc.ID]; ok { + queues = append(queues, s.queue) } } - - r.peerMtx.RUnlock() - } else { - r.peerMtx.RLock() - - q, ok := r.peerQueues[envelope.To] - contains := false - if ok { - peerChs := r.peerChannels[envelope.To] - - // check whether the peer is receiving on that channel - _, contains = peerChs[chID] - } - r.peerMtx.RUnlock() - - if !ok { - r.logger.Debug("dropping message for unconnected peer", "peer", envelope.To, "channel", chID) - continue - } - - if !contains { - // reactor tried to send a message across a channel that the - // peer doesn't have available. This is a known issue due to - // how peer subscriptions work: - // https://github.com/tendermint/tendermint/issues/6598 - continue - } - - queues = []queue{q} } - - // send message to peers - for _, q := range queues { - start := time.Now().UTC() - - select { - case q.enqueue() <- envelope: - r.metrics.RouterPeerQueueSend.Observe(time.Since(start).Seconds()) - - case <-q.closed(): - r.logger.Debug("dropping message for unconnected peer", "peer", envelope.To, "channel", chID) - - case <-ctx.Done(): - return - } + } else { + ok := false + var s *peerState + for states := range r.peerStates.RLock() { + s, ok = states[envelope.To] } - - case peerError, ok := <-errCh: if !ok { - return + r.logger.Debug("dropping message for unconnected peer", "peer", envelope.To, "channel", chDesc.ID) + continue } - - shouldEvict := peerError.Fatal || r.peerManager.HasMaxPeerCapacity() - r.logger.Error("peer error", - "peer", peerError.NodeID, - "err", peerError.Err, - "evicting", shouldEvict, - ) - if shouldEvict { - r.peerManager.Errored(peerError.NodeID, peerError.Err) - } else { - r.peerManager.processPeerEvent(ctx, PeerUpdate{ - NodeID: peerError.NodeID, - Status: PeerStatusBad, - }) + if _, contains := s.channels[chDesc.ID]; !contains { + // reactor tried to send a message across a channel that the + // peer doesn't have available. This is a known issue due to + // how peer subscriptions work: + // https://github.com/tendermint/tendermint/issues/6598 + continue + } + queues = []*Queue{s.queue} + } + // send message to peers + for _, q := range queues { + if pruned, ok := q.Send(envelope, chDesc.Priority).Get(); ok { + r.metrics.QueueDroppedMsgs.With("ch_id", fmt.Sprint(pruned.ChannelID), "direction", "out").Add(float64(1)) } - - case <-ctx.Done(): - return } } } @@ -444,12 +360,12 @@ func (r *Router) numConccurentDials() int { return r.options.NumConcurrentDials() } -func (r *Router) filterPeersIP(ctx context.Context, ip net.IP, port uint16) error { +func (r *Router) filterPeersIP(ctx context.Context, addrPort netip.AddrPort) error { if r.options.FilterPeerByIP == nil { return nil } - return r.options.FilterPeerByIP(ctx, ip, port) + return r.options.FilterPeerByIP(ctx, addrPort) } func (r *Router) filterPeersID(ctx context.Context, id types.NodeID) error { @@ -467,54 +383,35 @@ func (r *Router) filterPeersID(ctx context.Context, id types.NodeID) error { return r.options.FilterPeerByID(ctx, id) } -func (r *Router) dialSleep(ctx context.Context) { - if r.options.DialSleep == nil { - const ( - maxDialerInterval = 3000 - minDialerInterval = 250 - ) - - // nolint:gosec // G404: Use of weak random number generator - dur := time.Duration(rand.Int63n(maxDialerInterval-minDialerInterval+1) + minDialerInterval) - - timer := time.NewTimer(dur * time.Millisecond) - defer timer.Stop() - - select { - case <-ctx.Done(): - case <-timer.C: - } - - return +func (r *Router) dialSleep(ctx context.Context) error { + if r.options.DialSleep != nil { + return r.options.DialSleep(ctx) } + const ( + maxDialerInterval = 3000 + minDialerInterval = 250 + ) - r.options.DialSleep(ctx) + // nolint:gosec // G404: Use of weak random number generator + dur := time.Duration(rand.Int63n(maxDialerInterval-minDialerInterval+1) + minDialerInterval) + return utils.Sleep(ctx, dur*time.Millisecond) } // acceptPeers accepts inbound connections from peers on the given transport, // and spawns goroutines that route messages to/from them. -func (r *Router) acceptPeers(ctx context.Context, transport Transport) { +func (r *Router) acceptPeers(ctx context.Context, transport Transport) error { for { conn, err := transport.Accept(ctx) - switch { - case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): - r.logger.Debug("stopping accept routine", "transport", transport, "err", "context canceled") - return - case errors.Is(err, io.EOF): - r.logger.Debug("stopping accept routine", "transport", transport, "err", "EOF") - return - case err != nil: - // in this case we got an error from the net.Listener. - r.logger.Error("failed to accept connection", "transport", transport, "err", err) - continue + if err != nil { + return fmt.Errorf("failed to accept connection: %w", err) } - - incomingIP := conn.RemoteEndpoint().IP - if err := r.connTracker.AddConn(incomingIP); err != nil { + r.metrics.NewConnections.With("direction", "in").Add(1) + incomingAddr := conn.RemoteEndpoint().Addr + if err := r.connTracker.AddConn(incomingAddr); err != nil { closeErr := conn.Close() r.logger.Error("rate limiting incoming peer", "err", err, - "ip", incomingIP.String(), + "addr", incomingAddr.String(), "close_err", closeErr, ) @@ -522,21 +419,18 @@ func (r *Router) acceptPeers(ctx context.Context, transport Transport) { } // Spawn a goroutine for the handshake, to avoid head-of-line blocking. - go r.openConnection(ctx, conn) - + r.Spawn("openConnection", func(ctx context.Context) error { return r.openConnection(ctx, conn) }) } } -func (r *Router) openConnection(ctx context.Context, conn Connection) { +func (r *Router) openConnection(ctx context.Context, conn Connection) error { defer conn.Close() - defer r.connTracker.RemoveConn(conn.RemoteEndpoint().IP) - - re := conn.RemoteEndpoint() - incomingIP := re.IP + incomingAddr := conn.RemoteEndpoint().Addr + defer r.connTracker.RemoveConn(incomingAddr) - if err := r.filterPeersIP(ctx, incomingIP, re.Port); err != nil { - r.logger.Debug("peer filtered by IP", "ip", incomingIP.String(), "err", err) - return + if err := r.filterPeersIP(ctx, incomingAddr); err != nil { + r.logger.Debug("peer filtered by IP", "ip", incomingAddr, "err", err) + return nil } // FIXME: The peer manager may reject the peer during Accepted() @@ -555,85 +449,58 @@ func (r *Router) openConnection(ctx context.Context, conn Connection) { // message to make sure both ends have accepted the connection, such // that it can be coordinated with the peer manager. peerInfo, err := r.handshakePeer(ctx, conn, "") - switch { - case errors.Is(err, context.Canceled): - return - case err != nil: - r.logger.Error("peer handshake failed", "endpoint", conn, "err", err) - return + if err != nil { + return fmt.Errorf("peer handshake failed: endpoint=%v: %w", conn, err) } if err := r.filterPeersID(ctx, peerInfo.NodeID); err != nil { r.logger.Debug("peer filtered by node ID", "node", peerInfo.NodeID, "err", err) - return + return nil } - - if err := r.runWithPeerMutex(func() error { return r.peerManager.Accepted(peerInfo.NodeID) }); err != nil { - // If peer is trying to reconnect, error and let it reconnect - if strings.Contains(err.Error(), "is already connected") { - r.peerManager.Errored(peerInfo.NodeID, err) - } - r.logger.Error("failed to accept connection", - "op", "incoming/accepted", "peer", peerInfo.NodeID, "err", err) - return + if err := r.peerManager.Accepted(peerInfo.NodeID); err != nil { + return fmt.Errorf("failed to accept connection: op=incoming/accepted, peer=%v: %w", peerInfo.NodeID, err) } - - r.routePeer(ctx, peerInfo.NodeID, conn, toChannelIDs(peerInfo.Channels)) + return r.routePeer(ctx, peerInfo.NodeID, conn, toChannelIDs(peerInfo.Channels)) } // dialPeers maintains outbound connections to peers by dialing them. -func (r *Router) dialPeers(ctx context.Context) { - addresses := make(chan NodeAddress) - wg := &sync.WaitGroup{} - - // Start a limited number of goroutines to dial peers in - // parallel. the goal is to avoid starting an unbounded number - // of goroutines thereby spamming the network, but also being - // able to add peers at a reasonable pace, though the number - // is somewhat arbitrary. The action is further throttled by a - // sleep after sending to the addresses channel. - for i := 0; i < r.numConccurentDials(); i++ { - wg.Add(1) - go func() { - defer wg.Done() - - for { - select { - case <-ctx.Done(): - return - case address := <-addresses: +func (r *Router) dialPeers(ctx context.Context) error { + return scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + addresses := make(chan NodeAddress) + // Start a limited number of goroutines to dial peers in + // parallel. the goal is to avoid starting an unbounded number + // of goroutines thereby spamming the network, but also being + // able to add peers at a reasonable pace, though the number + // is somewhat arbitrary. The action is further throttled by a + // sleep after sending to the addresses channel. + for range r.numConccurentDials() { + s.Spawn(func() error { + for { + address, err := utils.Recv(ctx, addresses) + if err != nil { + return err + } r.logger.Debug(fmt.Sprintf("Going to dial next peer %s", address.NodeID)) r.connectPeer(ctx, address) } - } - }() - } - -LOOP: - for { - address, err := r.peerManager.DialNext(ctx) - switch { - case errors.Is(err, context.Canceled): - break LOOP - case err != nil: - r.logger.Error("failed to find next peer to dial", "err", err) - break LOOP + }) } - select { - case addresses <- address: + for { + address, err := r.peerManager.DialNext(ctx) + if err != nil { + return fmt.Errorf("failed to find next peer to dial: %w", err) + } + if err := utils.Send(ctx, addresses, address); err != nil { + return err + } // this jitters the frequency that we call // DialNext and prevents us from attempting to // create connections too quickly. - - r.dialSleep(ctx) - continue - case <-ctx.Done(): - close(addresses) - break LOOP + if err := r.dialSleep(ctx); err != nil { + return err + } } - } - - wg.Wait() + }) } func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { @@ -656,42 +523,28 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { return case err != nil: r.logger.Debug("failed to handshake with peer", "peer", address, "err", err) - if err = r.peerManager.DialFailed(ctx, address); err != nil { + if err := r.peerManager.DialFailed(ctx, address); err != nil { r.logger.Error("failed to report dial failure", "peer", address, "err", err) } conn.Close() return } - if err := r.runWithPeerMutex(func() error { return r.peerManager.Dialed(address) }); err != nil { - // If peer is trying to reconnect, fail it and let it reconnect - if strings.Contains(err.Error(), "is already connected") { - r.logger.Error(fmt.Sprintf("Disconnecting %s because of %s", address.NodeID, err)) - r.peerManager.Disconnected(ctx, address.NodeID) - } - - r.logger.Debug("failed to dial peer", - "op", "outgoing/dialing", "peer", address.NodeID, "err", err) + // TODO(gprusak): this symmetric logic for handling duplicate connections is a source of race conditions: + // if 2 nodes try to establish a connection to each other at the same time, both connections will be dropped. + // Instead either: + // * break the symmetry by favoring incoming connection iff my.NodeID > peer.NodeID + // * keep incoming and outcoming connection pools separate to avoid the collision (recommended) + if err := r.peerManager.Dialed(address); err != nil { + r.logger.Info("failed to dial peer", "op", "outgoing/dialing", "peer", address.NodeID, "err", err) conn.Close() return } - // routePeer (also) calls connection close - go r.routePeer(ctx, address.NodeID, conn, toChannelIDs(peerInfo.Channels)) -} - -func (r *Router) getOrMakeQueue(peerID types.NodeID, channels ChannelIDSet) queue { - r.peerMtx.Lock() - defer r.peerMtx.Unlock() - - if peerQueue, ok := r.peerQueues[peerID]; ok { - return peerQueue - } - - peerQueue := r.queueFactory(queueBufferDefault) - r.peerQueues[peerID] = peerQueue - r.peerChannels[peerID] = channels - return peerQueue + r.Spawn("routePeer", func(ctx context.Context) error { + defer conn.Close() + return r.routePeer(ctx, address.NodeID, conn, toChannelIDs(peerInfo.Channels)) + }) } // dialPeer connects to a peer by dialing it. @@ -736,6 +589,7 @@ func (r *Router) dialPeer(ctx context.Context, address NodeAddress) (Connection, if err != nil { r.logger.Debug("failed to dial endpoint", "peer", address.NodeID, "endpoint", endpoint, "err", err) } else { + r.metrics.NewConnections.With("direction", "out").Add(1) r.logger.Debug("dialed peer", "peer", address.NodeID, "endpoint", endpoint) return conn, nil } @@ -760,31 +614,32 @@ func (r *Router) handshakePeer( nodeInfo := r.nodeInfoProducer() peerInfo, peerKey, err := conn.Handshake(ctx, *nodeInfo, r.privKey) if err != nil { - return peerInfo, err + return types.NodeInfo{}, err } - if err = peerInfo.Validate(); err != nil { - return peerInfo, fmt.Errorf("invalid handshake NodeInfo: %w", err) + // Authenticate the peer first. + if types.NodeIDFromPubKey(peerKey) != peerInfo.NodeID { + return types.NodeInfo{}, fmt.Errorf("peer's public key did not match its node ID %q (expected %q)", + peerInfo.NodeID, types.NodeIDFromPubKey(peerKey)) } + if err = peerInfo.Validate(); err != nil { + return types.NodeInfo{}, fmt.Errorf("invalid handshake NodeInfo: %w", err) + } if peerInfo.Network != nodeInfo.Network { - if err := r.peerManager.store.Delete(peerInfo.NodeID); err != nil { - return peerInfo, fmt.Errorf("problem removing peer from store from incorrect network [%s]: %w", peerInfo.Network, err) + if err := r.peerManager.Delete(peerInfo.NodeID); err != nil { + return types.NodeInfo{}, fmt.Errorf("problem removing peer from store from incorrect network [%s]: %w", peerInfo.Network, err) } - return peerInfo, fmt.Errorf("connected to peer from wrong network, %q, removed from peer store", peerInfo.Network) + return types.NodeInfo{}, fmt.Errorf("connected to peer from wrong network, %q, removed from peer store", peerInfo.Network) } - if types.NodeIDFromPubKey(peerKey) != peerInfo.NodeID { - return peerInfo, fmt.Errorf("peer's public key did not match its node ID %q (expected %q)", - peerInfo.NodeID, types.NodeIDFromPubKey(peerKey)) - } if expectID != "" && expectID != peerInfo.NodeID { - return peerInfo, fmt.Errorf("expected to connect with peer %q, got %q", + return types.NodeInfo{}, fmt.Errorf("expected to connect with peer %q, got %q", expectID, peerInfo.NodeID) } if err := nodeInfo.CompatibleWith(peerInfo); err != nil { - return peerInfo, ErrRejected{ + return types.NodeInfo{}, ErrRejected{ err: err, id: peerInfo.ID(), isIncompatible: true, @@ -793,80 +648,48 @@ func (r *Router) handshakePeer( return peerInfo, nil } -func (r *Router) runWithPeerMutex(fn func() error) error { - r.peerMtx.Lock() - defer r.peerMtx.Unlock() - return fn() -} - // routePeer routes inbound and outbound messages between a peer and the reactor // channels. It will close the given connection and send queue when done, or if // they are closed elsewhere it will cause this method to shut down and return. -func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connection, channels ChannelIDSet) { +func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connection, channels ChannelIDSet) error { r.metrics.Peers.Add(1) r.peerManager.Ready(ctx, peerID, channels) - - sendQueue := r.getOrMakeQueue(peerID, channels) - defer func() { - r.peerMtx.Lock() - delete(r.peerQueues, peerID) - delete(r.peerChannels, peerID) - r.peerMtx.Unlock() - - sendQueue.close() - - r.peerManager.Disconnected(ctx, peerID) - r.metrics.Peers.Add(-1) - }() - - r.logger.Debug("peer connected", "peer", peerID, "endpoint", conn) - - errCh := make(chan error, 2) - - go func() { - select { - case errCh <- r.receivePeer(ctx, peerID, conn): - case <-ctx.Done(): - } - }() - - go func() { - select { - case errCh <- r.sendPeer(ctx, peerID, conn, sendQueue): - case <-ctx.Done(): + peerCtx, cancel := context.WithCancel(ctx) + state := &peerState{ + cancel: cancel, + queue: NewQueue(queueBufferDefault), + channels: channels, + } + for states := range r.peerStates.Lock() { + if old, ok := states[peerID]; ok { + old.cancel() } - }() - - var err error - select { - case err = <-errCh: - case <-ctx.Done(): + states[peerID] = state } - - _ = conn.Close() - sendQueue.close() - - select { - case <-ctx.Done(): - case e := <-errCh: - // The first err was nil, so we update it with the second err, which may - // or may not be nil. - if err == nil { - err = e + r.logger.Debug("peer connected", "peer", peerID, "endpoint", conn) + err := scope.Run(peerCtx, func(ctx context.Context, s scope.Scope) error { + s.Spawn(func() error { return r.receivePeer(ctx, peerID, conn) }) + s.Spawn(func() error { return r.sendPeer(ctx, peerID, conn, state.queue) }) + <-ctx.Done() + // TODO(gprusak): we need to close the connection here, because + // the mock connection used in tests does not respect the context. + // Get rid of these stupid mocks. + _ = conn.Close() + return nil + }) + r.logger.Info("peer disconnected", "peer", peerID, "endpoint", conn, "err", err) + for states := range r.peerStates.Lock() { + if states[peerID] == state { + delete(states, peerID) } } - - // if the context was canceled - if e := ctx.Err(); err == nil && e != nil { - err = e - } - - switch err { - case nil, io.EOF: - r.logger.Debug("peer disconnected", "peer", peerID, "endpoint", conn) - default: - r.logger.Error("peer failure", "peer", peerID, "endpoint", conn, "err", err) + // TODO(gprusak): investigate if peerManager handles overlapping connetions correctly + r.peerManager.Disconnected(ctx, peerID) + r.metrics.Peers.Add(-1) + if errors.Is(err, io.EOF) { + return nil } + return err } // receivePeer receives inbound messages from a peer, deserializes them and @@ -884,114 +707,80 @@ func (r *Router) receivePeer(ctx context.Context, peerID types.NodeID, conn Conn r.channelMtx.RUnlock() if !ok { + // TODO(gprusak): verify if this is a misbehavior, and drop the peer if it is. r.logger.Debug("dropping message for unknown channel", "peer", peerID, "channel", chID) continue } msg := proto.Clone(messageType) if err := proto.Unmarshal(bz, msg); err != nil { - r.logger.Error("message decoding failed, dropping message", "peer", peerID, "err", err) - continue + return fmt.Errorf("message decoding failed, dropping message: [peer=%v] %w", peerID, err) } if wrapper, ok := msg.(Wrapper); ok { msg, err = wrapper.Unwrap() if err != nil { - r.logger.Error("failed to unwrap message", "err", err) - continue + return fmt.Errorf("failed to unwrap message: %w", err) } } - start := time.Now().UTC() - - select { - case queue.enqueue() <- Envelope{From: peerID, Message: msg, ChannelID: chID}: - r.metrics.PeerReceiveBytesTotal.With( - "chID", fmt.Sprint(chID), - "peer_id", string(peerID), - "message_type", r.lc.ValueToMetricLabel(msg)).Add(float64(proto.Size(msg))) - r.metrics.RouterChannelQueueSend.Observe(time.Since(start).Seconds()) - r.logger.Debug("received message", "peer", peerID, "message", msg) - - case <-queue.closed(): - r.logger.Debug("channel closed, dropping message", "peer", peerID, "channel", chID) - - case <-ctx.Done(): - return nil + // Priority is not used since all messages in this queue are from the same channel. + if pruned, ok := queue.Send(Envelope{From: peerID, Message: msg, ChannelID: chID}, 0).Get(); ok { + r.metrics.QueueDroppedMsgs.With("ch_id", fmt.Sprint(pruned.ChannelID), "direction", "in").Add(float64(1)) } + r.metrics.PeerReceiveBytesTotal.With( + "chID", fmt.Sprint(chID), + "peer_id", string(peerID), + "message_type", r.lc.ValueToMetricLabel(msg)).Add(float64(proto.Size(msg))) + r.logger.Debug("received message", "peer", peerID, "message", msg) } } // sendPeer sends queued messages to a peer. -func (r *Router) sendPeer(ctx context.Context, peerID types.NodeID, conn Connection, peerQueue queue) error { +func (r *Router) sendPeer(ctx context.Context, peerID types.NodeID, conn Connection, peerQueue *Queue) error { for { start := time.Now().UTC() + envelope, err := peerQueue.Recv(ctx) + if err != nil { + return err + } + r.metrics.RouterPeerQueueRecv.Observe(time.Since(start).Seconds()) + if envelope.Message == nil { + r.logger.Error("dropping nil message", "peer", peerID) + continue + } - select { - case envelope := <-peerQueue.dequeue(): - r.metrics.RouterPeerQueueRecv.Observe(time.Since(start).Seconds()) - if envelope.Message == nil { - r.logger.Error("dropping nil message", "peer", peerID) - continue - } - - bz, err := proto.Marshal(envelope.Message) - if err != nil { - r.logger.Error("failed to marshal message", "peer", peerID, "err", err) - continue - } - - if err = conn.SendMessage(ctx, envelope.ChannelID, bz); err != nil { - r.logger.Error("failed to send message", "peer", peerID, "err", err) - return err - } - - r.logger.Debug("sent message", "peer", envelope.To, "message", envelope.Message) - - case <-peerQueue.closed(): - return nil + bz, err := proto.Marshal(envelope.Message) + if err != nil { + r.logger.Error("failed to marshal message", "peer", peerID, "err", err) + continue + } - case <-ctx.Done(): - return nil + if err = conn.SendMessage(ctx, envelope.ChannelID, bz); err != nil { + r.logger.Error("failed to send message", "peer", peerID, "err", err) + return err } + + r.logger.Debug("sent message", "peer", envelope.To, "message", envelope.Message) } } // evictPeers evicts connected peers as requested by the peer manager. -func (r *Router) evictPeers(ctx context.Context) { +func (r *Router) evictPeers(ctx context.Context) error { for { - peerID, err := r.peerManager.EvictNext(ctx) - - switch { - case errors.Is(err, context.Canceled): - return - case err != nil: - r.logger.Error("failed to find next peer to evict", "err", err) - return + ev, err := r.peerManager.EvictNext(ctx) + if err != nil { + return fmt.Errorf("failed to find next peer to evict: %w", err) } - - r.logger.Info("evicting peer", "peer", peerID) - - r.peerMtx.RLock() - queue, ok := r.peerQueues[peerID] - r.peerMtx.RUnlock() - - if ok { - queue.close() + for states := range r.peerStates.Lock() { + if s, ok := states[ev.ID]; ok { + r.logger.Info("evicting peer", "peer", ev.ID, "cause", ev.Cause) + s.cancel() + } } } } -func (r *Router) setupQueueFactory(ctx context.Context) error { - qf, err := r.createQueueFactory(ctx) - if err != nil { - return err - } - - r.queueFactory = qf - return nil -} - func (r *Router) AddChDescToBeAdded(chDesc *ChannelDescriptor, callback func(*Channel)) { r.chDescsToBeAdded = append(r.chDescsToBeAdded, chDescAdderWithCallback{ chDesc: chDesc, @@ -1001,26 +790,20 @@ func (r *Router) AddChDescToBeAdded(chDesc *ChannelDescriptor, callback func(*Ch // OnStart implements service.Service. func (r *Router) OnStart(ctx context.Context) error { - if err := r.setupQueueFactory(ctx); err != nil { - return err - } - - if err := r.transport.Listen(r.endpoint); err != nil { - return err - } - for _, chDescWithCb := range r.chDescsToBeAdded { - if ch, err := r.OpenChannel(ctx, chDescWithCb.chDesc); err != nil { + if ch, err := r.OpenChannel(chDescWithCb.chDesc); err != nil { return err } else { chDescWithCb.cb(ch) } } - go r.dialPeers(ctx) - go r.evictPeers(ctx) - go r.acceptPeers(ctx, r.transport) - + r.SpawnCritical("transport.Run", func(ctx context.Context) error { + return r.transport.Run(ctx) + }) + r.SpawnCritical("dialPeers", func(ctx context.Context) error { return r.dialPeers(ctx) }) + r.SpawnCritical("evictPeers", func(ctx context.Context) error { return r.evictPeers(ctx) }) + r.SpawnCritical("acceptPeers", func(ctx context.Context) error { return r.acceptPeers(ctx, r.transport) }) return nil } @@ -1030,32 +813,7 @@ func (r *Router) OnStart(ctx context.Context) error { // router, to prevent blocked channel sends in reactors. Channels are not closed // here, since that would cause any reactor senders to panic, so it is the // sender's responsibility. -func (r *Router) OnStop() { - // Close transport listeners (unblocks Accept calls). - if err := r.transport.Close(); err != nil { - r.logger.Error("failed to close transport", "err", err) - } - - // Collect all remaining queues, and wait for them to close. - queues := []queue{} - - r.channelMtx.RLock() - for _, q := range r.channelQueues { - queues = append(queues, q) - } - r.channelMtx.RUnlock() - - r.peerMtx.RLock() - for _, q := range r.peerQueues { - queues = append(queues, q) - } - r.peerMtx.RUnlock() - - for _, q := range queues { - q.close() - <-q.closed() - } -} +func (r *Router) OnStop() {} type ChannelIDSet map[ChannelID]struct{} diff --git a/internal/p2p/router_filter_test.go b/internal/p2p/router_filter_test.go index 5b1d7219a..afd9879bd 100644 --- a/internal/p2p/router_filter_test.go +++ b/internal/p2p/router_filter_test.go @@ -3,7 +3,7 @@ package p2p import ( "context" "errors" - "net" + "net/netip" "testing" "time" @@ -21,7 +21,7 @@ func TestConnectionFiltering(t *testing.T) { logger: logger, connTracker: newConnTracker(1, time.Second), options: RouterOptions{ - FilterPeerByIP: func(ctx context.Context, ip net.IP, port uint16) error { + FilterPeerByIP: func(ctx context.Context, addr netip.AddrPort) error { filterByIPCount++ return errors.New("mock") }, diff --git a/internal/p2p/router_init_test.go b/internal/p2p/router_init_test.go index 31d06338f..f2750b153 100644 --- a/internal/p2p/router_init_test.go +++ b/internal/p2p/router_init_test.go @@ -1,64 +1,14 @@ package p2p import ( - "os" "testing" "github.com/stretchr/testify/require" - - "github.com/tendermint/tendermint/libs/log" - "github.com/tendermint/tendermint/types" ) func TestRouter_ConstructQueueFactory(t *testing.T) { - ctx := t.Context() - t.Run("ValidateOptionsPopulatesDefaultQueue", func(t *testing.T) { opts := RouterOptions{} require.NoError(t, opts.Validate()) - require.Equal(t, "fifo", opts.QueueType) - }) - t.Run("Default", func(t *testing.T) { - require.Zero(t, os.Getenv("TM_P2P_QUEUE")) - opts := RouterOptions{} - r, err := NewRouter(log.NewNopLogger(), nil, nil, nil, func() *types.NodeInfo { return &types.NodeInfo{} }, nil, nil, nil, opts) - require.NoError(t, err) - require.NoError(t, r.setupQueueFactory(ctx)) - - _, ok := r.queueFactory(1).(*fifoQueue) - require.True(t, ok) - }) - t.Run("Fifo", func(t *testing.T) { - opts := RouterOptions{QueueType: queueTypeFifo} - r, err := NewRouter(log.NewNopLogger(), nil, nil, nil, func() *types.NodeInfo { return &types.NodeInfo{} }, nil, nil, nil, opts) - require.NoError(t, err) - require.NoError(t, r.setupQueueFactory(ctx)) - - _, ok := r.queueFactory(1).(*fifoQueue) - require.True(t, ok) - }) - t.Run("Priority", func(t *testing.T) { - opts := RouterOptions{QueueType: queueTypePriority} - r, err := NewRouter(log.NewNopLogger(), nil, nil, nil, func() *types.NodeInfo { return &types.NodeInfo{} }, nil, nil, nil, opts) - require.NoError(t, err) - require.NoError(t, r.setupQueueFactory(ctx)) - - q, ok := r.queueFactory(1).(*pqScheduler) - require.True(t, ok) - defer q.close() - }) - t.Run("NonExistant", func(t *testing.T) { - opts := RouterOptions{QueueType: "fast"} - _, err := NewRouter(log.NewNopLogger(), nil, nil, nil, func() *types.NodeInfo { return &types.NodeInfo{} }, nil, nil, nil, opts) - require.Error(t, err) - require.Contains(t, err.Error(), "fast") - }) - t.Run("InternalsSafeWhenUnspecified", func(t *testing.T) { - r := &Router{} - require.Zero(t, r.options.QueueType) - - fn, err := r.createQueueFactory(ctx) - require.Error(t, err) - require.Nil(t, fn) }) } diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index 0172bb114..36fe90254 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + slog "log" "runtime" "strings" "sync" @@ -15,7 +16,6 @@ import ( "github.com/gogo/protobuf/proto" gogotypes "github.com/gogo/protobuf/types" "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" dbm "github.com/tendermint/tm-db" "github.com/tendermint/tendermint/crypto" @@ -23,6 +23,7 @@ import ( "github.com/tendermint/tendermint/internal/p2p/mocks" "github.com/tendermint/tendermint/internal/p2p/p2ptest" "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/libs/utils/require" "github.com/tendermint/tendermint/types" ) @@ -31,6 +32,7 @@ func echoReactor(ctx context.Context, channel *p2p.Channel) { for iter.Next(ctx) { envelope := iter.Envelope() value := envelope.Message.(*p2ptest.Message).Value + slog.Printf("sending back %v", value) if err := channel.Send(ctx, p2p.Envelope{ To: envelope.From, Message: &p2ptest.Message{Value: value}, @@ -38,6 +40,7 @@ func echoReactor(ctx context.Context, channel *p2p.Channel) { return } } + slog.Printf("echoReactor done") } func TestRouter_Network(t *testing.T) { @@ -45,11 +48,11 @@ func TestRouter_Network(t *testing.T) { t.Cleanup(leaktest.Check(t)) - // Create a test network and open a channel where all peers run echoReactor. + t.Logf("Create a test network and open a channel where all peers run echoReactor.") network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 8}) local := network.RandomNode() peers := network.Peers(local.NodeID) - channels := network.MakeChannels(ctx, t, chDesc) + channels := network.MakeChannels(t, chDesc) network.Start(ctx, t) @@ -58,16 +61,15 @@ func TestRouter_Network(t *testing.T) { go echoReactor(ctx, channels[peer.NodeID]) } - // Sending a message to each peer should work. + t.Logf("Sending a message to each peer should work.") for _, peer := range peers { - p2ptest.RequireSendReceive(ctx, t, channel, peer.NodeID, - &p2ptest.Message{Value: "foo"}, - &p2ptest.Message{Value: "foo"}, - ) + msg := &p2ptest.Message{Value: "foo"} + p2ptest.RequireSend(t, channel, p2p.Envelope{To: peer.NodeID, Message: msg, ChannelID: chDesc.ID}) + p2ptest.RequireReceive(t, channel, p2p.Envelope{From: peer.NodeID, Message: msg, ChannelID: chDesc.ID}) } - // Sending a broadcast should return back a message from all peers. - p2ptest.RequireSend(ctx, t, channel, p2p.Envelope{ + t.Logf("Sending a broadcast should return back a message from all peers.") + p2ptest.RequireSend(t, channel, p2p.Envelope{ Broadcast: true, Message: &p2ptest.Message{Value: "bar"}, }) @@ -79,10 +81,10 @@ func TestRouter_Network(t *testing.T) { Message: &p2ptest.Message{Value: "bar"}, }) } - p2ptest.RequireReceiveUnordered(ctx, t, channel, expect) + p2ptest.RequireReceiveUnordered(t, channel, expect) - // We then submit an error for a peer, and watch it get disconnected and - // then reconnected as the router retries it. + t.Logf("We then submit an error for a peer, and watch it get disconnected and") + t.Logf("then reconnected as the router retries it.") peerUpdates := local.MakePeerUpdatesNoRequireEmpty(ctx, t) require.NoError(t, channel.SendError(ctx, p2p.PeerError{ NodeID: peers[0].NodeID, @@ -96,23 +98,23 @@ func TestRouter_Network(t *testing.T) { func TestRouter_Channel_Basic(t *testing.T) { t.Cleanup(leaktest.Check(t)) + logger, _ := log.NewDefaultLogger("plain", "debug") ctx := t.Context() // Set up a router with no transports (so no peers). - peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) + peerManager, err := p2p.NewPeerManager(logger, selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) require.NoError(t, err) testnet := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 1}) router, err := p2p.NewRouter( - log.NewNopLogger(), + logger, p2p.NopMetrics(), selfKey, peerManager, func() *types.NodeInfo { return &selfInfo }, testnet.RandomNode().Transport, - &p2p.Endpoint{}, nil, p2p.RouterOptions{}, ) @@ -121,42 +123,32 @@ func TestRouter_Channel_Basic(t *testing.T) { require.NoError(t, router.Start(ctx)) t.Cleanup(router.Wait) - // Opening a channel should work. - chctx, chcancel := context.WithCancel(ctx) - defer chcancel() - - channel, err := router.OpenChannel(chctx, chDesc) + t.Logf("Opening a channel should work.") + channel, err := router.OpenChannel(chDesc) require.NoError(t, err) require.NotNil(t, channel) - // Opening the same channel again should fail. - _, err = router.OpenChannel(ctx, chDesc) + t.Logf("Opening the same channel again should fail.") + _, err = router.OpenChannel(chDesc) require.Error(t, err) - // Opening a different channel should work. + t.Logf("Opening a different channel should work.") chDesc2 := &p2p.ChannelDescriptor{ID: 2, MessageType: &p2ptest.Message{}} - _, err = router.OpenChannel(ctx, chDesc2) - require.NoError(t, err) - - // Closing the channel, then opening it again should be fine. - chcancel() - time.Sleep(200 * time.Millisecond) // yes yes, but Close() is async... - - channel, err = router.OpenChannel(ctx, chDesc) + _, err = router.OpenChannel(chDesc2) require.NoError(t, err) - // We should be able to send on the channel, even though there are no peers. - p2ptest.RequireSend(ctx, t, channel, p2p.Envelope{ + t.Logf("We should be able to send on the channel, even though there are no peers.") + p2ptest.RequireSend(t, channel, p2p.Envelope{ To: types.NodeID(strings.Repeat("a", 40)), Message: &p2ptest.Message{Value: "foo"}, }) - // A message to ourselves should be dropped. - p2ptest.RequireSend(ctx, t, channel, p2p.Envelope{ + t.Logf("A message to ourselves should be dropped.") + p2ptest.RequireSend(t, channel, p2p.Envelope{ To: selfID, Message: &p2ptest.Message{Value: "self"}, }) - p2ptest.RequireEmpty(ctx, t, channel) + p2ptest.RequireEmpty(t, channel) } // Channel tests are hairy to mock, so we use an in-memory network instead. @@ -165,59 +157,59 @@ func TestRouter_Channel_SendReceive(t *testing.T) { t.Cleanup(leaktest.Check(t)) - // Create a test network and open a channel on all nodes. + t.Logf("Create a test network and open a channel on all nodes.") network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 3}) ids := network.NodeIDs() aID, bID, cID := ids[0], ids[1], ids[2] - channels := network.MakeChannels(ctx, t, chDesc) + channels := network.MakeChannels(t, chDesc) a, b, c := channels[aID], channels[bID], channels[cID] - otherChannels := network.MakeChannels(ctx, t, p2ptest.MakeChannelDesc(9)) + otherChannels := network.MakeChannels(t, p2ptest.MakeChannelDesc(9)) network.Start(ctx, t) - // Sending a message a->b should work, and not send anything - // further to a, b, or c. - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireReceive(ctx, t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireEmpty(ctx, t, a, b, c) + t.Logf("Sending a message a->b should work, and not send anything further to a, b, or c.") + p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}, ChannelID: chDesc.ID}) + p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "foo"}, ChannelID: chDesc.ID}) + p2ptest.RequireEmpty(t, a, b, c) - // Sending a nil message a->b should be dropped. - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: nil}) - p2ptest.RequireEmpty(ctx, t, a, b, c) + t.Logf("Sending a nil message a->b should be dropped.") + p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: nil, ChannelID: chDesc.ID}) + p2ptest.RequireEmpty(t, a, b, c) - // Sending a different message type should be dropped. - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}}) - p2ptest.RequireEmpty(ctx, t, a, b, c) + t.Logf("Sending a different message type should be dropped.") + p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}, ChannelID: chDesc.ID}) + p2ptest.RequireEmpty(t, a, b, c) - // Sending to an unknown peer should be dropped. - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{ - To: types.NodeID(strings.Repeat("a", 40)), - Message: &p2ptest.Message{Value: "a"}, + t.Logf("Sending to an unknown peer should be dropped.") + p2ptest.RequireSend(t, a, p2p.Envelope{ + To: types.NodeID(strings.Repeat("a", 40)), + Message: &p2ptest.Message{Value: "a"}, + ChannelID: chDesc.ID, }) - p2ptest.RequireEmpty(ctx, t, a, b, c) + p2ptest.RequireEmpty(t, a, b, c) - // Sending without a recipient should be dropped. - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{Message: &p2ptest.Message{Value: "noto"}}) - p2ptest.RequireEmpty(ctx, t, a, b, c) + t.Logf("Sending without a recipient should be dropped.") + p2ptest.RequireSend(t, a, p2p.Envelope{Message: &p2ptest.Message{Value: "noto"}, ChannelID: chDesc.ID}) + p2ptest.RequireEmpty(t, a, b, c) - // Sending to self should be dropped. - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "self"}}) - p2ptest.RequireEmpty(ctx, t, a, b, c) + t.Logf("Sending to self should be dropped.") + p2ptest.RequireSend(t, a, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "self"}, ChannelID: chDesc.ID}) + p2ptest.RequireEmpty(t, a, b, c) - // Removing b and sending to it should be dropped. + t.Logf("Removing b and sending to it should be dropped.") network.Remove(ctx, t, bID) - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "nob"}}) - p2ptest.RequireEmpty(ctx, t, a, b, c) + p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "nob"}, ChannelID: chDesc.ID}) + p2ptest.RequireEmpty(t, a, b, c) - // After all this, sending a message c->a should work. - p2ptest.RequireSend(ctx, t, c, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "bar"}}) - p2ptest.RequireReceive(ctx, t, a, p2p.Envelope{From: cID, Message: &p2ptest.Message{Value: "bar"}}) - p2ptest.RequireEmpty(ctx, t, a, b, c) + t.Logf("After all this, sending a message c->a should work.") + p2ptest.RequireSend(t, c, p2p.Envelope{To: aID, Message: &p2ptest.Message{Value: "bar"}, ChannelID: chDesc.ID}) + p2ptest.RequireReceive(t, a, p2p.Envelope{From: cID, Message: &p2ptest.Message{Value: "bar"}, ChannelID: chDesc.ID}) + p2ptest.RequireEmpty(t, a, b, c) - // None of these messages should have made it onto the other channels. + t.Logf("None of these messages should have made it onto the other channels.") for _, other := range otherChannels { - p2ptest.RequireEmpty(ctx, t, other) + p2ptest.RequireEmpty(t, other) } } @@ -226,29 +218,29 @@ func TestRouter_Channel_Broadcast(t *testing.T) { ctx := t.Context() - // Create a test network and open a channel on all nodes. + t.Logf("Create a test network and open a channel on all nodes.") network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 4}) ids := network.NodeIDs() aID, bID, cID, dID := ids[0], ids[1], ids[2], ids[3] - channels := network.MakeChannels(ctx, t, chDesc) + channels := network.MakeChannels(t, chDesc) a, b, c, d := channels[aID], channels[bID], channels[cID], channels[dID] network.Start(ctx, t) - // Sending a broadcast from b should work. - p2ptest.RequireSend(ctx, t, b, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireReceive(ctx, t, a, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireReceive(ctx, t, c, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireReceive(ctx, t, d, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireEmpty(ctx, t, a, b, c, d) + t.Logf("Sending a broadcast from b should work.") + p2ptest.RequireSend(t, b, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "foo"}, ChannelID: chDesc.ID}) + p2ptest.RequireReceive(t, a, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}, ChannelID: chDesc.ID}) + p2ptest.RequireReceive(t, c, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}, ChannelID: chDesc.ID}) + p2ptest.RequireReceive(t, d, p2p.Envelope{From: bID, Message: &p2ptest.Message{Value: "foo"}, ChannelID: chDesc.ID}) + p2ptest.RequireEmpty(t, a, b, c, d) - // Removing one node from the network shouldn't prevent broadcasts from working. + t.Logf("Removing one node from the network shouldn't prevent broadcasts from working.") network.Remove(ctx, t, dID) - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "bar"}}) - p2ptest.RequireReceive(ctx, t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "bar"}}) - p2ptest.RequireReceive(ctx, t, c, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "bar"}}) - p2ptest.RequireEmpty(ctx, t, a, b, c, d) + p2ptest.RequireSend(t, a, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "bar"}, ChannelID: chDesc.ID}) + p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "bar"}, ChannelID: chDesc.ID}) + p2ptest.RequireReceive(t, c, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "bar"}, ChannelID: chDesc.ID}) + p2ptest.RequireEmpty(t, a, b, c, d) } func TestRouter_Channel_Wrapper(t *testing.T) { @@ -256,7 +248,7 @@ func TestRouter_Channel_Wrapper(t *testing.T) { ctx := t.Context() - // Create a test network and open a channel on all nodes. + t.Logf("Create a test network and open a channel on all nodes.") network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 2}) ids := network.NodeIDs() @@ -266,10 +258,11 @@ func TestRouter_Channel_Wrapper(t *testing.T) { MessageType: &wrapperMessage{}, Priority: 5, SendQueueCapacity: 10, + RecvBufferCapacity: 10, RecvMessageCapacity: 10, } - channels := network.MakeChannels(ctx, t, chDesc) + channels := network.MakeChannels(t, chDesc) a, b := channels[aID], channels[bID] network.Start(ctx, t) @@ -277,22 +270,24 @@ func TestRouter_Channel_Wrapper(t *testing.T) { // Since wrapperMessage implements p2p.Wrapper and handles Message, it // should automatically wrap and unwrap sent messages -- we prepend the // wrapper actions to the message value to signal this. - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}}) - p2ptest.RequireReceive(ctx, t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "unwrap:wrap:foo"}}) + p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "foo"}, ChannelID: chDesc.ID}) + p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "unwrap:wrap:foo"}, ChannelID: chDesc.ID}) // If we send a different message that can't be wrapped, it should be dropped. - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}}) - p2ptest.RequireEmpty(ctx, t, b) + p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &gogotypes.BoolValue{Value: true}, ChannelID: chDesc.ID}) + p2ptest.RequireEmpty(t, b) // If we send the wrapper message itself, it should also be passed through // since WrapperMessage supports it, and should only be unwrapped at the receiver. - p2ptest.RequireSend(ctx, t, a, p2p.Envelope{ - To: bID, - Message: &wrapperMessage{Message: p2ptest.Message{Value: "foo"}}, + p2ptest.RequireSend(t, a, p2p.Envelope{ + To: bID, + Message: &wrapperMessage{Message: p2ptest.Message{Value: "foo"}}, + ChannelID: chDesc.ID, }) - p2ptest.RequireReceive(ctx, t, b, p2p.Envelope{ - From: aID, - Message: &p2ptest.Message{Value: "unwrap:foo"}, + p2ptest.RequireReceive(t, b, p2p.Envelope{ + From: aID, + Message: &p2ptest.Message{Value: "unwrap:foo"}, + ChannelID: chDesc.ID, }) } @@ -325,18 +320,18 @@ func TestRouter_Channel_Error(t *testing.T) { ctx := t.Context() - // Create a test network and open a channel on all nodes. + t.Logf("Create a test network and open a channel on all nodes.") network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 3}) network.Start(ctx, t) ids := network.NodeIDs() aID, bID := ids[0], ids[1] - channels := network.MakeChannels(ctx, t, chDesc) + channels := network.MakeChannels(t, chDesc) a := channels[aID] - // Erroring b should cause it to be disconnected. It will reconnect shortly after. + t.Logf("Erroring b should cause it to be disconnected. It will reconnect shortly after.") sub := network.Nodes[aID].MakePeerUpdates(ctx, t) - p2ptest.RequireError(ctx, t, a, p2p.PeerError{NodeID: bID, Err: errors.New("boom")}) + p2ptest.RequireSendError(t, a, p2p.PeerError{NodeID: bID, Err: errors.New("boom")}) p2ptest.RequireUpdates(t, sub, []p2p.PeerUpdate{ {NodeID: bID, Status: p2p.PeerStatusDown}, {NodeID: bID, Status: p2p.PeerStatusUp}, @@ -350,7 +345,7 @@ func TestRouter_AcceptPeers(t *testing.T) { ok bool }{ "valid handshake": {peerInfo, peerKey.PubKey(), true}, - "empty handshake": {types.NodeInfo{}, nil, false}, + "empty handshake": {types.NodeInfo{}, peerKey.PubKey(), false}, "invalid key": {peerInfo, selfKey.PubKey(), false}, "self handshake": {selfInfo, selfKey.PubKey(), false}, "incompatible peer": { @@ -384,11 +379,9 @@ func TestRouter_AcceptPeers(t *testing.T) { } mockTransport := &mocks.Transport{} - mockTransport.On("String").Maybe().Return("mock") - mockTransport.On("Close").Return(nil).Maybe() mockTransport.On("Accept", mock.Anything).Once().Return(mockConnection, nil) - mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, io.EOF) - mockTransport.On("Listen", mock.Anything).Return(nil) + mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, context.Canceled) + mockTransport.On("Run", mock.Anything).Return(nil) // Set up and start the router. peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) @@ -404,7 +397,6 @@ func TestRouter_AcceptPeers(t *testing.T) { func() *types.NodeInfo { return &selfInfo }, mockTransport, nil, - nil, p2p.RouterOptions{}, ) require.NoError(t, err) @@ -444,10 +436,8 @@ func TestRouter_AcceptPeers_Errors(t *testing.T) { // Set up a mock transport that returns io.EOF once, which should prevent // the router from calling Accept again. mockTransport := &mocks.Transport{} - mockTransport.On("String").Maybe().Return("mock") - mockTransport.On("Accept", mock.Anything).Once().Return(nil, io.EOF) - mockTransport.On("Close").Return(nil) - mockTransport.On("Listen", mock.Anything).Return(nil) + mockTransport.On("Accept", mock.Anything).Once().Return(nil, context.Canceled) + mockTransport.On("Run", mock.Anything).Return(nil) // Set up and start the router. peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) @@ -461,7 +451,6 @@ func TestRouter_AcceptPeers_Errors(t *testing.T) { func() *types.NodeInfo { return &selfInfo }, mockTransport, nil, - nil, p2p.RouterOptions{}, ) require.NoError(t, err) @@ -496,13 +485,11 @@ func TestRouter_AcceptPeers_HeadOfLineBlocking(t *testing.T) { mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) mockTransport := &mocks.Transport{} - mockTransport.On("String").Maybe().Return("mock") - mockTransport.On("Close").Return(nil) mockTransport.On("Accept", mock.Anything).Times(3).Run(func(_ mock.Arguments) { acceptCh <- true }).Return(mockConnection, nil) - mockTransport.On("Accept", mock.Anything).Once().Return(nil, io.EOF) - mockTransport.On("Listen", mock.Anything).Return(nil) + mockTransport.On("Accept", mock.Anything).Once().Return(nil, context.Canceled) + mockTransport.On("Run", mock.Anything).Return(nil) // Set up and start the router. peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) @@ -516,7 +503,6 @@ func TestRouter_AcceptPeers_HeadOfLineBlocking(t *testing.T) { func() *types.NodeInfo { return &selfInfo }, mockTransport, nil, - nil, p2p.RouterOptions{}, ) require.NoError(t, err) @@ -542,7 +528,7 @@ func TestRouter_DialPeers(t *testing.T) { ok bool }{ "valid dial": {peerInfo.NodeID, peerInfo, peerKey.PubKey(), nil, true}, - "empty handshake": {peerInfo.NodeID, types.NodeInfo{}, nil, nil, false}, + "empty handshake": {peerInfo.NodeID, types.NodeInfo{}, peerKey.PubKey(), nil, false}, "invalid key": {peerInfo.NodeID, peerInfo, selfKey.PubKey(), nil, false}, "unexpected node ID": {peerInfo.NodeID, selfInfo, selfKey.PubKey(), nil, false}, "dial error": {peerInfo.NodeID, peerInfo, peerKey.PubKey(), errors.New("boom"), false}, @@ -566,7 +552,7 @@ func TestRouter_DialPeers(t *testing.T) { ctx := t.Context() address := p2p.NodeAddress{Protocol: "mock", NodeID: tc.dialID} - endpoint := &p2p.Endpoint{Protocol: "mock", Path: string(tc.dialID)} + endpoint := p2p.Endpoint{Protocol: "mock", Path: string(tc.dialID)} // Set up a mock transport that handshakes. connCtx, connCancel := context.WithCancel(ctx) @@ -583,10 +569,8 @@ func TestRouter_DialPeers(t *testing.T) { } mockTransport := &mocks.Transport{} - mockTransport.On("String").Maybe().Return("mock") - mockTransport.On("Close").Return(nil).Maybe() - mockTransport.On("Listen", mock.Anything).Return(nil) - mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, io.EOF) + mockTransport.On("Run", mock.Anything).Return(nil) + mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, context.Canceled) if tc.dialErr == nil { mockTransport.On("Dial", mock.Anything, endpoint).Once().Return(mockConnection, nil) // This handles the retry when a dialed connection gets closed after ReceiveMessage @@ -615,7 +599,6 @@ func TestRouter_DialPeers(t *testing.T) { func() *types.NodeInfo { return &selfInfo }, mockTransport, nil, - nil, p2p.RouterOptions{}, ) require.NoError(t, err) @@ -665,12 +648,10 @@ func TestRouter_DialPeers_Parallel(t *testing.T) { mockConnection.On("Close").Return(nil) mockTransport := &mocks.Transport{} - mockTransport.On("String").Maybe().Return("mock") - mockTransport.On("Close").Return(nil) - mockTransport.On("Listen", mock.Anything).Return(nil) - mockTransport.On("Accept", mock.Anything).Once().Return(nil, io.EOF) + mockTransport.On("Run", mock.Anything).Return(nil) + mockTransport.On("Accept", mock.Anything).Once().Return(nil, context.Canceled) for _, address := range []p2p.NodeAddress{a, b, c} { - endpoint := &p2p.Endpoint{Protocol: address.Protocol, Path: string(address.NodeID)} + endpoint := p2p.Endpoint{Protocol: address.Protocol, Path: string(address.NodeID)} mockTransport.On("Dial", mock.Anything, endpoint).Run(func(_ mock.Arguments) { dialCh <- true }).Return(mockConnection, nil) @@ -700,9 +681,8 @@ func TestRouter_DialPeers_Parallel(t *testing.T) { func() *types.NodeInfo { return &selfInfo }, mockTransport, nil, - nil, p2p.RouterOptions{ - DialSleep: func(_ context.Context) {}, + DialSleep: func(_ context.Context) error { return nil }, NumConcurrentDials: func() int { ncpu := runtime.NumCPU() if ncpu <= 3 { @@ -734,6 +714,7 @@ func TestRouter_DialPeers_Parallel(t *testing.T) { func TestRouter_EvictPeers(t *testing.T) { t.Cleanup(leaktest.Check(t)) + logger, _ := log.NewDefaultLogger("plain", "debug") ctx := t.Context() @@ -754,27 +735,24 @@ func TestRouter_EvictPeers(t *testing.T) { }).Return(nil) mockTransport := &mocks.Transport{} - mockTransport.On("String").Maybe().Return("mock") - mockTransport.On("Close").Return(nil) mockTransport.On("Accept", mock.Anything).Once().Return(mockConnection, nil) - mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, io.EOF) - mockTransport.On("Listen", mock.Anything).Return(nil) + mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, context.Canceled) + mockTransport.On("Run", mock.Anything).Return(nil) // Set up and start the router. - peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) + peerManager, err := p2p.NewPeerManager(logger, selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) require.NoError(t, err) sub := peerManager.Subscribe(ctx) router, err := p2p.NewRouter( - log.NewNopLogger(), + logger, p2p.NopMetrics(), selfKey, peerManager, func() *types.NodeInfo { return &selfInfo }, mockTransport, nil, - nil, p2p.RouterOptions{}, ) require.NoError(t, err) @@ -785,7 +763,7 @@ func TestRouter_EvictPeers(t *testing.T) { NodeID: peerInfo.NodeID, Status: p2p.PeerStatusUp, }) - + t.Logf("node is up") peerManager.Errored(peerInfo.NodeID, errors.New("boom")) p2ptest.RequireUpdate(t, sub, p2p.PeerUpdate{ @@ -818,11 +796,9 @@ func TestRouter_ChannelCompatability(t *testing.T) { mockConnection.On("Close").Return(nil) mockTransport := &mocks.Transport{} - mockTransport.On("String").Maybe().Return("mock") - mockTransport.On("Close").Return(nil) + mockTransport.On("Run", mock.Anything).Return(nil) mockTransport.On("Accept", mock.Anything).Once().Return(mockConnection, nil) - mockTransport.On("Accept", mock.Anything).Once().Return(nil, io.EOF) - mockTransport.On("Listen", mock.Anything).Return(nil) + mockTransport.On("Accept", mock.Anything).Once().Return(nil, context.Canceled) // Set up and start the router. peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) @@ -836,7 +812,6 @@ func TestRouter_ChannelCompatability(t *testing.T) { func() *types.NodeInfo { return &selfInfo }, mockTransport, nil, - nil, p2p.RouterOptions{}, ) require.NoError(t, err) @@ -871,11 +846,9 @@ func TestRouter_DontSendOnInvalidChannel(t *testing.T) { mockTransport := &mocks.Transport{} mockTransport.On("AddChannelDescriptors", mock.Anything).Return() - mockTransport.On("String").Maybe().Return("mock") - mockTransport.On("Close").Return(nil) mockTransport.On("Accept", mock.Anything).Once().Return(mockConnection, nil) - mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, io.EOF) - mockTransport.On("Listen", mock.Anything).Return(nil) + mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, context.Canceled) + mockTransport.On("Run", mock.Anything).Return(nil) // Set up and start the router. peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) @@ -891,7 +864,6 @@ func TestRouter_DontSendOnInvalidChannel(t *testing.T) { func() *types.NodeInfo { return &selfInfo }, mockTransport, nil, - nil, p2p.RouterOptions{}, ) require.NoError(t, err) @@ -902,7 +874,7 @@ func TestRouter_DontSendOnInvalidChannel(t *testing.T) { Status: p2p.PeerStatusUp, }) - channel, err := router.OpenChannel(ctx, chDesc) + channel, err := router.OpenChannel(chDesc) require.NoError(t, err) require.NoError(t, channel.Send(ctx, p2p.Envelope{ @@ -939,10 +911,9 @@ func TestRouter_Channel_FilterByID(t *testing.T) { mockTransport := &mocks.Transport{} mockTransport.On("AddChannelDescriptors", mock.Anything).Return() mockTransport.On("String").Maybe().Return("mock") - mockTransport.On("Close").Return(nil) mockTransport.On("Accept", mock.Anything).Once().Return(mockConnection, nil) - mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, io.EOF) - mockTransport.On("Listen", mock.Anything).Return(nil) + mockTransport.On("Accept", mock.Anything).Maybe().Return(nil, context.Canceled) + mockTransport.On("Run", mock.Anything).Return(nil) peerManager, err := p2p.NewPeerManager(log.NewNopLogger(), selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}, p2p.NopMetrics()) require.NoError(t, err) @@ -955,7 +926,6 @@ func TestRouter_Channel_FilterByID(t *testing.T) { peerManager, func() *types.NodeInfo { return &selfInfo }, mockTransport, - &p2p.Endpoint{}, nil, p2p.RouterOptions{}, ) @@ -979,7 +949,6 @@ func TestRouter_Channel_FilterByID(t *testing.T) { peerManager, func() *types.NodeInfo { return &selfInfo }, mockTransport, - &p2p.Endpoint{}, func(_ context.Context, _ types.NodeID) error { return errors.New("should filter") }, p2p.RouterOptions{}, ) diff --git a/internal/p2p/rqueue.go b/internal/p2p/rqueue.go index afeed3a65..00cefc4d8 100644 --- a/internal/p2p/rqueue.go +++ b/internal/p2p/rqueue.go @@ -3,111 +3,173 @@ package p2p import ( "container/heap" "context" - "sort" "time" "github.com/gogo/protobuf/proto" + "github.com/tendermint/tendermint/libs/utils" ) -type simpleQueue struct { - input chan Envelope - output chan Envelope - closeFn func() - closeCh <-chan struct{} +type ord[T any] interface { + Less(T) bool +} + +type withIdx[T any] struct { + v T + minIdx int // index in byMin + maxIdx int // index in byMax +} + +func newWithIdx[T any](v T) *withIdx[T] { + return &withIdx[T]{v: v} +} + +// Heap returning minimal elements. +type byMin[T ord[T]] struct{ a []*withIdx[T] } + +func newByMin[T ord[T]](capacity int) byMin[T] { return byMin[T]{make([]*withIdx[T], 0, capacity)} } +func (x *byMin[T]) Less(i, j int) bool { return x.a[i].v.Less(x.a[j].v) } +func (x *byMin[T]) Len() int { return len(x.a) } +func (x *byMin[T]) Swap(i, j int) { + x.a[i], x.a[j] = x.a[j], x.a[i] + x.a[i].minIdx = i + x.a[j].minIdx = j +} +func (x *byMin[T]) Push(v any) { + w := v.(*withIdx[T]) + w.minIdx = len(x.a) + x.a = append(x.a, w) +} +func (x *byMin[T]) Pop() any { + n := len(x.a) - 1 + w := x.a[n] + x.a = x.a[:n] + return w +} + +// Heap returning maximal elements. +type byMax[T ord[T]] struct{ a []*withIdx[T] } + +func newByMax[T ord[T]](capacity int) byMax[T] { return byMax[T]{make([]*withIdx[T], 0, capacity)} } +func (x *byMax[T]) Less(i, j int) bool { return x.a[j].v.Less(x.a[i].v) } +func (x *byMax[T]) Len() int { return len(x.a) } +func (x *byMax[T]) Swap(i, j int) { + x.a[i], x.a[j] = x.a[j], x.a[i] + x.a[i].maxIdx = i + x.a[j].maxIdx = j +} +func (x *byMax[T]) Push(v any) { + w := v.(*withIdx[T]) + w.maxIdx = len(x.a) + x.a = append(x.a, w) +} +func (x *byMax[T]) Pop() any { + n := len(x.a) - 1 + w := x.a[n] + x.a = x.a[:n] + return w +} - maxSize int - chDescs []*ChannelDescriptor +// pqEnvelope defines a wrapper around an Envelope with priority to be inserted +// into a priority Queue used for Envelope scheduling. +type pqEnvelope struct { + envelope Envelope + priority int + size int + timestamp time.Time } -func newSimplePriorityQueue(ctx context.Context, size int, chDescs []*ChannelDescriptor) *simpleQueue { - if size%2 != 0 { - size++ +// true <=> a has higher priority than b +func (a *pqEnvelope) Less(b *pqEnvelope) bool { + // higher base priority wins + if a, b := a.priority, b.priority; a != b { + return a > b + } + // newer timestamp wins + if a, b := a.timestamp, b.timestamp; a.Sub(b).Abs() >= 10*time.Millisecond { + return a.After(b) } + // larger first + return a.size > b.size +} + +type inner struct { + capacity int + byMin byMin[*pqEnvelope] + byMax byMax[*pqEnvelope] +} - ctx, cancel := context.WithCancel(ctx) - q := &simpleQueue{ - input: make(chan Envelope, size*2), - output: make(chan Envelope, size/2), - maxSize: size * size, - closeCh: ctx.Done(), - closeFn: cancel, +func newInner(capacity int) *inner { + return &inner{ + capacity: capacity, + // We prune the maximal elements whenever capacity is exceeded. + // Therefore to avoid reallocation we need the heaps to have capacity+1. + byMin: newByMin[*pqEnvelope](capacity + 1), + byMax: newByMax[*pqEnvelope](capacity + 1), } +} - go q.run(ctx) - return q +func (i *inner) Len() int { return i.byMin.Len() } + +func (i *inner) Push(e *pqEnvelope) utils.Option[Envelope] { + w := newWithIdx(e) + heap.Push(&i.byMin, w) + heap.Push(&i.byMax, w) + if i.byMin.Len() > i.capacity { + w := heap.Pop(&i.byMax).(*withIdx[*pqEnvelope]) + heap.Remove(&i.byMin, w.minIdx) + return utils.Some(w.v.envelope) + } + return utils.None[Envelope]() } -func (q *simpleQueue) enqueue() chan<- Envelope { return q.input } -func (q *simpleQueue) dequeue() <-chan Envelope { return q.output } -func (q *simpleQueue) close() { q.closeFn() } -func (q *simpleQueue) closed() <-chan struct{} { return q.closeCh } +func (i *inner) Pop() *pqEnvelope { + w := heap.Pop(&i.byMin).(*withIdx[*pqEnvelope]) + heap.Remove(&i.byMax, w.maxIdx) + return w.v +} -func (q *simpleQueue) run(ctx context.Context) { - defer q.closeFn() +type Queue struct{ inner utils.Watch[*inner] } - var chPriorities = make(map[ChannelID]uint, len(q.chDescs)) - for _, chDesc := range q.chDescs { - chID := chDesc.ID - chPriorities[chID] = uint(chDesc.Priority) +func NewQueue(size int) *Queue { + if size <= 0 { + // prevent caller from shooting self in the foot. + size = 1 } + return &Queue{inner: utils.NewWatch(newInner(size))} +} + +func (q *Queue) Len() int { + for inner := range q.inner.Lock() { + return inner.Len() + } + panic("unreachable") +} + +// Non-blocking send. +// Returns the pruned message if any. +func (q *Queue) Send(e Envelope, priority int) utils.Option[Envelope] { + // We construct the pqEnvelope without holding the lock to avoid contention. + pqe := &pqEnvelope{ + envelope: e, + size: proto.Size(e.Message), + priority: priority, + timestamp: time.Now().UTC(), + } + for inner, ctrl := range q.inner.Lock() { + pruned := inner.Push(pqe) + ctrl.Updated() + return pruned + } + panic("unreachable") +} - pq := make(priorityQueue, 0, q.maxSize) - heap.Init(&pq) - ticker := time.NewTicker(10 * time.Millisecond) - defer ticker.Stop() - // must have a buffer of exactly one because both sides of - // this channel are used in this loop, and simply signals adds - // to the heap - signal := make(chan struct{}, 1) - for { - select { - case <-ctx.Done(): - return - case <-q.closeCh: - return - case e := <-q.input: - // enqueue the incoming Envelope - heap.Push(&pq, &pqEnvelope{ - envelope: e, - size: uint(proto.Size(e.Message)), - priority: chPriorities[e.ChannelID], - timestamp: time.Now().UTC(), - }) - - select { - case signal <- struct{}{}: - default: - if len(pq) > q.maxSize { - sort.Sort(pq) - pq = pq[:q.maxSize] - } - } - - case <-ticker.C: - if len(pq) > q.maxSize { - sort.Sort(pq) - pq = pq[:q.maxSize] - } - if len(pq) > 0 { - select { - case signal <- struct{}{}: - default: - } - } - case <-signal: - SEND: - for len(pq) > 0 { - select { - case <-ctx.Done(): - return - case <-q.closeCh: - return - case q.output <- heap.Pop(&pq).(*pqEnvelope).envelope: - continue SEND - default: - break SEND - } - } +// Blocking recv. +func (q *Queue) Recv(ctx context.Context) (Envelope, error) { + for inner, ctrl := range q.inner.Lock() { + if err := ctrl.WaitUntil(ctx, func() bool { return inner.Len() > 0 }); err != nil { + return Envelope{}, err } + return inner.Pop().envelope, nil } + panic("unreachable") } diff --git a/internal/p2p/rqueue_test.go b/internal/p2p/rqueue_test.go index 4e5e4c8bf..2b40d8d1b 100644 --- a/internal/p2p/rqueue_test.go +++ b/internal/p2p/rqueue_test.go @@ -1,45 +1,80 @@ package p2p import ( + "context" + "github.com/tendermint/tendermint/libs/utils" + "github.com/tendermint/tendermint/libs/utils/scope" + "slices" "testing" - "time" ) -func TestSimpleQueue(t *testing.T) { +func TestQueuePruning(t *testing.T) { ctx := t.Context() + rng := utils.TestRng() + n := 20 + var want []int + sq := NewQueue(n) + for range 100 { + // Send a bunch of messages. + for range 30 { + // priority is not part of the envelope currently, + // so we hack it by encoding it as a ChannelID. + v := ChannelID(rng.Int()) + sq.Send(Envelope{From: "merlin", ChannelID: v}, int(v)) + want = append(want, int(v)) + } + + // Low priority messages should be dropped. + slices.Sort(want) + l := len(want) + want = want[l-n:] + if len(want) != sq.Len() { + t.Fatalf("expected len %d, got %d", len(want), sq.Len()) + } - // set up a small queue with very small buffers so we can - // watch it shed load, then send a bunch of messages to the - // queue, most of which we'll watch it drop. - sq := newSimplePriorityQueue(ctx, 1, nil) - for i := 0; i < 100; i++ { - sq.enqueue() <- Envelope{From: "merlin"} + // Receive a bunch of messages. + for range 5 { + got, err := sq.Recv(ctx) + if err != nil { + t.Fatal(err) + } + l := len(want) + if got, want := int(got.ChannelID), want[l-1]; got != want { + t.Fatalf("sq.Recv() = %d, want %d", got, want) + } + want = want[:l-1] + } + if len(want) != sq.Len() { + t.Fatalf("expected len %d, got %d", len(want), sq.Len()) + } } +} - seen := 0 +// Test that receivers are notified when a message is available. +func TestQueueConcurrency(t *testing.T) { + ctx := t.Context() + q1, q2 := NewQueue(1), NewQueue(1) -RETRY: - for seen <= 2 { - select { - case e := <-sq.dequeue(): - if e.From != "merlin" { - continue + if err := utils.IgnoreCancel(scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + s.SpawnBg(func() error { + // Echo task. + for { + msg, err := q1.Recv(ctx) + if err != nil { + return err + } + q2.Send(msg, 0) + } + }) + // Send and receive a bunch of messages. + for range 100 { + q1.Send(Envelope{From: "merlin"}, 0) + if _, err := q2.Recv(ctx); err != nil { + return err } - seen++ - case <-time.After(10 * time.Millisecond): - break RETRY } + return nil + })); err != nil { + t.Fatal(err) } - // if we don't see any messages, then it's just broken. - if seen == 0 { - t.Errorf("seen %d messages, should have seen more than one", seen) - } - // ensure that load shedding happens: there can be at most 3 - // messages that we get out of this, one that was buffered - // plus 2 that were under the cap, everything else gets - // dropped. - if seen > 3 { - t.Errorf("saw %d messages, should have seen 5 or fewer", seen) - } - } diff --git a/internal/p2p/transport.go b/internal/p2p/transport.go index 7a965260a..540c165af 100644 --- a/internal/p2p/transport.go +++ b/internal/p2p/transport.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "net" + "net/netip" "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/types" @@ -23,18 +23,14 @@ type Protocol string // Transport is a connection-oriented mechanism for exchanging data with a peer. type Transport interface { - // Listen starts the transport on the specified endpoint. - Listen(*Endpoint) error - + // Run executes the background tasks of transport. + Run(ctx context.Context) error // Protocols returns the protocols supported by the transport. The Router // uses this to pick a transport for an Endpoint. Protocols() []Protocol - // Endpoints returns the local endpoints the transport is listening on, if any. - // - // How to listen is transport-dependent, e.g. MConnTransport uses Listen() while - // MemoryTransport starts listening via MemoryNetwork.CreateTransport(). - Endpoint() (*Endpoint, error) + // Endpoints returns the local endpoints the transport is listening on. + Endpoint() Endpoint // Accept waits for the next inbound connection on a listening endpoint, blocking // until either a connection is available or the transport is closed. On closure, @@ -42,10 +38,7 @@ type Transport interface { Accept(context.Context) (Connection, error) // Dial creates an outbound connection to an endpoint. - Dial(context.Context, *Endpoint) (Connection, error) - - // Close stops accepting new connections, but does not close active connections. - Close() error + Dial(context.Context, Endpoint) (Connection, error) // AddChannelDescriptors is only part of this interface // temporarily @@ -115,30 +108,23 @@ type Connection interface { type Endpoint struct { // Protocol specifies the transport protocol. Protocol Protocol - - // IP is an IP address (v4 or v6) to connect to. If set, this defines the - // endpoint as a networked endpoint. - IP net.IP - - // Port is a network port (either TCP or UDP). If 0, a default port may be - // used depending on the protocol. - Port uint16 + // TCP endpoint address. + Addr netip.AddrPort // Path is an optional transport-specific path or identifier. Path string } // NewEndpoint constructs an Endpoint from a types.NetAddress structure. -func NewEndpoint(addr string) (*Endpoint, error) { - ip, port, err := types.ParseAddressString(addr) +func NewEndpoint(addr string) (Endpoint, error) { + addrPort, err := types.ParseAddressString(addr) if err != nil { - return nil, err + return Endpoint{}, err } - return &Endpoint{ + return Endpoint{ Protocol: MConnProtocol, - IP: ip, - Port: port, + Addr: addrPort, }, nil } @@ -149,9 +135,9 @@ func (e Endpoint) NodeAddress(nodeID types.NodeID) NodeAddress { Protocol: e.Protocol, Path: e.Path, } - if len(e.IP) > 0 { - address.Hostname = e.IP.String() - address.Port = e.Port + if e.Addr != (netip.AddrPort{}) { + address.Hostname = e.Addr.Addr().String() + address.Port = e.Addr.Port() } return address } @@ -161,7 +147,7 @@ func (e Endpoint) String() string { // If this is a non-networked endpoint with a valid node ID as a path, // assume that path is a node ID (to handle opaque URLs of the form // scheme:id). - if e.IP == nil { + if e.Addr == (netip.AddrPort{}) { if nodeID, err := types.NewNodeID(e.Path); err == nil { return e.NodeAddress(nodeID).String() } @@ -171,20 +157,16 @@ func (e Endpoint) String() string { // Validate validates the endpoint. func (e Endpoint) Validate() error { - switch { - case e.Protocol == "": + if e.Protocol == "" { return errors.New("endpoint has no protocol") - - case len(e.IP) > 0 && e.IP.To16() == nil: - return fmt.Errorf("invalid IP address %v", e.IP) - - case e.Port > 0 && len(e.IP) == 0: - return fmt.Errorf("endpoint has port %v but no IP", e.Port) - - case len(e.IP) == 0 && e.Path == "": + } + if (e.Addr == netip.AddrPort{}) && (e.Path == "") { return errors.New("endpoint has neither path nor IP") - - default: - return nil } + if e.Addr != (netip.AddrPort{}) { + if !e.Addr.IsValid() { + return fmt.Errorf("endpoint has invalid address %q", e.Addr.String()) + } + } + return nil } diff --git a/internal/p2p/transport_mconn.go b/internal/p2p/transport_mconn.go index 3709fb58f..0b13916e3 100644 --- a/internal/p2p/transport_mconn.go +++ b/internal/p2p/transport_mconn.go @@ -7,7 +7,7 @@ import ( "io" "math" "net" - "strconv" + "net/netip" "sync" "golang.org/x/net/netutil" @@ -16,6 +16,9 @@ import ( "github.com/tendermint/tendermint/internal/libs/protoio" "github.com/tendermint/tendermint/internal/p2p/conn" "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/libs/utils" + "github.com/tendermint/tendermint/libs/utils/scope" + "github.com/tendermint/tendermint/libs/utils/tcp" p2pproto "github.com/tendermint/tendermint/proto/tendermint/p2p" "github.com/tendermint/tendermint/types" ) @@ -41,13 +44,12 @@ type MConnTransportOptions struct { // Tendermint protocol ("MConn"). type MConnTransport struct { logger log.Logger + endpoint Endpoint options MConnTransportOptions mConnConfig conn.MConnConfig channelDescs []*ChannelDescriptor - - closeOnce sync.Once - doneCh chan struct{} - listener net.Listener + started chan struct{} + listener chan *mConnConnection } // NewMConnTransport sets up a new MConnection transport. This uses the @@ -55,70 +57,43 @@ type MConnTransport struct { // conn.MConnection. func NewMConnTransport( logger log.Logger, + endpoint Endpoint, mConnConfig conn.MConnConfig, channelDescs []*ChannelDescriptor, options MConnTransportOptions, ) *MConnTransport { return &MConnTransport{ logger: logger, + endpoint: endpoint, options: options, mConnConfig: mConnConfig, - doneCh: make(chan struct{}), channelDescs: channelDescs, + // This is rendezvous channel, so that no unclosed connections get stuck inside + // when transport is closing. + started: make(chan struct{}), + listener: make(chan *mConnConnection), } } -// String implements Transport. -func (m *MConnTransport) String() string { - return string(MConnProtocol) +// WaitForStart waits until transport starts listening for incoming connections. +func (m *MConnTransport) WaitForStart(ctx context.Context) error { + _, _, err := utils.RecvOrClosed(ctx, m.started) + return err } -// Protocols implements Transport. We support tcp for backwards-compatibility. -func (m *MConnTransport) Protocols() []Protocol { - return []Protocol{MConnProtocol, TCPProtocol} +func (m *MConnTransport) Endpoint() Endpoint { + return m.endpoint } -// Endpoint implements Transport. -func (m *MConnTransport) Endpoint() (*Endpoint, error) { - if m.listener == nil { - return nil, errors.New("listenter not defined") - } - select { - case <-m.doneCh: - return nil, errors.New("transport closed") - default: - } - - endpoint := &Endpoint{ - Protocol: MConnProtocol, - } - if addr, ok := m.listener.Addr().(*net.TCPAddr); ok { - endpoint.IP = addr.IP - endpoint.Port = uint16(addr.Port) - } - return endpoint, nil -} - -// Listen asynchronously listens for inbound connections on the given endpoint. -// It must be called exactly once before calling Accept(), and the caller must -// call Close() to shut down the listener. -// -// FIXME: Listen currently only supports listening on a single endpoint, it -// might be useful to support listening on multiple addresses (e.g. IPv4 and -// IPv6, or a private and public address) via multiple Listen() calls. -func (m *MConnTransport) Listen(endpoint *Endpoint) error { - if m.listener != nil { - return errors.New("transport is already listening") - } - if err := m.validateEndpoint(endpoint); err != nil { +func (m *MConnTransport) Run(ctx context.Context) error { + if err := m.validateEndpoint(m.endpoint); err != nil { return err } - - listener, err := net.Listen("tcp", net.JoinHostPort( - endpoint.IP.String(), strconv.Itoa(int(endpoint.Port)))) + listener, err := tcp.Listen(m.endpoint.Addr) if err != nil { - return err + return fmt.Errorf("net.Listen(): %w", err) } + close(m.started) // signal that we are listening if m.options.MaxAcceptedConnections > 0 { // FIXME: This will establish the inbound connection but simply hang it // until another connection is released. It would probably be better to @@ -127,84 +102,60 @@ func (m *MConnTransport) Listen(endpoint *Endpoint) error { // This was just carried over from the legacy P2P stack. listener = netutil.LimitListener(listener, int(m.options.MaxAcceptedConnections)) } - m.listener = listener + return scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + s.Spawn(func() error { + <-ctx.Done() + listener.Close() + return nil + }) + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + } + mconn := newMConnConnection(m.logger, conn, m.mConnConfig, m.channelDescs) + if err := utils.Send(ctx, m.listener, mconn); err != nil { + mconn.Close() + return err + } + } + }) +} - return nil +// String implements Transport. +func (m *MConnTransport) String() string { + return string(MConnProtocol) +} + +// Protocols implements Transport. We support tcp for backwards-compatibility. +func (m *MConnTransport) Protocols() []Protocol { + return []Protocol{MConnProtocol, TCPProtocol} } // Accept implements Transport. func (m *MConnTransport) Accept(ctx context.Context) (Connection, error) { - if m.listener == nil { - return nil, errors.New("transport is not listening") - } - - conCh := make(chan net.Conn) - errCh := make(chan error) - go func() { - tcpConn, err := m.listener.Accept() - if err != nil { - select { - case errCh <- err: - case <-ctx.Done(): - } - } - select { - case conCh <- tcpConn: - case <-ctx.Done(): - } - }() - - select { - case <-ctx.Done(): - m.listener.Close() - return nil, io.EOF - case <-m.doneCh: - m.listener.Close() - return nil, io.EOF - case err := <-errCh: - return nil, err - case tcpConn := <-conCh: - return newMConnConnection(m.logger, tcpConn, m.mConnConfig, m.channelDescs), nil - } - + return utils.Recv(ctx, m.listener) } // Dial implements Transport. -func (m *MConnTransport) Dial(ctx context.Context, endpoint *Endpoint) (Connection, error) { +func (m *MConnTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) { if err := m.validateEndpoint(endpoint); err != nil { return nil, err } - if endpoint.Port == 0 { - endpoint.Port = 26657 + if endpoint.Addr.Port() == 0 { + endpoint.Addr = netip.AddrPortFrom(endpoint.Addr.Addr(), 26657) } - dialer := net.Dialer{} - tcpConn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort( - endpoint.IP.String(), strconv.Itoa(int(endpoint.Port)))) + tcpConn, err := dialer.DialContext(ctx, "tcp", endpoint.Addr.String()) if err != nil { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - return nil, err - } + return nil, fmt.Errorf("dialer.DialContext(%v): %w", endpoint.Addr, err) } - return newMConnConnection(m.logger, tcpConn, m.mConnConfig, m.channelDescs), nil } -// Close implements Transport. -func (m *MConnTransport) Close() error { - var err error - m.closeOnce.Do(func() { - close(m.doneCh) - if m.listener != nil { - err = m.listener.Close() - } - }) - return err -} - // SetChannels sets the channel descriptors to be used when // establishing a connection. // @@ -216,19 +167,21 @@ func (m *MConnTransport) AddChannelDescriptors(channelDesc []*ChannelDescriptor) m.channelDescs = append(m.channelDescs, channelDesc...) } +type InvalidEndpointErr struct{ error } + // validateEndpoint validates an endpoint. -func (m *MConnTransport) validateEndpoint(endpoint *Endpoint) error { +func (m *MConnTransport) validateEndpoint(endpoint Endpoint) error { if err := endpoint.Validate(); err != nil { - return err + return InvalidEndpointErr{err} } if endpoint.Protocol != MConnProtocol && endpoint.Protocol != TCPProtocol { - return fmt.Errorf("unsupported protocol %q", endpoint.Protocol) + return InvalidEndpointErr{fmt.Errorf("unsupported protocol %q", endpoint.Protocol)} } - if len(endpoint.IP) == 0 { - return errors.New("endpoint has no IP address") + if !endpoint.Addr.IsValid() { + return InvalidEndpointErr{errors.New("endpoint has invalid address")} } if endpoint.Path != "" { - return fmt.Errorf("endpoints with path not supported (got %q)", endpoint.Path) + return InvalidEndpointErr{fmt.Errorf("endpoints with path not supported (got %q)", endpoint.Path)} } return nil } @@ -400,7 +353,7 @@ func (c *mConnConnection) onReceive(ctx context.Context, chID ChannelID, payload // onError is a callback for MConnection errors. The error is passed via errorCh // to ReceiveMessage (but not SendMessage, for legacy P2P stack behavior). -func (c *mConnConnection) onError(ctx context.Context, e interface{}) { +func (c *mConnConnection) onError(ctx context.Context, e any) { err, ok := e.(error) if !ok { err = fmt.Errorf("%v", err) @@ -428,13 +381,10 @@ func (c *mConnConnection) SendMessage(ctx context.Context, chID ChannelID, msg [ select { case err := <-c.errorCh: return err - case <-ctx.Done(): - return io.EOF default: - if ok := c.mconn.Send(chID, msg); !ok { - return errors.New("sending message timed out") + if err := c.mconn.Send(ctx, chID, msg); err != nil { + return fmt.Errorf("m.mconn.Send(%v): %w", chID, err) } - return nil } } @@ -459,8 +409,7 @@ func (c *mConnConnection) LocalEndpoint() Endpoint { Protocol: MConnProtocol, } if addr, ok := c.conn.LocalAddr().(*net.TCPAddr); ok { - endpoint.IP = addr.IP - endpoint.Port = uint16(addr.Port) + endpoint.Addr = addr.AddrPort() } return endpoint } @@ -471,8 +420,7 @@ func (c *mConnConnection) RemoteEndpoint() Endpoint { Protocol: MConnProtocol, } if addr, ok := c.conn.RemoteAddr().(*net.TCPAddr); ok { - endpoint.IP = addr.IP - endpoint.Port = uint16(addr.Port) + endpoint.Addr = addr.AddrPort() } return endpoint } diff --git a/internal/p2p/transport_mconn_test.go b/internal/p2p/transport_mconn_test.go index 18d7f4fb3..64121809c 100644 --- a/internal/p2p/transport_mconn_test.go +++ b/internal/p2p/transport_mconn_test.go @@ -1,13 +1,19 @@ package p2p_test import ( + "context" + "errors" + "fmt" "io" - "net" + "net/netip" "testing" "time" "github.com/fortytw2/leaktest" "github.com/stretchr/testify/require" + "github.com/tendermint/tendermint/libs/utils" + "github.com/tendermint/tendermint/libs/utils/scope" + "github.com/tendermint/tendermint/libs/utils/tcp" "github.com/tendermint/tendermint/internal/p2p" "github.com/tendermint/tendermint/internal/p2p/conn" @@ -17,207 +23,211 @@ import ( // Transports are mainly tested by common tests in transport_test.go, we // register a transport factory here to get included in those tests. func init() { - testTransports["mconn"] = func(t *testing.T) p2p.Transport { - transport := p2p.NewMConnTransport( - log.NewNopLogger(), - conn.DefaultMConnConfig(), - []*p2p.ChannelDescriptor{{ID: chID, Priority: 1}}, - p2p.MConnTransportOptions{}, - ) - err := transport.Listen(&p2p.Endpoint{ - Protocol: p2p.MConnProtocol, - IP: net.IPv4(127, 0, 0, 1), - Port: 0, // assign a random port - }) - require.NoError(t, err) - - t.Cleanup(func() { _ = transport.Close() }) - - return transport + testTransports["mconn"] = func() func(context.Context) p2p.Transport { + return func(ctx context.Context) p2p.Transport { + transport := p2p.NewMConnTransport( + log.NewNopLogger(), + p2p.Endpoint{ + Protocol: p2p.MConnProtocol, + Addr: tcp.TestReserveAddr(), + }, + conn.DefaultMConnConfig(), + []*p2p.ChannelDescriptor{{ID: chID, Priority: 1}}, + p2p.MConnTransportOptions{}, + ) + go func() { + if err := transport.Run(ctx); err != nil { + panic(err) + } + }() + if err := transport.WaitForStart(ctx); err != nil { + panic(err) + } + return transport + } } } -func TestMConnTransport_AcceptBeforeListen(t *testing.T) { - transport := p2p.NewMConnTransport( - log.NewNopLogger(), - conn.DefaultMConnConfig(), - []*p2p.ChannelDescriptor{{ID: chID, Priority: 1}}, - p2p.MConnTransportOptions{ - MaxAcceptedConnections: 2, - }, - ) - t.Cleanup(func() { - _ = transport.Close() - }) - ctx := t.Context() - - _, err := transport.Accept(ctx) - require.Error(t, err) - require.NotEqual(t, io.EOF, err) // io.EOF should be returned after Close() +// Establishes a connection to the transport. +// Returns both ends of the connection. +func connect(ctx context.Context, tr *p2p.MConnTransport) (c1 p2p.Connection, c2 p2p.Connection, err error) { + defer func() { + if err != nil { + if c1 != nil { + c1.Close() + } + if c2 != nil { + c2.Close() + } + } + }() + // Here we are utilizing the fact that MConnTransport accepts connection proactively + // before Accept is called. + c1, err = tr.Dial(ctx, tr.Endpoint()) + if err != nil { + return nil, nil, fmt.Errorf("Dial(): %w", err) + } + c2, err = tr.Accept(ctx) + if err != nil { + return nil, nil, fmt.Errorf("Accept(): %w", err) + } + if got, want := c1.LocalEndpoint(), c2.RemoteEndpoint(); got != want { + return nil, nil, fmt.Errorf("c1.LocalEndpoint() = %v, want %v", got, want) + } + if got, want := c1.RemoteEndpoint(), c2.LocalEndpoint(); got != want { + return nil, nil, fmt.Errorf("c1.RemoteEndpoint() = %v, want %v", got, want) + } + return c1, c2, nil } func TestMConnTransport_AcceptMaxAcceptedConnections(t *testing.T) { ctx := t.Context() - transport := p2p.NewMConnTransport( log.NewNopLogger(), + p2p.Endpoint{ + Protocol: p2p.MConnProtocol, + Addr: tcp.TestReserveAddr(), + }, conn.DefaultMConnConfig(), []*p2p.ChannelDescriptor{{ID: chID, Priority: 1}}, p2p.MConnTransportOptions{ MaxAcceptedConnections: 2, }, ) - t.Cleanup(func() { - _ = transport.Close() - }) - err := transport.Listen(&p2p.Endpoint{ - Protocol: p2p.MConnProtocol, - IP: net.IPv4(127, 0, 0, 1), - }) - require.NoError(t, err) - endpoint, err := transport.Endpoint() - require.NoError(t, err) - require.NotNil(t, endpoint) - - // Start a goroutine to just accept any connections. - acceptCh := make(chan p2p.Connection, 10) - go func() { - for { - conn, err := transport.Accept(ctx) - if err != nil { - return + + err := utils.IgnoreCancel(scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + s.SpawnBgNamed("transport", func() error { return transport.Run(ctx) }) + if err := transport.WaitForStart(ctx); err != nil { + return err + } + t.Logf("The first two connections should be accepted just fine.") + + a1, a2, err := connect(ctx, transport) + if err != nil { + return fmt.Errorf("1st connect(): %w", err) + } + defer a1.Close() + defer a2.Close() + + b1, b2, err := connect(ctx, transport) + if err != nil { + return fmt.Errorf("2nd connect(): %w", err) + } + defer b1.Close() + defer b2.Close() + + t.Logf("The third connection will be dialed successfully, but the accept should not go through.") + c1, err := transport.Dial(ctx, transport.Endpoint()) + if err != nil { + return fmt.Errorf("3rd Dial(): %w", err) + } + defer c1.Close() + if err := utils.WithTimeout(ctx, time.Second, func(ctx context.Context) error { + c2, err := transport.Accept(ctx) + if err == nil { + c2.Close() } - acceptCh <- conn + return err + }); !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("Accept() over cap: %v, want %v", err, context.DeadlineExceeded) } - }() - // The first two connections should be accepted just fine. - dial1, err := transport.Dial(ctx, endpoint) - require.NoError(t, err) - defer dial1.Close() - accept1 := <-acceptCh - defer accept1.Close() - require.Equal(t, dial1.LocalEndpoint(), accept1.RemoteEndpoint()) - - dial2, err := transport.Dial(ctx, endpoint) - require.NoError(t, err) - defer dial2.Close() - accept2 := <-acceptCh - defer accept2.Close() - require.Equal(t, dial2.LocalEndpoint(), accept2.RemoteEndpoint()) - - // The third connection will be dialed successfully, but the accept should - // not go through. - dial3, err := transport.Dial(ctx, endpoint) - require.NoError(t, err) - defer dial3.Close() - select { - case <-acceptCh: - require.Fail(t, "unexpected accept") - case <-time.After(time.Second): + t.Logf("once either of the other connections are closed, the accept goes through.") + a1.Close() + a2.Close() // we close both a1 and a2 to make sure the connection count drops below the limit. + c2, err := transport.Accept(ctx) + if err != nil { + return fmt.Errorf("3rd Accept(): %w", err) + } + defer c2.Close() + return nil + })) + if err != nil { + t.Fatal(err) } - - // However, once either of the other connections are closed, the accept - // goes through. - require.NoError(t, accept1.Close()) - accept3 := <-acceptCh - defer accept3.Close() - require.Equal(t, dial3.LocalEndpoint(), accept3.RemoteEndpoint()) } func TestMConnTransport_Listen(t *testing.T) { - ctx := t.Context() + reservePort := func(ip netip.Addr) netip.AddrPort { + addr := tcp.TestReserveAddr() + return netip.AddrPortFrom(ip, addr.Port()) + } testcases := []struct { - endpoint *p2p.Endpoint + endpoint p2p.Endpoint ok bool }{ // Valid v4 and v6 addresses, with mconn and tcp protocols. - {&p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv4zero}, true}, - {&p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv4(127, 0, 0, 1)}, true}, - {&p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv6zero}, true}, - {&p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv6loopback}, true}, - {&p2p.Endpoint{Protocol: p2p.TCPProtocol, IP: net.IPv4zero}, true}, + {p2p.Endpoint{Protocol: p2p.MConnProtocol, Addr: reservePort(netip.IPv4Unspecified())}, true}, + {p2p.Endpoint{Protocol: p2p.MConnProtocol, Addr: reservePort(tcp.IPv4Loopback())}, true}, + {p2p.Endpoint{Protocol: p2p.MConnProtocol, Addr: reservePort(netip.IPv6Unspecified())}, true}, + {p2p.Endpoint{Protocol: p2p.MConnProtocol, Addr: reservePort(netip.IPv6Loopback())}, true}, + {p2p.Endpoint{Protocol: p2p.TCPProtocol, Addr: reservePort(netip.IPv4Unspecified())}, true}, // Invalid endpoints. - {&p2p.Endpoint{}, false}, - {&p2p.Endpoint{Protocol: p2p.MConnProtocol, Path: "foo"}, false}, - {&p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv4zero, Path: "foo"}, false}, + {p2p.Endpoint{}, false}, + {p2p.Endpoint{Protocol: p2p.MConnProtocol, Path: "foo"}, false}, + {p2p.Endpoint{Protocol: p2p.MConnProtocol, Addr: reservePort(netip.IPv4Unspecified()), Path: "foo"}, false}, } for _, tc := range testcases { t.Run(tc.endpoint.String(), func(t *testing.T) { + ctx := t.Context() t.Cleanup(leaktest.Check(t)) transport := p2p.NewMConnTransport( log.NewNopLogger(), + tc.endpoint, conn.DefaultMConnConfig(), []*p2p.ChannelDescriptor{{ID: chID, Priority: 1}}, p2p.MConnTransportOptions{}, ) + if got, want := transport.Endpoint(), tc.endpoint; got != want { + t.Fatalf("transport.Endpoint() = %v, want %v", got, want) + } - // Transport should not listen on any endpoints yet. - endpoint, err := transport.Endpoint() - require.Error(t, err) - require.Nil(t, endpoint) - - // Start listening, and check any expected errors. - err = transport.Listen(tc.endpoint) + err := utils.IgnoreCancel(scope.Run(ctx, func(ctx context.Context, s scope.Scope) error { + s.SpawnBgNamed("transport", func() error { return transport.Run(ctx) }) + if err := transport.WaitForStart(ctx); err != nil { + return err + } + s.SpawnNamed("dial", func() error { + conn, err := transport.Dial(ctx, tc.endpoint) + if err != nil { + return fmt.Errorf("transport.Dial(): %w", err) + } + if err := conn.Close(); err != nil { + return fmt.Errorf("conn.Close(): %w", err) + } + if _, _, err := conn.ReceiveMessage(ctx); !errors.Is(err, io.EOF) { + return fmt.Errorf("conn.ReceiveMessage() = %v, want %v", err, io.EOF) + } + return nil + }) + s.SpawnNamed("accept", func() error { + conn, err := transport.Accept(ctx) + if err != nil { + return fmt.Errorf("transport.Accept(): %w", err) + } + if err := conn.Close(); err != nil { + return fmt.Errorf("conn.Close(): %w", err) + } + if _, _, err := conn.ReceiveMessage(ctx); !errors.Is(err, io.EOF) { + return fmt.Errorf("conn.ReceiveMessage() = %v, want %v", err, io.EOF) + } + return nil + }) + return nil + })) if !tc.ok { - require.Error(t, err) - return + var want p2p.InvalidEndpointErr + if !errors.As(err, &want) { + t.Fatalf("error = %v, want %T", err, want) + } + } else if err != nil { + t.Fatal(err) } - require.NoError(t, err) - - // Check the endpoint. - endpoint, err = transport.Endpoint() - require.NoError(t, err) - require.NotNil(t, endpoint) - - require.Equal(t, p2p.MConnProtocol, endpoint.Protocol) - if tc.endpoint.IP.IsUnspecified() { - require.True(t, endpoint.IP.IsUnspecified(), - "expected unspecified IP, got %v", endpoint.IP) - } else { - require.True(t, tc.endpoint.IP.Equal(endpoint.IP), - "expected %v, got %v", tc.endpoint.IP, endpoint.IP) - } - require.NotZero(t, endpoint.Port) - require.Empty(t, endpoint.Path) - - dialedChan := make(chan struct{}) - - var peerConn p2p.Connection - go func() { - // Dialing the endpoint should work. - var err error - ctx := t.Context() - - peerConn, err = transport.Dial(ctx, endpoint) - require.NoError(t, err) - close(dialedChan) - }() - - conn, err := transport.Accept(ctx) - require.NoError(t, err) - _ = conn.Close() - <-dialedChan - - // closing the connection should not error - require.NoError(t, peerConn.Close()) - - // try to read from the connection should error - _, _, err = peerConn.ReceiveMessage(ctx) - require.Error(t, err) - - // Trying to listen again should error. - err = transport.Listen(tc.endpoint) - require.Error(t, err) - - // close the transport - _ = transport.Close() - // Dialing the closed endpoint should error - _, err = transport.Dial(ctx, endpoint) + _, err = transport.Dial(ctx, tc.endpoint) require.Error(t, err) }) } diff --git a/internal/p2p/transport_memory.go b/internal/p2p/transport_memory.go index 3eb4c5b51..f4fa08bf5 100644 --- a/internal/p2p/transport_memory.go +++ b/internal/p2p/transport_memory.go @@ -5,11 +5,12 @@ import ( "errors" "fmt" "io" - "net" + "net/netip" "sync" "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/libs/utils" "github.com/tendermint/tendermint/types" ) @@ -61,22 +62,6 @@ func (n *MemoryNetwork) GetTransport(id types.NodeID) *MemoryTransport { return n.transports[id] } -// RemoveTransport removes a transport from the network and closes it. -func (n *MemoryNetwork) RemoveTransport(id types.NodeID) { - n.mtx.Lock() - t, ok := n.transports[id] - delete(n.transports, id) - n.mtx.Unlock() - - if ok { - // Close may recursively call RemoveTransport() again, but this is safe - // because we've already removed the transport from the map above. - if err := t.Close(); err != nil { - n.logger.Error("failed to close memory transport", "id", id, "err", err) - } - } -} - // Size returns the number of transports in the network. func (n *MemoryNetwork) Size() int { return len(n.transports) @@ -101,16 +86,12 @@ type MemoryTransport struct { // newMemoryTransport creates a new MemoryTransport. This is for internal use by // MemoryNetwork, use MemoryNetwork.CreateTransport() instead. func newMemoryTransport(network *MemoryNetwork, nodeID types.NodeID) *MemoryTransport { - once := &sync.Once{} - closeCh := make(chan struct{}) return &MemoryTransport{ logger: network.logger.With("local", nodeID), network: network, nodeID: nodeID, bufferSize: network.bufferSize, acceptCh: make(chan *MemoryConnection), - closeCh: closeCh, - closeFn: func() { once.Do(func() { close(closeCh) }) }, } } @@ -119,7 +100,13 @@ func (t *MemoryTransport) String() string { return string(MemoryProtocol) } -func (*MemoryTransport) Listen(*Endpoint) error { return nil } +func (t *MemoryTransport) Run(ctx context.Context) error { + <-ctx.Done() + t.network.mtx.Lock() + delete(t.network.transports, t.nodeID) + t.network.mtx.Unlock() + return nil +} func (t *MemoryTransport) AddChannelDescriptors([]*ChannelDescriptor) {} @@ -129,36 +116,23 @@ func (t *MemoryTransport) Protocols() []Protocol { } // Endpoints implements Transport. -func (t *MemoryTransport) Endpoint() (*Endpoint, error) { - if n := t.network.GetTransport(t.nodeID); n == nil { - return nil, errors.New("node not defined") - } - - return &Endpoint{ +func (t *MemoryTransport) Endpoint() Endpoint { + return Endpoint{ Protocol: MemoryProtocol, Path: string(t.nodeID), // An arbitrary IP and port is used in order for the pex // reactor to be able to send addresses to one another. - IP: net.IPv4zero, - Port: 0, - }, nil + Addr: netip.AddrPort{}, + } } // Accept implements Transport. func (t *MemoryTransport) Accept(ctx context.Context) (Connection, error) { - select { - case <-t.closeCh: - return nil, io.EOF - case conn := <-t.acceptCh: - t.logger.Info("accepted connection", "remote", conn.RemoteEndpoint().Path) - return conn, nil - case <-ctx.Done(): - return nil, io.EOF - } + return utils.Recv(ctx, t.acceptCh) } // Dial implements Transport. -func (t *MemoryTransport) Dial(ctx context.Context, endpoint *Endpoint) (Connection, error) { +func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) { if endpoint.Protocol != MemoryProtocol { return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol) } @@ -194,19 +168,10 @@ func (t *MemoryTransport) Dial(ctx context.Context, endpoint *Endpoint) (Connect inConn.closeCh = closeCh inConn.closeFn = closeFn - select { - case peer.acceptCh <- inConn: - return outConn, nil - case <-ctx.Done(): - return nil, io.EOF + if err := utils.Send(ctx, peer.acceptCh, inConn); err != nil { + return nil, err } -} - -// Close implements Transport. -func (t *MemoryTransport) Close() error { - t.network.RemoveTransport(t.nodeID) - t.closeFn() - return nil + return outConn, nil } // MemoryConnection is an in-memory connection between two transport endpoints. diff --git a/internal/p2p/transport_memory_test.go b/internal/p2p/transport_memory_test.go index 33d96cdb8..334101bd8 100644 --- a/internal/p2p/transport_memory_test.go +++ b/internal/p2p/transport_memory_test.go @@ -2,10 +2,8 @@ package p2p_test import ( "bytes" + "context" "encoding/hex" - "testing" - - "github.com/stretchr/testify/require" "github.com/tendermint/tendermint/internal/p2p" "github.com/tendermint/tendermint/libs/log" @@ -15,22 +13,21 @@ import ( // Transports are mainly tested by common tests in transport_test.go, we // register a transport factory here to get included in those tests. func init() { - var network *p2p.MemoryNetwork // shared by transports in the same test - - testTransports["memory"] = func(t *testing.T) p2p.Transport { - if network == nil { - network = p2p.NewMemoryNetwork(log.NewNopLogger(), 1) + testTransports["memory"] = func() func(context.Context) p2p.Transport { + network := p2p.NewMemoryNetwork(log.NewNopLogger(), 1) + return func(ctx context.Context) p2p.Transport { + i := byte(network.Size()) + nodeID, err := types.NewNodeID(hex.EncodeToString(bytes.Repeat([]byte{i<<4 + i}, 20))) + if err != nil { + panic(err) + } + t := network.CreateTransport(nodeID) + go func() { + if err := t.Run(ctx); err != nil { + panic(err) + } + }() + return t } - i := byte(network.Size()) - nodeID, err := types.NewNodeID(hex.EncodeToString(bytes.Repeat([]byte{i<<4 + i}, 20))) - require.NoError(t, err) - transport := network.CreateTransport(nodeID) - - t.Cleanup(func() { - require.NoError(t, transport.Close()) - network = nil // set up a new memory network for the next test - }) - - return transport } } diff --git a/internal/p2p/transport_test.go b/internal/p2p/transport_test.go index ccb783f1d..a088d0a4a 100644 --- a/internal/p2p/transport_test.go +++ b/internal/p2p/transport_test.go @@ -3,7 +3,7 @@ package p2p_test import ( "context" "io" - "net" + "net/netip" "testing" "time" @@ -18,71 +18,39 @@ import ( ) // transportFactory is used to set up transports for tests. -type transportFactory func(t *testing.T) p2p.Transport +type transportFactory = func(ctx context.Context) p2p.Transport // testTransports is a registry of transport factories for withTransports(). -var testTransports = map[string]transportFactory{} +var testTransports = map[string](func() transportFactory){} // withTransports is a test helper that runs a test against all transports // registered in testTransports. func withTransports(t *testing.T, tester func(*testing.T, transportFactory)) { t.Helper() for name, transportFactory := range testTransports { - transportFactory := transportFactory t.Run(name, func(t *testing.T) { t.Cleanup(leaktest.Check(t)) - tester(t, transportFactory) + tester(t, transportFactory()) }) } } -func TestTransport_AcceptClose(t *testing.T) { - // Just test accept unblock on close, happy path is tested widely elsewhere. - withTransports(t, func(t *testing.T, makeTransport transportFactory) { - ctx := t.Context() - a := makeTransport(t) - opctx, opcancel := context.WithTimeout(ctx, 200*time.Millisecond) - defer opcancel() - - _, err := a.Accept(opctx) - require.Error(t, err) - require.Equal(t, io.EOF, err) - - <-opctx.Done() - _ = a.Close() - - // Closed transport should return error immediately, - // because the transport is closed. We use the base - // context (ctx) rather than the operation context - // (opctx) because using the later would mean this - // could error because the context was canceled. - _, err = a.Accept(ctx) - require.Error(t, err) - require.Equal(t, io.EOF, err) - }) -} - func TestTransport_DialEndpoints(t *testing.T) { ipTestCases := []struct { - ip net.IP + ip netip.Addr ok bool }{ - {net.IPv4zero, true}, - {net.IPv6zero, true}, - - {nil, false}, - {net.IPv4bcast, false}, - {net.IPv4allsys, false}, - {[]byte{1, 2, 3}, false}, - {[]byte{1, 2, 3, 4, 5}, false}, + {netip.IPv4Unspecified(), true}, + {netip.IPv6Unspecified(), true}, + + {netip.AddrFrom4([4]byte{255, 255, 255, 255}), false}, + {netip.AddrFrom4([4]byte{224, 0, 0, 1}), false}, } withTransports(t, func(t *testing.T, makeTransport transportFactory) { ctx := t.Context() - a := makeTransport(t) - endpoint, err := a.Endpoint() - require.NoError(t, err) - require.NotNil(t, endpoint) + a := makeTransport(ctx) + endpoint := a.Endpoint() // Spawn a goroutine to simply accept any connections until closed. go func() { @@ -101,28 +69,27 @@ func TestTransport_DialEndpoints(t *testing.T) { require.NoError(t, conn.Close()) // Dialing empty endpoint should error. - _, err = a.Dial(ctx, &p2p.Endpoint{}) + _, err = a.Dial(ctx, p2p.Endpoint{}) require.Error(t, err) // Dialing without protocol should error. - noProtocol := *endpoint + noProtocol := endpoint noProtocol.Protocol = "" - _, err = a.Dial(ctx, &noProtocol) + _, err = a.Dial(ctx, noProtocol) require.Error(t, err) // Dialing with invalid protocol should error. - fooProtocol := *endpoint + fooProtocol := endpoint fooProtocol.Protocol = "foo" - _, err = a.Dial(ctx, &fooProtocol) + _, err = a.Dial(ctx, fooProtocol) require.Error(t, err) // Tests for networked endpoints (with IP). - if len(endpoint.IP) > 0 && endpoint.Protocol != p2p.MemoryProtocol { + if endpoint.Addr != (netip.AddrPort{}) && endpoint.Protocol != p2p.MemoryProtocol { for _, tc := range ipTestCases { t.Run(tc.ip.String(), func(t *testing.T) { e := endpoint - require.NotNil(t, e) - e.IP = tc.ip + e.Addr = netip.AddrPortFrom(tc.ip, endpoint.Addr.Port()) conn, err := a.Dial(ctx, e) if tc.ok { require.NoError(t, err) @@ -135,8 +102,7 @@ func TestTransport_DialEndpoints(t *testing.T) { // Non-networked endpoints should error. noIP := endpoint - noIP.IP = nil - noIP.Port = 0 + noIP.Addr = netip.AddrPort{} noIP.Path = "foo" _, err := a.Dial(ctx, noIP) require.Error(t, err) @@ -151,95 +117,37 @@ func TestTransport_DialEndpoints(t *testing.T) { }) } -func TestTransport_Dial(t *testing.T) { - // Most just tests dial failures, happy path is tested widely elsewhere. - withTransports(t, func(t *testing.T, makeTransport transportFactory) { - ctx := t.Context() - a := makeTransport(t) - b := makeTransport(t) - - aEndpoint, err := a.Endpoint() - require.NoError(t, err) - require.NotNil(t, aEndpoint) - bEndpoint, err := b.Endpoint() - require.NoError(t, err) - require.NotNil(t, bEndpoint) - - // Context cancellation should error. We can't test timeouts since we'd - // need a non-responsive endpoint. - cancelCtx, cancel := context.WithCancel(ctx) - cancel() - _, err = a.Dial(cancelCtx, bEndpoint) - require.Error(t, err) - - // Unavailable endpoint should error. - err = b.Close() - require.NoError(t, err) - _, err = a.Dial(ctx, bEndpoint) - require.Error(t, err) - - // Dialing from a closed transport should still work. - errCh := make(chan error, 1) - go func() { - conn, err := a.Accept(ctx) - if err == nil { - _ = conn.Close() - } - errCh <- err - }() - conn, err := b.Dial(ctx, aEndpoint) - require.NoError(t, err) - require.NoError(t, conn.Close()) - require.NoError(t, <-errCh) - }) -} - func TestTransport_Endpoints(t *testing.T) { - withTransports(t, func(t *testing.T, makeTransport transportFactory) { - a := makeTransport(t) - b := makeTransport(t) + ctx := t.Context() + a := makeTransport(ctx) + b := makeTransport(ctx) // Both transports return valid and different endpoints. - aEndpoint, err := a.Endpoint() - require.NoError(t, err) - require.NotNil(t, aEndpoint) - bEndpoint, err := b.Endpoint() - require.NoError(t, err) - require.NotNil(t, bEndpoint) + aEndpoint := a.Endpoint() + bEndpoint := b.Endpoint() require.NotEqual(t, aEndpoint, bEndpoint) - for _, endpoint := range []*p2p.Endpoint{aEndpoint, bEndpoint} { + for _, endpoint := range []p2p.Endpoint{aEndpoint, bEndpoint} { err := endpoint.Validate() require.NoError(t, err, "invalid endpoint %q", endpoint) } - - // When closed, the transport should no longer return any endpoints. - require.NoError(t, a.Close()) - aEndpoint, err = a.Endpoint() - require.Error(t, err) - require.Nil(t, aEndpoint) - bEndpoint, err = b.Endpoint() - require.NoError(t, err) - require.NotNil(t, bEndpoint) }) } func TestTransport_Protocols(t *testing.T) { withTransports(t, func(t *testing.T, makeTransport transportFactory) { - a := makeTransport(t) + ctx := t.Context() + a := makeTransport(ctx) protocols := a.Protocols() - endpoint, err := a.Endpoint() - require.NoError(t, err) + endpoint := a.Endpoint() require.NotEmpty(t, protocols) - require.NotNil(t, endpoint) - require.Contains(t, protocols, endpoint.Protocol) }) } func TestTransport_String(t *testing.T) { withTransports(t, func(t *testing.T, makeTransport transportFactory) { - a := makeTransport(t) + a := makeTransport(t.Context()) require.NotEmpty(t, a.String()) }) } @@ -247,8 +155,8 @@ func TestTransport_String(t *testing.T) { func TestConnection_Handshake(t *testing.T) { withTransports(t, func(t *testing.T, makeTransport transportFactory) { ctx := t.Context() - a := makeTransport(t) - b := makeTransport(t) + a := makeTransport(ctx) + b := makeTransport(ctx) ab, ba := dialAccept(ctx, t, a, b) // A handshake should pass the given keys and NodeInfo. @@ -299,8 +207,8 @@ func TestConnection_Handshake(t *testing.T) { func TestConnection_HandshakeCancel(t *testing.T) { withTransports(t, func(t *testing.T, makeTransport transportFactory) { ctx := t.Context() - a := makeTransport(t) - b := makeTransport(t) + a := makeTransport(ctx) + b := makeTransport(ctx) // Handshake should error on context cancellation. ab, ba := dialAccept(ctx, t, a, b) @@ -327,8 +235,8 @@ func TestConnection_HandshakeCancel(t *testing.T) { func TestConnection_FlushClose(t *testing.T) { withTransports(t, func(t *testing.T, makeTransport transportFactory) { ctx := t.Context() - a := makeTransport(t) - b := makeTransport(t) + a := makeTransport(ctx) + b := makeTransport(ctx) ab, _ := dialAcceptHandshake(ctx, t, a, b) err := ab.Close() @@ -346,8 +254,8 @@ func TestConnection_FlushClose(t *testing.T) { func TestConnection_LocalRemoteEndpoint(t *testing.T) { withTransports(t, func(t *testing.T, makeTransport transportFactory) { ctx := t.Context() - a := makeTransport(t) - b := makeTransport(t) + a := makeTransport(ctx) + b := makeTransport(ctx) ab, ba := dialAcceptHandshake(ctx, t, a, b) // Local and remote connection endpoints correspond to each other. @@ -362,8 +270,8 @@ func TestConnection_SendReceive(t *testing.T) { withTransports(t, func(t *testing.T, makeTransport transportFactory) { ctx := t.Context() - a := makeTransport(t) - b := makeTransport(t) + a := makeTransport(ctx) + b := makeTransport(ctx) ab, ba := dialAcceptHandshake(ctx, t, a, b) // Can send and receive a to b. @@ -383,19 +291,6 @@ func TestConnection_SendReceive(t *testing.T) { require.NoError(t, err) require.Equal(t, []byte("bar"), msg) - // Connections should still be active after closing the transports. - err = a.Close() - require.NoError(t, err) - err = b.Close() - require.NoError(t, err) - - err = ab.SendMessage(ctx, chID, []byte("still here")) - require.NoError(t, err) - ch, msg, err = ba.ReceiveMessage(ctx) - require.NoError(t, err) - require.Equal(t, chID, ch) - require.Equal(t, []byte("still here"), msg) - // Close one side of the connection. Both sides should then error // with io.EOF when trying to send or receive. err = ba.Close() @@ -421,8 +316,8 @@ func TestConnection_SendReceive(t *testing.T) { func TestConnection_String(t *testing.T) { withTransports(t, func(t *testing.T, makeTransport transportFactory) { ctx := t.Context() - a := makeTransport(t) - b := makeTransport(t) + a := makeTransport(ctx) + b := makeTransport(ctx) ab, _ := dialAccept(ctx, t, a, b) require.NotEmpty(t, ab.String()) }) @@ -430,10 +325,9 @@ func TestConnection_String(t *testing.T) { func TestEndpoint_NodeAddress(t *testing.T) { var ( - ip4 = []byte{1, 2, 3, 4} - ip4in6 = net.IPv4(1, 2, 3, 4) - ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01} - id = types.NodeID("00112233445566778899aabbccddeeff00112233") + ip4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) + ip6 = netip.AddrFrom16([16]byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}) + id = types.NodeID("00112233445566778899aabbccddeeff00112233") ) testcases := []struct { @@ -442,15 +336,11 @@ func TestEndpoint_NodeAddress(t *testing.T) { }{ // Valid endpoints. { - p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "path"}, - p2p.NodeAddress{Protocol: "tcp", Hostname: "1.2.3.4", Port: 8080, Path: "path"}, - }, - { - p2p.Endpoint{Protocol: "tcp", IP: ip4in6, Port: 8080, Path: "path"}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4, 8080), Path: "path"}, p2p.NodeAddress{Protocol: "tcp", Hostname: "1.2.3.4", Port: 8080, Path: "path"}, }, { - p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080, Path: "path"}, + p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6, 8080), Path: "path"}, p2p.NodeAddress{Protocol: "tcp", Hostname: "b10c::1", Port: 8080, Path: "path"}, }, { @@ -465,8 +355,7 @@ func TestEndpoint_NodeAddress(t *testing.T) { // Partial (invalid) endpoints. {p2p.Endpoint{}, p2p.NodeAddress{}}, {p2p.Endpoint{Protocol: "tcp"}, p2p.NodeAddress{Protocol: "tcp"}}, - {p2p.Endpoint{IP: net.IPv4(1, 2, 3, 4)}, p2p.NodeAddress{Hostname: "1.2.3.4"}}, - {p2p.Endpoint{Port: 8080}, p2p.NodeAddress{}}, + {p2p.Endpoint{Addr: netip.AddrPortFrom(ip4, 0)}, p2p.NodeAddress{Hostname: "1.2.3.4"}}, {p2p.Endpoint{Path: "path"}, p2p.NodeAddress{Path: "path"}}, } for _, tc := range testcases { @@ -484,9 +373,8 @@ func TestEndpoint_NodeAddress(t *testing.T) { func TestEndpoint_String(t *testing.T) { var ( - ip4 = []byte{1, 2, 3, 4} - ip4in6 = net.IPv4(1, 2, 3, 4) - ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01} + ip4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) + ip6 = netip.AddrFrom16([16]byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}) nodeID = types.NodeID("00112233445566778899aabbccddeeff00112233") ) @@ -500,24 +388,23 @@ func TestEndpoint_String(t *testing.T) { {p2p.Endpoint{Protocol: "file", Path: "👋"}, "file:///%F0%9F%91%8B"}, // IPv4 endpoints. - {p2p.Endpoint{Protocol: "tcp", IP: ip4}, "tcp://1.2.3.4"}, - {p2p.Endpoint{Protocol: "tcp", IP: ip4in6}, "tcp://1.2.3.4"}, - {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080}, "tcp://1.2.3.4:8080"}, - {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "/path"}, "tcp://1.2.3.4:8080/path"}, - {p2p.Endpoint{Protocol: "tcp", IP: ip4, Path: "path/👋"}, "tcp://1.2.3.4/path/%F0%9F%91%8B"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4, 0)}, "tcp://1.2.3.4"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4, 8080)}, "tcp://1.2.3.4:8080"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4, 8080), Path: "/path"}, "tcp://1.2.3.4:8080/path"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4, 0), Path: "path/👋"}, "tcp://1.2.3.4/path/%F0%9F%91%8B"}, // IPv6 endpoints. - {p2p.Endpoint{Protocol: "tcp", IP: ip6}, "tcp://b10c::1"}, - {p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080}, "tcp://[b10c::1]:8080"}, - {p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080, Path: "/path"}, "tcp://[b10c::1]:8080/path"}, - {p2p.Endpoint{Protocol: "tcp", IP: ip6, Path: "path/👋"}, "tcp://b10c::1/path/%F0%9F%91%8B"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6, 0)}, "tcp://b10c::1"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6, 8080)}, "tcp://[b10c::1]:8080"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6, 8080), Path: "/path"}, "tcp://[b10c::1]:8080/path"}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6, 0), Path: "path/👋"}, "tcp://b10c::1/path/%F0%9F%91%8B"}, // Partial (invalid) endpoints. {p2p.Endpoint{}, ""}, {p2p.Endpoint{Protocol: "tcp"}, "tcp:"}, - {p2p.Endpoint{IP: []byte{1, 2, 3, 4}}, "1.2.3.4"}, - {p2p.Endpoint{IP: []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}}, "b10c::1"}, - {p2p.Endpoint{Port: 8080}, ""}, + {p2p.Endpoint{Addr: netip.AddrPortFrom(ip4, 0)}, "1.2.3.4"}, + {p2p.Endpoint{Addr: netip.AddrPortFrom(ip6, 0)}, "b10c::1"}, + {p2p.Endpoint{Addr: netip.AddrPortFrom(netip.IPv4Unspecified(), 8080)}, "0.0.0.0:8080"}, {p2p.Endpoint{Path: "foo"}, "/foo"}, } for _, tc := range testcases { @@ -528,30 +415,24 @@ func TestEndpoint_String(t *testing.T) { } func TestEndpoint_Validate(t *testing.T) { - var ( - ip4 = []byte{1, 2, 3, 4} - ip4in6 = net.IPv4(1, 2, 3, 4) - ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01} - ) + ip4 := netip.AddrFrom4([4]byte{1, 2, 3, 4}) + ip6 := netip.AddrFrom16([16]byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}) testcases := []struct { endpoint p2p.Endpoint expectValid bool }{ // Valid endpoints. - {p2p.Endpoint{Protocol: "tcp", IP: ip4}, true}, - {p2p.Endpoint{Protocol: "tcp", IP: ip4in6}, true}, - {p2p.Endpoint{Protocol: "tcp", IP: ip6}, true}, - {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8008}, true}, - {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "path"}, true}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4, 0)}, true}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip6, 0)}, true}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4, 8008)}, true}, + {p2p.Endpoint{Protocol: "tcp", Addr: netip.AddrPortFrom(ip4, 8080), Path: "path"}, true}, {p2p.Endpoint{Protocol: "memory", Path: "path"}, true}, // Invalid endpoints. {p2p.Endpoint{}, false}, - {p2p.Endpoint{IP: ip4}, false}, + {p2p.Endpoint{Addr: netip.AddrPortFrom(ip4, 0)}, false}, {p2p.Endpoint{Protocol: "tcp"}, false}, - {p2p.Endpoint{Protocol: "tcp", IP: []byte{1, 2, 3}}, false}, - {p2p.Endpoint{Protocol: "tcp", Port: 8080, Path: "path"}, false}, } for _, tc := range testcases { t.Run(tc.endpoint.String(), func(t *testing.T) { @@ -570,9 +451,7 @@ func TestEndpoint_Validate(t *testing.T) { func dialAccept(ctx context.Context, t *testing.T, a, b p2p.Transport) (p2p.Connection, p2p.Connection) { t.Helper() - endpoint, err := b.Endpoint() - require.NoError(t, err) - require.NotNil(t, endpoint, "peer not listening on any endpoints") + endpoint := b.Endpoint() ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() diff --git a/internal/statesync/reactor_test.go b/internal/statesync/reactor_test.go index a27fdac89..b61cb31ac 100644 --- a/internal/statesync/reactor_test.go +++ b/internal/statesync/reactor_test.go @@ -42,22 +42,22 @@ type reactorTestSuite struct { stateProvider *mocks.StateProvider snapshotChannel *p2p.Channel - snapshotInCh chan p2p.Envelope + snapshotInCh *p2p.Queue snapshotOutCh chan p2p.Envelope snapshotPeerErrCh chan p2p.PeerError chunkChannel *p2p.Channel - chunkInCh chan p2p.Envelope + chunkInCh *p2p.Queue chunkOutCh chan p2p.Envelope chunkPeerErrCh chan p2p.PeerError blockChannel *p2p.Channel - blockInCh chan p2p.Envelope + blockInCh *p2p.Queue blockOutCh chan p2p.Envelope blockPeerErrCh chan p2p.PeerError paramsChannel *p2p.Channel - paramsInCh chan p2p.Envelope + paramsInCh *p2p.Queue paramsOutCh chan p2p.Envelope paramsPeerErrCh chan p2p.PeerError @@ -73,7 +73,7 @@ func setup( t *testing.T, conn *clientmocks.Client, stateProvider *mocks.StateProvider, - chBuf uint, + chBuf int, ) *reactorTestSuite { t.Helper() @@ -82,16 +82,16 @@ func setup( } rts := &reactorTestSuite{ - snapshotInCh: make(chan p2p.Envelope, chBuf), + snapshotInCh: p2p.NewQueue(chBuf), snapshotOutCh: make(chan p2p.Envelope, chBuf), snapshotPeerErrCh: make(chan p2p.PeerError, chBuf), - chunkInCh: make(chan p2p.Envelope, chBuf), + chunkInCh: p2p.NewQueue(chBuf), chunkOutCh: make(chan p2p.Envelope, chBuf), chunkPeerErrCh: make(chan p2p.PeerError, chBuf), - blockInCh: make(chan p2p.Envelope, chBuf), + blockInCh: p2p.NewQueue(chBuf), blockOutCh: make(chan p2p.Envelope, chBuf), blockPeerErrCh: make(chan p2p.PeerError, chBuf), - paramsInCh: make(chan p2p.Envelope, chBuf), + paramsInCh: p2p.NewQueue(chBuf), paramsOutCh: make(chan p2p.Envelope, chBuf), paramsPeerErrCh: make(chan p2p.PeerError, chBuf), conn: conn, @@ -242,11 +242,11 @@ func TestReactor_ChunkRequest_InvalidRequest(t *testing.T) { rts := setup(ctx, t, nil, nil, 2) - rts.chunkInCh <- p2p.Envelope{ + rts.chunkInCh.Send(p2p.Envelope{ From: types.NodeID("aa"), ChannelID: ChunkChannel, Message: &ssproto.SnapshotsRequest{}, - } + }, 0) response := <-rts.chunkPeerErrCh require.Error(t, response.Err) @@ -297,11 +297,11 @@ func TestReactor_ChunkRequest(t *testing.T) { rts := setup(ctx, t, conn, nil, 2) - rts.chunkInCh <- p2p.Envelope{ + rts.chunkInCh.Send(p2p.Envelope{ From: types.NodeID("aa"), ChannelID: ChunkChannel, Message: tc.request, - } + }, 0) response := <-rts.chunkOutCh require.Equal(t, tc.expectResponse, response.Message) @@ -317,11 +317,11 @@ func TestReactor_SnapshotsRequest_InvalidRequest(t *testing.T) { rts := setup(ctx, t, nil, nil, 2) - rts.snapshotInCh <- p2p.Envelope{ + rts.snapshotInCh.Send(p2p.Envelope{ From: types.NodeID("aa"), ChannelID: SnapshotChannel, Message: &ssproto.ChunkRequest{}, - } + }, 0) response := <-rts.snapshotPeerErrCh require.Error(t, response.Err) @@ -377,11 +377,11 @@ func TestReactor_SnapshotsRequest(t *testing.T) { rts := setup(ctx, t, conn, nil, 100) - rts.snapshotInCh <- p2p.Envelope{ + rts.snapshotInCh.Send(p2p.Envelope{ From: types.NodeID("aa"), ChannelID: SnapshotChannel, Message: &ssproto.SnapshotsRequest{}, - } + }, 0) if len(tc.expectResponses) > 0 { retryUntil(ctx, t, func() bool { return len(rts.snapshotOutCh) == len(tc.expectResponses) }, time.Second) @@ -434,13 +434,13 @@ func TestReactor_LightBlockResponse(t *testing.T) { rts.stateStore.On("LoadValidators", height).Return(vals, nil) - rts.blockInCh <- p2p.Envelope{ + rts.blockInCh.Send(p2p.Envelope{ From: types.NodeID("aa"), ChannelID: LightBlockChannel, Message: &ssproto.LightBlockRequest{ Height: 10, }, - } + }, 0) require.Empty(t, rts.blockPeerErrCh) select { @@ -622,7 +622,6 @@ func TestReactor_Backfill(t *testing.T) { // test backfill algorithm with varying failure rates [0, 10] failureRates := []int{0, 2, 9} for _, failureRate := range failureRates { - failureRate := failureRate t.Run(fmt.Sprintf("failure rate: %d", failureRate), func(t *testing.T) { ctx := t.Context() t.Cleanup(leaktest.CheckTimeout(t, 1*time.Minute)) @@ -718,7 +717,7 @@ func handleLightBlockRequests( t *testing.T, chain map[int64]*types.LightBlock, receiving chan p2p.Envelope, - sending chan p2p.Envelope, + sending *p2p.Queue, close chan struct{}, failureRate int) { requests := 0 @@ -732,17 +731,13 @@ func handleLightBlockRequests( if requests%10 >= failureRate { lb, err := chain[int64(msg.Height)].ToProto() require.NoError(t, err) - select { - case sending <- p2p.Envelope{ + sending.Send(p2p.Envelope{ From: envelope.To, ChannelID: LightBlockChannel, Message: &ssproto.LightBlockResponse{ LightBlock: lb, }, - }: - case <-ctx.Done(): - return - } + }, 0) } else { switch errorCount % 3 { case 0: // send a different block @@ -750,29 +745,21 @@ func handleLightBlockRequests( _, _, lb := mockLB(ctx, t, int64(msg.Height), factory.DefaultTestTime, factory.MakeBlockID(), vals, pv) differntLB, err := lb.ToProto() require.NoError(t, err) - select { - case sending <- p2p.Envelope{ + sending.Send(p2p.Envelope{ From: envelope.To, ChannelID: LightBlockChannel, Message: &ssproto.LightBlockResponse{ LightBlock: differntLB, }, - }: - case <-ctx.Done(): - return - } + }, 0) case 1: // send nil block i.e. pretend we don't have it - select { - case sending <- p2p.Envelope{ + sending.Send(p2p.Envelope{ From: envelope.To, ChannelID: LightBlockChannel, Message: &ssproto.LightBlockResponse{ LightBlock: nil, }, - }: - case <-ctx.Done(): - return - } + }, 0) case 2: // don't do anything } errorCount++ @@ -788,7 +775,8 @@ func handleLightBlockRequests( func handleConsensusParamsRequest( ctx context.Context, t *testing.T, - receiving, sending chan p2p.Envelope, + receiving chan p2p.Envelope, + sending *p2p.Queue, closeCh chan struct{}, ) { t.Helper() @@ -804,21 +792,14 @@ func handleConsensusParamsRequest( t.Errorf("message was %T which is not a params request", envelope.Message) return } - select { - case sending <- p2p.Envelope{ + sending.Send(p2p.Envelope{ From: envelope.To, ChannelID: ParamsChannel, Message: &ssproto.ParamsResponse{ Height: msg.Height, ConsensusParams: paramsProto, }, - }: - case <-ctx.Done(): - return - case <-closeCh: - return - } - + }, 0) case <-closeCh: return } @@ -902,7 +883,7 @@ func handleSnapshotRequests( ctx context.Context, t *testing.T, receivingCh chan p2p.Envelope, - sendingCh chan p2p.Envelope, + sendingCh *p2p.Queue, closeCh chan struct{}, snapshots []snapshot, ) { @@ -917,7 +898,7 @@ func handleSnapshotRequests( _, ok := envelope.Message.(*ssproto.SnapshotsRequest) require.True(t, ok) for _, snapshot := range snapshots { - sendingCh <- p2p.Envelope{ + sendingCh.Send(p2p.Envelope{ From: envelope.To, ChannelID: SnapshotChannel, Message: &ssproto.SnapshotsResponse{ @@ -927,7 +908,7 @@ func handleSnapshotRequests( Hash: snapshot.Hash, Metadata: snapshot.Metadata, }, - } + }, 0) } } } @@ -937,7 +918,7 @@ func handleChunkRequests( ctx context.Context, t *testing.T, receivingCh chan p2p.Envelope, - sendingCh chan p2p.Envelope, + sendingCh *p2p.Queue, closeCh chan struct{}, chunk []byte, ) { @@ -951,7 +932,7 @@ func handleChunkRequests( case envelope := <-receivingCh: msg, ok := envelope.Message.(*ssproto.ChunkRequest) require.True(t, ok) - sendingCh <- p2p.Envelope{ + sendingCh.Send(p2p.Envelope{ From: envelope.To, ChannelID: ChunkChannel, Message: &ssproto.ChunkResponse{ @@ -961,7 +942,7 @@ func handleChunkRequests( Chunk: chunk, Missing: false, }, - } + }, 0) } } diff --git a/libs/service/service.go b/libs/service/service.go index 685b267c1..c96b32046 100644 --- a/libs/service/service.go +++ b/libs/service/service.go @@ -2,6 +2,7 @@ package service import ( "context" + "fmt" "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/libs/utils" "sync" @@ -164,6 +165,21 @@ func (bs *BaseService) Spawn(name string, task func(ctx context.Context) error) }() } +func (bs *BaseService) SpawnCritical(name string, task func(ctx context.Context) error) { + inner := bs.inner.Load() + if inner == nil { + panic("service is not started yet") + } + + inner.wg.Add(1) + go func() { + defer inner.wg.Done() + if err := utils.IgnoreCancel(task(inner.ctx)); err != nil { + panic(fmt.Sprintf("critical task failed: name=%v, service=%v: %v", name, bs.name, err)) + } + }() +} + // IsRunning implements Service by returning true or false depending on the // service's state. func (bs *BaseService) IsRunning() bool { diff --git a/libs/utils/mutex.go b/libs/utils/mutex.go index b6f4a9a58..0bd9bb722 100644 --- a/libs/utils/mutex.go +++ b/libs/utils/mutex.go @@ -33,6 +33,41 @@ func (m *Mutex[T]) Lock() iter.Seq[T] { } } +// Mutex guards access to object of type T. +type RWMutex[T any] struct { + mu sync.RWMutex + value T +} + +// NewMutex creates a new Mutex with given object. +func NewRWMutex[T any](value T) (m RWMutex[T]) { + m.value = value + // nolint:nakedret + return +} + +// Lock returns an iterator which locks the mutex and yields the guarded object. +// The mutex is unlocked when the iterator is done. +// If the mutex is nil, the iterator is a no-op. +func (m *RWMutex[T]) Lock() iter.Seq[T] { + return func(yield func(val T) bool) { + m.mu.Lock() + defer m.mu.Unlock() + _ = yield(m.value) + } +} + +// RLock returns an iterator which locks the mutex FOR READ and yields the guarded object. +// The mutex is unlocked when the iterator is done. +// If the mutex is nil, the iterator is a no-op. +func (m *RWMutex[T]) RLock() iter.Seq[T] { + return func(yield func(val T) bool) { + m.mu.RLock() + defer m.mu.RUnlock() + _ = yield(m.value) + } +} + // version of the value stored in an atomic watch. type version[T any] struct { updated chan struct{} diff --git a/libs/utils/proto.go b/libs/utils/proto.go index 5f5ad7a41..4593c9634 100644 --- a/libs/utils/proto.go +++ b/libs/utils/proto.go @@ -6,7 +6,7 @@ import ( "fmt" "sync" - "google.golang.org/protobuf/proto" + "github.com/gogo/protobuf/proto" ) // Hash is a SHA-256 hash. diff --git a/libs/utils/require/require.go b/libs/utils/require/require.go index 66bb750d3..438df3dfd 100644 --- a/libs/utils/require/require.go +++ b/libs/utils/require/require.go @@ -17,9 +17,22 @@ var False = require.False // True . var True = require.True +// Zero . +var Zero = require.Zero + +// NotZero . +var NotZero = require.NotZero + // Contains . var Contains = require.Contains +func ElementsMatch[T any](t TestingT, a []T, b []T, msgAndArgs ...any) { + require.ElementsMatch(t, a, b, msgAndArgs...) +} + +// Eventually . +var Eventually = require.Eventually + // EqualError . // TODO: get rid of comparing errors by strings, // use concrete error types instead. @@ -65,11 +78,21 @@ func Less[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { require.Less(t, e1, e2, msgAndArgs...) } +// LessOrEqual . +func LessOrEqual[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { + require.LessOrEqual(t, e1, e2, msgAndArgs...) +} + // Greater . func Greater[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { require.Greater(t, e1, e2, msgAndArgs...) } +// GreaterOrEqual . +func GreaterOrEqual[T cmp.Ordered](t TestingT, e1, e2 T, msgAndArgs ...any) { + require.GreaterOrEqual(t, e1, e2, msgAndArgs...) +} + // Equal . func Equal[T any](t TestingT, expected, actual T, msgAndArgs ...any) { require.Equal(t, expected, actual, msgAndArgs...) diff --git a/libs/utils/tcp/tcp.go b/libs/utils/tcp/tcp.go new file mode 100644 index 000000000..4baef6cb6 --- /dev/null +++ b/libs/utils/tcp/tcp.go @@ -0,0 +1,84 @@ +package tcp + +import ( + "context" + "errors" + "net" + "net/netip" + "syscall" + + "golang.org/x/sys/unix" + + "github.com/tendermint/tendermint/libs/utils" +) + +var reservedAddrs = utils.NewMutex(map[netip.AddrPort]struct{}{}) + +// IPv4Loopback returns the IPv4 loopback address. +func IPv4Loopback() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } + +// Norm normalizes address by unmapping IPv4 -> IPv6 embedding. +func Norm(addr netip.AddrPort) netip.AddrPort { + return netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port()) +} + +// Listen opens a TCP listener on the given address. +// It takes into account the reserved addresses (in tests) and sets the SO_REUSEPORT. +// nolint: contextcheck +func Listen(addr netip.AddrPort) (net.Listener, error) { + if addr.Port() == 0 { + return nil, errors.New("listening on anyport (i.e. 0) is not allowed. If you are implementing a test use TestReserveAddr() instead") // nolint:lll + } + cfg := net.ListenConfig{} + for addrs := range reservedAddrs.Lock() { + if _, ok := addrs[addr]; ok { + cfg.Control = func(network, address string, c syscall.RawConn) error { + var errInner error + if err := c.Control(func(fd uintptr) { + errInner = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) + }); err != nil { + return err + } + return errInner + } + } + } + // Passing the background context is ok, because Listen is + // non-blocking if it doesn't need to resolve the address + // against a DNS server. + return cfg.Listen(context.Background(), "tcp", addr.String()) +} + +// TestReserveAddr (testonly) reserves a port in ephemeral range to open a TCP listener on it. +// Reservation prevents race conditions with other processes. +func TestReserveAddr() netip.AddrPort { + // Bind a new socket to reserve a port, + // Don't mark it as listening to avoid the kernel from queueing up connections + // on that socket. + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_STREAM, 0) + if err != nil { + panic(err) + } + if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil { + panic(err) + } + if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { + panic(err) + } + ip := IPv4Loopback() + addrAny := &unix.SockaddrInet4{Port: 0, Addr: ip.As4()} + if err := unix.Bind(fd, addrAny); err != nil { + panic(err) + } + + addrRaw, err := unix.Getsockname(fd) + if err != nil { + panic(err) + } + port := uint16(addrRaw.(*unix.SockaddrInet4).Port) + addr := netip.AddrPortFrom(ip, port) + for addrs := range reservedAddrs.Lock() { + addrs[addr] = struct{}{} + } + return addr +} diff --git a/libs/utils/testonly.go b/libs/utils/testonly.go index afd6b8aa8..4ef001ead 100644 --- a/libs/utils/testonly.go +++ b/libs/utils/testonly.go @@ -7,9 +7,9 @@ import ( "reflect" "time" + "github.com/gogo/protobuf/proto" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" ) diff --git a/light/dispatcher_test.go b/light/dispatcher_test.go index 57fc042f8..f73c768dd 100644 --- a/light/dispatcher_test.go +++ b/light/dispatcher_test.go @@ -21,14 +21,14 @@ import ( ) type channelInternal struct { - In chan p2p.Envelope + In *p2p.Queue Out chan p2p.Envelope Error chan p2p.PeerError } func testChannel(size int) (*channelInternal, *p2p.Channel) { in := &channelInternal{ - In: make(chan p2p.Envelope, size), + In: p2p.NewQueue(size), Out: make(chan p2p.Envelope, size), Error: make(chan p2p.PeerError, size), } @@ -55,7 +55,7 @@ func TestDispatcherBasic(t *testing.T) { // make a bunch of async requests and require that the correct responses are // given - for i := 0; i < numPeers; i++ { + for i := range numPeers { wg.Add(1) go func(height int64) { defer wg.Done() @@ -175,7 +175,7 @@ func TestPeerListBasic(t *testing.T) { assert.Equal(t, numPeers, peerList.Len()) half := numPeers / 2 - for i := 0; i < half; i++ { + for i := range half { assert.Equal(t, peerSet[i], peerList.Pop(ctx)) } assert.Equal(t, half, peerList.Len()) @@ -330,7 +330,7 @@ func handleRequests(ctx context.Context, t *testing.T, d *Dispatcher, ch chan p2 func createPeerSet(num int) []types.NodeID { peers := make([]types.NodeID, num) - for i := 0; i < num; i++ { + for i := range num { peers[i], _ = types.NewNodeID(strings.Repeat(fmt.Sprintf("%d", i), 2*types.NodeIDByteLength)) } return peers diff --git a/node/node.go b/node/node.go index d6e43987a..6e175eddd 100644 --- a/node/node.go +++ b/node/node.go @@ -6,7 +6,7 @@ import ( "fmt" "net" "net/http" - "strconv" + "net/netip" "strings" "time" @@ -809,9 +809,7 @@ func LoadStateFromDBOrGenesisDocProvider(stateStore sm.Store, genDoc *types.Gene } func getRouterConfig(conf *config.Config, appClient abciclient.Client) p2p.RouterOptions { - opts := p2p.RouterOptions{ - QueueType: conf.P2P.QueueType, - } + opts := p2p.RouterOptions{} if conf.FilterPeers && appClient != nil { opts.FilterPeerByID = func(ctx context.Context, id types.NodeID) error { @@ -828,9 +826,9 @@ func getRouterConfig(conf *config.Config, appClient abciclient.Client) p2p.Route return nil } - opts.FilterPeerByIP = func(ctx context.Context, ip net.IP, port uint16) error { + opts.FilterPeerByIP = func(ctx context.Context, addrPort netip.AddrPort) error { res, err := appClient.Query(ctx, &abci.RequestQuery{ - Path: fmt.Sprintf("/p2p/filter/addr/%s", net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))), + Path: fmt.Sprintf("/p2p/filter/addr/%v", addrPort), }) if err != nil { return err diff --git a/node/setup.go b/node/setup.go index 526382440..205c205c9 100644 --- a/node/setup.go +++ b/node/setup.go @@ -300,23 +300,21 @@ func createRouter( p2pLogger := logger.With("module", "p2p") + ep, err := p2p.NewEndpoint(nodeKey.ID.AddressString(cfg.P2P.ListenAddress)) + if err != nil { + return nil, err + } transportConf := conn.DefaultMConnConfig() transportConf.FlushThrottle = cfg.P2P.FlushThrottleTimeout transportConf.SendRate = cfg.P2P.SendRate transportConf.RecvRate = cfg.P2P.RecvRate transportConf.MaxPacketMsgPayloadSize = cfg.P2P.MaxPacketMsgPayloadSize transport := p2p.NewMConnTransport( - p2pLogger, transportConf, []*p2p.ChannelDescriptor{}, + p2pLogger, ep, transportConf, []*p2p.ChannelDescriptor{}, p2p.MConnTransportOptions{ MaxAcceptedConnections: uint32(cfg.P2P.MaxConnections), }, ) - - ep, err := p2p.NewEndpoint(nodeKey.ID.AddressString(cfg.P2P.ListenAddress)) - if err != nil { - return nil, err - } - return p2p.NewRouter( p2pLogger, p2pMetrics, @@ -324,7 +322,6 @@ func createRouter( peerManager, nodeInfoProducer, transport, - ep, nil, // TODO: replace with mempool CheckTx failure based filterer getRouterConfig(cfg, appClient), ) diff --git a/privval/socket_listeners_test.go b/privval/socket_listeners_test.go index e91d111d0..ff9c24cab 100644 --- a/privval/socket_listeners_test.go +++ b/privval/socket_listeners_test.go @@ -6,8 +6,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" - "github.com/tendermint/tendermint/crypto/ed25519" ) @@ -111,8 +109,14 @@ func TestListenerTimeoutReadWrite(t *testing.T) { ) for _, tc := range listenerTestCases(t, timeoutAccept, timeoutReadWrite) { go func(dialer SocketDialer) { - _, err := dialer() - require.NoError(t, err) + conn, err := dialer() + if err != nil { + panic(err) // this is not the main goroutine, so "require.NoError" won't work + } + // If we don't close this properly, the test gets flaky because connection + // closes at random. + defer conn.Close() + <-t.Context().Done() }(tc.dialer) c, err := tc.listener.Accept() diff --git a/types/node_info.go b/types/node_info.go index fd47816e2..3c7758b67 100644 --- a/types/node_info.go +++ b/types/node_info.go @@ -3,8 +3,7 @@ package types import ( "errors" "fmt" - "net" - "strconv" + "net/netip" "strings" "github.com/tendermint/tendermint/libs/bytes" @@ -77,7 +76,7 @@ func (info NodeInfo) ID() NodeID { // url-encoding), and we just need to be careful with how we handle that in our // clients. (e.g. off by default). func (info NodeInfo) Validate() error { - if _, _, err := ParseAddressString(info.ID().AddressString(info.ListenAddr)); err != nil { + if _, err := ParseAddressString(info.ID().AddressString(info.ListenAddr)); err != nil { return err } @@ -236,48 +235,22 @@ func NodeInfoFromProto(pb *tmp2p.NodeInfo) (NodeInfo, error) { // ParseAddressString reads an address string, and returns the IP // address and port information, returning an error for any validation // errors. -func ParseAddressString(addr string) (net.IP, uint16, error) { +func ParseAddressString(addr string) (netip.AddrPort, error) { addrWithoutProtocol := removeProtocolIfDefined(addr) spl := strings.Split(addrWithoutProtocol, "@") if len(spl) != 2 { - return nil, 0, errors.New("invalid address") + return netip.AddrPort{}, errors.New("invalid address") } id, err := NewNodeID(spl[0]) if err != nil { - return nil, 0, err + return netip.AddrPort{}, err } if err := id.Validate(); err != nil { - return nil, 0, err + return netip.AddrPort{}, err } - - addrWithoutProtocol = spl[1] - - // get host and port - host, portStr, err := net.SplitHostPort(addrWithoutProtocol) - if err != nil { - return nil, 0, err - } - if len(host) == 0 { - return nil, 0, err - } - - ip := net.ParseIP(host) - if ip == nil { - ips, err := net.LookupIP(host) - if err != nil { - return nil, 0, err - } - ip = ips[0] - } - - port, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - return nil, 0, err - } - - return ip, uint16(port), nil + return netip.ParseAddrPort(spl[1]) } func removeProtocolIfDefined(addr string) string { diff --git a/types/node_info_test.go b/types/node_info_test.go index 1f8480d02..16a47fdb7 100644 --- a/types/node_info_test.go +++ b/types/node_info_test.go @@ -242,11 +242,10 @@ func TestParseAddressString(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - addr, port, err := ParseAddressString(tc.addr) + addr, err := ParseAddressString(tc.addr) if tc.correct { require.NoError(t, err, tc.addr) assert.Contains(t, tc.expected, addr.String()) - assert.Contains(t, tc.expected, fmt.Sprint(port)) } else { assert.Error(t, err, "%v", tc.addr) }