Skip to content

Commit

Permalink
generate map of extra attrs for ops (#44106)
Browse files Browse the repository at this point in the history
  • Loading branch information
zyfncg committed Jul 6, 2022
1 parent 07b68eb commit 24d07b7
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ paddle/fluid/API_PR.spec
paddle/fluid/eager/api/generated/*
paddle/fluid/op_use_default_grad_maker_DEV.spec
paddle/fluid/op_use_default_grad_maker_PR.spec
paddle/fluid/operators/ops_extra_info.h
paddle/phi/api/backward/backward_api.h
paddle/phi/api/backward/sparse_bw_api.h
paddle/phi/api/include/api.h
Expand Down
15 changes: 15 additions & 0 deletions paddle/phi/api/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ set(wrapped_infermeta_header_file
set(wrapped_infermeta_source_file
${CMAKE_SOURCE_DIR}/paddle/phi/infermeta/generated.cc)

# op extra info file
set(ops_extra_info_gen_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/ops_extra_info_gen.py)
set(api_compat_yaml_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/api_compat.yaml)
set(ops_extra_info_file
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/ops_extra_info.h)

if(NOT PYTHONINTERP_FOUND)
find_package(PythonInterp REQUIRED)
endif()
Expand Down Expand Up @@ -211,6 +219,13 @@ else()
message("remove ${generated_argument_mapping_path}")
endif()

# generate ops extra info
execute_process(
COMMAND
${PYTHON_EXECUTABLE} ${ops_extra_info_gen_file} --api_compat_yaml_path
${api_compat_yaml_file} --ops_extra_info_path ${ops_extra_info_file})
message("generate ${ops_extra_info_file}")

# generate forward api
add_custom_command(
OUTPUT ${api_header_file} ${api_source_file}
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/api_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,12 @@
x : Input
outputs :
out : Out

- api : conv2d
extra :
attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", bool fuse_alpha = false, bool fuse_beta = false, bool use_addto = false,
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
int workspace_size_MB = 512, bool exhaustive_search = false]
2 changes: 2 additions & 0 deletions paddle/phi/api/yaml/generator/generate_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def main(api_yaml_path, backward_yaml_path, api_compat_yaml_path,
api_args_map = yaml.safe_load(f)
# replace args name for OpMaker
for api_args in api_args_map:
if api_args['api'] not in forward_api_dict:
continue
forward_api_item = forward_api_dict[api_args['api']]
has_backward = True if forward_api_item['backward'] else False
if has_backward:
Expand Down
110 changes: 110 additions & 0 deletions paddle/phi/api/yaml/generator/ops_extra_info_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import yaml
import re
import argparse


def map_code_template(attrs_str):
return f"""
#include "paddle/fluid/framework/attribute.h"
namespace paddle {{
const static std::unordered_map<std::string, paddle::framework::AttributeMap> extra_attrs_map = {{
{attrs_str}
}};
}} // namespace paddle
"""


ATTR_TYPE_STRING_MAP = {
'bool': 'bool',
'int': 'int',
'int64_t': 'int64_t',
'float': 'float',
'double': 'double',
'str': 'std::string',
'int[]': 'std::vector<int>',
'int64_t[]': 'std::vector<int64_t>',
'float[]': 'std::vector<float>',
'double[]': 'std::vector<double>',
'str[]': 'std::vector<std::string>'
}


def parse_attr(attr_str):
result = re.search(
r"(?P<attr_type>[a-z[\]]+)\s+(?P<name>[a-zA-Z0-9_]+)\s*=\s*(?P<default_val>\S+)",
attr_str)
return ATTR_TYPE_STRING_MAP[result.group('attr_type')], result.group(
'name'), result.group('default_val')


def generate_extra_info(api_compat_yaml_path, ops_extra_info_path):
compat_apis = []
with open(api_compat_yaml_path, 'rt') as f:
compat_apis = yaml.safe_load(f)

extra_map_str_list = []

for api_compat_args in compat_apis:
if 'extra' in api_compat_args:
extra_args_map = api_compat_args['extra']
# TODO(chenweihang): add inputs and outputs
if 'attrs' in extra_args_map:
attr_map_list = []
for attr in extra_args_map['attrs']:
attr_type, attr_name, default_val = parse_attr(attr)
if attr_type.startswith("std::vector"):
attr_map_list.append(
f"{{\"{attr_name}\", {attr_type}{default_val}}}")
else:
attr_map_list.append(
f"{{\"{attr_name}\", {attr_type}{{{default_val}}}}}"
)
api_extra_attr_map = ", ".join(attr_map_list)
extra_map_str_list.append(
f"{{\"{api_compat_args['api']}\", {{ {api_extra_attr_map} }}}}"
)

ops_extra_info_file = open(ops_extra_info_path, 'w')
ops_extra_info_file.write(map_code_template(",\n".join(extra_map_str_list)))
ops_extra_info_file.close()


def main():
parser = argparse.ArgumentParser(
description='Generate PaddlePaddle Extra Param Info for Op')
parser.add_argument('--api_compat_yaml_path',
help='path to api compat yaml file',
default='paddle/phi/api/yaml/api_compat.yaml')

parser.add_argument('--ops_extra_info_path',
help='output of generated extra_prama_info code file',
default='paddle/fluid/operators/ops_extra_info.h')

options = parser.parse_args()

api_compat_yaml_path = options.api_compat_yaml_path
ops_extra_info_path = options.ops_extra_info_path

generate_extra_info(api_compat_yaml_path, ops_extra_info_path)


if __name__ == '__main__':
main()

0 comments on commit 24d07b7

Please sign in to comment.