diff --git a/dot/sync/block_queue.go b/dot/sync/block_queue.go index 28720164502..57daed89ade 100644 --- a/dot/sync/block_queue.go +++ b/dot/sync/block_queue.go @@ -4,6 +4,7 @@ package sync import ( + "context" "sync" "github.com/ChainSafe/gossamer/dot/types" @@ -11,42 +12,44 @@ import ( ) type blockQueue struct { - sync.RWMutex - cap int - ch chan *types.BlockData - blocks map[common.Hash]*types.BlockData + queue chan *types.BlockData + hashesSet map[common.Hash]struct{} + hashesSetMutex sync.RWMutex } // newBlockQueue initialises a queue of *types.BlockData with the given capacity. -func newBlockQueue(cap int) *blockQueue { +func newBlockQueue(capacity int) *blockQueue { return &blockQueue{ - cap: cap, - ch: make(chan *types.BlockData, cap), - blocks: make(map[common.Hash]*types.BlockData), + queue: make(chan *types.BlockData, capacity), + hashesSet: make(map[common.Hash]struct{}), } } -// push pushes an item into the queue. it blocks if the queue is at capacity. -func (q *blockQueue) push(bd *types.BlockData) { - q.Lock() - q.blocks[bd.Hash] = bd - q.Unlock() +// push pushes an item into the queue. It blocks if the queue is at capacity. +func (bq *blockQueue) push(blockData *types.BlockData) { + bq.hashesSetMutex.Lock() + bq.hashesSet[blockData.Hash] = struct{}{} + bq.hashesSetMutex.Unlock() - q.ch <- bd + bq.queue <- blockData } -// pop pops an item from the queue. it blocks if the queue is empty. -func (q *blockQueue) pop() *types.BlockData { - bd := <-q.ch - q.Lock() - delete(q.blocks, bd.Hash) - q.Unlock() - return bd +// pop pops an item from the queue. It blocks if the queue is empty. +func (bq *blockQueue) pop(ctx context.Context) (blockData *types.BlockData) { + select { + case <-ctx.Done(): + return + case blockData = <-bq.queue: + } + bq.hashesSetMutex.Lock() + delete(bq.hashesSet, blockData.Hash) + bq.hashesSetMutex.Unlock() + return blockData } -func (q *blockQueue) has(hash common.Hash) bool { - q.RLock() - defer q.RUnlock() - _, has := q.blocks[hash] +func (bq *blockQueue) has(blockHash common.Hash) (has bool) { + bq.hashesSetMutex.RLock() + defer bq.hashesSetMutex.RUnlock() + _, has = bq.hashesSet[blockHash] return has } diff --git a/dot/sync/chain_processor.go b/dot/sync/chain_processor.go index a628cbb2f86..c3970131ef7 100644 --- a/dot/sync/chain_processor.go +++ b/dot/sync/chain_processor.go @@ -73,15 +73,9 @@ func (s *chainProcessor) stop() { func (s *chainProcessor) processReadyBlocks() { for { - select { - case <-s.ctx.Done(): + bd := s.readyBlocks.pop(s.ctx) + if s.ctx.Err() != nil { return - default: - } - - bd := s.readyBlocks.pop() - if bd == nil { - continue } if err := s.processBlockData(bd); err != nil { diff --git a/dot/sync/chain_sync_test.go b/dot/sync/chain_sync_test.go index 07dca3337dd..d5ff74e5c3e 100644 --- a/dot/sync/chain_sync_test.go +++ b/dot/sync/chain_sync_test.go @@ -4,6 +4,7 @@ package sync import ( + "context" "errors" "fmt" "testing" @@ -663,7 +664,7 @@ func TestChainSync_doSync(t *testing.T) { workerErr = cs.doSync(req, make(map[peer.ID]struct{})) require.Nil(t, workerErr) - bd := readyBlocks.pop() + bd := readyBlocks.pop(context.Background()) require.NotNil(t, bd) require.Equal(t, resp.BlockData[0], bd) @@ -699,11 +700,11 @@ func TestChainSync_doSync(t *testing.T) { workerErr = cs.doSync(req, make(map[peer.ID]struct{})) require.Nil(t, workerErr) - bd = readyBlocks.pop() + bd = readyBlocks.pop(context.Background()) require.NotNil(t, bd) require.Equal(t, resp.BlockData[0], bd) - bd = readyBlocks.pop() + bd = readyBlocks.pop(context.Background()) require.NotNil(t, bd) require.Equal(t, resp.BlockData[1], bd) } @@ -757,9 +758,10 @@ func TestHandleReadyBlock(t *testing.T) { require.False(t, cs.pendingBlocks.hasBlock(header3.Hash())) require.True(t, cs.pendingBlocks.hasBlock(header2NotDescendant.Hash())) - require.Equal(t, block1.ToBlockData(), readyBlocks.pop()) - require.Equal(t, block2.ToBlockData(), readyBlocks.pop()) - require.Equal(t, block3.ToBlockData(), readyBlocks.pop()) + ctx := context.Background() + require.Equal(t, block1.ToBlockData(), readyBlocks.pop(ctx)) + require.Equal(t, block2.ToBlockData(), readyBlocks.pop(ctx)) + require.Equal(t, block3.ToBlockData(), readyBlocks.pop(ctx)) } func TestChainSync_determineSyncPeers(t *testing.T) { diff --git a/dot/sync/tip_syncer_test.go b/dot/sync/tip_syncer_test.go index 79569c06cd2..f77c05526c2 100644 --- a/dot/sync/tip_syncer_test.go +++ b/dot/sync/tip_syncer_test.go @@ -4,6 +4,7 @@ package sync import ( + "context" "testing" "github.com/ChainSafe/gossamer/dot/network" @@ -233,7 +234,8 @@ func TestTipSyncer_handleTick_case3(t *testing.T) { require.NoError(t, err) require.Equal(t, []*worker(nil), w) require.False(t, s.pendingBlocks.hasBlock(header.Hash())) - require.Equal(t, block.ToBlockData(), s.readyBlocks.pop()) + readyBlockData := s.readyBlocks.pop(context.Background()) + require.Equal(t, block.ToBlockData(), readyBlockData) // add pending block w/ full block, but block is not ready as parent is unknown bs := new(syncmocks.BlockState) @@ -274,8 +276,9 @@ func TestTipSyncer_handleTick_case3(t *testing.T) { require.NoError(t, err) require.Equal(t, []*worker(nil), w) require.False(t, s.pendingBlocks.hasBlock(header.Hash())) - s.readyBlocks.pop() // first pop will remove parent - require.Equal(t, block.ToBlockData(), s.readyBlocks.pop()) + _ = s.readyBlocks.pop(context.Background()) // first pop removes the parent + readyBlockData = s.readyBlocks.pop(context.Background()) + require.Equal(t, block.ToBlockData(), readyBlockData) } func TestTipSyncer_hasCurrentWorker(t *testing.T) {