diff --git a/dot/network/host_test.go b/dot/network/host_test.go index 41c50ca53a..918075937c 100644 --- a/dot/network/host_test.go +++ b/dot/network/host_test.go @@ -476,3 +476,51 @@ func Test_RemoveReservedPeers(t *testing.T) { err = nodeA.host.removeReservedPeers("failing peer ID") require.Error(t, err) } + +func TestStreamCloseEOF(t *testing.T) { + basePathA := utils.NewTestBasePath(t, "nodeA") + configA := &Config{ + BasePath: basePathA, + Port: 7001, + NoBootstrap: true, + NoMDNS: true, + } + + nodeA := createTestService(t, configA) + nodeA.noGossip = true + + basePathB := utils.NewTestBasePath(t, "nodeB") + + configB := &Config{ + BasePath: basePathB, + Port: 7002, + NoBootstrap: true, + NoMDNS: true, + } + + nodeB := createTestService(t, configB) + nodeB.noGossip = true + handler := newTestStreamHandler(testBlockRequestMessageDecoder) + nodeB.host.registerStreamHandler("", handler.handleStream) + require.False(t, handler.exit) + + addrInfoB := nodeB.host.addrInfo() + err := nodeA.host.connect(addrInfoB) + // retry connect if "failed to dial" error + if failedToDial(err) { + time.Sleep(TestBackoffTimeout) + err = nodeA.host.connect(addrInfoB) + } + require.NoError(t, err) + + stream, err := nodeA.host.send(addrInfoB.ID, nodeB.host.protocolID, testBlockRequestMessage) + require.NoError(t, err) + require.False(t, handler.exit) + + err = stream.Close() + require.NoError(t, err) + + time.Sleep(TestBackoffTimeout) + + require.True(t, handler.exit) +} diff --git a/dot/network/service.go b/dot/network/service.go index e6db0e0bd5..7a64e5a243 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -580,8 +580,8 @@ func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder for { tot, err := readStream(stream, msgBytes[:]) - if err == io.EOF { - continue + if errors.Is(err, io.EOF) { + return } else if err != nil { logger.Trace("failed to read from stream", "peer", stream.Conn().RemotePeer(), "protocol", stream.Protocol(), "error", err) _ = stream.Close() diff --git a/dot/network/test_helpers.go b/dot/network/test_helpers.go index f1912e25ff..3eb8ba85a1 100644 --- a/dot/network/test_helpers.go +++ b/dot/network/test_helpers.go @@ -1,6 +1,7 @@ package network import ( + "errors" "io" "math/big" @@ -89,6 +90,7 @@ func testBlockResponseMessage() *BlockResponseMessage { type testStreamHandler struct { messages map[peer.ID][]Message decoder messageDecoder + exit bool } func newTestStreamHandler(decoder messageDecoder) *testStreamHandler { @@ -135,10 +137,14 @@ func (s *testStreamHandler) readStream(stream libp2pnetwork.Stream, peer peer.ID msgBytes = make([]byte, maxMessageSize) ) + defer func() { + s.exit = true + }() + for { tot, err := readStream(stream, msgBytes) - if err == io.EOF { - continue + if errors.Is(err, io.EOF) { + return } else if err != nil { logger.Debug("failed to read from stream", "protocol", stream.Protocol(), "error", err) _ = stream.Close()