Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel]Update comp cost and completion for gpt auto search #46387

Merged
merged 2 commits into from
Oct 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions python/paddle/distributed/auto_parallel/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class Completer:
def __init__(self, dist_context):
assert dist_context is not None
self._dist_context = dist_context
self._has_prepared = False

def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
changed = False
Expand Down Expand Up @@ -719,6 +720,8 @@ def _update_process_mesh(self):
self._update_process_mesh_between_graphs()

def _prepare(self):
if self._has_prepared:
return
self._while_op_nodes = {}
self._array_nodes = {}
self._node_pairs_between_graphs = []
Expand All @@ -732,6 +735,8 @@ def _prepare(self):
if self._array_nodes.get(array_var_name, None) is None:
self._array_nodes[array_var_name] = []
self._array_nodes[array_var_name].append(node)
# Add the array input node
self._array_nodes[array_var_name].append(node.inputs[0])
if node.op().type() == "write_to_array":
array_var_name = node.op().output("Out")[0]
if self._array_nodes.get(array_var_name, None) is None:
Expand All @@ -752,6 +757,7 @@ def _prepare(self):
and after_node.var().name() == node.var().name():
self._node_pairs_between_graphs.append(
(after_node, node))
self._has_prepared = True

def complete_forward_annotation(self, serial_main_program=None):
""" Complete annotation for the partial annotated serial_main_program.
Expand Down Expand Up @@ -899,6 +905,72 @@ def _update_dist_attr_for_dp(self):
else:
dist_op.dist_attr = original_op_dist_attr

def _complete_tensor_dist_attr_by_op(self, serial_main_program=None):
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context._serial_main_program = serial_main_program

self._dist_context.initialize()

self._prepare()

has_set_dist_attr = set()

all_nodes = self._dist_context.serial_ordered_nodes
for node in all_nodes:
if node.is_op():
if node.op().type() in ["while"]:
continue
dist_op = self._dist_context.get_dist_op_for_graph(node)
op_dist_attr = dist_op.dist_attr
for tensor_node in node.inputs:
if tensor_node.is_var() and tensor_node.var() is not None:
# Skip the non-leaf var node
if len(tensor_node.inputs) != 0:
continue
tensor_desc = tensor_node.var()
tensor_name = tensor_desc.name()
tensor = dist_op.get_serial_input(tensor_name)
# Use the first op to set the tensor dist attr
if tensor_name in has_set_dist_attr:
continue
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_dist_attr.process_mesh = op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_name) if tensor.is_parameter else [
-1 for i in tensor_desc.shape()
]
has_set_dist_attr.add(tensor_name)
for tensor_node in node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
tensor_name = tensor_node.var().name()
if tensor_name in has_set_dist_attr:
continue
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_dist_attr.process_mesh = op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_name)
has_set_dist_attr.add(tensor_name)

self._update_process_mesh_for_specials()

self._update_process_mesh_between_graphs()

self._update_dims_mapping_for_special()

self._update_dims_mapping_between_graphs()

# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()

# Do the validation check and amend some completion
self._dist_context.amend_dist_attr_for_program()

self._dist_context.validate_dist_attr_for_program()

def _complete_high_order_grad_annotation(self, serial_main_program=None):
"""
NOTE:
Expand Down
55 changes: 55 additions & 0 deletions python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,25 @@ def calc_time(self):
return 0


@register_op_cost
class DropoutGradOpCost(CompOpCost):
OP_TYPE = "dropout_grad"

def __init__(self, op=None, op_desc=None, cluster=None):
super(DropoutGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)

# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0

def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0


@register_op_cost
class ElementwiseAddOpCost(CompOpCost):
OP_TYPE = "elementwise_add"
Expand Down Expand Up @@ -395,6 +414,42 @@ def calc_time(self):
return 0


@register_op_cost
class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost):
OP_TYPE = "fused_softmax_mask_upper_triangle"

def __init__(self, op=None, op_desc=None, cluster=None):
super(FusedSoftmaxMaskUpperTriangleOpCost,
self).__init__(op=op, op_desc=op_desc, cluster=cluster)

# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0

def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0


@register_op_cost
class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost):
OP_TYPE = "fused_softmax_mask_upper_triangle_grad"

def __init__(self, op=None, op_desc=None, cluster=None):
super(FusedSoftmaxMaskUpperTriangleGradOpCost,
self).__init__(op=op, op_desc=op_desc, cluster=cluster)

# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0

def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0


@register_op_cost
class GatherOpCost(CompOpCost):
OP_TYPE = "gather"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@
from paddle.distributed.auto_parallel.cost.comp_op_cost import Transpose2GradOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import Unsqueeze2OpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import WriteToArrayOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import DropoutGradOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import FusedSoftmaxMaskUpperTriangleOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import FusedSoftmaxMaskUpperTriangleGradOpCost

from test_cluster import cluster_json

Expand Down Expand Up @@ -417,6 +420,22 @@ def test_comp_cost(self):
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)

op_cost = DropoutGradOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)

op_cost = FusedSoftmaxMaskUpperTriangleOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)

op_cost = FusedSoftmaxMaskUpperTriangleGradOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)

# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,14 @@ def test_completer(self):
train_program)
# print_program_with_dist_attr(complete_train_program, dist_context)

def test_completer_by_dist_op(self):
train_program, start_program, dataloader, i, loss = get_program()
dist_context = DistributedContext()
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
complete_train_program = completer._complete_tensor_dist_attr_by_op()


if __name__ == "__main__":
unittest.main()