Skip to content

Commit

Permalink
Adjusted python-level trace_op to accomodate final state Eager Dygraph (
Browse files Browse the repository at this point in the history
#39319)

* Removed debug info

* Added automatic code generation for final state Eager Dygraph

* Modified backward yaml

* Added EagerUtils helper functions for final state CodeGen

* Adjusted CMakeFiles to support compilation for final state auto generated codes

* Added python-c code generation for final state Eager Dygraph

* Fixed minor issue

* Fixed yaml.load() method failure

* Fixed minor issues

* Refactored Python-C Attributes Parsing Functions

* Fixed minor issue with Python-C AddFunctions

* Adjusted python-level trace_op to accomodate final state Eager Dygraph

* Added Logs for final state Eager Dygraph

* Fixed merge issues

* Fixed minor issue
  • Loading branch information
jim19930609 committed Feb 14, 2022
1 parent 74a150f commit ec8a0c1
Show file tree
Hide file tree
Showing 4 changed files with 376 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
import argparse
import os

# For API dispatch used at python-level
# { op_name : [arg_name, ...] }
core_ops_returns_info = {}
core_ops_args_info = {}
core_ops_args_type_info = {}


def ParseArguments():
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -130,17 +136,16 @@ def ParseYamlArgs(string):
attrs_list = []

args = [x.strip() for x in string.strip().split(",")]

atype = r'((const )?\S+) '
aname = r'(\S+)'
aname = r'(.*)'
pattern = f'{atype}{aname}'
for i in range(len(args)):
arg = args[i]
m = re.search(pattern, arg)
arg_type = m.group(1)
arg_name = m.group(3).split("=")[0]
default_value = m.group(3).split("=")[1] if len(m.group(3).split(
"=")) > 1 else None
arg_type = m.group(1).strip()
arg_name = m.group(3).split("=")[0].strip()
default_value = m.group(3).split("=")[1].strip() if len(
m.group(3).split("=")) > 1 else None
if "Tensor" in arg_type:
assert default_value is None
inputs_list.append([arg_name, arg_type, i])
Expand Down Expand Up @@ -262,7 +267,6 @@ def ForwardsValidationCheck(forward_inputs_list, forward_attrs_list,
forward_attr_type = forward_attrs_list[i][1]
forward_attr_default = forward_attrs_list[i][2]
forward_attr_pos = forward_attrs_list[i][3]

assert orig_attr_type == forward_attr_type
assert orig_attr_default == forward_attr_default
assert orig_attr_pos == forward_attr_pos
Expand Down Expand Up @@ -741,26 +745,34 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
# Get Function Args
num_inputs = len(forward_attrs_list) + len(forward_inputs_position_map.keys(
))
inputs_args_list = ["" for i in range(num_inputs)]
inputs_args_definition_list = ["" for i in range(num_inputs)]
inputs_args_declaration_list = ["" for i in range(num_inputs)]
inputs_call_list = ["" for i in range(num_inputs)]
for name, (ttype, pos) in forward_inputs_position_map.items():
inputs_call_list[pos] = f"{name}"
if IsPlainTensorType(ttype):
inputs_args_list[
inputs_args_definition_list[
pos] = f"const paddle::experimental::Tensor& {name}"
inputs_args_declaration_list[
pos] = f"const paddle::experimental::Tensor& {name}"
else:
assert IsVectorTensorType(ttype)
inputs_args_list[
inputs_args_definition_list[
pos] = f"const std::vector<paddle::experimental::Tensor>& {name}"
inputs_args_declaration_list[
pos] = f"const std::vector<paddle::experimental::Tensor>& {name}"

for name, atype, default_val, pos in forward_attrs_list:
inputs_call_list[pos] = name
if default_val is not None:
inputs_args_list[pos] = f"{atype} {name} = {default_val}"
inputs_args_declaration_list[
pos] = f"{atype} {name} = {default_val}"
else:
inputs_args_list[pos] = f"{atype} {name}"
inputs_args_declaration_list[pos] = f"{atype} {name}"
inputs_args_definition_list[pos] = f"{atype} {name}"

inputs_args_str = ", ".join(inputs_args_list)
inputs_args_declaration_str = ", ".join(inputs_args_declaration_list)
inputs_args_definition_str = ", ".join(inputs_args_definition_list)
inputs_call_args_str = ", ".join(inputs_call_list)

# Forward Full Logic
Expand Down Expand Up @@ -812,13 +824,95 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,

forward_function_name = GetForwardFunctionName(fwd_api_name)
forward_function_str = FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name, inputs_args_str,
returns_type_str, forward_function_name, inputs_args_definition_str,
forward_call_str, node_creation_str, returns_str)
forward_function_declaration_str = f"{returns_type_str} {forward_function_name}({inputs_args_str});"
forward_function_declaration_str = f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});"

return forward_function_str, forward_function_declaration_str


def CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list):
# fwd_api_name : ""
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
# forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
num_args = len(forward_inputs_position_map.keys()) + len(forward_attrs_list)
num_returns = len(forward_outputs_position_map.keys())

