Skip to content

Commit

Permalink
[Fix] Fix dist.collect_results to keep all ranks' elements (#1469)
Browse files Browse the repository at this point in the history
  • Loading branch information
LZHgrla committed Jan 11, 2024
1 parent b51bf60 commit 109cd44
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions mmengine/dist/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
from torch.distributed import ProcessGroup

from itertools import zip_longest, chain
import mmengine
from .utils import (get_world_size, get_rank, get_backend, get_dist_info,
get_default_group, barrier, get_data_device,
Expand Down Expand Up @@ -1010,8 +1010,10 @@ def collect_results_cpu(result_part: list,
part_list.append(pickle.load(f))
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
zipped_results = zip_longest(*part_list)
ordered_results = [
i for i in chain.from_iterable(zipped_results) if i is not None
]
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
# remove tmp dir
Expand All @@ -1032,8 +1034,10 @@ def _collect_results_device(result_part: list, size: int) -> Optional[list]:
if rank == 0:
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
zipped_results = zip_longest(*part_list)
ordered_results = [
i for i in chain.from_iterable(zipped_results) if i is not None
]
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
return ordered_results
Expand Down

0 comments on commit 109cd44

Please sign in to comment.