From d2ee47e20f239153355a3958afcd02eb1e26f7dd Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 30 May 2022 10:55:59 -0400 Subject: [PATCH] fix(lib/grandpa): capped number of tracked vote messages (#2485) - Vote messages tracker - Removes oldest vote message when tracker capacity is reached - Efficient removal of multiple messages at any place in the tracker queue (linked list) if they get processed - Efficient removal of oldest message - Uses a bit more space to store each block hash + authority ID, for each vote message - Order is not modified for the same vote message (same block hash and authority id) - Discard vote messages for more than 1 round in the future from the state round (thanks [andresilva](https://github.com/andresilva)) - Discard vote messages for more than 1 round in the past from the state round (thanks [andresilva](https://github.com/andresilva)) - Disable `addCatchUpResponse` (not implemented yet) to avoid a possible memory leak/abuse, see #1531 - Comment with issue number about the reputation change of peers for bad vote messages Co-authored-by: Timothy Wu --- lib/grandpa/message_tracker.go | 62 +++-- lib/grandpa/message_tracker_test.go | 71 +++--- lib/grandpa/vote_message.go | 76 +++--- lib/grandpa/votes_tracker.go | 149 ++++++++++++ lib/grandpa/votes_tracker_test.go | 364 ++++++++++++++++++++++++++++ 5 files changed, 625 insertions(+), 97 deletions(-) create mode 100644 lib/grandpa/votes_tracker.go create mode 100644 lib/grandpa/votes_tracker_test.go diff --git a/lib/grandpa/message_tracker.go b/lib/grandpa/message_tracker.go index 425d063051..00e7ef801a 100644 --- a/lib/grandpa/message_tracker.go +++ b/lib/grandpa/message_tracker.go @@ -9,7 +9,7 @@ import ( "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/crypto/ed25519" + "github.com/libp2p/go-libp2p-core/peer" ) // tracker keeps track of messages that have been received, but have failed to @@ -18,8 +18,8 @@ import ( type tracker struct { blockState BlockState handler *MessageHandler - // map of vote block hash -> array of VoteMessages for that hash - voteMessages map[common.Hash]map[ed25519.PublicKeyBytes]*networkVoteMessage + votes votesTracker + // map of commit block hash to commit message commitMessages map[common.Hash]*CommitMessage mapLock sync.Mutex @@ -32,10 +32,11 @@ type tracker struct { } func newTracker(bs BlockState, handler *MessageHandler) *tracker { + const votesCapacity = 1000 return &tracker{ blockState: bs, handler: handler, - voteMessages: make(map[common.Hash]map[ed25519.PublicKeyBytes]*networkVoteMessage), + votes: newVotesTracker(votesCapacity), commitMessages: make(map[common.Hash]*CommitMessage), mapLock: sync.Mutex{}, in: bs.GetImportedBlockNotifierChannel(), @@ -53,21 +54,15 @@ func (t *tracker) stop() { t.blockState.FreeImportedBlockNotifierChannel(t.in) } -func (t *tracker) addVote(v *networkVoteMessage) { - if v.msg == nil { +func (t *tracker) addVote(peerID peer.ID, message *VoteMessage) { + if message == nil { return } t.mapLock.Lock() defer t.mapLock.Unlock() - msgs, has := t.voteMessages[v.msg.Message.BlockHash] - if !has { - msgs = make(map[ed25519.PublicKeyBytes]*networkVoteMessage) - t.voteMessages[v.msg.Message.BlockHash] = msgs - } - - msgs[v.msg.Message.AuthorityID] = v + t.votes.add(peerID, message) } func (t *tracker) addCommit(cm *CommitMessage) { @@ -76,10 +71,11 @@ func (t *tracker) addCommit(cm *CommitMessage) { t.commitMessages[cm.Vote.Hash] = cm } -func (t *tracker) addCatchUpResponse(cr *CatchUpResponse) { +func (t *tracker) addCatchUpResponse(_ *CatchUpResponse) { t.catchUpResponseMessageMutex.Lock() defer t.catchUpResponseMessageMutex.Unlock() - t.catchUpResponseMessages[cr.Round] = cr + // uncomment when usage is setup properly, see #1531 + // t.catchUpResponseMessages[cr.Round] = cr } func (t *tracker) handleBlocks() { @@ -108,18 +104,18 @@ func (t *tracker) handleBlock(b *types.Block) { defer t.mapLock.Unlock() h := b.Header.Hash() - if vms, has := t.voteMessages[h]; has { - for _, v := range vms { - // handleMessage would never error for vote message - _, err := t.handler.handleMessage(v.from, v.msg) - if err != nil { - logger.Warnf("failed to handle vote message %v: %s", v, err) - } + vms := t.votes.messages(h) + for _, v := range vms { + // handleMessage would never error for vote message + _, err := t.handler.handleMessage(v.from, v.msg) + if err != nil { + logger.Warnf("failed to handle vote message %v: %s", v, err) } - - delete(t.voteMessages, h) } + // delete block hash that may or may not be in the tracker. + t.votes.delete(h) + if cm, has := t.commitMessages[h]; has { _, err := t.handler.handleMessage("", cm) if err != nil { @@ -134,17 +130,17 @@ func (t *tracker) handleTick() { t.mapLock.Lock() defer t.mapLock.Unlock() - for _, vms := range t.voteMessages { - for _, v := range vms { + for _, networkVoteMessage := range t.votes.networkVoteMessages() { + peerID := networkVoteMessage.from + message := networkVoteMessage.msg + _, err := t.handler.handleMessage(peerID, message) + if err != nil { // handleMessage would never error for vote message - _, err := t.handler.handleMessage(v.from, v.msg) - if err != nil { - logger.Debugf("failed to handle vote message %v: %s", v, err) - } + logger.Debugf("failed to handle vote message %v from peer id %s: %s", message, peerID, err) + } - if v.msg.Round < t.handler.grandpa.state.round && v.msg.SetID == t.handler.grandpa.state.setID { - delete(t.voteMessages, v.msg.Message.BlockHash) - } + if message.Round < t.handler.grandpa.state.round && message.SetID == t.handler.grandpa.state.setID { + t.votes.delete(message.Message.BlockHash) } } diff --git a/lib/grandpa/message_tracker_test.go b/lib/grandpa/message_tracker_test.go index 9388baa54f..56a43adeef 100644 --- a/lib/grandpa/message_tracker_test.go +++ b/lib/grandpa/message_tracker_test.go @@ -16,6 +16,24 @@ import ( "github.com/stretchr/testify/require" ) +// getMessageFromVotesTracker returns the vote message +// from the votes tracker for the given block hash and authority ID. +func getMessageFromVotesTracker(votes votesTracker, + blockHash common.Hash, authorityID ed25519.PublicKeyBytes) ( + message *VoteMessage) { + authorityIDToElement, has := votes.mapping[blockHash] + if !has { + return nil + } + + element, ok := authorityIDToElement[authorityID] + if !ok { + return nil + } + + return element.Value.(networkVoteMessage).msg +} + func TestMessageTracker_ValidateMessage(t *testing.T) { kr, err := keystore.NewEd25519Keyring() require.NoError(t, err) @@ -33,13 +51,11 @@ func TestMessageTracker_ValidateMessage(t *testing.T) { require.NoError(t, err) gs.keypair = kr.Bob().(*ed25519.Keypair) - expected := &networkVoteMessage{ - msg: msg, - } - _, err = gs.validateVoteMessage("", msg) require.Equal(t, err, ErrBlockDoesNotExist) - require.Equal(t, expected, gs.tracker.voteMessages[fake.Hash()][kr.Alice().Public().(*ed25519.PublicKey).AsBytes()]) + authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes() + voteMessage := getMessageFromVotesTracker(gs.tracker.votes, fake.Hash(), authorityID) + require.Equal(t, msg, voteMessage) } func TestMessageTracker_SendMessage(t *testing.T) { @@ -72,13 +88,11 @@ func TestMessageTracker_SendMessage(t *testing.T) { require.NoError(t, err) gs.keypair = kr.Bob().(*ed25519.Keypair) - expected := &networkVoteMessage{ - msg: msg, - } - _, err = gs.validateVoteMessage("", msg) require.Equal(t, err, ErrBlockDoesNotExist) - require.Equal(t, expected, gs.tracker.voteMessages[next.Hash()][kr.Alice().Public().(*ed25519.PublicKey).AsBytes()]) + authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes() + voteMessage := getMessageFromVotesTracker(gs.tracker.votes, next.Hash(), authorityID) + require.Equal(t, msg, voteMessage) err = gs.blockState.(*state.BlockState).AddBlock(&types.Block{ Header: *next, @@ -126,13 +140,11 @@ func TestMessageTracker_ProcessMessage(t *testing.T) { require.NoError(t, err) gs.keypair = kr.Bob().(*ed25519.Keypair) - expected := &networkVoteMessage{ - msg: msg, - } - _, err = gs.validateVoteMessage("", msg) require.Equal(t, ErrBlockDoesNotExist, err) - require.Equal(t, expected, gs.tracker.voteMessages[next.Hash()][kr.Alice().Public().(*ed25519.PublicKey).AsBytes()]) + authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes() + voteMessage := getMessageFromVotesTracker(gs.tracker.votes, next.Hash(), authorityID) + require.Equal(t, msg, voteMessage) err = gs.blockState.(*state.BlockState).AddBlock(&types.Block{ Header: *next, @@ -147,7 +159,7 @@ func TestMessageTracker_ProcessMessage(t *testing.T) { } pv, has := gs.prevotes.Load(kr.Alice().Public().(*ed25519.PublicKey).AsBytes()) require.True(t, has) - require.Equal(t, expectedVote, &pv.(*SignedVote).Vote, gs.tracker.voteMessages) + require.Equal(t, expectedVote, &pv.(*SignedVote).Vote, gs.tracker.votes) } func TestMessageTracker_MapInsideMap(t *testing.T) { @@ -163,8 +175,8 @@ func TestMessageTracker_MapInsideMap(t *testing.T) { } hash := header.Hash() - _, ok := gs.tracker.voteMessages[hash] - require.False(t, ok) + messages := gs.tracker.votes.messages(hash) + require.Empty(t, messages) gs.keypair = kr.Alice().(*ed25519.Keypair) authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes() @@ -172,15 +184,10 @@ func TestMessageTracker_MapInsideMap(t *testing.T) { require.NoError(t, err) gs.keypair = kr.Bob().(*ed25519.Keypair) - gs.tracker.addVote(&networkVoteMessage{ - msg: msg, - }) - - voteMsgs, ok := gs.tracker.voteMessages[hash] - require.True(t, ok) + gs.tracker.addVote("", msg) - _, ok = voteMsgs[authorityID] - require.True(t, ok) + voteMessage := getMessageFromVotesTracker(gs.tracker.votes, hash, authorityID) + require.NotEmpty(t, voteMessage) } func TestMessageTracker_handleTick(t *testing.T) { @@ -197,9 +204,7 @@ func TestMessageTracker_handleTick(t *testing.T) { BlockHash: testHash, }, } - gs.tracker.addVote(&networkVoteMessage{ - msg: msg, - }) + gs.tracker.addVote("", msg) gs.tracker.handleTick() @@ -212,7 +217,7 @@ func TestMessageTracker_handleTick(t *testing.T) { } // shouldn't be deleted as round in message >= grandpa round - require.Equal(t, 1, len(gs.tracker.voteMessages[testHash])) + require.Len(t, gs.tracker.votes.messages(testHash), 1) gs.state.round = 1 msg = &VoteMessage{ @@ -221,9 +226,7 @@ func TestMessageTracker_handleTick(t *testing.T) { BlockHash: testHash, }, } - gs.tracker.addVote(&networkVoteMessage{ - msg: msg, - }) + gs.tracker.addVote("", msg) gs.tracker.handleTick() @@ -235,5 +238,5 @@ func TestMessageTracker_handleTick(t *testing.T) { } // should be deleted as round in message < grandpa round - require.Empty(t, len(gs.tracker.voteMessages[testHash])) + require.Empty(t, gs.tracker.votes.messages(testHash)) } diff --git a/lib/grandpa/vote_message.go b/lib/grandpa/vote_message.go index f2849ab09c..d54e605fa1 100644 --- a/lib/grandpa/vote_message.go +++ b/lib/grandpa/vote_message.go @@ -126,11 +126,15 @@ func (s *Service) validateVoteMessage(from peer.ID, m *VoteMessage) (*Vote, erro // check for message signature pk, err := ed25519.NewPublicKey(m.Message.AuthorityID[:]) if err != nil { + // TODO Affect peer reputation + // https://github.com/ChainSafe/gossamer/issues/2505 return nil, err } err = validateMessageSignature(pk, m) if err != nil { + // TODO Affect peer reputation + // https://github.com/ChainSafe/gossamer/issues/2505 return nil, err } @@ -138,39 +142,54 @@ func (s *Service) validateVoteMessage(from peer.ID, m *VoteMessage) (*Vote, erro return nil, ErrSetIDMismatch } - // check that vote is for current round - if m.Round != s.state.round { - if m.Round < s.state.round { - // peer doesn't know round was finalised, send out another commit message - header, err := s.blockState.GetFinalisedHeader(m.Round, m.SetID) - if err != nil { - return nil, err - } + const maxRoundsLag = 1 + minRoundAccepted := s.state.round - maxRoundsLag + if minRoundAccepted > s.state.round { + // we overflowed below 0 so set the minimum to 0. + minRoundAccepted = 0 + } - cm, err := s.newCommitMessage(header, m.Round) - if err != nil { - return nil, err - } + const maxRoundsAhead = 1 + maxRoundAccepted := s.state.round + maxRoundsAhead - // send finalised block from previous round to network - msg, err := cm.ToConsensusMessage() - if err != nil { - return nil, err - } + if m.Round < minRoundAccepted || m.Round > maxRoundAccepted { + // Discard message + // TODO: affect peer reputation, this is shameful impolite behaviour + // https://github.com/ChainSafe/gossamer/issues/2505 + return nil, nil //nolint:nilnil + } - if err = s.network.SendMessage(from, msg); err != nil { - logger.Warnf("failed to send CommitMessage: %s", err) - } - } else { - // round is higher than ours, perhaps we are behind. store vote in tracker for now - s.tracker.addVote(&networkVoteMessage{ - from: from, - msg: m, - }) + if m.Round < s.state.round { + // message round is lagging by 1 + // peer doesn't know round was finalised, send out another commit message + header, err := s.blockState.GetFinalisedHeader(m.Round, m.SetID) + if err != nil { + return nil, err + } + + cm, err := s.newCommitMessage(header, m.Round) + if err != nil { + return nil, err + } + + // send finalised block from previous round to network + msg, err := cm.ToConsensusMessage() + if err != nil { + return nil, err + } + + if err = s.network.SendMessage(from, msg); err != nil { + logger.Warnf("failed to send CommitMessage: %s", err) } // TODO: get justification if your round is lower, or just do catch-up? (#1815) return nil, errRoundMismatch(m.Round, s.state.round) + } else if m.Round > s.state.round { + // Message round is higher by 1 than the round of our state, + // we may be lagging behind, so store the message in the tracker + // for processing later in the coming few milliseconds. + s.tracker.addVote(from, m) + return nil, errRoundMismatch(m.Round, s.state.round) } // check for equivocation ie. multiple votes within one subround @@ -192,10 +211,7 @@ func (s *Service) validateVoteMessage(from peer.ID, m *VoteMessage) (*Vote, erro errors.Is(err, blocktree.ErrDescendantNotFound) || errors.Is(err, blocktree.ErrEndNodeNotFound) || errors.Is(err, blocktree.ErrStartNodeNotFound) { - s.tracker.addVote(&networkVoteMessage{ - from: from, - msg: m, - }) + s.tracker.addVote(from, m) } if err != nil { return nil, err diff --git a/lib/grandpa/votes_tracker.go b/lib/grandpa/votes_tracker.go new file mode 100644 index 0000000000..ed69088e5c --- /dev/null +++ b/lib/grandpa/votes_tracker.go @@ -0,0 +1,149 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "container/list" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/crypto/ed25519" + "github.com/libp2p/go-libp2p-core/peer" +) + +// votesTracker tracks vote messages that could +// not be processed, and removes the oldest ones once +// its maximum capacity is reached. +// It is NOT THREAD SAFE to use. +type votesTracker struct { + // map of vote block hash to authority ID (ed25519 public Key) + // to linked list element pointer + mapping map[common.Hash]map[ed25519.PublicKeyBytes]*list.Element + // double linked list of voteMessageData (peer ID + Vote Message) + linkedList *list.List + capacity int +} + +// newVotesTracker creates a new vote message tracker +// with the capacity specified. +func newVotesTracker(capacity int) votesTracker { + return votesTracker{ + mapping: make(map[common.Hash]map[ed25519.PublicKeyBytes]*list.Element, capacity), + linkedList: list.New(), + capacity: capacity, + } +} + +// add adds a vote message to the vote message tracker. +// If the vote message tracker capacity is reached, +// the oldest vote message is removed. +func (vt *votesTracker) add(peerID peer.ID, voteMessage *VoteMessage) { + signedMessage := voteMessage.Message + blockHash := signedMessage.BlockHash + authorityID := signedMessage.AuthorityID + + authorityIDToElement, blockHashExists := vt.mapping[blockHash] + if blockHashExists { + element, voteExists := authorityIDToElement[authorityID] + if voteExists { + // vote already exists so override the vote for the authority ID; + // do not move the list element in the linked list to avoid + // someone re-sending an equivocatory vote message and going at the + // front of the list, hence erasing other possible valid vote messages + // in the tracker. + element.Value = networkVoteMessage{ + from: peerID, + msg: voteMessage, + } + return + } + // continue below and add the authority ID and data to the tracker. + } else { + // add new block hash in tracker + authorityIDToElement = make(map[ed25519.PublicKeyBytes]*list.Element) + vt.mapping[blockHash] = authorityIDToElement + // continue below and add the authority ID and data to the tracker. + } + + vt.cleanup() + elementData := networkVoteMessage{ + from: peerID, + msg: voteMessage, + } + element := vt.linkedList.PushFront(elementData) + authorityIDToElement[authorityID] = element +} + +// cleanup removes the oldest vote message from the tracker +// if the number of vote messages is at the tracker capacity. +// This method is designed to be called automatically from the +// add method and should not be called elsewhere. +func (vt *votesTracker) cleanup() { + if vt.linkedList.Len() < vt.capacity { + return + } + + oldestElement := vt.linkedList.Back() + vt.linkedList.Remove(oldestElement) + + oldestData := oldestElement.Value.(networkVoteMessage) + oldestBlockHash := oldestData.msg.Message.BlockHash + oldestAuthorityID := oldestData.msg.Message.AuthorityID + + authIDToElement := vt.mapping[oldestBlockHash] + + delete(authIDToElement, oldestAuthorityID) + if len(authIDToElement) == 0 { + delete(vt.mapping, oldestBlockHash) + } +} + +// delete deletes all the vote messages for a particular +// block hash from the vote messages tracker. +func (vt *votesTracker) delete(blockHash common.Hash) { + authIDToElement, has := vt.mapping[blockHash] + if !has { + return + } + + for _, element := range authIDToElement { + vt.linkedList.Remove(element) + } + + delete(vt.mapping, blockHash) +} + +// messages returns all the vote messages +// for a particular block hash from the tracker as a slice +// of networkVoteMessage. There is no order in the slice. +// It returns nil if the block hash does not exist. +func (vt *votesTracker) messages(blockHash common.Hash) ( + messages []networkVoteMessage) { + authIDToElement, ok := vt.mapping[blockHash] + if !ok { + // Note authIDToElement cannot be empty + return nil + } + + messages = make([]networkVoteMessage, 0, len(authIDToElement)) + for _, element := range authIDToElement { + message := element.Value.(networkVoteMessage) + messages = append(messages, message) + } + return messages +} + +// networkVoteMessages returns all pairs of +// peer id + message stored in the tracker +// as a slice of networkVoteMessages. +func (vt *votesTracker) networkVoteMessages() ( + messages []networkVoteMessage) { + messages = make([]networkVoteMessage, 0, vt.linkedList.Len()) + for _, authorityIDToElement := range vt.mapping { + for _, element := range authorityIDToElement { + message := element.Value.(networkVoteMessage) + messages = append(messages, message) + } + } + return messages +} diff --git a/lib/grandpa/votes_tracker_test.go b/lib/grandpa/votes_tracker_test.go new file mode 100644 index 0000000000..a7a8c2c066 --- /dev/null +++ b/lib/grandpa/votes_tracker_test.go @@ -0,0 +1,364 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "container/list" + "sort" + "testing" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/crypto/ed25519" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// buildVoteMessage creates a test vote message using the +// given block hash and authority ID only. +func buildVoteMessage(blockHash common.Hash, + authorityID ed25519.PublicKeyBytes) *VoteMessage { + return &VoteMessage{ + Message: SignedMessage{ + BlockHash: blockHash, + AuthorityID: authorityID, + }, + } +} + +func wrapVoteMessageWithPeerID(voteMessage *VoteMessage, + peerID peer.ID) networkVoteMessage { + return networkVoteMessage{ + from: peerID, + msg: voteMessage, + } +} + +func assertVotesMapping(t *testing.T, + mapping map[common.Hash]map[ed25519.PublicKeyBytes]*list.Element, + expected map[common.Hash]map[ed25519.PublicKeyBytes]networkVoteMessage) { + t.Helper() + + require.Len(t, mapping, len(expected), "mapping does not have the expected length") + for expectedBlockHash, expectedAuthIDToMessage := range expected { + submap, ok := mapping[expectedBlockHash] + require.Truef(t, ok, "block hash %s not found in mapping", expectedBlockHash) + require.Lenf(t, submap, len(expectedAuthIDToMessage), + "submapping for block hash %s does not have the expected length", expectedBlockHash) + for expectedAuthorityID, expectedNetworkVoteMessage := range expectedAuthIDToMessage { + element, ok := submap[expectedAuthorityID] + assert.Truef(t, ok, + "submapping for block hash %s does not have expected authority id %s", + expectedBlockHash, expectedAuthorityID) + actualNetworkVoteMessage := element.Value.(networkVoteMessage) + assert.Equalf(t, expectedNetworkVoteMessage, actualNetworkVoteMessage, + "network vote message for block hash %s and authority id %s is not as expected", + expectedBlockHash, expectedAuthorityID) + } + } +} + +func Test_newVotesTracker(t *testing.T) { + t.Parallel() + + const capacity = 1 + expected := votesTracker{ + mapping: make(map[common.Hash]map[ed25519.PublicKeyBytes]*list.Element, capacity), + linkedList: list.New(), + capacity: capacity, + } + vt := newVotesTracker(capacity) + + assert.Equal(t, expected, vt) +} + +// We cannot really unit test each method independently +// due to the dependency on the double linked list from +// the standard package `list` which has private fields +// which cannot be set. +// For example we cannot assert the votes tracker mapping +// entirely due to the linked list elements unexported fields. + +func Test_votesTracker_cleanup(t *testing.T) { + t.Parallel() + + t.Run("in same block", func(t *testing.T) { + t.Parallel() + + const capacity = 2 + tracker := newVotesTracker(capacity) + + blockHashA := common.Hash{0xa} + + authIDA := ed25519.PublicKeyBytes{0xa} + authIDB := ed25519.PublicKeyBytes{0xb} + authIDC := ed25519.PublicKeyBytes{0xc} + + messageBlockAAuthA := buildVoteMessage(blockHashA, authIDA) + messageBlockAAuthB := buildVoteMessage(blockHashA, authIDB) + messageBlockAAuthC := buildVoteMessage(blockHashA, authIDC) + + const somePeer = peer.ID("abc") + + tracker.add(somePeer, messageBlockAAuthA) + tracker.add(somePeer, messageBlockAAuthB) + // Add third message for block A and authority id C. + // This triggers a cleanup removing the oldest message + // which is for block A and authority id A. + tracker.add(somePeer, messageBlockAAuthC) + assertVotesMapping(t, tracker.mapping, map[common.Hash]map[ed25519.PublicKeyBytes]networkVoteMessage{ + blockHashA: { + authIDB: wrapVoteMessageWithPeerID(messageBlockAAuthB, somePeer), + authIDC: wrapVoteMessageWithPeerID(messageBlockAAuthC, somePeer), + }, + }) + }) + + t.Run("remove entire block", func(t *testing.T) { + t.Parallel() + + const capacity = 2 + tracker := newVotesTracker(capacity) + + blockHashA := common.Hash{0xa} + blockHashB := common.Hash{0xb} + + authIDA := ed25519.PublicKeyBytes{0xa} + authIDB := ed25519.PublicKeyBytes{0xb} + + messageBlockAAuthA := buildVoteMessage(blockHashA, authIDA) + messageBlockBAuthA := buildVoteMessage(blockHashB, authIDA) + messageBlockBAuthB := buildVoteMessage(blockHashB, authIDB) + + const somePeer = peer.ID("abc") + + tracker.add(somePeer, messageBlockAAuthA) + tracker.add(somePeer, messageBlockBAuthA) + // Add third message for block B and authority id B. + // This triggers a cleanup removing the oldest message + // which is for block A and authority id A. The block A + // is also completely removed since it does not contain + // any authority ID (vote message) anymore. + tracker.add(somePeer, messageBlockBAuthB) + assertVotesMapping(t, tracker.mapping, map[common.Hash]map[ed25519.PublicKeyBytes]networkVoteMessage{ + blockHashB: { + authIDA: wrapVoteMessageWithPeerID(messageBlockBAuthA, somePeer), + authIDB: wrapVoteMessageWithPeerID(messageBlockBAuthB, somePeer), + }, + }) + }) +} + +// This test verifies overidding a value does not affect the +// input order for which each message was added. +func Test_votesTracker_overriding(t *testing.T) { + t.Parallel() + + t.Run("override oldest", func(t *testing.T) { + t.Parallel() + + const capacity = 2 + tracker := newVotesTracker(capacity) + + blockHashA := common.Hash{0xa} + blockHashB := common.Hash{0xb} + + authIDA := ed25519.PublicKeyBytes{0xa} + authIDB := ed25519.PublicKeyBytes{0xb} + + messageBlockAAuthA := buildVoteMessage(blockHashA, authIDA) + messageBlockBAuthA := buildVoteMessage(blockHashB, authIDA) + messageBlockBAuthB := buildVoteMessage(blockHashB, authIDB) + + const somePeer = peer.ID("abc") + + tracker.add(somePeer, messageBlockAAuthA) + tracker.add(somePeer, messageBlockBAuthA) + tracker.add(somePeer, messageBlockAAuthA) // override oldest + tracker.add(somePeer, messageBlockBAuthB) + + assertVotesMapping(t, tracker.mapping, map[common.Hash]map[ed25519.PublicKeyBytes]networkVoteMessage{ + blockHashB: { + authIDA: wrapVoteMessageWithPeerID(messageBlockBAuthA, somePeer), + authIDB: wrapVoteMessageWithPeerID(messageBlockBAuthB, somePeer), + }, + }) + }) + + t.Run("override newest", func(t *testing.T) { + t.Parallel() + + const capacity = 2 + tracker := newVotesTracker(capacity) + + blockHashA := common.Hash{0xa} + blockHashB := common.Hash{0xb} + + authIDA := ed25519.PublicKeyBytes{0xa} + authIDB := ed25519.PublicKeyBytes{0xb} + + messageBlockAAuthA := buildVoteMessage(blockHashA, authIDA) + messageBlockBAuthA := buildVoteMessage(blockHashB, authIDA) + messageBlockBAuthB := buildVoteMessage(blockHashB, authIDB) + + const somePeer = peer.ID("abc") + + tracker.add(somePeer, messageBlockAAuthA) + tracker.add(somePeer, messageBlockBAuthA) + tracker.add(somePeer, messageBlockBAuthA) // override newest + tracker.add(somePeer, messageBlockBAuthB) + + assertVotesMapping(t, tracker.mapping, map[common.Hash]map[ed25519.PublicKeyBytes]networkVoteMessage{ + blockHashB: { + authIDA: wrapVoteMessageWithPeerID(messageBlockBAuthA, somePeer), + authIDB: wrapVoteMessageWithPeerID(messageBlockBAuthB, somePeer), + }, + }) + }) +} + +func Test_votesTracker_delete(t *testing.T) { + t.Parallel() + + t.Run("non existing block hash", func(t *testing.T) { + t.Parallel() + + const capacity = 2 + tracker := newVotesTracker(capacity) + + blockHashA := common.Hash{0xa} + blockHashB := common.Hash{0xb} + + authIDA := ed25519.PublicKeyBytes{0xa} + + messageBlockAAuthA := buildVoteMessage(blockHashA, authIDA) + + const somePeer = peer.ID("abc") + + tracker.add(somePeer, messageBlockAAuthA) + tracker.delete(blockHashB) + + assertVotesMapping(t, tracker.mapping, map[common.Hash]map[ed25519.PublicKeyBytes]networkVoteMessage{ + blockHashA: { + authIDA: wrapVoteMessageWithPeerID(messageBlockAAuthA, somePeer), + }, + }) + }) + + t.Run("existing block hash", func(t *testing.T) { + t.Parallel() + + const capacity = 2 + tracker := newVotesTracker(capacity) + + blockHashA := common.Hash{0xa} + authIDA := ed25519.PublicKeyBytes{0xa} + authIDB := ed25519.PublicKeyBytes{0xb} + messageBlockAAuthA := buildVoteMessage(blockHashA, authIDA) + messageBlockAAuthB := buildVoteMessage(blockHashA, authIDB) + + const somePeer = peer.ID("abc") + + tracker.add(somePeer, messageBlockAAuthA) + tracker.add(somePeer, messageBlockAAuthB) + tracker.delete(blockHashA) + + assertVotesMapping(t, tracker.mapping, map[common.Hash]map[ed25519.PublicKeyBytes]networkVoteMessage{}) + }) +} + +func Test_votesTracker_messages(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + votesTracker *votesTracker + blockHash common.Hash + messages []networkVoteMessage + }{ + "non existing block hash": { + votesTracker: &votesTracker{ + mapping: map[common.Hash]map[ed25519.PublicKeyBytes]*list.Element{ + {1}: {}, + }, + linkedList: list.New(), + }, + blockHash: common.Hash{2}, + }, + "existing block hash": { + votesTracker: &votesTracker{ + mapping: map[common.Hash]map[ed25519.PublicKeyBytes]*list.Element{ + {1}: { + ed25519.PublicKeyBytes{1}: { + Value: networkVoteMessage{ + from: "a", + msg: &VoteMessage{Round: 1}, + }, + }, + ed25519.PublicKeyBytes{2}: { + Value: networkVoteMessage{ + from: "a", + msg: &VoteMessage{Round: 2}, + }, + }, + }, + }, + }, + blockHash: common.Hash{1}, + messages: []networkVoteMessage{ + {from: peer.ID("a"), msg: &VoteMessage{Round: 1}}, + {from: peer.ID("a"), msg: &VoteMessage{Round: 2}}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + vt := testCase.votesTracker + messages := vt.messages(testCase.blockHash) + + sort.Slice(messages, func(i, j int) bool { + if messages[i].from == messages[j].from { + return messages[i].msg.Round < messages[j].msg.Round + } + return messages[i].from < messages[j].from + }) + + assert.Equal(t, testCase.messages, messages) + }) + } +} + +func Test_votesTracker_networkVoteMessages(t *testing.T) { + t.Parallel() + + const capacity = 10 + vt := newVotesTracker(capacity) + + blockHashA := common.Hash{0xa} + blockHashB := common.Hash{0xb} + + authIDA := ed25519.PublicKeyBytes{0xa} + authIDB := ed25519.PublicKeyBytes{0xb} + + messageBlockAAuthA := buildVoteMessage(blockHashA, authIDA) + messageBlockAAuthB := buildVoteMessage(blockHashA, authIDB) + messageBlockBAuthA := buildVoteMessage(blockHashB, authIDA) + + vt.add("a", messageBlockAAuthA) + vt.add("b", messageBlockAAuthB) + vt.add("b", messageBlockBAuthA) + + networkVoteMessages := vt.networkVoteMessages() + + expectedNetworkVoteMessages := []networkVoteMessage{ + {from: "a", msg: messageBlockAAuthA}, + {from: "b", msg: messageBlockAAuthB}, + {from: "b", msg: messageBlockBAuthA}, + } + + assert.ElementsMatch(t, expectedNetworkVoteMessages, networkVoteMessages) +}