Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] fix onednn layout transform yaml format #60680

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -276,25 +276,27 @@ OneDNNPhiKernelInstruction::OneDNNPhiKernelInstruction(
VLOG(6) << "finish process no need buffer";

// Step2: build layout_transform information
if (op_attributes.count("layout_transform_arg")) {
auto layout_transform_arg = op_attributes.at("layout_transform_arg")
.dyn_cast<pir::StrAttribute>()
.AsString();
auto data_layout = op_attributes.at(layout_transform_arg)
.dyn_cast<pir::StrAttribute>()
.AsString();
input_layout_ = common::StringToDataLayout(data_layout);
std::vector<pir::Attribute> layout_transform_inputs_attr =
if (op_attributes.count("data_format_tensors")) {
if (op_attributes.count("data_format")) {
auto data_layout = op_attributes.at("data_format")
.dyn_cast<pir::StrAttribute>()
.AsString();
input_layout_ = common::StringToDataLayout(data_layout);
} else {
input_layout_ = phi::OneDNNContext::tls().get_cur_paddle_data_layout();
}

std::vector<pir::Attribute> data_format_tensors_attr =
op->attributes()
.at("layout_transform_inputs")
.at("data_format_tensors")
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
std::vector<std::string> layout_transform_inputs;
for (auto& attr : layout_transform_inputs_attr) {

for (auto& attr : data_format_tensors_attr) {
auto pair = kernel_context_.InputRangeAt(value_exec_info_->GetIdByName(
attr.dyn_cast<pir::StrAttribute>().AsString()));
for (int i = pair.first; i < pair.second; ++i) {
layout_transform_inputs_.insert(i);
data_format_tensors_.insert(i);
}
}
}
Expand Down Expand Up @@ -333,7 +335,7 @@ void OneDNNPhiKernelInstruction::Run() {

// Handle 'layout_transform' in
// ops_onednn_extra.yaml(GetKernelTypeForVar)
if (layout_transform_inputs_.count(i) &&
if (data_format_tensors_.count(i) &&
input_layout_ != phi::DataLayout::kAnyLayout) {
from_layout = input_layout_;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class OneDNNPhiKernelInstruction : public InstructionBase {

const ValueExecutionInfo* value_exec_info_; // not owned

std::set<int> layout_transform_inputs_{};
std::set<int> data_format_tensors_{};
phi::DataLayout input_layout_{phi::DataLayout::kAnyLayout};
std::map<std::string, phi::Attribute> extra_attr_{};
std::map<std::string, std::vector<std::string>> inputs_{};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,25 +202,27 @@ OneDNNLegacyKernelInstruction::OneDNNLegacyKernelInstruction(
VLOG(6) << "finish process no need buffer";

// Step2: build layout_transform information
if (op_attributes.count("layout_transform_arg")) {
auto layout_transform_arg = op_attributes.at("layout_transform_arg")
.dyn_cast<pir::StrAttribute>()
.AsString();
auto data_layout = op_attributes.at(layout_transform_arg)
.dyn_cast<pir::StrAttribute>()
.AsString();
input_layout_ = common::StringToDataLayout(data_layout);
std::vector<pir::Attribute> layout_transform_inputs_attr =
if (op_attributes.count("data_format_tensors")) {
if (op_attributes.count("data_format")) {
auto data_layout = op_attributes.at("data_format")
.dyn_cast<pir::StrAttribute>()
.AsString();
input_layout_ = common::StringToDataLayout(data_layout);
} else {
input_layout_ = phi::OneDNNContext::tls().get_cur_paddle_data_layout();
}

std::vector<pir::Attribute> data_format_tensors_attr =
op->attributes()
.at("layout_transform_inputs")
.at("data_format_tensors")
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
std::vector<std::string> layout_transform_inputs;
auto& op_normalizer = paddle::translator::OpNameNormalizer::instance();
std::string fluid_op_name = yaml_info_parser.GetOriginOpName();
for (auto& attr : layout_transform_inputs_attr) {
for (auto& attr : data_format_tensors_attr) {
auto input_name = attr.dyn_cast<pir::StrAttribute>().AsString();
layout_transform_inputs_.insert(
data_format_tensors_.insert(
op_normalizer.GetLegacyArgName(fluid_op_name, input_name));
}
}
Expand Down Expand Up @@ -249,7 +251,7 @@ void OneDNNLegacyKernelInstruction::Run() {

// Handle 'layout_transform' in
// ops_onednn_extra.yaml(GetKernelTypeForVar)
if (layout_transform_inputs_.count(*input_name) &&
if (data_format_tensors_.count(*input_name) &&
input_layout_ != phi::DataLayout::kAnyLayout) {
from_layout = input_layout_;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class OneDNNLegacyKernelInstruction : public InstructionBase {

const ValueExecutionInfo* value_exec_info_; // not owned

std::set<std::string> layout_transform_inputs_{};
std::set<std::string> data_format_tensors_{};
phi::DataLayout input_layout_{phi::DataLayout::kAnyLayout};
};

Expand Down
41 changes: 19 additions & 22 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from op_kerneltype_gen import gen_kernel_type_for_var_str
from op_member_func_gen import gen_op_get_inputs_outputs_str
from op_verify_gen import gen_verify_func_str
from ops_onednn_extra_parser import parse_extra_args, parse_layout_transform
from ops_onednn_extra_parser import parse_data_format_tensors, parse_extra_args
from parse_kernel_key_gen import gen_parse_kernel_key_str
from vjp_interface_black_list import vjp_interface_black_list

Expand Down Expand Up @@ -233,7 +233,7 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, "{kernel_func}", {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{kernel_key_backend}}}, {{{inplace}}}, {{{view}}}, {{{extra_args}}}, "{layout_transform_arg}", {{{layout_transform_inputs}}}, {is_onednn_only}, {dynamic_fallback});
paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, "{kernel_func}", {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{kernel_key_backend}}}, {{{inplace}}}, {{{view}}}, {{{extra_args}}}, {{{data_format_tensors}}}, {is_onednn_only}, {dynamic_fallback});
return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}");
}}
"""
Expand Down Expand Up @@ -490,12 +490,14 @@ def __init__(self, op_yaml_item, op_compat_item):
# OneDNN info
if "extra_args" in self.op_yaml_item:
self.onednn_extra_args = self.op_yaml_item["extra_args"]
self.onednn_layout_transform = self.op_yaml_item["layout_transform"]
self.onednn_data_format_tensors = self.op_yaml_item[
"data_format_tensors"
]
self.is_onednn_only = self.op_yaml_item["is_onednn_only"]
self.dynamic_fallback = self.op_yaml_item["dynamic_fallback"]
else:
self.onednn_extra_args = []
self.onednn_layout_transform = None
self.onednn_data_format_tensors = None
self.is_onednn_only = False
self.dynamic_fallback = False

Expand Down Expand Up @@ -1616,18 +1618,12 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
extra_args = '"' + '", "'.join(args_name) + '"'
else:
extra_args = ""
if op_info.onednn_layout_transform is None:
layout_transform_arg, layout_transform_inputs = (
"",
"",
)
if op_info.onednn_data_format_tensors is None:
data_format_tensors = ""
else:
(
layout_transform_arg,
layout_transform_inputs,
) = op_info.onednn_layout_transform
layout_transform_inputs = (
'"' + '", "'.join(layout_transform_inputs) + '"'
data_format_tensors = op_info.onednn_data_format_tensors
data_format_tensors = (
'"' + '", "'.join(data_format_tensors) + '"'
)

op_info_func_str = OP_INFO_ONEDNN_TEMPLATE.format(
Expand All @@ -1645,8 +1641,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
view=view_str,
origin_op_name=op_info.op_yaml_item['name'],
extra_args=extra_args,
layout_transform_arg=layout_transform_arg,
layout_transform_inputs=layout_transform_inputs,
data_format_tensors=data_format_tensors,
is_onednn_only="true"
if op_info.is_onednn_only
else "false",
Expand Down Expand Up @@ -1864,12 +1859,12 @@ def OpGenerator(
item = {}
item["is_onednn_only"] = False
item["extra_args"] = parse_extra_args(op_name, op['extra_args'])
if 'layout_transform' in op:
item["layout_transform"] = parse_layout_transform(
op_name, op['layout_transform']
if 'data_format_tensors' in op:
item["data_format_tensors"] = parse_data_format_tensors(
op_name, op['data_format_tensors']
)
else:
item["layout_transform"] = None
item["data_format_tensors"] = None
if 'dynamic_fallback' in op:
item["dynamic_fallback"] = op['dynamic_fallback']
else:
Expand Down Expand Up @@ -1924,7 +1919,9 @@ def OpGenerator(
onednn_item = ops_onednn_extra_map[op['name']]
op["is_onednn_only"] = onednn_item["is_onednn_only"]
op["extra_args"] = onednn_item["extra_args"]
op["layout_transform"] = onednn_item["layout_transform"]
op["data_format_tensors"] = onednn_item[
"data_format_tensors"
]
op["dynamic_fallback"] = onednn_item["dynamic_fallback"]
op["attrs"] = op["attrs"] + onednn_item["attrs"]
else:
Expand Down
12 changes: 5 additions & 7 deletions paddle/fluid/pir/dialect/op_generator/ops_onednn_extra_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import re
from typing import Any, Dict, List, Tuple
from typing import Dict, List, Tuple


def parse_plain_list(s: str, sep=",") -> List[str]:
Expand Down Expand Up @@ -76,11 +76,9 @@ def parse_extra_args(op_name: str, arguments: str) -> List:
return attrs


def parse_layout_transform(
op_name: str, layout_transform: Dict[str, Any]
def parse_data_format_tensors(
op_name: str, data_format_tensors: str
) -> Tuple[str, List]:
if layout_transform is None:
if data_format_tensors is None:
return "", []
return layout_transform["arg_name"], parse_plain_list(
layout_transform["tensors"]
)
return parse_plain_list(data_format_tensors)
16 changes: 4 additions & 12 deletions paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml
Original file line number Diff line number Diff line change
@@ -1,27 +1,19 @@

- op : conv2d
extra_args : bool is_test=false
layout_transform :
arg_name: data_format
tensors: input
data_format_tensors : input

- op : conv2d_grad
extra_args : bool is_test=false
layout_transform :
arg_name: data_format
tensors: input, out_grad
data_format_tensors : input, out_grad

- op : lrn
extra_args : bool is_test=false
layout_transform :
arg_name: data_format
tensors: x
data_format_tensors : x

- op : lrn_grad
extra_args : bool is_test=false
layout_transform :
arg_name: data_format
tensors: x, out, mid_out, out_grad
data_format_tensors : x, out, mid_out, out_grad

# - op : matmul
# extra_args : str mkldnn_data_type="float32"
Expand Down
9 changes: 3 additions & 6 deletions paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ struct OpRunTimeInfo {
std::vector<std::pair<std::string, std::string>> inplace;
std::vector<std::pair<std::string, std::string>> view;
std::vector<std::string> extra_args;
std::string layout_transform_arg;
std::vector<std::string> layout_transform_inputs;
std::vector<std::string> data_format_tensors;
bool is_onednn_only;
bool dynamic_fallback;

Expand All @@ -108,8 +107,7 @@ struct OpRunTimeInfo {
const std::vector<std::pair<std::string, std::string>>& inplace,
const std::vector<std::pair<std::string, std::string>>& view,
const std::vector<std::string>& extra_args = {},
const std::string& layout_transform_arg = "",
const std::vector<std::string>& layout_transform_inputs = {},
const std::vector<std::string>& data_format_tensors = {},
bool is_onednn_only = false,
bool dynamic_fallback = false)
: infer_meta_func(infer_meta_func),
Expand All @@ -121,8 +119,7 @@ struct OpRunTimeInfo {
inplace(inplace),
view(view),
extra_args(extra_args),
layout_transform_arg(layout_transform_arg),
layout_transform_inputs(layout_transform_inputs),
data_format_tensors(data_format_tensors),
is_onednn_only(is_onednn_only),
dynamic_fallback(dynamic_fallback) {}
};
Expand Down
15 changes: 5 additions & 10 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2086,18 +2086,13 @@ pir::Operation* BuildKernelOp(
op_attribute.emplace(
"extra_args",
pir::ArrayAttribute::get(pir::IrContext::Instance(), extra_args));
op_attribute.emplace(
"layout_transform_arg",
pir::StrAttribute::get(
ctx, op_info_parser->OpRuntimeInfo().layout_transform_arg));
std::vector<pir::Attribute> layout_transform_inputs;
for (auto& input :
op_info_parser->OpRuntimeInfo().layout_transform_inputs) {
layout_transform_inputs.push_back(pir::StrAttribute::get(ctx, input));
std::vector<pir::Attribute> data_format_tensors;
for (auto& input : op_info_parser->OpRuntimeInfo().data_format_tensors) {
data_format_tensors.push_back(pir::StrAttribute::get(ctx, input));
}
op_attribute.emplace("layout_transform_inputs",
op_attribute.emplace("data_format_tensors",
pir::ArrayAttribute::get(pir::IrContext::Instance(),
layout_transform_inputs));
data_format_tensors));
op_attribute.emplace(
"is_onednn_only",
pir::BoolAttribute::get(
Expand Down