diff --git a/dot/network/block_announce.go b/dot/network/block_announce.go index f6b3e91821..64338154cb 100644 --- a/dot/network/block_announce.go +++ b/dot/network/block_announce.go @@ -190,10 +190,10 @@ func (s *Service) validateBlockAnnounceHandshake(from peer.ID, hs Handshake) err // don't need to lock here, since function is always called inside the func returned by // `createNotificationsMessageHandler` which locks the map beforehand. - data, ok := np.getInboundHandshakeData(from) - if ok { + data := np.peersData.getInboundHandshakeData(from) + if data != nil { data.handshake = hs - np.inboundHandshakeData.Store(from, data) + np.peersData.setInboundHandshakeData(from, data) } // if peer has higher best block than us, begin syncing diff --git a/dot/network/block_announce_test.go b/dot/network/block_announce_test.go index 68c54280ef..3d9fd07314 100644 --- a/dot/network/block_announce_test.go +++ b/dot/network/block_announce_test.go @@ -4,7 +4,6 @@ package network import ( - "sync" "testing" "github.com/ChainSafe/gossamer/dot/types" @@ -160,10 +159,10 @@ func TestValidateBlockAnnounceHandshake(t *testing.T) { nodeA := createTestService(t, configA) nodeA.noGossip = true nodeA.notificationsProtocols[BlockAnnounceMsgType] = ¬ificationsProtocol{ - inboundHandshakeData: new(sync.Map), + peersData: newPeersData(), } testPeerID := peer.ID("noot") - nodeA.notificationsProtocols[BlockAnnounceMsgType].inboundHandshakeData.Store(testPeerID, &handshakeData{}) + nodeA.notificationsProtocols[BlockAnnounceMsgType].peersData.setInboundHandshakeData(testPeerID, &handshakeData{}) err := nodeA.validateBlockAnnounceHandshake(testPeerID, &BlockAnnounceHandshake{ BestBlockNumber: 100, diff --git a/dot/network/errors.go b/dot/network/errors.go index a9c2bd94f3..44199f4249 100644 --- a/dot/network/errors.go +++ b/dot/network/errors.go @@ -11,7 +11,6 @@ var ( errCannotValidateHandshake = errors.New("failed to validate handshake") errMessageTypeNotValid = errors.New("message type is not valid") errMessageIsNotHandshake = errors.New("failed to convert message to Handshake") - errMissingHandshakeMutex = errors.New("outboundHandshakeMutex does not exist") errInvalidHandshakeForPeer = errors.New("peer previously sent invalid handshake") errHandshakeTimeout = errors.New("handshake timeout reached") ) diff --git a/dot/network/host_test.go b/dot/network/host_test.go index 1e0faf2b47..1250b15181 100644 --- a/dot/network/host_test.go +++ b/dot/network/host_test.go @@ -348,24 +348,20 @@ func TestStreamCloseMetadataCleanup(t *testing.T) { info := nodeA.notificationsProtocols[BlockAnnounceMsgType] // Set handshake data to received - info.inboundHandshakeData.Store(nodeB.host.id(), &handshakeData{ + info.peersData.setInboundHandshakeData(nodeB.host.id(), &handshakeData{ received: true, validated: true, }) // Verify that handshake data exists. - _, ok := info.getInboundHandshakeData(nodeB.host.id()) - require.True(t, ok) + data := info.peersData.getInboundHandshakeData(nodeB.host.id()) + require.NotNil(t, data) - time.Sleep(time.Second) nodeB.host.close() - // Wait for cleanup - time.Sleep(time.Second) - // Verify that handshake data is cleared. - _, ok = info.getInboundHandshakeData(nodeB.host.id()) - require.False(t, ok) + data = info.peersData.getInboundHandshakeData(nodeB.host.id()) + require.Nil(t, data) } func Test_PeerSupportsProtocol(t *testing.T) { diff --git a/dot/network/inbound.go b/dot/network/inbound.go index 0437d09238..9d673a6f8e 100644 --- a/dot/network/inbound.go +++ b/dot/network/inbound.go @@ -64,7 +64,7 @@ func (s *Service) resetInboundStream(stream libp2pnetwork.Stream) { continue } - prtl.inboundHandshakeData.Delete(peerID) + prtl.peersData.deleteInboundHandshakeData(peerID) break } diff --git a/dot/network/notifications.go b/dot/network/notifications.go index 401ef4b069..e05f164c44 100644 --- a/dot/network/notifications.go +++ b/dot/network/notifications.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "io" - "sync" "time" "github.com/libp2p/go-libp2p-core/mux" @@ -61,56 +60,24 @@ type handshakeReader struct { } type notificationsProtocol struct { - protocolID protocol.ID - getHandshake HandshakeGetter - handshakeDecoder HandshakeDecoder - handshakeValidator HandshakeValidator - outboundHandshakeMutexes *sync.Map //map[peer.ID]*sync.Mutex - inboundHandshakeData *sync.Map //map[peer.ID]*handshakeData - outboundHandshakeData *sync.Map //map[peer.ID]*handshakeData + protocolID protocol.ID + getHandshake HandshakeGetter + handshakeDecoder HandshakeDecoder + handshakeValidator HandshakeValidator + peersData *peersData } func newNotificationsProtocol(protocolID protocol.ID, handshakeGetter HandshakeGetter, handshakeDecoder HandshakeDecoder, handshakeValidator HandshakeValidator) *notificationsProtocol { return ¬ificationsProtocol{ - protocolID: protocolID, - getHandshake: handshakeGetter, - handshakeValidator: handshakeValidator, - handshakeDecoder: handshakeDecoder, - outboundHandshakeMutexes: new(sync.Map), - inboundHandshakeData: new(sync.Map), - outboundHandshakeData: new(sync.Map), + protocolID: protocolID, + getHandshake: handshakeGetter, + handshakeValidator: handshakeValidator, + handshakeDecoder: handshakeDecoder, + peersData: newPeersData(), } } -func (n *notificationsProtocol) getInboundHandshakeData(pid peer.ID) (*handshakeData, bool) { - var ( - data interface{} - has bool - ) - - data, has = n.inboundHandshakeData.Load(pid) - if !has { - return nil, false - } - - return data.(*handshakeData), true -} - -func (n *notificationsProtocol) getOutboundHandshakeData(pid peer.ID) (*handshakeData, bool) { - var ( - data interface{} - has bool - ) - - data, has = n.outboundHandshakeData.Load(pid) - if !has { - return nil, false - } - - return data.(*handshakeData), true -} - type handshakeData struct { received bool validated bool @@ -131,18 +98,15 @@ func createDecoder(info *notificationsProtocol, handshakeDecoder HandshakeDecode return func(in []byte, peer peer.ID, inbound bool) (Message, error) { // if we don't have handshake data on this peer, or we haven't received the handshake from them already, // assume we are receiving the handshake - var ( - hsData *handshakeData - has bool - ) + var hsData *handshakeData if inbound { - hsData, has = info.getInboundHandshakeData(peer) + hsData = info.peersData.getInboundHandshakeData(peer) } else { - hsData, has = info.getOutboundHandshakeData(peer) + hsData = info.peersData.getOutboundHandshakeData(peer) } - if !has || !hsData.received { + if hsData == nil || !hsData.received { return handshakeDecoder(in) } @@ -185,11 +149,12 @@ func (s *Service) createNotificationsMessageHandler( // note: if this function is being called, it's being called via SetStreamHandler, // ie it is an inbound stream and we only send the handshake over it. // we do not send any other data over this stream, we would need to open a new outbound stream. - if _, has := info.getInboundHandshakeData(peer); !has { + hsData := info.peersData.getInboundHandshakeData(peer) + if hsData == nil { logger.Tracef("receiver: validating handshake using protocol %s", info.protocolID) - hsData := newHandshakeData(true, false, stream) - info.inboundHandshakeData.Store(peer, hsData) + hsData = newHandshakeData(true, false, stream) + info.peersData.setInboundHandshakeData(peer, hsData) err := info.handshakeValidator(peer, hs) if err != nil { @@ -200,7 +165,7 @@ func (s *Service) createNotificationsMessageHandler( } hsData.validated = true - info.inboundHandshakeData.Store(peer, hsData) + info.peersData.setInboundHandshakeData(peer, hsData) // once validated, send back a handshake resp, err := info.getHandshake() @@ -263,7 +228,7 @@ func closeOutboundStream(info *notificationsProtocol, peerID peer.ID, stream lib peerID, ) - info.outboundHandshakeData.Delete(peerID) + info.peersData.deleteOutboundHandshakeData(peerID) _ = stream.Close() } @@ -318,25 +283,25 @@ func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtoc } func (s *Service) sendHandshake(peer peer.ID, hs Handshake, info *notificationsProtocol) (libp2pnetwork.Stream, error) { - mu, has := info.outboundHandshakeMutexes.Load(peer) - if !has { - // this should not happen - return nil, errMissingHandshakeMutex - } - // multiple processes could each call this upcoming section, opening multiple streams and // sending multiple handshakes. thus, we need to have a per-peer and per-protocol lock - mu.(*sync.Mutex).Lock() - defer mu.(*sync.Mutex).Unlock() - hsData, has := info.getOutboundHandshakeData(peer) + // Note: we need to extract the mutex here since some sketchy test code + // sometimes deletes it from its peerid->mutex map in info.peersData + // so we cannot have a method on peersData to lock and unlock the mutex + // from the map + peerMutex := info.peersData.getMutex(peer) + peerMutex.Lock() + defer peerMutex.Unlock() + + hsData := info.peersData.getOutboundHandshakeData(peer) switch { - case has && !hsData.validated: + case hsData != nil && !hsData.validated: // peer has sent us an invalid handshake in the past, ignore return nil, errInvalidHandshakeForPeer - case has && hsData.validated: + case hsData != nil && hsData.validated: return hsData.stream, nil - case !has: + case hsData == nil: hsData = newHandshakeData(false, false, nil) } @@ -388,7 +353,7 @@ func (s *Service) sendHandshake(peer peer.ID, hs Handshake, info *notificationsP hsData.validated = false hsData.stream = nil _ = stream.Reset() - info.outboundHandshakeData.Store(peer, hsData) + info.peersData.setOutboundHandshakeData(peer, hsData) // don't delete handshake data, as we want to store that the handshake for this peer was invalid // and not to exchange messages over this protocol with it return nil, err @@ -396,7 +361,7 @@ func (s *Service) sendHandshake(peer peer.ID, hs Handshake, info *notificationsP hsData.validated = true hsData.handshake = resp - info.outboundHandshakeData.Store(peer, hsData) + info.peersData.setOutboundHandshakeData(peer, hsData) logger.Tracef("sender: validated handshake from peer %s using protocol %s", peer, info.protocolID) return hsData.stream, nil } @@ -419,7 +384,7 @@ func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer continue } - info.outboundHandshakeMutexes.Store(peer, new(sync.Mutex)) + info.peersData.setMutex(peer) go s.sendData(peer, hs, info, msg) } diff --git a/dot/network/notifications_test.go b/dot/network/notifications_test.go index 93c6761cf3..a62f121e4b 100644 --- a/dot/network/notifications_test.go +++ b/dot/network/notifications_test.go @@ -6,7 +6,6 @@ package network import ( "errors" "reflect" - "sync" "testing" "time" "unsafe" @@ -33,17 +32,16 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) { // create info and decoder info := ¬ificationsProtocol{ - protocolID: s.host.protocolID + blockAnnounceID, - getHandshake: s.getBlockAnnounceHandshake, - handshakeValidator: s.validateBlockAnnounceHandshake, - inboundHandshakeData: new(sync.Map), - outboundHandshakeData: new(sync.Map), + protocolID: s.host.protocolID + blockAnnounceID, + getHandshake: s.getBlockAnnounceHandshake, + handshakeValidator: s.validateBlockAnnounceHandshake, + peersData: newPeersData(), } decoder := createDecoder(info, decodeBlockAnnounceHandshake, decodeBlockAnnounceMessage) // haven't received handshake from peer testPeerID := peer.ID("QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ") - info.inboundHandshakeData.Store(testPeerID, &handshakeData{ + info.peersData.setInboundHandshakeData(testPeerID, &handshakeData{ received: false, }) @@ -73,9 +71,9 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) { require.NoError(t, err) // set handshake data to received - hsData, _ := info.getInboundHandshakeData(testPeerID) + hsData := info.peersData.getInboundHandshakeData(testPeerID) hsData.received = true - info.inboundHandshakeData.Store(testPeerID, hsData) + info.peersData.setInboundHandshakeData(testPeerID, hsData) msg, err = decoder(enc, testPeerID, true) require.NoError(t, err) require.Equal(t, testBlockAnnounce, msg) @@ -119,16 +117,15 @@ func TestCreateNotificationsMessageHandler_BlockAnnounce(t *testing.T) { // create info and handler info := ¬ificationsProtocol{ - protocolID: s.host.protocolID + blockAnnounceID, - getHandshake: s.getBlockAnnounceHandshake, - handshakeValidator: s.validateBlockAnnounceHandshake, - inboundHandshakeData: new(sync.Map), - outboundHandshakeData: new(sync.Map), + protocolID: s.host.protocolID + blockAnnounceID, + getHandshake: s.getBlockAnnounceHandshake, + handshakeValidator: s.validateBlockAnnounceHandshake, + peersData: newPeersData(), } handler := s.createNotificationsMessageHandler(info, s.handleBlockAnnounceMessage, nil) // set handshake data to received - info.inboundHandshakeData.Store(testPeerID, &handshakeData{ + info.peersData.setInboundHandshakeData(testPeerID, &handshakeData{ received: true, validated: true, }) @@ -156,11 +153,10 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T) // create info and handler info := ¬ificationsProtocol{ - protocolID: s.host.protocolID + blockAnnounceID, - getHandshake: s.getBlockAnnounceHandshake, - handshakeValidator: s.validateBlockAnnounceHandshake, - inboundHandshakeData: new(sync.Map), - outboundHandshakeData: new(sync.Map), + protocolID: s.host.protocolID + blockAnnounceID, + getHandshake: s.getBlockAnnounceHandshake, + handshakeValidator: s.validateBlockAnnounceHandshake, + peersData: newPeersData(), } handler := s.createNotificationsMessageHandler(info, s.handleBlockAnnounceMessage, nil) @@ -198,8 +194,8 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T) err = handler(stream, testHandshake) require.Equal(t, errCannotValidateHandshake, err) - data, has := info.getInboundHandshakeData(testPeerID) - require.True(t, has) + data := info.peersData.getInboundHandshakeData(testPeerID) + require.NotNil(t, data) require.True(t, data.received) require.False(t, data.validated) @@ -211,12 +207,12 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T) GenesisHash: s.blockState.GenesisHash(), } - info.inboundHandshakeData.Delete(testPeerID) + info.peersData.deleteInboundHandshakeData(testPeerID) err = handler(stream, testHandshake) require.NoError(t, err) - data, has = info.getInboundHandshakeData(testPeerID) - require.True(t, has) + data = info.peersData.getInboundHandshakeData(testPeerID) + require.NotNil(t, data) require.True(t, data.received) require.True(t, data.validated) } @@ -268,7 +264,7 @@ func Test_HandshakeTimeout(t *testing.T) { // clear handshake data from connection handler time.Sleep(time.Millisecond * 100) - info.outboundHandshakeData.Delete(nodeB.host.id()) + info.peersData.deleteOutboundHandshakeData(nodeB.host.id()) connAToB := nodeA.host.h.Network().ConnsToPeer(nodeB.host.id()) for _, stream := range connAToB[0].GetStreams() { _ = stream.Close() @@ -281,14 +277,14 @@ func Test_HandshakeTimeout(t *testing.T) { GenesisHash: common.Hash{2}, } - info.outboundHandshakeMutexes.Store(nodeB.host.id(), new(sync.Mutex)) + info.peersData.setMutex(nodeB.host.id()) go nodeA.sendData(nodeB.host.id(), testHandshakeMsg, info, nil) time.Sleep(time.Second) // handshake data shouldn't exist, as nodeB hasn't responded yet - _, ok := info.getOutboundHandshakeData(nodeB.host.id()) - require.False(t, ok) + data := info.peersData.getOutboundHandshakeData(nodeB.host.id()) + require.Nil(t, data) // a stream should be open until timeout connAToB = nodeA.host.h.Network().ConnsToPeer(nodeB.host.id()) @@ -299,8 +295,8 @@ func Test_HandshakeTimeout(t *testing.T) { time.Sleep(handshakeTimeout) // handshake data shouldn't exist still - _, ok = info.getOutboundHandshakeData(nodeB.host.id()) - require.False(t, ok) + data = info.peersData.getOutboundHandshakeData(nodeB.host.id()) + require.Nil(t, data) // stream should be closed connAToB = nodeA.host.h.Network().ConnsToPeer(nodeB.host.id()) @@ -350,16 +346,15 @@ func TestCreateNotificationsMessageHandler_HandleTransaction(t *testing.T) { // create info and handler info := ¬ificationsProtocol{ - protocolID: txnProtocolID, - getHandshake: srvc1.getTransactionHandshake, - handshakeValidator: validateTransactionHandshake, - inboundHandshakeData: new(sync.Map), - outboundHandshakeData: new(sync.Map), + protocolID: txnProtocolID, + getHandshake: srvc1.getTransactionHandshake, + handshakeValidator: validateTransactionHandshake, + peersData: newPeersData(), } handler := srvc1.createNotificationsMessageHandler(info, srvc1.handleTransactionMessage, txnBatchHandler) // set handshake data to received - info.inboundHandshakeData.Store(srvc2.host.id(), handshakeData{ + info.peersData.setInboundHandshakeData(srvc2.host.id(), &handshakeData{ received: true, validated: true, }) diff --git a/dot/network/peersdata.go b/dot/network/peersdata.go new file mode 100644 index 0000000000..5f1e747089 --- /dev/null +++ b/dot/network/peersdata.go @@ -0,0 +1,103 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package network + +import ( + "sync" + + "github.com/libp2p/go-libp2p-core/peer" +) + +type peersData struct { + mutexesMu sync.RWMutex + mutexes map[peer.ID]*sync.Mutex + inboundMu sync.RWMutex + inbound map[peer.ID]*handshakeData + outboundMu sync.RWMutex + outbound map[peer.ID]*handshakeData +} + +func newPeersData() *peersData { + return &peersData{ + mutexes: make(map[peer.ID]*sync.Mutex), + inbound: make(map[peer.ID]*handshakeData), + outbound: make(map[peer.ID]*handshakeData), + } +} + +func (p *peersData) setMutex(peerID peer.ID) { + p.mutexesMu.Lock() + defer p.mutexesMu.Unlock() + p.mutexes[peerID] = new(sync.Mutex) +} + +func (p *peersData) getMutex(peerID peer.ID) *sync.Mutex { + p.mutexesMu.RLock() + defer p.mutexesMu.RUnlock() + return p.mutexes[peerID] +} + +func (p *peersData) deleteMutex(peerID peer.ID) { + p.mutexesMu.Lock() + defer p.mutexesMu.Unlock() + delete(p.mutexes, peerID) +} + +func (p *peersData) getInboundHandshakeData(peerID peer.ID) (data *handshakeData) { + p.inboundMu.RLock() + defer p.inboundMu.RUnlock() + return p.inbound[peerID] +} + +func (p *peersData) setInboundHandshakeData(peerID peer.ID, data *handshakeData) { + p.inboundMu.Lock() + defer p.inboundMu.Unlock() + p.inbound[peerID] = data +} + +func (p *peersData) deleteInboundHandshakeData(peerID peer.ID) { + p.inboundMu.Lock() + defer p.inboundMu.Unlock() + delete(p.inbound, peerID) +} + +func (p *peersData) countInboundStreams() (count int64) { + p.inboundMu.RLock() + defer p.inboundMu.RUnlock() + for _, data := range p.inbound { + if data.stream != nil { + count++ + } + } + return count +} + +func (p *peersData) getOutboundHandshakeData(peerID peer.ID) (data *handshakeData) { + p.outboundMu.RLock() + defer p.outboundMu.RUnlock() + return p.outbound[peerID] +} + +func (p *peersData) setOutboundHandshakeData(peerID peer.ID, data *handshakeData) { + p.outboundMu.Lock() + defer p.outboundMu.Unlock() + p.outbound[peerID] = data +} + +func (p *peersData) deleteOutboundHandshakeData(peerID peer.ID) { + p.outboundMu.Lock() + defer p.outboundMu.Unlock() + delete(p.outbound, peerID) +} + +func (p *peersData) countOutboundStreams() (count int64) { + p.outboundMu.RLock() + defer p.outboundMu.RUnlock() + for _, data := range p.outbound { + if data.stream != nil { + count++ + } + } + return count +} diff --git a/dot/network/service.go b/dot/network/service.go index b0708a8e2d..2348090fce 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -283,16 +283,16 @@ func (s *Service) Start() error { // it creates a per-protocol mutex for sending outbound handshakes to the peer s.host.cm.connectHandler = func(peerID peer.ID) { for _, prtl := range s.notificationsProtocols { - prtl.outboundHandshakeMutexes.Store(peerID, new(sync.Mutex)) + prtl.peersData.setMutex(peerID) } } // when a peer gets disconnected, we should clear all handshake data we have for it. s.host.cm.disconnectHandler = func(peerID peer.ID) { for _, prtl := range s.notificationsProtocols { - prtl.outboundHandshakeMutexes.Delete(peerID) - prtl.inboundHandshakeData.Delete(peerID) - prtl.outboundHandshakeData.Delete(peerID) + prtl.peersData.deleteMutex(peerID) + prtl.peersData.deleteInboundHandshakeData(peerID) + prtl.peersData.deleteOutboundHandshakeData(peerID) } } @@ -374,26 +374,10 @@ func (s *Service) getNumStreams(protocolID byte, inbound bool) (count int64) { return 0 } - var hsData *sync.Map if inbound { - hsData = np.inboundHandshakeData - } else { - hsData = np.outboundHandshakeData + return np.peersData.countInboundStreams() } - - hsData.Range(func(_, data interface{}) bool { - if data == nil { - return true - } - - if data.(*handshakeData).stream != nil { - count++ - } - - return true - }) - - return count + return np.peersData.countOutboundStreams() } func (s *Service) logPeerCount() { @@ -632,8 +616,8 @@ func (s *Service) Peers() []common.PeerInfo { s.notificationsMu.RUnlock() for _, p := range s.host.peers() { - data, has := np.getInboundHandshakeData(p) - if !has || data.handshake == nil { + data := np.peersData.getInboundHandshakeData(p) + if data == nil || data.handshake == nil { peers = append(peers, common.PeerInfo{ PeerID: p.String(), }) diff --git a/dot/network/service_test.go b/dot/network/service_test.go index 65984ec97c..ea6e925398 100644 --- a/dot/network/service_test.go +++ b/dot/network/service_test.go @@ -274,7 +274,7 @@ func TestBroadcastDuplicateMessage(t *testing.T) { require.NotNil(t, stream) protocol := nodeA.notificationsProtocols[BlockAnnounceMsgType] - protocol.outboundHandshakeData.Store(nodeB.host.id(), &handshakeData{ + protocol.peersData.setOutboundHandshakeData(nodeB.host.id(), &handshakeData{ received: true, validated: true, stream: stream,