final_state_fwd_api_name = "final_state_" + fwd_api_name
core_ops_returns_info[
final_state_fwd_api_name] = ["" for i in range(num_returns)]
core_ops_args_info[final_state_fwd_api_name] = ["" for i in range(num_args)]
core_ops_args_type_info[
final_state_fwd_api_name] = ["" for i in range(num_args)]
for name, (ttype, pos) in forward_inputs_position_map.items():
core_ops_args_info[final_state_fwd_api_name][pos] = name
if IsPlainTensorType(ttype):
core_ops_args_type_info[final_state_fwd_api_name][pos] = "tensor"
else:
assert IsVectorTensorType(ttype)
core_ops_args_type_info[final_state_fwd_api_name][pos] = "list"

for name, _, _, pos in forward_attrs_list:
core_ops_args_info[final_state_fwd_api_name][pos] = name

for name, (ttype, pos) in forward_outputs_position_map.items():
core_ops_returns_info[final_state_fwd_api_name][pos] = name


def GenerateCoreOpInfoDeclaration():
core_ops_declaration_str = """
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_type_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_returns_info;
"""
return core_ops_declaration_str


def GenerateCoreOpInfoDefinition():

CORE_OPS_INFO_TEMPLATE = """
std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_info = {{
{}
}};
std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_type_info = {{
{}
}};
std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_returns_info = {{
{}
}};
"""
op_args_info_list = []
for op_name, arg_list in core_ops_args_info.items():
arg_str = ",".join(["\"" + v + "\"" for v in arg_list])
op_args_info = f"{{ \"{op_name}\", {{ {arg_str} }} }},"
op_args_info_list.append(op_args_info)

op_types_info_list = []
for op_name, type_list in core_ops_args_type_info.items():
type_str = ",".join(["\"" + v + "\"" for v in type_list])
op_types_info = f"{{ \"{op_name}\", {{ {type_str} }} }},"
op_types_info_list.append(op_types_info)

op_returns_info_list = []
for op_name, return_list in core_ops_returns_info.items():
return_str = ",".join(["\"" + v + "\"" for v in return_list])
return_types_info = f"{{ \"{op_name}\", {{ {return_str} }} }},"
op_returns_info_list.append(return_types_info)

op_args_info_str = "\n".join(op_args_info_list)
op_types_info_str = "\n".join(op_types_info_list)
op_returns_info_str = "\n".join(op_returns_info_list)

core_ops_info_definition_str = CORE_OPS_INFO_TEMPLATE.format(
op_args_info_str, op_types_info_str, op_returns_info_str)

return core_ops_info_definition_str


def GenerateNodeCCFile(filepath, node_definition_str):
file_contents = """
#include "glog/logging.h"
Expand Down Expand Up @@ -856,6 +950,8 @@ def GenerateForwardCCFile(filepath, forward_definition_str):
#include "paddle/fluid/eager/api/utils/global_utils.h"
"""

