Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel][PIR] fix vpp pass when open fused pass #68383

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 47 additions & 23 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,51 @@ def create_backward_prune_set(
return outputs_set, inputs_set, no_gradvar_set


def _complete_grad_op_chunk_id(block, state):
is_dist_program = False
for op in block.ops:
if op.dist_attr is not None:
is_dist_program = True
break

if not is_dist_program:
return

for op in block.ops:
if op not in state.op_to_opgrad:
continue

if op.dist_attr is None:
op_chunk_id = -1
if op.name() == "builtin.split":
op_chunk_id = (
op.operand_source(0).get_defining_op().dist_attr.chunk_id
)
elif op.name() == "builtin.combine":
op_chunk_id = op.result(0).get_defining_op().dist_attr.chunk_id
else:
# TODO(luchang): need to support more ops such as pd_op.pylayer and so on
pass
else:
op_chunk_id = op.dist_attr.chunk_id
if op_chunk_id == -1 and op.name() == "dist_op.reshard":
op_chunk_id = (
op.operand_source(0).get_defining_op().dist_attr.chunk_id
)

for bwd_op in state.op_to_opgrad[op]:
if bwd_op.dist_attr is None:
continue
bwd_op.dist_attr = (
paddle.base.libpaddle.pir.create_op_dist_attribute(
bwd_op.dist_attr.process_mesh,
bwd_op.dist_attr.operands(),
bwd_op.dist_attr.results(),
op_chunk_id,
)
)


def calc_gradient_helper(
outputs: Value | Sequence[Value],
inputs: Value | Sequence[Value],
Expand Down Expand Up @@ -1057,29 +1102,8 @@ def calc_gradient_helper(
outputs_fwd_set, inputs_fwd_set, no_grad_set, state
)

# set struct name for grad ops
for op in block.ops:
if op in state.op_to_opgrad:
if op.dist_attr is None:
continue

op_chunk_id = op.dist_attr.chunk_id
if op_chunk_id == -1 and op.name() == "dist_op.reshard":
op_chunk_id = (
op.operand_source(0).get_defining_op().dist_attr.chunk_id
)

for bwd_op in state.op_to_opgrad[op]:
if bwd_op.dist_attr is None:
continue
bwd_op.dist_attr = (
paddle.base.libpaddle.pir.create_op_dist_attribute(
bwd_op.dist_attr.process_mesh,
bwd_op.dist_attr.operands(),
bwd_op.dist_attr.results(),
op_chunk_id,
)
)
# set chunk id for grad ops
_complete_grad_op_chunk_id(block, state)

remove_ops = []
if not is_inplace_net(backward_ops) and inputs:
Expand Down
22 changes: 20 additions & 2 deletions python/paddle/distributed/auto_parallel/static/pir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
choose_reshard_func,
copy_dist_attr_with_new_member,
copy_op_attr_with_new_member,
copy_process_mesh_with_new_member,
)
from .reshard_funcs.reshard_func_register import register_reshard_funcs
from .utils import (
Expand Down Expand Up @@ -148,11 +149,19 @@ def apply_partition_pass(program):
result.update_dist_attr(result_attr)

with auto_complete_op_role(program, ref_op_role):
prev_op = prev_var.get_defining_op()

# reshard output to assign out input
reshard_var_1 = paddle._C_ops.reshard_v2(
result, prev_var.dist_attr()
)
paddle.assign(reshard_var_1, prev_var)
assign_out = paddle._C_ops.assign_out_(reshard_var_1, prev_var)
assign_out.get_defining_op().dist_attr = (
copy_op_attr_with_new_member(
assign_out.get_defining_op().dist_attr,
new_chunk_id=prev_op.dist_attr.chunk_id,
)
)

if old_dist_attr == result.dist_attr():
continue
Expand Down Expand Up @@ -687,7 +696,7 @@ def _get_seg_struct_names(ops, seg_method):
seg_op_mesh = collections.OrderedDict()

for i in range(fwd_start_op_index, fwd_end_op_index + 1):
if ops[i].name() == "builtin.combine":
if ops[i].name() in dist_skip_op_list:
continue

struct_name = _extract_seg_method(ops[i], seg_method)
Expand Down Expand Up @@ -960,10 +969,19 @@ def complete_chunk_id(dist_program, pipeline_strategy):
new_dst_dist_attr = copy_dist_attr_with_new_member(
dst_dist_attr, new_process_mesh=dst_process_mesh
)
new_process_ids = (
src_process_mesh.process_ids + dst_process_mesh.process_ids
)
new_process_mesh = copy_process_mesh_with_new_member(
op.dist_attr.process_mesh,
new_process_ids=new_process_ids,
)

op.dist_attr = copy_op_attr_with_new_member(
op_dist_attr,
new_operands=[new_src_dist_attr],
new_results=[new_dst_dist_attr],
new_process_mesh=new_process_mesh,
)
elif reshard_func_name == "SameStatusReshardFunction":
op.result(0).replace_all_uses_with(var)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
else:
dst_value = paddle.assign(src_value)

src_chunk_id = src_value.get_defining_op().dist_attr.chunk_id
dst_value.set_type(dst_type)
dst_value.get_defining_op().dist_attr = (
paddle.base.libpaddle.pir.create_op_dist_attribute(
dst_mesh, [src_dist_attr], [dst_dist_attr]
dst_mesh, [src_dist_attr], [dst_dist_attr], src_chunk_id
)
)

Expand Down