diff --git a/Makefile b/Makefile index ff4cee4f5..d95250575 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ RUN_LONG_TESTS?=yes GO_1_10_AND_HIGHER=$(shell (printf '%s\n' go1.10 $(GOVERSION) | sort -cV >/dev/null 2>&1) && echo "yes") -all: test check system-test +all: test bench check system-test prepare: go get -u github.com/alecthomas/gometalinter @@ -57,6 +57,13 @@ else go test -v `go list ./... | grep -v vendor/` -gocheck.v=true endif +bench: +ifeq ($(GO_1_10_AND_HIGHER), yes) + go test -v ./... -run=nothing -bench=. -benchmem +else + go test -v `go list ./... | grep -v vendor/` -run=nothing -bench=. -benchmem +endif + mem.png: mem.dat mem.gp gnuplot mem.gp open mem.png diff --git a/deb/local.go b/deb/local.go index 9b2207c6a..229f6e1da 100644 --- a/deb/local.go +++ b/deb/local.go @@ -2,6 +2,7 @@ package deb import ( "bytes" + "errors" "fmt" "log" "sync" @@ -93,8 +94,8 @@ func (repo *LocalRepo) RefKey() []byte { // LocalRepoCollection does listing, updating/adding/deleting of LocalRepos type LocalRepoCollection struct { *sync.RWMutex - db database.Storage - list []*LocalRepo + db database.Storage + cache map[string]*LocalRepo } // NewLocalRepoCollection loads LocalRepos from DB and makes up collection @@ -102,43 +103,59 @@ func NewLocalRepoCollection(db database.Storage) *LocalRepoCollection { return &LocalRepoCollection{ RWMutex: &sync.RWMutex{}, db: db, + cache: make(map[string]*LocalRepo), } } -func (collection *LocalRepoCollection) loadList() { - if collection.list != nil { - return +func (collection *LocalRepoCollection) search(filter func(*LocalRepo) bool, unique bool) []*LocalRepo { + result := []*LocalRepo(nil) + for _, r := range collection.cache { + if filter(r) { + result = append(result, r) + } } - blobs := collection.db.FetchByPrefix([]byte("L")) - collection.list = make([]*LocalRepo, 0, len(blobs)) + if unique && len(result) > 0 { + return result + } - for _, blob := range blobs { + collection.db.ProcessByPrefix([]byte("L"), func(key, blob []byte) error { r := &LocalRepo{} if err := r.Decode(blob); err != nil { - log.Printf("Error decoding repo: %s\n", err) - } else { - collection.list = append(collection.list, r) + log.Printf("Error decoding local repo: %s\n", err) + return nil } - } + + if filter(r) { + if _, exists := collection.cache[r.UUID]; !exists { + collection.cache[r.UUID] = r + result = append(result, r) + if unique { + return errors.New("abort") + } + } + } + + return nil + }) + + return result } // Add appends new repo to collection and saves it func (collection *LocalRepoCollection) Add(repo *LocalRepo) error { - collection.loadList() + _, err := collection.ByName(repo.Name) - for _, r := range collection.list { - if r.Name == repo.Name { - return fmt.Errorf("local repo with name %s already exists", repo.Name) - } + if err == nil { + return fmt.Errorf("local repo with name %s already exists", repo.Name) } - err := collection.Update(repo) + err = collection.Update(repo) if err != nil { return err } - collection.list = append(collection.list, repo) + collection.cache[repo.UUID] = repo return nil } @@ -159,8 +176,6 @@ func (collection *LocalRepoCollection) Update(repo *LocalRepo) error { // LoadComplete loads additional information for local repo func (collection *LocalRepoCollection) LoadComplete(repo *LocalRepo) error { - collection.loadList() - encoded, err := collection.db.Get(repo.RefKey()) if err == database.ErrNotFound { return nil @@ -175,26 +190,39 @@ func (collection *LocalRepoCollection) LoadComplete(repo *LocalRepo) error { // ByName looks up repository by name func (collection *LocalRepoCollection) ByName(name string) (*LocalRepo, error) { - collection.loadList() - - for _, r := range collection.list { - if r.Name == name { - return r, nil - } + result := collection.search(func(r *LocalRepo) bool { return r.Name == name }, true) + if len(result) == 0 { + return nil, fmt.Errorf("local repo with name %s not found", name) } - return nil, fmt.Errorf("local repo with name %s not found", name) + + return result[0], nil } // ByUUID looks up repository by uuid func (collection *LocalRepoCollection) ByUUID(uuid string) (*LocalRepo, error) { - collection.loadList() + if r, ok := collection.cache[uuid]; ok { + return r, nil + } - for _, r := range collection.list { - if r.UUID == uuid { - return r, nil - } + key := (&LocalRepo{UUID: uuid}).Key() + + value, err := collection.db.Get(key) + if err == database.ErrNotFound { + return nil, fmt.Errorf("local repo with uuid %s not found", uuid) + } + + if err != nil { + return nil, err + } + + r := &LocalRepo{} + err = r.Decode(value) + + if err == nil { + collection.cache[r.UUID] = r } - return nil, fmt.Errorf("local repo with uuid %s not found", uuid) + + return r, err } // ForEach runs method for each repository @@ -212,30 +240,16 @@ func (collection *LocalRepoCollection) ForEach(handler func(*LocalRepo) error) e // Len returns number of remote repos func (collection *LocalRepoCollection) Len() int { - collection.loadList() - - return len(collection.list) + return len(collection.db.KeysByPrefix([]byte("L"))) } // Drop removes remote repo from collection func (collection *LocalRepoCollection) Drop(repo *LocalRepo) error { - collection.loadList() - - repoPosition := -1 - - for i, r := range collection.list { - if r == repo { - repoPosition = i - break - } - } - - if repoPosition == -1 { + if _, err := collection.db.Get(repo.Key()); err == database.ErrNotFound { panic("local repo not found!") } - collection.list[len(collection.list)-1], collection.list[repoPosition], collection.list = - nil, collection.list[len(collection.list)-1], collection.list[:len(collection.list)-1] + delete(collection.cache, repo.UUID) err := collection.db.Delete(repo.Key()) if err != nil { diff --git a/deb/local_test.go b/deb/local_test.go index e472463d9..82df0497a 100644 --- a/deb/local_test.go +++ b/deb/local_test.go @@ -124,6 +124,11 @@ func (s *LocalRepoCollectionSuite) TestByUUID(c *C) { r, err := s.collection.ByUUID(repo.UUID) c.Assert(err, IsNil) + c.Assert(r, Equals, repo) + + collection := NewLocalRepoCollection(s.db) + r, err = collection.ByUUID(repo.UUID) + c.Assert(err, IsNil) c.Assert(r.String(), Equals, repo.String()) } diff --git a/deb/remote.go b/deb/remote.go index 2646dd0d7..cc26eef17 100644 --- a/deb/remote.go +++ b/deb/remote.go @@ -3,6 +3,7 @@ package deb import ( "bytes" gocontext "context" + "errors" "fmt" "log" "net/url" @@ -654,8 +655,8 @@ func (repo *RemoteRepo) RefKey() []byte { // RemoteRepoCollection does listing, updating/adding/deleting of RemoteRepos type RemoteRepoCollection struct { *sync.RWMutex - db database.Storage - list []*RemoteRepo + db database.Storage + cache map[string]*RemoteRepo } // NewRemoteRepoCollection loads RemoteRepos from DB and makes up collection @@ -663,43 +664,59 @@ func NewRemoteRepoCollection(db database.Storage) *RemoteRepoCollection { return &RemoteRepoCollection{ RWMutex: &sync.RWMutex{}, db: db, + cache: make(map[string]*RemoteRepo), } } -func (collection *RemoteRepoCollection) loadList() { - if collection.list != nil { - return +func (collection *RemoteRepoCollection) search(filter func(*RemoteRepo) bool, unique bool) []*RemoteRepo { + result := []*RemoteRepo(nil) + for _, r := range collection.cache { + if filter(r) { + result = append(result, r) + } } - blobs := collection.db.FetchByPrefix([]byte("R")) - collection.list = make([]*RemoteRepo, 0, len(blobs)) + if unique && len(result) > 0 { + return result + } - for _, blob := range blobs { + collection.db.ProcessByPrefix([]byte("R"), func(key, blob []byte) error { r := &RemoteRepo{} if err := r.Decode(blob); err != nil { - log.Printf("Error decoding mirror: %s\n", err) - } else { - collection.list = append(collection.list, r) + log.Printf("Error decoding remote repo: %s\n", err) + return nil } - } + + if filter(r) { + if _, exists := collection.cache[r.UUID]; !exists { + collection.cache[r.UUID] = r + result = append(result, r) + if unique { + return errors.New("abort") + } + } + } + + return nil + }) + + return result } // Add appends new repo to collection and saves it func (collection *RemoteRepoCollection) Add(repo *RemoteRepo) error { - collection.loadList() + _, err := collection.ByName(repo.Name) - for _, r := range collection.list { - if r.Name == repo.Name { - return fmt.Errorf("mirror with name %s already exists", repo.Name) - } + if err == nil { + return fmt.Errorf("mirror with name %s already exists", repo.Name) } - err := collection.Update(repo) + err = collection.Update(repo) if err != nil { return err } - collection.list = append(collection.list, repo) + collection.cache[repo.UUID] = repo return nil } @@ -734,26 +751,38 @@ func (collection *RemoteRepoCollection) LoadComplete(repo *RemoteRepo) error { // ByName looks up repository by name func (collection *RemoteRepoCollection) ByName(name string) (*RemoteRepo, error) { - collection.loadList() - - for _, r := range collection.list { - if r.Name == name { - return r, nil - } + result := collection.search(func(r *RemoteRepo) bool { return r.Name == name }, true) + if len(result) == 0 { + return nil, fmt.Errorf("mirror with name %s not found", name) } - return nil, fmt.Errorf("mirror with name %s not found", name) + + return result[0], nil } // ByUUID looks up repository by uuid func (collection *RemoteRepoCollection) ByUUID(uuid string) (*RemoteRepo, error) { - collection.loadList() + if r, ok := collection.cache[uuid]; ok { + return r, nil + } - for _, r := range collection.list { - if r.UUID == uuid { - return r, nil - } + key := (&RemoteRepo{UUID: uuid}).Key() + + value, err := collection.db.Get(key) + if err == database.ErrNotFound { + return nil, fmt.Errorf("mirror with uuid %s not found", uuid) } - return nil, fmt.Errorf("mirror with uuid %s not found", uuid) + if err != nil { + return nil, err + } + + r := &RemoteRepo{} + err = r.Decode(value) + + if err == nil { + collection.cache[r.UUID] = r + } + + return r, err } // ForEach runs method for each repository @@ -771,30 +800,16 @@ func (collection *RemoteRepoCollection) ForEach(handler func(*RemoteRepo) error) // Len returns number of remote repos func (collection *RemoteRepoCollection) Len() int { - collection.loadList() - - return len(collection.list) + return len(collection.db.KeysByPrefix([]byte("R"))) } // Drop removes remote repo from collection func (collection *RemoteRepoCollection) Drop(repo *RemoteRepo) error { - collection.loadList() - - repoPosition := -1 - - for i, r := range collection.list { - if r == repo { - repoPosition = i - break - } - } - - if repoPosition == -1 { + if _, err := collection.db.Get(repo.Key()); err == database.ErrNotFound { panic("repo not found!") } - collection.list[len(collection.list)-1], collection.list[repoPosition], collection.list = - nil, collection.list[len(collection.list)-1], collection.list[:len(collection.list)-1] + delete(collection.cache, repo.UUID) err := collection.db.Delete(repo.Key()) if err != nil { diff --git a/deb/remote_test.go b/deb/remote_test.go index 568ec5af0..08fb7f64e 100644 --- a/deb/remote_test.go +++ b/deb/remote_test.go @@ -651,6 +651,11 @@ func (s *RemoteRepoCollectionSuite) TestByUUID(c *C) { r, err := s.collection.ByUUID(repo.UUID) c.Assert(err, IsNil) + c.Assert(r, Equals, repo) + + collection := NewRemoteRepoCollection(s.db) + r, err = collection.ByUUID(repo.UUID) + c.Assert(err, IsNil) c.Assert(r.String(), Equals, repo.String()) } diff --git a/deb/snapshot.go b/deb/snapshot.go index bdf324808..307433af8 100644 --- a/deb/snapshot.go +++ b/deb/snapshot.go @@ -173,8 +173,8 @@ func (s *Snapshot) Decode(input []byte) error { // SnapshotCollection does listing, updating/adding/deleting of Snapshots type SnapshotCollection struct { *sync.RWMutex - db database.Storage - list []*Snapshot + db database.Storage + cache map[string]*Snapshot } // NewSnapshotCollection loads Snapshots from DB and makes up collection @@ -182,43 +182,23 @@ func NewSnapshotCollection(db database.Storage) *SnapshotCollection { return &SnapshotCollection{ RWMutex: &sync.RWMutex{}, db: db, - } -} - -func (collection *SnapshotCollection) loadList() { - if collection.list != nil { - return - } - - blobs := collection.db.FetchByPrefix([]byte("S")) - collection.list = make([]*Snapshot, 0, len(blobs)) - - for _, blob := range blobs { - s := &Snapshot{} - if err := s.Decode(blob); err != nil { - log.Printf("Error decoding snapshot: %s\n", err) - } else { - collection.list = append(collection.list, s) - } + cache: map[string]*Snapshot{}, } } // Add appends new repo to collection and saves it func (collection *SnapshotCollection) Add(snapshot *Snapshot) error { - collection.loadList() - - for _, s := range collection.list { - if s.Name == snapshot.Name { - return fmt.Errorf("snapshot with name %s already exists", snapshot.Name) - } + _, err := collection.ByName(snapshot.Name) + if err == nil { + return fmt.Errorf("snapshot with name %s already exists", snapshot.Name) } - err := collection.Update(snapshot) + err = collection.Update(snapshot) if err != nil { return err } - collection.list = append(collection.list, snapshot) + collection.cache[snapshot.UUID] = snapshot return nil } @@ -245,70 +225,96 @@ func (collection *SnapshotCollection) LoadComplete(snapshot *Snapshot) error { return snapshot.packageRefs.Decode(encoded) } -// ByName looks up snapshot by name -func (collection *SnapshotCollection) ByName(name string) (*Snapshot, error) { - collection.loadList() +func (collection *SnapshotCollection) search(filter func(*Snapshot) bool, unique bool) []*Snapshot { + result := []*Snapshot(nil) + for _, s := range collection.cache { + if filter(s) { + result = append(result, s) + } + } - for _, s := range collection.list { - if s.Name == name { - return s, nil + if unique && len(result) > 0 { + return result + } + + collection.db.ProcessByPrefix([]byte("S"), func(key, blob []byte) error { + s := &Snapshot{} + if err := s.Decode(blob); err != nil { + log.Printf("Error decoding snapshot: %s\n", err) + return nil } + + if filter(s) { + if _, exists := collection.cache[s.UUID]; !exists { + collection.cache[s.UUID] = s + result = append(result, s) + if unique { + return errors.New("abort") + } + } + } + + return nil + }) + + return result +} + +// ByName looks up snapshot by name +func (collection *SnapshotCollection) ByName(name string) (*Snapshot, error) { + result := collection.search(func(s *Snapshot) bool { return s.Name == name }, true) + if len(result) > 0 { + return result[0], nil } + return nil, fmt.Errorf("snapshot with name %s not found", name) } // ByUUID looks up snapshot by UUID func (collection *SnapshotCollection) ByUUID(uuid string) (*Snapshot, error) { - collection.loadList() + if s, ok := collection.cache[uuid]; ok { + return s, nil + } - for _, s := range collection.list { - if s.UUID == uuid { - return s, nil - } + key := (&Snapshot{UUID: uuid}).Key() + + value, err := collection.db.Get(key) + if err == database.ErrNotFound { + return nil, fmt.Errorf("snapshot with uuid %s not found", uuid) } - return nil, fmt.Errorf("snapshot with uuid %s not found", uuid) + if err != nil { + return nil, err + } + + s := &Snapshot{} + err = s.Decode(value) + + if err == nil { + collection.cache[s.UUID] = s + } + + return s, err } // ByRemoteRepoSource looks up snapshots that have specified RemoteRepo as a source func (collection *SnapshotCollection) ByRemoteRepoSource(repo *RemoteRepo) []*Snapshot { - collection.loadList() - - var result []*Snapshot - - for _, s := range collection.list { - if s.SourceKind == SourceRemoteRepo && utils.StrSliceHasItem(s.SourceIDs, repo.UUID) { - result = append(result, s) - } - } - return result + return collection.search(func(s *Snapshot) bool { + return s.SourceKind == SourceRemoteRepo && utils.StrSliceHasItem(s.SourceIDs, repo.UUID) + }, false) } // ByLocalRepoSource looks up snapshots that have specified LocalRepo as a source func (collection *SnapshotCollection) ByLocalRepoSource(repo *LocalRepo) []*Snapshot { - collection.loadList() - - var result []*Snapshot - - for _, s := range collection.list { - if s.SourceKind == SourceLocalRepo && utils.StrSliceHasItem(s.SourceIDs, repo.UUID) { - result = append(result, s) - } - } - return result + return collection.search(func(s *Snapshot) bool { + return s.SourceKind == SourceLocalRepo && utils.StrSliceHasItem(s.SourceIDs, repo.UUID) + }, false) } // BySnapshotSource looks up snapshots that have specified snapshot as a source func (collection *SnapshotCollection) BySnapshotSource(snapshot *Snapshot) []*Snapshot { - collection.loadList() - - var result []*Snapshot - - for _, s := range collection.list { - if s.SourceKind == "snapshot" && utils.StrSliceHasItem(s.SourceIDs, snapshot.UUID) { - result = append(result, s) - } - } - return result + return collection.search(func(s *Snapshot) bool { + return s.SourceKind == "snapshot" && utils.StrSliceHasItem(s.SourceIDs, snapshot.UUID) + }, false) } // ForEach runs method for each snapshot @@ -326,15 +332,25 @@ func (collection *SnapshotCollection) ForEach(handler func(*Snapshot) error) err // ForEachSorted runs method for each snapshot following some sort order func (collection *SnapshotCollection) ForEachSorted(sortMethod string, handler func(*Snapshot) error) error { - collection.loadList() + blobs := collection.db.FetchByPrefix([]byte("S")) + list := make([]*Snapshot, 0, len(blobs)) + + for _, blob := range blobs { + s := &Snapshot{} + if err := s.Decode(blob); err != nil { + log.Printf("Error decoding snapshot: %s\n", err) + } else { + list = append(list, s) + } + } - sorter, err := newSnapshotSorter(sortMethod, collection) + sorter, err := newSnapshotSorter(sortMethod, list) if err != nil { return err } - for _, i := range sorter.list { - err = handler(collection.list[i]) + for _, s := range sorter.list { + err = handler(s) if err != nil { return err } @@ -346,30 +362,16 @@ func (collection *SnapshotCollection) ForEachSorted(sortMethod string, handler f // Len returns number of snapshots in collection // ForEach runs method for each snapshot func (collection *SnapshotCollection) Len() int { - collection.loadList() - - return len(collection.list) + return len(collection.db.KeysByPrefix([]byte("S"))) } // Drop removes snapshot from collection func (collection *SnapshotCollection) Drop(snapshot *Snapshot) error { - collection.loadList() - - snapshotPosition := -1 - - for i, s := range collection.list { - if s == snapshot { - snapshotPosition = i - break - } - } - - if snapshotPosition == -1 { + if _, err := collection.db.Get(snapshot.Key()); err == database.ErrNotFound { panic("snapshot not found!") } - collection.list[len(collection.list)-1], collection.list[snapshotPosition], collection.list = - nil, collection.list[len(collection.list)-1], collection.list[:len(collection.list)-1] + delete(collection.cache, snapshot.UUID) err := collection.db.Delete(snapshot.Key()) if err != nil { @@ -386,13 +388,12 @@ const ( ) type snapshotSorter struct { - list []int - collection *SnapshotCollection + list []*Snapshot sortMethod int } -func newSnapshotSorter(sortMethod string, collection *SnapshotCollection) (*snapshotSorter, error) { - s := &snapshotSorter{collection: collection} +func newSnapshotSorter(sortMethod string, list []*Snapshot) (*snapshotSorter, error) { + s := &snapshotSorter{list: list} switch sortMethod { case "time", "Time": @@ -403,11 +404,6 @@ func newSnapshotSorter(sortMethod string, collection *SnapshotCollection) (*snap return nil, fmt.Errorf("sorting method \"%s\" unknown", sortMethod) } - s.list = make([]int, len(collection.list)) - for i := range s.list { - s.list[i] = i - } - sort.Sort(s) return s, nil @@ -420,9 +416,9 @@ func (s *snapshotSorter) Swap(i, j int) { func (s *snapshotSorter) Less(i, j int) bool { switch s.sortMethod { case SortName: - return s.collection.list[s.list[i]].Name < s.collection.list[s.list[j]].Name + return s.list[i].Name < s.list[j].Name case SortTime: - return s.collection.list[s.list[i]].CreatedAt.Before(s.collection.list[s.list[j]].CreatedAt) + return s.list[i].CreatedAt.Before(s.list[j].CreatedAt) } panic("unknown sort method") } diff --git a/deb/snapshot_bench_test.go b/deb/snapshot_bench_test.go index 19997c2c8..0240b9f8b 100644 --- a/deb/snapshot_bench_test.go +++ b/deb/snapshot_bench_test.go @@ -36,3 +36,63 @@ func BenchmarkSnapshotCollectionForEach(b *testing.B) { }) } } + +func BenchmarkSnapshotCollectionByUUID(b *testing.B) { + const count = 1024 + + tmpDir := os.TempDir() + defer os.RemoveAll(tmpDir) + + db, _ := database.NewOpenDB(tmpDir) + defer db.Close() + + collection := NewSnapshotCollection(db) + + uuids := []string{} + for i := 0; i < count; i++ { + snapshot := NewSnapshotFromRefList(fmt.Sprintf("snapshot%d", i), nil, NewPackageRefList(), fmt.Sprintf("Snapshot number %d", i)) + if collection.Add(snapshot) != nil { + b.FailNow() + } + uuids = append(uuids, snapshot.UUID) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + collection = NewSnapshotCollection(db) + + if _, err := collection.ByUUID(uuids[i%len(uuids)]); err != nil { + b.FailNow() + } + } +} + +func BenchmarkSnapshotCollectionByName(b *testing.B) { + const count = 1024 + + tmpDir := os.TempDir() + defer os.RemoveAll(tmpDir) + + db, _ := database.NewOpenDB(tmpDir) + defer db.Close() + + collection := NewSnapshotCollection(db) + + for i := 0; i < count; i++ { + snapshot := NewSnapshotFromRefList(fmt.Sprintf("snapshot%d", i), nil, NewPackageRefList(), fmt.Sprintf("Snapshot number %d", i)) + if collection.Add(snapshot) != nil { + b.FailNow() + } + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + collection = NewSnapshotCollection(db) + + if _, err := collection.ByName(fmt.Sprintf("snapshot%d", i%count)); err != nil { + b.FailNow() + } + } +} diff --git a/deb/snapshot_test.go b/deb/snapshot_test.go index 85bf14ca6..0d1bd25bd 100644 --- a/deb/snapshot_test.go +++ b/deb/snapshot_test.go @@ -2,6 +2,7 @@ package deb import ( "errors" + "sort" "github.com/aptly-dev/aptly/database" @@ -158,6 +159,10 @@ func (s *SnapshotCollectionSuite) TestAddByNameByUUID(c *C) { snapshot, err = collection.ByUUID(s.snapshot1.UUID) c.Assert(err, IsNil) c.Assert(snapshot.String(), Equals, s.snapshot1.String()) + + snapshot, err = collection.ByUUID(s.snapshot2.UUID) + c.Assert(err, IsNil) + c.Assert(snapshot.String(), Equals, s.snapshot2.String()) } func (s *SnapshotCollectionSuite) TestUpdateLoadComplete(c *C) { @@ -193,6 +198,23 @@ func (s *SnapshotCollectionSuite) TestForEachAndLen(c *C) { c.Assert(err, Equals, e) } +func (s *SnapshotCollectionSuite) TestForEachSorted(c *C) { + s.collection.Add(s.snapshot2) + s.collection.Add(s.snapshot1) + s.collection.Add(s.snapshot4) + s.collection.Add(s.snapshot3) + + names := []string{} + + err := s.collection.ForEachSorted("name", func(snapshot *Snapshot) error { + names = append(names, snapshot.Name) + return nil + }) + c.Assert(err, IsNil) + + c.Check(sort.StringsAreSorted(names), Equals, true) +} + func (s *SnapshotCollectionSuite) TestFindByRemoteRepoSource(c *C) { c.Assert(s.collection.Add(s.snapshot1), IsNil) c.Assert(s.collection.Add(s.snapshot2), IsNil) @@ -230,7 +252,11 @@ func (s *SnapshotCollectionSuite) TestFindSnapshotSource(c *C) { c.Assert(s.collection.Add(snapshot4), IsNil) c.Assert(s.collection.Add(snapshot5), IsNil) - c.Check(s.collection.BySnapshotSource(s.snapshot1), DeepEquals, []*Snapshot{snapshot3, snapshot4}) + list := s.collection.BySnapshotSource(s.snapshot1) + sorter, _ := newSnapshotSorter("name", list) + sort.Sort(sorter) + + c.Check(sorter.list, DeepEquals, []*Snapshot{snapshot3, snapshot4}) c.Check(s.collection.BySnapshotSource(s.snapshot2), DeepEquals, []*Snapshot{snapshot3}) c.Check(s.collection.BySnapshotSource(snapshot5), DeepEquals, []*Snapshot(nil)) }