Skip to content

Commit

Permalink
fix(dot/network): resize bytes slice buffer if needed (#2291)
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Feb 17, 2022
1 parent 67a9bbb commit 8db8b2a
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 9 deletions.
2 changes: 1 addition & 1 deletion dot/network/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (s *testStreamHandler) readStream(stream libp2pnetwork.Stream,
}()

for {
tot, err := readStream(stream, msgBytes)
tot, err := readStream(stream, &msgBytes)
if errors.Is(err, io.EOF) {
return
} else if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions dot/network/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder
peer := stream.Conn().RemotePeer()
buffer := s.bufPool.Get().(*[]byte)
defer s.bufPool.Put(buffer)
msgBytes := *buffer

for {
n, err := readStream(stream, msgBytes[:])
n, err := readStream(stream, buffer)
if err != nil {
logger.Tracef(
"failed to read from stream id %s of peer %s using protocol %s: %s",
Expand All @@ -32,6 +31,7 @@ func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder

// decode message based on message type
// stream should always be inbound if it passes through service.readStream
msgBytes := *buffer
msg, err := decoder(msgBytes[:n], peer, isInbound(stream))
if err != nil {
logger.Tracef("failed to decode message from stream id %s using protocol %s: %s",
Expand Down
4 changes: 2 additions & 2 deletions dot/network/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,14 +431,14 @@ func (s *Service) readHandshake(stream libp2pnetwork.Stream, decoder HandshakeDe

buffer := s.bufPool.Get().(*[]byte)
defer s.bufPool.Put(buffer)
msgBytes := *buffer

tot, err := readStream(stream, msgBytes[:])
tot, err := readStream(stream, buffer)
if err != nil {
hsC <- &handshakeReader{hs: nil, err: err}
return
}

msgBytes := *buffer
hs, err := decoder(msgBytes[:tot])
if err != nil {
s.host.cm.peerSetHandler.ReportPeer(peerset.ReputationChange{
Expand Down
2 changes: 1 addition & 1 deletion dot/network/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ const (
blockAnnounceID = "/block-announces/1"
transactionsID = "/transactions/1"

maxMessageSize = 1024 * 63 // 63kb for now
maxMessageSize = 1024 * 64 // 64kb for now
)

var (
Expand Down
2 changes: 1 addition & 1 deletion dot/network/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (s *Service) receiveBlockResponse(stream libp2pnetwork.Stream) (*BlockRespo

buf := s.blockResponseBuf

n, err := readStream(stream, buf)
n, err := readStream(stream, &buf)
if err != nil {
return nil, fmt.Errorf("read stream error: %w", err)
}
Expand Down
6 changes: 4 additions & 2 deletions dot/network/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func readLEB128ToUint64(r io.Reader, buf []byte) (uint64, int, error) {
}

// readStream reads from the stream into the given buffer, returning the number of bytes read
func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) {
func readStream(stream libp2pnetwork.Stream, bufPointer *[]byte) (int, error) {
if stream == nil {
return 0, errors.New("stream is nil")
}
Expand All @@ -185,6 +185,7 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) {
tot int
)

buf := *bufPointer
length, bytesRead, err := readLEB128ToUint64(stream, buf[:1])
if err != nil {
return bytesRead, fmt.Errorf("failed to read length: %w", err)
Expand All @@ -195,8 +196,9 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) {
}

if length > uint64(len(buf)) {
extraBytes := int(length) - len(buf)
*bufPointer = append(buf, make([]byte, extraBytes)...) // TODO #2288 use bytes.Buffer instead
logger.Warnf("received message with size %d greater than allocated message buffer size %d", length, len(buf))
return 0, fmt.Errorf("message size greater than allocated message buffer: got %d", length)
}

if length > maxBlockResponseSize {
Expand Down

0 comments on commit 8db8b2a

Please sign in to comment.