Skip to content

Commit

Permalink
add _get_phi_kernel_name interface (#47033)
Browse files Browse the repository at this point in the history
  • Loading branch information
JZZ-NOTE committed Oct 20, 2022
1 parent c74bf01 commit 4c92524
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 4 deletions.
7 changes: 7 additions & 0 deletions paddle/fluid/pybind/inference_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_pass_builder.h"
#include "paddle/fluid/inference/utils/io_utils.h"
#include "paddle/phi/core/compat/convert_utils.h"

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/phi/core/cuda_stream.h"
Expand Down Expand Up @@ -401,6 +402,12 @@ void BindInferenceApi(py::module *m) {
new paddle_infer::Predictor(config));
return pred;
});
m->def(
"_get_phi_kernel_name",
[](const std::string &fluid_op_name) {
return phi::TransToPhiKernelName(fluid_op_name);
},
py::return_value_policy::reference);
m->def("copy_tensor", &CopyPaddleInferTensor);
m->def("paddle_dtype_size", &paddle::PaddleDtypeSize);
m->def("paddle_tensor_to_bytes", &SerializePDTensorToBytes);
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def to_list(s):
from .libpaddle import _get_current_stream
from .libpaddle import _Profiler, _ProfilerResult, _RecordEvent
from .libpaddle import _set_current_stream
from .libpaddle import _get_phi_kernel_name
if sys.platform != 'win32':
from .libpaddle import _set_process_pids
from .libpaddle import _erase_process_pids
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
from .wrapper import Config, DataType, PlaceType, PrecisionType, Tensor, Predictor
from .wrapper import convert_to_mixed_precision

from ..core import create_predictor, get_version, get_num_bytes_of_data_type, PredictorPool, get_trt_compile_version, get_trt_runtime_version
from ..core import create_predictor, get_version, _get_phi_kernel_name, get_num_bytes_of_data_type, PredictorPool, get_trt_compile_version, get_trt_runtime_version
7 changes: 4 additions & 3 deletions python/paddle/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ..fluid.inference import Tensor # noqa: F401
from ..fluid.inference import Predictor # noqa: F401
from ..fluid.inference import create_predictor # noqa: F401
from ..fluid.inference import _get_phi_kernel_name
from ..fluid.inference import get_version # noqa: F401
from ..fluid.inference import get_trt_compile_version # noqa: F401
from ..fluid.inference import get_trt_runtime_version # noqa: F401
Expand All @@ -28,7 +29,7 @@

__all__ = [ # noqa
'Config', 'DataType', 'PlaceType', 'PrecisionType', 'Tensor', 'Predictor',
'create_predictor', 'get_version', 'get_trt_compile_version',
'convert_to_mixed_precision', 'get_trt_runtime_version',
'get_num_bytes_of_data_type', 'PredictorPool'
'create_predictor', 'get_version', '_get_phi_kernel_name',
'get_trt_compile_version', 'convert_to_mixed_precision',
'get_trt_runtime_version', 'get_num_bytes_of_data_type', 'PredictorPool'
]

0 comments on commit 4c92524

Please sign in to comment.