Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supported intermediate outputs for eager final state codegen #39767

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,26 @@ def ReadBwdFile(filepath):
######################
### Yaml Parsers ###
######################
def IntermediateValidationCheck(intermediate_outputs, forward_returns_list):
# intermediate_outputs : [name0, name1, ...]
# forward_returns_list : [[ret_name, type, orig_pos], ...]
"""
Check whether intermediate_outputs are positioned
at the very end of forward_returns_list
"""

intermediate_positions = range(
len(forward_returns_list) - len(intermediate_outputs),
len(forward_returns_list))
for ret_name, _, pos in forward_returns_list:
if ret_name in intermediate_outputs:
assert pos in intermediate_positions


def ParseIntermediate(string):
return [v.strip() for v in string.split(",")]


def ParseNoNeedBuffer(string):
# string: "x, y"
no_need_buffer_set = set()
Expand Down Expand Up @@ -742,11 +762,11 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
return node_creation_str


def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list):
def GenerateForwardDefinition(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, intermediate_outputs):
# fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
Expand Down Expand Up @@ -790,13 +810,20 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
inputs_call_args_str = ", ".join(inputs_call_list)

# Forward Full Logic
forward_call_str = f"auto api_result = paddle::experimental::{fwd_api_name}({inputs_call_args_str});"
if len(intermediate_outputs) == 0:
function_name = fwd_api_name
else:
function_name = fwd_api_name + "_intermediate"
forward_call_str = f"auto api_result = paddle::experimental::{function_name}({inputs_call_args_str});"

# Get return type list & outputs
num_outputs = len(forward_outputs_position_map.keys())
num_outputs = len(forward_outputs_position_map.keys()) - len(
intermediate_outputs)
returns_type_list = ["" for i in range(num_outputs)]
returns_list = ["" for i in range(num_outputs)]
for name, (rtype, pos) in forward_outputs_position_map.items():
if name in intermediate_outputs:
continue
if num_outputs == 1:
returns_list[0] = f"api_result"
else:
Expand Down Expand Up @@ -1037,6 +1064,12 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
print("Prased Forward Attrs List: ", forward_attrs_list)
print("Parsed Forward Returns List: ", forward_returns_list)

intermediate_outputs = []
if 'intermediate' in fwd_api.keys():
intermediate_outputs = ParseIntermediate(fwd_api['intermediate'])

IntermediateValidationCheck(intermediate_outputs, forward_returns_list)

# Collect Original Forward Inputs/Outputs and then perform validation checks
orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list = ParseYamlForward(
fwd_args_str, fwd_returns_str)
Expand Down Expand Up @@ -1095,7 +1128,7 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list)
backward_grad_output_map, backward_attrs_list, intermediate_outputs)
print("Generated Forward Definition: ", forward_definition_str)
print("Generated Forward Declaration: ", forward_declaration_str)
forward_definition_str += definition_declaration_pair[0]
Expand Down