Skip to content

Commit

Permalink
add dist_attr for dist op and var (PaddlePaddle#35585)
Browse files Browse the repository at this point in the history
* add dist_attr for dist op

* add unitest

* update inputname

* update function name

* add unitest

* update CMakeLists.txt for CI

* fix dis_matmul

* fix compile error

* update matmul to matmul_v2
  • Loading branch information
zhaoyinglia authored and AnnaTrainingG committed Sep 29, 2021
1 parent 13ad58b commit b04b36c
Show file tree
Hide file tree
Showing 6 changed files with 354 additions and 58 deletions.
32 changes: 28 additions & 4 deletions paddle/fluid/operators/searchsorted_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,37 @@ using Tensor = framework::Tensor;
template <typename T1, typename T2, typename OutType>
class GpuAndCpuSearchSortedCompute {
public:
static HOSTDEVICE bool IsNan(float x) { return ::isnan(x); }
static HOSTDEVICE bool IsNan(double x) { return ::isnan(x); }
static HOSTDEVICE bool IsNan(float x) {
#ifdef __NVCC__
return ::isnan(x);
#else
return std::isnan(x);
#endif
}
static HOSTDEVICE bool IsNan(double x) {
#ifdef __NVCC__
return ::isnan(x);
#else
return std::isnan(x);
#endif
}
static HOSTDEVICE bool IsNan(int x) { return false; }
static HOSTDEVICE bool IsNan(int64_t x) { return false; }

static HOSTDEVICE bool IsInf(float x) { return ::isinf(x); }
static HOSTDEVICE bool IsInf(double x) { return ::isinf(x); }
static HOSTDEVICE bool IsInf(float x) {
#ifdef __NVCC__
return ::isinf(x);
#else
return std::isinf(x);
#endif
}
static HOSTDEVICE bool IsInf(double x) {
#ifdef __NVCC__
return ::isinf(x);
#else
return std::isinf(x);
#endif
}
static HOSTDEVICE bool IsInf(int x) { return false; }
static HOSTDEVICE bool IsInf(int64_t x) { return false; }

Expand Down
47 changes: 47 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,50 @@ def find_best_compatible_distributed_operator_impl(name, op_dist_attr,
best_compatible_impl, idx = None, -1

return best_compatible_impl, idx


def copy_distributed_attr_for_var(src_op_dist_attr, var, src_var):
"""
copy src var's dist_attr to dst var
"""
import copy

auto_paralle_context = src_op_dist_attr.get_owner_context()
dist_attr = copy.deepcopy(
auto_paralle_context.get_tensor_distributed_attr_for_program(src_var))
dist_attr._owner_tensor = var
dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program(
src_var)._owner_context
auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr)


def copy_distributed_attr_for_dist_op(dist_op, dst_block, src_op_dist_attr):
"""
copy src op's dist_attr to dst dist op
"""
from ..attribute import OperatorDistributedAttribute

auto_paralle_context = src_op_dist_attr.get_owner_context()
op_dist_attr = OperatorDistributedAttribute(dist_op, auto_paralle_context)
auto_paralle_context._copy_distributed_attr_from_op_desc(dist_op.desc,
op_dist_attr)
auto_paralle_context.set_op_distributed_attr_for_program(dist_op,
op_dist_attr)

op_dist_attr.set_process_mesh(src_op_dist_attr.get_process_mesh())
op_dist_attr.set_impl_idx(src_op_dist_attr.get_impl_idx())

for input_varname in dist_op.desc.input_arg_names():
input_var = dst_block.var(input_varname)
tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
input_var)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
op_dist_attr.set_input_dims_mapping(input_varname, tensor_dims_mapping)

for output_varname in dist_op.desc.output_arg_names():
output_var = dst_block.var(output_varname)
tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
output_var)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
op_dist_attr.set_output_dims_mapping(output_varname,
tensor_dims_mapping)
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from .common import copy_distributed_attr_for_var
from .common import copy_distributed_attr_for_dist_op
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
Expand Down Expand Up @@ -173,21 +175,24 @@ def static_handle(dst_block,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_var.stop_gradient)
# copy Out_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0,
Out_var)

check_variable_and_dtype(
Out_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'c_allreduce_sum')

dst_block.append_op(
c_embedding_op = dst_block.append_op(
type='c_embedding',
inputs={'Ids': [Ids_var],
'W': [Weight_var]},
outputs={'Out': [intermediate_var_0]},
attrs={"start_index": relative_idx})

# use_model_parallel
dst_block.append_op(
c_allreduce_sum_op = dst_block.append_op(
type='c_allreduce_sum',
inputs={'X': [intermediate_var_0]},
outputs={'Out': [Out_var]},
Expand All @@ -197,6 +202,12 @@ def static_handle(dst_block,
'use_model_parallel': True,
})

# copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(c_embedding_op, dst_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, dst_block,
op_dist_attr)

if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
Expand Down
74 changes: 54 additions & 20 deletions python/paddle/distributed/auto_parallel/operators/dist_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from .common import copy_distributed_attr_for_var
from .common import copy_distributed_attr_for_dist_op
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
Expand Down Expand Up @@ -223,13 +225,16 @@ def static_handle(dst_block,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=X_var.stop_gradient)
# copy X_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0,
X_var)

