diff --git a/channelmonitor/channelmonitor_test.go b/channelmonitor/channelmonitor_test.go index eefe3fc7..3a3fac76 100644 --- a/channelmonitor/channelmonitor_test.go +++ b/channelmonitor/channelmonitor_test.go @@ -611,3 +611,7 @@ func (m *mockChannelState) Stages() *datatransfer.ChannelStages { func (m *mockChannelState) ReceivedCids() []cid.Cid { panic("implement me") } + +func (m *mockChannelState) ReceivedCidsLen() int { + panic("implement me") +} diff --git a/channels/channel_state.go b/channels/channel_state.go index a8beb25d..90b76fad 100644 --- a/channels/channel_state.go +++ b/channels/channel_state.go @@ -48,7 +48,7 @@ type channelState struct { voucherResults []internal.EncodedVoucherResult voucherResultDecoder DecoderByTypeFunc voucherDecoder DecoderByTypeFunc - channelCIDsReader ChannelCIDsReader + receivedCids ReceivedCidsReader // stages tracks the timeline of events related to a data transfer, for // traceability purposes. @@ -100,13 +100,22 @@ func (c channelState) Voucher() datatransfer.Voucher { // ReceivedCids returns the cids received so far on this channel func (c channelState) ReceivedCids() []cid.Cid { - receivedCids, err := c.channelCIDsReader(c.ChannelID()) + receivedCids, err := c.receivedCids.ToArray(c.ChannelID()) if err != nil { log.Error(err) } return receivedCids } +// ReceivedCids returns the number of cids received so far on this channel +func (c channelState) ReceivedCidsLen() int { + len, err := c.receivedCids.Len(c.ChannelID()) + if err != nil { + log.Error(err) + } + return len +} + // Sender returns the peer id for the node that is sending data func (c channelState) Sender() peer.ID { return c.sender } @@ -190,7 +199,7 @@ func (c channelState) Stages() *datatransfer.ChannelStages { return c.stages } -func fromInternalChannelState(c internal.ChannelState, voucherDecoder DecoderByTypeFunc, voucherResultDecoder DecoderByTypeFunc, channelCIDsReader ChannelCIDsReader) datatransfer.ChannelState { +func fromInternalChannelState(c internal.ChannelState, voucherDecoder DecoderByTypeFunc, voucherResultDecoder DecoderByTypeFunc, receivedCidsReader ReceivedCidsReader) datatransfer.ChannelState { return channelState{ selfPeer: c.SelfPeer, isPull: c.Initiator == c.Recipient, @@ -209,7 +218,7 @@ func fromInternalChannelState(c internal.ChannelState, voucherDecoder DecoderByT voucherResults: c.VoucherResults, voucherResultDecoder: voucherResultDecoder, voucherDecoder: voucherDecoder, - channelCIDsReader: channelCIDsReader, + receivedCids: receivedCidsReader, stages: c.Stages, } } diff --git a/channels/channels.go b/channels/channels.go index 06bdc223..1e80f3d2 100644 --- a/channels/channels.go +++ b/channels/channels.go @@ -28,7 +28,10 @@ import ( type DecoderByTypeFunc func(identifier datatransfer.TypeIdentifier) (encoding.Decoder, bool) -type ChannelCIDsReader func(chid datatransfer.ChannelID) ([]cid.Cid, error) +type ReceivedCidsReader interface { + ToArray(chid datatransfer.ChannelID) ([]cid.Cid, error) + Len(chid datatransfer.ChannelID) (int, error) +} type Notifier func(datatransfer.Event, datatransfer.ChannelState) @@ -55,7 +58,6 @@ type Channels struct { voucherResultDecoder DecoderByTypeFunc stateMachines fsm.Group migrateStateMachines func(context.Context) error - cidLists cidlists.CIDLists seenCIDs *cidsets.CIDSetManager } @@ -78,7 +80,6 @@ func New(ds datastore.Batching, seenCIDsDS := namespace.Wrap(ds, datastore.NewKey("seencids")) c := &Channels{ - cidLists: cidLists, seenCIDs: cidsets.NewCIDSetManager(seenCIDsDS), notifier: notifier, voucherDecoder: voucherDecoder, @@ -123,7 +124,7 @@ func (c *Channels) dispatch(eventName fsm.EventName, channel fsm.StateType) { Timestamp: time.Now(), } - c.notifier(evt, fromInternalChannelState(realChannel, c.voucherDecoder, c.voucherResultDecoder, c.cidLists.ReadList)) + c.notifier(evt, c.fromInternalChannelState(realChannel)) // When the channel has been cleaned up, remove the caches of seen cids if evt.Code == datatransfer.CleanupComplete { @@ -180,10 +181,6 @@ func (c *Channels) CreateNew(selfPeer peer.ID, tid datatransfer.TransferID, base if err != nil { return datatransfer.ChannelID{}, err } - err = c.cidLists.CreateList(chid, nil) - if err != nil { - return datatransfer.ChannelID{}, err - } return chid, c.stateMachines.Send(chid, datatransfer.Open) } @@ -197,7 +194,7 @@ func (c *Channels) InProgress() (map[datatransfer.ChannelID]datatransfer.Channel channels := make(map[datatransfer.ChannelID]datatransfer.ChannelState, len(internalChannels)) for _, internalChannel := range internalChannels { channels[datatransfer.ChannelID{ID: internalChannel.TransferID, Responder: internalChannel.Responder, Initiator: internalChannel.Initiator}] = - fromInternalChannelState(internalChannel, c.voucherDecoder, c.voucherResultDecoder, c.cidLists.ReadList) + c.fromInternalChannelState(internalChannel) } return channels, nil } @@ -210,7 +207,7 @@ func (c *Channels) GetByID(ctx context.Context, chid datatransfer.ChannelID) (da if err != nil { return nil, NewErrNotFound(chid) } - return fromInternalChannelState(internalChannel, c.voucherDecoder, c.voucherResultDecoder, c.cidLists.ReadList), nil + return c.fromInternalChannelState(internalChannel), nil } // Accept marks a data transfer as accepted @@ -239,11 +236,6 @@ func (c *Channels) DataQueued(chid datatransfer.ChannelID, k cid.Cid, delta uint // Returns true if this is the first time the block has been received func (c *Channels) DataReceived(chid datatransfer.ChannelID, k cid.Cid, delta uint64) (bool, error) { - err := c.cidLists.AppendList(chid, k) - if err != nil { - return false, err - } - return c.fireProgressEvent(chid, datatransfer.DataReceived, datatransfer.DataReceivedProgress, k, delta) } @@ -361,12 +353,12 @@ func (c *Channels) HasChannel(chid datatransfer.ChannelID) (bool, error) { // blocks that have already been queued / sent / received func (c *Channels) removeSeenCIDCaches(chid datatransfer.ChannelID) error { progressStates := []datatransfer.EventCode{ - datatransfer.DataQueuedProgress, - datatransfer.DataSentProgress, - datatransfer.DataReceivedProgress, + datatransfer.DataQueued, + datatransfer.DataSent, + datatransfer.DataReceived, } for _, evt := range progressStates { - sid := cidsets.SetID(chid.String() + "/" + datatransfer.Events[evt]) + sid := seenCidsSetID(chid, evt) err := c.seenCIDs.DeleteSet(sid) if err != nil { return err @@ -388,7 +380,7 @@ func (c *Channels) fireProgressEvent(chid datatransfer.ChannelID, evt datatransf } // Check if the block has already been seen - sid := cidsets.SetID(chid.String() + "/" + datatransfer.Events[evt]) + sid := seenCidsSetID(chid, evt) seen, err := c.seenCIDs.InsertSetCID(sid, k) if err != nil { return false, err @@ -424,3 +416,40 @@ func (c *Channels) checkChannelExists(chid datatransfer.ChannelID, code datatran } return nil } + +// Get the ID of the CID set for the given channel ID and event code. +// The CID set stores a unique list of queued / sent / received CIDs. +func seenCidsSetID(chid datatransfer.ChannelID, evt datatransfer.EventCode) cidsets.SetID { + return cidsets.SetID(chid.String() + "/" + datatransfer.Events[evt]) +} + +// Convert from the internally used channel state format to the externally exposed ChannelState +func (c *Channels) fromInternalChannelState(ch internal.ChannelState) datatransfer.ChannelState { + rcr := &receivedCidsReader{ + seenCIDs: c.seenCIDs, + } + return fromInternalChannelState(ch, c.voucherDecoder, c.voucherResultDecoder, rcr) +} + +// Implements the ReceivedCidsReader interface so that the internal channel +// state has access to the received CIDs. +// The interface is used (instead of passing these values directly) +// so the values can be loaded lazily. Reading all CIDs from the datastore +// is an expensive operation so we want to avoid doing it unless necessary. +// Note that the received CIDs get cleaned up when the channel completes, so +// these methods will return an empty array after that point. +type receivedCidsReader struct { + seenCIDs *cidsets.CIDSetManager +} + +func (r *receivedCidsReader) ToArray(chid datatransfer.ChannelID) ([]cid.Cid, error) { + sid := seenCidsSetID(chid, datatransfer.DataReceived) + return r.seenCIDs.SetToArray(sid) +} + +func (r *receivedCidsReader) Len(chid datatransfer.ChannelID) (int, error) { + sid := seenCidsSetID(chid, datatransfer.DataReceived) + return r.seenCIDs.SetLen(sid) +} + +var _ ReceivedCidsReader = (*receivedCidsReader)(nil) diff --git a/channels/channels_test.go b/channels/channels_test.go index 6ad58fad..eaf9e7ab 100644 --- a/channels/channels_test.go +++ b/channels/channels_test.go @@ -179,7 +179,7 @@ func TestChannels(t *testing.T) { state = checkEvent(ctx, t, received, datatransfer.DataReceived) require.Equal(t, uint64(100), state.Received()) require.Equal(t, uint64(100), state.Sent()) - require.Equal(t, []cid.Cid{cids[0], cids[1]}, state.ReceivedCids()) + require.ElementsMatch(t, []cid.Cid{cids[0], cids[1]}, state.ReceivedCids()) isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 25) require.NoError(t, err) @@ -187,7 +187,7 @@ func TestChannels(t *testing.T) { state = checkEvent(ctx, t, received, datatransfer.DataSent) require.Equal(t, uint64(100), state.Received()) require.Equal(t, uint64(100), state.Sent()) - require.Equal(t, []cid.Cid{cids[0], cids[1]}, state.ReceivedCids()) + require.ElementsMatch(t, []cid.Cid{cids[0], cids[1]}, state.ReceivedCids()) isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50) require.NoError(t, err) @@ -195,7 +195,7 @@ func TestChannels(t *testing.T) { state = checkEvent(ctx, t, received, datatransfer.DataReceived) require.Equal(t, uint64(100), state.Received()) require.Equal(t, uint64(100), state.Sent()) - require.Equal(t, []cid.Cid{cids[0], cids[1], cids[0]}, state.ReceivedCids()) + require.ElementsMatch(t, []cid.Cid{cids[0], cids[1]}, state.ReceivedCids()) }) t.Run("pause/resume", func(t *testing.T) { @@ -613,7 +613,10 @@ func TestMigrationsV1(t *testing.T) { require.Equal(t, messages[i], channel.Message()) require.Equal(t, vouchers[i], channel.LastVoucher()) require.Equal(t, voucherResults[i], channel.LastVoucherResult()) - require.Equal(t, receivedCids[i], channel.ReceivedCids()) + // No longer relying on this migration to migrate CID lists as they + // have been deprecated since we moved to CID sets: + // https://github.com/filecoin-project/go-data-transfer/pull/217 + //require.Equal(t, receivedCids[i], channel.ReceivedCids()) } } diff --git a/cidlists/cidlists.go b/cidlists/cidlists.go index 9014ae29..8ca30826 100644 --- a/cidlists/cidlists.go +++ b/cidlists/cidlists.go @@ -12,6 +12,7 @@ import ( datatransfer "github.com/filecoin-project/go-data-transfer" ) +// Deprecated: CIDLists have now been replaced by CID sets (see cidsets directory). // CIDLists maintains files that contain a list of CIDs received for different data transfers type CIDLists interface { CreateList(chid datatransfer.ChannelID, initalCids []cid.Cid) error diff --git a/cidsets/cidsets.go b/cidsets/cidsets.go index 193c1d17..5949aee0 100644 --- a/cidsets/cidsets.go +++ b/cidsets/cidsets.go @@ -29,6 +29,16 @@ func (mgr *CIDSetManager) InsertSetCID(sid SetID, c cid.Cid) (exists bool, err e return mgr.getSet(sid).Insert(c) } +// SetToArray gets the set as an array of CIDs +func (mgr *CIDSetManager) SetToArray(sid SetID) ([]cid.Cid, error) { + return mgr.getSet(sid).ToArray() +} + +// SetLen gets the number of CIDs in the set +func (mgr *CIDSetManager) SetLen(sid SetID) (int, error) { + return mgr.getSet(sid).Len() +} + // DeleteSet deletes a CID set func (mgr *CIDSetManager) DeleteSet(sid SetID) error { return mgr.getSet(sid).Truncate() @@ -55,12 +65,13 @@ func (mgr *CIDSetManager) getSetDS(sid SetID) datastore.Batching { // cidSet persists a set of CIDs type cidSet struct { - lk sync.Mutex - ds datastore.Batching + lk sync.Mutex + ds datastore.Batching + len int // cached length of set, starts at -1 } func NewCIDSet(ds datastore.Batching) *cidSet { - return &cidSet{ds: ds} + return &cidSet{ds: ds, len: -1} } // Insert a CID into the set. @@ -69,15 +80,97 @@ func (s *cidSet) Insert(c cid.Cid) (exists bool, err error) { s.lk.Lock() defer s.lk.Unlock() + // Check if the key is in the set already k := datastore.NewKey(c.String()) has, err := s.ds.Has(k) if err != nil { return false, err } if has { + // Already in the set, just return true return true, nil } - return false, s.ds.Put(k, nil) + + // Get the length of the set + len, err := s.unlockedLen() + if err != nil { + return false, err + } + + // Add the new CID to the set + err = s.ds.Put(k, nil) + if err != nil { + return false, err + } + + // Increment the cached length of the set + s.len = len + 1 + + return false, nil +} + +// Returns the number of CIDs in the set +func (s *cidSet) Len() (int, error) { + s.lk.Lock() + defer s.lk.Unlock() + + return s.unlockedLen() +} + +func (s *cidSet) unlockedLen() (int, error) { + // If the length is already cached, return it + if s.len >= 0 { + return s.len, nil + } + + // Query the datastore for all keys + res, err := s.ds.Query(query.Query{KeysOnly: true}) + if err != nil { + return 0, err + } + + entries, err := res.Rest() + if err != nil { + return 0, err + } + + // Cache the length of the set + s.len = len(entries) + + return s.len, nil +} + +// Get all cids in the set as an array +func (s *cidSet) ToArray() ([]cid.Cid, error) { + s.lk.Lock() + defer s.lk.Unlock() + + res, err := s.ds.Query(query.Query{KeysOnly: true}) + if err != nil { + return nil, err + } + + entries, err := res.Rest() + if err != nil { + return nil, err + } + + cids := make([]cid.Cid, 0, len(entries)) + for _, entry := range entries { + // When we create a datastore Key, a "/" is automatically pre-pended, + // so here we need to remove the preceding "/" before parsing as a CID + k := entry.Key + if string(k[0]) == "/" { + k = k[1:] + } + + c, err := cid.Parse(k) + if err != nil { + return nil, err + } + cids = append(cids, c) + } + return cids, nil } // Truncate removes all CIDs in the set @@ -85,6 +178,7 @@ func (s *cidSet) Truncate() error { s.lk.Lock() defer s.lk.Unlock() + // Get all keys in the datastore res, err := s.ds.Query(query.Query{KeysOnly: true}) if err != nil { return err @@ -95,11 +189,13 @@ func (s *cidSet) Truncate() error { return err } + // Create a batch to perform all deletes as one operation batched, err := s.ds.Batch() if err != nil { return err } + // Add delete operations for each key to the batch for _, entry := range entries { err := batched.Delete(datastore.NewKey(entry.Key)) if err != nil { @@ -107,5 +203,14 @@ func (s *cidSet) Truncate() error { } } - return batched.Commit() + // Commit the batch + err = batched.Commit() + if err != nil { + return err + } + + // Set the cached length of the set to zero + s.len = 0 + + return nil } diff --git a/cidsets/cidsets_test.go b/cidsets/cidsets_test.go index 17ba704f..ae75a719 100644 --- a/cidsets/cidsets_test.go +++ b/cidsets/cidsets_test.go @@ -18,30 +18,194 @@ func TestCIDSetManager(t *testing.T) { setID1 := SetID("set1") setID2 := SetID("set2") + // set1: +cid1 exists, err := mgr.InsertSetCID(setID1, cid1) require.NoError(t, err) require.False(t, exists) + // set1: +cid1 (again) exists, err = mgr.InsertSetCID(setID1, cid1) require.NoError(t, err) require.True(t, exists) + // set2: +cid1 exists, err = mgr.InsertSetCID(setID2, cid1) require.NoError(t, err) require.False(t, exists) + // set2: +cid2 (again) exists, err = mgr.InsertSetCID(setID2, cid1) require.NoError(t, err) require.True(t, exists) + // delete set1 err = mgr.DeleteSet(setID1) require.NoError(t, err) + // set1: +cid1 exists, err = mgr.InsertSetCID(setID1, cid1) require.NoError(t, err) require.False(t, exists) + // set1: +cid1 (again) exists, err = mgr.InsertSetCID(setID2, cid1) require.NoError(t, err) require.True(t, exists) } + +func TestCIDSetToArray(t *testing.T) { + cids := testutil.GenerateCids(2) + cid1 := cids[0] + cid2 := cids[1] + + dstore := ds_sync.MutexWrap(ds.NewMapDatastore()) + mgr := NewCIDSetManager(dstore) + setID1 := SetID("set1") + + // Expect no items in set + len, err := mgr.SetLen(setID1) + require.NoError(t, err) + require.Equal(t, 0, len) + + arr, err := mgr.SetToArray(setID1) + require.NoError(t, err) + require.Len(t, arr, 0) + + // set1: +cid1 + exists, err := mgr.InsertSetCID(setID1, cid1) + require.NoError(t, err) + require.False(t, exists) + + // Expect 1 cid in set + len, err = mgr.SetLen(setID1) + require.NoError(t, err) + require.Equal(t, 1, len) + + arr, err = mgr.SetToArray(setID1) + require.NoError(t, err) + require.Len(t, arr, 1) + require.Equal(t, arr[0], cid1) + + // set1: +cid1 (again) + exists, err = mgr.InsertSetCID(setID1, cid1) + require.NoError(t, err) + require.True(t, exists) + + // Expect 1 cid in set + len, err = mgr.SetLen(setID1) + require.NoError(t, err) + require.Equal(t, 1, len) + + arr, err = mgr.SetToArray(setID1) + require.NoError(t, err) + require.Len(t, arr, 1) + require.Equal(t, arr[0], cid1) + + // set1: +cid2 + exists, err = mgr.InsertSetCID(setID1, cid2) + require.NoError(t, err) + require.False(t, exists) + + // Expect 2 cids in set + len, err = mgr.SetLen(setID1) + require.NoError(t, err) + require.Equal(t, 2, len) + + arr, err = mgr.SetToArray(setID1) + require.NoError(t, err) + require.Len(t, arr, 2) + require.Contains(t, arr, cid1) + require.Contains(t, arr, cid2) + + // Delete set1 + err = mgr.DeleteSet(setID1) + require.NoError(t, err) + + // Expect no items in set + len, err = mgr.SetLen(setID1) + require.NoError(t, err) + require.Equal(t, 0, len) + + arr, err = mgr.SetToArray(setID1) + require.NoError(t, err) + require.Len(t, arr, 0) +} + +// Add items to set then get the length (to make sure that internal caching +// is working correctly) +func TestCIDSetLenAfterInsert(t *testing.T) { + cids := testutil.GenerateCids(2) + cid1 := cids[0] + cid2 := cids[1] + + dstore := ds_sync.MutexWrap(ds.NewMapDatastore()) + mgr := NewCIDSetManager(dstore) + setID1 := SetID("set1") + + // set1: +cid1 + exists, err := mgr.InsertSetCID(setID1, cid1) + require.NoError(t, err) + require.False(t, exists) + + // set1: +cid2 + exists, err = mgr.InsertSetCID(setID1, cid2) + require.NoError(t, err) + require.False(t, exists) + + // Expect 2 cids in set + len, err := mgr.SetLen(setID1) + require.NoError(t, err) + require.Equal(t, 2, len) +} + +func TestCIDSetRestart(t *testing.T) { + cids := testutil.GenerateCids(3) + cid1 := cids[0] + cid2 := cids[1] + cid3 := cids[2] + + dstore := ds_sync.MutexWrap(ds.NewMapDatastore()) + mgr := NewCIDSetManager(dstore) + setID1 := SetID("set1") + + // set1: +cid1 + exists, err := mgr.InsertSetCID(setID1, cid1) + require.NoError(t, err) + require.False(t, exists) + + // set1: +cid2 + exists, err = mgr.InsertSetCID(setID1, cid2) + require.NoError(t, err) + require.False(t, exists) + + // Expect 2 cids in set + arr, err := mgr.SetToArray(setID1) + require.NoError(t, err) + require.Len(t, arr, 2) + require.Contains(t, arr, cid1) + require.Contains(t, arr, cid2) + + // Simulate a restart by creating a new CIDSetManager from the same + // datastore + mgr = NewCIDSetManager(dstore) + + // Expect 2 cids in set + arr, err = mgr.SetToArray(setID1) + require.NoError(t, err) + require.Len(t, arr, 2) + require.Contains(t, arr, cid1) + require.Contains(t, arr, cid2) + + // set1: +cid3 + exists, err = mgr.InsertSetCID(setID1, cid3) + require.NoError(t, err) + require.False(t, exists) + + // Expect 3 cids in set + arr, err = mgr.SetToArray(setID1) + require.NoError(t, err) + require.Len(t, arr, 3) + require.Contains(t, arr, cid1) + require.Contains(t, arr, cid2) + require.Contains(t, arr, cid3) +} diff --git a/impl/initiating_test.go b/impl/initiating_test.go index 3259817c..74f87ff3 100644 --- a/impl/initiating_test.go +++ b/impl/initiating_test.go @@ -386,7 +386,7 @@ func TestDataTransferRestartInitiating(t *testing.T) { require.Equal(t, openChannel.Selector, h.stor) require.True(t, openChannel.Message.IsRequest()) // received cids should be a part of the channel req - require.Equal(t, []cid.Cid{testCids[0], testCids[1]}, openChannel.DoNotSendCids) + require.ElementsMatch(t, []cid.Cid{testCids[0], testCids[1]}, openChannel.DoNotSendCids) receivedRequest, ok := openChannel.Message.(datatransfer.Request) require.True(t, ok) diff --git a/impl/responding_test.go b/impl/responding_test.go index 634db1d9..211b5aa9 100644 --- a/impl/responding_test.go +++ b/impl/responding_test.go @@ -673,7 +673,7 @@ func TestDataTransferRestartResponding(t *testing.T) { require.Equal(t, openChannel.Root, cidlink.Link{Cid: h.baseCid}) require.Equal(t, openChannel.Selector, h.stor) // assert do not send cids are sent - require.Equal(t, []cid.Cid{testCids[0], testCids[1]}, openChannel.DoNotSendCids) + require.ElementsMatch(t, []cid.Cid{testCids[0], testCids[1]}, openChannel.DoNotSendCids) require.False(t, openChannel.Message.IsRequest()) response, ok := openChannel.Message.(datatransfer.Response) require.True(t, ok) @@ -884,7 +884,7 @@ func TestDataTransferRestartResponding(t *testing.T) { require.Equal(t, openChannel.Selector, h.stor) require.True(t, openChannel.Message.IsRequest()) // received cids should be a part of the channel req - require.Equal(t, []cid.Cid{testCids[0], testCids[1]}, openChannel.DoNotSendCids) + require.ElementsMatch(t, openChannel.DoNotSendCids, testCids) // assert a restart request is in the channel request, ok := openChannel.Message.(datatransfer.Request) diff --git a/impl/restart_integration_test.go b/impl/restart_integration_test.go index 4c2c0d78..ed2a7337 100644 --- a/impl/restart_integration_test.go +++ b/impl/restart_integration_test.go @@ -98,6 +98,7 @@ func TestRestartPush(t *testing.T) { queued := make(chan uint64, totalIncrements*2) sent := make(chan uint64, totalIncrements*2) received := make(chan uint64, totalIncrements*2) + var receivedCids []cid.Cid receivedTillNow := atomic.NewInt32(0) // counters we will check at the end for correctness @@ -106,6 +107,7 @@ func TestRestartPush(t *testing.T) { var finishedPeers []peer.ID disConnChan := make(chan struct{}) + var chid datatransfer.ChannelID var subscriber datatransfer.Subscriber = func(event datatransfer.Event, channelState datatransfer.ChannelState) { if event.Code == datatransfer.DataQueued { if channelState.Queued() > 0 { @@ -133,7 +135,18 @@ func TestRestartPush(t *testing.T) { } if channelState.Status() == datatransfer.Completed { finishedPeersLk.Lock() - finishedPeers = append(finishedPeers, channelState.SelfPeer()) + { + finishedPeers = append(finishedPeers, channelState.SelfPeer()) + + // When the receiving peer completes, record received CIDs + // before they get cleaned up + if channelState.SelfPeer() == rh.peer2 { + chs, err := rh.dt2.InProgressChannels(rh.testCtx) + require.NoError(t, err) + require.Len(t, chs, 1) + receivedCids = chs[chid].ReceivedCids() + } + } finishedPeersLk.Unlock() finished <- channelState.SelfPeer() } @@ -153,7 +166,7 @@ func TestRestartPush(t *testing.T) { rh.dt2.SubscribeToEvents(subscriber) // OPEN PUSH - chid := tc.openPushF(rh) + chid = tc.openPushF(rh) // wait for disconnection to happen <-disConnChan t.Logf("peers unlinked and disconnected, total increments received till now: %d", receivedTillNow.Load()) @@ -205,15 +218,7 @@ func TestRestartPush(t *testing.T) { require.NoError(t, err) // verify all cids are present on the receiver - chs, err := rh.dt2.InProgressChannels(rh.testCtx) - require.NoError(t, err) - require.Len(t, chs, 1) - cids := chs[chid].ReceivedCids() - set := cid.NewSet() - for _, c := range cids { - set.Add(c) - } - require.Equal(t, totalIncrements, set.Len()) + require.Equal(t, totalIncrements, len(receivedCids)) testutil.VerifyHasFile(rh.testCtx, t, rh.destDagService, rh.root, rh.origBytes) rh.sv.VerifyExpectations(t) @@ -304,6 +309,7 @@ func TestRestartPull(t *testing.T) { sent := make(chan uint64, totalIncrements) received := make(chan uint64, totalIncrements) receivedTillNow := atomic.NewInt32(0) + var receivedCids []cid.Cid // counters we will check at the end for correctness opens := atomic.NewInt32(0) @@ -311,6 +317,7 @@ func TestRestartPull(t *testing.T) { var finishedPeers []peer.ID disConnChan := make(chan struct{}) + var chid datatransfer.ChannelID var subscriber datatransfer.Subscriber = func(event datatransfer.Event, channelState datatransfer.ChannelState) { if event.Code == datatransfer.DataQueued { if channelState.Queued() > 0 { @@ -333,7 +340,18 @@ func TestRestartPull(t *testing.T) { if channelState.Status() == datatransfer.Completed { finishedPeersLk.Lock() - finishedPeers = append(finishedPeers, channelState.SelfPeer()) + { + finishedPeers = append(finishedPeers, channelState.SelfPeer()) + + // When the receiving peer completes, record received CIDs + // before they get cleaned up + if channelState.SelfPeer() == rh.peer2 { + chs, err := rh.dt2.InProgressChannels(rh.testCtx) + require.NoError(t, err) + require.Len(t, chs, 1) + receivedCids = chs[chid].ReceivedCids() + } + } finishedPeersLk.Unlock() finished <- channelState.SelfPeer() } @@ -353,7 +371,7 @@ func TestRestartPull(t *testing.T) { rh.dt2.SubscribeToEvents(subscriber) // OPEN pull - chid := tc.openPullF(rh) + chid = tc.openPullF(rh) // wait for disconnection to happen select { @@ -407,15 +425,7 @@ func TestRestartPull(t *testing.T) { require.NoError(t, err) // verify all cids are present on the receiver - chs, err := rh.dt2.InProgressChannels(rh.testCtx) - require.NoError(t, err) - require.Len(t, chs, 1) - cids := chs[chid].ReceivedCids() - set := cid.NewSet() - for _, c := range cids { - set.Add(c) - } - require.Equal(t, totalIncrements, set.Len()) + require.Equal(t, totalIncrements, len(receivedCids)) testutil.VerifyHasFile(rh.testCtx, t, rh.destDagService, rh.root, rh.origBytes) rh.sv.VerifyExpectations(t) diff --git a/types.go b/types.go index 60a6e36c..0f65fcaf 100644 --- a/types.go +++ b/types.go @@ -132,6 +132,9 @@ type ChannelState interface { // ReceivedCids returns the cids received so far on the channel ReceivedCids() []cid.Cid + // ReceivedCidsLen returns the number of cids received so far on the channel + ReceivedCidsLen() int + // Queued returns the number of bytes read from the node and queued for sending Queued() uint64