Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] adapt for gpt-gen #46771

Merged
merged 7 commits into from
Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/paddle/distributed/auto_parallel/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from paddle.fluid import core

from .utils import is_gradient_clip_op
from .utils import is_gradient_clip_op, __not_shape_var_type__
from .operators import find_compatible_distributed_operator_impls
from .dist_context import _node_id
from .dist_attribute import TensorDistributedAttribute
Expand Down Expand Up @@ -491,14 +491,14 @@ def _find_nodes_related_to_cond(source_node):
for tensor_node in node.inputs:
if tensor_node.is_var() and tensor_node.var(
) is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER \
if tensor_node.var().type() in __not_shape_var_type__ \
or len(tensor_node.var().shape()) != 1:
flag = False
break
for tensor_node in node.outputs:
if tensor_node.is_var() and tensor_node.var(
) is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER \
if tensor_node.var().type() in __not_shape_var_type__ \
or len(tensor_node.var().shape()) != 1:
flag = False
break
Expand Down
8 changes: 6 additions & 2 deletions python/paddle/distributed/auto_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,8 +1139,10 @@ def prepare(self,
self.to_mode(mode)
if inputs or labels:
self._skip_build = True
self._inputs_spec = inputs_spec
self._labels_spec = labels_spec
self._inputs, self._labels = self._prepare_data_tensor(
inputs_spec, labels_spec, inputs, labels)
self._inputs_spec, self._labels_spec, inputs, labels)
self._orig_main_prog = main_program
if self._orig_main_prog is None:
self._orig_main_prog = static.default_main_program()
Expand All @@ -1152,9 +1154,11 @@ def prepare(self,
else:
self._switch_mode(self._mode)
elif inputs_spec or labels_spec:
self._inputs_spec = inputs_spec
self._labels_spec = labels_spec
self._outside_dataloader = True
self._inputs, self._labels = self._prepare_data_tensor(
inputs_spec, labels_spec)
self._inputs_spec, self._labels_spec)
self._orig_main_prog = main_program
if self._orig_main_prog is None:
self._orig_main_prog = static.default_main_program()
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@
from . import dist_fused_feedforward
from . import dist_fused_attention
from . import dist_reduce_sum_p
from . import dist_shape
from . import dist_assign
88 changes: 88 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/dist_assign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .dist_default import DistributedDefaultImpl0
from ..utils import compute_compatible_and_update_dim_mapping


class DistributedAssign(DistributedOperatorImplContainer):

def __init__(self, op_type):
super(DistributedAssign, self).__init__(op_type)


register_distributed_operator_impl_container(DistributedAssign("assign"))


class DistributedAssignImpl(DistributedOperatorImpl):

