From d90007ae51f942922f2e80203a9e1baee18423eb Mon Sep 17 00:00:00 2001 From: EclesioMeloJunior Date: Mon, 29 May 2023 14:26:30 -0400 Subject: [PATCH 1/4] chore: remove `maxReads` limitation to read stream --- dot/network/errors.go | 2 ++ dot/network/service.go | 3 +-- dot/network/utils.go | 6 +++++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/dot/network/errors.go b/dot/network/errors.go index d22d4d5d88..261f724cdc 100644 --- a/dot/network/errors.go +++ b/dot/network/errors.go @@ -16,4 +16,6 @@ var ( errInvalidStartingBlockType = errors.New("invalid StartingBlock in messsage") errInboundHanshakeExists = errors.New("an inbound handshake already exists for given peer") errInvalidRole = errors.New("invalid role") + ErrFailedToReadEntireMessage = errors.New("failed to read entire message") + ErrUnexpectedLenght = errors.New("unexpected length") ) diff --git a/dot/network/service.go b/dot/network/service.go index fd0de70f95..e917212ae7 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -39,8 +39,7 @@ const ( ) var ( - logger = log.NewFromGlobal(log.AddContext("pkg", "network")) - maxReads = 256 + logger = log.NewFromGlobal(log.AddContext("pkg", "network")) peerCountGauge = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "gossamer_network_node", diff --git a/dot/network/utils.go b/dot/network/utils.go index e5fd8da6ef..24d51831b4 100644 --- a/dot/network/utils.go +++ b/dot/network/utils.go @@ -217,7 +217,7 @@ func readStream(stream libp2pnetwork.Stream, bufPointer *[]byte, maxSize uint64) } tot = 0 - for i := 0; i < maxReads; i++ { + for { n, err := stream.Read(buf[tot:]) if err != nil { return n + tot, err @@ -227,6 +227,10 @@ func readStream(stream libp2pnetwork.Stream, bufPointer *[]byte, maxSize uint64) if tot == int(length) { break } + + if tot > int(length) { + return tot, fmt.Errorf("%w, expected %d bytes, read %d", ErrUnexpectedLenght, length, tot) + } } if tot != int(length) { From be14ef366e3f3096ca03dea0b7d1c25e1b977d1d Mon Sep 17 00:00:00 2001 From: EclesioMeloJunior Date: Tue, 30 May 2023 07:01:19 -0400 Subject: [PATCH 2/4] chore: use loop condition and remove a sentinel error --- dot/network/errors.go | 1 - dot/network/utils.go | 9 +-------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/dot/network/errors.go b/dot/network/errors.go index 261f724cdc..c52ecdc0a8 100644 --- a/dot/network/errors.go +++ b/dot/network/errors.go @@ -17,5 +17,4 @@ var ( errInboundHanshakeExists = errors.New("an inbound handshake already exists for given peer") errInvalidRole = errors.New("invalid role") ErrFailedToReadEntireMessage = errors.New("failed to read entire message") - ErrUnexpectedLenght = errors.New("unexpected length") ) diff --git a/dot/network/utils.go b/dot/network/utils.go index 24d51831b4..2e33e3a6d8 100644 --- a/dot/network/utils.go +++ b/dot/network/utils.go @@ -217,20 +217,13 @@ func readStream(stream libp2pnetwork.Stream, bufPointer *[]byte, maxSize uint64) } tot = 0 - for { + for tot < int(length) { n, err := stream.Read(buf[tot:]) if err != nil { return n + tot, err } tot += n - if tot == int(length) { - break - } - - if tot > int(length) { - return tot, fmt.Errorf("%w, expected %d bytes, read %d", ErrUnexpectedLenght, length, tot) - } } if tot != int(length) { From d19fb755a3be6953546ebdafda108e60dca424c8 Mon Sep 17 00:00:00 2001 From: EclesioMeloJunior Date: Tue, 30 May 2023 16:05:56 -0400 Subject: [PATCH 3/4] chore: increase test coverage over `readStream` func --- dot/network/errors.go | 4 +- dot/network/mock_stream_test.go | 249 +++++++++++++++++++++++++++++ dot/network/mocks_generate_test.go | 1 + dot/network/utils.go | 35 ++-- dot/network/utils_test.go | 151 ++++++++++++++++- 5 files changed, 414 insertions(+), 26 deletions(-) create mode 100644 dot/network/mock_stream_test.go diff --git a/dot/network/errors.go b/dot/network/errors.go index c52ecdc0a8..135d5e7277 100644 --- a/dot/network/errors.go +++ b/dot/network/errors.go @@ -16,5 +16,7 @@ var ( errInvalidStartingBlockType = errors.New("invalid StartingBlock in messsage") errInboundHanshakeExists = errors.New("an inbound handshake already exists for given peer") errInvalidRole = errors.New("invalid role") - ErrFailedToReadEntireMessage = errors.New("failed to read entire message") + ErrNilStream = errors.New("nil stream") + ErrInvalidLEB128EncodedData = errors.New("invalid LEB128 encoded data") + ErrGreaterThanMaxSize = errors.New("greater than maximum size") ) diff --git a/dot/network/mock_stream_test.go b/dot/network/mock_stream_test.go new file mode 100644 index 0000000000..f2029ceb8a --- /dev/null +++ b/dot/network/mock_stream_test.go @@ -0,0 +1,249 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/libp2p/go-libp2p/core/network (interfaces: Stream) + +// Package network is a generated GoMock package. +package network + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + network "github.com/libp2p/go-libp2p/core/network" + protocol "github.com/libp2p/go-libp2p/core/protocol" +) + +// MockStream is a mock of Stream interface. +type MockStream struct { + ctrl *gomock.Controller + recorder *MockStreamMockRecorder +} + +// MockStreamMockRecorder is the mock recorder for MockStream. +type MockStreamMockRecorder struct { + mock *MockStream +} + +// NewMockStream creates a new mock instance. +func NewMockStream(ctrl *gomock.Controller) *MockStream { + mock := &MockStream{ctrl: ctrl} + mock.recorder = &MockStreamMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStream) EXPECT() *MockStreamMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockStream) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockStreamMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStream)(nil).Close)) +} + +// CloseRead mocks base method. +func (m *MockStream) CloseRead() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseRead") + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseRead indicates an expected call of CloseRead. +func (mr *MockStreamMockRecorder) CloseRead() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseRead", reflect.TypeOf((*MockStream)(nil).CloseRead)) +} + +// CloseWrite mocks base method. +func (m *MockStream) CloseWrite() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseWrite") + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseWrite indicates an expected call of CloseWrite. +func (mr *MockStreamMockRecorder) CloseWrite() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWrite", reflect.TypeOf((*MockStream)(nil).CloseWrite)) +} + +// Conn mocks base method. +func (m *MockStream) Conn() network.Conn { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Conn") + ret0, _ := ret[0].(network.Conn) + return ret0 +} + +// Conn indicates an expected call of Conn. +func (mr *MockStreamMockRecorder) Conn() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Conn", reflect.TypeOf((*MockStream)(nil).Conn)) +} + +// ID mocks base method. +func (m *MockStream) ID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ID") + ret0, _ := ret[0].(string) + return ret0 +} + +// ID indicates an expected call of ID. +func (mr *MockStreamMockRecorder) ID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockStream)(nil).ID)) +} + +// Protocol mocks base method. +func (m *MockStream) Protocol() protocol.ID { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Protocol") + ret0, _ := ret[0].(protocol.ID) + return ret0 +} + +// Protocol indicates an expected call of Protocol. +func (mr *MockStreamMockRecorder) Protocol() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Protocol", reflect.TypeOf((*MockStream)(nil).Protocol)) +} + +// Read mocks base method. +func (m *MockStream) Read(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read. +func (mr *MockStreamMockRecorder) Read(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStream)(nil).Read), arg0) +} + +// Reset mocks base method. +func (m *MockStream) Reset() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Reset") + ret0, _ := ret[0].(error) + return ret0 +} + +// Reset indicates an expected call of Reset. +func (mr *MockStreamMockRecorder) Reset() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Reset", reflect.TypeOf((*MockStream)(nil).Reset)) +} + +// Scope mocks base method. +func (m *MockStream) Scope() network.StreamScope { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Scope") + ret0, _ := ret[0].(network.StreamScope) + return ret0 +} + +// Scope indicates an expected call of Scope. +func (mr *MockStreamMockRecorder) Scope() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scope", reflect.TypeOf((*MockStream)(nil).Scope)) +} + +// SetDeadline mocks base method. +func (m *MockStream) SetDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline. +func (mr *MockStreamMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStream)(nil).SetDeadline), arg0) +} + +// SetProtocol mocks base method. +func (m *MockStream) SetProtocol(arg0 protocol.ID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetProtocol", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetProtocol indicates an expected call of SetProtocol. +func (mr *MockStreamMockRecorder) SetProtocol(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetProtocol", reflect.TypeOf((*MockStream)(nil).SetProtocol), arg0) +} + +// SetReadDeadline mocks base method. +func (m *MockStream) SetReadDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline. +func (mr *MockStreamMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStream)(nil).SetReadDeadline), arg0) +} + +// SetWriteDeadline mocks base method. +func (m *MockStream) SetWriteDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetWriteDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWriteDeadline indicates an expected call of SetWriteDeadline. +func (mr *MockStreamMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStream)(nil).SetWriteDeadline), arg0) +} + +// Stat mocks base method. +func (m *MockStream) Stat() network.Stats { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stat") + ret0, _ := ret[0].(network.Stats) + return ret0 +} + +// Stat indicates an expected call of Stat. +func (mr *MockStreamMockRecorder) Stat() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stat", reflect.TypeOf((*MockStream)(nil).Stat)) +} + +// Write mocks base method. +func (m *MockStream) Write(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write. +func (mr *MockStreamMockRecorder) Write(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStream)(nil).Write), arg0) +} diff --git a/dot/network/mocks_generate_test.go b/dot/network/mocks_generate_test.go index 3f5632ef05..d170388470 100644 --- a/dot/network/mocks_generate_test.go +++ b/dot/network/mocks_generate_test.go @@ -7,3 +7,4 @@ package network //go:generate mockgen -destination=mock_syncer_test.go -package $GOPACKAGE . Syncer //go:generate mockgen -destination=mock_block_state_test.go -package $GOPACKAGE . BlockState //go:generate mockgen -destination=mock_transaction_handler_test.go -package $GOPACKAGE . TransactionHandler +//go:generate mockgen -destination=mock_stream_test.go -package $GOPACKAGE github.com/libp2p/go-libp2p/core/network Stream diff --git a/dot/network/utils.go b/dot/network/utils.go index 2e33e3a6d8..201873c21a 100644 --- a/dot/network/utils.go +++ b/dot/network/utils.go @@ -6,7 +6,6 @@ package network import ( crand "crypto/rand" "encoding/hex" - "errors" "fmt" "io" mrand "math/rand" @@ -150,11 +149,7 @@ func uint64ToLEB128(in uint64) []byte { return out } -func readLEB128ToUint64(r io.Reader, buf []byte) (uint64, int, error) { - if len(buf) == 0 { - return 0, 0, errors.New("buffer has length 0") - } - +func readLEB128ToUint64(r io.Reader) (uint64, int, error) { var out uint64 var shift uint @@ -162,14 +157,16 @@ func readLEB128ToUint64(r io.Reader, buf []byte) (uint64, int, error) { bytesRead := 0 for { - n, err := r.Read(buf[:1]) + // read a sinlge byte + singleByte := []byte{0} + n, err := r.Read(singleByte) if err != nil { return 0, bytesRead, err } bytesRead += n - b := buf[0] + b := singleByte[0] out |= uint64(0x7F&b) << shift if b&0x80 == 0 { break @@ -177,7 +174,7 @@ func readLEB128ToUint64(r io.Reader, buf []byte) (uint64, int, error) { maxSize-- if maxSize == 0 { - return 0, bytesRead, fmt.Errorf("invalid LEB128 encoded data") + return 0, bytesRead, ErrInvalidLEB128EncodedData } shift += 7 @@ -186,17 +183,12 @@ func readLEB128ToUint64(r io.Reader, buf []byte) (uint64, int, error) { } // readStream reads from the stream into the given buffer, returning the number of bytes read -func readStream(stream libp2pnetwork.Stream, bufPointer *[]byte, maxSize uint64) (int, error) { +func readStream(stream libp2pnetwork.Stream, bufPointer *[]byte, maxSize uint64) (tot int, err error) { if stream == nil { - return 0, errors.New("stream is nil") + return 0, ErrNilStream } - var ( - tot int - ) - - buf := *bufPointer - length, bytesRead, err := readLEB128ToUint64(stream, buf[:1]) + length, bytesRead, err := readLEB128ToUint64(stream) if err != nil { return bytesRead, fmt.Errorf("failed to read length: %w", err) } @@ -205,18 +197,19 @@ func readStream(stream libp2pnetwork.Stream, bufPointer *[]byte, maxSize uint64) return 0, nil // msg length of 0 is allowed, for example transactions handshake } + buf := *bufPointer if length > uint64(len(buf)) { - extraBytes := int(length) - len(buf) - *bufPointer = append(buf, make([]byte, extraBytes)...) // TODO #2288 use bytes.Buffer instead logger.Warnf("received message with size %d greater than allocated message buffer size %d", length, len(buf)) + extraBytes := int(length) - len(buf) + *bufPointer = append(buf, make([]byte, extraBytes)...) + buf = *bufPointer } if length > maxSize { logger.Warnf("received message with size %d greater than max size %d, closing stream", length, maxSize) - return 0, fmt.Errorf("message size greater than maximum: got %d", length) + return 0, fmt.Errorf("%w: max %d, got %d", ErrGreaterThanMaxSize, maxSize, length) } - tot = 0 for tot < int(length) { n, err := stream.Read(buf[tot:]) if err != nil { diff --git a/dot/network/utils_test.go b/dot/network/utils_test.go index 0ca2e35c0f..50a4541d96 100644 --- a/dot/network/utils_test.go +++ b/dot/network/utils_test.go @@ -7,6 +7,8 @@ import ( "bytes" "testing" + "github.com/golang/mock/gomock" + libp2pnetwork "github.com/libp2p/go-libp2p/core/network" "github.com/stretchr/testify/require" ) @@ -131,12 +133,11 @@ func TestReadLEB128ToUint64(t *testing.T) { } for _, tc := range tests { - b := make([]byte, 2) buf := new(bytes.Buffer) _, err := buf.Write(tc.input) require.NoError(t, err) - ret, _, err := readLEB128ToUint64(buf, b[:1]) + ret, _, err := readLEB128ToUint64(buf) require.NoError(t, err) require.Equal(t, tc.output, ret) } @@ -145,11 +146,153 @@ func TestReadLEB128ToUint64(t *testing.T) { func TestInvalidLeb128(t *testing.T) { input := []byte{'\xFF', '\xFF', '\xFF', '\xFF', '\xFF', '\xFF', '\xFF', '\xFF', '\xFF', '\xFF', '\x01'} - b := make([]byte, 2) buf := new(bytes.Buffer) _, err := buf.Write(input) require.NoError(t, err) - _, _, err = readLEB128ToUint64(buf, b[:1]) + _, _, err = readLEB128ToUint64(buf) require.Error(t, err) } + +func TestReadStream(t *testing.T) { + t.Parallel() + + cases := map[string]struct { + maxSize uint64 + bufPointer *[]byte + buildStreamMock func(ctrl *gomock.Controller) libp2pnetwork.Stream + wantErr error + errString string + expectedOutput int + expectedBuf []byte + }{ + "nil_stream": { + buildStreamMock: func(ctrl *gomock.Controller) libp2pnetwork.Stream { + return nil + }, + wantErr: ErrNilStream, + errString: "nil stream", + expectedOutput: 0, + }, + + "invalid_leb128": { + buildStreamMock: func(ctrl *gomock.Controller) libp2pnetwork.Stream { + input := []byte{'\xFF', '\xFF', '\xFF', '\xFF', '\xFF', + '\xFF', '\xFF', '\xFF', '\xFF', '\xFF', '\x01'} + + invalidLeb128Buf := new(bytes.Buffer) + _, err := invalidLeb128Buf.Write(input) + require.NoError(t, err) + + streamMock := NewMockStream(ctrl) + + streamMock.EXPECT().Read([]byte{0}). + DoAndReturn(func(buf any) (n, err any) { + return invalidLeb128Buf.Read(buf.([]byte)) + }).MaxTimes(10) + + return streamMock + }, + bufPointer: &[]byte{0}, + expectedOutput: 10, // read all the bytes in the invalidLeb128Buf + wantErr: ErrInvalidLEB128EncodedData, + errString: "failed to read length: invalid LEB128 encoded data", + }, + + "zero_length": { + buildStreamMock: func(ctrl *gomock.Controller) libp2pnetwork.Stream { + input := []byte{'\x00'} + + streamBuf := new(bytes.Buffer) + _, err := streamBuf.Write(input) + require.NoError(t, err) + + streamMock := NewMockStream(ctrl) + + streamMock.EXPECT().Read([]byte{0}). + DoAndReturn(func(buf any) (n, err any) { + return streamBuf.Read(buf.([]byte)) + }) + + return streamMock + }, + bufPointer: &[]byte{0}, + expectedOutput: 0, + }, + + "length_greater_than_buf_increase_buf_size": { + buildStreamMock: func(ctrl *gomock.Controller) libp2pnetwork.Stream { + input := []byte{0xa, //size 0xa == 10 + 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, // actual data + } + + streamBuf := new(bytes.Buffer) + _, err := streamBuf.Write(input) + require.NoError(t, err) + + streamMock := NewMockStream(ctrl) + + streamMock.EXPECT().Read([]byte{0}). + DoAndReturn(func(buf any) (n, err any) { + return streamBuf.Read(buf.([]byte)) + }) + + streamMock.EXPECT().Read(make([]byte, 10)). + DoAndReturn(func(buf any) (n, err any) { + return streamBuf.Read(buf.([]byte)) + }) + + return streamMock + }, + bufPointer: &[]byte{0}, // a buffer with size 1 + expectedBuf: []byte{0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1}, + expectedOutput: 10, + maxSize: 11, + }, + + "length_greater_than_max_size": { + buildStreamMock: func(ctrl *gomock.Controller) libp2pnetwork.Stream { + input := []byte{0xa, //size 0xa == 10 + 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, // actual data + } + + streamBuf := new(bytes.Buffer) + _, err := streamBuf.Write(input) + require.NoError(t, err) + + streamMock := NewMockStream(ctrl) + + streamMock.EXPECT().Read([]byte{0}). + DoAndReturn(func(buf any) (n, err any) { + return streamBuf.Read(buf.([]byte)) + }) + + return streamMock + }, + bufPointer: &[]byte{0}, // a buffer with size 1 + wantErr: ErrGreaterThanMaxSize, + errString: "greater than maximum size: max 9, got 10", + maxSize: 9, + }, + } + + for tname, tt := range cases { + tt := tt + t.Run(tname, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + stream := tt.buildStreamMock(ctrl) + + n, err := readStream(stream, tt.bufPointer, tt.maxSize) + require.Equal(t, tt.expectedOutput, n) + require.ErrorIs(t, err, tt.wantErr) + if tt.errString != "" { + require.EqualError(t, err, tt.errString) + } + + if tt.expectedBuf != nil { + require.Equal(t, tt.expectedBuf, *tt.bufPointer) + } + }) + } +} From dd8bef0467290ff4d26d3ff3a48e9169ac7761e3 Mon Sep 17 00:00:00 2001 From: EclesioMeloJunior Date: Tue, 30 May 2023 16:10:58 -0400 Subject: [PATCH 4/4] chore: use `ErrFailedToReadEntireMessage` --- dot/network/errors.go | 1 + dot/network/utils.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/dot/network/errors.go b/dot/network/errors.go index 135d5e7277..ef895a6735 100644 --- a/dot/network/errors.go +++ b/dot/network/errors.go @@ -16,6 +16,7 @@ var ( errInvalidStartingBlockType = errors.New("invalid StartingBlock in messsage") errInboundHanshakeExists = errors.New("an inbound handshake already exists for given peer") errInvalidRole = errors.New("invalid role") + ErrFailedToReadEntireMessage = errors.New("failed to read entire message") ErrNilStream = errors.New("nil stream") ErrInvalidLEB128EncodedData = errors.New("invalid LEB128 encoded data") ErrGreaterThanMaxSize = errors.New("greater than maximum size") diff --git a/dot/network/utils.go b/dot/network/utils.go index 201873c21a..61a8414839 100644 --- a/dot/network/utils.go +++ b/dot/network/utils.go @@ -220,7 +220,7 @@ func readStream(stream libp2pnetwork.Stream, bufPointer *[]byte, maxSize uint64) } if tot != int(length) { - return tot, fmt.Errorf("failed to read entire message: expected %d bytes, received %d bytes", length, tot) + return tot, fmt.Errorf("%w: expected %d bytes, received %d bytes", ErrFailedToReadEntireMessage, length, tot) } return tot, nil