check_variable_and_dtype(
X_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_identity')

dst_block.append_op(
c_identity_op = dst_block.append_op(
type='c_identity',
inputs={'X': [X_var]},
outputs={'Out': intermediate_var_0},
Expand All @@ -250,12 +255,18 @@ def static_handle(dst_block,
'alpha': 1,
}
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
dst_block.append_op(
matmul_op = dst_block.append_op(
type='matmul',
inputs=inputs,
outputs={'Out': Out_var},
attrs=attrs)

# copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(c_identity_op, dst_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(matmul_op, dst_block,
op_dist_attr)

if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
Expand Down Expand Up @@ -369,13 +380,17 @@ def static_handle(dst_block,
persistable=False,
is_data=False,
need_check_feed=Out_var.desc.need_check_feed())
dst_block.append_op(
# copy Out_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0,
Out_var)

matmul_op = dst_block.append_op(
type='matmul',
inputs=inputs,
outputs={'Out': intermediate_var_0},
attrs=attrs)

dst_block.append_op(
c_allreduce_sum_op = dst_block.append_op(
type='c_allreduce_sum',
inputs={'X': intermediate_var_0},
outputs={'Out': Out_var},
Expand All @@ -385,6 +400,12 @@ def static_handle(dst_block,
'use_model_parallel': True
})

# copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(matmul_op, dst_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, dst_block,
op_dist_attr)

if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
Expand Down Expand Up @@ -540,15 +561,12 @@ def static_handle(dst_block,
Out_var = dst_block.var(output_name_mapping['Out'][0])

# TODO infer logic comm presentation
from ..process import new_process_group
from ..transpiler import _get_comm_group
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.topology,
model_parallel_axis,
process_mesh.process_group, rank_id)
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology,
model_parallel_axis, rank_id)
group = new_process_group(group_ranks)
# print("@@@@@@@@@@@@@@@@@@@@@ 5", group)

intermediate_var_0 = dst_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
Expand All @@ -558,13 +576,16 @@ def static_handle(dst_block,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=X_var.stop_gradient)
# copy X_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0,
X_var)

check_variable_and_dtype(
X_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_identity')

dst_block.append_op(
c_identity_op = dst_block.append_op(
type='c_identity',
inputs={'X': [X_var]},
outputs={'Out': intermediate_var_0},
Expand All @@ -581,12 +602,18 @@ def static_handle(dst_block,
['float16', 'float32', 'float64'], 'linear')
attrs = {'trans_x': False, 'trans_y': False}
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
dst_block.append_op(
matmul_v2_op = dst_block.append_op(
type='matmul_v2',
inputs=inputs,
outputs={'Out': Out_var},
attrs=attrs)

# copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(c_identity_op, dst_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(matmul_v2_op, dst_block,
op_dist_attr)

if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
Expand Down Expand Up @@ -675,15 +702,12 @@ def static_handle(dst_block,
Out_var = dst_block.var(output_name_mapping['Out'][0])

# TODO infer logic comm presentation
from ..process import new_process_group
from ..transpiler import _get_comm_group
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.topology,
model_parallel_axis,
process_mesh.process_group, rank_id)
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology,
model_parallel_axis, rank_id)
group = new_process_group(group_ranks)
# print("@@@@@@@@@@@@@@@@@@@@@ 4", group)

check_variable_and_dtype(
X_var, 'x', ['float16', 'float32', 'float64'], 'linear')
Expand All @@ -699,13 +723,17 @@ def static_handle(dst_block,
persistable=False,
is_data=False,
need_check_feed=Out_var.desc.need_check_feed())
dst_block.append_op(
# copy Out_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0,
Out_var)

matmul_v2_op = dst_block.append_op(
type='matmul_v2',
inputs=inputs,
outputs={'Out': intermediate_var_0},
attrs=attrs)

dst_block.append_op(
c_allreduce_sum_op = dst_block.append_op(
type='c_allreduce_sum',
inputs={'X': intermediate_var_0},
outputs={'Out': Out_var},
Expand All @@ -715,6 +743,12 @@ def static_handle(dst_block,
'use_model_parallel': True
})

# copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(matmul_v2_op, dst_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, dst_block,
op_dist_attr)

if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,8 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_fleet_localsgd_meta_optimizer MODULES test_fleet_localsgd_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_lars_meta_optimizer MODULES test_fleet_lars_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_lamb_meta_optimizer MODULES test_fleet_lamb_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_partitioner MODULES test_auto_parallel_partitioner ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_partitioner_gpt MODULES test_auto_parallel_partitioner_gpt ENVS ${dist_ENVS})
endif(NOT WIN32)
endif(NOT APPLE)
if(WITH_DGC)
Expand Down
Loading

0 comments on commit b04b36c

Please sign in to comment.