def __init__(self, name):
super(DistributedAssignImpl, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True

def is_input_compatible(self, dist_op):
return True

def is_output_compatible(self, dist_op):
return True

def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False

op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

if x_dims_mapping != out_dims_mapping:
return False

return True

def update_dims_mapping(self, dist_op):
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

for i in range(len(x_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True

return changed

@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

@staticmethod
def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)


register_distributed_operator_impl("assign", DistributedAssignImpl("assign"))
73 changes: 73 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/dist_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .dist_default import DistributedDefaultImpl0
from ..utils import is_dim_shard


class DistributedShape(DistributedOperatorImplContainer):

def __init__(self, op_type):
super(DistributedShape, self).__init__(op_type)


register_distributed_operator_impl_container(DistributedShape("shape"))


class DistributedShapeImpl(DistributedOperatorImpl):

def __init__(self, name):
super(DistributedShapeImpl, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True

def is_input_compatible(self, dist_op):
return True

def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

assert len(out_dims_mapping) == 1
if is_dim_shard(out_dims_mapping[0]):
return False

return True

def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False

return True

def update_dims_mapping(self, dist_op):
return False

@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

@staticmethod
def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)


register_distributed_operator_impl("shape", DistributedShapeImpl("shape"))
40 changes: 24 additions & 16 deletions python/paddle/distributed/auto_parallel/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_g_gradient_clip_ops = [
"sum", "sqrt", "fill_constant", "elementwise_max", "elementwise_div"
]
_g_subblock_ops = ["while", "conditional_block"]


def get_var_with_recursion(var_name, block, program):
Expand All @@ -42,11 +43,11 @@ def get_var_with_recursion(var_name, block, program):
if var_name in block.vars:
var = block.vars[var_name]
else:
parent_block = program.blocks[block.parent_idx]
if var_name in parent_block.vars:
var = parent_block.vars[var_name]
assert var is not None, \
"{} is not found".format(var.name)
var = block._var_recursive(var_name)
# parent_block = program.blocks[block.parent_idx]
# if var_name in parent_block.vars:
# var = parent_block.vars[var_name]
assert var is not None, "{} is not found".format(var.name)

return var

Expand Down Expand Up @@ -1075,7 +1076,9 @@ def change_while_op_input_and_output(auto_parallel_main_prog, dist_context):
new_Out = []
for var_name in while_op.output("Out"):
for output_name in sub_block_op_outputs[::-1]:
if output_name.find(var_name) != -1:
if output_name.find(var_name) != -1 and (
len(var_name) == len(output_name)
or "@RESHARD" in output_name):
if output_name not in new_Out:
new_Out.append(output_name)
assert new_Out
Expand Down Expand Up @@ -1104,13 +1107,15 @@ def is_special_op(self, op):
return False

def is_condition_replicative(self, op):
assert op.type == "while"
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
dist_op = self.dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr

if op.type == "while":
input_cond = op.input("Condition")
elif op.type == "conditional_block":
input_cond = op.input("Cond")

# the dims mapping of condition tensor should be replicative
for var_name in op.input("Condition"):
for var_name in input_cond:
var = get_var_with_recursion(var_name, sub_block,
self.auto_parallel_main_prog)
dist_tensor = self.dist_context.get_dist_tensor_for_program(var)
Expand Down Expand Up @@ -1660,9 +1665,9 @@ def parse_op_desc(self, block, op_desc_seq, var_name, reshard_op,
op.desc.set_input(proto.inputs[0].name,
op.input("X") + while_op_X_append)

def _get_while_op_input_attrs(self, op, var_name):
def _get_subblock_input_attrs(self, op, var_name):
# NOTE: Multi while loop is not supported
assert op.type == "while"
assert op.type in _g_subblock_ops
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
input_attrs = []
Expand Down Expand Up @@ -1713,8 +1718,8 @@ def _get_common_op_input_attrs(self, op, var_name):
def get_op_input_attrs(self, op, var_name):
op_input_attrs = []

if op.type == "while":
op_input_attrs = self._get_while_op_input_attrs(op, var_name)
if op.type in _g_subblock_ops:
op_input_attrs = self._get_subblock_input_attrs(op, var_name)
else:
op_input_attrs = self._get_common_op_input_attrs(op, var_name)

Expand Down Expand Up @@ -1818,7 +1823,7 @@ def _reshard_input(self, block):
if dist_op is not None:
op_input_dist_attrs = [
] # [(op_process_mesh, op_input_dims_mapping), (op_process_mesh, op_input_dims_mapping)]
if op.type == "while":
if op.type in _g_subblock_ops:
if not self.is_condition_replicative(op):
raise ValueError(
"Please check the condition due to the dims mapping is not replicative."
Expand All @@ -1832,6 +1837,8 @@ def _reshard_input(self, block):
if op.type == "while":
# condition var process mesh is the same with op and dims_mapping is replicative, so it do not need reshard
input_var_names = op.input("X")
elif op.type == "conditional_block":
input_var_names = op.input("Input")
else:
input_var_names = op.input_arg_names
# to avoid while op X order different
Expand Down Expand Up @@ -1984,11 +1991,12 @@ def _reshard_output(self, block):
idx = 0
# skip reader and ops whose process mesh is union
skip_ops = [
"create_py_reader", "create_double_buffer_reader", "read", "while",
"create_py_reader", "create_double_buffer_reader", "read",
"write_to_array", "read_from_array"
]
global _g_special_ops
skip_ops += _g_special_ops
skip_ops += _g_subblock_ops
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
from paddle.fluid.io import is_parameter, is_belong_to_optimizer
from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute

__not_shape_var_type__ = [
core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES
]


def get_logger(log_level, name="auto_parallel"):
logger = logging.getLogger(name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_interface MODULES test_interface)
py_test_modules(test_strategy MODULES test_strategy)
py_test_modules(test_pass_quantization MODULES test_pass_quantization)
py_test_modules(test_dist_shape MODULES test_dist_shape)
py_test_modules(test_dist_assign MODULES test_dist_assign)
py_test_modules(test_conditional_block_reshard MODULES
test_conditional_block_reshard)

endif()
Loading