From 3ab9aef12edcffe63577e064705a963ea1dd3204 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 26 Jan 2022 21:07:03 +0800 Subject: [PATCH] [pten] remove deprecated fluid op kernel for pten (#38842) * update cmake file to remove fluid kernel * add pten declaration.h to where pybind.h used * fix sync_bn and tensorrt_engine * refine detection_library * fix interpreter_core * support eager legacy * fit eager legacy for pten * fall back to cpu if not found kernel * fix compile problem * fix compile problem * refine fallback logic * fit operator.run() * fix xpu compile * fit for new_exec * add REGISTER_OP_WITHOUT_GRADIENT * un-cache pt_kernel_context * fix compile * fix cudnn * fix compiling with on_infer * fix mkldnn * fix isfinite_v2 * fix xpu problem * fix op_device * refine fallback for xpu * fix xpu compile * merge develop * refine code format * fix compile * fix compile * add data_transfer * fix PreparePtenData * fix cpu context * merge develop * fix compile * fix error device context * fix xpu * fix dev_ctx --- cmake/operators.cmake | 262 +++++++++------- .../auto_code_generator/eager_generator.cc | 6 +- paddle/fluid/eager/legacy/op_runner.cc | 1 + .../fluid/eager/legacy/prepared_operator.cc | 111 ++++++- paddle/fluid/eager/legacy/prepared_operator.h | 14 + paddle/fluid/framework/CMakeLists.txt | 13 +- paddle/fluid/framework/async_executor.cc | 3 + .../fluid/framework/executor_thread_worker.cc | 3 + .../framework/new_executor/CMakeLists.txt | 2 +- .../framework/new_executor/interpretercore.cc | 2 - .../new_executor/interpretercore_util.cc | 69 +++-- .../new_executor/new_executor_defs.cc | 2 +- paddle/fluid/framework/operator.cc | 104 +++++-- paddle/fluid/framework/operator.h | 41 ++- paddle/fluid/framework/pten_utils.cc | 34 ++ paddle/fluid/framework/pten_utils.h | 9 + paddle/fluid/imperative/prepared_operator.cc | 291 +++--------------- paddle/fluid/imperative/prepared_operator.h | 242 +++++++++++++++ paddle/fluid/inference/io.cc | 3 + paddle/fluid/operators/CMakeLists.txt | 4 - paddle/fluid/operators/benchmark/op_tester.cc | 3 + paddle/fluid/operators/cast_op.cc | 38 ++- paddle/fluid/operators/cast_op.cu | 2 - .../operators/controlflow/CMakeLists.txt | 6 - .../fluid/operators/detection/CMakeLists.txt | 15 +- paddle/fluid/operators/fused/CMakeLists.txt | 15 - paddle/fluid/operators/isfinite_v2_op.cc | 98 ++++-- paddle/fluid/operators/isfinite_v2_op.cu | 49 ++- paddle/fluid/operators/nccl/CMakeLists.txt | 1 - .../fluid/operators/reduce_ops/CMakeLists.txt | 18 -- paddle/fluid/operators/sign_op.cc | 12 +- paddle/fluid/operators/sign_op_xpu.cc | 3 +- .../fluid/operators/tensorrt/CMakeLists.txt | 1 - .../fluid/platform/device/xpu/CMakeLists.txt | 3 +- .../pybind/eager_op_function_generator.cc | 3 + paddle/fluid/pybind/op_function_generator.cc | 5 +- paddle/pten/core/kernel_factory.h | 15 +- paddle/pten/core/type_defs.h | 11 + paddle/pten/kernels/CMakeLists.txt | 1 + 39 files changed, 945 insertions(+), 570 deletions(-) diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 2d4aa1a815fff..d7742c3473724 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -1,6 +1,29 @@ # CMake file `unity_build` is used to handle Unity Build compilation. include(unity_build) set(PART_CUDA_KERNEL_FILES) + +function(find_register FILENAME PATTERN OUTPUT) +# find the op_name of REGISTER_OPERATOR(op_name, ...), REGISTER_OP_CPU_KERNEL(op_name, ...) , etc. +# set op_name to OUTPUT + set(options "") + set(oneValueArgs "") + set(multiValueArgs "") + file(READ ${FILENAME} CONTENT) + # message ("number of arguments sent to function: ${ARGC}") + # message ("all function arguments: ${ARGV}") + # message("PATTERN ${PATTERN}") + string(REGEX MATCH "${PATTERN}\\([ \t\r\n]*[a-z0-9_]*," register "${CONTENT}") + if (NOT register STREQUAL "") + string(REPLACE "${PATTERN}(" "" register "${register}") + string(REPLACE "," "" register "${register}") + # [ \t\r\n]+ is used for blank characters. + # Here we use '+' instead of '*' since it is a REPLACE operation. + string(REGEX REPLACE "[ \t\r\n]+" "" register "${register}") + endif() + + set(${OUTPUT} ${register} PARENT_SCOPE) +endfunction() + function(op_library TARGET) # op_library is a function to create op library. The interface is same as # cc_library. But it handle split GPU/CPU code and link some common library @@ -119,16 +142,16 @@ function(op_library TARGET) list(APPEND miopen_cu_cc_srcs ${src}) elseif(WITH_ROCM AND ${src} MATCHES ".*\\.cu.cc$") list(APPEND hip_cc_srcs ${src}) - elseif(${src} MATCHES ".*_cudnn_op.cu$") + elseif(WITH_GPU AND ${src} MATCHES ".*_cudnn_op.cu$") list(APPEND cudnn_cu_srcs ${src}) - elseif (${src} MATCHES ".*\\.cu$") + elseif (WITH_GPU AND ${src} MATCHES ".*\\.cu$") list(APPEND cu_srcs ${src}) - elseif(${src} MATCHES ".*_cudnn_op.cu.cc$") + elseif(WITH_GPU AND ${src} MATCHES ".*_cudnn_op.cu.cc$") list(APPEND cudnn_cu_cc_srcs ${src}) + elseif(WITH_GPU AND ${src} MATCHES ".*\\.cu.cc$") + list(APPEND cu_cc_srcs ${src}) elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$") list(APPEND mkldnn_cc_srcs ${src}) - elseif(${src} MATCHES ".*\\.cu.cc$") - list(APPEND cu_cc_srcs ${src}) elseif(WITH_XPU AND ${src} MATCHES ".*_op_xpu.cc$") list(APPEND xpu_cc_srcs ${src}) elseif(WITH_ASCEND_CL AND ${src} MATCHES ".*_op_npu.cc$") @@ -228,135 +251,136 @@ function(op_library TARGET) endif() endif() + list(LENGTH cu_srcs cu_srcs_len) + list(LENGTH hip_srcs hip_srcs_len) + list(LENGTH cu_cc_srcs cu_cc_srcs_len) + list(LENGTH hip_cc_srcs hip_cc_srcs_len) + list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len) + list(LENGTH xpu_cc_srcs xpu_cc_srcs_len) + list(LENGTH miopen_cu_cc_srcs miopen_cu_cc_srcs_len) + list(LENGTH npu_cc_srcs npu_cc_srcs_len) + list(LENGTH mlu_cc_srcs mlu_cc_srcs_len) + # Define operators that don't need pybind here. foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "bitwise_op" "nccl_op" -"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" -"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" -"sync_batch_norm_op" "sparse_attention_op" "dgc_op" "fused_fc_elementwise_layernorm_op" -"skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op" -"fused_bn_add_activation_op" "fused_attention_op" "resnet_unit_op" "fused_feedforward_op") - - if ("${TARGET}" STREQUAL "${manual_pybind_op}") - set(pybind_flag 1) - endif() - endforeach() + "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op") + + if ("${TARGET}" STREQUAL "${manual_pybind_op}") + set(pybind_flag 1) + endif() + endforeach() # The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h. # Note that it's enough to just adding one operator to pybind in a *_op.cc file. # And for detail pybind information, please see generated paddle/pybind/pybind.h. set(ORIGINAL_TARGET ${TARGET}) - file(READ ${TARGET}.cc TARGET_CONTENT) - string(REGEX MATCH "REGISTER_OPERATOR\\(.*REGISTER_OPERATOR\\(" multi_register "${TARGET_CONTENT}") - # [ \t\r\n]* is used for blank characters - string(REGEX MATCH "REGISTER_OPERATOR\\([ \t\r\n]*[a-z0-9_]*," one_register "${multi_register}") + string(REGEX REPLACE "_op" "" TARGET "${TARGET}") - if (one_register STREQUAL "") - string(REPLACE "_op" "" TARGET "${TARGET}") - else () - string(REPLACE "REGISTER_OPERATOR(" "" TARGET "${one_register}") - string(REPLACE "," "" TARGET "${TARGET}") - # [ \t\r\n]+ is used for blank characters. - # Here we use '+' instead of '*' since it is a REPLACE operation. - string(REGEX REPLACE "[ \t\r\n]+" "" TARGET "${TARGET}") - endif() + foreach(cc_src ${cc_srcs}) + # pybind USE_OP_ITSELF + set(op_name "") + find_register(${cc_src} "REGISTER_OPERATOR" op_name) + if(NOT ${op_name} EQUAL "") + file(APPEND ${pybind_file} "USE_OP_ITSELF(${op_name});\n") + # hack: for example, the target in conv_transpose_op.cc is conv2d_transpose, used in mkldnn + set(TARGET ${op_name}) + set(pybind_flag 1) + endif() + + set(op_name "") + find_register(${cc_src} "REGISTER_OP_WITHOUT_GRADIENT" op_name) + if(NOT ${op_name} EQUAL "") + file(APPEND ${pybind_file} "USE_OP_ITSELF(${op_name});\n") + # hack: for example, the target in conv_transpose_op.cc is conv2d_transpose, used in mkldnn + set(TARGET ${op_name}) + set(pybind_flag 1) + endif() - # pybind USE_NO_KERNEL_OP - # HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel - string(REGEX MATCH "REGISTER_OP_CPU_KERNEL" regex_result "${TARGET_CONTENT}") - string(REPLACE "_op" "" TARGET "${TARGET}") - if (${pybind_flag} EQUAL 0 AND regex_result STREQUAL "") - file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(${TARGET});\n") - set(pybind_flag 1) - endif() + # pybind USE_OP_DEVICE_KERNEL for CPU + set(op_name "") + find_register(${cc_src} "REGISTER_OP_CPU_KERNEL" op_name) + if(NOT ${op_name} EQUAL "") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CPU);\n") + # why change TARGET here? + # when building padle with on_infer, the REGISTER_OPERATOR(*_grad) will be removed before compiling (see details in remove_grad_op_and_kernel.py) + # in elementwise_op.cc, it will find REGISTER_OPERATOR(grad_add) and set TARGET to grad_add + # and, in the following "mkldnn" part, it will add USE_OP_DEVICE_KERNEL(grad_add, MKLDNN) to pybind.h + # however, grad_add has no mkldnn kernel. + set(TARGET ${op_name}) + set(pybind_flag 1) + endif() + endforeach() - # pybind USE_CPU_ONLY_OP - list(LENGTH cu_srcs cu_srcs_len) - list(LENGTH hip_srcs hip_srcs_len) - list(LENGTH cu_cc_srcs cu_cc_srcs_len) - list(LENGTH hip_cc_srcs hip_cc_srcs_len) - list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len) - list(LENGTH xpu_cc_srcs xpu_cc_srcs_len) - list(LENGTH miopen_cu_cc_srcs miopen_cu_cc_srcs_len) - list(LENGTH npu_cc_srcs npu_cc_srcs_len) - list(LENGTH mlu_cc_srcs mlu_cc_srcs_len) - if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND - ${hip_srcs_len} EQUAL 0 AND ${hip_cc_srcs_len} EQUAL 0 AND ${miopen_cu_cc_srcs_len} EQUAL 0 AND ${xpu_cc_srcs_len} EQUAL 0 AND - ${npu_cc_srcs_len} EQUAL 0 AND ${mlu_cc_srcs_len} EQUAL 0) - file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") - set(pybind_flag 1) - endif() + # pybind USE_OP_DEVICE_KERNEL for CUDA + list (APPEND cu_srcs ${cu_cc_srcs}) + # message("cu_srcs ${cu_srcs}") + foreach(cu_src ${cu_srcs}) + set(op_name "") + find_register(${cu_src} "REGISTER_OP_CUDA_KERNEL" op_name) + if(NOT ${op_name} EQUAL "") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDA);\n") + set(pybind_flag 1) + endif() + endforeach() - # pybind USE_OP_DEVICE_KERNEL for CUDNN - list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len) - if (WITH_GPU AND ${cudnn_cu_cc_srcs_len} GREATER 0) - if(${TARGET} STREQUAL "activation") - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, CUDNN);\n") - else() - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n") - endif() - endif() - # pybind USE_OP_DEVICE_KERNEL for MIOPEN - list(LENGTH miopen_cu_cc_srcs miopen_cu_cc_srcs_len) - if (WITH_ROCM AND ${miopen_cu_cc_srcs_len} GREATER 0) - if(${TARGET} STREQUAL "activation") + # pybind USE_OP_DEVICE_KERNEL for CUDNN/MIOPEN + list(APPEND cudnn_cu_srcs ${cudnn_cu_cc_srcs}) + list(APPEND cudnn_cu_srcs ${miopen_cu_cc_srcs}) + list(APPEND cudnn_cu_srcs ${miopen_cu_srcs}) + list(LENGTH cudnn_cu_srcs cudnn_cu_srcs_len) + #message("cudnn_cu_srcs ${cudnn_cu_srcs}") + if(${cudnn_cu_srcs_len} GREATER 0 AND ${ORIGINAL_TARGET} STREQUAL "activation_op") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, CUDNN);\n") - else() - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n") - endif() + else() + foreach(cudnn_src ${cudnn_cu_srcs}) + set(op_name "") + find_register(${cudnn_src} "REGISTER_OP_KERNEL" op_name) + if(NOT ${op_name} EQUAL "") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDNN);\n") + set(pybind_flag 1) + endif() + endforeach() endif() - # pybind USE_OP_DEVICE_KERNEL for CUDNN - list(LENGTH cudnn_cu_srcs cudnn_cu_srcs_len) - if (WITH_GPU AND ${cudnn_cu_srcs_len} GREATER 0) - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n") - endif() - # pybind USE_OP_DEVICE_KERNEL for MIOPEN - list(LENGTH miopen_cu_srcs miopen_cu_srcs_len) - if (WITH_ROCM AND ${miopen_cu_srcs_len} GREATER 0) - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n") + if (WITH_XPU AND ${xpu_cc_srcs_len} GREATER 0) + if(${ORIGINAL_TARGET} STREQUAL "activation_op") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, XPU);\n") + else() + foreach(xpu_src ${xpu_cc_srcs}) + set(op_name "") + find_register(${xpu_src} "REGISTER_OP_XPU_KERNEL" op_name) + if(NOT ${op_name} EQUAL "") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, XPU);\n") + set(pybind_flag 1) + endif() + endforeach() endif() - - if (WITH_XPU AND ${pybind_flag} EQUAL 0 AND ${xpu_cc_srcs_len} GREATER 0) - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, XPU);\n") endif() + # pybind USE_OP_DEVICE_KERNEL for NPU if (WITH_ASCEND_CL AND ${npu_cc_srcs_len} GREATER 0) - file(READ ${ORIGINAL_TARGET}_npu.cc TARGET_NPU_CONTENT) - # It is different from the logic above, becareful - string(REGEX MATCH "REGISTER_OP_NPU_KERNEL\\(.*" multi_npu_register "${TARGET_NPU_CONTENT}") - # [ \t\r\n]* is used for blank characters - string(REGEX MATCH "REGISTER_OP_NPU_KERNEL\\([ \t\r\n]*[a-z0-9_]*," one_npu_register "${multi_npu_register}") - - if (one_npu_register STREQUAL "") - string(REPLACE "_op" "" NPU_TARGET "${TARGET}") - else () - string(REPLACE "REGISTER_OP_NPU_KERNEL(" "" NPU_TARGET "${one_npu_register}") - string(REPLACE "," "" NPU_TARGET "${NPU_TARGET}") - # [ \t\r\n]+ is used for blank characters. - # Here we use '+' instead of '*' since it is a REPLACE operation. - string(REGEX REPLACE "[ \t\r\n]+" "" NPU_TARGET "${NPU_TARGET}") + foreach(npu_src ${npu_cc_srcs}) + set(op_name "") + find_register(${npu_src} "REGISTER_OP_NPU_KERNEL" op_name) + if(NOT ${op_name} EQUAL "") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, NPU);\n") + set(pybind_flag 1) endif() - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${NPU_TARGET}, NPU);\n") + endforeach() endif() - if (WITH_MLU AND ${mlu_cc_srcs_len} GREATER 0) - file(READ ${ORIGINAL_TARGET}_mlu.cc TARGET_MLU_CONTENT) - # It is different from the logic above, becareful - string(REGEX MATCH "REGISTER_OP_MLU_KERNEL\\(.*" multi_mlu_register "${TARGET_MLU_CONTENT}") - # [ \t\r\n]* is used for blank characters - string(REGEX MATCH "REGISTER_OP_MLU_KERNEL\\([ \t\r\n]*[a-z0-9_]*," one_mlu_register "${multi_mlu_register}") - if (one_mlu_register STREQUAL "") - string(REPLACE "_op" "" MLU_TARGET "${TARGET}") - else () - string(REPLACE "REGISTER_OP_MLU_KERNEL(" "" MLU_TARGET "${one_mlu_register}") - string(REPLACE "," "" MLU_TARGET "${MLU_TARGET}") - # [ \t\r\n]+ is used for blank characters. - # Here we use '+' instead of '*' since it is a REPLACE operation. - string(REGEX REPLACE "[ \t\r\n]+" "" MLU_TARGET "${MLU_TARGET}") + # pybind USE_OP_DEVICE_KERNEL for MLU + if (WITH_MLU AND ${mlu_cc_srcs_len} GREATER 0) + foreach(mlu_src ${mlu_cc_srcs}) + set(op_name "") + find_register(${mlu_src} "REGISTER_OP_MLU_KERNEL" op_name) + if(NOT ${op_name} EQUAL "") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, MLU);\n") + set(pybind_flag 1) endif() - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${MLU_TARGET}, MLU);\n") + endforeach() endif() # pybind USE_OP_DEVICE_KERNEL for MKLDNN @@ -377,10 +401,26 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, S8);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, U8);\n") else() - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") + foreach(mkldnn_src ${mkldnn_cc_srcs}) + set(op_name "") + find_register(${mkldnn_src} "REGISTER_OP_KERNEL" op_name) + if(NOT ${op_name} EQUAL "") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, MKLDNN);\n") + set(pybind_flag 1) + endif() + endforeach() endif() endif() + # pybind USE_NO_KERNEL_OP + # HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel + string(REGEX MATCH "REGISTER_OP_CPU_KERNEL" regex_result "${TARGET_CONTENT}") + string(REPLACE "_op" "" TARGET "${TARGET}") + if (${pybind_flag} EQUAL 0 AND regex_result STREQUAL "") + file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(${TARGET});\n") + set(pybind_flag 1) + endif() + # pybind USE_OP if (${pybind_flag} EQUAL 0) # NOTE(*): activation use macro to regist the kernels, set use_op manually. diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index b79b69356b3ac..d0a5ad13f74ea 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -27,6 +27,9 @@ #include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/string/string_helper.h" +// pten +#include "paddle/pten/kernels/declarations.h" + #define NUM_CREATED_DUP_INPUTS 4 namespace paddle { @@ -535,7 +538,8 @@ static bool CheckOpProto(proto::OpProto* op_proto) { // Skip ooerator which is not inherit form OperatorWithKernel, like while, // since only OperatorWithKernel can run in dygraph mode. auto& all_kernels = paddle::framework::OperatorWithKernel::AllOpKernels(); - if (!all_kernels.count(op_type)) { + if (!all_kernels.count(op_type) && + !pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type)) { return false; } diff --git a/paddle/fluid/eager/legacy/op_runner.cc b/paddle/fluid/eager/legacy/op_runner.cc index 305d66d134c36..4f88346dab9c5 100644 --- a/paddle/fluid/eager/legacy/op_runner.cc +++ b/paddle/fluid/eager/legacy/op_runner.cc @@ -93,6 +93,7 @@ void OpRunImpl(const paddle::framework::OperatorBase& op, prepared_op.Run(*tmp_ins_ptr, outs, attrs, default_attrs); } + VLOG(6) << "Run Prepared Op end"; // TODO(jiabin): Set the output var's grad Forward DataType } diff --git a/paddle/fluid/eager/legacy/prepared_operator.cc b/paddle/fluid/eager/legacy/prepared_operator.cc index 3179b96807119..fcdf4162685c5 100644 --- a/paddle/fluid/eager/legacy/prepared_operator.cc +++ b/paddle/fluid/eager/legacy/prepared_operator.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/eager/legacy/prepared_operator.h" +#include "paddle/fluid/imperative/prepared_operator.h" #include "paddle/fluid/eager/legacy/infer_shape_context.h" #include "paddle/fluid/framework/data_type_transform.h" @@ -71,6 +72,21 @@ PreparedOp::PreparedOp( func_(func), dev_ctx_(dev_ctx) {} +PreparedOp::PreparedOp( + const paddle::framework::OperatorBase& op, + const paddle::framework::RuntimeContext& ctx, + const paddle::framework::OpKernelType& kernel_type, + const paddle::framework::KernelSignature& kernel_signature, + const pten::Kernel& pt_kernel, paddle::platform::DeviceContext* dev_ctx) + : op_(op), + ctx_(ctx), + kernel_type_(kernel_type), + func_(nullptr), + dev_ctx_(dev_ctx), + run_pten_kernel_(true), + pt_kernel_signature_(kernel_signature), + pt_kernel_(pt_kernel) {} + PreparedOp PrepareImpl(const NameTensorMap& ins, const NameTensorMap& outs, const paddle::framework::OperatorWithKernel& op, const paddle::platform::Place& place, @@ -104,17 +120,71 @@ PreparedOp PrepareImpl(const NameTensorMap& ins, const NameTensorMap& outs, auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx); VLOG(3) << "expected_kernel_key:" << expected_kernel_key; + // fit for pten + pten::KernelSignature pt_kernel_signature; + pten::KernelKey pt_kernel_key; + std::string pt_kernel_name; + if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) { + pt_kernel_signature = op.GetExpectedPtenKernelArgs(dygraph_exe_ctx); + VLOG(6) << pt_kernel_signature; + + pt_kernel_name = pt_kernel_signature.name; + pt_kernel_key = TransOpKernelTypeToPtenKernelKey(expected_kernel_key); + auto pt_kernel = pten::KernelFactory::Instance().SelectKernel( + pt_kernel_name, pt_kernel_key); + + if (pt_kernel.IsValid()) { + VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name + << " | kernel key: " << pt_kernel_key + << " | kernel: " << pt_kernel; + + // TODO(chenweihang): using CPUKernel when miss device kernel case + return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature, + pt_kernel, dev_ctx); + } else { + VLOG(6) << "Dynamic mode ChoosePtenKernel - kernel `" << pt_kernel_name + << "` not found."; + } + } + // 2. check if op[type] has kernel registered. auto& all_op_kernels = op.AllOpKernels(); auto kernels_iter = all_op_kernels.find(op.Type()); + + if (kernels_iter == all_op_kernels.end() || + kernels_iter->second.find(expected_kernel_key) == + kernels_iter->second.end() +#ifdef PADDLE_WITH_XPU + || + paddle::platform::is_xpu_place(expected_kernel_key.place_) && + !paddle::platform::is_xpu_support_op(op.Type(), + expected_kernel_key) || + paddle::platform::is_in_xpu_black_list(op.Type()) +#endif + ) { + if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) { + auto pt_cpu_kernel_key = + FallBackToCpu(expected_kernel_key, pt_kernel_key, op); + auto pt_cpu_kernel = pten::KernelFactory::Instance().SelectKernel( + pt_kernel_name, pt_cpu_kernel_key); + if (pt_cpu_kernel.IsValid()) { + VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name + << " | kernel key: " << pt_cpu_kernel_key + << " | kernel: " << pt_cpu_kernel; + return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature, + pt_cpu_kernel, dev_ctx); + } + } + } + PADDLE_ENFORCE_NE( kernels_iter, all_op_kernels.end(), paddle::platform::errors::NotFound( "There are no kernels which are registered in the %s operator.", op.Type())); - auto& kernels = kernels_iter->second; auto kernel_iter = kernels.find(expected_kernel_key); + #ifdef PADDLE_WITH_XPU if (paddle::platform::is_xpu_place(expected_kernel_key.place_) && (kernel_iter == kernels.end() || @@ -202,11 +272,46 @@ static void PreparedOpRunImpl( VLOG(6) << "Finish Runing Prepared Op"; } +static void PreparedOpRunPtImpl( + const paddle::framework::OperatorBase& op, + const paddle::framework::OpKernelType& kernel_type, + const paddle::framework::KernelSignature& pt_kernel_signature, + const pten::Kernel& pt_kernel, paddle::platform::DeviceContext* dev_ctx, + const NameTensorMap& ins, const NameTensorMap& outs, + const paddle::framework::AttributeMap& attrs, + const paddle::framework::AttributeMap& default_attrs) { + EagerInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, &default_attrs, + op.Type()); + static_cast(op).InferShape( + &infer_shape_ctx); + + paddle::imperative::PreparePtenData( + pt_kernel, pt_kernel_signature, + static_cast(ins)); + + pten::KernelContext pt_kernel_context; + paddle::imperative::BuildDygraphPtenKernelContext( + pt_kernel_signature, pt_kernel, + static_cast(ins), + static_cast(outs), attrs, + default_attrs, dev_ctx, &pt_kernel_context); + + pt_kernel(&pt_kernel_context); + + // TODO(chenweihang): add debug flags later + // TODO(chenweihang): deal with complex cases later +} + void PreparedOp::Run(const NameTensorMap& ins, const NameTensorMap& outs, const paddle::framework::AttributeMap& attrs, const paddle::framework::AttributeMap& default_attrs) { - PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, outs, attrs, - default_attrs); + if (run_pten_kernel_) { + PreparedOpRunPtImpl(op_, kernel_type_, pt_kernel_signature_, pt_kernel_, + dev_ctx_, ins, outs, attrs, default_attrs); + } else { + PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, outs, + attrs, default_attrs); + } } std::shared_ptr PrepareData( diff --git a/paddle/fluid/eager/legacy/prepared_operator.h b/paddle/fluid/eager/legacy/prepared_operator.h index 7c448a7629646..c0cb56d99dc1b 100644 --- a/paddle/fluid/eager/legacy/prepared_operator.h +++ b/paddle/fluid/eager/legacy/prepared_operator.h @@ -55,6 +55,13 @@ class PreparedOp { const paddle::framework::OperatorWithKernel::OpKernelFunc& func, paddle::platform::DeviceContext* dev_ctx); + PreparedOp(const paddle::framework::OperatorBase& op, + const paddle::framework::RuntimeContext& ctx, + const paddle::framework::OpKernelType& kernel_type, + const paddle::framework::KernelSignature& kernel_signature, + const pten::Kernel& pt_kernel, + paddle::platform::DeviceContext* dev_ctx); + static PreparedOp Prepare( const NameTensorMap& ins, const NameTensorMap& outs, const paddle::framework::OperatorWithKernel& op, @@ -76,6 +83,13 @@ class PreparedOp { paddle::framework::OpKernelType kernel_type_; paddle::framework::OperatorWithKernel::OpKernelFunc func_; paddle::platform::DeviceContext* dev_ctx_; + + // NOTE(chenweihang): Similar op members are used to adapt to + // new pten kernel, if there is a better design in the future, + // we may polish the implementation here + bool run_pten_kernel_{false}; + paddle::framework::KernelSignature pt_kernel_signature_; + pten::Kernel pt_kernel_; }; } // namespace legacy diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index ce63a58d41ae0..bf6393544f715 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -185,10 +185,17 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co cc_test(no_need_buffer_vars_inference_test SRCS no_need_buffer_vars_inference_test.cc DEPS no_need_buffer_vars_inference layer) cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context) -cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place) cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_vars_inference) +cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place) + +IF(WITH_XPU) +cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows_utils place pten var_type_traits pten_api_utils op_info xpu_op_list) +ELSE() +cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows_utils place pten var_type_traits pten_api_utils op_info) +ENDIF() + IF(WITH_XPU) cc_library(operator SRCS operator.cc DEPS xpu_op_list op_info device_context tensor scope glog trainer_desc_proto data_feed_proto shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils @@ -403,8 +410,8 @@ cc_library(save_load_util SRCS save_load_util.cc DEPS tensor scope layer) cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer) cc_library(generator SRCS generator.cc DEPS enforce place) -cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows_utils place pten var_type_traits pten_api_utils op_info) -cc_library(infershape_utils SRCS infershape_utils.cc DEPS pten_utils attribute shape_inference op_utils) +cc_library(infershape_utils SRCS infershape_utils.cc DEPS lod_tensor selected_rows_utils attribute place pten var_type_traits pten pten_api_utils op_info shape_inference) + # Get the current working branch execute_process( diff --git a/paddle/fluid/framework/async_executor.cc b/paddle/fluid/framework/async_executor.cc index 4c7ef2e600bc1..bb252c4fbdf2f 100644 --- a/paddle/fluid/framework/async_executor.cc +++ b/paddle/fluid/framework/async_executor.cc @@ -33,6 +33,9 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/pybind/pybind.h" +// pten +#include "paddle/pten/kernels/declarations.h" + namespace paddle { namespace framework { AsyncExecutor::AsyncExecutor(Scope* scope, const platform::Place& place) diff --git a/paddle/fluid/framework/executor_thread_worker.cc b/paddle/fluid/framework/executor_thread_worker.cc index b3fab80444a3f..668a28f6008ef 100644 --- a/paddle/fluid/framework/executor_thread_worker.cc +++ b/paddle/fluid/framework/executor_thread_worker.cc @@ -32,6 +32,9 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/timer.h" #include "paddle/fluid/pybind/pybind.h" + +// pten +#include "paddle/pten/kernels/declarations.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/new_executor/CMakeLists.txt b/paddle/fluid/framework/new_executor/CMakeLists.txt index e268bce87acf1..1c2e11a18eeb2 100644 --- a/paddle/fluid/framework/new_executor/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/CMakeLists.txt @@ -31,7 +31,7 @@ endif() # cc_binary(standalone_executor_test SRCS standalone_executor_test.cc DEPS interpretercore standalone_executor operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler) # skip win32 since wget is not installed by default on windows machine. # skip COVERAGE_CI since the test runs slowly because of instrumentation. -if (WITH_TESTING AND NOT WIN32 AND NOT WITH_COVERAGE AND NOT "$ENV{CI_SKIP_CPP_TEST}" STREQUAL "ON") +if (WITH_CUDA AND WITH_TESTING AND NOT WIN32 AND NOT WITH_COVERAGE AND NOT "$ENV{CI_SKIP_CPP_TEST}" STREQUAL "ON") add_custom_target( download_program COMMAND wget -nc https://paddle-ci.gz.bcebos.com/new_exec/lm_main_program diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index f71a5b2c710ce..ef9c5b9213492 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -426,8 +426,6 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { (*instr_node.PtenKernel())(&pt_kernel_context); - op_with_kernel->WriteBackToOutputs( - instr_node.InnerRuntimeContext().get(), &pt_kernel_context); } else { instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); } diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 0371b12d009f3..fb0951e87aa16 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -358,15 +358,6 @@ void build_op_func_list(const platform::Place& place, op_with_kernel->Info().infer_shape_(&infer_shape_ctx); } - auto kernels_iter = all_op_kernels.find(op->Type()); - PADDLE_ENFORCE_NE( - kernels_iter, all_op_kernels.end(), - platform::errors::Unavailable( - "There are no kernels which are registered in the %s operator.", - op->Type())); - - OpKernelMap& kernels = kernels_iter->second; - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); @@ -404,26 +395,41 @@ void build_op_func_list(const platform::Place& place, dev_ctx = pool.Get(expected_kernel_key.place_); } op_func_node.dev_ctx_ = dev_ctx; - + VLOG(3) << op_with_kernel->Type() + << " : expected_kernel_key : " << expected_kernel_key; auto exec_ctx = ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context); - auto kernel_iter = kernels.find(expected_kernel_key); - PADDLE_ENFORCE_NE( - kernel_iter, kernels.end(), - platform::errors::NotFound( - "Operator (%s) does not have kernel for %s.", op->Type(), - KernelTypeToString(expected_kernel_key))); - auto run_pten_kernel = false; - - if (FLAGS_run_pten_kernel && - pten::KernelFactory::Instance().HasCompatiblePtenKernel( + if (pten::KernelFactory::Instance().HasCompatiblePtenKernel( op_with_kernel->Type())) { - op_with_kernel->ChoosePtenKernel(exec_ctx); - run_pten_kernel = op_with_kernel->PtenKernel()->IsValid(); - } + auto pt_kernel_key = op_with_kernel->ChoosePtenKernel(exec_ctx); + auto pt_kernel_name = op_with_kernel->PtenKernelSignature()->name; + if (op_with_kernel->PtenKernel()->IsValid()) { + run_pten_kernel = true; + } else { + auto kernels_iter = all_op_kernels.find(op_with_kernel->Type()); + if (kernels_iter == all_op_kernels.end() || + kernels_iter->second.find(expected_kernel_key) == + kernels_iter->second.end()) { + auto pt_cpu_kernel_key = FallBackToCpu( + expected_kernel_key, pt_kernel_key, *op_with_kernel); + op_with_kernel->ResetPtenKernel( + new pten::Kernel(pten::KernelFactory::Instance().SelectKernel( + pt_kernel_name, pt_cpu_kernel_key))); + if (op_with_kernel->PtenKernel()->IsValid()) { + VLOG(6) << "Static mode PrepareImpl - kernel name: " + << pt_kernel_name + << " | kernel key: " << pt_cpu_kernel_key + << " | kernel: " << *(op_with_kernel->PtenKernel()); + run_pten_kernel = true; + } + } + } + } + VLOG(3) << op_with_kernel->Type() + << " : expected_kernel_key : " << expected_kernel_key; if (run_pten_kernel) { pten::KernelContext pt_kernel_context; op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx, @@ -431,9 +437,22 @@ void build_op_func_list(const platform::Place& place, op_func_node.pt_kernel_ = op_with_kernel->PtenKernel(); (*op_func_node.pt_kernel_)(&pt_kernel_context); - op_with_kernel->WriteBackToOutputs(&runtime_context, - &pt_kernel_context); } else { + auto kernels_iter = all_op_kernels.find(op->Type()); + PADDLE_ENFORCE_NE( + kernels_iter, all_op_kernels.end(), + platform::errors::Unavailable( + "There are no kernels which are registered in the %s operator.", + op->Type())); + OpKernelMap& kernels = kernels_iter->second; + + auto kernel_iter = kernels.find(expected_kernel_key); + PADDLE_ENFORCE_NE( + kernel_iter, kernels.end(), + platform::errors::NotFound( + "Operator (%s) does not have kernel for %s.", op->Type(), + KernelTypeToString(expected_kernel_key))); + // TODO(zhiqiu): add fallback logic op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); op_func_node.kernel_func_(exec_ctx); } diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index 6c5e98489ef5a..c72cbda008f3b 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -313,7 +313,7 @@ bool InterpretercoreInferShapeContext::IsRunMKLDNNKernel() const { return ((op_with_kernel.kernel_type()) && (op_with_kernel.kernel_type()->data_layout_ == framework::DataLayout::kMKLDNN)); - } catch (std::bad_cast exp) { + } catch (std::bad_cast& exp) { return false; } } diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 087a817d03af1..426b5ac8ffdde 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/op_call_stack.h" +#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/unused_var_check.h" @@ -1144,22 +1145,80 @@ void OperatorWithKernel::RunImpl(const Scope& scope, #endif auto exe_ctx = ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx); + // using cache + if (kernel_type_.get()) { + dev_ctx = pool.Get(kernel_type_->place_); + } // TODO(chenweihang): Now we are still reusing a lot of the original fluid // implementation, this is a gradual replacement process // TODO(chenweihang): in the first phase of project, we only support CPU, CUDA // and RCOM backend, the XPU, NPU and MKLDNN will be supported in the second // phase - if (FLAGS_run_pten_kernel && - pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) { + pten::KernelKey pt_kernel_key; + std::string pt_kernel_name; + if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) { if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) { - ChoosePtenKernel(exe_ctx); + pt_kernel_signature_.reset(new KernelSignature( + std::move(this->GetExpectedPtenKernelArgs(exe_ctx)))); + VLOG(6) << *pt_kernel_signature_.get(); + + kernel_type_.reset( + new OpKernelType(std::move(InnerGetExpectedKernelType(exe_ctx)))); + dev_ctx = pool.Get(kernel_type_->place_); + + pt_kernel_name = pt_kernel_signature_->name; + pt_kernel_key = TransOpKernelTypeToPtenKernelKey(*kernel_type_.get()); + pt_kernel_.reset( + new pten::Kernel(pten::KernelFactory::Instance().SelectKernel( + pt_kernel_name, pt_kernel_key))); + + if (pt_kernel_->IsValid()) { + VLOG(6) << "Static mode ChoosePtenKernel - kernel name: " + << pt_kernel_name << " | kernel key: " << pt_kernel_key + << " | kernel: " << *pt_kernel_; + } else { + VLOG(6) << "Static mode ChoosePtenKernel - kernel `" << pt_kernel_name + << "` not found."; + } + } + if (pt_kernel_->IsValid()) { + run_pten_kernel_ = true; + } else { + auto& all_op_kernels = AllOpKernels(); + auto kernels_iter = all_op_kernels.find(type_); + if (kernels_iter == all_op_kernels.end() || + kernels_iter->second.find(*kernel_type_.get()) == + kernels_iter->second.end() +#ifdef PADDLE_WITH_XPU + || + paddle::platform::is_xpu_place(kernel_type_->place_) && // NOLINT + !paddle::platform::is_xpu_support_op( + type_, *kernel_type_.get()) // NOLINT + || paddle::platform::is_in_xpu_black_list(type_) +#endif + ) { + auto pt_cpu_kernel_key = + FallBackToCpu(*kernel_type_.get(), pt_kernel_key, *this); + pt_kernel_.reset( + new pten::Kernel(pten::KernelFactory::Instance().SelectKernel( + pt_kernel_name, pt_cpu_kernel_key))); + + dev_ctx = pool.Get(platform::CPUPlace()); + + if (pt_kernel_->IsValid()) { + VLOG(6) << "Static mode PrepareImpl - kernel name: " << pt_kernel_name + << " | kernel key: " << pt_cpu_kernel_key + << " | kernel: " << *pt_kernel_; + run_pten_kernel_ = true; + } + } } - run_pten_kernel_ = pt_kernel_->IsValid(); } if (!run_pten_kernel_) { if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) { ChooseKernel(exe_ctx); + dev_ctx = pool.Get(kernel_type_->place_); } } @@ -1178,10 +1237,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, const Scope& exec_scope = (transfer_scope == nullptr ? scope : *transfer_scope); - if (!(kernel_type_->place_ == dev_ctx->GetPlace())) { - dev_ctx = pool.Get(kernel_type_->place_); - } - if (!all_kernels_must_compute_runtime_shape_) { platform::RecordEvent record_event("infer_shape", platform::EventRole::kInnerOp); @@ -1201,6 +1256,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, if (run_pten_kernel_) { pten::KernelContext pt_kernel_context; // Do data transform before building KernelContext + // TODO(zhiqiu): support TransferInplaceVarsBack PreparePtenData(exec_scope, *pt_kernel_, *pt_kernel_signature_, runtime_ctx); BuildPtenKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context); @@ -1289,7 +1345,8 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( return expected_kernel_key; } -void OperatorWithKernel::ChoosePtenKernel(const ExecutionContext& ctx) const { +pten::KernelKey OperatorWithKernel::ChoosePtenKernel( + const ExecutionContext& ctx) const { pt_kernel_signature_.reset( new KernelSignature(std::move(this->GetExpectedPtenKernelArgs(ctx)))); VLOG(6) << *pt_kernel_signature_.get(); @@ -1311,6 +1368,7 @@ void OperatorWithKernel::ChoosePtenKernel(const ExecutionContext& ctx) const { VLOG(6) << "Static mode ChoosePtenKernel - kernel `" << pt_kernel_name << "` not found."; } + return pt_kernel_key; } void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const { @@ -1839,25 +1897,21 @@ Scope* OperatorWithKernel::PreparePtenData( continue; } - // TODO(zyfncg): Now there is no kernel which need to transform input - // data, so we commented out following code temporarily, - // and it will be used in the future. + VLOG(3) << "PTen Transform Variable " << input_names[i] << " from " + << tensor_in->place() << " to " << expected_place; - // VLOG(3) << "PTen Transform Variable " << input_names[i] << " from " - // << tensor_in->place() << " to " << expected_place; - - // if (!new_scope) { - // new_scope = &scope.NewScope(); - // } + if (!new_scope) { + new_scope = &scope.NewScope(); + } - // // Create new var with the same name in transfer scopes - // auto* trans_var = new_scope->Var(input_names[i]); - // ins_vector[i] = trans_var; + // Create new var with the same name in transfer scopes + auto* trans_var = new_scope->Var(input_names[i]); + ins_vector[offset] = trans_var; - // // Do transfer - // Tensor out; - // framework::TensorCopySync(*tensor_in, expected_place, &out); - // SetTensorToVariable(*var, out, trans_var); + // Do transfer + Tensor out; + framework::TensorCopySync(*tensor_in, expected_place, &out); + SetTensorToVariable(*var, out, trans_var); } } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index c280eeaa0fa57..9ad13299a3773 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -525,13 +525,27 @@ class OperatorWithKernel : public OperatorBase { } bool SupportGPU() const override { - auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); - return std::any_of(op_kernels.begin(), op_kernels.end(), - [](OpKernelMap::const_reference kern_pair) { - return platform::is_gpu_place(kern_pair.first.place_); - }); + auto pten_kernels = pten::KernelFactory::Instance().SelectKernelMap( + pten::TransToPtenKernelName(type_)); + auto has_pten_kernel = std::any_of( + pten_kernels.begin(), pten_kernels.end(), + [](pten::KernelFactory::KernelKeyMap::const_reference kern_pair) { + return kern_pair.first.backend() == pten::Backend::GPU; + }); + if (has_pten_kernel) { + return true; + } else { + auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); + return std::any_of( + op_kernels.begin(), op_kernels.end(), + [](OpKernelMap::const_reference kern_pair) { + return platform::is_gpu_place(kern_pair.first.place_); + }); + } } + bool SupportNPU() const override { + // TODO(zhiqiu): support pten if needed? auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); return std::any_of(op_kernels.begin(), op_kernels.end(), [](OpKernelMap::const_reference kern_pair) { @@ -539,6 +553,7 @@ class OperatorWithKernel : public OperatorBase { }); } bool SupportMLU() const override { + // TODO(zhiqiu): support pten if needed? auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); return std::any_of(op_kernels.begin(), op_kernels.end(), [](OpKernelMap::const_reference kern_pair) { @@ -583,18 +598,18 @@ class OperatorWithKernel : public OperatorBase { * When selecting Kernel during Op execution, select the arguments of the * original Op according to the GetExpectedPtenKernelArgs returned arguments. */ - virtual KernelSignature GetExpectedPtenKernelArgs( + virtual pten::KernelSignature GetExpectedPtenKernelArgs( const ExecutionContext& ctx) const; /* member functions for adapting to pten lib */ - void ChoosePtenKernel(const ExecutionContext& ctx) const; + pten::KernelKey ChoosePtenKernel(const ExecutionContext& ctx) const; /** * Transfer data place for pten kernel * Is this really needed? */ Scope* PreparePtenData(const Scope& scope, const pten::Kernel& pt_kernel, - const KernelSignature& pt_kernel_signature, + const pten::KernelSignature& pt_kernel_signature, RuntimeContext* ctx) const; void BuildPtenKernelContext(const RuntimeContext& ctx, @@ -604,8 +619,16 @@ class OperatorWithKernel : public OperatorBase { void WriteBackToOutputs(RuntimeContext* ctx, pten::KernelContext* pt_kernel_context) const; + pten::KernelSignature* PtenKernelSignature() const { + return pt_kernel_signature_.get(); + } + pten::Kernel* PtenKernel() const { return pt_kernel_.get(); } + void ResetPtenKernel(pten::Kernel* kernel) const { + return pt_kernel_.reset(kernel); + } + const OpKernelType* kernel_type() const { return kernel_type_.get(); } private: @@ -662,7 +685,7 @@ class OperatorWithKernel : public OperatorBase { // new pten kernel, if there is a better design in the future, // we may polish the implementation here mutable bool run_pten_kernel_ = false; - mutable std::unique_ptr pt_kernel_signature_; + mutable std::unique_ptr pt_kernel_signature_; mutable std::unique_ptr pt_kernel_; }; diff --git a/paddle/fluid/framework/pten_utils.cc b/paddle/fluid/framework/pten_utils.cc index dc20aaffec9ca..336f8423d6f0b 100644 --- a/paddle/fluid/framework/pten_utils.cc +++ b/paddle/fluid/framework/pten_utils.cc @@ -90,6 +90,40 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey( return pten::KernelKey(backend, layout, dtype); } +pten::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, + const pten::KernelKey& kernel_key, + const framework::OperatorBase& op) { +#ifdef PADDLE_WITH_XPU + if (platform::is_xpu_place(expected_kernel_key.place_) || + paddle::platform::is_in_xpu_black_list(op.Type())) { + VLOG(3) << "pten missing XPU kernel: " << op.Type() + << ", expected_kernel_key:" << expected_kernel_key + << ", fallbacking to CPU one!"; + return pten::KernelKey(pten::Backend::CPU, kernel_key.layout(), + kernel_key.dtype()); + } +#endif +#ifdef PADDLE_WITH_ASCEND_CL + if (platform::is_npu_place(expected_kernel_key.place_)) { + VLOG(3) << "pten missing NPU kernel: " << op.Type() + << ", expected_kernel_key:" << expected_kernel_key + << ", fallbacking to CPU one!"; + return pten::KernelKey(pten::Backend::CPU, kernel_key.layout(), + kernel_key.dtype()); + } +#endif +#ifdef PADDLE_WITH_MLU + if (platform::is_mlu_place(expected_kernel_key.place_)) { + VLOG(3) << "pten missing MLU kernel: " << op.Type() + << ", expected_kernel_key:" << expected_kernel_key + << ", fallbacking to CPU one!"; + return pten::KernelKey(pten::Backend::CPU, kernel_key.layout(), + kernel_key.dtype()); + } +#endif + return pten::KernelKey(); +} + const paddle::SmallVector& KernelArgsNameMakerByOpProto::GetInputArgsNames() { for (int i = 0; i < op_proto_->inputs_size(); ++i) { diff --git a/paddle/fluid/framework/pten_utils.h b/paddle/fluid/framework/pten_utils.h index 9b1019f658237..2d335fc9c9894 100644 --- a/paddle/fluid/framework/pten_utils.h +++ b/paddle/fluid/framework/pten_utils.h @@ -24,12 +24,18 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/place.h" + +#include "paddle/fluid/framework/operator.h" #include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/core/compat/arg_map_context.h" #include "paddle/pten/core/kernel_factory.h" #include "paddle/utils/flat_hash_map.h" #include "paddle/utils/small_vector.h" +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/platform/device/xpu/xpu_op_list.h" +#endif + namespace paddle { namespace framework { @@ -41,6 +47,9 @@ OpKernelType TransPtenKernelKeyToOpKernelType( const pten::KernelKey& kernel_key); pten::KernelKey TransOpKernelTypeToPtenKernelKey( const OpKernelType& kernel_type); +pten::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, + const pten::KernelKey& kernel_key, + const framework::OperatorBase& op); /* Kernel Args parse */ diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index d9a21c9247b93..5d6df145ab356 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -55,21 +55,6 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var) { } } -static const framework::Attribute& GetAttr( - const framework::AttributeMap& attrs, - const framework::AttributeMap& default_attrs, const std::string& name) { - auto it = attrs.find(name); - bool found = it != attrs.end(); - if (!found) { - it = default_attrs.find(name); - found = it != default_attrs.end(); - } - PADDLE_ENFORCE_EQ( - found, true, - platform::errors::NotFound("(%s) is not found in AttributeMap.", name)); - return it->second; -} - template static void HandleComplexGradToRealGrad(const NameVarMap& outs) { for (auto& pair : outs) { @@ -152,6 +137,9 @@ PreparedOp PrepareImpl(const NameVarMap& ins, } } #endif + // NOTE(zhiqiu): for kernels on given device, for example NPU, the order to + // choose is: + // pten npu kernel > fluid npu kernel > pten cpu kernel > fluid cpu kernel // 1. get expected kernel key auto dygraph_exe_ctx = DygraphExecutionContext( @@ -159,13 +147,15 @@ PreparedOp PrepareImpl(const NameVarMap& ins, auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx); VLOG(3) << "expected_kernel_key:" << expected_kernel_key; - if (FLAGS_run_pten_kernel && - pten::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) { - auto pt_kernel_signature = op.GetExpectedPtenKernelArgs(dygraph_exe_ctx); + framework::KernelSignature pt_kernel_signature; + pten::KernelKey pt_kernel_key; + std::string pt_kernel_name; + if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) { + pt_kernel_signature = op.GetExpectedPtenKernelArgs(dygraph_exe_ctx); VLOG(6) << pt_kernel_signature; - auto pt_kernel_name = pt_kernel_signature.name; - auto pt_kernel_key = TransOpKernelTypeToPtenKernelKey(expected_kernel_key); + pt_kernel_name = pt_kernel_signature.name; + pt_kernel_key = TransOpKernelTypeToPtenKernelKey(expected_kernel_key); auto pt_kernel = pten::KernelFactory::Instance().SelectKernel( pt_kernel_name, pt_kernel_key); @@ -191,14 +181,42 @@ PreparedOp PrepareImpl(const NameVarMap& ins, // 2. check if op[type] has kernel registered. auto& all_op_kernels = op.AllOpKernels(); auto kernels_iter = all_op_kernels.find(op.Type()); + + if ((kernels_iter == all_op_kernels.end() || + kernels_iter->second.find(expected_kernel_key) == + kernels_iter->second.end()) +#ifdef PADDLE_WITH_XPU + || + paddle::platform::is_xpu_place(expected_kernel_key.place_) && + !paddle::platform::is_xpu_support_op(op.Type(), + expected_kernel_key) || + paddle::platform::is_in_xpu_black_list(op.Type()) +#endif + ) { + if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) { + auto pt_cpu_kernel_key = + FallBackToCpu(expected_kernel_key, pt_kernel_key, op); + auto pt_cpu_kernel = pten::KernelFactory::Instance().SelectKernel( + pt_kernel_name, pt_cpu_kernel_key); + if (pt_cpu_kernel.IsValid()) { + VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name + << " | kernel key: " << pt_cpu_kernel_key + << " | kernel: " << pt_cpu_kernel; + auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace()); + return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature, + pt_cpu_kernel, cpu_ctx); + } + } + } + PADDLE_ENFORCE_NE( kernels_iter, all_op_kernels.end(), platform::errors::NotFound( "There are no kernels which are registered in the %s operator.", op.Type())); - auto& kernels = kernels_iter->second; auto kernel_iter = kernels.find(expected_kernel_key); + #ifdef PADDLE_WITH_XPU if (paddle::platform::is_xpu_place(expected_kernel_key.place_) && (kernel_iter == kernels.end() || @@ -264,237 +282,6 @@ PreparedOp PreparedOp::Prepare(const NameVarMap& ins, default_attrs); } -template -void PreparePtenData(const pten::Kernel& pt_kernel, - const framework::KernelSignature& pt_kernel_signature, - const NameVarMap& ins) { - auto& input_names = std::get<0>(pt_kernel_signature.args); - auto& input_defs = pt_kernel.args_def().input_defs(); - - PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(), - platform::errors::InvalidArgument( - "the size of inputs_args names (%d) must be equal to " - "the size of kernel input_defs (%d).", - input_names.size(), input_defs.size())); - - for (size_t i = 0; i < input_names.size(); ++i) { - auto& in_def = input_defs.at(i); - auto& ins_vector = ins.at(input_names[i]); - - for (size_t offset = 0; offset < ins_vector.size(); ++offset) { - auto var_base = ins_vector[offset]; - const auto* tensor_in = GetTensorFromVar(var_base->Var()); - if (tensor_in && tensor_in->IsInitialized()) { - auto expected_place = pten::TransToFluidPlace(in_def.backend); - if (platform::is_same_place(tensor_in->place(), expected_place)) { - continue; - } - - // TODO(zyfncg): Now there is no kernel which need to transform input - // data, so we commented out following code temporarily, - // and it will be used in the future. - - // VLOG(3) << "Pten Transform Variable " << var_base->Name() << " from " - // << tensor_in->place() << " to " << expected_place; - - // framework::Tensor tmp_tensor; - // framework::TensorCopySync(*tensor_in, expected_place, &tmp_tensor); - - // SetTensorToVariable(var_base->Var(), tmp_tensor, - // var_base->MutableVar()); - } - } - } -} - -template -static void BuildDygraphPtenKernelContext( - const framework::KernelSignature& pt_kernel_signature, - const pten::Kernel& pt_kernel, const NameVarMap& ins, - const NameVarMap& outs, const framework::AttributeMap& attrs, - const framework::AttributeMap& default_attrs, - platform::DeviceContext* dev_ctx, pten::KernelContext* kernel_ctx) { - kernel_ctx->SetDeviceContext(dev_ctx); - - auto& input_names = std::get<0>(pt_kernel_signature.args); - auto& attr_names = std::get<1>(pt_kernel_signature.args); - auto& output_names = std::get<2>(pt_kernel_signature.args); - - auto& input_defs = pt_kernel.args_def().input_defs(); - auto& output_defs = pt_kernel.args_def().output_defs(); - auto& attr_defs = pt_kernel.args_def().attribute_defs(); - - PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(), - platform::errors::InvalidArgument( - "the size of inputs_args names (%d) must be equal to " - "the size of kernel input_defs (%d).", - input_names.size(), input_defs.size())); - - PADDLE_ENFORCE_EQ(output_names.size(), output_defs.size(), - platform::errors::InvalidArgument( - "the size of outputs_args names (%d) must be equal to " - "the size of kernel output_defs (%d).", - output_names.size(), output_defs.size())); - - PADDLE_ENFORCE_EQ(attr_names.size(), attr_defs.size(), - platform::errors::InvalidArgument( - "the size of attribute_args names (%d) must be equal " - "to the size of kernel attribute_defs (%d).", - attr_names.size(), attr_defs.size())); - - for (size_t i = 0; i < input_names.size(); ++i) { - auto& ins_vector = ins.at(input_names[i]); - - size_t start_idx = (i == 0 ? 0 : kernel_ctx->InputRangeAt(i - 1).second); - size_t end_idx = start_idx + ins_vector.size(); - - for (size_t offset = 0; offset < ins_vector.size(); ++offset) { - const auto* tensor_in = GetTensorFromVar(ins_vector[offset]->Var()); - kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in); - } - kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i); - } - - for (size_t i = 0; i < output_names.size(); ++i) { - size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second); - - auto iter = outs.find(output_names[i]); - if (iter == outs.end()) { - kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr}); - kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1), - i); - continue; - } - - auto& outs_vector = iter->second; - size_t end_idx = start_idx + outs_vector.size(); - - for (size_t offset = 0; offset < outs_vector.size(); ++offset) { - if (outs_vector[offset] == nullptr) { - kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr}); - continue; - } - auto* var = outs_vector[offset]->MutableVar(); - framework::Tensor* tensor_out = nullptr; - if (var->template IsType()) { - tensor_out = var->template GetMutable(); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported output `%s` type when call pt kernel.", - framework::ToTypeName(var->Type()))); - } // TODO(zyfncg): Add support for SelectedRows - - experimental::ResetTensorByArgDef(tensor_out, output_defs.at(i)); - framework::SetAllocationForOutputTenosr( - tensor_out, pten::TransToFluidPlace(output_defs.at(i).backend)); - - kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out); - } - kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i); - } - - for (size_t i = 0; i < attr_names.size(); ++i) { - if (attr_defs[i].type_index == std::type_index(typeid(pten::ScalarArray))) { - if (attrs.find(attr_names[i]) != - attrs.end()) { // shape is in the attribute - auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); - if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { - kernel_ctx->EmplaceBackAttr(std::move( - pten::ScalarArray(BOOST_GET_CONST(std::vector, attr)))); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { - kernel_ctx->EmplaceBackAttr(std::move( - pten::ScalarArray(BOOST_GET_CONST(std::vector, attr)))); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported cast op attribute `%s` to VectorTensor when " - "construct KernelContext.", - attr_names[i])); - } - } else { // shape is in the input - auto& ins_vector = ins.at(attr_names[i]); - if (ins_vector.size() == 1) { // ShapeTensor - kernel_ctx->EmplaceBackAttr(std::move( - experimental::MakePtenScalarArrayFromVar(ins_vector[0]->Var()))); - } else { // ShapeTensorList - std::vector variables; - variables.reserve(ins_vector.size()); - for (const auto& var_base : ins_vector) { - variables.push_back(var_base->MutableVar()); - } - kernel_ctx->EmplaceBackAttr(std::move( - experimental::MakePtenScalarArrayFromVarList(variables))); - } - } - } else if (attr_defs[i].type_index == - std::type_index(typeid(pten::Scalar))) { - // TODO(chenweihang): support other attrs later - // TODO(zhangyunfei): Scalar should hold scaler type, and we should check - // attribtue type by attr_defs - if (attrs.find(attr_names[i]) != attrs.end() || - default_attrs.find(attr_names[i]) != - default_attrs.end()) { // scalar is in the attribute - auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); - if (std::type_index(attr.type()) == std::type_index(typeid(float))) { - kernel_ctx->EmplaceBackAttr( - std::move(pten::Scalar(BOOST_GET_CONST(float, attr)))); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::string))) { - kernel_ctx->EmplaceBackAttr( - std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr)))); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(int))) { - kernel_ctx->EmplaceBackAttr( - std::move(pten::Scalar(BOOST_GET_CONST(int, attr)))); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported cast op attribute `%s` to Scalar when construct " - "KernelContext in dygraph.", - attr_names[i])); - } - } else { // scalar is in the input - auto& ins_vector = ins.at(attr_names[i]); - kernel_ctx->EmplaceBackAttr(std::move( - experimental::MakePtenScalarFromVar(ins_vector[0]->Var()))); - } - - } else { - // TODO(chenweihang): support other attrs later - auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); - if (attr_defs[i].type_index == std::type_index(typeid(int))) { - kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr)); - } else if (attr_defs[i].type_index == std::type_index(typeid(float))) { - kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); - } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { - kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(pten::DataType))) { - auto data_type = pten::TransToPtenDataType( - static_cast( - BOOST_GET_CONST(int, attr))); - kernel_ctx->EmplaceBackAttr(data_type); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { - if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { - // Emplace Back Attr according to the type of Pten_Kernel args. - const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); - const std::vector vector_int64_attr(vector_int_attr.begin(), - vector_int_attr.end()); - kernel_ctx->EmplaceBackAttr(vector_int64_attr); - } - // TODO(YuanRisheng) Need support vector attr - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported cast op attribute `%s` when construct " - "KernelContext in dygraph.", - attr_names[i])); - } - } - } -} - template static void PreparedOpRunImpl( const framework::OperatorBase& op, const framework::RuntimeContext& ctx, diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 1a66fe0a05620..f9165e8ee23e3 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -194,5 +194,247 @@ class PreparedOp { pten::Kernel pt_kernel_; }; +const inline framework::Attribute& GetAttr( + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs, const std::string& name) { + auto it = attrs.find(name); + bool found = it != attrs.end(); + if (!found) { + it = default_attrs.find(name); + found = it != default_attrs.end(); + } + PADDLE_ENFORCE_EQ( + found, true, + platform::errors::NotFound("(%s) is not found in AttributeMap.", name)); + return it->second; +} + +template +void BuildDygraphPtenKernelContext( + const framework::KernelSignature& pt_kernel_signature, + const pten::Kernel& pt_kernel, const NameVarMap& ins, + const NameVarMap& outs, const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs, + platform::DeviceContext* dev_ctx, pten::KernelContext* kernel_ctx) { + kernel_ctx->SetDeviceContext(dev_ctx); + + auto& input_names = std::get<0>(pt_kernel_signature.args); + auto& attr_names = std::get<1>(pt_kernel_signature.args); + auto& output_names = std::get<2>(pt_kernel_signature.args); + + auto& input_defs = pt_kernel.args_def().input_defs(); + auto& output_defs = pt_kernel.args_def().output_defs(); + auto& attr_defs = pt_kernel.args_def().attribute_defs(); + + PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(), + platform::errors::InvalidArgument( + "the size of inputs_args names (%d) must be equal to " + "the size of kernel input_defs (%d).", + input_names.size(), input_defs.size())); + + PADDLE_ENFORCE_EQ(output_names.size(), output_defs.size(), + platform::errors::InvalidArgument( + "the size of outputs_args names (%d) must be equal to " + "the size of kernel output_defs (%d).", + output_names.size(), output_defs.size())); + + PADDLE_ENFORCE_EQ(attr_names.size(), attr_defs.size(), + platform::errors::InvalidArgument( + "the size of attribute_args names (%d) must be equal " + "to the size of kernel attribute_defs (%d).", + attr_names.size(), attr_defs.size())); + + for (size_t i = 0; i < input_names.size(); ++i) { + auto& ins_vector = ins.at(input_names[i]); + + size_t start_idx = (i == 0 ? 0 : kernel_ctx->InputRangeAt(i - 1).second); + size_t end_idx = start_idx + ins_vector.size(); + + for (size_t offset = 0; offset < ins_vector.size(); ++offset) { + const auto* tensor_in = GetTensorFromVar(ins_vector[offset]->Var()); + kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in); + } + kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i); + } + + for (size_t i = 0; i < output_names.size(); ++i) { + size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second); + + auto iter = outs.find(output_names[i]); + if (iter == outs.end()) { + kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr}); + kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1), + i); + continue; + } + + auto& outs_vector = iter->second; + size_t end_idx = start_idx + outs_vector.size(); + + for (size_t offset = 0; offset < outs_vector.size(); ++offset) { + if (outs_vector[offset] == nullptr) { + kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr}); + continue; + } + auto* var = outs_vector[offset]->MutableVar(); + framework::Tensor* tensor_out = nullptr; + if (var->template IsType()) { + tensor_out = var->template GetMutable(); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported output `%s` type when call pt kernel.", + framework::ToTypeName(var->Type()))); + } // TODO(zyfncg): Add support for SelectedRows + + experimental::ResetTensorByArgDef(tensor_out, output_defs.at(i)); + framework::SetAllocationForOutputTenosr( + tensor_out, pten::TransToFluidPlace(output_defs.at(i).backend)); + + kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out); + } + kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i); + } + + for (size_t i = 0; i < attr_names.size(); ++i) { + if (attr_defs[i].type_index == std::type_index(typeid(pten::ScalarArray))) { + if (attrs.find(attr_names[i]) != + attrs.end()) { // shape is in the attribute + auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); + if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + kernel_ctx->EmplaceBackAttr(std::move( + pten::ScalarArray(BOOST_GET_CONST(std::vector, attr)))); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + kernel_ctx->EmplaceBackAttr(std::move( + pten::ScalarArray(BOOST_GET_CONST(std::vector, attr)))); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to VectorTensor when " + "construct KernelContext.", + attr_names[i])); + } + } else { // shape is in the input + auto& ins_vector = ins.at(attr_names[i]); + if (ins_vector.size() == 1) { // ShapeTensor + kernel_ctx->EmplaceBackAttr(std::move( + experimental::MakePtenScalarArrayFromVar(ins_vector[0]->Var()))); + } else { // ShapeTensorList + std::vector variables; + variables.reserve(ins_vector.size()); + for (const auto& var_base : ins_vector) { + variables.push_back(var_base->MutableVar()); + } + kernel_ctx->EmplaceBackAttr(std::move( + experimental::MakePtenScalarArrayFromVarList(variables))); + } + } + } else if (attr_defs[i].type_index == + std::type_index(typeid(pten::Scalar))) { + // TODO(chenweihang): support other attrs later + // TODO(zhangyunfei): Scalar should hold scaler type, and we should check + // attribtue type by attr_defs + if (attrs.find(attr_names[i]) != attrs.end() || + default_attrs.find(attr_names[i]) != + default_attrs.end()) { // scalar is in the attribute + auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); + if (std::type_index(attr.type()) == std::type_index(typeid(float))) { + kernel_ctx->EmplaceBackAttr( + std::move(pten::Scalar(BOOST_GET_CONST(float, attr)))); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::string))) { + kernel_ctx->EmplaceBackAttr( + std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr)))); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(int))) { + kernel_ctx->EmplaceBackAttr( + std::move(pten::Scalar(BOOST_GET_CONST(int, attr)))); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to Scalar when construct " + "KernelContext in dygraph.", + attr_names[i])); + } + } else { // scalar is in the input + auto& ins_vector = ins.at(attr_names[i]); + kernel_ctx->EmplaceBackAttr(std::move( + experimental::MakePtenScalarFromVar(ins_vector[0]->Var()))); + } + + } else { + // TODO(chenweihang): support other attrs later + auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); + if (attr_defs[i].type_index == std::type_index(typeid(int))) { + kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr)); + } else if (attr_defs[i].type_index == std::type_index(typeid(float))) { + kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); + } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { + kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); + } else if (attr_defs[i].type_index == + std::type_index(typeid(pten::DataType))) { + auto data_type = pten::TransToPtenDataType( + static_cast( + BOOST_GET_CONST(int, attr))); + kernel_ctx->EmplaceBackAttr(data_type); + } else if (attr_defs[i].type_index == + std::type_index(typeid(std::vector))) { + if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + // Emplace Back Attr according to the type of Pten_Kernel args. + const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); + const std::vector vector_int64_attr(vector_int_attr.begin(), + vector_int_attr.end()); + kernel_ctx->EmplaceBackAttr(vector_int64_attr); + } + // TODO(YuanRisheng) Need support vector attr + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` when construct " + "KernelContext in dygraph.", + attr_names[i])); + } + } + } +} + +template +void PreparePtenData(const pten::Kernel& pt_kernel, + const framework::KernelSignature& pt_kernel_signature, + const NameVarMap& ins) { + auto& input_names = std::get<0>(pt_kernel_signature.args); + auto& input_defs = pt_kernel.args_def().input_defs(); + + PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(), + platform::errors::InvalidArgument( + "the size of inputs_args names (%d) must be equal to " + "the size of kernel input_defs (%d).", + input_names.size(), input_defs.size())); + + for (size_t i = 0; i < input_names.size(); ++i) { + auto& in_def = input_defs.at(i); + auto& ins_vector = ins.at(input_names[i]); + + for (size_t offset = 0; offset < ins_vector.size(); ++offset) { + auto var_base = ins_vector[offset]; + const auto* tensor_in = GetTensorFromVar(var_base->Var()); + if (tensor_in && tensor_in->IsInitialized()) { + auto expected_place = pten::TransToFluidPlace(in_def.backend); + if (platform::is_same_place(tensor_in->place(), expected_place)) { + continue; + } + + VLOG(3) << "Pten Transform Variable " << input_names[i] << " from " + << tensor_in->place() << " to " << expected_place; + + framework::Tensor tmp_tensor; + framework::TensorCopySync(*tensor_in, expected_place, &tmp_tensor); + + SetTensorToVariable(var_base->Var(), tmp_tensor, + var_base->MutableVar()); + } + } + } +} + } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index f976e217bab1a..73dc34b22dd89 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -26,6 +26,9 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/pybind/pybind.h" +// pten +#include "paddle/pten/kernels/declarations.h" + DEFINE_string(devices, "", "The devices to be used which is joined by comma."); DEFINE_int32(math_num_threads, 1, "Number of threads used to run math functions."); diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index cbc61fc804397..b87cdf6f6df19 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -115,10 +115,8 @@ if (WITH_GPU OR WITH_ROCM) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() op_library(sync_batch_norm_op) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n") if ((NOT WIN32) AND (NOT WITH_ROCM) AND (NOT PADDLE_WITH_ARM) AND (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 11.3) ) op_library(sparse_attention_op) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sparse_attention);\n") endif() else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) @@ -142,7 +140,6 @@ endif() if (WITH_ASCEND_CL) op_library(sync_batch_norm_op) - file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(sync_batch_norm);\n") endif() op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute) @@ -153,7 +150,6 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS}) if (WITH_DGC) op_library(dgc_op DEPS dgc) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(dgc);\n") set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dgc) endif() diff --git a/paddle/fluid/operators/benchmark/op_tester.cc b/paddle/fluid/operators/benchmark/op_tester.cc index ace2b656e8efb..7b90dddca069f 100644 --- a/paddle/fluid/operators/benchmark/op_tester.cc +++ b/paddle/fluid/operators/benchmark/op_tester.cc @@ -24,6 +24,9 @@ limitations under the License. */ #include "paddle/fluid/platform/timer.h" #include "paddle/fluid/pybind/pybind.h" +// pten +#include "paddle/pten/kernels/declarations.h" + namespace paddle { namespace operators { namespace benchmark { diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index 90d665eb93bcb..9e2fe6e2d066e 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -133,23 +133,27 @@ class CastOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; using CPU = paddle::platform::CPUDeviceContext; -#define REGISTER_CAST_CPU_BASE(op_name, ...) \ - REGISTER_OPERATOR(op_name, ops::CastOp, \ - ops::CastOpGradMaker, \ - ops::CastOpGradMaker, \ - ops::CastOpProtoMaker); \ - REGISTER_OP_CPU_KERNEL( \ - op_name, ops::CastOpKernel, ops::CastOpKernel, \ - ops::CastOpKernel, ops::CastOpKernel, \ - ops::CastOpKernel, ops::CastOpKernel, \ - ops::CastOpKernel, ops::CastOpKernel, \ - ops::CastOpKernel, \ - ops::CastOpKernel, \ - ops::CastOpKernel>, \ - ops::CastOpKernel>); - -REGISTER_CAST_CPU_BASE(cast) + +// cast use pten kernel, so no need to REGISTER_OP_CPU_KERNEL here. +REGISTER_OPERATOR(cast, ops::CastOp, + ops::CastOpGradMaker, + ops::CastOpGradMaker, + ops::CastOpProtoMaker); + // [ why register transfer_dtype_op alias with cast_op? ] // In case of InterpreterCore, if we reuse cast_op, we cannot distinguish // which cast_op is inserted by new executor when we do profiling. -REGISTER_CAST_CPU_BASE(transfer_dtype) +REGISTER_OPERATOR(transfer_dtype, ops::CastOp, + ops::CastOpGradMaker, + ops::CastOpGradMaker, + ops::CastOpProtoMaker); +REGISTER_OP_CPU_KERNEL( + transfer_dtype, ops::CastOpKernel, + ops::CastOpKernel, ops::CastOpKernel, + ops::CastOpKernel, ops::CastOpKernel, + ops::CastOpKernel, ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel>, + ops::CastOpKernel>); diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index 8ae68435ada10..5c7dd0e2561fa 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -30,10 +30,8 @@ using CUDA = paddle::platform::CUDADeviceContext; ops::CastOpKernel>, ##__VA_ARGS__); #if !defined(PADDLE_WITH_HIP) -REGISTER_CAST_CUDA_BASE(cast, ops::CastOpKernel) // See [ why register transfer_dtype_op alias with cast_op? ] in cast_op.cc REGISTER_CAST_CUDA_BASE(transfer_dtype, ops::CastOpKernel) #else -REGISTER_CAST_CUDA_BASE(cast) REGISTER_CAST_CUDA_BASE(transfer_dtype) #endif diff --git a/paddle/fluid/operators/controlflow/CMakeLists.txt b/paddle/fluid/operators/controlflow/CMakeLists.txt index d2ad93bbae921..1a2df2a0c7ba3 100644 --- a/paddle/fluid/operators/controlflow/CMakeLists.txt +++ b/paddle/fluid/operators/controlflow/CMakeLists.txt @@ -22,9 +22,3 @@ endif() file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal_all);\nUSE_NO_KERNEL_OP(read_from_array);\n") file(APPEND ${pybind_file} "USE_OP(logical_and);\nUSE_OP(logical_or);\nUSE_OP(logical_xor);\nUSE_OP(logical_not);\n") file(APPEND ${pybind_file} "USE_OP(bitwise_and);\nUSE_OP(bitwise_or);\nUSE_OP(bitwise_xor);\nUSE_OP(bitwise_not);\n") - -if(WITH_XPU) - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(equal, XPU);\nUSE_OP_DEVICE_KERNEL(not_equal, XPU);\n") - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(less_than, XPU);\nUSE_OP_DEVICE_KERNEL(less_equal, XPU);\n") - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(greater_than, XPU);\nUSE_OP_DEVICE_KERNEL(greater_equal, XPU);\n") -endif() diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index a85bca3646499..1ebafa5459857 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -8,7 +8,20 @@ function(detection_library TARGET_NAME) set(pybind_flag 0) cmake_parse_arguments(detection_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - op_library(${TARGET_NAME} SRCS ${detection_library_SRCS} DEPS ${common_deps} ${detection_library_DEPS}) + set(srcs) + # filter cuda source file when not build with cuda/rocm + foreach(src ${detection_library_SRCS}) + if (NOT WITH_GPU AND NOT WITH_ROCM) + if(${src} MATCHES ".*\\.cc$") + list(APPEND srcs ${src}) + endif() + else() + list(APPEND srcs ${src}) + endif() + endforeach() + + op_library(${TARGET_NAME} SRCS ${srcs} DEPS ${common_deps} ${detection_library_DEPS}) + set(LOCAL_DETECTION_LIBS ${TARGET_NAME} ${LOCAL_DETECTION_LIBS} diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index eec925b2c057b..67287afa6ae50 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -24,8 +24,6 @@ register_operators(EXCLUDES # fusion_gru_op does not have CUDA kernel op_library(fusion_gru_op) op_library(fusion_lstm_op) -file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(fusion_gru);\n") -file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(fusion_lstm);\n") if (WITH_GPU OR WITH_ROCM) @@ -33,46 +31,36 @@ if (WITH_GPU OR WITH_ROCM) # HIP not support bn act fuse in MIOPEN if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7401)) op_library(fused_bn_activation_op) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_batch_norm_act);\n") endif() # conv_fusion_op needs cudnn 7 above if (NOT ${CUDNN_VERSION} VERSION_LESS 7100) op_library(conv_fusion_op) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n") endif() # fusion_transpose_flatten_concat_op # HIP not support cudnnTransformTensor if(NOT WITH_ROCM) op_library(fusion_transpose_flatten_concat_op) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fusion_transpose_flatten_concat);\n") endif() # fusion_conv_inception_op needs cudnn 7 above # HIP not support cudnnConvolutionBiasActivationForward if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7100)) op_library(fusion_conv_inception_op) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_inception_fusion);\n") endif() # fused_fc_elementwise_layernorm_op op_library(fused_fc_elementwise_layernorm_op) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_fc_elementwise_layernorm);\n") # multihead_matmul_op op_library(multihead_matmul_op) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(multihead_matmul);\n") op_library(skip_layernorm_op) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(skip_layernorm);\n") op_library(fused_embedding_eltwise_layernorm_op) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_embedding_eltwise_layernorm);\n") # fusion_group if(NOT APPLE AND NOT WIN32) op_library(fusion_group_op DEPS device_code) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fusion_group);\n") cc_test(test_fusion_group_op SRCS fusion_group_op_test.cc DEPS fusion_group_op) endif() # fused_bn_add_activation # HIP not support bn act fuse in MIOPEN if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7401)) op_library(fused_bn_add_activation_op) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_bn_add_activation);\n") endif() # fused_dropout # only support CUDA @@ -82,15 +70,12 @@ if (WITH_GPU OR WITH_ROCM) nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) op_library(fused_feedforward_op) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_feedforward);\n") # fused_attention_op op_library(fused_attention_op) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_attention);\n") endif() # resnet_unit needs cudnn 8.0 above if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000)) op_library(resnet_unit_op) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(resnet_unit);\n") cc_test(test_cudnn_norm_conv SRCS cudnn_norm_conv_test.cc DEPS conv_op blas im2col vol2col depthwise_conv eigen_function tensor op_registry device_context generator memory) cc_test(test_cudnn_bn_add_relu SRCS cudnn_bn_add_relu_test.cc DEPS batch_norm_op fused_bn_add_activation_op tensor op_registry device_context generator memory) endif() diff --git a/paddle/fluid/operators/isfinite_v2_op.cc b/paddle/fluid/operators/isfinite_v2_op.cc index 3b48a41ed4f75..05a394e682d67 100644 --- a/paddle/fluid/operators/isfinite_v2_op.cc +++ b/paddle/fluid/operators/isfinite_v2_op.cc @@ -103,39 +103,69 @@ element of X as a tensor. namespace ops = paddle::operators; -#define REGISTER_V2OP_MAKER(op_type, comment) \ - namespace paddle { \ - namespace operators { \ - class _##op_type##OverflowV2OpMaker \ - : public ::paddle::operators::OverflowV2OpMaker { \ - protected: \ - std::string GetName() const { return #op_type; } \ - std::string GetComments() const { return comment; } \ - }; \ - } \ - } \ - REGISTER_OPERATOR( \ - op_type, ops::OverflowV2Op, ops::_##op_type##OverflowV2OpMaker, \ - paddle::framework::EmptyGradOpMaker, \ - paddle::framework::EmptyGradOpMaker) - -#define REGISTER_OVERFLOW_CPU_KERNEL(op_type, functor) \ - REGISTER_OP_CPU_KERNEL( \ - op_type, ops::OverflowKernel, \ - ops::OverflowKernel, \ - ops::OverflowKernel, \ - ops::OverflowKernel, \ - ops::OverflowKernel); - -REGISTER_V2OP_MAKER(isinf_v2, "isinfv2(X)"); -REGISTER_V2OP_MAKER(isnan_v2, "isnanv2(X)"); +#define REGISTER_V2OP_MAKER(op_type, comment) \ + namespace paddle { \ + namespace operators { \ + class _##op_type##OverflowV2OpMaker \ + : public ::paddle::operators::OverflowV2OpMaker { \ + protected: \ + std::string GetName() const { return #op_type; } \ + std::string GetComments() const { return comment; } \ + }; \ + } \ + } + +REGISTER_V2OP_MAKER(isinf_v2, "isinfv2(X)") +REGISTER_V2OP_MAKER(isnan_v2, "isnanv2(X)") REGISTER_V2OP_MAKER(isfinite_v2, "isfinitev2(X)"); -REGISTER_OVERFLOW_CPU_KERNEL(isinf_v2, InfinityV2Functor); -REGISTER_OVERFLOW_CPU_KERNEL(isnan_v2, NANV2Functor); -REGISTER_OVERFLOW_CPU_KERNEL(isfinite_v2, IsfiniteV2Functor); +REGISTER_OPERATOR( + isinf_v2, ops::OverflowV2Op, ops::_isinf_v2OverflowV2OpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OPERATOR( + isnan_v2, ops::OverflowV2Op, ops::_isnan_v2OverflowV2OpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OPERATOR( + isfinite_v2, ops::OverflowV2Op, ops::_isfinite_v2OverflowV2OpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL(isnan_v2, + ops::OverflowKernel, + ops::OverflowKernel, + ops::OverflowKernel, + ops::OverflowKernel, + ops::OverflowKernel); + +REGISTER_OP_CPU_KERNEL( + isinf_v2, ops::OverflowKernel, + ops::OverflowKernel, + ops::OverflowKernel, + ops::OverflowKernel, + ops::OverflowKernel); + +REGISTER_OP_CPU_KERNEL( + isfinite_v2, ops::OverflowKernel, + ops::OverflowKernel, + ops::OverflowKernel, + ops::OverflowKernel, + ops::OverflowKernel); diff --git a/paddle/fluid/operators/isfinite_v2_op.cu b/paddle/fluid/operators/isfinite_v2_op.cu index 4a6d818d0501e..1b9f19d36dfa0 100644 --- a/paddle/fluid/operators/isfinite_v2_op.cu +++ b/paddle/fluid/operators/isfinite_v2_op.cu @@ -18,19 +18,38 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; -#define REGISTER_OVERFLOW_CUDA_KERNEL(op_type, functor) \ - REGISTER_OP_CUDA_KERNEL( \ - op_type, ops::OverflowKernel, \ - ops::OverflowKernel, \ - ops::OverflowKernel, \ - ops::OverflowKernel, \ - ops::OverflowKernel); +REGISTER_OP_CUDA_KERNEL(isnan_v2, + ops::OverflowKernel, + ops::OverflowKernel, + ops::OverflowKernel, + ops::OverflowKernel, + ops::OverflowKernel); -REGISTER_OVERFLOW_CUDA_KERNEL(isinf_v2, InfinityV2Functor); -REGISTER_OVERFLOW_CUDA_KERNEL(isnan_v2, NANV2Functor); -REGISTER_OVERFLOW_CUDA_KERNEL(isfinite_v2, IsfiniteV2Functor); +REGISTER_OP_CUDA_KERNEL( + isinf_v2, ops::OverflowKernel, + ops::OverflowKernel, + ops::OverflowKernel, + ops::OverflowKernel, + ops::OverflowKernel); + +REGISTER_OP_CUDA_KERNEL( + isfinite_v2, ops::OverflowKernel, + ops::OverflowKernel, + ops::OverflowKernel, + ops::OverflowKernel, + ops::OverflowKernel); diff --git a/paddle/fluid/operators/nccl/CMakeLists.txt b/paddle/fluid/operators/nccl/CMakeLists.txt index 9a412228255d0..b3d53f0d39020 100644 --- a/paddle/fluid/operators/nccl/CMakeLists.txt +++ b/paddle/fluid/operators/nccl/CMakeLists.txt @@ -12,7 +12,6 @@ endif() if(WITH_GPU OR WITH_ROCM) op_library(nccl_op DEPS nccl_common) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(ncclAllReduce);\n") set(OPERATOR_DEPS ${OPERATOR_DEPS} nccl_common PARENT_SCOPE) endif() diff --git a/paddle/fluid/operators/reduce_ops/CMakeLists.txt b/paddle/fluid/operators/reduce_ops/CMakeLists.txt index 846d362fb522d..9a2abfd93d066 100644 --- a/paddle/fluid/operators/reduce_ops/CMakeLists.txt +++ b/paddle/fluid/operators/reduce_ops/CMakeLists.txt @@ -13,24 +13,6 @@ else() register_operators() endif() -if(WITH_GPU OR WITH_ROCM) - file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.part.cu") - string(REPLACE ".part.cu" "" OPS "${OPS}") - - foreach(src ${OPS}) - if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.part.cu) - set(CUDA_KERNEL_FILE ${CMAKE_CURRENT_SOURCE_DIR}/${src}.part.cu) - file(READ ${CUDA_KERNEL_FILE} TARGET_CONTENT) - string(REGEX MATCH "REGISTER_OP_CUDA_KERNEL\\(\\n?([^,]+),.*" MATCHED ${TARGET_CONTENT}) - if (MATCHED) - string(STRIP ${CMAKE_MATCH_1} MATCHED) - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${MATCHED}, CUDA);\n") - endif() - - endif() - endforeach() -endif() - if(WITH_GPU) if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) nv_test(check_reduce_rank_test SRCS check_reduce_rank_test.cu DEPS tensor cub) diff --git a/paddle/fluid/operators/sign_op.cc b/paddle/fluid/operators/sign_op.cc index f36124078054e..3fd2a5bc5e4c8 100644 --- a/paddle/fluid/operators/sign_op.cc +++ b/paddle/fluid/operators/sign_op.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sign_op.h" #include #include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/float16.h" #include "paddle/pten/core/infermeta_utils.h" #include "paddle/pten/infermeta/unary.h" @@ -65,13 +65,3 @@ REGISTER_OPERATOR(sign, ops::SignOp, ops::SignOpMaker, ops::SignGradMaker, ops::SignGradMaker, SignInferShapeFunctor); -REGISTER_OP_CPU_KERNEL( - sign, ops::SignKernel, - ops::SignKernel); - -REGISTER_OP_CUDA_KERNEL( - sign, - paddle::operators::SignKernel, - paddle::operators::SignKernel, - paddle::operators::SignKernel); diff --git a/paddle/fluid/operators/sign_op_xpu.cc b/paddle/fluid/operators/sign_op_xpu.cc index 8b3beb2fb397b..22934cf482159 100644 --- a/paddle/fluid/operators/sign_op_xpu.cc +++ b/paddle/fluid/operators/sign_op_xpu.cc @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef PADDLE_WITH_XPU - -#include "paddle/fluid/operators/sign_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/xpu/xpu_header.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/tensorrt/CMakeLists.txt b/paddle/fluid/operators/tensorrt/CMakeLists.txt index 0ab66f2fdceaf..a7f18245ab9e9 100644 --- a/paddle/fluid/operators/tensorrt/CMakeLists.txt +++ b/paddle/fluid/operators/tensorrt/CMakeLists.txt @@ -1,5 +1,4 @@ op_library(tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter infer_io_utils analysis_helper) -file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(tensorrt_engine);\n") nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc DEPS tensorrt_engine_op analysis) diff --git a/paddle/fluid/platform/device/xpu/CMakeLists.txt b/paddle/fluid/platform/device/xpu/CMakeLists.txt index d292ce130eb34..28573eb0c1e4c 100644 --- a/paddle/fluid/platform/device/xpu/CMakeLists.txt +++ b/paddle/fluid/platform/device/xpu/CMakeLists.txt @@ -4,7 +4,8 @@ endif() set(XPU_CTX_DEPS xpulib ssl crypto rt z resolv dl) + cc_library(xpu_info SRCS xpu_info.cc DEPS gflags glog enforce xpulib device_context place pten_xpu_info) -cc_library(xpu_op_list SRCS xpu_op_list.cc DEPS gflags glog enforce xpulib device_context) +cc_library(xpu_op_list SRCS xpu_op_list.cc DEPS gflags glog enforce xpulib device_context op_kernel_type) add_subdirectory(tests) diff --git a/paddle/fluid/pybind/eager_op_function_generator.cc b/paddle/fluid/pybind/eager_op_function_generator.cc index 090604ab4ee1a..b38c0eeaf96fb 100644 --- a/paddle/fluid/pybind/eager_op_function_generator.cc +++ b/paddle/fluid/pybind/eager_op_function_generator.cc @@ -32,6 +32,9 @@ #endif #include "paddle/fluid/pybind/op_function_generator.h" +// pten +#include "paddle/pten/kernels/declarations.h" + // clang-format off const char* OUT_INITIALIZER_TEMPLATE = R"({"%s", {std::shared_ptr(new imperative::VarBase("auto_"+std::to_string(VarBaseUniqueNameID++)+"_"))}})"; diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 957c0b0ee6d1d..508ff493a937d 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -32,6 +32,9 @@ #include "paddle/fluid/framework/fleet/ascend_wrapper.h" #endif +// pten +#include "paddle/pten/kernels/declarations.h" + // NOTE(pangyoki): Inplace OP with duplicable input. // The set includes inplace ops that have duplicable input. // The first Varbase in input needs to be specified for the inplace strategy @@ -395,7 +398,7 @@ GenerateOpFunctions() { continue; } auto& op_type = op_proto->type(); - // Skip ooerator which is not inherit form OperatorWithKernel, like while, + // Skip operator which is not inherit form OperatorWithKernel, like while, // since only OperatorWithKernel can run in dygraph mode. // if the pten lib contains op kernel, we still generate ops method if (!all_kernels.count(op_type) && diff --git a/paddle/pten/core/kernel_factory.h b/paddle/pten/core/kernel_factory.h index 8a100451cd4a8..25e3439a6408d 100644 --- a/paddle/pten/core/kernel_factory.h +++ b/paddle/pten/core/kernel_factory.h @@ -208,14 +208,14 @@ class Kernel { */ class KernelFactory { public: - // replaced by paddle::flat_hash_map later - using KernelMap = paddle::flat_hash_map< - std::string, - paddle::flat_hash_map>; + using KernelKeyMap = + paddle::flat_hash_map; + + using KernelNameMap = paddle::flat_hash_map; static KernelFactory& Instance(); - KernelMap& kernels() { return kernels_; } + KernelNameMap& kernels() { return kernels_; } bool HasCompatiblePtenKernel(const std::string& op_type) const { return kernels_.find(TransToPtenKernelName(op_type)) != kernels_.end(); @@ -232,13 +232,12 @@ class KernelFactory { Kernel SelectKernel(const std::string& kernel_name, const KernelKey& kernel_key) const; - paddle::flat_hash_map SelectKernelMap( - const std::string& kernel_name) const; + KernelKeyMap SelectKernelMap(const std::string& kernel_name) const; private: KernelFactory() = default; - KernelMap kernels_; + KernelNameMap kernels_; }; inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) { diff --git a/paddle/pten/core/type_defs.h b/paddle/pten/core/type_defs.h index 13e7bb51c2e1b..eb5459b1b6ea7 100644 --- a/paddle/pten/core/type_defs.h +++ b/paddle/pten/core/type_defs.h @@ -23,6 +23,9 @@ limitations under the License. */ #include +namespace egr { +class EagerTensor; +} namespace paddle { namespace framework { // The order should be as same as framework.proto @@ -71,6 +74,13 @@ template <> struct NameVarMapTrait { using Type = std::map; }; + +template <> +struct NameVarMapTrait { + using Type = + std::map>>; +}; + } // namespace details template @@ -78,6 +88,7 @@ using NameVarMap = typename details::NameVarMapTrait::Type; using NameVarBaseMap = NameVarMap; using NameVariableWrapperMap = NameVarMap; +using NameTensorMap = NameVarMap; using VariableWrapperList = std::vector>; diff --git a/paddle/pten/kernels/CMakeLists.txt b/paddle/pten/kernels/CMakeLists.txt index 615b80be592a0..e14c2f6b6c47c 100644 --- a/paddle/pten/kernels/CMakeLists.txt +++ b/paddle/pten/kernels/CMakeLists.txt @@ -1,6 +1,7 @@ set(kernel_declare_file ${PADDLE_BINARY_DIR}/paddle/pten/kernels/declarations.h.tmp CACHE INTERNAL "declarations.h file") set(kernel_declare_file_final ${PADDLE_BINARY_DIR}/paddle/pten/kernels/declarations.h) file(WRITE ${kernel_declare_file} "// Generated by the paddle/pten/kernels/CMakeLists.txt. DO NOT EDIT!\n\n#pragma once\n\n") +file(APPEND ${kernel_declare_file} "#include \"paddle/pten/core/kernel_registry.h\"\n\n") # pten functors and functions called by kernels add_subdirectory(funcs)