file_contents += GenerateCoreOpInfoDefinition()
file_contents += forward_definition_str
with open(filepath, 'a') as f:
f.write(file_contents)
Expand All @@ -871,6 +967,7 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
#include "paddle/fluid/framework/op_registry.h"
"""
file_contents += GenerateCoreOpInfoDeclaration()
file_contents += forward_function_declaration_str
with open(filepath, 'a') as f:
f.write(file_contents)
Expand Down Expand Up @@ -985,6 +1082,11 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
forward_definition_str += definition_declaration_pair[0]
forward_declaration_str += definition_declaration_pair[1]

# For python-level API dispatch
CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map,
forward_outputs_position_map,
forward_attrs_list)

# Generate Files
nodes_h_path = args.nodes_h_path
nodes_cc_path = args.nodes_cc_path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map,
PyThreadState *tstate = nullptr;
try
{{
VLOG(6) << "Running Eager Final State API: {}";
// Get EagerTensors from args
{}
Expand All @@ -129,16 +131,87 @@ def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map,
"""
python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
fwd_api_name, get_eager_tensor_str, parse_attributes_str,
fwd_api_name, fwd_api_name, get_eager_tensor_str, parse_attributes_str,
GetForwardFunctionName(fwd_api_name), dygraph_function_call_str)

python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void))eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}}"
python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void))eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}},\n"

return python_c_function_str, python_c_function_reg_str


def GenerateCoreOpsInfoMap():
result = """
static PyObject * eager_get_final_state_core_ops_args_info(PyObject *self) {
PyThreadState *tstate = nullptr;
try
{
return ToPyObject(core_ops_final_state_args_info);
}
catch(...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
static PyObject * eager_get_final_state_core_ops_args_type_info(PyObject *self) {
PyThreadState *tstate = nullptr;
try
{
return ToPyObject(core_ops_final_state_args_type_info);
}
catch(...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
static PyObject * eager_get_final_state_core_ops_returns_info(PyObject *self) {
PyThreadState *tstate = nullptr;
try
{
return ToPyObject(core_ops_final_state_returns_info);
}
catch(...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
"""

core_ops_infos_registry = """
{\"get_final_state_core_ops_args_info\",
(PyCFunction)(void(*)(void))eager_get_final_state_core_ops_args_info, METH_NOARGS,
\"C++ interface function for eager_get_final_state_core_ops_args_info.\"},
{\"get_final_state_core_ops_args_type_info\",
(PyCFunction)(void(*)(void))eager_get_final_state_core_ops_args_type_info,
METH_NOARGS,
\"C++ interface function for eager_get_final_state_core_ops_args_type_info.\"},
{\"get_final_state_core_ops_returns_info\",
(PyCFunction)(void(*)(void))eager_get_final_state_core_ops_returns_info,
METH_NOARGS, \"C++ interface function for eager_get_final_state_core_ops_returns_info.\"},
"""

return result, core_ops_infos_registry


def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str):

core_ops_infos_definition, core_ops_infos_registry = GenerateCoreOpsInfoMap(
)

python_c_function_str += core_ops_infos_definition
python_c_function_reg_str += core_ops_infos_registry
python_c_function_reg_str += "\n {nullptr,nullptr,0,nullptr}"

PYTHON_C_WRAPPER_TEMPLATE = """
#pragma once
Expand Down Expand Up @@ -215,12 +288,12 @@ def GeneratePythonCFile(filepath, python_c_str):
python_c_function_reg_list.append(python_c_function_reg_str)
print("Generated Python-C Function: ", python_c_function_str)

python_c_function_reg_list.append("{nullptr,nullptr,0,nullptr}")
python_c_functions_str = "\n".join(python_c_function_list)
python_c_functions_reg_str = ",\n".join(python_c_function_reg_list)

python_c_str = GeneratePythonCWrappers(python_c_functions_str,
python_c_functions_reg_str)

print("Generated Python-C Codes: ", python_c_str)

output_path = args.output_path
Expand Down
Loading

0 comments on commit ec8a0c1

Please sign in to comment.