diff --git a/dot/sync/block_queue.go b/dot/sync/block_queue.go index 2872016450..aa4648722b 100644 --- a/dot/sync/block_queue.go +++ b/dot/sync/block_queue.go @@ -4,6 +4,7 @@ package sync import ( + "context" "sync" "github.com/ChainSafe/gossamer/dot/types" @@ -11,42 +12,44 @@ import ( ) type blockQueue struct { - sync.RWMutex - cap int - ch chan *types.BlockData - blocks map[common.Hash]*types.BlockData + queue chan *types.BlockData + hashesSet map[common.Hash]struct{} + hashesSetMutex sync.RWMutex } // newBlockQueue initialises a queue of *types.BlockData with the given capacity. -func newBlockQueue(cap int) *blockQueue { +func newBlockQueue(capacity int) *blockQueue { return &blockQueue{ - cap: cap, - ch: make(chan *types.BlockData, cap), - blocks: make(map[common.Hash]*types.BlockData), + queue: make(chan *types.BlockData, capacity), + hashesSet: make(map[common.Hash]struct{}, capacity), } } -// push pushes an item into the queue. it blocks if the queue is at capacity. -func (q *blockQueue) push(bd *types.BlockData) { - q.Lock() - q.blocks[bd.Hash] = bd - q.Unlock() +// push pushes an item into the queue. It blocks if the queue is at capacity. +func (bq *blockQueue) push(blockData *types.BlockData) { + bq.hashesSetMutex.Lock() + bq.hashesSet[blockData.Hash] = struct{}{} + bq.hashesSetMutex.Unlock() - q.ch <- bd + bq.queue <- blockData } -// pop pops an item from the queue. it blocks if the queue is empty. -func (q *blockQueue) pop() *types.BlockData { - bd := <-q.ch - q.Lock() - delete(q.blocks, bd.Hash) - q.Unlock() - return bd +// pop pops an item from the queue. It blocks if the queue is empty. +func (bq *blockQueue) pop(ctx context.Context) (blockData *types.BlockData) { + select { + case <-ctx.Done(): + return nil + case blockData = <-bq.queue: + } + bq.hashesSetMutex.Lock() + delete(bq.hashesSet, blockData.Hash) + bq.hashesSetMutex.Unlock() + return blockData } -func (q *blockQueue) has(hash common.Hash) bool { - q.RLock() - defer q.RUnlock() - _, has := q.blocks[hash] +func (bq *blockQueue) has(blockHash common.Hash) (has bool) { + bq.hashesSetMutex.RLock() + defer bq.hashesSetMutex.RUnlock() + _, has = bq.hashesSet[blockHash] return has } diff --git a/dot/sync/block_queue_test.go b/dot/sync/block_queue_test.go new file mode 100644 index 0000000000..f6796083bb --- /dev/null +++ b/dot/sync/block_queue_test.go @@ -0,0 +1,237 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package sync + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/ChainSafe/gossamer/dot/types" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_newBlockQueue(t *testing.T) { + t.Parallel() + + const capacity = 1 + bq := newBlockQueue(capacity) + + require.NotNil(t, bq.queue) + assert.Equal(t, 1, cap(bq.queue)) + assert.Equal(t, 0, len(bq.queue)) + bq.queue = nil + + expectedBlockQueue := &blockQueue{ + hashesSet: make(map[common.Hash]struct{}, capacity), + } + assert.Equal(t, expectedBlockQueue, bq) +} + +func Test_blockQueue_push(t *testing.T) { + t.Parallel() + + const capacity = 1 + bq := newBlockQueue(capacity) + blockData := &types.BlockData{ + Hash: common.Hash{1}, + } + + bq.push(blockData) + + // cannot compare channels + require.NotNil(t, bq.queue) + assert.Len(t, bq.queue, 1) + + receivedBlockData := <-bq.queue + expectedBlockData := &types.BlockData{ + Hash: common.Hash{1}, + } + assert.Equal(t, expectedBlockData, receivedBlockData) + + bq.queue = nil + expectedBlockQueue := &blockQueue{ + hashesSet: map[common.Hash]struct{}{{1}: {}}, + } + assert.Equal(t, expectedBlockQueue, bq) +} + +func Test_blockQueue_pop(t *testing.T) { + t.Parallel() + + t.Run("context canceled", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + const capacity = 1 + bq := newBlockQueue(capacity) + + blockData := bq.pop(ctx) + assert.Nil(t, blockData) + }) + + t.Run("get block data after waiting", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + const capacity = 1 + bq := newBlockQueue(capacity) + + const afterDuration = 5 * time.Millisecond + time.AfterFunc(afterDuration, func() { + blockData := &types.BlockData{ + Hash: common.Hash{1}, + } + bq.push(blockData) + }) + + blockData := bq.pop(ctx) + + expectedBlockData := &types.BlockData{ + Hash: common.Hash{1}, + } + assert.Equal(t, expectedBlockData, blockData) + + assert.Len(t, bq.queue, 0) + bq.queue = nil + expectedBlockQueue := &blockQueue{ + hashesSet: map[common.Hash]struct{}{}, + } + assert.Equal(t, expectedBlockQueue, bq) + }) +} + +func Test_blockQueue_has(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + blockQueue *blockQueue + blockHash common.Hash + has bool + }{ + "absent": { + blockQueue: &blockQueue{ + hashesSet: map[common.Hash]struct{}{}, + }, + blockHash: common.Hash{1}, + }, + "exists": { + blockQueue: &blockQueue{ + hashesSet: map[common.Hash]struct{}{{1}: {}}, + }, + blockHash: common.Hash{1}, + has: true, + }, + } + + for name, tc := range testCases { + testCase := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + has := testCase.blockQueue.has(testCase.blockHash) + assert.Equal(t, testCase.has, has) + }) + } +} + +func Test_lockQueue_endToEnd(t *testing.T) { + t.Parallel() + + const capacity = 10 + blockQueue := newBlockQueue(capacity) + + newBlockData := func(i byte) *types.BlockData { + return &types.BlockData{ + Hash: common.Hash{i}, + } + } + + blockQueue.push(newBlockData(1)) + blockQueue.push(newBlockData(2)) + blockQueue.push(newBlockData(3)) + + blockData := blockQueue.pop(context.Background()) + assert.Equal(t, newBlockData(1), blockData) + + has := blockQueue.has(newBlockData(2).Hash) + assert.True(t, has) + has = blockQueue.has(newBlockData(3).Hash) + assert.True(t, has) + + blockQueue.push(newBlockData(4)) + + has = blockQueue.has(newBlockData(4).Hash) + assert.True(t, has) + + blockData = blockQueue.pop(context.Background()) + assert.Equal(t, newBlockData(2), blockData) + + // drain queue + for len(blockQueue.queue) > 0 { + <-blockQueue.queue + } +} + +func Test_lockQueue_threadSafety(t *testing.T) { + // This test consists in checking for concurrent access + // using the -race detector. + t.Parallel() + + var startWg, endWg sync.WaitGroup + ctx, cancel := context.WithCancel(context.Background()) + + const operations = 3 + const parallelism = 3 + const goroutines = parallelism * operations + startWg.Add(goroutines) + endWg.Add(goroutines) + + const testDuration = 50 * time.Millisecond + go func() { + timer := time.NewTimer(time.Hour) + startWg.Wait() + _ = timer.Reset(testDuration) + <-timer.C + cancel() + }() + + runInLoop := func(f func()) { + defer endWg.Done() + startWg.Done() + startWg.Wait() + for ctx.Err() == nil { + f() + } + } + + const capacity = 10 + blockQueue := newBlockQueue(capacity) + blockData := &types.BlockData{ + Hash: common.Hash{1}, + } + blockHash := common.Hash{1} + + for i := 0; i < parallelism; i++ { + go runInLoop(func() { + blockQueue.push(blockData) + }) + + go runInLoop(func() { + _ = blockQueue.pop(ctx) + }) + + go runInLoop(func() { + _ = blockQueue.has(blockHash) + }) + } + + endWg.Wait() +} diff --git a/dot/sync/chain_processor.go b/dot/sync/chain_processor.go index a628cbb2f8..c3970131ef 100644 --- a/dot/sync/chain_processor.go +++ b/dot/sync/chain_processor.go @@ -73,15 +73,9 @@ func (s *chainProcessor) stop() { func (s *chainProcessor) processReadyBlocks() { for { - select { - case <-s.ctx.Done(): + bd := s.readyBlocks.pop(s.ctx) + if s.ctx.Err() != nil { return - default: - } - - bd := s.readyBlocks.pop() - if bd == nil { - continue } if err := s.processBlockData(bd); err != nil { diff --git a/dot/sync/chain_sync_integeration_test.go b/dot/sync/chain_sync_integeration_test.go index b742226705..d5e12f2499 100644 --- a/dot/sync/chain_sync_integeration_test.go +++ b/dot/sync/chain_sync_integeration_test.go @@ -7,6 +7,7 @@ package sync import ( + "context" "errors" "fmt" "testing" @@ -666,7 +667,7 @@ func TestChainSync_doSync(t *testing.T) { workerErr = cs.doSync(req, make(map[peer.ID]struct{})) require.Nil(t, workerErr) - bd := readyBlocks.pop() + bd := readyBlocks.pop(context.Background()) require.NotNil(t, bd) require.Equal(t, resp.BlockData[0], bd) @@ -702,11 +703,11 @@ func TestChainSync_doSync(t *testing.T) { workerErr = cs.doSync(req, make(map[peer.ID]struct{})) require.Nil(t, workerErr) - bd = readyBlocks.pop() + bd = readyBlocks.pop(context.Background()) require.NotNil(t, bd) require.Equal(t, resp.BlockData[0], bd) - bd = readyBlocks.pop() + bd = readyBlocks.pop(context.Background()) require.NotNil(t, bd) require.Equal(t, resp.BlockData[1], bd) } @@ -760,9 +761,10 @@ func TestHandleReadyBlock(t *testing.T) { require.False(t, cs.pendingBlocks.hasBlock(header3.Hash())) require.True(t, cs.pendingBlocks.hasBlock(header2NotDescendant.Hash())) - require.Equal(t, block1.ToBlockData(), readyBlocks.pop()) - require.Equal(t, block2.ToBlockData(), readyBlocks.pop()) - require.Equal(t, block3.ToBlockData(), readyBlocks.pop()) + ctx := context.Background() + require.Equal(t, block1.ToBlockData(), readyBlocks.pop(ctx)) + require.Equal(t, block2.ToBlockData(), readyBlocks.pop(ctx)) + require.Equal(t, block3.ToBlockData(), readyBlocks.pop(ctx)) } func TestChainSync_determineSyncPeers(t *testing.T) { diff --git a/dot/sync/tip_syncer_integeration_test.go b/dot/sync/tip_syncer_integeration_test.go index b24c0c4dc9..4014e4c79f 100644 --- a/dot/sync/tip_syncer_integeration_test.go +++ b/dot/sync/tip_syncer_integeration_test.go @@ -7,6 +7,7 @@ package sync import ( + "context" "testing" "github.com/ChainSafe/gossamer/dot/network" @@ -235,7 +236,8 @@ func TestTipSyncer_handleTick_case3(t *testing.T) { require.NoError(t, err) require.Equal(t, []*worker(nil), w) require.False(t, s.pendingBlocks.hasBlock(header.Hash())) - require.Equal(t, block.ToBlockData(), s.readyBlocks.pop()) + readyBlockData := s.readyBlocks.pop(context.Background()) + require.Equal(t, block.ToBlockData(), readyBlockData) // add pending block w/ full block, but block is not ready as parent is unknown bs := new(syncmocks.BlockState) @@ -276,8 +278,9 @@ func TestTipSyncer_handleTick_case3(t *testing.T) { require.NoError(t, err) require.Equal(t, []*worker(nil), w) require.False(t, s.pendingBlocks.hasBlock(header.Hash())) - s.readyBlocks.pop() // first pop will remove parent - require.Equal(t, block.ToBlockData(), s.readyBlocks.pop()) + _ = s.readyBlocks.pop(context.Background()) // first pop removes the parent + readyBlockData = s.readyBlocks.pop(context.Background()) + require.Equal(t, block.ToBlockData(), readyBlockData) } func TestTipSyncer_hasCurrentWorker(t *testing.T) {