Skip to content

Commit

Permalink
[AutoParallel] adapt for gpt-gen (#46771)
Browse files Browse the repository at this point in the history
* for gpt-gen

* fix reshard

* adapt assign and shape op

* add dist_assign & unittest

* add conditional block unittest

* rename unittest
  • Loading branch information
zhaoyinglia committed Oct 14, 2022
1 parent eee6b3a commit 31a437b
Show file tree
Hide file tree
Showing 11 changed files with 458 additions and 21 deletions.
6 changes: 3 additions & 3 deletions python/paddle/distributed/auto_parallel/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from paddle.fluid import core

from .utils import is_gradient_clip_op
from .utils import is_gradient_clip_op, __not_shape_var_type__
from .operators import find_compatible_distributed_operator_impls
from .dist_context import _node_id
from .dist_attribute import TensorDistributedAttribute
Expand Down Expand Up @@ -491,14 +491,14 @@ def _find_nodes_related_to_cond(source_node):
for tensor_node in node.inputs:
if tensor_node.is_var() and tensor_node.var(
) is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER \
if tensor_node.var().type() in __not_shape_var_type__ \
or len(tensor_node.var().shape()) != 1:
flag = False
break
for tensor_node in node.outputs:
if tensor_node.is_var() and tensor_node.var(
) is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER \
if tensor_node.var().type() in __not_shape_var_type__ \
or len(tensor_node.var().shape()) != 1:
flag = False
break
Expand Down
8 changes: 6 additions & 2 deletions python/paddle/distributed/auto_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,8 +1139,10 @@ def prepare(self,
self.to_mode(mode)
if inputs or labels:
self._skip_build = True
self._inputs_spec = inputs_spec
self._labels_spec = labels_spec
self._inputs, self._labels = self._prepare_data_tensor(
inputs_spec, labels_spec, inputs, labels)
self._inputs_spec, self._labels_spec, inputs, labels)
self._orig_main_prog = main_program
if self._orig_main_prog is None:
self._orig_main_prog = static.default_main_program()
Expand All @@ -1152,9 +1154,11 @@ def prepare(self,
else:
self._switch_mode(self._mode)
elif inputs_spec or labels_spec:
self._inputs_spec = inputs_spec
self._labels_spec = labels_spec
self._outside_dataloader = True
self._inputs, self._labels = self._prepare_data_tensor(
inputs_spec, labels_spec)
self._inputs_spec, self._labels_spec)
self._orig_main_prog = main_program
if self._orig_main_prog is None:
self._orig_main_prog = static.default_main_program()
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@
from . import dist_fused_feedforward
from . import dist_fused_attention
from . import dist_reduce_sum_p
from . import dist_shape
from . import dist_assign
88 changes: 88 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/dist_assign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .dist_default import DistributedDefaultImpl0
from ..utils import compute_compatible_and_update_dim_mapping


class DistributedAssign(DistributedOperatorImplContainer):

def __init__(self, op_type):
super(DistributedAssign, self).__init__(op_type)


register_distributed_operator_impl_container(DistributedAssign("assign"))


class DistributedAssignImpl(DistributedOperatorImpl):

def __init__(self, name):
super(DistributedAssignImpl, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True

def is_input_compatible(self, dist_op):
return True

def is_output_compatible(self, dist_op):
return True

def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False

op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

if x_dims_mapping != out_dims_mapping:
return False

return True

def update_dims_mapping(self, dist_op):
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

for i in range(len(x_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True

return changed

@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

@staticmethod
def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)


register_distributed_operator_impl("assign", DistributedAssignImpl("assign"))
73 changes: 73 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/dist_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .dist_default import DistributedDefaultImpl0
from ..utils import is_dim_shard


class DistributedShape(DistributedOperatorImplContainer):

def __init__(self, op_type):
super(DistributedShape, self).__init__(op_type)


register_distributed_operator_impl_container(DistributedShape("shape"))


class DistributedShapeImpl(DistributedOperatorImpl):

def __init__(self, name):
super(DistributedShapeImpl, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True

def is_input_compatible(self, dist_op):
return True

def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

assert len(out_dims_mapping) == 1
if is_dim_shard(out_dims_mapping[0]):
return False

return True

def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False

return True

def update_dims_mapping(self, dist_op):
return False

@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)

@staticmethod
def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)


register_distributed_operator_impl("shape", DistributedShapeImpl("shape"))
40 changes: 24 additions & 16 deletions python/paddle/distributed/auto_parallel/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_g_gradient_clip_ops = [
"sum", "sqrt", "fill_constant", "elementwise_max", "elementwise_div"
]
_g_subblock_ops = ["while", "conditional_block"]


def get_var_with_recursion(var_name, block, program):
Expand All @@ -42,11 +43,11 @@ def get_var_with_recursion(var_name, block, program):
if var_name in block.vars:
var = block.vars[var_name]
else:
parent_block = program.blocks[block.parent_idx]
if var_name in parent_block.vars:
var = parent_block.vars[var_name]
assert var is not None, \
"{} is not found".format(var.name)
var = block._var_recursive(var_name)
# parent_block = program.blocks[block.parent_idx]
# if var_name in parent_block.vars:
# var = parent_block.vars[var_name]
assert var is not None, "{} is not found".format(var.name)

return var

Expand Down Expand Up @@ -1075,7 +1076,9 @@ def change_while_op_input_and_output(auto_parallel_main_prog, dist_context):
new_Out = []
for var_name in while_op.output("Out"):
for output_name in sub_block_op_outputs[::-1]:
if output_name.find(var_name) != -1:
if output_name.find(var_name) != -1 and (
len(var_name) == len(output_name)
or "@RESHARD" in output_name):
if output_name not in new_Out:
new_Out.append(output_name)
assert new_Out
Expand Down Expand Up @@ -1104,13 +1107,15 @@ def is_special_op(self, op):
return False

def is_condition_replicative(self, op):
assert op.type == "while"
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
dist_op = self.dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr

if op.type == "while":
input_cond = op.input("Condition")
elif op.type == "conditional_block":
input_cond = op.input("Cond")

# the dims mapping of condition tensor should be replicative
for var_name in op.input("Condition"):
for var_name in input_cond:
var = get_var_with_recursion(var_name, sub_block,
self.auto_parallel_main_prog)
dist_tensor = self.dist_context.get_dist_tensor_for_program(var)
Expand Down Expand Up @@ -1660,9 +1665,9 @@ def parse_op_desc(self, block, op_desc_seq, var_name, reshard_op,
op.desc.set_input(proto.inputs[0].name,
op.input("X") + while_op_X_append)

def _get_while_op_input_attrs(self, op, var_name):
def _get_subblock_input_attrs(self, op, var_name):
# NOTE: Multi while loop is not supported
assert op.type == "while"
assert op.type in _g_subblock_ops
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
input_attrs = []
Expand Down Expand Up @@ -1713,8 +1718,8 @@ def _get_common_op_input_attrs(self, op, var_name):
def get_op_input_attrs(self, op, var_name):
op_input_attrs = []

if op.type == "while":
op_input_attrs = self._get_while_op_input_attrs(op, var_name)
if op.type in _g_subblock_ops:
op_input_attrs = self._get_subblock_input_attrs(op, var_name)
else:
op_input_attrs = self._get_common_op_input_attrs(op, var_name)

Expand Down Expand Up @@ -1818,7 +1823,7 @@ def _reshard_input(self, block):
if dist_op is not None:
op_input_dist_attrs = [
] # [(op_process_mesh, op_input_dims_mapping), (op_process_mesh, op_input_dims_mapping)]
if op.type == "while":
if op.type in _g_subblock_ops:
if not self.is_condition_replicative(op):
raise ValueError(
"Please check the condition due to the dims mapping is not replicative."
Expand All @@ -1832,6 +1837,8 @@ def _reshard_input(self, block):
if op.type == "while":
# condition var process mesh is the same with op and dims_mapping is replicative, so it do not need reshard
input_var_names = op.input("X")
elif op.type == "conditional_block":
input_var_names = op.input("Input")
else:
input_var_names = op.input_arg_names
# to avoid while op X order different
Expand Down Expand Up @@ -1984,11 +1991,12 @@ def _reshard_output(self, block):
idx = 0
# skip reader and ops whose process mesh is union
skip_ops = [
"create_py_reader", "create_double_buffer_reader", "read", "while",
"create_py_reader", "create_double_buffer_reader", "read",
"write_to_array", "read_from_array"
]
global _g_special_ops
skip_ops += _g_special_ops
skip_ops += _g_subblock_ops
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
from paddle.fluid.io import is_parameter, is_belong_to_optimizer
from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute

__not_shape_var_type__ = [
core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES
]


def get_logger(log_level, name="auto_parallel"):
logger = logging.getLogger(name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_interface MODULES test_interface)
py_test_modules(test_strategy MODULES test_strategy)
py_test_modules(test_pass_quantization MODULES test_pass_quantization)
py_test_modules(test_dist_shape MODULES test_dist_shape)
py_test_modules(test_dist_assign MODULES test_dist_assign)
py_test_modules(test_conditional_block_reshard MODULES
test_conditional_block_reshard)

endif()
Loading

0 comments on commit 31a437b

Please sign in to comment.