Skip to content

Commit

Permalink
fix all_gather_object with various length, test=allcases (#44718)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio committed Aug 1, 2022
1 parent 3e8708b commit e48cb42
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
27 changes: 20 additions & 7 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,12 +1032,12 @@ def _convert_object_to_tensor(obj):
_pickler(f).dump(obj)
data = np.frombuffer(f.getvalue(), dtype=np.uint8)
tensor = paddle.to_tensor(data)
return tensor
return tensor, tensor.numel()


def _convert_tensor_to_object(tensor):
def _convert_tensor_to_object(tensor, len_of_tensor):
_unpickler = pickle.Unpickler
return _unpickler(io.BytesIO(tensor.numpy())).load()
return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()


def all_gather_object(object_list, obj, group=None):
Expand Down Expand Up @@ -1076,12 +1076,25 @@ def all_gather_object(object_list, obj, group=None):
assert in_dygraph_mode(
), "all_gather_object doesn't support static graph mode."

tensor = _convert_object_to_tensor(obj)
tensor, len_of_tensor = _convert_object_to_tensor(obj)

# gather len_of_tensor from all ranks
list_len_of_tensor = []
all_gather(list_len_of_tensor, len_of_tensor, group)
# get the max length from list
max_len_of_tensor = int(max(list_len_of_tensor).item())
# resize the input tensor to max length avoid hang in all gather
# Note(liyurui): Maybe we should support various length all_gather?
# Now this operation is efficient for we don't support resize in python.
numpy_data = tensor.numpy()
numpy_data = np.resize(numpy_data, [max_len_of_tensor])
input_tensor = paddle.to_tensor(numpy_data)

tensor_list = []
all_gather(tensor_list, tensor, group)
for tensor in tensor_list:
object_list.append(_convert_tensor_to_object(tensor))
all_gather(tensor_list, input_tensor, group)
for i, tensor in enumerate(tensor_list):
object_list.append(
_convert_tensor_to_object(tensor, list_len_of_tensor[i]))


def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def create_complex_test_data(shape=None, dtype=None, seed=None):
def create_pylist_test_data(shape=None, seed=None):
if seed:
np.random.seed(seed)
# Generate random shape test case for xxx_object api
shape = np.random.randint(0, high=100, size=(2)).tolist()
data = np.random.random(shape).tolist()
return data

Expand Down

0 comments on commit e48cb42

Please sign in to comment.