Skip to content

Commit

Permalink
chore(dot/network): replace sync.Map with map+mutex (#2284)
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Mar 10, 2022
1 parent 8840bb5 commit 5843324
Show file tree
Hide file tree
Showing 10 changed files with 190 additions and 149 deletions.
6 changes: 3 additions & 3 deletions dot/network/block_announce.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,10 @@ func (s *Service) validateBlockAnnounceHandshake(from peer.ID, hs Handshake) err

// don't need to lock here, since function is always called inside the func returned by
// `createNotificationsMessageHandler` which locks the map beforehand.
data, ok := np.getInboundHandshakeData(from)
if ok {
data := np.peersData.getInboundHandshakeData(from)
if data != nil {
data.handshake = hs
np.inboundHandshakeData.Store(from, data)
np.peersData.setInboundHandshakeData(from, data)
}

// if peer has higher best block than us, begin syncing
Expand Down
5 changes: 2 additions & 3 deletions dot/network/block_announce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package network

import (
"sync"
"testing"

"github.com/ChainSafe/gossamer/dot/types"
Expand Down Expand Up @@ -160,10 +159,10 @@ func TestValidateBlockAnnounceHandshake(t *testing.T) {
nodeA := createTestService(t, configA)
nodeA.noGossip = true
nodeA.notificationsProtocols[BlockAnnounceMsgType] = &notificationsProtocol{
inboundHandshakeData: new(sync.Map),
peersData: newPeersData(),
}
testPeerID := peer.ID("noot")
nodeA.notificationsProtocols[BlockAnnounceMsgType].inboundHandshakeData.Store(testPeerID, &handshakeData{})
nodeA.notificationsProtocols[BlockAnnounceMsgType].peersData.setInboundHandshakeData(testPeerID, &handshakeData{})

err := nodeA.validateBlockAnnounceHandshake(testPeerID, &BlockAnnounceHandshake{
BestBlockNumber: 100,
Expand Down
1 change: 0 additions & 1 deletion dot/network/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ var (
errCannotValidateHandshake = errors.New("failed to validate handshake")
errMessageTypeNotValid = errors.New("message type is not valid")
errMessageIsNotHandshake = errors.New("failed to convert message to Handshake")
errMissingHandshakeMutex = errors.New("outboundHandshakeMutex does not exist")
errInvalidHandshakeForPeer = errors.New("peer previously sent invalid handshake")
errHandshakeTimeout = errors.New("handshake timeout reached")
)
14 changes: 5 additions & 9 deletions dot/network/host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,24 +348,20 @@ func TestStreamCloseMetadataCleanup(t *testing.T) {
info := nodeA.notificationsProtocols[BlockAnnounceMsgType]

// Set handshake data to received
info.inboundHandshakeData.Store(nodeB.host.id(), &handshakeData{
info.peersData.setInboundHandshakeData(nodeB.host.id(), &handshakeData{
received: true,
validated: true,
})

// Verify that handshake data exists.
_, ok := info.getInboundHandshakeData(nodeB.host.id())
require.True(t, ok)
data := info.peersData.getInboundHandshakeData(nodeB.host.id())
require.NotNil(t, data)

time.Sleep(time.Second)
nodeB.host.close()

// Wait for cleanup
time.Sleep(time.Second)

// Verify that handshake data is cleared.
_, ok = info.getInboundHandshakeData(nodeB.host.id())
require.False(t, ok)
data = info.peersData.getInboundHandshakeData(nodeB.host.id())
require.Nil(t, data)
}

func Test_PeerSupportsProtocol(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion dot/network/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (s *Service) resetInboundStream(stream libp2pnetwork.Stream) {
continue
}

prtl.inboundHandshakeData.Delete(peerID)
prtl.peersData.deleteInboundHandshakeData(peerID)
break
}

Expand Down
105 changes: 35 additions & 70 deletions dot/network/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
"sync"
"time"

"github.com/libp2p/go-libp2p-core/mux"
Expand Down Expand Up @@ -61,56 +60,24 @@ type handshakeReader struct {
}

type notificationsProtocol struct {
protocolID protocol.ID
getHandshake HandshakeGetter
handshakeDecoder HandshakeDecoder
handshakeValidator HandshakeValidator
outboundHandshakeMutexes *sync.Map //map[peer.ID]*sync.Mutex
inboundHandshakeData *sync.Map //map[peer.ID]*handshakeData
outboundHandshakeData *sync.Map //map[peer.ID]*handshakeData
protocolID protocol.ID
getHandshake HandshakeGetter
handshakeDecoder HandshakeDecoder
handshakeValidator HandshakeValidator
peersData *peersData
}

func newNotificationsProtocol(protocolID protocol.ID, handshakeGetter HandshakeGetter,
handshakeDecoder HandshakeDecoder, handshakeValidator HandshakeValidator) *notificationsProtocol {
return &notificationsProtocol{
protocolID: protocolID,
getHandshake: handshakeGetter,
handshakeValidator: handshakeValidator,
handshakeDecoder: handshakeDecoder,
outboundHandshakeMutexes: new(sync.Map),
inboundHandshakeData: new(sync.Map),
outboundHandshakeData: new(sync.Map),
protocolID: protocolID,
getHandshake: handshakeGetter,
handshakeValidator: handshakeValidator,
handshakeDecoder: handshakeDecoder,
peersData: newPeersData(),
}
}

func (n *notificationsProtocol) getInboundHandshakeData(pid peer.ID) (*handshakeData, bool) {
var (
data interface{}
has bool
)

data, has = n.inboundHandshakeData.Load(pid)
if !has {
return nil, false
}

return data.(*handshakeData), true
}

func (n *notificationsProtocol) getOutboundHandshakeData(pid peer.ID) (*handshakeData, bool) {
var (
data interface{}
has bool
)

data, has = n.outboundHandshakeData.Load(pid)
if !has {
return nil, false
}

return data.(*handshakeData), true
}

type handshakeData struct {
received bool
validated bool
Expand All @@ -131,18 +98,15 @@ func createDecoder(info *notificationsProtocol, handshakeDecoder HandshakeDecode
return func(in []byte, peer peer.ID, inbound bool) (Message, error) {
// if we don't have handshake data on this peer, or we haven't received the handshake from them already,
// assume we are receiving the handshake
var (
hsData *handshakeData
has bool
)

var hsData *handshakeData
if inbound {
hsData, has = info.getInboundHandshakeData(peer)
hsData = info.peersData.getInboundHandshakeData(peer)
} else {
hsData, has = info.getOutboundHandshakeData(peer)
hsData = info.peersData.getOutboundHandshakeData(peer)
}

if !has || !hsData.received {
if hsData == nil || !hsData.received {
return handshakeDecoder(in)
}

Expand Down Expand Up @@ -185,11 +149,12 @@ func (s *Service) createNotificationsMessageHandler(
// note: if this function is being called, it's being called via SetStreamHandler,
// ie it is an inbound stream and we only send the handshake over it.
// we do not send any other data over this stream, we would need to open a new outbound stream.
if _, has := info.getInboundHandshakeData(peer); !has {
hsData := info.peersData.getInboundHandshakeData(peer)
if hsData == nil {
logger.Tracef("receiver: validating handshake using protocol %s", info.protocolID)

hsData := newHandshakeData(true, false, stream)
info.inboundHandshakeData.Store(peer, hsData)
hsData = newHandshakeData(true, false, stream)
info.peersData.setInboundHandshakeData(peer, hsData)

err := info.handshakeValidator(peer, hs)
if err != nil {
Expand All @@ -200,7 +165,7 @@ func (s *Service) createNotificationsMessageHandler(
}

hsData.validated = true
info.inboundHandshakeData.Store(peer, hsData)
info.peersData.setInboundHandshakeData(peer, hsData)

// once validated, send back a handshake
resp, err := info.getHandshake()
Expand Down Expand Up @@ -263,7 +228,7 @@ func closeOutboundStream(info *notificationsProtocol, peerID peer.ID, stream lib
peerID,
)

info.outboundHandshakeData.Delete(peerID)
info.peersData.deleteOutboundHandshakeData(peerID)
_ = stream.Close()
}

Expand Down Expand Up @@ -318,25 +283,25 @@ func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtoc
}

func (s *Service) sendHandshake(peer peer.ID, hs Handshake, info *notificationsProtocol) (libp2pnetwork.Stream, error) {
mu, has := info.outboundHandshakeMutexes.Load(peer)
if !has {
// this should not happen
return nil, errMissingHandshakeMutex
}

// multiple processes could each call this upcoming section, opening multiple streams and
// sending multiple handshakes. thus, we need to have a per-peer and per-protocol lock
mu.(*sync.Mutex).Lock()
defer mu.(*sync.Mutex).Unlock()

hsData, has := info.getOutboundHandshakeData(peer)
// Note: we need to extract the mutex here since some sketchy test code
// sometimes deletes it from its peerid->mutex map in info.peersData
// so we cannot have a method on peersData to lock and unlock the mutex
// from the map
peerMutex := info.peersData.getMutex(peer)
peerMutex.Lock()
defer peerMutex.Unlock()

hsData := info.peersData.getOutboundHandshakeData(peer)
switch {
case has && !hsData.validated:
case hsData != nil && !hsData.validated:
// peer has sent us an invalid handshake in the past, ignore
return nil, errInvalidHandshakeForPeer
case has && hsData.validated:
case hsData != nil && hsData.validated:
return hsData.stream, nil
case !has:
case hsData == nil:
hsData = newHandshakeData(false, false, nil)
}

Expand Down Expand Up @@ -388,15 +353,15 @@ func (s *Service) sendHandshake(peer peer.ID, hs Handshake, info *notificationsP
hsData.validated = false
hsData.stream = nil
_ = stream.Reset()
info.outboundHandshakeData.Store(peer, hsData)
info.peersData.setOutboundHandshakeData(peer, hsData)
// don't delete handshake data, as we want to store that the handshake for this peer was invalid
// and not to exchange messages over this protocol with it
return nil, err
}

hsData.validated = true
hsData.handshake = resp
info.outboundHandshakeData.Store(peer, hsData)
info.peersData.setOutboundHandshakeData(peer, hsData)
logger.Tracef("sender: validated handshake from peer %s using protocol %s", peer, info.protocolID)
return hsData.stream, nil
}
Expand All @@ -419,7 +384,7 @@ func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer
continue
}

info.outboundHandshakeMutexes.Store(peer, new(sync.Mutex))
info.peersData.setMutex(peer)

go s.sendData(peer, hs, info, msg)
}
Expand Down
Loading

0 comments on commit 5843324

Please sign in to comment.