Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix all_gather_object to support various length object #44718

Merged
merged 1 commit into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,12 +1032,12 @@ def _convert_object_to_tensor(obj):
_pickler(f).dump(obj)
data = np.frombuffer(f.getvalue(), dtype=np.uint8)
tensor = paddle.to_tensor(data)
return tensor
return tensor, tensor.numel()


def _convert_tensor_to_object(tensor):
def _convert_tensor_to_object(tensor, len_of_tensor):
_unpickler = pickle.Unpickler
return _unpickler(io.BytesIO(tensor.numpy())).load()
return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()


def all_gather_object(object_list, obj, group=None):
Expand Down Expand Up @@ -1076,12 +1076,25 @@ def all_gather_object(object_list, obj, group=None):
assert in_dygraph_mode(
), "all_gather_object doesn't support static graph mode."

tensor = _convert_object_to_tensor(obj)
tensor, len_of_tensor = _convert_object_to_tensor(obj)

# gather len_of_tensor from all ranks
list_len_of_tensor = []
all_gather(list_len_of_tensor, len_of_tensor, group)
# get the max length from list
max_len_of_tensor = int(max(list_len_of_tensor).item())
# resize the input tensor to max length avoid hang in all gather
# Note(liyurui): Maybe we should support various length all_gather?
# Now this operation is efficient for we don't support resize in python.
numpy_data = tensor.numpy()
numpy_data = np.resize(numpy_data, [max_len_of_tensor])
input_tensor = paddle.to_tensor(numpy_data)

tensor_list = []
all_gather(tensor_list, tensor, group)
for tensor in tensor_list:
object_list.append(_convert_tensor_to_object(tensor))
all_gather(tensor_list, input_tensor, group)
for i, tensor in enumerate(tensor_list):
object_list.append(
_convert_tensor_to_object(tensor, list_len_of_tensor[i]))


def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def create_complex_test_data(shape=None, dtype=None, seed=None):
def create_pylist_test_data(shape=None, seed=None):
if seed:
np.random.seed(seed)
# Generate random shape test case for xxx_object api
shape = np.random.randint(0, high=100, size=(2)).tolist()
data = np.random.random(shape).tolist()
return data

Expand Down