Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complete the dtypes for all_gather, add all_gather_object api #44417

Merged
merged 7 commits into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupGloo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ namespace distributed {
case experimental::DataType::INT64: \
func<int64_t>(args); \
break; \
case experimental::DataType::INT8: \
func<int8_t>(args); \
break; \
case experimental::DataType::UINT8: \
func<uint8_t>(args); \
break; \
case experimental::DataType::BOOL: \
func<bool>(args); \
break; \
default: \
VLOG(0) << "Error: Unknown DataType."; \
exit(-1); \
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/operators/collective/c_allgather_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,7 @@ REGISTER_OP_CPU_KERNEL(c_allgather,
ops::CAllGatherOpCPUKernel<double>,
ops::CAllGatherOpCPUKernel<int>,
ops::CAllGatherOpCPUKernel<int64_t>,
ops::CAllGatherOpCPUKernel<uint8_t>,
ops::CAllGatherOpCPUKernel<int8_t>,
ops::CAllGatherOpCPUKernel<bool>,
ops::CAllGatherOpCPUKernel<plat::float16>);
3 changes: 3 additions & 0 deletions paddle/fluid/operators/collective/c_allgather_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,8 @@ REGISTER_OP_CUDA_KERNEL(c_allgather,
ops::CAllGatherOpCUDAKernel<plat::bfloat16>,
#endif
ops::CAllGatherOpCUDAKernel<int>,
ops::CAllGatherOpCUDAKernel<uint8_t>,
ops::CAllGatherOpCUDAKernel<int8_t>,
ops::CAllGatherOpCUDAKernel<int64_t>,
ops::CAllGatherOpCUDAKernel<bool>,
ops::CAllGatherOpCUDAKernel<plat::float16>);
10 changes: 10 additions & 0 deletions paddle/fluid/platform/device/gpu/nccl_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
return ncclFloat16;
} else if (type == framework::proto::VarType::INT8) {
return ncclInt8;
} else if (type == framework::proto::VarType::UINT8) {
return ncclUint8;
} else if (type == framework::proto::VarType::BOOL) {
return ncclUint8;
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
} else if (type == framework::proto::VarType::BF16) {
return ncclBfloat16;
Expand All @@ -76,6 +80,12 @@ inline ncclDataType_t ToNCCLDataType(experimental::DataType type) {
return ncclInt64;
} else if (type == experimental::DataType::FLOAT16) {
return ncclFloat16;
} else if (type == experimental::DataType::UINT8) {
return ncclUint8;
} else if (type == experimental::DataType::INT8) {
return ncclInt8;
} else if (type == experimental::DataType::BOOL) {
return ncclUint8;
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
} else if (type == experimental::DataType::BFLOAT16) {
return ncclBfloat16;
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/cpu/split_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,7 @@ PD_REGISTER_KERNEL(split,
int64_t,
int,
bool,
uint8_t,
int8_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/split_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,7 @@ PD_REGISTER_KERNEL(split,
int64_t,
int,
bool,
uint8_t,
int8_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
8 changes: 5 additions & 3 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .collective import all_reduce # noqa: F401
from .collective import reduce # noqa: F401
from .collective import all_gather # noqa: F401
from .collective import all_gather_object # noqa: F401
from .collective import scatter # noqa: F401
from .collective import barrier # noqa: F401
from .collective import ReduceOp # noqa: F401
Expand Down Expand Up @@ -71,7 +72,8 @@
"init_parallel_env", "gloo_init_parallel_env", "gloo_barrier",
"gloo_release", "QueueDataset", "split", "CountFilterEntry",
"ShowClickEntry", "get_world_size", "get_group", "all_gather",
"InMemoryDataset", "barrier", "all_reduce", "alltoall", "send", "reduce",
"recv", "ReduceOp", "wait", "get_rank", "ProbabilityEntry", "ParallelMode",
"is_initialized", "isend", "irecv", "reduce_scatter"
"all_gather_object", "InMemoryDataset", "barrier", "all_reduce", "alltoall",
"send", "reduce", "recv", "ReduceOp", "wait", "get_rank",
"ProbabilityEntry", "ParallelMode", "is_initialized", "isend", "irecv",
"reduce_scatter"
]
113 changes: 93 additions & 20 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import numpy as np
import os
import pickle
import io
from datetime import timedelta
from ..fluid.layer_helper import LayerHelper
from ..fluid.framework import Variable
Expand Down Expand Up @@ -927,9 +929,9 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):

Args:
tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32 or int64.
should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32 or int64.
should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
Default to True.
LiYuRio marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -941,29 +943,33 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
.. code-block:: python

# required: distributed
import numpy as np
import paddle
from paddle.distributed import init_parallel_env

paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
init_parallel_env()
tensor_list = []
if paddle.distributed.ParallelEnv().local_rank == 0:
np_data1 = np.array([[4, 5, 6], [4, 5, 6]])
np_data2 = np.array([[4, 5, 6], [4, 5, 6]])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
data1 = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
paddle.distributed.all_gather(tensor_list, data1)
else:
np_data1 = np.array([[1, 2, 3], [1, 2, 3]])
np_data2 = np.array([[1, 2, 3], [1, 2, 3]])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
data2 = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
paddle.distributed.all_gather(tensor_list, data2)
"""
if group is not None and not group.is_member():
return

def convert_to_complex(list_of_tensor):
list_of_complex = []
for tensor in list_of_tensor:
list_of_complex.append(paddle.as_complex(tensor))
return list_of_complex

is_input_complex = (tensor.dtype == paddle.complex64
or tensor.dtype == paddle.complex128)
if is_input_complex:
tensor = paddle.as_real(tensor)

if in_dygraph_mode():
group = _get_default_group() if group is None else group
if len(tensor_list) == 0:
Expand All @@ -975,7 +981,11 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
task = group.process_group.all_gather(tensor, out)
task.wait()
tensor_list.clear()
tensor_list.extend(paddle.split(out, group.nranks, 0))
list_of_tensor = paddle.split(out, group.nranks, 0)
if is_input_complex:
tensor_list.extend(convert_to_complex(list_of_tensor))
else:
tensor_list.extend(list_of_tensor)
return

ring_id = 0 if group is None else group.id
Expand All @@ -992,13 +1002,14 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
raise ValueError("The type of 'tensor_list' for all_gather "
"should be list.")
for elem in tensor_list:
check_variable_and_dtype(
elem, 'tensor_list',
['float16', 'float32', 'float64', 'int32', 'int64'],
'all_gather')
check_variable_and_dtype(
tensor, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], 'all_gather')
check_variable_and_dtype(elem, 'tensor_list', [
'float16', 'float32', 'float64', 'int32', 'int64', 'bool',
'int8', 'uint8', 'complex64', 'complex128'
], 'all_gather')
check_variable_and_dtype(tensor, 'tensor', [
'float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'int8',
'uint8', 'complex64', 'complex128'
], 'all_gather')
helper.append_op(type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [out]},
Expand All @@ -1008,7 +1019,69 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
'nranks': nranks
})

tensor_list.extend(paddle.split(out, nranks, 0))
list_of_tensor = paddle.split(out, nranks, 0)
if is_input_complex:
tensor_list.extend(convert_to_complex(list_of_tensor))
else:
tensor_list.extend(list_of_tensor)


def _convert_object_to_tensor(obj):
_pickler = pickle.Pickler
f = io.BytesIO()
_pickler(f).dump(obj)
data = np.frombuffer(f.getvalue(), dtype=np.uint8)
tensor = paddle.to_tensor(data)
return tensor


def _convert_tensor_to_object(tensor):
_unpickler = pickle.Unpickler
return _unpickler(io.BytesIO(tensor.numpy())).load()


def all_gather_object(object_list, obj, group=None):
"""

Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in.

Args:
object_list (list): A list of output object. The datatype of every element in the list is same as the input obj.
obj (Any): The picklable object to send.
group (Group): The group instance return by new_group or None for global default group.

Returns:
None.

Warning:
This API only supports the dygraph mode.

Examples:
.. code-block:: python

# required: distributed
import paddle
import paddle.distributed as dist

paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
LiYuRio marked this conversation as resolved.
Show resolved Hide resolved
dist.init_parallel_env()
object_list = []
if paddle.distributed.ParallelEnv().local_rank == 0:
obj = {"foo": [1, 2, 3]}
paddle.distributed.all_gather_object(object_list, obj)
else:
obj = {"bar": [4, 5, 6]}
paddle.distributed.all_gather_object(object_list, obj)
"""
assert in_dygraph_mode(
), "all_gather_object doesn't support static graph mode."

tensor = _convert_object_to_tensor(obj)

tensor_list = []
all_gather(tensor_list, tensor, group)
for tensor in tensor_list:
object_list.append(_convert_tensor_to_object(tensor))
LiYuRio marked this conversation as resolved.
Show resolved Hide resolved


def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
Expand Down
6 changes: 5 additions & 1 deletion python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
list(REMOVE_ITEM TEST_OPS test_new_group_api)
list(REMOVE_ITEM TEST_OPS test_collective_broadcast_api)
list(REMOVE_ITEM TEST_OPS test_collective_allgather_api)
list(REMOVE_ITEM TEST_OPS test_collective_allgather_object_api)
list(REMOVE_ITEM TEST_OPS test_collective_alltoall_api)
list(REMOVE_ITEM TEST_OPS test_collective_global_gather)
list(REMOVE_ITEM TEST_OPS test_collective_global_scatter)
Expand Down Expand Up @@ -1598,7 +1599,9 @@ if(APPLE)
endif()

if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 300)
set_tests_properties(test_collective_allgather_object_api PROPERTIES TIMEOUT
120)
set_tests_properties(test_collective_alltoall_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_global_gather PROPERTIES TIMEOUT 200)
set_tests_properties(test_collective_global_scatter PROPERTIES TIMEOUT 200)
Expand Down Expand Up @@ -1629,6 +1632,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
test_new_group_api
test_collective_broadcast_api
test_collective_allgather_api
test_collective_allgather_object_api
test_collective_alltoall_api
test_collective_global_gather
test_collective_global_scatter
Expand Down
50 changes: 43 additions & 7 deletions python/paddle/fluid/tests/unittests/collective_allgather_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,64 @@
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
import pickle
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
import test_collective_api_base as test_base

paddle.enable_static()


class TestCollectiveAllgatherAPI(TestCollectiveAPIRunnerBase):
class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):

def __init__(self):
self.global_ring_id = 0

def get_model(self, main_prog, startup_program, rank):
def get_model(self, main_prog, startup_program, rank, dtype=None):
dtype = "float32" if dtype is None else dtype
with fluid.program_guard(main_prog, startup_program):
tensor_list = []
tindata = layers.data(name="tindata",
shape=[10, 1000],
dtype='float32')
tindata = layers.data(name="tindata", shape=[10, 1000], dtype=dtype)
paddle.distributed.all_gather(tensor_list, tindata)
return tensor_list

def run_trainer(self, args):
train_prog = fluid.Program()
startup_prog = fluid.Program()
endpoints = args["endpoints"].split(",")
rank = args["trainerid"]
current_endpoint = args["currentendpoint"]
nranks = 2
paddle.distributed.init_parallel_env()
if args['backend'] == 'nccl':
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(
device_id) #if args.use_gpu else fluid.CPUPlace()
elif args['backend'] == 'bkcl':
device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
place = fluid.XPUPlace(device_id)
else:
place = fluid.CPUPlace()
indata = test_base.create_test_data(shape=(10, 1000),
dtype=args["dtype"],
seed=os.getpid())
assert args[
'static_mode'] == 1, "collective_allgather_api only support static mode"
result = self.get_model(train_prog,
startup_prog,
rank,
dtype=args["dtype"])
exe = fluid.Executor(place)
exe.run(startup_prog)
fetch_list = []
for elem in result:
fetch_list.append(elem.name)
out = exe.run(train_prog,
feed={'tindata': indata},
fetch_list=fetch_list)
sys.stdout.buffer.write(pickle.dumps(out))


if __name__ == "__main__":
runtime_main(TestCollectiveAllgatherAPI, "allgather")
test_base.runtime_main(TestCollectiveAllgatherAPI, "allgather")
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import paddle
import paddle.fluid as fluid
import unittest
import test_collective_api_base as test_base


class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):

def __init__(self):
self.global_ring_id = 0

def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.to_tensor(indata)
tensor_list = []
paddle.distributed.all_gather(tensor_list, tindata)
return [tensor.numpy() for tensor in tensor_list]


if __name__ == "__main__":
test_base.runtime_main(TestCollectiveAllgatherAPI, "allgather")
Loading