Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dot/network/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (s *testStreamHandler) readStream(stream libp2pnetwork.Stream,
}()

for {
tot, err := readStream(stream, msgBytes)
tot, err := readStream(stream, &msgBytes)
if errors.Is(err, io.EOF) {
return
} else if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions dot/network/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder
peer := stream.Conn().RemotePeer()
buffer := s.bufPool.Get().(*[]byte)
defer s.bufPool.Put(buffer)
msgBytes := *buffer

for {
n, err := readStream(stream, msgBytes[:])
n, err := readStream(stream, buffer)
if err != nil {
logger.Tracef(
"failed to read from stream id %s of peer %s using protocol %s: %s",
Expand All @@ -32,6 +31,7 @@ func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder

// decode message based on message type
// stream should always be inbound if it passes through service.readStream
msgBytes := *buffer
msg, err := decoder(msgBytes[:n], peer, isInbound(stream))
if err != nil {
logger.Tracef("failed to decode message from stream id %s using protocol %s: %s",
Expand Down
4 changes: 2 additions & 2 deletions dot/network/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,14 +431,14 @@ func (s *Service) readHandshake(stream libp2pnetwork.Stream, decoder HandshakeDe

buffer := s.bufPool.Get().(*[]byte)
defer s.bufPool.Put(buffer)
msgBytes := *buffer

tot, err := readStream(stream, msgBytes[:])
tot, err := readStream(stream, buffer)
if err != nil {
hsC <- &handshakeReader{hs: nil, err: err}
return
}

msgBytes := *buffer
hs, err := decoder(msgBytes[:tot])
if err != nil {
s.host.cm.peerSetHandler.ReportPeer(peerset.ReputationChange{
Expand Down
2 changes: 1 addition & 1 deletion dot/network/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ const (
blockAnnounceID = "/block-announces/1"
transactionsID = "/transactions/1"

maxMessageSize = 1024 * 63 // 63kb for now
maxMessageSize = 1024 * 64 // 64kb for now
)

var (
Expand Down
2 changes: 1 addition & 1 deletion dot/network/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (s *Service) receiveBlockResponse(stream libp2pnetwork.Stream) (*BlockRespo

buf := s.blockResponseBuf

n, err := readStream(stream, buf)
n, err := readStream(stream, &buf)
if err != nil {
return nil, fmt.Errorf("read stream error: %w", err)
}
Expand Down
6 changes: 4 additions & 2 deletions dot/network/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func readLEB128ToUint64(r io.Reader, buf []byte) (uint64, int, error) {
}

// readStream reads from the stream into the given buffer, returning the number of bytes read
func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) {
func readStream(stream libp2pnetwork.Stream, bufPointer *[]byte) (int, error) {
if stream == nil {
return 0, errors.New("stream is nil")
}
Expand All @@ -185,6 +185,7 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) {
tot int
)

buf := *bufPointer
length, bytesRead, err := readLEB128ToUint64(stream, buf[:1])
if err != nil {
return bytesRead, fmt.Errorf("failed to read length: %w", err)
Expand All @@ -195,8 +196,9 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) {
}

if length > uint64(len(buf)) {
extraBytes := int(length) - len(buf)
*bufPointer = append(buf, make([]byte, extraBytes)...) // TODO #2288 use bytes.Buffer instead
logger.Warnf("received message with size %d greater than allocated message buffer size %d", length, len(buf))
return 0, fmt.Errorf("message size greater than allocated message buffer: got %d", length)
}

if length > maxBlockResponseSize {
Expand Down