From 8c7ee8c21b28ad9332652e5a1f7b62ca6b4736e0 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Fri, 18 Feb 2022 17:43:14 +0800 Subject: [PATCH] [Pten] blas and lapck migration (#39587) * move blas related files * move lapack related files --- cmake/generic.cmake | 39 + paddle/fluid/distributed/common/utils.h | 8 +- .../ps/service/communicator/communicator.cc | 14 +- .../ps/service/communicator/communicator.h | 2 +- .../ir/embedding_fc_lstm_fuse_pass.cc | 10 +- .../fluid/imperative/gradient_accumulator.cc | 6 +- .../tensorrt/plugin/qkv_to_context_plugin.cu | 2 +- paddle/fluid/operators/activation_op.h | 2 +- paddle/fluid/operators/addmm_op.h | 6 +- paddle/fluid/operators/affine_grid_op.h | 6 +- paddle/fluid/operators/atan2_op.h | 2 +- paddle/fluid/operators/attention_lstm_op.cc | 4 +- paddle/fluid/operators/batch_fc_op.cu | 6 +- .../operators/bilinear_tensor_product_op.h | 6 +- paddle/fluid/operators/bmm_op.h | 20 +- paddle/fluid/operators/center_loss_op.h | 4 +- paddle/fluid/operators/cholesky_op.h | 8 +- paddle/fluid/operators/cholesky_solve_op.h | 17 +- paddle/fluid/operators/conv_op.h | 8 +- paddle/fluid/operators/conv_transpose_op.h | 6 +- .../operators/deformable_conv_filter.cu.h | 2 +- paddle/fluid/operators/deformable_conv_func.h | 2 +- paddle/fluid/operators/deformable_conv_op.cu | 6 +- paddle/fluid/operators/deformable_conv_op.h | 6 +- .../fluid/operators/deformable_conv_v1_op.cu | 6 +- .../fluid/operators/deformable_conv_v1_op.h | 6 +- .../operators/deformable_psroi_pooling_op.cc | 2 +- .../operators/deformable_psroi_pooling_op.cu | 2 +- .../operators/deformable_psroi_pooling_op.h | 2 +- paddle/fluid/operators/eig_op.h | 6 +- paddle/fluid/operators/eigvals_op.h | 18 +- .../elementwise/elementwise_mul_op.cc | 2 +- paddle/fluid/operators/fake_quantize_op.h | 2 +- paddle/fluid/operators/flatten_op.h | 2 +- paddle/fluid/operators/fsp_op.h | 18 +- .../fluid/operators/fused/attn_feed_forward.h | 6 +- paddle/fluid/operators/fused/attn_gemm.h | 6 +- paddle/fluid/operators/fused/fmha_ref.h | 4 +- .../fused_embedding_eltwise_layernorm_op.cu | 2 +- .../fused/fused_embedding_fc_lstm_op.cc | 6 +- .../fused/fused_embedding_seq_pool_op.h | 6 +- .../fused_fc_elementwise_layernorm_op.cu | 4 +- .../operators/fused/fused_feedforward_op.cc | 6 +- .../operators/fused/fused_feedforward_op.cu | 25 +- paddle/fluid/operators/fused/fusion_gru_op.cc | 6 +- .../fluid/operators/fused/fusion_lstm_op.cc | 6 +- .../fused/fusion_seqconv_eltadd_relu_op.cc | 2 +- .../fused/fusion_seqexpand_concat_fc_op.cc | 4 +- paddle/fluid/operators/fused/multi_gru_op.cc | 2 +- .../operators/fused/multihead_matmul_op.cu | 5 +- .../operators/fused/skip_layernorm_op.cu | 2 +- paddle/fluid/operators/gelu_op.h | 33 +- paddle/fluid/operators/group_norm_op.h | 2 +- paddle/fluid/operators/gru_op.cc | 4 +- paddle/fluid/operators/gru_unit_op.h | 6 +- .../fluid/operators/hierarchical_sigmoid_op.h | 2 +- paddle/fluid/operators/index_select_op.h | 4 +- paddle/fluid/operators/inverse_op.h | 15 +- paddle/fluid/operators/layer_norm_op.h | 6 +- .../fluid/operators/lookup_table_dequant_op.h | 2 +- paddle/fluid/operators/lookup_table_op.h | 9 +- paddle/fluid/operators/lookup_table_v2_op.h | 5 +- paddle/fluid/operators/lrn_op.cc | 4 +- paddle/fluid/operators/lstm_op.h | 6 +- paddle/fluid/operators/lstmp_op.h | 6 +- paddle/fluid/operators/lstsq_op.h | 48 +- paddle/fluid/operators/lu_op.cc | 4 +- paddle/fluid/operators/lu_op.h | 36 +- .../fluid/operators/match_matrix_tensor_op.cc | 6 +- paddle/fluid/operators/math/CMakeLists.txt | 44 - .../operators/math/bert_encoder_functor.cu | 6 +- paddle/fluid/operators/math/blas_impl.cu.h | 1804 ---------- paddle/fluid/operators/math/blas_impl.h | 1860 ----------- paddle/fluid/operators/math/blas_impl.hip.h | 1379 -------- paddle/fluid/operators/math/context_project.h | 4 +- .../operators/math/eigen_values_vectors.h | 10 +- paddle/fluid/operators/math/fc.cc | 4 +- paddle/fluid/operators/math/fc.cu | 4 +- paddle/fluid/operators/math/gru_compute.cc | 10 +- paddle/fluid/operators/math/gru_compute.cu | 6 +- .../fluid/operators/math/lapack_function.cc | 226 -- paddle/fluid/operators/math/lapack_function.h | 66 - .../fluid/operators/math/matrix_bit_code.cc | 12 +- paddle/fluid/operators/math/matrix_bit_code.h | 2 +- paddle/fluid/operators/math/matrix_inverse.cc | 2 +- .../fluid/operators/math/matrix_inverse.cu.cc | 4 +- paddle/fluid/operators/math/matrix_solve.cc | 4 +- .../fluid/operators/math/matrix_solve.cu.cc | 6 +- .../operators/math/selected_rows_functor.cc | 16 +- .../operators/math/selected_rows_functor.h | 2 +- .../fluid/operators/math/sequence_pooling.cc | 4 +- paddle/fluid/operators/matmul_op.cc | 41 +- paddle/fluid/operators/matmul_op_xpu.cc | 16 +- paddle/fluid/operators/matmul_v2_op.h | 8 +- paddle/fluid/operators/matmul_v2_op_xpu.cc | 8 +- paddle/fluid/operators/matrix_power_op.h | 12 +- .../operators/mkldnn/matmul_mkldnn_op.cc | 29 +- .../fluid/operators/mkldnn/matmul_mkldnn_op.h | 2 +- .../operators/mkldnn/matmul_v2_mkldnn_op.cc | 9 +- paddle/fluid/operators/mul_op.h | 8 +- paddle/fluid/operators/multi_dot_op.cc | 52 +- paddle/fluid/operators/mv_op.cu | 2 +- paddle/fluid/operators/mv_op.h | 6 +- paddle/fluid/operators/rank_attention_op.cu | 6 +- paddle/fluid/operators/repeat_interleave_op.h | 2 +- paddle/fluid/operators/rnn_op.h | 62 +- paddle/fluid/operators/scatter.h | 4 +- paddle/fluid/operators/search_compute.h | 8 +- .../operators/sequence_ops/sequence_conv_op.h | 4 +- paddle/fluid/operators/solve_op.h | 20 +- paddle/fluid/operators/spectral_norm_op.h | 6 +- paddle/fluid/operators/squeeze_op.h | 2 +- paddle/fluid/operators/svd_helper.h | 10 +- paddle/fluid/operators/tree_conv_op.h | 6 +- paddle/fluid/operators/triangular_solve_op.h | 15 +- paddle/fluid/operators/unsqueeze_op.h | 2 +- paddle/fluid/operators/var_conv_2d_op.cc | 6 +- paddle/pten/kernels/cpu/elementwise.h | 12 +- paddle/pten/kernels/funcs/CMakeLists.txt | 41 +- paddle/pten/kernels/funcs/blas/CMakeLists.txt | 1 + .../math => pten/kernels/funcs/blas}/blas.cc | 24 +- .../math => pten/kernels/funcs/blas}/blas.h | 295 +- paddle/pten/kernels/funcs/blas/blas_impl.cu.h | 2941 +++++++++++++++++ paddle/pten/kernels/funcs/blas/blas_impl.h | 2530 ++++++++++++++ .../pten/kernels/funcs/blas/blas_impl.hip.h | 2276 +++++++++++++ paddle/pten/kernels/funcs/functors.h | 18 - .../pten/kernels/funcs/lapack/CMakeLists.txt | 1 + .../kernels/funcs/lapack/lapack_function.cc | 509 +++ .../kernels/funcs/lapack/lapack_function.h | 128 + paddle/pten/kernels/funcs/math_function.cu | 8 +- .../kernels/impl/matmul_grad_kernel_impl.h | 16 +- paddle/pten/kernels/impl/matmul_kernel_impl.h | 4 +- .../pten/tests/kernels/test_math_function.cc | 91 +- .../pten/tests/kernels/test_math_function.cu | 9 +- 134 files changed, 9215 insertions(+), 6058 deletions(-) delete mode 100644 paddle/fluid/operators/math/blas_impl.cu.h delete mode 100644 paddle/fluid/operators/math/blas_impl.h delete mode 100644 paddle/fluid/operators/math/blas_impl.hip.h delete mode 100644 paddle/fluid/operators/math/lapack_function.cc delete mode 100644 paddle/fluid/operators/math/lapack_function.h create mode 100644 paddle/pten/kernels/funcs/blas/CMakeLists.txt rename paddle/{fluid/operators/math => pten/kernels/funcs/blas}/blas.cc (75%) rename paddle/{fluid/operators/math => pten/kernels/funcs/blas}/blas.h (59%) create mode 100644 paddle/pten/kernels/funcs/blas/blas_impl.cu.h create mode 100644 paddle/pten/kernels/funcs/blas/blas_impl.h create mode 100644 paddle/pten/kernels/funcs/blas/blas_impl.hip.h create mode 100644 paddle/pten/kernels/funcs/lapack/CMakeLists.txt create mode 100644 paddle/pten/kernels/funcs/lapack/lapack_function.cc create mode 100644 paddle/pten/kernels/funcs/lapack/lapack_function.h diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 6655963e728f1..24e3c07215af8 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -1036,3 +1036,42 @@ function(generate_dummy_static_lib) add_library(${dummy_LIB_NAME} STATIC ${dummy_FILE_PATH}) endfunction() +function(math_library TARGET) + # math_library is a function to create math library. + # The interface is the same as cc_library. + # But it handle split GPU/CPU code and link some common library. + set(cc_srcs) + set(cu_srcs) + set(hip_srcs) + set(math_common_deps device_context framework_proto enforce) + if (WITH_GPU) + if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) + list(APPEND math_common_deps cub) + else() + list(APPEND math_common_deps) + endif() + endif() + set(multiValueArgs DEPS) + cmake_parse_arguments(math_library "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc) + list(APPEND cc_srcs ${TARGET}.cc) + endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) + list(APPEND cu_srcs ${TARGET}.cu) + endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc) + list(APPEND cu_srcs ${TARGET}.cu.cc) + endif() + + list(LENGTH cc_srcs cc_srcs_len) + if (WITH_GPU) + nv_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) + elseif (WITH_ROCM) + hip_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) + elseif(${cc_srcs_len} GREATER 0) + cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) + endif() +endfunction() + diff --git a/paddle/fluid/distributed/common/utils.h b/paddle/fluid/distributed/common/utils.h index 85b89d75b98b6..d50423bd4b1a8 100644 --- a/paddle/fluid/distributed/common/utils.h +++ b/paddle/fluid/distributed/common/utils.h @@ -24,18 +24,16 @@ #include #include -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace distributed { template -inline paddle::operators::math::BlasT -GetBlas() { +inline pten::funcs::BlasT GetBlas() { paddle::platform::CPUDeviceContext cpu_ctx; - return paddle::operators::math::GetBlas(cpu_ctx); + return pten::funcs::GetBlas(cpu_ctx); } template diff --git a/paddle/fluid/distributed/ps/service/communicator/communicator.cc b/paddle/fluid/distributed/ps/service/communicator/communicator.cc index 99973ee8bdd74..f47415812e51d 100644 --- a/paddle/fluid/distributed/ps/service/communicator/communicator.cc +++ b/paddle/fluid/distributed/ps/service/communicator/communicator.cc @@ -1161,8 +1161,7 @@ void GeoCommunicator::SendDense(const CommContext &send_ctx) { t_delta->mutable_data(t_latest.dims(), cpu_ctx.GetPlace()); auto blas = - paddle::operators::math::GetBlas( - cpu_ctx); + pten::funcs::GetBlas(cpu_ctx); blas.VSUB(t_latest.numel(), t_latest.data(), t_timestamp->data(), t_delta->data()); @@ -1201,8 +1200,7 @@ void GeoCommunicator::RecvDense(const CommContext &send_ctx) { t_delta->mutable_data(t_latest->dims(), cpu_ctx.GetPlace()); auto blas = - paddle::operators::math::GetBlas( - cpu_ctx); + pten::funcs::GetBlas(cpu_ctx); blas.VSUB(t_latest->numel(), t_pserver.data(), t_old->data(), t_delta->data()); blas.VADD(t_latest->numel(), t_latest->data(), @@ -1303,9 +1301,7 @@ void GeoCommunicator::SendSparse(const std::string &varname, t_delta->set_rows(sparse_ids); t_delta->set_height(t_latest.dims()[0]); - auto blas = - paddle::operators::math::GetBlas( - cpu_ctx); + auto blas = pten::funcs::GetBlas(cpu_ctx); float coefficient = 1.0 / static_cast(trainers_); std::vector push_g_vec; @@ -1371,9 +1367,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id, v_delta.resize(numel); paddle::platform::CPUDeviceContext cpu_ctx; - auto blas = - paddle::operators::math::GetBlas( - cpu_ctx); + auto blas = pten::funcs::GetBlas(cpu_ctx); for (auto j = 0; j < static_cast(keys.size()); ++j) { VLOG(5) << "DEBUG GeoCommunicator::RecvSparse recv sparse key" << keys[j] diff --git a/paddle/fluid/distributed/ps/service/communicator/communicator.h b/paddle/fluid/distributed/ps/service/communicator/communicator.h index 9f8c998d3a1c2..7e5a229aa86c5 100644 --- a/paddle/fluid/distributed/ps/service/communicator/communicator.h +++ b/paddle/fluid/distributed/ps/service/communicator/communicator.h @@ -34,12 +34,12 @@ limitations under the License. */ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable_helper.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/split.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" #include "paddle/fluid/distributed/ps/service/ps_client.h" diff --git a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc index dc0459493c46a..844cdd6e69887 100644 --- a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc @@ -17,7 +17,7 @@ #include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace framework { @@ -121,14 +121,14 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, // broadcast biases std::vector ones(m, 1.0f); - paddle::operators::math::CBlas::GEMM( + pten::funcs::CBlas::GEMM( CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, 1, alpha, &ones[0], 1, &combined_biases[0], n, 0.0f, embeddings_data, n); // Wx*embeddings + biases - paddle::operators::math::CBlas::GEMM( - CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, alpha, - embedding_data, k, weightx_data, n, beta, embeddings_data, n); + pten::funcs::CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, + m, n, k, alpha, embedding_data, k, + weightx_data, n, beta, embeddings_data, n); op_desc.SetInput("Embeddings", {embeddings}); op_desc.SetInput("H0", {}); diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index 382d1b0591cbe..168923e819daa 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -22,13 +22,13 @@ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/imperative/layer.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/profiler.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" #ifdef PADDLE_WITH_XPU #include "xpu/refactor/math.h" @@ -86,7 +86,7 @@ class TensorAddFunctor : public boost::static_visitor<> { void operator()(const platform::CPUPlace& place) const { platform::CPUDeviceContext* ctx = dynamic_cast( platform::DeviceContextPool::Instance().Get(place)); - auto blas = operators::math::GetBlas(*ctx); + auto blas = pten::funcs::GetBlas(*ctx); blas.AXPY(numel_, 1., x_, y_); } @@ -118,7 +118,7 @@ class TensorAddFunctor : public boost::static_visitor<> { platform::CUDADeviceContext* ctx = dynamic_cast( platform::DeviceContextPool::Instance().Get(place)); - auto blas = operators::math::GetBlas(*ctx); + auto blas = pten::funcs::GetBlas(*ctx); blas.AXPY(numel_, 1., x_, y_); } #else diff --git a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu index 8e59fc1355a75..ad7fc2567dc0c 100644 --- a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu @@ -22,8 +22,8 @@ #include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace inference { diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index a089f6b4a3c19..41448ef6345e0 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -28,9 +28,9 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif diff --git a/paddle/fluid/operators/addmm_op.h b/paddle/fluid/operators/addmm_op.h index 8fe73d81b0272..52b1a339c6397 100644 --- a/paddle/fluid/operators/addmm_op.h +++ b/paddle/fluid/operators/addmm_op.h @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/eigen/eigen_function.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace ops = paddle::operators; @@ -94,7 +94,7 @@ class AddMMKernel : public framework::OpKernel { float alpha = context.template Attr("Alpha"); float beta = context.template Attr("Beta"); - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); // calc broadcast dim Array2 bcast_dims; @@ -146,7 +146,7 @@ class AddMMGradKernel : public framework::OpKernel { } auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); if (dinput) { dinput->mutable_data(ctx.GetPlace()); total_elems = in_dims[0] * in_dims[1]; diff --git a/paddle/fluid/operators/affine_grid_op.h b/paddle/fluid/operators/affine_grid_op.h index 129c7a61a7876..39dd95120dac2 100644 --- a/paddle/fluid/operators/affine_grid_op.h +++ b/paddle/fluid/operators/affine_grid_op.h @@ -16,7 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { @@ -122,7 +122,7 @@ class AffineGridOpKernel : public framework::OpKernel { GetIdxMap(n, h, w, align_corners, &grid, ctx); // output = grid * theta.T // TODO(wanghaoshuang): Refine batched matrix multiply - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); for (int i = 0; i < n; ++i) { Tensor sliced_grid = grid.Slice(i, i + 1).Resize( {static_cast(h) * static_cast(w), 3}); @@ -165,7 +165,7 @@ class AffineGridGradOpKernel : public framework::OpKernel { GetIdxMap(n, h, w, align_corners, &grid, ctx); // output = grid * theta.T // TODO(wanghaoshuang): Refine batched matrix multiply - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); for (int i = 0; i < n; ++i) { Tensor sliced_grid = grid.Slice(i, i + 1).Resize( {static_cast(h) * static_cast(w), 3}); diff --git a/paddle/fluid/operators/atan2_op.h b/paddle/fluid/operators/atan2_op.h index 8ed0fda843d47..94d41ea379001 100644 --- a/paddle/fluid/operators/atan2_op.h +++ b/paddle/fluid/operators/atan2_op.h @@ -17,10 +17,10 @@ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/for_range.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 593a1b861cb0d..e35a1a43d7aaf 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -14,10 +14,10 @@ limitations under the License. */ #include "paddle/fluid/operators/attention_lstm_op.h" #include -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/platform/cpu_info.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -373,7 +373,7 @@ class AttentionLSTMKernel : public framework::OpKernel { T* lstm_x_data = lstm_x->mutable_data(ctx.GetPlace()); T* lstm_out_data = lstm_out->mutable_data(ctx.GetPlace()); - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 auto& dev_ctx = ctx.template device_context(); diff --git a/paddle/fluid/operators/batch_fc_op.cu b/paddle/fluid/operators/batch_fc_op.cu index c326929a14680..ad15ca5576a6d 100644 --- a/paddle/fluid/operators/batch_fc_op.cu +++ b/paddle/fluid/operators/batch_fc_op.cu @@ -15,9 +15,9 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/batch_fc_op.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -112,7 +112,7 @@ class BatchFCCUDAKernel : public framework::OpKernel { int64_t strideA = ins_num * in_dim; int64_t strideB = in_dim * out_dim; - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); blas.BatchedGEMM(transA, transB, ins_num, out_dim, in_dim, alpha, in_data, w_data, beta, out_data, slot_pairs_num, strideA, strideB); add_bias(ctx.cuda_device_context().stream(), out_data, slot_pairs_num, @@ -165,7 +165,7 @@ class BatchFCGradOpCUDAKernel : public framework::OpKernel { add_bias_grad(ctx.cuda_device_context().stream(), dout_data, slot_pairs_num, ins_num, out_dim, db_data); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); T alpha = 1; T beta = 0; diff --git a/paddle/fluid/operators/bilinear_tensor_product_op.h b/paddle/fluid/operators/bilinear_tensor_product_op.h index c7eb70c290e17..26f7b78862d1f 100644 --- a/paddle/fluid/operators/bilinear_tensor_product_op.h +++ b/paddle/fluid/operators/bilinear_tensor_product_op.h @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -61,7 +61,7 @@ class BilinearTensorProductKernel : public framework::OpKernel { auto output_col_vec = output_mat.chip(i, 1); Tensor weight_mat = weight->Slice(i, i + 1).Resize(framework::make_ddim({x_dim, y_dim})); - math::GetBlas(dev_ctx).GEMM( + pten::funcs::GetBlas(dev_ctx).GEMM( CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1, x->data(), weight_mat.data(), 0, left_mul.data()); output_col_vec.device(place) = @@ -127,7 +127,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { d_weight->mutable_data(ctx.GetPlace()); } - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); // Caculate the Output(X@Grad) and Output(Y@Grad). if (d_x || d_y || d_weight) { diff --git a/paddle/fluid/operators/bmm_op.h b/paddle/fluid/operators/bmm_op.h index 7a0ddd4582341..1116a3a0ec702 100644 --- a/paddle/fluid/operators/bmm_op.h +++ b/paddle/fluid/operators/bmm_op.h @@ -20,7 +20,7 @@ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { namespace operators { @@ -28,7 +28,7 @@ namespace operators { using Tensor = framework::Tensor; static void ReshapeTensorIntoMatrixSequence( - framework::Tensor *x, const math::MatDescriptor &descriptor) { + framework::Tensor *x, const pten::funcs::MatDescriptor &descriptor) { int64_t h, w; h = descriptor.height_; w = descriptor.width_; @@ -45,8 +45,8 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x, bool trans_y) { auto x_dim = x->dims(); auto y_dim = y->dims(); - auto mat_dim_x = math::CreateMatrixDescriptor(x_dim, 0, false); - auto mat_dim_y = math::CreateMatrixDescriptor(y_dim, 0, false); + auto mat_dim_x = pten::funcs::CreateMatrixDescriptor(x_dim, 0, false); + auto mat_dim_y = pten::funcs::CreateMatrixDescriptor(y_dim, 0, false); out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_), mat_dim_x.height_, mat_dim_y.width_}); @@ -68,10 +68,10 @@ class BmmKernel : public framework::OpKernel { return; } - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); - auto mat_dim_a = math::CreateMatrixDescriptor(x.dims(), 0, false); - auto mat_dim_b = math::CreateMatrixDescriptor(y.dims(), 0, false); + auto mat_dim_a = pten::funcs::CreateMatrixDescriptor(x.dims(), 0, false); + auto mat_dim_b = pten::funcs::CreateMatrixDescriptor(y.dims(), 0, false); // auto scale = static_cast(context.Attr("alpha")); blas.MatMul(x, mat_dim_a, y, mat_dim_b, T(1), out, T(0)); @@ -86,9 +86,9 @@ class BmmGradKernel : public framework::OpKernel { const framework::Tensor &b, bool trans_b, framework::Tensor *out) const { out->mutable_data(context.GetPlace()); - auto blas = math::GetBlas(context); - auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); - auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); + auto blas = pten::funcs::GetBlas(context); + auto mat_dim_a = pten::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); + auto mat_dim_b = pten::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0)); } diff --git a/paddle/fluid/operators/center_loss_op.h b/paddle/fluid/operators/center_loss_op.h index 565b1cee9f785..2b713d3d82cff 100644 --- a/paddle/fluid/operators/center_loss_op.h +++ b/paddle/fluid/operators/center_loss_op.h @@ -19,8 +19,8 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/transform.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -84,7 +84,7 @@ class CenterLossKernel : public framework::OpKernel { int numel = centers_diffacc.numel(); std::memset(centers_diffacc_data, 0, sizeof(T) * numel); - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); int tLabel; const T *x_index; diff --git a/paddle/fluid/operators/cholesky_op.h b/paddle/fluid/operators/cholesky_op.h index 15dd8315362ed..d65c69164d965 100644 --- a/paddle/fluid/operators/cholesky_op.h +++ b/paddle/fluid/operators/cholesky_op.h @@ -19,9 +19,9 @@ limitations under the License. */ #include "Eigen/Cholesky" #include "Eigen/Core" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/platform/for_range.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -323,9 +323,9 @@ class CholeskyGradKernel : public framework::OpKernel { /*! phi = matmul(L.transpose(-1, -2), grad) */ Tensor middle; auto* middle_data = middle.mutable_data(dims, context.GetPlace()); - auto trans_desc = math::CreateMatrixDescriptor(dims, 0, true); - auto no_trans_desc = math::CreateMatrixDescriptor(dims, 0, false); - auto blas = math::GetBlas(context); + auto trans_desc = pten::funcs::CreateMatrixDescriptor(dims, 0, true); + auto no_trans_desc = pten::funcs::CreateMatrixDescriptor(dims, 0, false); + auto blas = pten::funcs::GetBlas(context); blas.MatMul(l, trans_desc, l_grad, no_trans_desc, T(1), &middle, T(0)); /*! phi.tril_().diagonal(0, -2, -1).mul_(0.5) */ diff --git a/paddle/fluid/operators/cholesky_solve_op.h b/paddle/fluid/operators/cholesky_solve_op.h index 2c92969225f3b..ed0211dcd9a56 100644 --- a/paddle/fluid/operators/cholesky_solve_op.h +++ b/paddle/fluid/operators/cholesky_solve_op.h @@ -15,11 +15,11 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/lapack_function.h" #include "paddle/fluid/operators/solve_op.h" #include "paddle/fluid/operators/svd_helper.h" #include "paddle/fluid/operators/triangular_solve_op.h" #include "paddle/fluid/platform/complex.h" +#include "paddle/pten/kernels/funcs/lapack/lapack_function.h" #include "paddle/pten/kernels/math_kernel.h" namespace paddle { @@ -38,8 +38,8 @@ class CholeskySolveFunctor { void operator()(const platform::CPUDeviceContext &dev_ctx, bool upper, int n, int nrhs, T *Adata, int lda, T *Bdata, int *devInfo) { char uplo = upper ? 'U' : 'L'; - math::lapackCholeskySolve(uplo, n, nrhs, Adata, lda, Bdata, lda, - devInfo); + pten::funcs::lapackCholeskySolve(uplo, n, nrhs, Adata, lda, Bdata, lda, + devInfo); } }; @@ -168,7 +168,7 @@ class CholeskySolveGradKernel : public framework::OpKernel { db->Resize(bin->dims()); } - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); // calculate out's conjugate for complex framework::Tensor out_conj(out->type()); @@ -182,8 +182,8 @@ class CholeskySolveGradKernel : public framework::OpKernel { framework::Tensor commonterm(out->type()); auto outdims = out_conj.dims(); auto dbdims = db_bst.dims(); - auto mat_dim_a = math::CreateMatrixDescriptor(outdims, 0, false); - auto mat_dim_b = math::CreateMatrixDescriptor(dbdims, 0, false); + auto mat_dim_a = pten::funcs::CreateMatrixDescriptor(outdims, 0, false); + auto mat_dim_b = pten::funcs::CreateMatrixDescriptor(dbdims, 0, false); auto cmtdim = outdims; cmtdim[cmtdim.size() - 2] = dbdims[dbdims.size() - 2]; commonterm.Resize(cmtdim); @@ -207,9 +207,10 @@ class CholeskySolveGradKernel : public framework::OpKernel { DeviceContext>::TYPE &>(dev_ctx), commonterm, commonterm_conj, -1, &commonterm); - auto mat_dim_u = math::CreateMatrixDescriptor(u_bst.dims(), 0, false); + auto mat_dim_u = + pten::funcs::CreateMatrixDescriptor(u_bst.dims(), 0, false); auto mat_dim_c = - math::CreateMatrixDescriptor(commonterm.dims(), 0, false); + pten::funcs::CreateMatrixDescriptor(commonterm.dims(), 0, false); Tensor du_bst(uin->type()); // get upper or lower triangular diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index 956ba80f32bd1..37af2644c4b43 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -21,10 +21,10 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/layout_utils.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/vol2col.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -332,7 +332,7 @@ class GemmConvKernel : public framework::OpKernel { math::Vol2ColFunctor vol2col; math::Im2ColFunctor im2col; - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); for (int i = 0; i < batch_size; i++) { Tensor in_batch = transformed_input.Slice(i, i + 1).Resize(in_matrix_shape); @@ -486,7 +486,7 @@ class GemmConvGradKernel : public framework::OpKernel { } pten::funcs::SetConstant set_zero; - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); if (input_grad) { input_grad->mutable_data(context.GetPlace()); @@ -693,7 +693,7 @@ class GemmConvDoubleGradKernel : public framework::OpKernel { } pten::funcs::SetConstant set_zero; - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); // dx convolution double grad: gemm + col2im(col2vol) // dx = ddw * dy ==> dx(N, Cin, H, W), ddw(Cout, Cin, kh, kw), dy(N, Cout, diff --git a/paddle/fluid/operators/conv_transpose_op.h b/paddle/fluid/operators/conv_transpose_op.h index 7b1fb6901e39b..2dd2791cfb18d 100644 --- a/paddle/fluid/operators/conv_transpose_op.h +++ b/paddle/fluid/operators/conv_transpose_op.h @@ -20,11 +20,11 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/eigen/eigen_function.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/vol2col.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -228,7 +228,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { output->mutable_data(context.GetPlace()); pten::funcs::SetConstant set_zero; auto& dev_ctx = context.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); set_zero(dev_ctx, output, static_cast(0)); int in_step = @@ -425,7 +425,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // im2col + gemm (similar to conv-forward) // input need to compute gradient auto& dev_ctx = context.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); if (input_grad || filter_grad) { Tensor col; col.mutable_data(col_shape, context.GetPlace()); diff --git a/paddle/fluid/operators/deformable_conv_filter.cu.h b/paddle/fluid/operators/deformable_conv_filter.cu.h index 75d16ae0d43db..85e2dce420d8e 100644 --- a/paddle/fluid/operators/deformable_conv_filter.cu.h +++ b/paddle/fluid/operators/deformable_conv_filter.cu.h @@ -22,7 +22,7 @@ // \author Yi Li, Guodong Zhang, Jifeng Dai #pragma once -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" template diff --git a/paddle/fluid/operators/deformable_conv_func.h b/paddle/fluid/operators/deformable_conv_func.h index 134a1ea06d946..6c45675e35d7a 100644 --- a/paddle/fluid/operators/deformable_conv_func.h +++ b/paddle/fluid/operators/deformable_conv_func.h @@ -22,8 +22,8 @@ // \author Yi Li, Guodong Zhang, Jifeng Dai #pragma once -#include "paddle/fluid/operators/math/blas.h" #include "paddle/pten/core/hostdevice.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" template diff --git a/paddle/fluid/operators/deformable_conv_op.cu b/paddle/fluid/operators/deformable_conv_op.cu index 97d2f71758fb5..0392278f7bc22 100644 --- a/paddle/fluid/operators/deformable_conv_op.cu +++ b/paddle/fluid/operators/deformable_conv_op.cu @@ -25,8 +25,8 @@ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/deformable_conv_op.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { @@ -513,7 +513,7 @@ class DeformableConvCUDAKernel : public framework::OpKernel { int input_offset_dim = offset.numel() / offset.dims()[0]; int input_mask_dim = mask.numel() / mask.dims()[0]; - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); const T* input_ptr = input->data(); const T* offset_ptr = offset.data(); @@ -624,7 +624,7 @@ class DeformableConvGradCUDAKernel : public framework::OpKernel { col_buffer_3d.ShareDataWith(col_buffer).Resize(col_buffer_3d_shape); pten::funcs::SetConstant set_zero; - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); col_buffer.mutable_data(ctx.GetPlace()); col_buffer_3d.mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/deformable_conv_op.h b/paddle/fluid/operators/deformable_conv_op.h index a5c0404ed3a5d..ce8d6cbd41eab 100644 --- a/paddle/fluid/operators/deformable_conv_op.h +++ b/paddle/fluid/operators/deformable_conv_op.h @@ -26,7 +26,7 @@ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/deformable_conv_func.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { @@ -382,7 +382,7 @@ class DeformableConvCPUKernel : public framework::OpKernel { int input_dim = input->numel() / input->dims()[0]; int input_offset_dim = offset->numel() / offset->dims()[0]; int input_mask_dim = mask->numel() / mask->dims()[0]; - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); const T* input_ptr = input->data(); const T* offset_ptr = offset->data(); const T* mask_ptr = mask->data(); @@ -490,7 +490,7 @@ class DeformableConvGradCPUKernel : public framework::OpKernel { col_buffer_3d.ShareDataWith(col_buffer).Resize(col_buffer_3d_shape); pten::funcs::SetConstant set_zero; - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); col_buffer.mutable_data(ctx.GetPlace()); col_buffer_3d.mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/deformable_conv_v1_op.cu b/paddle/fluid/operators/deformable_conv_v1_op.cu index 8f6c5a226bc86..20dbdde31fbfb 100644 --- a/paddle/fluid/operators/deformable_conv_v1_op.cu +++ b/paddle/fluid/operators/deformable_conv_v1_op.cu @@ -28,8 +28,8 @@ #include "paddle/fluid/operators/deformable_conv_filter.cu.h" #include "paddle/fluid/operators/deformable_conv_func.h" #include "paddle/fluid/operators/deformable_conv_v1_op.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { @@ -381,7 +381,7 @@ class DeformableConvV1CUDAKernel : public framework::OpKernel { int input_dim = input->numel() / input->dims()[0]; int input_offset_dim = offset.numel() / offset.dims()[0]; - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); const T* input_ptr = input->data(); const T* offset_ptr = offset.data(); @@ -490,7 +490,7 @@ class DeformableConvV1GradCUDAKernel : public framework::OpKernel { col_buffer_3d.ShareDataWith(col_buffer).Resize(col_buffer_3d_shape); pten::funcs::SetConstant set_zero; - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); col_buffer.mutable_data(ctx.GetPlace()); col_buffer_3d.mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/deformable_conv_v1_op.h b/paddle/fluid/operators/deformable_conv_v1_op.h index 1ddc31c93eaaa..95c78db0c4809 100644 --- a/paddle/fluid/operators/deformable_conv_v1_op.h +++ b/paddle/fluid/operators/deformable_conv_v1_op.h @@ -27,7 +27,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/deformable_conv_func.h" #include "paddle/fluid/operators/deformable_conv_op.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { @@ -348,7 +348,7 @@ class DeformableConvV1CPUKernel : public framework::OpKernel { std::vector input_shape_vec = framework::vectorize(input_shape); int input_dim = input->numel() / input->dims()[0]; int input_offset_dim = offset->numel() / offset->dims()[0]; - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); const T* input_ptr = input->data(); const T* offset_ptr = offset->data(); col_buffer.mutable_data(ctx.GetPlace()); @@ -452,7 +452,7 @@ class DeformableConvV1GradCPUKernel : public framework::OpKernel { col_buffer_3d.ShareDataWith(col_buffer).Resize(col_buffer_3d_shape); pten::funcs::SetConstant set_zero; - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); col_buffer.mutable_data(ctx.GetPlace()); col_buffer_3d.mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/deformable_psroi_pooling_op.cc b/paddle/fluid/operators/deformable_psroi_pooling_op.cc index bba859aed6d7f..f63221634d5c8 100644 --- a/paddle/fluid/operators/deformable_psroi_pooling_op.cc +++ b/paddle/fluid/operators/deformable_psroi_pooling_op.cc @@ -16,7 +16,7 @@ #include #include #include -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/deformable_psroi_pooling_op.cu b/paddle/fluid/operators/deformable_psroi_pooling_op.cu index 95f05963cd1f6..02e4a1fb46924 100644 --- a/paddle/fluid/operators/deformable_psroi_pooling_op.cu +++ b/paddle/fluid/operators/deformable_psroi_pooling_op.cu @@ -30,8 +30,8 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/deformable_psroi_pooling_op.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/deformable_psroi_pooling_op.h b/paddle/fluid/operators/deformable_psroi_pooling_op.h index 08b8342a1fd69..f57554aeeed41 100644 --- a/paddle/fluid/operators/deformable_psroi_pooling_op.h +++ b/paddle/fluid/operators/deformable_psroi_pooling_op.h @@ -26,7 +26,7 @@ #include #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/eig_op.h b/paddle/fluid/operators/eig_op.h index f822802d305e9..a53cacdd2086e 100644 --- a/paddle/fluid/operators/eig_op.h +++ b/paddle/fluid/operators/eig_op.h @@ -17,12 +17,12 @@ #include #include #include -#include "paddle/fluid/operators/math/lapack_function.h" #include "paddle/fluid/operators/math/matrix_solve.h" #include "paddle/fluid/operators/svd_helper.h" #include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/pten/kernels/funcs/complex_functors.h" +#include "paddle/pten/kernels/funcs/lapack/lapack_function.h" #include "paddle/pten/kernels/funcs/math_function.h" #define EPSILON 1e-6 @@ -94,7 +94,7 @@ void LapackEig(Tensor* input, Tensor* values, Tensor* vectors, int info, // call lapackEig once to compute the size of work; T computed_work_size; - math::lapackEig>( + pten::funcs::lapackEig>( jobvl, jobvr, order, input_data, lda, values_data, lvector_data, ldvl, rvector_data, ldvr, &computed_work_size, lwork, rwork_data, &info); @@ -109,7 +109,7 @@ void LapackEig(Tensor* input, Tensor* values, Tensor* vectors, int info, T* current_values = &values_data[i * values_stride]; T* current_rvectors = &rvector_data[i * matrix_stride]; - math::lapackEig>( + pten::funcs::lapackEig>( jobvl, jobvr, order, current_matrix, lda, current_values, lvector_data, ldvl, current_rvectors, ldvr, work_data, lwork, rwork_data, &info); PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/operators/eigvals_op.h b/paddle/fluid/operators/eigvals_op.h index a069ea164c94c..7ad37558a2ab2 100644 --- a/paddle/fluid/operators/eigvals_op.h +++ b/paddle/fluid/operators/eigvals_op.h @@ -20,9 +20,9 @@ #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/allocation/allocator.h" -#include "paddle/fluid/operators/math/lapack_function.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/pten/kernels/funcs/complex_functors.h" +#include "paddle/pten/kernels/funcs/lapack/lapack_function.h" namespace paddle { namespace operators { @@ -103,11 +103,11 @@ LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input, required_work_mem, work_mem)); int info = 0; - math::lapackEig('N', 'N', static_cast(n_dim), a.template data(), - static_cast(n_dim), w_data, NULL, 1, NULL, 1, - work->template data(), - static_cast(work_mem / sizeof(T)), - static_cast(NULL), &info); + pten::funcs::lapackEig('N', 'N', static_cast(n_dim), + a.template data(), static_cast(n_dim), + w_data, NULL, 1, NULL, 1, work->template data(), + static_cast(work_mem / sizeof(T)), + static_cast(NULL), &info); std::string name = "framework::platform::dynload::dgeev_"; if (framework::TransToProtoVarType(input.dtype()) == @@ -153,7 +153,7 @@ LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input, required_rwork_mem, rwork_mem)); int info = 0; - math::lapackEig>( + pten::funcs::lapackEig>( 'N', 'N', static_cast(n_dim), a.template data(), static_cast(n_dim), output->template data(), NULL, 1, NULL, 1, work->template data(), static_cast(work_mem / sizeof(T)), @@ -187,10 +187,10 @@ class EigvalsKernel : public framework::OpKernel { // query workspace size T qwork; int info; - math::lapackEig>( + pten::funcs::lapackEig>( 'N', 'N', static_cast(n_dim), input_matrices[0].template data(), static_cast(n_dim), NULL, NULL, 1, NULL, 1, &qwork, -1, - static_cast*>(NULL), &info); + static_cast*>(NULL), &info); int64_t lwork = static_cast(qwork); Tensor work, rwork; diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc index a5683c9e88a56..9ac863ddd848f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -28,7 +28,7 @@ struct SameDimsElemwiseMul< void operator()(const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *y, framework::Tensor *z) { - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); blas.VMUL(x->numel(), x->data(), y->data(), z->data()); } }; diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index c31139611e84c..37cb543aac960 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -19,9 +19,9 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/transform.h" #include "paddle/pten/core/hostdevice.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/flatten_op.h b/paddle/fluid/operators/flatten_op.h index 15e820a9ee366..4a586bc9e98c0 100644 --- a/paddle/fluid/operators/flatten_op.h +++ b/paddle/fluid/operators/flatten_op.h @@ -16,12 +16,12 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/pten_utils.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/pooling.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/pten/kernels/empty_kernel.h" #include "paddle/pten/kernels/flatten_grad_kernel.h" #include "paddle/pten/kernels/flatten_kernel.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/fsp_op.h b/paddle/fluid/operators/fsp_op.h index 999c3ae3747e9..0113326e9b126 100644 --- a/paddle/fluid/operators/fsp_op.h +++ b/paddle/fluid/operators/fsp_op.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { @@ -39,16 +39,16 @@ class FSPOpKernel : public framework::OpKernel { auto height = x_dims[2]; auto width = x_dims[3]; - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); - math::MatDescriptor x_mat_desc; + pten::funcs::MatDescriptor x_mat_desc; x_mat_desc.height_ = x_channel; x_mat_desc.width_ = height * width; x_mat_desc.batch_size_ = batch_size; x_mat_desc.stride_ = x_channel * height * width; x_mat_desc.trans_ = false; - math::MatDescriptor y_mat_desc; + pten::funcs::MatDescriptor y_mat_desc; y_mat_desc.height_ = height * width; y_mat_desc.width_ = y_channel; y_mat_desc.batch_size_ = batch_size; @@ -78,7 +78,7 @@ class FSPGradOpKernel : public framework::OpKernel { int64_t h = 0; int64_t w = 0; - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); pten::funcs::SetConstant set_zero; if (d_x != nullptr) { d_x->mutable_data(context.GetPlace()); @@ -89,14 +89,14 @@ class FSPGradOpKernel : public framework::OpKernel { h = y_dims[2]; w = y_dims[3]; - math::MatDescriptor d_out_mat_desc; + pten::funcs::MatDescriptor d_out_mat_desc; d_out_mat_desc.height_ = x_channel; d_out_mat_desc.width_ = y_channel; d_out_mat_desc.batch_size_ = batch_size; d_out_mat_desc.stride_ = x_channel * y_channel; d_out_mat_desc.trans_ = false; - math::MatDescriptor y_mat_desc; + pten::funcs::MatDescriptor y_mat_desc; y_mat_desc.height_ = y_channel; y_mat_desc.width_ = h * w; y_mat_desc.batch_size_ = batch_size; @@ -116,14 +116,14 @@ class FSPGradOpKernel : public framework::OpKernel { h = x_dims[2]; w = x_dims[3]; - math::MatDescriptor d_out_mat_desc; + pten::funcs::MatDescriptor d_out_mat_desc; d_out_mat_desc.height_ = y_channel; d_out_mat_desc.width_ = x_channel; d_out_mat_desc.batch_size_ = batch_size; d_out_mat_desc.stride_ = x_channel * y_channel; d_out_mat_desc.trans_ = true; - math::MatDescriptor x_mat_desc; + pten::funcs::MatDescriptor x_mat_desc; x_mat_desc.height_ = x_channel; x_mat_desc.width_ = h * w; x_mat_desc.batch_size_ = batch_size; diff --git a/paddle/fluid/operators/fused/attn_feed_forward.h b/paddle/fluid/operators/fused/attn_feed_forward.h index e7eba2da63663..f9d4d48e81f7e 100644 --- a/paddle/fluid/operators/fused/attn_feed_forward.h +++ b/paddle/fluid/operators/fused/attn_feed_forward.h @@ -15,8 +15,8 @@ limitations under the License. */ #pragma once #include "paddle/fluid/operators/fused/attn_bias_add.cu.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -47,7 +47,7 @@ class FeedForward { // column-major: (m,n,k) = output_size,bsz_seq,input_size (weight*input=out) // here: (m,n,k) = bsz_seq,output_size,input_size (input*weight=out) - auto blas = math::GetBlas(dev_ctx_); + auto blas = pten::funcs::GetBlas(dev_ctx_); blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha, input_data, weight_data, beta, output_data); if (compute_bias_) { @@ -60,7 +60,7 @@ class FeedForward { T* d_weight, T* d_bias) { T alpha = static_cast(1.0); T beta = static_cast(0.0); - auto blas = math::GetBlas(dev_ctx_); + auto blas = pten::funcs::GetBlas(dev_ctx_); // column-major: gemm-nt, get d_weight. CBLAS_TRANSPOSE transA = CblasTrans; diff --git a/paddle/fluid/operators/fused/attn_gemm.h b/paddle/fluid/operators/fused/attn_gemm.h index 48f520d60b735..638ca5c80e1d7 100644 --- a/paddle/fluid/operators/fused/attn_gemm.h +++ b/paddle/fluid/operators/fused/attn_gemm.h @@ -11,8 +11,8 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" @@ -56,7 +56,7 @@ class AttnMatMul { T beta = static_cast(0.0); // here: (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out) - auto blas = math::GetBlas(dev_ctx_); + auto blas = pten::funcs::GetBlas(dev_ctx_); blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha, input->data(), weight->data(), beta, output->data()); if (compute_bias_) { @@ -80,7 +80,7 @@ class AttnMatMul { framework::Tensor* d_bias) { T alpha = static_cast(1.0); T beta = static_cast(0.0); - auto blas = math::GetBlas(dev_ctx_); + auto blas = pten::funcs::GetBlas(dev_ctx_); CBLAS_TRANSPOSE dB_transA = CblasNoTrans; CBLAS_TRANSPOSE dB_transB = CblasNoTrans; diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 8c080f97cba82..302532b625b54 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -99,7 +99,7 @@ class FMHARef { // q*k^t, batched_gemm CBLAS_TRANSPOSE transA = CblasNoTrans; CBLAS_TRANSPOSE transB = CblasTrans; - auto blas = math::GetBlas(dev_ctx_); + auto blas = pten::funcs::GetBlas(dev_ctx_); int gemm_batch_size = batch_size_ * num_head_; int gemm_m = seq_len_; int gemm_n = seq_len_; @@ -174,7 +174,7 @@ class FMHARef { Tensor* softmax_out_grad_tensor, Tensor* src_mask_out_grad_tensor, Tensor* qk_out_grad_tensor, Tensor* transpose_2_out_grad_tensor, Tensor* src_mask_grad_tensor, Tensor* qkv_input_grad_tensor) { - auto blas = math::GetBlas(dev_ctx_); + auto blas = pten::funcs::GetBlas(dev_ctx_); int q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; int k_size = q_size; int softmax_axis = -1; diff --git a/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cu b/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cu index c2c617b8d5238..dd32ef71e3be6 100644 --- a/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cu +++ b/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cu @@ -18,7 +18,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc index 79fa268f3884b..8aec5444677d1 100644 --- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc @@ -14,10 +14,10 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.h" #include -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/platform/cpu_info.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -364,7 +364,7 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel { T* xx_data = xx->mutable_data(place); T* h_out_data = hidden_out->mutable_data(place); T* c_out_data = cell_out->mutable_data(place); - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); for (int64_t i = 0; i < ids_numel; ++i) { PADDLE_ENFORCE_LT( @@ -475,7 +475,7 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel { math::LoDTensor2BatchFunctor to_batch; auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); for (int64_t i = 0; i < ids_numel; ++i) { PADDLE_ENFORCE_LT( diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h index fc782dc551175..468e7a5786642 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h @@ -23,7 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/operators/jit/kernels.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -179,7 +179,7 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel { const int m = batch_size * idx_width; const int n = table_width; const int k = table_height; - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); blas.CSRMM(&transa, &m, &n, &k, &alpha, matdescra, (const T *)csr_vals, (const int *)csr_colmuns, (const int *)csr_row_idx, (const int *)csr_row_idx + 1, weights, &n, &beta, output, &n); @@ -277,7 +277,7 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel { csr_colmuns, csr_row_idx, padding_idx); auto *d_output_data = d_output->data(); - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); int width = static_cast(table_dim[1]); int num_seq = batch_size * idx_width; LOG(INFO) << "num seq = " << num_seq << " width = " << width; diff --git a/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu b/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu index ebda9bbaa8b81..d37ac322da704 100644 --- a/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu +++ b/paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu @@ -21,9 +21,9 @@ namespace cub = hipcub; #endif #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -150,7 +150,7 @@ class FusedFCElementwiseLayerNormOpKernel : public framework::OpKernel { T* out_data = out->mutable_data(ctx.GetPlace()); auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); blas.GEMM(false, false, M, N, K, static_cast(1.0), x_data, K, w_data, N, static_cast(0.0), out_data, N); diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cc b/paddle/fluid/operators/fused/fused_feedforward_op.cc index 1815a4dd71c78..1ef9edcb911af 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cc +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cc @@ -17,8 +17,8 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/matmul_v2_op.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -49,8 +49,8 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel { "fused_feedforward"); auto dim_x = context->GetInputDim("X"); - auto mat_dim_x = - math::CreateMatrixDescriptor(RowMatrixFromVector(dim_x), 0, false); + auto mat_dim_x = pten::funcs::CreateMatrixDescriptor( + RowMatrixFromVector(dim_x), 0, false); // verify for the pre layer_norm, the feature size must be larger than 1 PADDLE_ENFORCE_GT( mat_dim_x.width_, static_cast(1), diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cu b/paddle/fluid/operators/fused/fused_feedforward_op.cu index 934ce78e715bb..2709c206ed39c 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cu +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cu @@ -14,8 +14,8 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/matmul_v2_op.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/fused/fused_dropout_helper.h" @@ -32,11 +32,11 @@ class FusedFeedForwardKernel : public framework::OpKernel { void MatMul(const platform::CUDADeviceContext& ctx, const framework::Tensor& a, const framework::Tensor& b, framework::Tensor* c) const { - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); auto a_2d = FoldInitDims(a); auto b_2d = FoldInitDims(b); - auto mat_dim_a = math::CreateMatrixDescriptor(a_2d.dims(), 0, false); - auto mat_dim_b = math::CreateMatrixDescriptor(b_2d.dims(), 0, false); + auto mat_dim_a = pten::funcs::CreateMatrixDescriptor(a_2d.dims(), 0, false); + auto mat_dim_b = pten::funcs::CreateMatrixDescriptor(b_2d.dims(), 0, false); T alpha = static_cast(1.0); blas.MatMul(a, mat_dim_a, b, mat_dim_b, alpha, c, T(0)); } @@ -173,8 +173,8 @@ class FusedFeedForwardKernel : public framework::OpKernel { dropout2_out->mutable_data(place); auto x_dim = x->dims(); - auto mat_dim_x = - math::CreateMatrixDescriptor(RowMatrixFromVector(x_dim), 0, false); + auto mat_dim_x = pten::funcs::CreateMatrixDescriptor( + RowMatrixFromVector(x_dim), 0, false); auto dim = linear1_weight->dims(); int d_model = dim[0]; @@ -197,12 +197,13 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { const framework::Tensor& d_out, const framework::Tensor& a, const framework::Tensor& b, framework::Tensor* d_a, framework::Tensor* d_b) const { - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); auto a_2d = FoldInitDims(a); auto b_2d = FoldInitDims(b); - auto mat_dim_a = math::CreateMatrixDescriptor(a_2d.dims(), 0, true); - auto mat_dim_b = math::CreateMatrixDescriptor(b_2d.dims(), 0, true); - auto mat_dim_dout = math::CreateMatrixDescriptor(d_out.dims(), 0, false); + auto mat_dim_a = pten::funcs::CreateMatrixDescriptor(a_2d.dims(), 0, true); + auto mat_dim_b = pten::funcs::CreateMatrixDescriptor(b_2d.dims(), 0, true); + auto mat_dim_dout = + pten::funcs::CreateMatrixDescriptor(d_out.dims(), 0, false); T alpha = static_cast(1.0); blas.MatMul(d_out, mat_dim_dout, b, mat_dim_b, alpha, d_a, T(0)); blas.MatMul(a, mat_dim_a, d_out, mat_dim_dout, alpha, d_b, T(0)); @@ -403,8 +404,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { d_linear2_weight->mutable_data(place); auto x_dim = x.dims(); - auto mat_dim_x = - math::CreateMatrixDescriptor(RowMatrixFromVector(x_dim), 0, false); + auto mat_dim_x = pten::funcs::CreateMatrixDescriptor( + RowMatrixFromVector(x_dim), 0, false); auto linear1_weight_dim = linear1_weight.dims(); int d_model = linear1_weight_dim[0]; diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index e0ecd2cab535a..92904ed126a8e 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -18,9 +18,9 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/jit/kernels.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/math/sequence2batch.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -295,7 +295,7 @@ class FusionGRUKernel : public framework::OpKernel { const T* h0_data = h0 ? h0->data() : nullptr; const T* wh_state_data = wh_data + D * D2; T* hidden_out_data = hidden_out->mutable_data(place); - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); auto& dev_ctx = ctx.template device_context(); math::FCFunctor fc; @@ -367,7 +367,7 @@ class FusionGRUKernel : public framework::OpKernel { T* batched_out_data = batched_out->mutable_data(place); hidden_out->mutable_data(place); auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); math::LoDTensor2BatchFunctor to_batch; math::FCFunctor fc; diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index 42bf784b2af4f..7834c91a6595f 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -15,9 +15,9 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fusion_lstm_op.h" #include #include "paddle/fluid/operators/jit/kernels.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/math/sequence2batch.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -343,7 +343,7 @@ class FuisonLSTMKernel : public framework::OpKernel { T* xx_data = xx->mutable_data(place); T* h_out_data = hidden_out->mutable_data(place); T* c_out_data = cell_out->mutable_data(place); - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); auto& dev_ctx = ctx.template device_context(); math::FCFunctor fc; @@ -423,7 +423,7 @@ class FuisonLSTMKernel : public framework::OpKernel { math::LoDTensor2BatchFunctor to_batch; auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); math::FCFunctor fc; if (M > D4) { fc(dev_ctx, x_dims[0], D4, M, x_data, wx_data, xx_data, bias->data()); diff --git a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc index c5a291f10b2ea..5432d0144f9eb 100644 --- a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc @@ -15,8 +15,8 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h" #include // for min, max #include -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/fc.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc index 2f52ee226bc5f..a9305f3c6d6ec 100644 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc @@ -14,10 +14,10 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h" #include -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/platform/cpu_info.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -209,7 +209,7 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel { T* out_data = out->mutable_data(ctx.GetPlace()); T* fc_out_data = fc_out->mutable_data(ctx.GetPlace()); - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); auto& dev_ctx = ctx.template device_context(); math::FCFunctor fc; diff --git a/paddle/fluid/operators/fused/multi_gru_op.cc b/paddle/fluid/operators/fused/multi_gru_op.cc index 922b8496441bc..4b4e3b840bb14 100644 --- a/paddle/fluid/operators/fused/multi_gru_op.cc +++ b/paddle/fluid/operators/fused/multi_gru_op.cc @@ -18,9 +18,9 @@ limitations under the License. */ #include #include #include "paddle/fluid/operators/jit/kernels.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/math/sequence2batch.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cu b/paddle/fluid/operators/fused/multihead_matmul_op.cu index 69056189ac221..4fe6dccf4a184 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cu +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cu @@ -17,7 +17,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -211,7 +211,8 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { auto *temp_out_data = temp_out_tensor.mutable_data(context.GetPlace()); // (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H) - auto blas = math::GetBlas(device_ctx); + auto blas = + pten::funcs::GetBlas(device_ctx); blas.MatMul(input_matrix, w_matrix, &temp_out_tensor); // temp_out_tensor.Resize(temp_out_dims); diff --git a/paddle/fluid/operators/fused/skip_layernorm_op.cu b/paddle/fluid/operators/fused/skip_layernorm_op.cu index 74cd9127711b1..2ee2177e1b431 100644 --- a/paddle/fluid/operators/fused/skip_layernorm_op.cu +++ b/paddle/fluid/operators/fused/skip_layernorm_op.cu @@ -17,7 +17,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/gelu_op.h b/paddle/fluid/operators/gelu_op.h index a913b8a111279..048c27f201794 100644 --- a/paddle/fluid/operators/gelu_op.h +++ b/paddle/fluid/operators/gelu_op.h @@ -20,8 +20,8 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" @@ -63,13 +63,13 @@ struct GeluFunctor { int n = std::min(x.size(), out.size()); std::memset(out_data, 0, n * sizeof(T)); - math::CBlas::AXPY(n, static_cast(M_SQRT1_2), x_data, 1, out_data, - 1); - math::CBlas::VMERF(n, out_data, out_data, VML_LA); + pten::funcs::CBlas::AXPY(n, static_cast(M_SQRT1_2), x_data, 1, + out_data, 1); + pten::funcs::CBlas::VMERF(n, out_data, out_data, VML_LA); for (int i = 0; i < n; i++) { out_data[i] += static_cast(1); } - math::CBlas::VMUL(n, x_data, out_data, out_data); + pten::funcs::CBlas::VMUL(n, x_data, out_data, out_data); for (int i = 0; i < n; i++) { out_data[i] *= static_cast(0.5); } @@ -138,24 +138,25 @@ struct GeluGradFunctor { std::memset(second, 0, n * sizeof(T)); // first = (0.5 * (1 + erf(x / sqrt(2)))) - math::CBlas::AXPY(n, static_cast(M_SQRT1_2), x_data, 1, first, 1); - math::CBlas::VMERF(n, first, first, VML_LA); + pten::funcs::CBlas::AXPY(n, static_cast(M_SQRT1_2), x_data, 1, + first, 1); + pten::funcs::CBlas::VMERF(n, first, first, VML_LA); for (int i = 0; i < n; i++) { first[i] += static_cast(1); } - math::CBlas::SCAL(n, static_cast(0.5), first, 1); + pten::funcs::CBlas::SCAL(n, static_cast(0.5), first, 1); // second = (0.5 * 2/sqrt(pi) * 1/sqrt(2) * x * exp(-0.5 * x^2)) - math::CBlas::VSQUARE(n, x_data, second); - math::CBlas::SCAL(n, -static_cast(0.5), second, 1); - math::CBlas::VEXP(n, second, second); - math::CBlas::VMUL(n, x_data, second, second); - math::CBlas::SCAL(n, static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2), - second, 1); + pten::funcs::CBlas::VSQUARE(n, x_data, second); + pten::funcs::CBlas::SCAL(n, -static_cast(0.5), second, 1); + pten::funcs::CBlas::VEXP(n, second, second); + pten::funcs::CBlas::VMUL(n, x_data, second, second); + pten::funcs::CBlas::SCAL( + n, static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2), second, 1); // dx = dout * (first + second); - math::CBlas::VADD(n, first, second, first); - math::CBlas::VMUL(n, dout_data, first, dx_data); + pten::funcs::CBlas::VADD(n, first, second, first); + pten::funcs::CBlas::VMUL(n, dout_data, first, dx_data); std::free(first); std::free(second); diff --git a/paddle/fluid/operators/group_norm_op.h b/paddle/fluid/operators/group_norm_op.h index 3fc2d413b6cef..7f0bc7b258208 100644 --- a/paddle/fluid/operators/group_norm_op.h +++ b/paddle/fluid/operators/group_norm_op.h @@ -21,7 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 20956e3cdbbde..5e1594feecf1d 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -15,9 +15,9 @@ limitations under the License. */ #include "paddle/fluid/operators/gru_op.h" #include #include -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" #include "paddle/fluid/operators/math/detail/gru_kernel.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" DECLARE_int32(paddle_num_threads); @@ -355,7 +355,7 @@ class GRUCPUKernel : public framework::OpKernel { #ifdef PADDLE_WITH_MKLML // use MKL packed to speedup GEMM if (FLAGS_paddle_num_threads >= 4) { - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, frame_size * 2 /*width of weight*/, frame_size /*height of height*/); diff --git a/paddle/fluid/operators/gru_unit_op.h b/paddle/fluid/operators/gru_unit_op.h index b727da4ae0cd3..951441677c799 100644 --- a/paddle/fluid/operators/gru_unit_op.h +++ b/paddle/fluid/operators/gru_unit_op.h @@ -17,8 +17,8 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/place.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -87,7 +87,7 @@ class GRUUnitKernel : public framework::OpKernel { const T* weight_data = weight->data(); T* gate_data = gate->data(); T* reset_hidden_prev_data = reset_hidden_prev->data(); - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); blas.GEMM(false, false, batch_size, 2 * frame_size, frame_size, 1, hidden_prev_data, frame_size, weight_data, frame_size * 2, 1, gate_data, frame_size * 3); @@ -204,7 +204,7 @@ class GRUUnitGradKernel : public framework::OpKernel { d_g.slice(c_offsets, extents), d_h * u); } // backward for reset_hidden_prev - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); blas.GEMM(false, true, batch_size, frame_size, frame_size, 1, gate_grad_data + frame_size * 2, frame_size * 3, weight_data + frame_size * frame_size * 2, frame_size, 0, diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index 5734e247f4dfc..59bd716bc79f1 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -166,7 +166,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { // softrelu derivative - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); auto* pre_out_grad_data = pre_out_grad.data(); auto* pre_out_data = pre_out.template data(); diff --git a/paddle/fluid/operators/index_select_op.h b/paddle/fluid/operators/index_select_op.h index 5fee9bfd8547c..7ab62a07b4a5f 100644 --- a/paddle/fluid/operators/index_select_op.h +++ b/paddle/fluid/operators/index_select_op.h @@ -15,7 +15,7 @@ #pragma once #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { @@ -141,7 +141,7 @@ struct IndexSelectAdd< typename std::enable_if::value>::type> { void operator()(const framework::ExecutionContext& ctx, int slice_size, const T* src_pointer, const T* p_pointer, T* dist_pointer) { - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); blas.VADD(slice_size, src_pointer, p_pointer, dist_pointer); } }; diff --git a/paddle/fluid/operators/inverse_op.h b/paddle/fluid/operators/inverse_op.h index c1859a26f360b..bfe449800c733 100644 --- a/paddle/fluid/operators/inverse_op.h +++ b/paddle/fluid/operators/inverse_op.h @@ -15,8 +15,8 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/matrix_inverse.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -48,19 +48,22 @@ class InverseGradKernel : public framework::OpKernel { if (a_grad) { a_grad->mutable_data(context.GetPlace()); - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); auto& dev_ctx = context.template device_context(); framework::Tensor tmp_out = context.AllocateTmpTensor(a_inv->dims(), dev_ctx); auto mat_dim_a0 = - math::CreateMatrixDescriptor(a_inv_grad->dims(), 0, false); - auto mat_dim_b0 = math::CreateMatrixDescriptor(a_inv->dims(), 0, true); + pten::funcs::CreateMatrixDescriptor(a_inv_grad->dims(), 0, false); + auto mat_dim_b0 = + pten::funcs::CreateMatrixDescriptor(a_inv->dims(), 0, true); blas.MatMul(*a_inv_grad, mat_dim_a0, *a_inv, mat_dim_b0, T(1), &tmp_out, T(0)); - auto mat_dim_a1 = math::CreateMatrixDescriptor(a_inv->dims(), 0, true); - auto mat_dim_b1 = math::CreateMatrixDescriptor(tmp_out.dims(), 0, false); + auto mat_dim_a1 = + pten::funcs::CreateMatrixDescriptor(a_inv->dims(), 0, true); + auto mat_dim_b1 = + pten::funcs::CreateMatrixDescriptor(tmp_out.dims(), 0, false); blas.MatMul(*a_inv, mat_dim_a1, tmp_out, mat_dim_b1, T(-1), a_grad, T(0)); } } diff --git a/paddle/fluid/operators/layer_norm_op.h b/paddle/fluid/operators/layer_norm_op.h index b7916f44d3c33..bdbf91119d393 100644 --- a/paddle/fluid/operators/layer_norm_op.h +++ b/paddle/fluid/operators/layer_norm_op.h @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \ !defined(__OSX__) #include "paddle/fluid/operators/jit/kernels.h" @@ -61,7 +61,7 @@ class RowwiseMean2D { } void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, framework::Tensor* out) { - math::GetBlas(context).GEMV( + pten::funcs::GetBlas(context).GEMV( false, left_, right_, 1., input.data(), divisor_.data(), 0., out->data()); } @@ -108,7 +108,7 @@ class ColwiseSum2D { void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, framework::Tensor* out) { - math::GetBlas(context).GEMV( + pten::funcs::GetBlas(context).GEMV( true, left_, right_, 1., input.data(), divisor_.data(), 0., out->data()); } diff --git a/paddle/fluid/operators/lookup_table_dequant_op.h b/paddle/fluid/operators/lookup_table_dequant_op.h index 475d0922ccc69..adb1fab741fe9 100644 --- a/paddle/fluid/operators/lookup_table_dequant_op.h +++ b/paddle/fluid/operators/lookup_table_dequant_op.h @@ -22,7 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/framework/var_type_traits.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index c0c6809656587..e05b333fa0e95 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -21,7 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows_utils.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -109,8 +109,8 @@ class LookupTableKernel : public framework::OpKernel { memcpy(output + i * row_width, table + id_index * row_width, row_width * sizeof(T)); } else { - auto blas = - math::GetBlas(context); + auto blas = pten::funcs::GetBlas( + context); blas.VCOPY(row_width, table + id_index * row_width, output + i * row_width); } @@ -137,7 +137,8 @@ class LookupTableKernel : public framework::OpKernel { memcpy(output + i * row_width, table + id_index * row_width, row_width * sizeof(T)); } else { - auto blas = math::GetBlas(context); + auto blas = + pten::funcs::GetBlas(context); blas.VCOPY(row_width, table + id_index * row_width, output + i * row_width); } diff --git a/paddle/fluid/operators/lookup_table_v2_op.h b/paddle/fluid/operators/lookup_table_v2_op.h index 88149f53ac7e1..39ba4d5ac619f 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.h +++ b/paddle/fluid/operators/lookup_table_v2_op.h @@ -22,7 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows_utils.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -124,7 +124,8 @@ struct LookupTableV2CPUFunctor { memcpy(output + i * row_width, table + id_index * row_width, row_width * sizeof(T)); } else { - auto blas = math::GetBlas(context_); + auto blas = + pten::funcs::GetBlas(context_); blas.VCOPY(row_width, table + id_index * row_width, output + i * row_width); } diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index bee8b5396af5f..cedb74c27ed8e 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include #include #include -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" @@ -35,7 +35,7 @@ struct LRNFunctor { framework::Tensor* mid, int N, int C, int H, int W, int n, T k, T alpha, T beta, const DataLayout data_layout) { auto place = ctx.GetPlace(); - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); pten::funcs::Transpose transpose; auto& dev_ctx = ctx.template device_context(); Tensor in_transpose, mid_transpose, out_transpose; diff --git a/paddle/fluid/operators/lstm_op.h b/paddle/fluid/operators/lstm_op.h index df94952a9a693..ea63eea8f8072 100644 --- a/paddle/fluid/operators/lstm_op.h +++ b/paddle/fluid/operators/lstm_op.h @@ -15,10 +15,10 @@ limitations under the License. */ #pragma once #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/sequence2batch.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -128,7 +128,7 @@ class LSTMKernel : public framework::OpKernel { auto cand_act = math::detail::GetActivationType( ctx.Attr("candidate_activation")); - auto blas = math::GetBlas(device_ctx); + auto blas = pten::funcs::GetBlas(device_ctx); for (size_t n = 0; n < num_batch; n++) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); @@ -302,7 +302,7 @@ class LSTMGradKernel : public framework::OpKernel { auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; - auto blas = math::GetBlas(device_ctx); + auto blas = pten::funcs::GetBlas(device_ctx); for (int n = static_cast(num_batch) - 1; n >= 0; n--) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); diff --git a/paddle/fluid/operators/lstmp_op.h b/paddle/fluid/operators/lstmp_op.h index c63184f76e702..82f508b73c2c6 100644 --- a/paddle/fluid/operators/lstmp_op.h +++ b/paddle/fluid/operators/lstmp_op.h @@ -18,12 +18,12 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/transform.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -185,7 +185,7 @@ class LSTMPKernel : public framework::OpKernel { auto proj_act = math::detail::GetActivationType( ctx.Attr("proj_activation")); auto& place = *ctx.template device_context().eigen_device(); - auto blas = math::GetBlas(device_ctx); + auto blas = pten::funcs::GetBlas(device_ctx); for (size_t n = 0; n < num_batch; n++) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); @@ -405,7 +405,7 @@ class LSTMPGradKernel : public framework::OpKernel { auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; - auto blas = math::GetBlas(device_ctx); + auto blas = pten::funcs::GetBlas(device_ctx); for (int n = static_cast(num_batch) - 1; n >= 0; n--) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); diff --git a/paddle/fluid/operators/lstsq_op.h b/paddle/fluid/operators/lstsq_op.h index f39d65d681f2f..a6bb94c34c854 100644 --- a/paddle/fluid/operators/lstsq_op.h +++ b/paddle/fluid/operators/lstsq_op.h @@ -19,13 +19,13 @@ #include #include "paddle/fluid/operators/eig_op.h" #include "paddle/fluid/operators/math/eigen_values_vectors.h" -#include "paddle/fluid/operators/math/lapack_function.h" #include "paddle/fluid/operators/math/matrix_solve.h" #include "paddle/fluid/operators/svd_helper.h" #include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/operators/triangular_solve_op.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/pten/kernels/funcs/complex_functors.h" +#include "paddle/pten/kernels/funcs/lapack/lapack_function.h" #include "paddle/pten/kernels/funcs/math_function.h" #define EPSILON 1e-6 @@ -153,20 +153,21 @@ class LstsqCPUKernel : public framework::OpKernel { int iwkopt = 0; if (driver == LapackDriverType::Gels) { - math::lapackGels('N', m, n, nrhs, x_vector, lda, y_vector, ldb, &wkopt, - lwork, &info); + pten::funcs::lapackGels('N', m, n, nrhs, x_vector, lda, y_vector, ldb, + &wkopt, lwork, &info); } else if (driver == LapackDriverType::Gelsd) { - math::lapackGelsd(m, n, nrhs, x_vector, lda, y_vector, ldb, s_working_ptr, - static_cast(rcond), &rank_32, &wkopt, lwork, - &rwkopt, &iwkopt, &info); + pten::funcs::lapackGelsd(m, n, nrhs, x_vector, lda, y_vector, ldb, + s_working_ptr, static_cast(rcond), + &rank_32, &wkopt, lwork, &rwkopt, &iwkopt, + &info); } else if (driver == LapackDriverType::Gelsy) { - math::lapackGelsy(m, n, nrhs, x_vector, lda, y_vector, ldb, jpvt_data, - static_cast(rcond), &rank_32, &wkopt, lwork, - &rwkopt, &info); + pten::funcs::lapackGelsy(m, n, nrhs, x_vector, lda, y_vector, ldb, + jpvt_data, static_cast(rcond), + &rank_32, &wkopt, lwork, &rwkopt, &info); } else if (driver == LapackDriverType::Gelss) { - math::lapackGelss(m, n, nrhs, x_vector, lda, y_vector, ldb, s_working_ptr, - static_cast(rcond), &rank_32, &wkopt, lwork, - &rwkopt, &info); + pten::funcs::lapackGelss(m, n, nrhs, x_vector, lda, y_vector, ldb, + s_working_ptr, static_cast(rcond), + &rank_32, &wkopt, lwork, &rwkopt, &info); } lwork = std::max(1, static_cast(pten::funcs::Real(wkopt))); @@ -206,20 +207,21 @@ class LstsqCPUKernel : public framework::OpKernel { s_working_ptr = s_working_ptr ? &s_data[i * s_stride] : nullptr; if (driver == LapackDriverType::Gels) { - math::lapackGels('N', m, n, nrhs, x_input, lda, y_input, ldb, work_data, - lwork, &info); + pten::funcs::lapackGels('N', m, n, nrhs, x_input, lda, y_input, ldb, + work_data, lwork, &info); } else if (driver == LapackDriverType::Gelsd) { - math::lapackGelsd(m, n, nrhs, x_input, lda, y_input, ldb, s_working_ptr, - static_cast(rcond), &rank_32, work_data, - lwork, rwork_data, iwork_data, &info); + pten::funcs::lapackGelsd(m, n, nrhs, x_input, lda, y_input, ldb, + s_working_ptr, static_cast(rcond), + &rank_32, work_data, lwork, rwork_data, + iwork_data, &info); } else if (driver == LapackDriverType::Gelsy) { - math::lapackGelsy(m, n, nrhs, x_input, lda, y_input, ldb, jpvt_data, - static_cast(rcond), &rank_32, work_data, - lwork, rwork_data, &info); + pten::funcs::lapackGelsy(m, n, nrhs, x_input, lda, y_input, ldb, + jpvt_data, static_cast(rcond), + &rank_32, work_data, lwork, rwork_data, &info); } else if (driver == LapackDriverType::Gelss) { - math::lapackGelss(m, n, nrhs, x_input, lda, y_input, ldb, s_working_ptr, - static_cast(rcond), &rank_32, work_data, - lwork, rwork_data, &info); + pten::funcs::lapackGelss(m, n, nrhs, x_input, lda, y_input, ldb, + s_working_ptr, static_cast(rcond), + &rank_32, work_data, lwork, rwork_data, &info); } PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/operators/lu_op.cc b/paddle/fluid/operators/lu_op.cc index aff6a77762fa3..a51c0c1898a29 100644 --- a/paddle/fluid/operators/lu_op.cc +++ b/paddle/fluid/operators/lu_op.cc @@ -142,8 +142,8 @@ class LUKernel : public framework::OpKernel { auto out_data_item = &out_data[b * m * n]; int *info_data_item = &info_data[b]; int *ipiv_data_item = &ipiv_data[b * std::min(m, n)]; - math::lapackLu(m, n, out_data_item, lda, ipiv_data_item, - info_data_item); + pten::funcs::lapackLu(m, n, out_data_item, lda, ipiv_data_item, + info_data_item); } *out = helper.Transpose(*out); } diff --git a/paddle/fluid/operators/lu_op.h b/paddle/fluid/operators/lu_op.h index 0d05d766e67fb..aedae9f93458d 100644 --- a/paddle/fluid/operators/lu_op.h +++ b/paddle/fluid/operators/lu_op.h @@ -15,11 +15,11 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/lapack_function.h" #include "paddle/fluid/operators/set_value_op.h" #include "paddle/fluid/operators/svd_helper.h" #include "paddle/fluid/operators/triangular_solve_op.h" #include "paddle/fluid/operators/tril_triu_op.h" +#include "paddle/pten/kernels/funcs/lapack/lapack_function.h" #include "paddle/pten/kernels/math_kernel.h" namespace paddle { @@ -489,7 +489,7 @@ class LUGradKernel : public framework::OpKernel { const auto& dev_ctx = ctx.template device_context(); math::DeviceIndependenceTensorOperations helper(ctx); - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); auto xdims = xin->dims(); int xrank = xdims.size(); @@ -519,9 +519,9 @@ class LUGradKernel : public framework::OpKernel { phi_L.mutable_data(ctx.GetPlace()); phi_U.Resize(UmHdims); phi_U.mutable_data(ctx.GetPlace()); - auto mat_dim_l = math::CreateMatrixDescriptor(LmHdims, 0, false); - auto mat_dim_u = math::CreateMatrixDescriptor(UmHdims, 0, false); - auto mat_dim_g = math::CreateMatrixDescriptor(graddims, 0, false); + auto mat_dim_l = pten::funcs::CreateMatrixDescriptor(LmHdims, 0, false); + auto mat_dim_u = pten::funcs::CreateMatrixDescriptor(UmHdims, 0, false); + auto mat_dim_g = pten::funcs::CreateMatrixDescriptor(graddims, 0, false); blas.MatMul(L_narrow_mH, mat_dim_l, grad_narrow, mat_dim_g, static_cast(1), &phi_L, static_cast(0)); @@ -567,10 +567,10 @@ class LUGradKernel : public framework::OpKernel { Tensor_Conj(dev_ctx, U_complement_mH, &U_complement_mH); - auto mat_dim_g = - math::CreateMatrixDescriptor(U_grad_complement.dims(), 0, false); - auto mat_dim_u = - math::CreateMatrixDescriptor(U_complement_mH.dims(), 0, false); + auto mat_dim_g = pten::funcs::CreateMatrixDescriptor( + U_grad_complement.dims(), 0, false); + auto mat_dim_u = pten::funcs::CreateMatrixDescriptor( + U_complement_mH.dims(), 0, false); auto phidims = UmHdims; phidims[UmHdims.size() - 2] = k; phidims[UmHdims.size() - 1] = k; @@ -623,8 +623,10 @@ class LUGradKernel : public framework::OpKernel { triangular_solve(dev_ctx, L_narrow_mH, psi, &psi_tmp, true, false, true); - auto mat_dim_p = math::CreateMatrixDescriptor(Pmat.dims(), 0, false); - auto mat_dim_b = math::CreateMatrixDescriptor(psi_tmp.dims(), 0, false); + auto mat_dim_p = + pten::funcs::CreateMatrixDescriptor(Pmat.dims(), 0, false); + auto mat_dim_b = + pten::funcs::CreateMatrixDescriptor(psi_tmp.dims(), 0, false); blas.MatMul(Pmat, mat_dim_p, psi_tmp, mat_dim_b, static_cast(1), dx, static_cast(0)); } else { @@ -636,10 +638,10 @@ class LUGradKernel : public framework::OpKernel { framework::Tensor L_complement_mH = helper.Transpose(L_complement); Tensor_Conj(dev_ctx, L_complement_mH, &L_complement_mH); - auto mat_dim_g = - math::CreateMatrixDescriptor(L_grad_complement.dims(), 0, false); + auto mat_dim_g = pten::funcs::CreateMatrixDescriptor( + L_grad_complement.dims(), 0, false); auto mat_dim_u = - math::CreateMatrixDescriptor(L_complement_mH.dims(), 0, false); + pten::funcs::CreateMatrixDescriptor(L_complement_mH.dims(), 0, false); auto phidims = LmHdims; phidims[LmHdims.size() - 2] = k; phidims[LmHdims.size() - 1] = k; @@ -685,8 +687,10 @@ class LUGradKernel : public framework::OpKernel { psi_tmp.Resize(psi.dims()); psi_tmp.mutable_data(ctx.GetPlace()); - auto mat_dim_p = math::CreateMatrixDescriptor(Pmat.dims(), 0, false); - auto mat_dim_b = math::CreateMatrixDescriptor(psi.dims(), 0, false); + auto mat_dim_p = + pten::funcs::CreateMatrixDescriptor(Pmat.dims(), 0, false); + auto mat_dim_b = + pten::funcs::CreateMatrixDescriptor(psi.dims(), 0, false); blas.MatMul(Pmat, mat_dim_p, psi, mat_dim_b, static_cast(1), &psi_tmp, static_cast(0)); psi_tmp = helper.Transpose(psi_tmp); diff --git a/paddle/fluid/operators/match_matrix_tensor_op.cc b/paddle/fluid/operators/match_matrix_tensor_op.cc index e95aef8eb563f..b7f67041b0308 100644 --- a/paddle/fluid/operators/match_matrix_tensor_op.cc +++ b/paddle/fluid/operators/match_matrix_tensor_op.cc @@ -249,7 +249,7 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel { memset(bottom_l_trans_data, 0.0, tmp->dims()[0] * tmp->dims()[1] * sizeof(T)); - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); call_gemm(blas, CblasNoTrans, CblasNoTrans, x->dims()[0], dim_t * dim_in, dim_in, 1.0f, bottom_l_data, t_data, 0.0f, bottom_l_trans_data); @@ -262,7 +262,7 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel { const auto* l_t_data = bottom_l_trans_data + offset_l[b] * dim_t * dim_in + t * dim_in; const auto* r_data = bottom_r_data + offset_r[b] * dim_in; - auto blas_2 = math::GetBlas(ctx); + auto blas_2 = pten::funcs::GetBlas(ctx); call_gemm_with_lda(blas_2, CblasNoTrans, CblasTrans, len_l, len_r, dim_in, 1.0f, l_t_data, r_data, 0.0f, top_data, dim_t * dim_in); @@ -346,7 +346,7 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel { } } - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); auto* t_data = w->data(); auto* d_w = ctx.Output(framework::GradVarName("W")); diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index a97e2ecfce701..52310ab9d48eb 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -1,44 +1,5 @@ add_subdirectory(detail) -function(math_library TARGET) - # math_library is a function to create math library. - # The interface is the same as cc_library. - # But it handle split GPU/CPU code and link some common library. - set(cc_srcs) - set(cu_srcs) - set(hip_srcs) - set(math_common_deps device_context framework_proto enforce) - if (WITH_GPU) - if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) - list(APPEND math_common_deps cub) - else() - list(APPEND math_common_deps) - endif() - endif() - set(multiValueArgs DEPS) - cmake_parse_arguments(math_library "${options}" "${oneValueArgs}" - "${multiValueArgs}" ${ARGN}) - - if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc) - list(APPEND cc_srcs ${TARGET}.cc) - endif() - if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) - list(APPEND cu_srcs ${TARGET}.cu) - endif() - if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc) - list(APPEND cu_srcs ${TARGET}.cu.cc) - endif() - - list(LENGTH cc_srcs cc_srcs_len) - if (WITH_GPU) - nv_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) - elseif (WITH_ROCM) - hip_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) - elseif(${cc_srcs_len} GREATER 0) - cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) - endif() -endfunction() - if (WITH_ASCEND_CL) cc_library(beam_search_npu SRCS beam_search_npu.cc DEPS npu_op_runner) endif() @@ -59,9 +20,6 @@ math_library(sampler DEPS generator) math_library(gru_compute DEPS activation_functions math_function) math_library(lstm_compute DEPS activation_functions) - -cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context) -# math_library(math_function DEPS blas dense_tensor tensor) math_library(maxouting) math_library(pooling) @@ -82,8 +40,6 @@ else() math_library(beam_search DEPS math_function) endif() math_library(fc DEPS blas) -math_library(lapack_function DEPS dynload_lapack) - math_library(matrix_bit_code) math_library(unpooling) diff --git a/paddle/fluid/operators/math/bert_encoder_functor.cu b/paddle/fluid/operators/math/bert_encoder_functor.cu index b9b209646dbcf..1c94fa67f8b42 100644 --- a/paddle/fluid/operators/math/bert_encoder_functor.cu +++ b/paddle/fluid/operators/math/bert_encoder_functor.cu @@ -17,8 +17,8 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_cuda_utils.h" namespace paddle { @@ -502,7 +502,7 @@ inline void MatMulWithHeadQK(const platform::CUDADeviceContext &context, typedef typename CUDATypeTraits::TYPE run_type; auto blas = - operators::math::GetBlas(context); + pten::funcs::GetBlas(context); auto stream = context.stream(); blas.BatchedGEMM( @@ -568,7 +568,7 @@ inline void MatMulWithHeadQKV(const platform::CUDADeviceContext &context, typedef typename CUDATypeTraits::TYPE run_type; auto blas = - operators::math::GetBlas(context); + pten::funcs::GetBlas(context); auto stream = context.stream(); CBLAS_TRANSPOSE transA = !qk_trans ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = !v_trans ? CblasNoTrans : CblasTrans; diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h deleted file mode 100644 index 0e6b63be90ef6..0000000000000 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ /dev/null @@ -1,1804 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/platform/dynload/cublas.h" -#include "paddle/pten/kernels/funcs/math_function.h" - -#include "paddle/fluid/platform/device/gpu/gpu_info.h" -#include "paddle/pten/backends/gpu/gpu_context.h" - -DECLARE_bool(enable_cublas_tensor_op_math); - -namespace paddle { -namespace operators { -namespace math { - -template -struct CUBlas; - -template <> -struct CUBlas { - template - static void GEMM(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasSgemm(args...)); - } - - template - static void AXPY(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasSaxpy(args...)); - } - - template - static void SCAL(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasSscal(args...)); - } - - template - static void VCOPY(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasScopy(args...)); - } - - template - static void GEMV(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasSgemv(args...)); - } - - template - static void GEMM_STRIDED_BATCH(ARGS... args) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasSgemmStridedBatched(args...)); -#else - PADDLE_THROW(platform::errors::Unimplemented( - "SgemmStridedBatched is not supported on cuda <= 7.5")); -#endif - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(platform::CUDADeviceContext *dev_ctx, - cublasOperation_t transa, cublasOperation_t transb, int m, - int n, int k, const float *alpha, const void *A, - cudaDataType_t Atype, int lda, const void *B, - cudaDataType_t Btype, int ldb, const float *beta, void *C, - cudaDataType_t Ctype, int ldc) { -// Because the gcc 4.8 doesn't expand template parameter pack that -// appears in a lambda-expression, I can not use template parameter pack -// here. -#if CUDA_VERSION >= 8000 - VLOG(5) << "use_tensor_op_math: " - << (dev_ctx->tensor_core_available() ? "True" : "False"); - dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasSgemmEx( - handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, - beta, C, Ctype, ldc)); - }); -#else - PADDLE_THROW(platform::errors::Unimplemented( - "cublasSgemmEx is not supported on cuda <= 7.5")); -#endif - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(pten::GPUContext *dev_ctx, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const void *A, cudaDataType_t Atype, - int lda, const void *B, cudaDataType_t Btype, int ldb, - const float *beta, void *C, cudaDataType_t Ctype, - int ldc) { -// Because the gcc 4.8 doesn't expand template parameter pack that -// appears in a lambda-expression, I can not use template parameter pack -// here. -#if CUDA_VERSION >= 8000 - VLOG(5) << "use_tensor_op_math: " - << (dev_ctx->tensor_core_available() ? "True" : "False"); - dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasSgemmEx( - handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, - beta, C, Ctype, ldc)); - }); -#else - PADDLE_THROW(platform::errors::Unimplemented( - "cublasSgemmEx is not supported on cuda <= 7.5")); -#endif - } - - template - static void TRSM(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasStrsm(args...)); - } - - template - static void GETRF_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasSgetrfBatched(args...)); - } - - template - static void GETRI_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasSgetriBatched(args...)); - } - - template - static void MATINV_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasSmatinvBatched(args...)); - } - - template - static void GETRS_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasSgetrsBatched(args...)); - } - - template - static void TRSM_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasStrsmBatched(args...)); - } -}; - -template <> -struct CUBlas { - template - static void GEMM(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasDgemm(args...)); - } - - template - static void AXPY(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasDaxpy(args...)); - } - - template - static void SCAL(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasDscal(args...)); - } - - template - static void VCOPY(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasDcopy(args...)); - } - - template - static void GEMV(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasDgemv(args...)); - } - - template - static void GEMM_STRIDED_BATCH(ARGS... args) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasDgemmStridedBatched(args...)); -#else - PADDLE_THROW(platform::errors::Unimplemented( - "DgemmStridedBatched is not supported on cuda <= 7.5")); -#endif - } - - template - static void GEMM_EX(ARGS... args) { - PADDLE_THROW(platform::errors::Unimplemented( - "Currently there are not cublasDgemmEx.")); - } - - template - static void TRSM(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasDtrsm(args...)); - } - - template - static void GETRF_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasDgetrfBatched(args...)); - } - - template - static void GETRI_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasDgetriBatched(args...)); - } - - template - static void MATINV_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasDmatinvBatched(args...)); - } - - template - static void GETRS_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasDgetrsBatched(args...)); - } - - template - static void TRSM_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasDtrsmBatched(args...)); - } -}; - -template <> -struct CUBlas { - using float16 = platform::float16; - - static void GEMM(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float16 *alpha, const float16 *A, int lda, - const float16 *B, int ldb, const float16 *beta, float16 *C, - int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasHgemm(handle, transa, transb, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb, - reinterpret_cast(beta), - reinterpret_cast<__half *>(C), ldc)); - } - - static void GEMM_STRIDED_BATCH(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float16 *alpha, const float16 *A, - int lda, long long int strideA, // NOLINT - const float16 *B, // NOLINT - int ldb, long long int strideB, // NOLINT - const float16 *beta, float16 *C, int ldc, - long long int strideC, // NOLINT - int batchCount) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasHgemmStridedBatched( - handle, transa, transb, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, strideA, - reinterpret_cast(B), ldb, strideB, - reinterpret_cast(beta), reinterpret_cast<__half *>(C), - ldc, strideC, batchCount)); -#else - PADDLE_THROW(platform::errors::Unimplemented( - "HgemmStridedBatched is not supported on cuda <= 7.5")); -#endif - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(platform::CUDADeviceContext *dev_ctx, - cublasOperation_t transa, cublasOperation_t transb, int m, - int n, int k, const void *alpha, const void *A, - cudaDataType_t Atype, int lda, const void *B, - cudaDataType_t Btype, int ldb, const void *beta, void *C, - cudaDataType_t Ctype, int ldc, - cudaDataType_t computeType) { -#if CUDA_VERSION >= 8000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -#if CUDA_VERSION >= 9000 - bool use_tensor_op_math = dev_ctx->tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); -#endif // CUDA_VERSION >= 9000 - - dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx( - handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, - beta, C, Ctype, ldc, computeType, algo)); - }); -#else - PADDLE_THROW(platform::errors::Unimplemented( - "cublasGemmEx is not supported on cuda <= 7.5")); -#endif - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(pten::GPUContext *dev_ctx, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const void *alpha, const void *A, cudaDataType_t Atype, - int lda, const void *B, cudaDataType_t Btype, int ldb, - const void *beta, void *C, cudaDataType_t Ctype, int ldc, - cudaDataType_t computeType) { -#if CUDA_VERSION >= 8000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -#if CUDA_VERSION >= 9000 - bool use_tensor_op_math = dev_ctx->tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); -#endif // CUDA_VERSION >= 9000 - - dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx( - handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, - beta, C, Ctype, ldc, computeType, algo)); - }); -#else - PADDLE_THROW(platform::errors::Unimplemented( - "cublasGemmEx is not supported on cuda <= 7.5")); -#endif - } -}; - -template <> -struct CUBlas> { - static void GEMV(cublasHandle_t handle, cublasOperation_t transa, int m, - int n, const platform::complex *alpha, - const platform::complex *A, int lda, - const platform::complex *B, int ldb, - const platform::complex *beta, - platform::complex *C, int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasCgemv( - handle, transa, m, n, reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb, - reinterpret_cast(beta), - reinterpret_cast(C), ldc)); - } - - static void AXPY(cublasHandle_t handle, int n, - const platform::complex *alpha, - const platform::complex *X, const int incX, - platform::complex *Y, const int incY) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasCaxpy( - handle, n, reinterpret_cast(alpha), - reinterpret_cast(X), incX, - reinterpret_cast(Y), incY)); - } - - static void GEMM_STRIDED_BATCH(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const platform::complex *alpha, - const platform::complex *A, int lda, - long long int strideA, // NOLINT - const platform::complex *B, // NOLINT - int ldb, long long int strideB, // NOLINT - const platform::complex *beta, - platform::complex *C, int ldc, - long long int strideC, // NOLINT - int batchCount) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasCgemmStridedBatched( - handle, transa, transb, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, strideA, - reinterpret_cast(B), ldb, strideB, - reinterpret_cast(beta), - reinterpret_cast(C), ldc, strideC, batchCount)); -#else - PADDLE_THROW(platform::errors::Unimplemented( - "CgemmStridedBatched is not supported on cuda <= 7.5")); -#endif - } - - static void GEMM(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const platform::complex *alpha, - const platform::complex *A, int lda, - const platform::complex *B, int ldb, - const platform::complex *beta, - platform::complex *C, int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasCgemm( - handle, transa, transb, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb, - reinterpret_cast(beta), - reinterpret_cast(C), ldc)); - } - - static void TRSM(cublasHandle_t handle, cublasSideMode_t side, - cublasFillMode_t uplo, cublasOperation_t transa, - cublasDiagType_t diag, int m, int n, - const paddle::platform::complex *alpha, - const paddle::platform::complex *A, int lda, - paddle::platform::complex *B, int ldb) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasCtrsm( - handle, side, uplo, transa, diag, m, n, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb)); - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(platform::CUDADeviceContext *dev_ctx, - cublasOperation_t transa, cublasOperation_t transb, int m, - int n, int k, const void *alpha, const void *A, - cudaDataType_t Atype, int lda, const void *B, - cudaDataType_t Btype, int ldb, const void *beta, void *C, - cudaDataType_t Ctype, int ldc, - cudaDataType_t computeType) { -#if CUDA_VERSION >= 8000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -#if CUDA_VERSION >= 9000 - bool use_tensor_op_math = dev_ctx->tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); -#endif // CUDA_VERSION >= 9000 - - dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx( - handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, - beta, C, Ctype, ldc, computeType, algo)); - }); -#else - PADDLE_THROW(platform::errors::Unimplemented( - "cublasGemmEx is not supported on cuda <= 7.5")); -#endif - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(pten::GPUContext *dev_ctx, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const void *alpha, const void *A, cudaDataType_t Atype, - int lda, const void *B, cudaDataType_t Btype, int ldb, - const void *beta, void *C, cudaDataType_t Ctype, int ldc, - cudaDataType_t computeType) { -#if CUDA_VERSION >= 8000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -#if CUDA_VERSION >= 9000 - bool use_tensor_op_math = dev_ctx->tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); -#endif // CUDA_VERSION >= 9000 - - dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx( - handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, - beta, C, Ctype, ldc, computeType, algo)); - }); -#else - PADDLE_THROW(platform::errors::Unimplemented( - "cublasGemmEx is not supported on cuda <= 7.5")); -#endif - } - - static void TRSM_BATCH(cublasHandle_t handle, cublasSideMode_t side, - cublasFillMode_t uplo, cublasOperation_t transa, - cublasDiagType_t diag, int m, int n, - const paddle::platform::complex *alpha, - const paddle::platform::complex **A, int lda, - paddle::platform::complex **B, int ldb, - int batch_size) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasCtrsmBatched( - handle, side, uplo, transa, diag, m, n, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb, batch_size)); - } -}; - -template <> -struct CUBlas> { - static void GEMV(cublasHandle_t handle, cublasOperation_t transa, int m, - int n, const platform::complex *alpha, - const platform::complex *A, int lda, - const platform::complex *B, int ldb, - const platform::complex *beta, - platform::complex *C, int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasZgemv( - handle, transa, m, n, reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb, - reinterpret_cast(beta), - reinterpret_cast(C), ldc)); - } - - static void AXPY(cublasHandle_t handle, int n, - const platform::complex *alpha, - const platform::complex *X, const int incX, - platform::complex *Y, const int incY) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasZaxpy( - handle, n, reinterpret_cast(alpha), - reinterpret_cast(X), incX, - reinterpret_cast(Y), incY)); - } - - static void GEMM_STRIDED_BATCH(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const platform::complex *alpha, - const platform::complex *A, int lda, - long long int strideA, // NOLINT - const platform::complex *B, // NOLINT - int ldb, long long int strideB, // NOLINT - const platform::complex *beta, - platform::complex *C, int ldc, - long long int strideC, // NOLINT - int batchCount) { -#if CUDA_VERSION >= 8000 - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasZgemmStridedBatched( - handle, transa, transb, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, strideA, - reinterpret_cast(B), ldb, strideB, - reinterpret_cast(beta), - reinterpret_cast(C), ldc, strideC, batchCount)); -#else - PADDLE_THROW(platform::errors::Unimplemented( - "CgemmStridedBatched is not supported on cuda <= 7.5")); -#endif - } - - static void GEMM(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const platform::complex *alpha, - const platform::complex *A, int lda, - const platform::complex *B, int ldb, - const platform::complex *beta, - platform::complex *C, int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasZgemm( - handle, transa, transb, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb, - reinterpret_cast(beta), - reinterpret_cast(C), ldc)); - } - - static void TRSM(cublasHandle_t handle, cublasSideMode_t side, - cublasFillMode_t uplo, cublasOperation_t transa, - cublasDiagType_t diag, int m, int n, - const paddle::platform::complex *alpha, - const paddle::platform::complex *A, int lda, - paddle::platform::complex *B, int ldb) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasZtrsm( - handle, side, uplo, transa, diag, m, n, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb)); - } - - static void TRSM_BATCH(cublasHandle_t handle, cublasSideMode_t side, - cublasFillMode_t uplo, cublasOperation_t transa, - cublasDiagType_t diag, int m, int n, - const paddle::platform::complex *alpha, - const paddle::platform::complex **A, int lda, - paddle::platform::complex **B, int ldb, - int batch_size) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasZtrsmBatched( - handle, side, uplo, transa, diag, m, n, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb, batch_size)); - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(platform::CUDADeviceContext *dev_ctx, - cublasOperation_t transa, cublasOperation_t transb, int m, - int n, int k, const void *alpha, const void *A, - cudaDataType_t Atype, int lda, const void *B, - cudaDataType_t Btype, int ldb, const void *beta, void *C, - cudaDataType_t Ctype, int ldc, - cudaDataType_t computeType) { -#if CUDA_VERSION >= 8000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -#if CUDA_VERSION >= 9000 - bool use_tensor_op_math = dev_ctx->tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); -#endif // CUDA_VERSION >= 9000 - - dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx( - handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, - beta, C, Ctype, ldc, computeType, algo)); - }); -#else - PADDLE_THROW(platform::errors::Unimplemented( - "cublasGemmEx is not supported on cuda <= 7.5")); -#endif - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(pten::GPUContext *dev_ctx, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const void *alpha, const void *A, cudaDataType_t Atype, - int lda, const void *B, cudaDataType_t Btype, int ldb, - const void *beta, void *C, cudaDataType_t Ctype, int ldc, - cudaDataType_t computeType) { -#if CUDA_VERSION >= 8000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -#if CUDA_VERSION >= 9000 - bool use_tensor_op_math = dev_ctx->tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); -#endif // CUDA_VERSION >= 9000 - - dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx( - handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, - beta, C, Ctype, ldc, computeType, algo)); - }); -#else - PADDLE_THROW(platform::errors::Unimplemented( - "cublasGemmEx is not supported on cuda <= 7.5")); -#endif - } -}; - -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, - int N, int K, T alpha, const T *A, - const T *B, T beta, T *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - -#if CUDA_VERSION >= 8000 - if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { - auto &cuda_ctx = const_cast(context_); - CUBlas::GEMM_EX(&cuda_ctx, cuTransB, cuTransA, N, M, K, &alpha, B, - CUDA_R_32F, ldb, A, CUDA_R_32F, lda, &beta, C, - CUDA_R_32F, N); - } else { -#endif // CUDA_VERSION >= 8000 - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, - lda, &beta, C, N); - }); - -#if CUDA_VERSION >= 8000 - } -#endif // CUDA_VERSION >= 8000 -} - -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, int K, - T alpha, const T *A, const T *B, T beta, - T *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - -#if CUDA_VERSION >= 8000 - if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { - auto &cuda_ctx = const_cast(context_); - CUBlas::GEMM_EX(&cuda_ctx, cuTransB, cuTransA, N, M, K, &alpha, B, - CUDA_R_32F, ldb, A, CUDA_R_32F, lda, &beta, C, - CUDA_R_32F, N); - } else { -#endif // CUDA_VERSION >= 8000 - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, - lda, &beta, C, N); - }); - -#if CUDA_VERSION >= 8000 - } -#endif // CUDA_VERSION >= 8000 -} - -template <> -template <> -inline void Blas::GEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::float16 alpha, const platform::float16 *A, - const platform::float16 *B, platform::float16 beta, - platform::float16 *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - context_.GetComputeCapability(), 53, - platform::errors::InvalidArgument( - "cublas fp16 gemm requires GPU compute capability >= 53," - "but received %d", - context_.GetComputeCapability())); - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - -#if CUDA_VERSION >= 8000 - // cublasHgemm does true FP16 computation which is slow for non-Volta - // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: - // input/output in fp16, computation in fp32, which can also be accelerated - // using tensor cores in volta GPUs. - auto &cuda_ctx = const_cast(context_); - CUBlas::GEMM_EX( - &cuda_ctx, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16F, ldb, A, - CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F); -#else - // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, - &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C, - N); - }); -#endif // CUDA_VERSION >= 8000 -} - -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, - int K, platform::float16 alpha, - const platform::float16 *A, - const platform::float16 *B, - platform::float16 beta, - platform::float16 *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - context_.GetComputeCapability(), 53, - platform::errors::InvalidArgument( - "cublas fp16 gemm requires GPU compute capability >= 53," - "but received %d", - context_.GetComputeCapability())); - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - -#if CUDA_VERSION >= 8000 - // cublasHgemm does true FP16 computation which is slow for non-Volta - // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: - // input/output in fp16, computation in fp32, which can also be accelerated - // using tensor cores in volta GPUs. - auto &cuda_ctx = const_cast(context_); - CUBlas::GEMM_EX( - &cuda_ctx, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16F, ldb, A, - CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F); -#else - // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, - &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C, - N); - }); -#endif // CUDA_VERSION >= 8000 -} - -template <> -template <> -inline void Blas::GEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::bfloat16 alpha, const platform::bfloat16 *A, - const platform::bfloat16 *B, platform::bfloat16 beta, - platform::bfloat16 *C) const { -#if CUDA_VERSION >= 11000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - context_.GetComputeCapability(), 80, - platform::errors::InvalidArgument( - "cublas fp16 gemm requires GPU compute capability >= 80," - "but received %d", - context_.GetComputeCapability())); - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; - bool use_tensor_op_math = context_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx( - handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb, A, - CUDA_R_16BF, lda, &h_beta, C, CUDA_R_16BF, N, CUDA_R_32F, algo)); - }); -#else - // raise error - PADDLE_THROW(platform::errors::Unimplemented( - "cublasGemmEx with bfloat16 is not supported on cuda <= 11")); - -#endif // CUDA_VERSION >= 11000 -} - -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, - int K, platform::bfloat16 alpha, - const platform::bfloat16 *A, - const platform::bfloat16 *B, - platform::bfloat16 beta, - platform::bfloat16 *C) const { -#if CUDA_VERSION >= 11000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - PADDLE_ENFORCE_GE( - context_.GetComputeCapability(), 80, - platform::errors::InvalidArgument( - "cublas bf16 gemm requires GPU compute capability >= 80," - "but received %d", - context_.GetComputeCapability())); - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; - bool use_tensor_op_math = context_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - - context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx( - handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb, A, - CUDA_R_16BF, lda, &h_beta, C, CUDA_R_16BF, N, CUDA_R_32F, algo)); - }); -#else - // raise error - PADDLE_THROW(platform::errors::Unimplemented( - "cublasGemmEx with bfloat16 is not supported on cuda <= 11")); - -#endif // CUDA_VERSION >= 11000 -} - -template <> -template <> -inline void Blas::GEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::complex alpha, const platform::complex *A, - const platform::complex *B, platform::complex beta, - platform::complex *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - context_.GetComputeCapability(), 53, - platform::errors::InvalidArgument( - "cublas complex64 gemm requires GPU compute capability >= 53," - "but received %d", - context_.GetComputeCapability())); - - thrust::complex c_alpha = - thrust::complex(alpha.real, alpha.imag); - thrust::complex c_beta = thrust::complex(beta.real, beta.imag); - -#if CUDA_VERSION >= 8000 - // cublasHgemm does true FP16 computation which is slow for non-Volta - // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: - // input/output in fp16, computation in fp32, which can also be accelerated - // using tensor cores in volta GPUs. - auto &cuda_ctx = const_cast(context_); - CUBlas>::GEMM_EX( - &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, CUDA_C_32F, ldb, A, - CUDA_C_32F, lda, &c_beta, C, CUDA_C_32F, N, CUDA_C_32F); -#else - // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas>::GEMM(handle, cuTransB, cuTransA, N, M, K, - &c_alpha, h_B, ldb, h_A, lda, - &c_beta, h_C, N); - }); -#endif // CUDA_VERSION >= 8000 -} - -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, - int K, platform::complex alpha, - const platform::complex *A, - const platform::complex *B, - platform::complex beta, - platform::complex *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - context_.GetComputeCapability(), 53, - platform::errors::InvalidArgument( - "cublas complex64 gemm requires GPU compute capability >= 53," - "but received %d", - context_.GetComputeCapability())); - - thrust::complex c_alpha = - thrust::complex(alpha.real, alpha.imag); - thrust::complex c_beta = thrust::complex(beta.real, beta.imag); - -#if CUDA_VERSION >= 8000 - // cublasHgemm does true FP16 computation which is slow for non-Volta - // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: - // input/output in fp16, computation in fp32, which can also be accelerated - // using tensor cores in volta GPUs. - auto &cuda_ctx = const_cast(context_); - CUBlas>::GEMM_EX( - &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, CUDA_C_32F, ldb, A, - CUDA_C_32F, lda, &c_beta, C, CUDA_C_32F, N, CUDA_C_32F); -#else - // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas>::GEMM(handle, cuTransB, cuTransA, N, M, K, - &c_alpha, h_B, ldb, h_A, lda, - &c_beta, h_C, N); - }); -#endif // CUDA_VERSION >= 8000 -} - -template <> -template <> -inline void Blas::GEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::complex alpha, const platform::complex *A, - const platform::complex *B, platform::complex beta, - platform::complex *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - context_.GetComputeCapability(), 53, - platform::errors::InvalidArgument( - "cublas complex128 gemm requires GPU compute capability >= 53," - "but received %d", - context_.GetComputeCapability())); - - thrust::complex c_alpha = - thrust::complex(alpha.real, alpha.imag); - thrust::complex c_beta = - thrust::complex(beta.real, beta.imag); - -#if CUDA_VERSION >= 8000 - // cublasHgemm does true FP16 computation which is slow for non-Volta - // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: - // input/output in fp16, computation in fp32, which can also be accelerated - // using tensor cores in volta GPUs. - auto &cuda_ctx = const_cast(context_); - CUBlas>::GEMM_EX( - &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, CUDA_C_64F, ldb, A, - CUDA_C_64F, lda, &c_beta, C, CUDA_C_64F, N, CUDA_C_64F); -#else - // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas>::GEMM(handle, cuTransB, cuTransA, N, M, K, - &c_alpha, h_B, ldb, h_A, lda, - &c_beta, h_C, N); - }); -#endif // CUDA_VERSION >= 8000 -} - -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, - int K, platform::complex alpha, - const platform::complex *A, - const platform::complex *B, - platform::complex beta, - platform::complex *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - context_.GetComputeCapability(), 53, - platform::errors::InvalidArgument( - "cublas complex128 gemm requires GPU compute capability >= 53," - "but received %d", - context_.GetComputeCapability())); - - thrust::complex c_alpha = - thrust::complex(alpha.real, alpha.imag); - thrust::complex c_beta = - thrust::complex(beta.real, beta.imag); - -#if CUDA_VERSION >= 8000 - // cublasHgemm does true FP16 computation which is slow for non-Volta - // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: - // input/output in fp16, computation in fp32, which can also be accelerated - // using tensor cores in volta GPUs. - auto &cuda_ctx = const_cast(context_); - CUBlas>::GEMM_EX( - &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, CUDA_C_64F, ldb, A, - CUDA_C_64F, lda, &c_beta, C, CUDA_C_64F, N, CUDA_C_64F); -#else - // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas>::GEMM(handle, cuTransB, cuTransA, N, M, K, - &c_alpha, h_B, ldb, h_A, lda, - &c_beta, h_C, N); - }); -#endif // CUDA_VERSION >= 8000 -} - -template <> -template -void Blas::GEMM(bool transA, bool transB, int M, - int N, int K, T alpha, const T *A, - int lda, const T *B, int ldb, - T beta, T *C, int ldc) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - -#if CUDA_VERSION >= 8000 - if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { - auto &cuda_ctx = const_cast(context_); - CUBlas::GEMM_EX(&cuda_ctx, cuTransB, cuTransA, N, M, K, &alpha, B, - CUDA_R_32F, ldb, A, CUDA_R_32F, lda, &beta, C, - CUDA_R_32F, ldc); - } else { -#endif // CUDA_VERSION >= 8000 - - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, - lda, &beta, C, ldc); - }); - -#if CUDA_VERSION >= 8000 - } -#endif // CUDA_VERSION >= 8000 -} - -template <> -template -void Blas::GEMM(bool transA, bool transB, int M, int N, int K, - T alpha, const T *A, int lda, const T *B, - int ldb, T beta, T *C, int ldc) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - -#if CUDA_VERSION >= 8000 - if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { - auto &cuda_ctx = const_cast(context_); - CUBlas::GEMM_EX(&cuda_ctx, cuTransB, cuTransA, N, M, K, &alpha, B, - CUDA_R_32F, ldb, A, CUDA_R_32F, lda, &beta, C, - CUDA_R_32F, ldc); - } else { -#endif // CUDA_VERSION >= 8000 - - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, - lda, &beta, C, ldc); - }); - -#if CUDA_VERSION >= 8000 - } -#endif // CUDA_VERSION >= 8000 -} - -template <> -template <> -inline void Blas::GEMM( - bool transA, bool transB, int M, int N, int K, platform::float16 alpha, - const platform::float16 *A, int lda, const platform::float16 *B, int ldb, - platform::float16 beta, platform::float16 *C, int ldc) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, - B, ldb, A, lda, &beta, C, ldc); - }); -} - -template <> -template <> -inline void Blas::GEMM(bool transA, bool transB, int M, int N, - int K, platform::float16 alpha, - const platform::float16 *A, int lda, - const platform::float16 *B, int ldb, - platform::float16 beta, - platform::float16 *C, int ldc) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, - B, ldb, A, lda, &beta, C, ldc); - }); -} - -template <> -template -void Blas::AXPY(int n, T alpha, const T *x, - T *y) const { - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::AXPY(handle, n, &alpha, x, 1, y, 1); - }); -} - -template <> -template -void Blas::AXPY(int n, T alpha, const T *x, T *y) const { - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::AXPY(handle, n, &alpha, x, 1, y, 1); - }); -} - -template <> -template -void Blas::SCAL(int n, const T alpha, T *x) const { - context_.CublasCall( - [&](cublasHandle_t handle) { CUBlas::SCAL(handle, n, &alpha, x, 1); }); -} - -template <> -template -void Blas::SCAL(int n, const T alpha, T *x) const { - context_.CublasCall( - [&](cublasHandle_t handle) { CUBlas::SCAL(handle, n, &alpha, x, 1); }); -} - -template <> -template -void Blas::VCOPY(int n, const T *x, T *y) const { - context_.CublasCall( - [&](cublasHandle_t handle) { CUBlas::VCOPY(handle, n, x, 1, y, 1); }); -} - -template <> -template -void Blas::VCOPY(int n, const T *x, T *y) const { - context_.CublasCall( - [&](cublasHandle_t handle) { CUBlas::VCOPY(handle, n, x, 1, y, 1); }); -} - -template <> -template -void Blas::GEMV(bool trans_a, int M, int N, - T alpha, const T *A, const T *B, - T beta, T *C) const { - cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; - - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1); - }); -} - -template <> -template -void Blas::GEMV(bool trans_a, int M, int N, T alpha, - const T *A, const T *B, T beta, T *C) const { - cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; - - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1); - }); -} - -template <> -template <> -inline void Blas::GEMV( - bool trans_a, int M, int N, platform::float16 alpha, - const platform::float16 *A, const platform::float16 *B, - platform::float16 beta, platform::float16 *C) const { - // Because cublas doesn't support half gemv, we use cublasHgemm to achieve it. - if (trans_a) { - this->template GEMM(CblasNoTrans, CblasNoTrans, 1, N, M, - alpha, B, A, beta, C); - } else { - this->template GEMM(CblasNoTrans, CblasNoTrans, M, 1, N, - alpha, A, B, beta, C); - } -} - -template <> -template <> -inline void Blas::GEMV(bool trans_a, int M, int N, - platform::float16 alpha, - const platform::float16 *A, - const platform::float16 *B, - platform::float16 beta, - platform::float16 *C) const { - // Because cublas doesn't support half gemv, we use cublasHgemm to achieve it. - if (trans_a) { - this->template GEMM(CblasNoTrans, CblasNoTrans, 1, N, M, - alpha, B, A, beta, C); - } else { - this->template GEMM(CblasNoTrans, CblasNoTrans, M, 1, N, - alpha, A, B, beta, C); - } -} - -template <> -template <> -inline void Blas::GEMV( - bool trans_a, int M, int N, platform::bfloat16 alpha, - const platform::bfloat16 *A, const platform::bfloat16 *B, - platform::bfloat16 beta, platform::bfloat16 *C) const { - // Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve - // it. - if (trans_a) { - this->template GEMM(CblasNoTrans, CblasNoTrans, 1, N, M, - alpha, B, A, beta, C); - } else { - this->template GEMM(CblasNoTrans, CblasNoTrans, M, 1, N, - alpha, A, B, beta, C); - } -} - -template <> -template <> -inline void Blas::GEMV(bool trans_a, int M, int N, - platform::bfloat16 alpha, - const platform::bfloat16 *A, - const platform::bfloat16 *B, - platform::bfloat16 beta, - platform::bfloat16 *C) const { - // Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve - // it. - if (trans_a) { - this->template GEMM(CblasNoTrans, CblasNoTrans, 1, N, M, - alpha, B, A, beta, C); - } else { - this->template GEMM(CblasNoTrans, CblasNoTrans, M, 1, N, - alpha, A, B, beta, C); - } -} - -template <> -template -void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - T alpha, const T *A, const T *B, T beta, T *C, int batchCount, - int64_t strideA, int64_t strideB) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int64_t strideC = M * N; - -#if CUDA_VERSION >= 9010 - if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same::value)) || - std::is_same::value) { - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = context_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); - - auto fp = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; - context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx( - handle, cuTransB, cuTransA, N, M, K, &alpha, B, fp, ldb, strideB, A, - fp, lda, strideA, &beta, C, fp, ldc, strideC, batchCount, fp, algo)); - }); - } else { -#endif // CUDA_VERSION >= 9010 - - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha, - B, ldb, strideB, A, lda, strideA, &beta, C, - ldc, strideC, batchCount); - }); - -#if CUDA_VERSION >= 9010 - } -#endif // CUDA_VERSION >= 9010 -} - -template <> -template -void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, - int K, T alpha, const T *A, const T *B, - T beta, T *C, int batchCount, - int64_t strideA, - int64_t strideB) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int64_t strideC = M * N; - -#if CUDA_VERSION >= 9010 - if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same::value)) || - std::is_same::value) { - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = context_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); - - auto fp = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; - context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx( - handle, cuTransB, cuTransA, N, M, K, &alpha, B, fp, ldb, strideB, A, - fp, lda, strideA, &beta, C, fp, ldc, strideC, batchCount, fp, algo)); - }); - } else { -#endif // CUDA_VERSION >= 9010 - - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha, - B, ldb, strideB, A, lda, strideA, &beta, C, - ldc, strideC, batchCount); - }); - -#if CUDA_VERSION >= 9010 - } -#endif // CUDA_VERSION >= 9010 -} - -template <> -template <> -inline void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::bfloat16 alpha, const platform::bfloat16 *A, - const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C, - int batchCount, int64_t strideA, int64_t strideB) const { -#if CUDA_VERSION >= 11000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int64_t strideC = M * N; - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = context_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - - context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx( - handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb, - strideB, A, CUDA_R_16BF, lda, strideA, &h_beta, C, CUDA_R_16BF, ldc, - strideC, batchCount, CUBLAS_COMPUTE_32F, algo)); - }); -#else - // raise error - PADDLE_THROW(platform::errors::Unimplemented( - "cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= " - "11")); -#endif // CUDA_VERSION >= 11000 -} - -template <> -template <> -inline void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::bfloat16 alpha, const platform::bfloat16 *A, - const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C, - int batchCount, int64_t strideA, int64_t strideB) const { -#if CUDA_VERSION >= 11000 - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int64_t strideC = M * N; - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = context_.tensor_core_available(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - - context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx( - handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb, - strideB, A, CUDA_R_16BF, lda, strideA, &h_beta, C, CUDA_R_16BF, ldc, - strideC, batchCount, CUBLAS_COMPUTE_32F, algo)); - }); -#else - // raise error - PADDLE_THROW(platform::errors::Unimplemented( - "cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= " - "11")); -#endif // CUDA_VERSION >= 11000 -} - -template <> -template -void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - T alpha, const T **A, const T **B, T beta, T **C, int batchCount) const { - for (int k = 0; k < batchCount; ++k) { - this->template GEMM(transA, transB, M, N, K, alpha, A[k], B[k], beta, - C[k]); - } -} - -template <> -template -void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, - int K, T alpha, const T **A, - const T **B, T beta, T **C, - int batchCount) const { - for (int k = 0; k < batchCount; ++k) { - this->template GEMM(transA, transB, M, N, K, alpha, A[k], B[k], beta, - C[k]); - } -} - -template <> -template <> -inline void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::float16 alpha, const platform::float16 **A, - const platform::float16 **B, platform::float16 beta, platform::float16 **C, - int batchCount) const { - for (int k = 0; k < batchCount; ++k) { - this->template GEMM(transA, transB, M, N, K, alpha, A[k], - B[k], beta, C[k]); - } -} - -template <> -template <> -inline void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::float16 alpha, const platform::float16 **A, - const platform::float16 **B, platform::float16 beta, platform::float16 **C, - int batchCount) const { - for (int k = 0; k < batchCount; ++k) { - this->template GEMM(transA, transB, M, N, K, alpha, A[k], - B[k], beta, C[k]); - } -} - -template <> -template <> -inline void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::bfloat16 alpha, const platform::bfloat16 **A, - const platform::bfloat16 **B, platform::bfloat16 beta, - platform::bfloat16 **C, int batchCount) const { - for (int k = 0; k < batchCount; ++k) { - this->template GEMM(transA, transB, M, N, K, alpha, - A[k], B[k], beta, C[k]); - } -} - -template <> -template <> -inline void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::bfloat16 alpha, const platform::bfloat16 **A, - const platform::bfloat16 **B, platform::bfloat16 beta, - platform::bfloat16 **C, int batchCount) const { - for (int k = 0; k < batchCount; ++k) { - this->template GEMM(transA, transB, M, N, K, alpha, - A[k], B[k], beta, C[k]); - } -} - -template <> -template -void Blas::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, - CBLAS_DIAG diag, int M, int N, - T alpha, const T *A, int lda, T *B, - int ldb) const { - // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` - // where ' stands for transpose - cublasSideMode_t cuSide = - (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT; - cublasFillMode_t cuUplo = - (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; - // use CUBLAS_OP_C (conjugate transpose) for complex - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasDiagType_t cuDiag = - (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; - - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::TRSM(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A, - lda, B, ldb); - }); -} - -template <> -template -void Blas::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, CBLAS_DIAG diag, - int M, int N, T alpha, const T *A, int lda, - T *B, int ldb) const { - // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` - // where ' stands for transpose - cublasSideMode_t cuSide = - (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT; - cublasFillMode_t cuUplo = - (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; - // use CUBLAS_OP_C (conjugate transpose) for complex - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasDiagType_t cuDiag = - (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; - - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::TRSM(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A, - lda, B, ldb); - }); -} - -template <> -template -void Blas::BatchedGETRF(int n, T **a, int *ipiv, - int *info, - int batch_size) const { - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size); - }); -} - -template <> -template -void Blas::BatchedGETRF(int n, T **a, int *ipiv, int *info, - int batch_size) const { - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size); - }); -} - -template <> -template -void Blas::BatchedGETRI(int n, const T **a, - const int *ipiv, T **a_inv, - int *info, - int batch_size) const { - PADDLE_ENFORCE_NE( - a_inv, a, - platform::errors::InvalidArgument( - "cuBLAS fuction 'cublasgetrfBatched' cannot be executed " - "in-place. The memory space of output matrix (address: %p) cannot " - "overlap memory space of input matrix (address: %p).", - a_inv, a)); - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size); - }); -} - -template <> -template -void Blas::BatchedGETRI(int n, const T **a, const int *ipiv, - T **a_inv, int *info, - int batch_size) const { - PADDLE_ENFORCE_NE( - a_inv, a, - platform::errors::InvalidArgument( - "cuBLAS fuction 'cublasgetrfBatched' cannot be executed " - "in-place. The memory space of output matrix (address: %p) cannot " - "overlap memory space of input matrix (address: %p).", - a_inv, a)); - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size); - }); -} - -template <> -template -void Blas::BatchedMatInv(int n, const T **a, - T **a_inv, int *info, - int batch_size) const { - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size); - }); -} - -template <> -template -void Blas::BatchedMatInv(int n, const T **a, T **a_inv, - int *info, int batch_size) const { - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size); - }); -} - -template <> -template -void Blas::BatchedGETRS( - CBLAS_TRANSPOSE trans, int n, int nrhs, const T **a, int lda, int *ipiv, - T **b, int ldb, int *info, int batch_size) const { - // use CUBLAS_OP_C (conjugate transpose) for complex - cublasOperation_t cuTrans = - (trans == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GETRS_BATCH(handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info, - batch_size); - }); -} - -template <> -template -void Blas::BatchedGETRS(CBLAS_TRANSPOSE trans, int n, - int nrhs, const T **a, int lda, - int *ipiv, T **b, int ldb, int *info, - int batch_size) const { - // use CUBLAS_OP_C (conjugate transpose) for complex - cublasOperation_t cuTrans = - (trans == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GETRS_BATCH(handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info, - batch_size); - }); -} - -template <> -template -void Blas::BatchedTRSM( - CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA, CBLAS_DIAG diag, - int M, int N, T alpha, const T **A, int lda, T **B, int ldb, - int batch_size) const { - // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` - // where ' stands for transpose - cublasSideMode_t cuSide = - (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT; - cublasFillMode_t cuUplo = - (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; - // use CUBLAS_OP_C (conjugate transpose) for complex - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasDiagType_t cuDiag = - (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; - - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::TRSM_BATCH(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, - &alpha, A, lda, B, ldb, batch_size); - }); -} - -template <> -template -void Blas::BatchedTRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, - CBLAS_DIAG diag, int M, int N, T alpha, - const T **A, int lda, T **B, int ldb, - int batch_size) const { - // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` - // where ' stands for transpose - cublasSideMode_t cuSide = - (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT; - cublasFillMode_t cuUplo = - (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; - // use CUBLAS_OP_C (conjugate transpose) for complex - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasDiagType_t cuDiag = - (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; - - context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::TRSM_BATCH(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, - &alpha, A, lda, B, ldb, batch_size); - }); -} - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h deleted file mode 100644 index 8e0075c42eb2c..0000000000000 --- a/paddle/fluid/operators/math/blas_impl.h +++ /dev/null @@ -1,1860 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include "paddle/pten/backends/cpu/cpu_context.h" -#ifdef PADDLE_WITH_MKLML -#include -#endif - -#include -#include -#include -#include - -#include "paddle/fluid/platform/bfloat16.h" -#include "paddle/fluid/platform/complex.h" -#include "paddle/pten/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { -namespace math { -namespace detail { - -template -static void axpy(int n, const T alpha, const T *x, const int incx, T *y, - const int incy) { - // Y = Y + alpha * X - while (n-- > 0) { - *y += alpha * *x; - y = y + incy; - x = x + incx; - } -} -} // namespace detail - -template -struct CBlas; - -template <> -struct CBlas { - template - static void VCOPY(ARGS... args) { - PADDLE_THROW(platform::errors::Unimplemented( - "Blas VCOPY do not supported on CPU, please check your code")); - } -}; - -template <> -struct CBlas { - template - static void VCOPY(ARGS... args) { - PADDLE_THROW(platform::errors::Unimplemented( - "Blas VCOPY do not supported on CPU, please check your code")); - } -}; - -template <> -struct CBlas { - template - static void AXPY(ARGS... args) { - detail::axpy(args...); - } - - template - static void VCOPY(ARGS... args) { - PADDLE_THROW(platform::errors::Unimplemented( - "Blas VCOPY do not supported on CPU with bfloat16," - " please check your code")); - } -}; - -#ifdef PADDLE_WITH_MKLML -template <> -struct CBlas { - template - static void GEMM(ARGS... args) { - platform::dynload::cblas_sgemm(args...); - } - - template - static float *GEMM_ALLOC(ARGS... args) { - return platform::dynload::cblas_sgemm_alloc(args...); - } - - template - static void GEMM_PACK(ARGS... args) { - platform::dynload::cblas_sgemm_pack(args...); - } - - template - static void GEMM_COMPUTE(ARGS... args) { - platform::dynload::cblas_sgemm_compute(args...); - } - - template - static void GEMM_FREE(ARGS... args) { - platform::dynload::cblas_sgemm_free(args...); - } - -#ifdef PADDLE_WITH_LIBXSMM - template - static void SMM_GEMM(ARGS... args) { - libxsmm_sgemm(args...); - } -#endif - - template - static void AXPY(ARGS... args) { - platform::dynload::cblas_saxpy(args...); - } - - template - static void VCOPY(ARGS... args) { - platform::dynload::cblas_scopy(args...); - } - - template - static void GEMV(ARGS... args) { - platform::dynload::cblas_sgemv(args...); - } - - template - static float DOT(ARGS... args) { - return platform::dynload::cblas_sdot(args...); - } - - template - static void SCAL(ARGS... args) { - platform::dynload::cblas_sscal(args...); - } - - template - static float ASUM(ARGS... args) { - return platform::dynload::cblas_sasum(args...); - } - - template - static void GEMM_BATCH(ARGS... args) { - platform::dynload::cblas_sgemm_batch(args...); - } - - template - static void VADD(ARGS... args) { - platform::dynload::vsAdd(args...); - } - - template - static void VSUB(ARGS... args) { - platform::dynload::vsSub(args...); - } - - template - static void VMUL(ARGS... args) { - platform::dynload::vsMul(args...); - } - - template - static void VDIV(ARGS... args) { - platform::dynload::vsDiv(args...); - } - - template - static void VEXP(ARGS... args) { - platform::dynload::vsExp(args...); - } - - template - static void VSQUARE(ARGS... args) { - platform::dynload::vsSqr(args...); - } - - template - static void VPOW(ARGS... args) { - platform::dynload::vsPowx(args...); - } - - template - static void VINV(ARGS... args) { - platform::dynload::vsInv(args...); - } - - template - static void VMERF(ARGS... args) { - platform::dynload::vmsErf(args...); - } -#if !defined(_WIN32) - template - static void CSRMM(ARGS... args) { - platform::dynload::mkl_scsrmm(args...); - } -#endif - - template - static void TRSM(ARGS... args) { - platform::dynload::cblas_strsm(args...); - } -}; - -template <> -struct CBlas { - template - static void GEMM(ARGS... args) { - platform::dynload::cblas_dgemm(args...); - } - - template - static double *GEMM_ALLOC(ARGS... args) { - return platform::dynload::cblas_dgemm_alloc(args...); - } - - template - static void GEMM_PACK(ARGS... args) { - platform::dynload::cblas_dgemm_pack(args...); - } - - template - static void GEMM_COMPUTE(ARGS... args) { - platform::dynload::cblas_dgemm_compute(args...); - } - - template - static void GEMM_FREE(ARGS... args) { - platform::dynload::cblas_dgemm_free(args...); - } - -#ifdef PADDLE_WITH_LIBXSMM - template - static void SMM_GEMM(ARGS... args) { - libxsmm_dgemm(args...); - } -#endif - - template - static void AXPY(ARGS... args) { - platform::dynload::cblas_daxpy(args...); - } - - template - static void VCOPY(ARGS... args) { - platform::dynload::cblas_dcopy(args...); - } - - template - static void GEMV(ARGS... args) { - platform::dynload::cblas_dgemv(args...); - } - - template - static double DOT(ARGS... args) { - return platform::dynload::cblas_ddot(args...); - } - - template - static void SCAL(ARGS... args) { - platform::dynload::cblas_dscal(args...); - } - - template - static double ASUM(ARGS... args) { - return platform::dynload::cblas_dasum(args...); - } - - template - static void GEMM_BATCH(ARGS... args) { - platform::dynload::cblas_dgemm_batch(args...); - } - - template - static void VADD(ARGS... args) { - platform::dynload::vdAdd(args...); - } - - template - static void VSUB(ARGS... args) { - platform::dynload::vdSub(args...); - } - - template - static void VMUL(ARGS... args) { - platform::dynload::vdMul(args...); - } - - template - static void VDIV(ARGS... args) { - platform::dynload::vdDiv(args...); - } - - template - static void VEXP(ARGS... args) { - platform::dynload::vdExp(args...); - } - - template - static void VSQUARE(ARGS... args) { - platform::dynload::vdSqr(args...); - } - - template - static void VPOW(ARGS... args) { - platform::dynload::vdPowx(args...); - } - - template - static void VINV(ARGS... args) { - platform::dynload::vdInv(args...); - } - - template - static void VMERF(ARGS... args) { - platform::dynload::vmdErf(args...); - } -#if !defined(_WIN32) - template - static void CSRMM(ARGS... args) { - platform::dynload::mkl_dcsrmm(args...); - } -#endif - - template - static void TRSM(ARGS... args) { - platform::dynload::cblas_dtrsm(args...); - } -}; - -template <> -struct CBlas> { - template - static void AXPY(int n, const paddle::platform::complex alpha, - const paddle::platform::complex *X, const int incX, - paddle::platform::complex *Y, const int incY) { - platform::dynload::cblas_caxpy(n, &alpha, X, incX, Y, incY); - } - - template - static void VCOPY(ARGS... args) { - platform::dynload::cblas_ccopy(args...); - } - - // the libmklml_intel.so paddle used has no vcAdd, vcSub, - // vcMul, vcDiv apis before rebuild from source - // so replace with the raw operator methods - /* - template - static void VADD(ARGS... args) { - platform::dynload::vcAdd(args...); - } - - template - static void VSUB(ARGS... args) { - platform::dynload::vcSub(args...); - } - - template - static void VMUL(ARGS... args) { - platform::dynload::vcMul(args...); - } - - template - static void VDIV(ARGS... args) { - platform::dynload::vcDiv(args...); - } - */ - - template - static void VADD(int n, const paddle::platform::complex *a, - const paddle::platform::complex *b, - paddle::platform::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] + b[i]; - } - } - - template - static void VSUB(int n, const paddle::platform::complex *a, - const paddle::platform::complex *b, - paddle::platform::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] - b[i]; - } - } - - template - static void VMUL(int n, const paddle::platform::complex *a, - const paddle::platform::complex *b, - paddle::platform::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] * b[i]; - } - } - template - static void VDIV(int n, const paddle::platform::complex *a, - const paddle::platform::complex *b, - paddle::platform::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] / b[i]; - } - } - - template - static void GEMV(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, int M, int N, - paddle::platform::complex alpha, - const paddle::platform::complex *A, int lda, - const paddle::platform::complex *X, int incx, - paddle::platform::complex beta, - paddle::platform::complex *Y, int incy) { - const void *a_ = (const void *)(A); - const void *x_ = (const void *)(X); - void *y_ = static_cast(Y); - platform::dynload::cblas_cgemv(layout, trans, M, N, &alpha, a_, lda, x_, - incx, &beta, y_, incy); - } - - template - static void GEMM(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_a, - CBLAS_TRANSPOSE trans_b, int M, int N, int K, - paddle::platform::complex alpha, - const paddle::platform::complex *A, int lda, - const paddle::platform::complex *B, int ldb, - paddle::platform::complex beta, - paddle::platform::complex *C, int ldc) { - const void *a_ = (const void *)(A); - const void *b_ = (const void *)(B); - void *c_ = static_cast(C); - platform::dynload::cblas_cgemm(layout, trans_a, trans_b, M, N, K, &alpha, - a_, lda, b_, ldb, &beta, c_, ldc); - } - - static void TRSM(CBLAS_LAYOUT layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans_a, CBLAS_DIAG diag, int M, int N, - paddle::platform::complex alpha, - const paddle::platform::complex *A, int lda, - paddle::platform::complex *B, int ldb) { - const void *a_ = (const void *)(A); - void *b_ = static_cast(B); - platform::dynload::cblas_ctrsm(layout, side, uplo, trans_a, diag, M, N, - &alpha, a_, lda, b_, ldb); - } - - template - static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a, - CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K, - paddle::platform::complex *alpha, - const paddle::platform::complex **A, - const int *lda, - const paddle::platform::complex **B, - const int *ldb, paddle::platform::complex *beta, - paddle::platform::complex **C, const int *ldc, - int group_count, int *group_size) { - const void **A_void = (const void **)(&(*A)); - const void **B_void = (const void **)(&(*B)); - void **C_void = reinterpret_cast(C); - - platform::dynload::cblas_cgemm_batch(layout, trans_a, trans_b, M, N, K, - alpha, A_void, lda, B_void, ldb, beta, - C_void, ldc, group_count, group_size); - } - - template - static void GEMM_EX(ARGS... args) { - platform::dynload::cblas_cgemm_batch(args...); - } -}; - -template <> -struct CBlas> { - template - static void AXPY(int n, const paddle::platform::complex alpha, - const paddle::platform::complex *X, const int incX, - paddle::platform::complex *Y, const int incY) { - platform::dynload::cblas_zaxpy(n, &alpha, X, incX, Y, incY); - } - - template - static void VCOPY(ARGS... args) { - platform::dynload::cblas_zcopy(args...); - } - - // the libmklml_intel.so paddle used has no vzAdd, vzSub, - // vzMul, vzDiv apis before rebuild from source - // so replace with the raw operator methods - /* - template - static void VADD(ARGS... args) { - platform::dynload::vzAdd(args...); - } - - template - static void VSUB(ARGS... args) { - platform::dynload::vzSub(args...); - } - - template - static void VMUL(ARGS... args) { - platform::dynload::vzMul(args...); - } - - template - static void VDIV(ARGS... args) { - platform::dynload::vzDiv(args...); - } - */ - - template - static void VADD(int n, const paddle::platform::complex *a, - const paddle::platform::complex *b, - paddle::platform::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] + b[i]; - } - } - - template - static void VSUB(int n, const paddle::platform::complex *a, - const paddle::platform::complex *b, - paddle::platform::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] - b[i]; - } - } - - template - static void VMUL(int n, const paddle::platform::complex *a, - const paddle::platform::complex *b, - paddle::platform::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] * b[i]; - } - } - template - static void VDIV(int n, const paddle::platform::complex *a, - const paddle::platform::complex *b, - paddle::platform::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] / b[i]; - } - } - - template - static void GEMV(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, int M, int N, - paddle::platform::complex alpha, - const paddle::platform::complex *A, int lda, - const paddle::platform::complex *X, int incx, - paddle::platform::complex beta, - paddle::platform::complex *Y, int incy) { - const void *a_ = (const void *)(A); - const void *x_ = (const void *)(X); - void *y_ = static_cast(Y); - platform::dynload::cblas_zgemv(layout, trans, M, N, &alpha, a_, lda, x_, - incx, &beta, y_, incy); - } - - template - static void GEMM(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_a, - CBLAS_TRANSPOSE trans_b, int M, int N, int K, - paddle::platform::complex alpha, - const paddle::platform::complex *A, int lda, - const paddle::platform::complex *B, int ldb, - paddle::platform::complex beta, - paddle::platform::complex *C, int ldc) { - const void *a_ = (const void *)(A); - const void *b_ = (const void *)(B); - void *c_ = static_cast(C); - platform::dynload::cblas_zgemm(layout, trans_a, trans_b, M, N, K, &alpha, - a_, lda, b_, ldb, &beta, c_, ldc); - } - - static void TRSM(CBLAS_LAYOUT layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans_a, CBLAS_DIAG diag, int M, int N, - paddle::platform::complex alpha, - const paddle::platform::complex *A, int lda, - paddle::platform::complex *B, int ldb) { - const void *a_ = (const void *)(A); - void *b_ = static_cast(B); - platform::dynload::cblas_ztrsm(layout, side, uplo, trans_a, diag, M, N, - &alpha, a_, lda, b_, ldb); - } - - template - static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a, - CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K, - paddle::platform::complex *alpha, - const paddle::platform::complex **A, - const int *lda, - const paddle::platform::complex **B, - const int *ldb, - paddle::platform::complex *beta, - paddle::platform::complex **C, const int *ldc, - int group_count, int *group_size) { - const void **A_void = (const void **)(&(*A)); - const void **B_void = (const void **)(&(*B)); - void **C_void = reinterpret_cast(C); - - platform::dynload::cblas_zgemm_batch(layout, trans_a, trans_b, M, N, K, - alpha, A_void, lda, B_void, ldb, beta, - C_void, ldc, group_count, group_size); - } - - template - static void GEMM_EX(ARGS... args) { - platform::dynload::cblas_zgemm_batch(args...); - } -}; - -#else - -template <> -struct CBlas { - template - static void GEMM(ARGS... args) { - cblas_sgemm(args...); - } - - template - static void AXPY(ARGS... args) { - cblas_saxpy(args...); - } - - template - static void VCOPY(ARGS... args) { - cblas_scopy(args...); - } - - template - static void GEMV(ARGS... args) { - cblas_sgemv(args...); - } - - template - static void TRSM(ARGS... args) { - cblas_strsm(args...); - } -}; - -template <> -struct CBlas { - template - static void GEMM(ARGS... args) { - cblas_dgemm(args...); - } - - template - static void AXPY(ARGS... args) { - cblas_daxpy(args...); - } - - template - static void VCOPY(ARGS... args) { - cblas_dcopy(args...); - } - - template - static void GEMV(ARGS... args) { - cblas_dgemv(args...); - } - - template - static void TRSM(ARGS... args) { - cblas_dtrsm(args...); - } -}; - -template <> -struct CBlas> { - template - static void VCOPY(ARGS... args) { - cblas_ccopy(args...); - } - - template - static void AXPY(int n, const paddle::platform::complex alpha, - const paddle::platform::complex *X, const int incX, - paddle::platform::complex *Y, const int incY) { - cblas_caxpy(n, &alpha, X, incX, Y, incY); - } - - template - static void GEMV(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA, - const int M, const int N, - const paddle::platform::complex alpha, - const paddle::platform::complex *A, const int lda, - const paddle::platform::complex *X, const int incX, - const paddle::platform::complex beta, - paddle::platform::complex *Y, const int incY) { - cblas_cgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY); - } - - template - static void GEMM(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, const int M, const int N, - const int K, const paddle::platform::complex alpha, - const paddle::platform::complex *A, const int lda, - const paddle::platform::complex *B, const int ldb, - const paddle::platform::complex beta, - paddle::platform::complex *C, const int ldc) { - cblas_cgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta, - C, ldc); - } - - static void TRSM(const CBLAS_LAYOUT layout, const CBLAS_SIDE side, - const CBLAS_UPLO uplo, const CBLAS_TRANSPOSE transA, - const CBLAS_DIAG diag, const int M, const int N, - const paddle::platform::complex alpha, - const paddle::platform::complex *A, const int lda, - paddle::platform::complex *B, const int ldb) { - cblas_ctrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb); - } -}; - -template <> -struct CBlas> { - template - static void VCOPY(ARGS... args) { - cblas_zcopy(args...); - } - - template - static void AXPY(int n, const paddle::platform::complex alpha, - const paddle::platform::complex *X, const int incX, - paddle::platform::complex *Y, const int incY) { - cblas_zaxpy(n, &alpha, X, incX, Y, incY); - } - - template - static void GEMV(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA, - const int M, const int N, - const paddle::platform::complex alpha, - const paddle::platform::complex *A, const int lda, - const paddle::platform::complex *X, const int incX, - const paddle::platform::complex beta, - paddle::platform::complex *Y, const int incY) { - cblas_zgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY); - } - - template - static void GEMM(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, const int M, const int N, - const int K, const paddle::platform::complex alpha, - const paddle::platform::complex *A, const int lda, - const paddle::platform::complex *B, const int ldb, - const paddle::platform::complex beta, - paddle::platform::complex *C, const int ldc) { - cblas_zgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta, - C, ldc); - } - - static void TRSM(const CBLAS_LAYOUT layout, const CBLAS_SIDE side, - const CBLAS_UPLO uplo, const CBLAS_TRANSPOSE transA, - const CBLAS_DIAG diag, const int M, const int N, - const paddle::platform::complex alpha, - const paddle::platform::complex *A, const int lda, - paddle::platform::complex *B, const int ldb) { - cblas_ztrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb); - } -}; - -#endif - -template <> -struct CBlas { - static void GEMM(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 GEMM not supported on CPU, please check your code")); - } - - static void SMM_GEMM(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 SMM_GEMM not supported on CPU, please check your code")); - } - static void VMUL(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 VMUL not supported on CPU, please check your code")); - } - static void VEXP(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 VEXP not supported on CPU, please check your code")); - } - static void VSQUARE(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 VSQUARE not supported on CPU, please check your code")); - } - static void VPOW(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 VPOW not supported on CPU, please check your code")); - } - static void DOT(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 DOT not supported on CPU, please check your code")); - }; - static void SCAL(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 SCAL not supported on CPU, please check your code")); - }; - static void ASUM(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 ASUM not supported on CPU, please check your code")); - }; -#ifdef PADDLE_WITH_MKLML - static void GEMM_BATCH(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 GEMM_BATCH not supported on CPU, please check your code")); - } -#endif -}; - -#ifdef PADDLE_WITH_MKLML -template <> -template -T *Blas::GEMM_ALLOC(const CBLAS_IDENTIFIER id, - const int M, const int N, - const int K) const { - return CBlas::GEMM_ALLOC(id, M, N, K); -} -template <> -template -T *Blas::GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, - const int N, const int K) const { - return CBlas::GEMM_ALLOC(id, M, N, K); -} - -template <> -template -void Blas::GEMM_PACK(const CBLAS_IDENTIFIER id, - const CBLAS_TRANSPOSE trans, - int M, int N, int K, - const T alpha, const T *src, - const int ld, T *dst) const { - CBlas::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst); -} -template <> -template -void Blas::GEMM_PACK(const CBLAS_IDENTIFIER id, - const CBLAS_TRANSPOSE trans, int M, - int N, int K, const T alpha, - const T *src, const int ld, - T *dst) const { - CBlas::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst); -} - -template <> -template -void Blas::GEMM_COMPUTE( - int transA, int transB, int M, int N, int K, const T *A, const int lda, - const T *B, const int ldb, T beta, T *C, const int ldc) const { - CBlas::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb, - beta, C, ldc); -} -template <> -template -void Blas::GEMM_COMPUTE(int transA, int transB, int M, int N, - int K, const T *A, const int lda, - const T *B, const int ldb, T beta, - T *C, const int ldc) const { - CBlas::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb, - beta, C, ldc); -} - -template <> -template -void Blas::GEMM_FREE(T *data) const { - CBlas::GEMM_FREE(data); -} -template <> -template -void Blas::GEMM_FREE(T *data) const { - CBlas::GEMM_FREE(data); -} -#endif - -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, - int N, int K, T alpha, const T *A, - const T *B, T beta, T *C) const { - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); -} -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, int K, - T alpha, const T *A, const T *B, T beta, - T *C) const { - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); -} - -template <> -template -void Blas::GEMM(bool transA, bool transB, int M, - int N, int K, T alpha, const T *A, - int lda, const T *B, int ldb, - T beta, T *C, int ldc) const { - CBlas::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, - transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, - lda, B, ldb, beta, C, ldc); -} -template <> -template -void Blas::GEMM(bool transA, bool transB, int M, int N, int K, - T alpha, const T *A, int lda, const T *B, - int ldb, T beta, T *C, int ldc) const { - CBlas::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, - transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, - lda, B, ldb, beta, C, ldc); -} - -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, - int N, int K, T alpha, const T *A, - int lda, const T *B, int ldb, - T beta, T *C, int ldc) const { - CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); -} -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, int K, - T alpha, const T *A, int lda, const T *B, - int ldb, T beta, T *C, int ldc) const { - CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); -} - -template -template -void Blas::MatMul(const framework::Tensor &mat_a, bool trans_a, - const framework::Tensor &mat_b, bool trans_b, - T alpha, framework::Tensor *mat_out, - T beta) const { - auto dim_a = mat_a.dims(); - auto dim_b = mat_b.dims(); - auto dim_out = mat_out->dims(); - PADDLE_ENFORCE_EQ( - dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, true, - platform::errors::InvalidArgument( - "The input and output of matmul should be matrix, the dim size must " - "be 2," - "but received dim size input_a:%d, input_b:%d, output:%d", - dim_a.size(), dim_b.size(), dim_out.size())); - PADDLE_ENFORCE_EQ( - mat_a.place() == mat_b.place() && mat_a.place() == mat_out->place(), true, - platform::errors::InvalidArgument("The places of matrices in the matmul " - "should be same, please check your " - "code.")); - - int M = dim_out[0]; - int N = dim_out[1]; - int K = !trans_a ? dim_a[1] : dim_a[0]; - - CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = !trans_b ? CblasNoTrans : CblasTrans; - - this->GEMM(transA, transB, M, N, K, alpha, mat_a.data(), mat_b.data(), - beta, mat_out->data()); -} - -template <> -template -void Blas::AXPY(int n, T alpha, const T *x, - T *y) const { - CBlas::AXPY(n, alpha, x, 1, y, 1); -} -template <> -template -void Blas::AXPY(int n, T alpha, const T *x, T *y) const { - CBlas::AXPY(n, alpha, x, 1, y, 1); -} - -template <> -template -void Blas::VCOPY(int n, const T *x, T *y) const { - CBlas::VCOPY(n, x, 1, y, 1); -} -template <> -template -void Blas::VCOPY(int n, const T *x, T *y) const { - CBlas::VCOPY(n, x, 1, y, 1); -} - -template <> -template -void Blas::VADD(int n, const T *x, const T *y, - T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VADD(n, x, y, z); -#else - if (x == z) { - this->template AXPY(n, (T)(1.), y, z); - } else { - this->template VCOPY(n, y, z); - this->template AXPY(n, (T)(1.), x, z); - } -#endif -} -template <> -template -void Blas::VADD(int n, const T *x, const T *y, T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VADD(n, x, y, z); -#else - if (x == z) { - this->template AXPY(n, (T)(1.), y, z); - } else { - this->template VCOPY(n, y, z); - this->template AXPY(n, (T)(1.), x, z); - } -#endif -} - -template <> -template -void Blas::VSUB(int n, const T *x, const T *y, - T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VSUB(n, x, y, z); -#else - // try to find if openblas support vsub - for (int i = 0; i < n; ++i) { - z[i] = x[i] - y[i]; - } -#endif -} -template <> -template -void Blas::VSUB(int n, const T *x, const T *y, T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VSUB(n, x, y, z); -#else - // try to find if openblas support vsub - for (int i = 0; i < n; ++i) { - z[i] = x[i] - y[i]; - } -#endif -} - -template <> -template -void Blas::VMUL(int n, const T *x, const T *y, - T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VMUL(n, x, y, z); -#else - // try to find if openblas support vmul - for (int i = 0; i < n; ++i) { - z[i] = x[i] * y[i]; - } -#endif -} -template <> -template -void Blas::VMUL(int n, const T *x, const T *y, T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VMUL(n, x, y, z); -#else - // try to find if openblas support vmul - for (int i = 0; i < n; ++i) { - z[i] = x[i] * y[i]; - } -#endif -} - -template <> -template -void Blas::VDIV(int n, const T *x, const T *y, - T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VDIV(n, x, y, z); -#else - // try to find if openblas support vdiv - for (int i = 0; i < n; ++i) { - z[i] = x[i] / y[i]; - } -#endif -} -template <> -template -void Blas::VDIV(int n, const T *x, const T *y, T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VDIV(n, x, y, z); -#else - // try to find if openblas support vdiv - for (int i = 0; i < n; ++i) { - z[i] = x[i] / y[i]; - } -#endif -} - -template <> -template -void Blas::VEXP(int n, const T *x, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VEXP(n, x, y); -#else - // try to find if openblas support vexp - for (int i = 0; i < n; ++i) { - y[i] = std::exp(x[i]); - } -#endif -} -template <> -template -void Blas::VEXP(int n, const T *x, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VEXP(n, x, y); -#else - // try to find if openblas support vexp - for (int i = 0; i < n; ++i) { - y[i] = std::exp(x[i]); - } -#endif -} - -template <> -template -void Blas::VSQUARE(int n, const T *x, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VSQUARE(n, x, y); -#else - for (int i = 0; i < n; ++i) { - y[i] = x[i] * x[i]; - } -#endif -} -template <> -template -void Blas::VSQUARE(int n, const T *x, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VSQUARE(n, x, y); -#else - for (int i = 0; i < n; ++i) { - y[i] = x[i] * x[i]; - } -#endif -} - -template <> -template -void Blas::VPOW(int n, const T *x, T a, - T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VPOW(n, x, a, y); -#else - for (int i = 0; i < n; ++i) { - y[i] = std::pow(x[i], a); - } -#endif -} -template <> -template -void Blas::VPOW(int n, const T *x, T a, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VPOW(n, x, a, y); -#else - for (int i = 0; i < n; ++i) { - y[i] = std::pow(x[i], a); - } -#endif -} - -template <> -template -T Blas::DOT(int n, const T *x, const T *y) const { -#ifdef PADDLE_WITH_MKLML - return CBlas::DOT(n, x, 1, y, 1); -#else - // try to find if openblas support cblas_dot - T sum = 0; - for (int i = 0; i < n; ++i) { - sum += x[i] * y[i]; - } - return sum; -#endif -} -template <> -template -T Blas::DOT(int n, const T *x, const T *y) const { -#ifdef PADDLE_WITH_MKLML - return CBlas::DOT(n, x, 1, y, 1); -#else - // try to find if openblas support cblas_dot - T sum = 0; - for (int i = 0; i < n; ++i) { - sum += x[i] * y[i]; - } - return sum; -#endif -} - -template <> -template -void Blas::SCAL(int n, const T a, T *x) const { -#ifdef PADDLE_WITH_MKLML - CBlas::SCAL(n, a, x, 1); -#else - // try to find if openblas support cblas_scal - for (int i = 0; i < n; ++i) { - x[i] = a * x[i]; - } -#endif -} -template <> -template -void Blas::SCAL(int n, const T a, T *x) const { -#ifdef PADDLE_WITH_MKLML - CBlas::SCAL(n, a, x, 1); -#else - // try to find if openblas support cblas_scal - for (int i = 0; i < n; ++i) { - x[i] = a * x[i]; - } -#endif -} - -template <> -template -T Blas::ASUM(int n, T *x, int inc) const { - auto sum = static_cast(0.0); -#ifdef PADDLE_WITH_MKLML - sum = CBlas::ASUM(n, x, inc); -#else - // TODO(jczaja): check if openblas does provide cblas_sasum/cblas_dasum - for (int c = 0; c < n; ++c) { - sum += x[c]; - } -#endif - return sum; -} -template <> -template -T Blas::ASUM(int n, T *x, int inc) const { - auto sum = static_cast(0.0); -#ifdef PADDLE_WITH_MKLML - sum = CBlas::ASUM(n, x, inc); -#else - // TODO(jczaja): check if openblas does provide cblas_sasum/cblas_dasum - for (int c = 0; c < n; ++c) { - sum += x[c]; - } -#endif - return sum; -} - -template <> -template -void Blas::GEMV(bool trans_a, int M, int N, T alpha, - const T *A, const T *B, T beta, - T *C) const { - CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; - CBlas::GEMV(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); -} -template <> -template -void Blas::GEMV(bool trans_a, int M, int N, T alpha, - const T *A, const T *B, T beta, T *C) const { - CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; - CBlas::GEMV(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); -} - -template <> -template -void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - T alpha, const T *A, const T *B, T beta, T *C, int batchCount, - int64_t strideA, int64_t strideB) const { - PADDLE_ENFORCE_NOT_NULL( - A, platform::errors::InvalidArgument("Pointer A should not be null.")); - PADDLE_ENFORCE_NOT_NULL( - B, platform::errors::InvalidArgument("Pointer B should not be null.")); - PADDLE_ENFORCE_NOT_NULL( - C, platform::errors::InvalidArgument("Pointer C should not be null.")); -#ifdef PADDLE_WITH_MKLML - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - auto a_array = std::vector(batchCount); - auto b_array = std::vector(batchCount); - auto c_array = std::vector(batchCount); - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA]; - b_array[k] = &B[k * strideB]; - c_array[k] = &C[k * M * N]; - } - - CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, - a_array.data(), &lda, b_array.data(), &ldb, &beta, - c_array.data(), &ldc, 1 /* group_count */, &batchCount); -#else - for (int k = 0; k < batchCount; ++k) { - auto *Ak = &A[k * strideA]; - auto *Bk = &B[k * strideB]; - auto *Ck = &C[k * M * N]; - this->template GEMM(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck); - } -#endif -} -template <> -template -void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, - int K, T alpha, const T *A, const T *B, - T beta, T *C, int batchCount, - int64_t strideA, - int64_t strideB) const { - PADDLE_ENFORCE_NOT_NULL( - A, platform::errors::InvalidArgument("Pointer A should not be null.")); - PADDLE_ENFORCE_NOT_NULL( - B, platform::errors::InvalidArgument("Pointer B should not be null.")); - PADDLE_ENFORCE_NOT_NULL( - C, platform::errors::InvalidArgument("Pointer C should not be null.")); -#ifdef PADDLE_WITH_MKLML - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - auto a_array = std::vector(batchCount); - auto b_array = std::vector(batchCount); - auto c_array = std::vector(batchCount); - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA]; - b_array[k] = &B[k * strideB]; - c_array[k] = &C[k * M * N]; - } - - CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, - a_array.data(), &lda, b_array.data(), &ldb, &beta, - c_array.data(), &ldc, 1 /* group_count */, &batchCount); -#else - for (int k = 0; k < batchCount; ++k) { - auto *Ak = &A[k * strideA]; - auto *Bk = &B[k * strideB]; - auto *Ck = &C[k * M * N]; - this->template GEMM(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck); - } -#endif -} - -template <> -template -void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - T alpha, const T **A, const T **B, T beta, T **C, int batchCount) const { -#ifdef PADDLE_WITH_MKLML - const int lda = (std::max)((transA == CblasNoTrans) ? K : M, 1); - const int ldb = (std::max)((transB == CblasNoTrans) ? N : K, 1); - const int ldc = (std::max)(N, 1); - CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, A, - &lda, B, &ldb, &beta, C, &ldc, 1 /* group_count */, - &batchCount); -#else - for (int k = 0; k < batchCount; ++k) { - this->template GEMM(transA, transB, M, N, K, alpha, A[k], B[k], beta, - C[k]); - } -#endif -} -template <> -template -void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, - int K, T alpha, const T **A, - const T **B, T beta, T **C, - int batchCount) const { -#ifdef PADDLE_WITH_MKLML - const int lda = (std::max)((transA == CblasNoTrans) ? K : M, 1); - const int ldb = (std::max)((transB == CblasNoTrans) ? N : K, 1); - const int ldc = (std::max)(N, 1); - CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, A, - &lda, B, &ldb, &beta, C, &ldc, 1 /* group_count */, - &batchCount); -#else - for (int k = 0; k < batchCount; ++k) { - this->template GEMM(transA, transB, M, N, K, alpha, A[k], B[k], beta, - C[k]); - } -#endif -} - -#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ - !defined(PADDLE_WITH_HIP) // @{ Group Blas MKLML: BatchedGEMMWithHead -template <> -template -void Blas::BatchedGEMMWithHead( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int W1, int H1, int W2, - int H2, T alpha, const T *A, const T *B, T beta, T *C, int batchCount, - int64_t strideA, int64_t strideB, int64_t head_number, - bool split_b_vertical) const { - int lda = (transA == CblasNoTrans) ? W1 : H1; - int ldb = (transB == CblasNoTrans) ? W2 : H2; - auto a_array = std::vector(batchCount); - auto b_array = std::vector(batchCount); - auto c_array = std::vector(batchCount); - - if (split_b_vertical) { - int ldc = W2; - int sub_width = W2 / head_number; - - for (int i = 0; i < head_number; i++) { - int sub_matA_offset = (transA == CblasNoTrans) - ? i * (W1 / head_number) - : i * (W1 / head_number) * H1; - int sub_matB_offset = (transB == CblasNoTrans) - ? i * (W2 / head_number) - : i * (W2 / head_number) * H2; - int sub_matC_offset = i * W2 / head_number; - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA] + sub_matA_offset; - b_array[k] = &B[k * strideB] + sub_matB_offset; - c_array[k] = &C[k * H1 * W2] + sub_matC_offset; - } - - CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &H1, &sub_width, - &H2, &alpha, a_array.data(), &lda, b_array.data(), - &ldb, &beta, c_array.data(), &ldc, - 1 /* group_count */, &batchCount); - } - - } else { - PADDLE_ENFORCE_EQ( - W1, H2, - platform::errors::InvalidArgument( - "The fisrt matrix width should be same as second matrix height," - "but received fisrt matrix width %d" - ", second matrix height %d", - W1, H2)); - int ldc = W2 * head_number; - int sub_width = W1 / head_number; - - for (int i = 0; i < head_number; i++) { - int sub_matA_offset = (transA == CblasNoTrans) - ? i * (W1 / head_number) - : i * (W1 / head_number) * H1; - int sub_matB_offset = (transB == CblasNoTrans) - ? i * (W1 / head_number) * W2 - : i * (W1 / head_number); - int sub_matC_offset = i * W2; - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA] + sub_matA_offset; - b_array[k] = &B[k * strideB] + sub_matB_offset; - c_array[k] = &C[k * H1 * head_number * W2] + sub_matC_offset; - } - - CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &H1, &W2, - &sub_width, &alpha, a_array.data(), &lda, - b_array.data(), &ldb, &beta, c_array.data(), &ldc, - 1 /* group_count */, &batchCount); - } - } -} -template <> -template -void Blas::BatchedGEMMWithHead( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int W1, int H1, int W2, - int H2, T alpha, const T *A, const T *B, T beta, T *C, int batchCount, - int64_t strideA, int64_t strideB, int64_t head_number, - bool split_b_vertical) const { - int lda = (transA == CblasNoTrans) ? W1 : H1; - int ldb = (transB == CblasNoTrans) ? W2 : H2; - auto a_array = std::vector(batchCount); - auto b_array = std::vector(batchCount); - auto c_array = std::vector(batchCount); - - if (split_b_vertical) { - int ldc = W2; - int sub_width = W2 / head_number; - - for (int i = 0; i < head_number; i++) { - int sub_matA_offset = (transA == CblasNoTrans) - ? i * (W1 / head_number) - : i * (W1 / head_number) * H1; - int sub_matB_offset = (transB == CblasNoTrans) - ? i * (W2 / head_number) - : i * (W2 / head_number) * H2; - int sub_matC_offset = i * W2 / head_number; - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA] + sub_matA_offset; - b_array[k] = &B[k * strideB] + sub_matB_offset; - c_array[k] = &C[k * H1 * W2] + sub_matC_offset; - } - - CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &H1, &sub_width, - &H2, &alpha, a_array.data(), &lda, b_array.data(), - &ldb, &beta, c_array.data(), &ldc, - 1 /* group_count */, &batchCount); - } - - } else { - PADDLE_ENFORCE_EQ( - W1, H2, - platform::errors::InvalidArgument( - "The fisrt matrix width should be same as second matrix height," - "but received fisrt matrix width %d" - ", second matrix height %d", - W1, H2)); - int ldc = W2 * head_number; - int sub_width = W1 / head_number; - - for (int i = 0; i < head_number; i++) { - int sub_matA_offset = (transA == CblasNoTrans) - ? i * (W1 / head_number) - : i * (W1 / head_number) * H1; - int sub_matB_offset = (transB == CblasNoTrans) - ? i * (W1 / head_number) * W2 - : i * (W1 / head_number); - int sub_matC_offset = i * W2; - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA] + sub_matA_offset; - b_array[k] = &B[k * strideB] + sub_matB_offset; - c_array[k] = &C[k * H1 * head_number * W2] + sub_matC_offset; - } - - CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &H1, &W2, - &sub_width, &alpha, a_array.data(), &lda, - b_array.data(), &ldb, &beta, c_array.data(), &ldc, - 1 /* group_count */, &batchCount); - } - } -} -#endif // @} End Group Blas MKLML: BatchedGEMMWithHead - -template -template -void Blas::MatMul(const int M, const int N, const int K, - const T *A, const T *B, T *C) const { - this->template GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, - static_cast(1), A, K, B, N, static_cast(0), C, - N); -} - -template <> -template -void Blas::MatMul(const int M, const int N, - const int K, const T *A, - const T *B, T *C) const { -#ifdef PADDLE_WITH_LIBXSMM - // Refer to https://github.com/hfp/libxsmm/blob/master/README.md - // But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; - - // Since the matrix is very small, - // so the unit of calculation is already very fast, - // and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead, - // use xsmm directly. - // Note: SMM use ColMajor - const char transa = 'N'; - const char transb = 'N'; - const T alpha = static_cast(1); - const T beta = static_cast(0); - CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta, - C, &N); - return; -#endif - - CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, - static_cast(1), A, K, B, N, static_cast(0), C, N); -} -template <> -template -void Blas::MatMul(const int M, const int N, const int K, - const T *A, const T *B, T *C) const { -#ifdef PADDLE_WITH_LIBXSMM - // Refer to https://github.com/hfp/libxsmm/blob/master/README.md - // But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; - - // Since the matrix is very small, - // so the unit of calculation is already very fast, - // and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead, - // use xsmm directly. - // Note: SMM use ColMajor - const char transa = 'N'; - const char transb = 'N'; - const T alpha = static_cast(1); - const T beta = static_cast(0); - CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta, - C, &N); - return; -#endif - - CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, - static_cast(1), A, K, B, N, static_cast(0), C, N); -} - -template -template -void Blas::MatMul(const framework::Tensor &mat_a, - const MatDescriptor &dim_a, - const framework::Tensor &mat_b, - const MatDescriptor &dim_b, T alpha, - framework::Tensor *mat_out, T beta) const { - MatMul(mat_a.data(), dim_a, mat_b.data(), dim_b, alpha, - mat_out->data(), beta); -} - -template -template -void Blas::MatMul(const T *mat_a, const MatDescriptor &dim_a, - const T *mat_b, const MatDescriptor &dim_b, - T alpha, T *mat_out, T beta) const { - PADDLE_ENFORCE_EQ( - dim_a.width_, dim_b.height_, - platform::errors::InvalidArgument( - "The fisrt matrix width should be same as second matrix height," - "but received fisrt matrix width %d" - ", second matrix height %d", - dim_a.width_, dim_b.height_)); - - CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; - if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { - this->template GEMM(transA, transB, dim_a.height_, dim_b.width_, - dim_a.width_, alpha, mat_a, mat_b, beta, mat_out); - } else { - PADDLE_ENFORCE_EQ( - dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || - dim_b.batch_size_ == 0, - true, platform::errors::InvalidArgument( - "dim_a.batch_size should be equal to dim_b.batch_size, or " - "one of dim_a.batch_size and dim_b.batch_size should be 0. " - "But got dim_a.batch_size = %d, dim_b.batch_size = %d.", - dim_a.batch_size_, dim_b.batch_size_)); - this->template BatchedGEMM( - transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, mat_a, - mat_b, beta, mat_out, - dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, - dim_a.stride_, dim_b.stride_); - } -} - -#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ - !defined(PADDLE_WITH_HIP) -// @{ Group Blas MKLML: MatMulWithHead -/* - * Multiple two matrixes with multiple heads - * - * A new parameter, i.e head_number is added compared to normal MatMul. - * The head_number describes the number of heads a matrix is vertically - * split. - * - * When user calls this API, the multiplication of two big matrixes is split - * into multiplication of several (head_number_) small matrixes. e.g. if Mat A - * is [3, 24] and Mat B is [24, 4], when multiple A and B with head_number as - * 4, Mat A will be split as 4 matrix of [3, 6] and Mat B will be - * (horizontally) split as 4 matrix of [6, 4]. The result of final matrix - * will be 4 matrix of [3, 4], i.e. [3, 16]. - * Another example is A is [3, 8], B is [2, 16], head_number is 4. In this - * case, A will be split as [3, 2], B will be (vertically) split as - * [2, 4]. The final result will be 4 matrix of 4 matrix of [3,4], i.e. [3, 16] - */ -template -template -void Blas::MatMulWithHead(const framework::Tensor &mat_a, - const MatDescriptor &dim_a, - const framework::Tensor &mat_b, - const MatDescriptor &dim_b, T alpha, - int head_number, - framework::Tensor *mat_out, T beta, - bool mat_b_split_vertical) const { - PADDLE_ENFORCE_EQ( - dim_a.width_ % head_number, 0, - platform::errors::InvalidArgument( - "The first input width must be some times the head number" - "but received first input width %d" - ", head_number %d", - dim_a.width_, head_number)); - PADDLE_ENFORCE_GE(head_number, 1, - platform::errors::InvalidArgument( - "The head number should be greater equal 1," - "but received head number %d", - head_number)); - PADDLE_ENFORCE_LE( - head_number, dim_a.width_, - platform::errors::InvalidArgument( - "The head number should be less equal first input width," - "but received first input width %d" - ", head_number %d", - dim_a.width_, head_number)); - CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; - - if (mat_b_split_vertical) { - PADDLE_ENFORCE_EQ( - dim_b.height_, dim_a.width_ / head_number, - platform::errors::InvalidArgument( - "The second input height should be equal than first input width," - "but received second input height %d, first input width %d", - dim_b.height_, dim_a.width_ / head_number)); - PADDLE_ENFORCE_EQ( - dim_a.width_ % head_number, 0, - platform::errors::InvalidArgument( - "The second input width should be some times the head number" - "but received second input width %d" - ", head_number %d", - dim_b.width_, head_number)); - } - - if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { - int lda = !dim_a.trans_ ? dim_a.width_ : dim_a.height_; - int ldb = !dim_b.trans_ ? dim_b.width_ : dim_b.height_; - int sub_matA_offset; - int sub_matB_offset; - int sub_matC_offset; - int sub_mat_M = dim_a.height_; - int sub_mat_N; - int sub_mat_K; - int ldc; - - for (int i = 0; i < head_number; i++) { - sub_matA_offset = dim_a.trans_ - ? i * (dim_a.width_ / head_number) * dim_a.height_ - : i * (dim_a.width_ / head_number); - if (mat_b_split_vertical) { - sub_matB_offset = dim_b.trans_ - ? i * (dim_b.width_ / head_number) * dim_b.height_ - : i * (dim_b.width_ / head_number); - sub_matC_offset = i * dim_b.width_ / head_number; - - sub_mat_N = dim_b.width_ / head_number; - sub_mat_K = dim_b.height_; - - ldc = dim_b.width_; - } else { - sub_matB_offset = - dim_b.trans_ ? i * (dim_b.height_ / head_number) - : i * (dim_b.height_ / head_number) * dim_b.width_; - sub_matC_offset = i * dim_b.width_; - - sub_mat_N = dim_b.width_; - sub_mat_K = dim_a.width_ / head_number; - - ldc = head_number * dim_b.width_; - } - - this->template GEMM(transA, transB, sub_mat_M, sub_mat_N, sub_mat_K, - alpha, mat_a.data() + sub_matA_offset, lda, - mat_b.data() + sub_matB_offset, ldb, beta, - mat_out->data() + sub_matC_offset, ldc); - } - } else { - PADDLE_ENFORCE_EQ( - (dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || - dim_b.batch_size_ == 0), - true, - platform::errors::InvalidArgument( - "The first input batch size should be equal than second input," - "either two input batch size is 0, but received first input batch " - "size" - " %d, second input batch size %d", - dim_a.batch_size_, dim_b.batch_size_)); - - this->template BatchedGEMMWithHead( - transA, transB, dim_a.width_, dim_a.height_, dim_b.width_, - dim_b.height_, alpha, mat_a.data(), mat_b.data(), beta, - mat_out->data(), - dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, - dim_a.stride_, dim_b.stride_, head_number, mat_b_split_vertical); - } -} -#endif // @} End Group Blas MKLML: MatMulWithHead - -template -template -void Blas::VINV(int n, const T *a, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VINV(n, a, y); -#else - for (int i = 0; i < n; ++i) { - y[i] = 1.0 / a[i]; - } -#endif -} - -template <> -template -void Blas::VMERF(int n, const T *a, T *y, - int64_t mode) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VMERF(n, a, y, mode); -#else - for (int i = 0; i < n; ++i) { - y[i] = std::erf(a[i]); - } -#endif -} -template <> -template -void Blas::VMERF(int n, const T *a, T *y, - int64_t mode) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VMERF(n, a, y, mode); -#else - for (int i = 0; i < n; ++i) { - y[i] = std::erf(a[i]); - } -#endif -} - -#ifdef PADDLE_WITH_MKLML -template <> -template -void Blas::CSRMM( - const char *transa, const int *m, const int *n, const int *k, - const T *alpha, const char *matdescra, const T *val, const int *indx, - const int *pntrb, const int *pntre, const T *b, const int *ldb, - const T *beta, T *c, const int *ldc) const { - CBlas::CSRMM(transa, m, n, k, alpha, matdescra, val, indx, pntrb, pntre, b, - ldb, beta, c, ldc); -} -template <> -template -void Blas::CSRMM(const char *transa, const int *m, - const int *n, const int *k, const T *alpha, - const char *matdescra, const T *val, - const int *indx, const int *pntrb, - const int *pntre, const T *b, const int *ldb, - const T *beta, T *c, const int *ldc) const { - CBlas::CSRMM(transa, m, n, k, alpha, matdescra, val, indx, pntrb, pntre, b, - ldb, beta, c, ldc); -} -#endif - -template <> -template -void Blas::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, - CBLAS_DIAG diag, int M, int N, - T alpha, const T *A, int lda, T *B, - int ldb) const { - CBlas::TRSM(CblasRowMajor, side, uplo, transA, diag, M, N, alpha, A, lda, - B, ldb); -} -template <> -template -void Blas::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, CBLAS_DIAG diag, - int M, int N, T alpha, const T *A, int lda, - T *B, int ldb) const { - CBlas::TRSM(CblasRowMajor, side, uplo, transA, diag, M, N, alpha, A, lda, - B, ldb); -} - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/blas_impl.hip.h b/paddle/fluid/operators/math/blas_impl.hip.h deleted file mode 100644 index 9518da89edeb0..0000000000000 --- a/paddle/fluid/operators/math/blas_impl.hip.h +++ /dev/null @@ -1,1379 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/platform/device/gpu/gpu_info.h" -#include "paddle/fluid/platform/dynload/rocblas.h" -#include "paddle/pten/backends/gpu/gpu_context.h" -#include "paddle/pten/kernels/funcs/math_function.h" - -DECLARE_bool(enable_cublas_tensor_op_math); - -namespace paddle { -namespace operators { -namespace math { - -template -struct CUBlas; - -template <> -struct CUBlas { - template - static void GEMM(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_sgemm(args...)); - } - - template - static void AXPY(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_saxpy(args...)); - } - - template - static void SCAL(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_sscal(args...)); - } - - template - static void VCOPY(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_scopy(args...)); - } - - template - static void GEMV(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_sgemv(args...)); - } - - template - static void GEMM_STRIDED_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::rocblas_sgemm_strided_batched(args...)); - } - - // HIP not supportted, refer to the doc here: - // https://github.com/ROCm-Developer-Tools/HIP/blob/roc-3.5.x/docs/markdown/CUBLAS_API_supported_by_HIP.md - template - static void GEMM_EX(ARGS... args) { - PADDLE_THROW(platform::errors::Unimplemented( - "cublasSgemmEx is not supported on HIP platform.")); - } - - template - static void TRSM(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_strsm(args...)); - } - - template - static void GETRF_BATCH(ARGS... args) { - PADDLE_THROW(platform::errors::Unimplemented( - "cublasSgetrfBatched is not supported on HIP platform.")); - } - - template - static void GETRI_BATCH(ARGS... args) { - PADDLE_THROW(platform::errors::Unimplemented( - "cublasSgetriBatched is not supported on HIP platform.")); - } - - template - static void MATINV_BATCH(ARGS... args) { - PADDLE_THROW(platform::errors::Unimplemented( - "cublasSmatinvBatched is not supported on HIP platform.")); - } - - template - static void TRSM_BATCH(ARGS... args) { - PADDLE_THROW(platform::errors::Unimplemented( - "cublasStrsmBatched is not supported on HIP platform.")); - } -}; - -template <> -struct CUBlas { - template - static void GEMM(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_dgemm(args...)); - } - - template - static void AXPY(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_daxpy(args...)); - } - - template - static void SCAL(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_dscal(args...)); - } - - template - static void VCOPY(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_dcopy(args...)); - } - - template - static void GEMV(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_dgemv(args...)); - } - - template - static void GEMM_STRIDED_BATCH(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::rocblas_dgemm_strided_batched(args...)); - } - - template - static void GEMM_EX(ARGS... args) { - PADDLE_THROW(platform::errors::Unimplemented( - "Currently there are not cublasDgemmEx.")); - } - - template - static void TRSM(ARGS... args) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_dtrsm(args...)); - } - - template - static void GETRF_BATCH(ARGS... args) { - PADDLE_THROW(platform::errors::Unimplemented( - "cublasDgetrfBatched is not supported on HIP platform.")); - } - - template - static void GETRI_BATCH(ARGS... args) { - PADDLE_THROW(platform::errors::Unimplemented( - "cublasDgetriBatched is not supported on HIP platform.")); - } - - template - static void MATINV_BATCH(ARGS... args) { - PADDLE_THROW(platform::errors::Unimplemented( - "cublasDmatinvBatched is not supported on HIP platform.")); - } - - template - static void TRSM_BATCH(ARGS... args) { - PADDLE_THROW(platform::errors::Unimplemented( - "cublasDtrsmBatched is not supported on HIP platform.")); - } -}; - -template <> -struct CUBlas { - using float16 = platform::float16; - - static void GEMM(rocblas_handle handle, rocblas_operation transa, - rocblas_operation transb, int m, int n, int k, - const float16 *alpha, const float16 *A, int lda, - const float16 *B, int ldb, const float16 *beta, float16 *C, - int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_hgemm( - handle, transa, transb, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb, - reinterpret_cast(beta), - reinterpret_cast(C), ldc)); - } - - static void GEMM_STRIDED_BATCH(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, int m, int n, int k, - const float16 *alpha, const float16 *A, - int lda, long long int strideA, // NOLINT - const float16 *B, // NOLINT - int ldb, long long int strideB, // NOLINT - const float16 *beta, float16 *C, int ldc, - long long int strideC, // NOLINT - int batchCount) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_hgemm_strided_batched( - handle, transa, transb, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, strideA, - reinterpret_cast(B), ldb, strideB, - reinterpret_cast(beta), - reinterpret_cast(C), ldc, strideC, batchCount)); - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(platform::CUDADeviceContext *dev_ctx, - rocblas_operation transa, rocblas_operation transb, int m, - int n, int k, const void *alpha, const void *A, - rocblas_datatype Atype, int lda, const void *B, - rocblas_datatype Btype, int ldb, const void *beta, - void *C, rocblas_datatype Ctype, int ldc, - rocblas_datatype computeType) { - rocblas_gemm_algo algo = rocblas_gemm_algo_standard; - dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_gemm_ex( - handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, - beta, C, Ctype, ldc, C, Ctype, ldc, computeType, algo, 0, 0)); - }); - } - template - static void GEMM_EX(pten::GPUContext *dev_ctx, rocblas_operation transa, - rocblas_operation transb, int m, int n, int k, - const void *alpha, const void *A, rocblas_datatype Atype, - int lda, const void *B, rocblas_datatype Btype, int ldb, - const void *beta, void *C, rocblas_datatype Ctype, - int ldc, rocblas_datatype computeType) { - rocblas_gemm_algo algo = rocblas_gemm_algo_standard; - dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_gemm_ex( - handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, - beta, C, Ctype, ldc, C, Ctype, ldc, computeType, algo, 0, 0)); - }); - } -}; - -template <> -struct CUBlas> { - static void GEMV(rocblas_handle handle, rocblas_operation transa, int m, - int n, const platform::complex *alpha, - const platform::complex *A, int lda, - const platform::complex *B, int ldb, - const platform::complex *beta, - platform::complex *C, int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_cgemv( - handle, transa, m, n, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb, - reinterpret_cast(beta), - reinterpret_cast(C), ldc)); - } - - static void AXPY(rocblas_handle handle, int n, - const platform::complex *alpha, - const platform::complex *X, const int incX, - platform::complex *Y, const int incY) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_caxpy( - handle, n, reinterpret_cast(alpha), - reinterpret_cast(X), incX, - reinterpret_cast(Y), incY)); - } - - static void GEMM_STRIDED_BATCH(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, int m, int n, int k, - const platform::complex *alpha, - const platform::complex *A, int lda, - long long int strideA, // NOLINT - const platform::complex *B, // NOLINT - int ldb, long long int strideB, // NOLINT - const platform::complex *beta, - platform::complex *C, int ldc, - long long int strideC, // NOLINT - int batchCount) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_cgemm_strided_batched( - handle, transa, transb, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, strideA, - reinterpret_cast(B), ldb, strideB, - reinterpret_cast(beta), - reinterpret_cast(C), ldc, strideC, - batchCount)); - } - - static void GEMM(rocblas_handle handle, rocblas_operation transa, - rocblas_operation transb, int m, int n, int k, - const platform::complex *alpha, - const platform::complex *A, int lda, - const platform::complex *B, int ldb, - const platform::complex *beta, - platform::complex *C, int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_cgemm( - handle, transa, transb, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb, - reinterpret_cast(beta), - reinterpret_cast(C), ldc)); - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(platform::CUDADeviceContext *dev_ctx, - rocblas_operation transa, rocblas_operation transb, int m, - int n, int k, const void *alpha, const void *A, - rocblas_datatype Atype, int lda, const void *B, - rocblas_datatype Btype, int ldb, const void *beta, - void *C, rocblas_datatype Ctype, int ldc, - rocblas_datatype computeType) { - rocblas_gemm_algo algo = rocblas_gemm_algo_standard; - dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_gemm_ex( - handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, - beta, C, Ctype, ldc, C, Ctype, ldc, computeType, algo, 0, 0)); - }); - } - template - static void GEMM_EX(pten::GPUContext *dev_ctx, rocblas_operation transa, - rocblas_operation transb, int m, int n, int k, - const void *alpha, const void *A, rocblas_datatype Atype, - int lda, const void *B, rocblas_datatype Btype, int ldb, - const void *beta, void *C, rocblas_datatype Ctype, - int ldc, rocblas_datatype computeType) { - rocblas_gemm_algo algo = rocblas_gemm_algo_standard; - dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_gemm_ex( - handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, - beta, C, Ctype, ldc, C, Ctype, ldc, computeType, algo, 0, 0)); - }); - } -}; - -template <> -struct CUBlas> { - static void GEMV(rocblas_handle handle, rocblas_operation transa, int m, - int n, const platform::complex *alpha, - const platform::complex *A, int lda, - const platform::complex *B, int ldb, - const platform::complex *beta, - platform::complex *C, int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_zgemv( - handle, transa, m, n, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb, - reinterpret_cast(beta), - reinterpret_cast(C), ldc)); - } - - static void AXPY(rocblas_handle handle, int n, - const platform::complex *alpha, - const platform::complex *X, const int incX, - platform::complex *Y, const int incY) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_zaxpy( - handle, n, reinterpret_cast(alpha), - reinterpret_cast(X), incX, - reinterpret_cast(Y), incY)); - } - - static void GEMM_STRIDED_BATCH(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, int m, int n, int k, - const platform::complex *alpha, - const platform::complex *A, int lda, - long long int strideA, // NOLINT - const platform::complex *B, // NOLINT - int ldb, long long int strideB, // NOLINT - const platform::complex *beta, - platform::complex *C, int ldc, - long long int strideC, // NOLINT - int batchCount) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_zgemm_strided_batched( - handle, transa, transb, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, strideA, - reinterpret_cast(B), ldb, strideB, - reinterpret_cast(beta), - reinterpret_cast(C), ldc, strideC, - batchCount)); - } - - static void GEMM(rocblas_handle handle, rocblas_operation transa, - rocblas_operation transb, int m, int n, int k, - const platform::complex *alpha, - const platform::complex *A, int lda, - const platform::complex *B, int ldb, - const platform::complex *beta, - platform::complex *C, int ldc) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_zgemm( - handle, transa, transb, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb, - reinterpret_cast(beta), - reinterpret_cast(C), ldc)); - } - - // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. - // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode - template - static void GEMM_EX(platform::CUDADeviceContext *dev_ctx, - rocblas_operation transa, rocblas_operation transb, int m, - int n, int k, const void *alpha, const void *A, - rocblas_datatype Atype, int lda, const void *B, - rocblas_datatype Btype, int ldb, const void *beta, - void *C, rocblas_datatype Ctype, int ldc, - rocblas_datatype computeType) { - rocblas_gemm_algo algo = rocblas_gemm_algo_standard; - dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_gemm_ex( - handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, - beta, C, Ctype, ldc, C, Ctype, ldc, computeType, algo, 0, 0)); - }); - } - template - static void GEMM_EX(pten::GPUContext *dev_ctx, rocblas_operation transa, - rocblas_operation transb, int m, int n, int k, - const void *alpha, const void *A, rocblas_datatype Atype, - int lda, const void *B, rocblas_datatype Btype, int ldb, - const void *beta, void *C, rocblas_datatype Ctype, - int ldc, rocblas_datatype computeType) { - rocblas_gemm_algo algo = rocblas_gemm_algo_standard; - dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_gemm_ex( - handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, - beta, C, Ctype, ldc, C, Ctype, ldc, computeType, algo, 0, 0)); - }); - } -}; - -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, - int N, int K, T alpha, const T *A, - const T *B, T beta, T *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - rocblas_operation cuTransA = (transA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_operation cuTransB = (transB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, - &beta, C, N); - }); -} -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, int K, - T alpha, const T *A, const T *B, T beta, - T *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - rocblas_operation cuTransA = (transA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_operation cuTransB = (transB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, - &beta, C, N); - }); -} - -template <> -template <> -inline void Blas::GEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::float16 alpha, const platform::float16 *A, - const platform::float16 *B, platform::float16 beta, - platform::float16 *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - rocblas_operation cuTransA = (transA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_operation cuTransB = (transB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - context_.GetComputeCapability(), 53, - platform::errors::InvalidArgument( - "cublas fp16 gemm requires GPU compute capability >= 53," - "but received %d", - context_.GetComputeCapability())); - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - - auto &cuda_ctx = const_cast(context_); - CUBlas::GEMM_EX( - &cuda_ctx, cuTransB, cuTransA, N, M, K, &h_alpha, B, - rocblas_datatype_f16_r, ldb, A, rocblas_datatype_f16_r, lda, &h_beta, C, - rocblas_datatype_f16_r, N, rocblas_datatype_f32_r); -} -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, - int K, platform::float16 alpha, - const platform::float16 *A, - const platform::float16 *B, - platform::float16 beta, - platform::float16 *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - rocblas_operation cuTransA = (transA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_operation cuTransB = (transB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - context_.GetComputeCapability(), 53, - platform::errors::InvalidArgument( - "cublas fp16 gemm requires GPU compute capability >= 53," - "but received %d", - context_.GetComputeCapability())); - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - - auto &cuda_ctx = const_cast(context_); - CUBlas::GEMM_EX( - &cuda_ctx, cuTransB, cuTransA, N, M, K, &h_alpha, B, - rocblas_datatype_f16_r, ldb, A, rocblas_datatype_f16_r, lda, &h_beta, C, - rocblas_datatype_f16_r, N, rocblas_datatype_f32_r); -} - -template <> -template <> -inline void Blas::GEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::bfloat16 alpha, const platform::bfloat16 *A, - const platform::bfloat16 *B, platform::bfloat16 beta, - platform::bfloat16 *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - rocblas_operation cuTransA = (transA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_operation cuTransB = (transB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - // TODO(zhiqiu): 80 has the same meaning for rocm and cuda? - PADDLE_ENFORCE_GE( - context_.GetComputeCapability(), 80, - platform::errors::InvalidArgument( - "rocblas fp16 gemm requires GPU compute capability >= 80," - "but received %d", - context_.GetComputeCapability())); - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - rocblas_gemm_algo algo = rocblas_gemm_algo_standard; - - context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_gemm_ex( - handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, - rocblas_datatype_bf16_r, ldb, A, rocblas_datatype_bf16_r, lda, &h_beta, - C, rocblas_datatype_bf16_r, N, C, rocblas_datatype_bf16_r, N, - rocblas_datatype_f32_r, algo, 0, 0)); - }); -} - -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, - int K, platform::bfloat16 alpha, - const platform::bfloat16 *A, - const platform::bfloat16 *B, - platform::bfloat16 beta, - platform::bfloat16 *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - rocblas_operation cuTransA = (transA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_operation cuTransB = (transB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - // TODO(zhiqiu): 80 has the same meaning for rocm and cuda? - PADDLE_ENFORCE_GE( - context_.GetComputeCapability(), 80, - platform::errors::InvalidArgument( - "rocblas fp16 gemm requires GPU compute capability >= 80," - "but received %d", - context_.GetComputeCapability())); - - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - rocblas_gemm_algo algo = rocblas_gemm_algo_standard; - - context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_gemm_ex( - handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, - rocblas_datatype_bf16_r, ldb, A, rocblas_datatype_bf16_r, lda, &h_beta, - C, rocblas_datatype_bf16_r, N, C, rocblas_datatype_bf16_r, N, - rocblas_datatype_f32_r, algo, 0, 0)); - }); -} - -template <> -template <> -inline void Blas::GEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::complex alpha, const platform::complex *A, - const platform::complex *B, platform::complex beta, - platform::complex *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - rocblas_operation cuTransA = (transA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_operation cuTransB = (transB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - context_.GetComputeCapability(), 53, - platform::errors::InvalidArgument( - "cublas complex64 gemm requires GPU compute capability >= 53," - "but received %d", - context_.GetComputeCapability())); - - thrust::complex c_alpha = - thrust::complex(alpha.real, alpha.imag); - thrust::complex c_beta = thrust::complex(beta.real, beta.imag); - - auto &cuda_ctx = const_cast(context_); - CUBlas>::GEMM_EX( - &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, - rocblas_datatype_f32_c, ldb, A, rocblas_datatype_f32_c, lda, &c_beta, C, - rocblas_datatype_f32_c, N, rocblas_datatype_f32_c); -} -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, - int K, platform::complex alpha, - const platform::complex *A, - const platform::complex *B, - platform::complex beta, - platform::complex *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - rocblas_operation cuTransA = (transA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_operation cuTransB = (transB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - context_.GetComputeCapability(), 53, - platform::errors::InvalidArgument( - "cublas complex64 gemm requires GPU compute capability >= 53," - "but received %d", - context_.GetComputeCapability())); - - thrust::complex c_alpha = - thrust::complex(alpha.real, alpha.imag); - thrust::complex c_beta = thrust::complex(beta.real, beta.imag); - - auto &cuda_ctx = const_cast(context_); - CUBlas>::GEMM_EX( - &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, - rocblas_datatype_f32_c, ldb, A, rocblas_datatype_f32_c, lda, &c_beta, C, - rocblas_datatype_f32_c, N, rocblas_datatype_f32_c); -} - -template <> -template <> -inline void Blas::GEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::complex alpha, const platform::complex *A, - const platform::complex *B, platform::complex beta, - platform::complex *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - rocblas_operation cuTransA = (transA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_operation cuTransB = (transB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - context_.GetComputeCapability(), 53, - platform::errors::InvalidArgument( - "cublas complex128 gemm requires GPU compute capability >= 53," - "but received %d", - context_.GetComputeCapability())); - - thrust::complex c_alpha = - thrust::complex(alpha.real, alpha.imag); - thrust::complex c_beta = - thrust::complex(beta.real, beta.imag); - - auto &cuda_ctx = const_cast(context_); - CUBlas>::GEMM_EX( - &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, - rocblas_datatype_f64_c, ldb, A, rocblas_datatype_f64_c, lda, &c_beta, C, - rocblas_datatype_f64_c, N, rocblas_datatype_f64_c); -} -template <> -template <> -inline void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, - int K, platform::complex alpha, - const platform::complex *A, - const platform::complex *B, - platform::complex beta, - platform::complex *C) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - rocblas_operation cuTransA = (transA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_operation cuTransB = (transB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE( - context_.GetComputeCapability(), 53, - platform::errors::InvalidArgument( - "cublas complex128 gemm requires GPU compute capability >= 53," - "but received %d", - context_.GetComputeCapability())); - - thrust::complex c_alpha = - thrust::complex(alpha.real, alpha.imag); - thrust::complex c_beta = - thrust::complex(beta.real, beta.imag); - - auto &cuda_ctx = const_cast(context_); - CUBlas>::GEMM_EX( - &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, - rocblas_datatype_f64_c, ldb, A, rocblas_datatype_f64_c, lda, &c_beta, C, - rocblas_datatype_f64_c, N, rocblas_datatype_f64_c); -} - -template <> -template -void Blas::GEMM(bool transA, bool transB, int M, - int N, int K, T alpha, const T *A, - int lda, const T *B, int ldb, - T beta, T *C, int ldc) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - rocblas_operation cuTransA = - transA ? rocblas_operation_transpose : rocblas_operation_none; - rocblas_operation cuTransB = - transB ? rocblas_operation_transpose : rocblas_operation_none; - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, - &beta, C, ldc); - }); -} -template <> -template -void Blas::GEMM(bool transA, bool transB, int M, int N, int K, - T alpha, const T *A, int lda, const T *B, - int ldb, T beta, T *C, int ldc) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - rocblas_operation cuTransA = - transA ? rocblas_operation_transpose : rocblas_operation_none; - rocblas_operation cuTransB = - transB ? rocblas_operation_transpose : rocblas_operation_none; - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, - &beta, C, ldc); - }); -} - -template <> -template <> -inline void Blas::GEMM( - bool transA, bool transB, int M, int N, int K, platform::float16 alpha, - const platform::float16 *A, int lda, const platform::float16 *B, int ldb, - platform::float16 beta, platform::float16 *C, int ldc) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - rocblas_operation cuTransA = - transA ? rocblas_operation_transpose : rocblas_operation_none; - rocblas_operation cuTransB = - transB ? rocblas_operation_transpose : rocblas_operation_none; - - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, - B, ldb, A, lda, &beta, C, ldc); - }); -} -template <> -template <> -inline void Blas::GEMM(bool transA, bool transB, int M, int N, - int K, platform::float16 alpha, - const platform::float16 *A, int lda, - const platform::float16 *B, int ldb, - platform::float16 beta, - platform::float16 *C, int ldc) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - rocblas_operation cuTransA = - transA ? rocblas_operation_transpose : rocblas_operation_none; - rocblas_operation cuTransB = - transB ? rocblas_operation_transpose : rocblas_operation_none; - - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, - B, ldb, A, lda, &beta, C, ldc); - }); -} - -template <> -template -void Blas::AXPY(int n, T alpha, const T *x, - T *y) const { - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::AXPY(handle, n, &alpha, x, 1, y, 1); - }); -} -template <> -template -void Blas::AXPY(int n, T alpha, const T *x, T *y) const { - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::AXPY(handle, n, &alpha, x, 1, y, 1); - }); -} - -template <> -template -void Blas::SCAL(int n, const T alpha, T *x) const { - context_.CublasCall( - [&](rocblas_handle handle) { CUBlas::SCAL(handle, n, &alpha, x, 1); }); -} -template <> -template -void Blas::SCAL(int n, const T alpha, T *x) const { - context_.CublasCall( - [&](rocblas_handle handle) { CUBlas::SCAL(handle, n, &alpha, x, 1); }); -} - -template <> -template -void Blas::VCOPY(int n, const T *x, T *y) const { - context_.CublasCall( - [&](rocblas_handle handle) { CUBlas::VCOPY(handle, n, x, 1, y, 1); }); -} -template <> -template -void Blas::VCOPY(int n, const T *x, T *y) const { - context_.CublasCall( - [&](rocblas_handle handle) { CUBlas::VCOPY(handle, n, x, 1, y, 1); }); -} - -template <> -template -void Blas::GEMV(bool trans_a, int M, int N, - T alpha, const T *A, const T *B, - T beta, T *C) const { - rocblas_operation cuTransA = - !trans_a ? rocblas_operation_transpose : rocblas_operation_none; - - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1); - }); -} -template <> -template -void Blas::GEMV(bool trans_a, int M, int N, T alpha, - const T *A, const T *B, T beta, T *C) const { - rocblas_operation cuTransA = - !trans_a ? rocblas_operation_transpose : rocblas_operation_none; - - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1); - }); -} - -template <> -template <> -inline void Blas::GEMV( - bool trans_a, int M, int N, platform::float16 alpha, - const platform::float16 *A, const platform::float16 *B, - platform::float16 beta, platform::float16 *C) const { - // Because cublas doesn't support half gemv, we use cublasHgemm to achieve it. - if (trans_a) { - this->template GEMM(CblasNoTrans, CblasNoTrans, 1, N, M, - alpha, B, A, beta, C); - } else { - this->template GEMM(CblasNoTrans, CblasNoTrans, M, 1, N, - alpha, A, B, beta, C); - } -} -template <> -template <> -inline void Blas::GEMV(bool trans_a, int M, int N, - platform::float16 alpha, - const platform::float16 *A, - const platform::float16 *B, - platform::float16 beta, - platform::float16 *C) const { - // Because cublas doesn't support half gemv, we use cublasHgemm to achieve it. - if (trans_a) { - this->template GEMM(CblasNoTrans, CblasNoTrans, 1, N, M, - alpha, B, A, beta, C); - } else { - this->template GEMM(CblasNoTrans, CblasNoTrans, M, 1, N, - alpha, A, B, beta, C); - } -} - -template <> -template <> -inline void Blas::GEMV( - bool trans_a, int M, int N, platform::bfloat16 alpha, - const platform::bfloat16 *A, const platform::bfloat16 *B, - platform::bfloat16 beta, platform::bfloat16 *C) const { - // Because rocblas doesn't support bfloat16 gemv, we use gemmex to achieve it. - if (trans_a) { - this->template GEMM(CblasNoTrans, CblasNoTrans, 1, N, M, - alpha, B, A, beta, C); - } else { - this->template GEMM(CblasNoTrans, CblasNoTrans, M, 1, N, - alpha, A, B, beta, C); - } -} -template <> -template <> -inline void Blas::GEMV(bool trans_a, int M, int N, - platform::bfloat16 alpha, - const platform::bfloat16 *A, - const platform::bfloat16 *B, - platform::bfloat16 beta, - platform::bfloat16 *C) const { - // Because rocblas doesn't support bfloat16 gemv, we use gemmex to achieve it. - if (trans_a) { - this->template GEMM(CblasNoTrans, CblasNoTrans, 1, N, M, - alpha, B, A, beta, C); - } else { - this->template GEMM(CblasNoTrans, CblasNoTrans, M, 1, N, - alpha, A, B, beta, C); - } -} - -template <> -template -void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - T alpha, const T *A, const T *B, T beta, T *C, int batchCount, - int64_t strideA, int64_t strideB) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - rocblas_operation cuTransA = (transA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_operation cuTransB = (transB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - const int64_t strideC = M * N; - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha, - B, ldb, strideB, A, lda, strideA, &beta, C, - ldc, strideC, batchCount); - }); -} - -template <> -template -void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, - int K, T alpha, const T *A, const T *B, - T beta, T *C, int batchCount, - int64_t strideA, - int64_t strideB) const { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - rocblas_operation cuTransA = (transA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_operation cuTransB = (transB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - const int64_t strideC = M * N; - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha, - B, ldb, strideB, A, lda, strideA, &beta, C, - ldc, strideC, batchCount); - }); -} - -template <> -template <> -inline void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::bfloat16 alpha, const platform::bfloat16 *A, - const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C, - int batchCount, int64_t strideA, int64_t strideB) const { - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - const int64_t strideC = M * N; - rocblas_operation cuTransA = (transA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_operation cuTransB = (transB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - rocblas_gemm_algo algo = rocblas_gemm_algo_standard; - - context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::rocblas_gemm_strided_batched_ex( - handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, - rocblas_datatype_bf16_r, ldb, strideB, A, rocblas_datatype_bf16_r, - lda, strideA, &h_beta, C, rocblas_datatype_bf16_r, ldc, strideC, C, - rocblas_datatype_bf16_r, ldc, strideC, batchCount, - rocblas_datatype_f32_r, algo, 0, 0)); - }); -} - -template <> -template <> -inline void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::bfloat16 alpha, const platform::bfloat16 *A, - const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C, - int batchCount, int64_t strideA, int64_t strideB) const { - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - const int64_t strideC = M * N; - rocblas_operation cuTransA = (transA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_operation cuTransB = (transB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - rocblas_gemm_algo algo = rocblas_gemm_algo_standard; - - context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::rocblas_gemm_strided_batched_ex( - handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, - rocblas_datatype_bf16_r, ldb, strideB, A, rocblas_datatype_bf16_r, - lda, strideA, &h_beta, C, rocblas_datatype_bf16_r, ldc, strideC, C, - rocblas_datatype_bf16_r, ldc, strideC, batchCount, - rocblas_datatype_f32_r, algo, 0, 0)); - }); -} - -template <> -template -void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - T alpha, const T **A, const T **B, T beta, T **C, int batchCount) const { - for (int k = 0; k < batchCount; ++k) { - this->template GEMM(transA, transB, M, N, K, alpha, A[k], B[k], beta, - C[k]); - } -} - -template <> -template -void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, - int K, T alpha, const T **A, - const T **B, T beta, T **C, - int batchCount) const { - for (int k = 0; k < batchCount; ++k) { - this->template GEMM(transA, transB, M, N, K, alpha, A[k], B[k], beta, - C[k]); - } -} - -template <> -template <> -inline void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::float16 alpha, const platform::float16 **A, - const platform::float16 **B, platform::float16 beta, platform::float16 **C, - int batchCount) const { - for (int k = 0; k < batchCount; ++k) { - this->template GEMM(transA, transB, M, N, K, alpha, A[k], - B[k], beta, C[k]); - } -} -template <> -template <> -inline void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::float16 alpha, const platform::float16 **A, - const platform::float16 **B, platform::float16 beta, platform::float16 **C, - int batchCount) const { - for (int k = 0; k < batchCount; ++k) { - this->template GEMM(transA, transB, M, N, K, alpha, A[k], - B[k], beta, C[k]); - } -} - -template <> -template <> -inline void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::bfloat16 alpha, const platform::bfloat16 **A, - const platform::bfloat16 **B, platform::bfloat16 beta, - platform::bfloat16 **C, int batchCount) const { - for (int k = 0; k < batchCount; ++k) { - this->template GEMM(transA, transB, M, N, K, alpha, - A[k], B[k], beta, C[k]); - } -} - -template <> -template <> -inline void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::bfloat16 alpha, const platform::bfloat16 **A, - const platform::bfloat16 **B, platform::bfloat16 beta, - platform::bfloat16 **C, int batchCount) const { - for (int k = 0; k < batchCount; ++k) { - this->template GEMM(transA, transB, M, N, K, alpha, - A[k], B[k], beta, C[k]); - } -} - -template <> -template -void Blas::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, - CBLAS_DIAG diag, int M, int N, - T alpha, const T *A, int lda, T *B, - int ldb) const { - // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` - // where ' stands for transpose - rocblas_side cuSide = - (side == CblasLeft) ? rocblas_side_right : rocblas_side_left; - rocblas_fill cuUplo = - (uplo == CblasLower) ? rocblas_fill_upper : rocblas_fill_lower; - // use CUBLAS_OP_C (conjugate transpose) for complex - rocblas_operation cuTransA = (transA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_diagonal cuDiag = - (diag == CblasUnit) ? rocblas_diagonal_unit : rocblas_diagonal_non_unit; - - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::TRSM(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A, - lda, B, ldb); - }); -} -template <> -template -void Blas::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, CBLAS_DIAG diag, - int M, int N, T alpha, const T *A, int lda, - T *B, int ldb) const { - // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` - // where ' stands for transpose - rocblas_side cuSide = - (side == CblasLeft) ? rocblas_side_right : rocblas_side_left; - rocblas_fill cuUplo = - (uplo == CblasLower) ? rocblas_fill_upper : rocblas_fill_lower; - // use CUBLAS_OP_C (conjugate transpose) for complex - rocblas_operation cuTransA = (transA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_diagonal cuDiag = - (diag == CblasUnit) ? rocblas_diagonal_unit : rocblas_diagonal_non_unit; - - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::TRSM(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A, - lda, B, ldb); - }); -} - -template <> -template -void Blas::BatchedGETRF(int n, T **a, int *ipiv, - int *info, - int batch_size) const { - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size); - }); -} -template <> -template -void Blas::BatchedGETRF(int n, T **a, int *ipiv, int *info, - int batch_size) const { - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size); - }); -} - -template <> -template -void Blas::BatchedGETRI(int n, const T **a, - const int *ipiv, T **a_inv, - int *info, - int batch_size) const { - PADDLE_ENFORCE_NE( - a_inv, a, - platform::errors::InvalidArgument( - "cuBLAS fuction 'cublasgetrfBatched' cannot be executed " - "in-place. The memory space of output matrix (address: %p) cannot " - "overlap memory space of input matrix (address: %p).", - a_inv, a)); - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size); - }); -} -template <> -template -void Blas::BatchedGETRI(int n, const T **a, const int *ipiv, - T **a_inv, int *info, - int batch_size) const { - PADDLE_ENFORCE_NE( - a_inv, a, - platform::errors::InvalidArgument( - "cuBLAS fuction 'cublasgetrfBatched' cannot be executed " - "in-place. The memory space of output matrix (address: %p) cannot " - "overlap memory space of input matrix (address: %p).", - a_inv, a)); - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size); - }); -} - -template <> -template -void Blas::BatchedMatInv(int n, const T **a, - T **a_inv, int *info, - int batch_size) const { - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size); - }); -} -template <> -template -void Blas::BatchedMatInv(int n, const T **a, T **a_inv, - int *info, int batch_size) const { - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size); - }); -} - -template <> -template -void Blas::BatchedGETRS( - CBLAS_TRANSPOSE trans, int n, int nrhs, const T **a, int lda, int *ipiv, - T **b, int ldb, int *info, int batch_size) const { - rocblas_operation cuTrans = (trans == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::GETRS_BATCH(handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info, - batch_size); - }); -} -template <> -template -void Blas::BatchedGETRS(CBLAS_TRANSPOSE trans, int n, - int nrhs, const T **a, int lda, - int *ipiv, T **b, int ldb, int *info, - int batch_size) const { - rocblas_operation cuTrans = (trans == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::GETRS_BATCH(handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info, - batch_size); - }); -} - -template <> -template -void Blas::BatchedTRSM( - CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA, CBLAS_DIAG diag, - int M, int N, T alpha, const T **A, int lda, T **B, int ldb, - int batch_size) const { - // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` - // where ' stands for transpose - rocblas_side cuSide = - (side == CblasLeft) ? rocblas_side_right : rocblas_side_left; - rocblas_fill cuUplo = - (uplo == CblasLower) ? rocblas_fill_upper : rocblas_fill_lower; - // use CUBLAS_OP_C (conjugate transpose) for complex - rocblas_operation cuTransA = (transA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_diagonal cuDiag = - (diag == CblasUnit) ? rocblas_diagonal_unit : rocblas_diagonal_non_unit; - - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::TRSM_BATCH(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, - &alpha, A, lda, B, ldb, batch_size); - }); -} -template <> -template -void Blas::BatchedTRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, - CBLAS_DIAG diag, int M, int N, T alpha, - const T **A, int lda, T **B, int ldb, - int batch_size) const { - // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` - // where ' stands for transpose - rocblas_side cuSide = - (side == CblasLeft) ? rocblas_side_right : rocblas_side_left; - rocblas_fill cuUplo = - (uplo == CblasLower) ? rocblas_fill_upper : rocblas_fill_lower; - // use CUBLAS_OP_C (conjugate transpose) for complex - rocblas_operation cuTransA = (transA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_diagonal cuDiag = - (diag == CblasUnit) ? rocblas_diagonal_unit : rocblas_diagonal_non_unit; - - context_.CublasCall([&](rocblas_handle handle) { - CUBlas::TRSM_BATCH(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, - &alpha, A, lda, B, ldb, batch_size); - }); -} - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/context_project.h b/paddle/fluid/operators/math/context_project.h index ac2cd2a996173..7e5d988861bd4 100644 --- a/paddle/fluid/operators/math/context_project.h +++ b/paddle/fluid/operators/math/context_project.h @@ -18,8 +18,8 @@ limitations under the License. */ #include #include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/im2col.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -223,7 +223,7 @@ class ContextProjectGradFunctor { int input_row_begin, input_row_end; int sequence_height, sequence_width; sequence_width = in.dims()[1]; - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); if (input_grad) { for (int i = 0; i < static_cast(lod_level_0.size()) - 1; ++i) { diff --git a/paddle/fluid/operators/math/eigen_values_vectors.h b/paddle/fluid/operators/math/eigen_values_vectors.h index b946d4d072ba2..5ec808613266a 100644 --- a/paddle/fluid/operators/math/eigen_values_vectors.h +++ b/paddle/fluid/operators/math/eigen_values_vectors.h @@ -15,8 +15,8 @@ #pragma once #include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/operators/math/lapack_function.h" #include "paddle/fluid/operators/svd_helper.h" +#include "paddle/pten/kernels/funcs/lapack/lapack_function.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/dynload/cusolver.h" #endif // PADDLE_WITH_CUDA @@ -98,9 +98,9 @@ struct MatrixEighFunctor { int info = 0; // Call lapackEigh to get the optimal size of work data - math::lapackEigh(jobz, uplo, n, input_vector, lda, out_value, - &lwork_opt, lwork, &rwork_opt, lrwork, - &iwork_opt, liwork, &info); + pten::funcs::lapackEigh( + jobz, uplo, n, input_vector, lda, out_value, &lwork_opt, lwork, + &rwork_opt, lrwork, &iwork_opt, liwork, &info); lwork = std::max(1, static_cast(lwork_opt)); liwork = std::max(1, iwork_opt); @@ -123,7 +123,7 @@ struct MatrixEighFunctor { for (auto i = 0; i < batch_size; i++) { auto *value_data = out_value + i * values_stride; auto *input_data = input_vector + i * vector_stride; - math::lapackEigh>( + pten::funcs::lapackEigh>( jobz, uplo, n, input_data, lda, value_data, work_data, lwork, rwork_data, lrwork, iwork_data, liwork, &info); CheckEighResult(i, info); diff --git a/paddle/fluid/operators/math/fc.cc b/paddle/fluid/operators/math/fc.cc index 38519a770c346..9229d670966a8 100644 --- a/paddle/fluid/operators/math/fc.cc +++ b/paddle/fluid/operators/math/fc.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/jit/kernels.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -28,7 +28,7 @@ class FCFunctor { const int N, const int K, const T* X, const T* W, T* Y, const T* B = nullptr, bool relu = false, bool padding_weights = false) { - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); framework::Tensor Y1; T* Y1_data = nullptr; if (padding_weights) { diff --git a/paddle/fluid/operators/math/fc.cu b/paddle/fluid/operators/math/fc.cu index 69f62d1d53d72..ade7df5866f93 100644 --- a/paddle/fluid/operators/math/fc.cu +++ b/paddle/fluid/operators/math/fc.cu @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/fc.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -85,7 +85,7 @@ class FCFunctor { padding_weights, false, platform::errors::PermissionDenied( "Weight padding in fc can not be used in GPU scope.")); - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); blas.GEMM(false, false, M, N, K, static_cast(1.0), X, K, W, N, static_cast(0.0), Y, N); if (B == NULL) { diff --git a/paddle/fluid/operators/math/gru_compute.cc b/paddle/fluid/operators/math/gru_compute.cc index b7a3974ae33e7..515988e587e15 100644 --- a/paddle/fluid/operators/math/gru_compute.cc +++ b/paddle/fluid/operators/math/gru_compute.cc @@ -11,9 +11,9 @@ limitations under the License. */ #include "paddle/fluid/operators/math/gru_compute.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" #include "paddle/fluid/operators/math/detail/gru_kernel.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace platform { @@ -33,7 +33,7 @@ struct GRUUnitFunctor { const detail::ActivationType active_gate, bool origin_mode) { #if !defined(__NVCC__) && !defined(__HIPCC___) - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); if (value.prev_out_value) { blas.GEMM(false, false, batch_size, frame_size * 2, frame_size, 1, value.prev_out_value, frame_size, value.gate_weight, @@ -70,7 +70,7 @@ struct GRUUnitGradFunctor { detail::backward_state_grad(detail::backward::gru_stateGrad(), value, grad, frame_size, batch_size, active_node, origin_mode); - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); if (value.prev_out_value && grad.prev_out_grad) { blas.GEMM(false, true, batch_size, frame_size, frame_size, 1, grad.gate_grad + frame_size * 2, frame_size * 3, @@ -109,7 +109,7 @@ struct GRUUnitFunctorV2 { const detail::ActivationType active_node, const detail::ActivationType active_gate) { #if !defined(__NVCC__) && !defined(__HIPCC___) - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); if (value.prev_out_value) { blas.GEMM(CblasNoTrans, CblasTrans, batch_size, frame_size, frame_size, 1, value.prev_out_value, value.state_weight, 0, @@ -147,7 +147,7 @@ struct GRUUnitGradFunctorV2 { // grad_reset_output, grad_reset_gate detail::cpu_gru_backward(context, detail::backward::gru(), value, grad, frame_size, batch_size, active_node, active_gate); - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); if (grad.prev_out_grad && value.prev_out_value) { // update prev_out_grad blas.GEMM(false, false, batch_size, frame_size, frame_size, 1, diff --git a/paddle/fluid/operators/math/gru_compute.cu b/paddle/fluid/operators/math/gru_compute.cu index cf3d57b0630da..2ac5531d1ff4e 100644 --- a/paddle/fluid/operators/math/gru_compute.cu +++ b/paddle/fluid/operators/math/gru_compute.cu @@ -10,10 +10,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/detail/gru_gpu_kernel.h" #include "paddle/fluid/operators/math/detail/gru_kernel.h" #include "paddle/fluid/operators/math/gru_compute.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -77,7 +77,7 @@ struct GRUUnitFunctor { threads = dim3(32, 32); grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32); } - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); if (value.prev_out_value) { blas.GEMM(false, false, batch_size, frame_size * 2, frame_size, 1, value.prev_out_value, frame_size, value.gate_weight, @@ -162,7 +162,7 @@ struct GRUUnitGradFunctor { grad.output_grad, frame_size, batch_size, active_node, origin_mode); } - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); if (value.prev_out_value && grad.prev_out_grad) { blas.GEMM(false, true, batch_size, frame_size, frame_size, 1, diff --git a/paddle/fluid/operators/math/lapack_function.cc b/paddle/fluid/operators/math/lapack_function.cc deleted file mode 100644 index 33fa2efb12c1b..0000000000000 --- a/paddle/fluid/operators/math/lapack_function.cc +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/math/lapack_function.h" -#include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/platform/dynload/lapack.h" - -namespace paddle { -namespace operators { -namespace math { - -// LU (for example) -template <> -void lapackLu(int m, int n, double *a, int lda, int *ipiv, int *info) { - platform::dynload::dgetrf_(&m, &n, a, &lda, ipiv, info); -} - -template <> -void lapackLu(int m, int n, float *a, int lda, int *ipiv, int *info) { - platform::dynload::sgetrf_(&m, &n, a, &lda, ipiv, info); -} - -// eigh -template <> -void lapackEigh(char jobz, char uplo, int n, float *a, int lda, float *w, - float *work, int lwork, float *rwork, int lrwork, - int *iwork, int liwork, int *info) { - (void)rwork; // unused - (void)lrwork; // unused - platform::dynload::ssyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, - &liwork, info); -} - -template <> -void lapackEigh(char jobz, char uplo, int n, double *a, int lda, - double *w, double *work, int lwork, double *rwork, - int lrwork, int *iwork, int liwork, int *info) { - (void)rwork; // unused - (void)lrwork; // unused - platform::dynload::dsyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, - &liwork, info); -} - -template <> -void lapackEigh, float>( - char jobz, char uplo, int n, platform::complex *a, int lda, float *w, - platform::complex *work, int lwork, float *rwork, int lrwork, - int *iwork, int liwork, int *info) { - platform::dynload::cheevd_(&jobz, &uplo, &n, - reinterpret_cast *>(a), &lda, - w, reinterpret_cast *>(work), - &lwork, rwork, &lrwork, iwork, &liwork, info); -} - -template <> -void lapackEigh, double>( - char jobz, char uplo, int n, platform::complex *a, int lda, - double *w, platform::complex *work, int lwork, double *rwork, - int lrwork, int *iwork, int liwork, int *info) { - platform::dynload::zheevd_(&jobz, &uplo, &n, - reinterpret_cast *>(a), &lda, - w, reinterpret_cast *>(work), - &lwork, rwork, &lrwork, iwork, &liwork, info); -} - -// Eig -template <> -void lapackEig(char jobvl, char jobvr, int n, double *a, int lda, - double *w, double *vl, int ldvl, double *vr, int ldvr, - double *work, int lwork, double *rwork, int *info) { - double *wr = w; - double *wi = w + n; - (void)rwork; // unused - platform::dynload::dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, - &ldvr, work, &lwork, info); -} - -template <> -void lapackEig(char jobvl, char jobvr, int n, float *a, int lda, - float *w, float *vl, int ldvl, float *vr, int ldvr, - float *work, int lwork, float *rwork, int *info) { - float *wr = w; - float *wi = w + n; - (void)rwork; // unused - platform::dynload::sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, - &ldvr, work, &lwork, info); -} - -template <> -void lapackEig, double>( - char jobvl, char jobvr, int n, platform::complex *a, int lda, - platform::complex *w, platform::complex *vl, int ldvl, - platform::complex *vr, int ldvr, platform::complex *work, - int lwork, double *rwork, int *info) { - platform::dynload::zgeev_( - &jobvl, &jobvr, &n, reinterpret_cast *>(a), &lda, - reinterpret_cast *>(w), - reinterpret_cast *>(vl), &ldvl, - reinterpret_cast *>(vr), &ldvr, - reinterpret_cast *>(work), &lwork, rwork, info); -} - -template <> -void lapackEig, float>( - char jobvl, char jobvr, int n, platform::complex *a, int lda, - platform::complex *w, platform::complex *vl, int ldvl, - platform::complex *vr, int ldvr, platform::complex *work, - int lwork, float *rwork, int *info) { - platform::dynload::cgeev_( - &jobvl, &jobvr, &n, reinterpret_cast *>(a), &lda, - reinterpret_cast *>(w), - reinterpret_cast *>(vl), &ldvl, - reinterpret_cast *>(vr), &ldvr, - reinterpret_cast *>(work), &lwork, rwork, info); -} - -template <> -void lapackGels(char trans, int m, int n, int nrhs, double *a, int lda, - double *b, int ldb, double *work, int lwork, - int *info) { - platform::dynload::dgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, - &lwork, info); -} - -template <> -void lapackGels(char trans, int m, int n, int nrhs, float *a, int lda, - float *b, int ldb, float *work, int lwork, int *info) { - platform::dynload::sgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, - &lwork, info); -} - -template <> -void lapackGelsd(int m, int n, int nrhs, double *a, int lda, double *b, - int ldb, double *s, double rcond, int *rank, - double *work, int lwork, double *rwork, int *iwork, - int *info) { - platform::dynload::dgelsd_(&m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank, - work, &lwork, iwork, info); -} - -template <> -void lapackGelsd(int m, int n, int nrhs, float *a, int lda, float *b, - int ldb, float *s, float rcond, int *rank, float *work, - int lwork, float *rwork, int *iwork, int *info) { - platform::dynload::sgelsd_(&m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank, - work, &lwork, iwork, info); -} - -template <> -void lapackGelsy(int m, int n, int nrhs, double *a, int lda, double *b, - int ldb, int *jpvt, double rcond, int *rank, - double *work, int lwork, double *rwork, int *info) { - platform::dynload::dgelsy_(&m, &n, &nrhs, a, &lda, b, &ldb, jpvt, &rcond, - rank, work, &lwork, info); -} - -template <> -void lapackGelsy(int m, int n, int nrhs, float *a, int lda, float *b, - int ldb, int *jpvt, float rcond, int *rank, float *work, - int lwork, float *rwork, int *info) { - platform::dynload::sgelsy_(&m, &n, &nrhs, a, &lda, b, &ldb, jpvt, &rcond, - rank, work, &lwork, info); -} - -template <> -void lapackGelss(int m, int n, int nrhs, double *a, int lda, double *b, - int ldb, double *s, double rcond, int *rank, - double *work, int lwork, double *rwork, int *info) { - platform::dynload::dgelss_(&m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank, - work, &lwork, info); -} - -template <> -void lapackGelss(int m, int n, int nrhs, float *a, int lda, float *b, - int ldb, float *s, float rcond, int *rank, float *work, - int lwork, float *rwork, int *info) { - platform::dynload::sgelss_(&m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank, - work, &lwork, info); -} - -template <> -void lapackCholeskySolve>( - char uplo, int n, int nrhs, platform::complex *a, int lda, - platform::complex *b, int ldb, int *info) { - platform::dynload::zpotrs_( - &uplo, &n, &nrhs, reinterpret_cast *>(a), &lda, - reinterpret_cast *>(b), &ldb, info); -} - -template <> -void lapackCholeskySolve>(char uplo, int n, int nrhs, - platform::complex *a, - int lda, - platform::complex *b, - int ldb, int *info) { - platform::dynload::cpotrs_( - &uplo, &n, &nrhs, reinterpret_cast *>(a), &lda, - reinterpret_cast *>(b), &ldb, info); -} - -template <> -void lapackCholeskySolve(char uplo, int n, int nrhs, double *a, int lda, - double *b, int ldb, int *info) { - platform::dynload::dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info); -} - -template <> -void lapackCholeskySolve(char uplo, int n, int nrhs, float *a, int lda, - float *b, int ldb, int *info) { - platform::dynload::spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info); -} - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/lapack_function.h b/paddle/fluid/operators/math/lapack_function.h deleted file mode 100644 index 488b225ef570e..0000000000000 --- a/paddle/fluid/operators/math/lapack_function.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -namespace paddle { -namespace operators { -namespace math { - -// LU (for example) -template -void lapackLu(int m, int n, T *a, int lda, int *ipiv, int *info); - -// Eigh -template -void lapackEigh(char jobz, char uplo, int n, T *a, int lda, ValueType *w, - T *work, int lwork, ValueType *rwork, int lrwork, int *iwork, - int liwork, int *info); - -// Eig -template -void lapackEig(char jobvl, char jobvr, int n, T1 *a, int lda, T1 *w, T1 *vl, - int ldvl, T1 *vr, int ldvr, T1 *work, int lwork, T2 *rwork, - int *info); - -// Gels -template -void lapackGels(char trans, int m, int n, int nrhs, T *a, int lda, T *b, - int ldb, T *work, int lwork, int *info); - -// Gelsd -template -void lapackGelsd(int m, int n, int nrhs, T1 *a, int lda, T1 *b, int ldb, T2 *s, - T2 rcond, int *rank, T1 *work, int lwork, T2 *rwork, - int *iwork, int *info); - -// Gelsy -template -void lapackGelsy(int m, int n, int nrhs, T1 *a, int lda, T1 *b, int ldb, - int *jpvt, T2 rcond, int *rank, T1 *work, int lwork, T2 *rwork, - int *info); - -// Gelss -template -void lapackGelss(int m, int n, int nrhs, T1 *a, int lda, T1 *b, int ldb, T2 *s, - T2 rcond, int *rank, T1 *work, int lwork, T2 *rwork, - int *info); - -template -void lapackCholeskySolve(char uplo, int n, int nrhs, T *a, int lda, T *b, - int ldb, int *info); - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index a94bb594be5f9..1c3acc14d3cb2 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -135,8 +135,8 @@ struct MatrixBitCodeFunctorMul : public boost::static_visitor { template void operator()(const CodeTable &code_table) { - auto blas = - GetBlas(platform::CPUDeviceContext()); + auto blas = pten::funcs::GetBlas( + platform::CPUDeviceContext()); size_t num_samples = tmat_->dims()[0]; size_t tmat_width = tmat_->dims()[1]; size_t input_width = input_.dims()[1]; @@ -183,8 +183,8 @@ struct MatrixBitCodeFunctorMulGradWeight : public boost::static_visitor { : tmat_(tmat), weight_(weight), input_(input) {} template void operator()(const CodeTable &code_table) { - auto blas = - GetBlas(platform::CPUDeviceContext()); + auto blas = pten::funcs::GetBlas( + platform::CPUDeviceContext()); size_t num_samples = tmat_.dims()[0]; size_t input_width = input_.dims()[1]; size_t tmat_width = tmat_.dims()[1]; @@ -237,8 +237,8 @@ struct MatrixBitCodeFunctorMulGradWeightSR template void operator()(const CodeTable &code_table) { - auto blas = - GetBlas(platform::CPUDeviceContext()); + auto blas = pten::funcs::GetBlas( + platform::CPUDeviceContext()); size_t num_samples = tmat_.dims()[0]; size_t input_width = input_.dims()[1]; size_t tmat_width = tmat_.dims()[1]; diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index 13ddd27cbf0d7..968bd4a47503e 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -21,9 +21,9 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/variant.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #if defined(_WIN32) #include diff --git a/paddle/fluid/operators/math/matrix_inverse.cc b/paddle/fluid/operators/math/matrix_inverse.cc index 60481491cb4b4..3a0d494eb88f3 100644 --- a/paddle/fluid/operators/math/matrix_inverse.cc +++ b/paddle/fluid/operators/math/matrix_inverse.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/matrix_inverse.h" #include "Eigen/Core" #include "Eigen/LU" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/matrix_inverse.cu.cc b/paddle/fluid/operators/math/matrix_inverse.cu.cc index 0b6a097d09d15..a9fe8af37b329 100644 --- a/paddle/fluid/operators/math/matrix_inverse.cu.cc +++ b/paddle/fluid/operators/math/matrix_inverse.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/matrix_inverse.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace platform { @@ -72,7 +72,7 @@ class MatrixInverseFunctor { memory::Alloc(context, num_ints * sizeof(int)); int* gpu_info_ptr = reinterpret_cast(tmp_gpu_info_data->ptr()); - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); std::vector info; // only for singular checking info.resize(batch_size); diff --git a/paddle/fluid/operators/math/matrix_solve.cc b/paddle/fluid/operators/math/matrix_solve.cc index 95c84d83976f5..0e082226b32a7 100644 --- a/paddle/fluid/operators/math/matrix_solve.cc +++ b/paddle/fluid/operators/math/matrix_solve.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/matrix_solve.h" #include "Eigen/Core" #include "Eigen/LU" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -62,7 +62,7 @@ class TriangularSolveFunctor { batch_size *= a_dim[i]; } - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); for (int i = 0; i < batch_size; i++) { blas.TRSM(side, uplo, transA, diag, M, N, T(1), a_data + i * M * M, lda, b_data + i * N * M, ldb); diff --git a/paddle/fluid/operators/math/matrix_solve.cu.cc b/paddle/fluid/operators/math/matrix_solve.cu.cc index f23c3f14c5c9a..b6dd7e6d209ed 100644 --- a/paddle/fluid/operators/math/matrix_solve.cu.cc +++ b/paddle/fluid/operators/math/matrix_solve.cu.cc @@ -14,9 +14,9 @@ limitations under the License. */ #include "paddle/fluid/operators/math/matrix_solve.h" #include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/solve_op.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { @@ -105,7 +105,7 @@ class MatrixSolveFunctor { memory::Alloc(context, num_ints * sizeof(int)); int* gpu_info_ptr = reinterpret_cast(tmp_gpu_info_data->ptr()); - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); // only for singular checking std::vector info; @@ -189,7 +189,7 @@ class TriangularSolveFunctor { batch_size *= a_dim[i]; } - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); if (batch_size <= 8 && M >= 64) { for (auto i = 0; i < batch_size; i++) { blas.TRSM(side, uplo, transA, diag, M, N, static_cast(1.0), diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index b921e844c9f21..e750949c566a0 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -224,7 +224,7 @@ struct SelectedRowsSumTo { auto* in2_value = input2->mutable_value(); auto* in2_data = in2_value->data(); - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); size_t offset = 0u; for (size_t i = 0u; i != input1.size(); ++i) { auto& in_value = input1[i]->value(); @@ -295,15 +295,15 @@ namespace scatter { template typename std::enable_if::value>::type elementwise_add_to( - BlasT* blas, size_t data_len, const T* in, - T* out) { + pten::funcs::BlasT* blas, size_t data_len, + const T* in, T* out) { blas->AXPY(data_len, T(1.f), in, out); } template typename std::enable_if::value>::type elementwise_add_to( - BlasT* blas, size_t data_len, const T* in, - T* out) { + pten::funcs::BlasT* blas, size_t data_len, + const T* in, T* out) { for (size_t i = 0; i < data_len; i++) { out[i] += in[i]; } @@ -316,7 +316,7 @@ add_sparse_inputs(const std::vector& inputs, int64_t input_width, const platform::CPUDeviceContext& context, T* out_data) { #ifndef PADDLE_WITH_MKLDNN - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); #endif for (auto* input : inputs) { if (input->rows().size() == 0) { @@ -350,7 +350,7 @@ add_sparse_inputs(const std::vector& inputs, int64_t input_width, const platform::CPUDeviceContext& context, T* out_data) { VLOG(4) << "[CPU] add_sparse_inputs <" << typeid(T).name(); - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); for (auto* input : inputs) { if (input->rows().size() == 0) { continue; @@ -697,7 +697,7 @@ struct MergeAverage { rows_to_id[merge_rows[i]] = i; } - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); for (auto* input : inputs) { if (input->rows().size() == 0) { continue; diff --git a/paddle/fluid/operators/math/selected_rows_functor.h b/paddle/fluid/operators/math/selected_rows_functor.h index e0ac583f15b60..4f0798c0e8c99 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.h +++ b/paddle/fluid/operators/math/selected_rows_functor.h @@ -18,8 +18,8 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/selected_rows_utils.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" #define INLINE_FOR2(sizei, sizej) \ diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index 22cd435297341..f3439053ac088 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -15,8 +15,8 @@ limitations under the License. */ #include #include "paddle/fluid/operators/jit/kernels.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/sequence_pooling.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { @@ -289,7 +289,7 @@ class SumSeqPoolGradFunctor { in_w, out_w, in_w, out_w)); const T* out_g_data = out_grad.data(); T* in_g_data = in_grad->mutable_data(context.GetPlace()); - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { int64_t h = static_cast(lod[i + 1] - lod[i]); if (h == 0) continue; diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 9f0f8a96c7677..4bf59939dd8ff 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -25,7 +25,8 @@ namespace operators { /** * Printing shape information into a string is easy to use. */ -inline static std::string DumpMatrixShape(const math::MatDescriptor &desc) { +inline static std::string DumpMatrixShape( + const pten::funcs::MatDescriptor &desc) { std::stringstream buffer; buffer << "[" << desc.batch_size_ << ", " << desc.height_ << ", " << desc.width_ << "]"; @@ -65,10 +66,10 @@ class MatMulKernel : public framework::OpKernel { auto *out = context.Output("Out"); out->mutable_data(context.GetPlace()); - auto blas = math::GetBlas(context); - auto mat_dim_a = math::CreateMatrixDescriptor( + auto blas = pten::funcs::GetBlas(context); + auto mat_dim_a = pten::funcs::CreateMatrixDescriptor( RowMatrixFromVector(x.dims()), 0, context.Attr("transpose_X")); - auto mat_dim_b = math::CreateMatrixDescriptor( + auto mat_dim_b = pten::funcs::CreateMatrixDescriptor( ColumnMatrixFromVector(y.dims()), 0, context.Attr("transpose_Y")); auto scale = static_cast(context.Attr("alpha")); @@ -142,7 +143,7 @@ static framework::Tensor FoldHeadAndLastDims(const DeviceContext &context, * If transposed, `H,W` will be swapped. */ static void ReshapeTensorIntoMatrixSequence( - framework::Tensor *x, const math::MatDescriptor &descriptor) { + framework::Tensor *x, const pten::funcs::MatDescriptor &descriptor) { int64_t h, w; h = descriptor.height_; w = descriptor.width_; @@ -176,8 +177,8 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x, bool trans_y) { auto x_dim = RowMatrixFromVector(x->dims()); auto y_dim = ColumnMatrixFromVector(y->dims()); - auto mat_dim_x = math::CreateMatrixDescriptor(x_dim, 0, trans_x); - auto mat_dim_y = math::CreateMatrixDescriptor(y_dim, 0, trans_y); + auto mat_dim_x = pten::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x); + auto mat_dim_y = pten::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y); if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { out->Resize({mat_dim_x.height_, mat_dim_y.width_}); } else { @@ -222,9 +223,9 @@ class MatMulGradKernel : public framework::OpKernel { const framework::Tensor &b, bool trans_b, framework::Tensor *out) const { out->mutable_data(context.GetPlace()); - auto blas = math::GetBlas(context); - auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); - auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); + auto blas = pten::funcs::GetBlas(context); + auto mat_dim_a = pten::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); + auto mat_dim_b = pten::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); int head_number = 1; #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ @@ -404,9 +405,9 @@ class MatMulDoubleGradKernel : public framework::OpKernel { const framework::Tensor &b, bool trans_b, bool flag, framework::Tensor *out) const { out->mutable_data(context.GetPlace()); - auto blas = math::GetBlas(context); - auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); - auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); + auto blas = pten::funcs::GetBlas(context); + auto mat_dim_a = pten::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); + auto mat_dim_b = pten::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); int head_number = 1; #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ @@ -584,12 +585,12 @@ class MatMulOp : public framework::OperatorWithKernel { auto dim_x = GetDimForInput(*context, "X"); auto dim_y = GetDimForInput(*context, "Y"); - auto mat_dim_x = - math::CreateMatrixDescriptor(RowMatrixFromVector(dim_x), 0, - context->Attrs().Get("transpose_X")); - auto mat_dim_y = - math::CreateMatrixDescriptor(ColumnMatrixFromVector(dim_y), 0, - context->Attrs().Get("transpose_Y")); + auto mat_dim_x = pten::funcs::CreateMatrixDescriptor( + RowMatrixFromVector(dim_x), 0, + context->Attrs().Get("transpose_X")); + auto mat_dim_y = pten::funcs::CreateMatrixDescriptor( + ColumnMatrixFromVector(dim_y), 0, + context->Attrs().Get("transpose_Y")); if (mat_dim_x.width_ == -1) { mat_dim_x.width_ = mat_dim_y.height_; diff --git a/paddle/fluid/operators/matmul_op_xpu.cc b/paddle/fluid/operators/matmul_op_xpu.cc index 53593d2db01f7..76bf4976176d3 100644 --- a/paddle/fluid/operators/matmul_op_xpu.cc +++ b/paddle/fluid/operators/matmul_op_xpu.cc @@ -19,8 +19,8 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/xpu_api_wrapper.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -53,7 +53,7 @@ static framework::DDim ColumnMatrixFromVector(const framework::DDim &y_dim) { } static void ReshapeTensorIntoMatrixSequence( - framework::Tensor *x, const math::MatDescriptor &descriptor) { + framework::Tensor *x, const pten::funcs::MatDescriptor &descriptor) { int64_t h, w; h = descriptor.height_; w = descriptor.width_; @@ -86,8 +86,8 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x, bool trans_y) { auto x_dim = RowMatrixFromVector(x->dims()); auto y_dim = ColumnMatrixFromVector(y->dims()); - auto mat_dim_x = math::CreateMatrixDescriptor(x_dim, 0, trans_x); - auto mat_dim_y = math::CreateMatrixDescriptor(y_dim, 0, trans_y); + auto mat_dim_x = pten::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x); + auto mat_dim_y = pten::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y); if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { out->Resize({mat_dim_x.height_, mat_dim_y.width_}); } else { @@ -109,10 +109,10 @@ static void MatMulXPUFunction(const Tensor *x, const Tensor *y, Tensor *out, auto &dev_ctx = ctx.template device_context(); - auto mat_dim_a = - math::CreateMatrixDescriptor(RowMatrixFromVector(x_dims), 0, trans_x); - auto mat_dim_b = - math::CreateMatrixDescriptor(ColumnMatrixFromVector(y_dims), 0, trans_y); + auto mat_dim_a = pten::funcs::CreateMatrixDescriptor( + RowMatrixFromVector(x_dims), 0, trans_x); + auto mat_dim_b = pten::funcs::CreateMatrixDescriptor( + ColumnMatrixFromVector(y_dims), 0, trans_y); if (x_dims.size() == 3 && y_dims.size() <= 2) { // if transpose_X is true, the transpose cost much time diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index 045f823b7b672..f26777781ef1e 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -21,8 +21,8 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/dot_op.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/complex_functors.h" // only can include the headers in paddle/pten/api dirs @@ -77,7 +77,7 @@ static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) { * If transposed, `H,W` will be swapped. */ static void ReshapeTensorIntoMatrixSequence( - framework::Tensor* x, const math::MatDescriptor& descriptor) { + framework::Tensor* x, const pten::funcs::MatDescriptor& descriptor) { int64_t h, w; h = descriptor.height_; w = descriptor.width_; @@ -97,8 +97,8 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x, bool trans_y) { auto x_dim = RowMatrixFromVector(x->dims()); auto y_dim = ColumnMatrixFromVector(y->dims()); - auto mat_dim_x = math::CreateMatrixDescriptor(x_dim, 0, trans_x); - auto mat_dim_y = math::CreateMatrixDescriptor(y_dim, 0, trans_y); + auto mat_dim_x = pten::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x); + auto mat_dim_y = pten::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y); if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { out->Resize({mat_dim_x.height_, mat_dim_y.width_}); } else { diff --git a/paddle/fluid/operators/matmul_v2_op_xpu.cc b/paddle/fluid/operators/matmul_v2_op_xpu.cc index 908a23c4ecc63..e95d8e8d41afb 100644 --- a/paddle/fluid/operators/matmul_v2_op_xpu.cc +++ b/paddle/fluid/operators/matmul_v2_op_xpu.cc @@ -33,10 +33,10 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out, auto& dev_ctx = ctx.template device_context(); - auto mat_dim_a = - math::CreateMatrixDescriptor(RowMatrixFromVector(x_dims), 0, trans_x); - auto mat_dim_b = - math::CreateMatrixDescriptor(ColumnMatrixFromVector(y_dims), 0, trans_y); + auto mat_dim_a = pten::funcs::CreateMatrixDescriptor( + RowMatrixFromVector(x_dims), 0, trans_x); + auto mat_dim_b = pten::funcs::CreateMatrixDescriptor( + ColumnMatrixFromVector(y_dims), 0, trans_y); if (x_dims.size() == 3 && y_dims.size() <= 2) { // if transpose_X is true, the transpose cost much time diff --git a/paddle/fluid/operators/matrix_power_op.h b/paddle/fluid/operators/matrix_power_op.h index 58a8ef87628fe..c5285aae37efc 100644 --- a/paddle/fluid/operators/matrix_power_op.h +++ b/paddle/fluid/operators/matrix_power_op.h @@ -18,9 +18,9 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/matrix_inverse.h" #include "paddle/fluid/platform/for_range.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -58,7 +58,7 @@ void MatrixPowerFunction(const Tensor* X, const int n, Tensor* Out, return; } - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); Tensor new_x = ctx.AllocateTmpTensor(X->dims(), dev_ctx); int new_n = n; @@ -77,7 +77,7 @@ void MatrixPowerFunction(const Tensor* X, const int n, Tensor* Out, return; } - auto no_trans_desc = math::CreateMatrixDescriptor(x_dims, 0, false); + auto no_trans_desc = pten::funcs::CreateMatrixDescriptor(x_dims, 0, false); if (new_n == 2) { // Out = newX * newX @@ -166,7 +166,7 @@ void MatrixPowerGradFunction(const Tensor* X, const Tensor* Out, const auto& x_dims = X->dims(); auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); if (n == 0) { // \nabla X = O @@ -179,8 +179,8 @@ void MatrixPowerGradFunction(const Tensor* X, const Tensor* Out, return; } - auto trans_desc = math::CreateMatrixDescriptor(x_dims, 0, true); - auto no_trans_desc = math::CreateMatrixDescriptor(x_dims, 0, false); + auto trans_desc = pten::funcs::CreateMatrixDescriptor(x_dims, 0, true); + auto no_trans_desc = pten::funcs::CreateMatrixDescriptor(x_dims, 0, false); if (n == -1) { // \nabla X = Out^{T} * \nabla Out * Out^{T} diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index 2c84218c48e0b..608d06789fc27 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -113,10 +113,8 @@ class MatMulMKLDNNHandler float scale) : paddle::platform::MKLDNNHandlerNoCachingT(engine, cpu_place) { - auto mat_dim_x = - paddle::operators::math::CreateMatrixDescriptor(x->dims(), 0, trans_x); - auto mat_dim_y = - paddle::operators::math::CreateMatrixDescriptor(y->dims(), 0, trans_y); + auto mat_dim_x = pten::funcs::CreateMatrixDescriptor(x->dims(), 0, trans_x); + auto mat_dim_y = pten::funcs::CreateMatrixDescriptor(y->dims(), 0, trans_y); memory::dim x_bs = mat_dim_x.batch_size_; memory::dim y_bs = mat_dim_y.batch_size_; @@ -237,8 +235,8 @@ class MatMulMKLDNNHandler out_strides; }; - std::pair - GetInputDimsAndStrides(const ExecutionContext& ctx, std::string input_name) { + std::pair GetInputDimsAndStrides( + const ExecutionContext& ctx, std::string input_name) { auto shape = ctx.Attr>("fused_reshape_" + input_name); auto axis = ctx.Attr>("fused_transpose_" + input_name); auto input_dims = ctx.Input(input_name)->dims(); @@ -279,10 +277,9 @@ class MatMulMKLDNNHandler auto& MatrixDimsFromVector = input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector; - paddle::operators::math::MatDescriptor mat_dim = - paddle::operators::math::CreateMatrixDescriptor( - MatrixDimsFromVector(new_dims), 0, - ctx.Attr("transpose_" + input_name)); + pten::funcs::MatDescriptor mat_dim = pten::funcs::CreateMatrixDescriptor( + MatrixDimsFromVector(new_dims), 0, + ctx.Attr("transpose_" + input_name)); memory::dims strides; if (!shape.empty()) { @@ -324,10 +321,10 @@ class MatMulMKLDNNHandler } MatMulDims GetMatmulDims(const ExecutionContext& ctx) { - paddle::operators::math::MatDescriptor mat_dim_x; + pten::funcs::MatDescriptor mat_dim_x; memory::dims strides_x; std::tie(mat_dim_x, strides_x) = GetInputDimsAndStrides(ctx, "X"); - paddle::operators::math::MatDescriptor mat_dim_y; + pten::funcs::MatDescriptor mat_dim_y; memory::dims strides_y; std::tie(mat_dim_y, strides_y) = GetInputDimsAndStrides(ctx, "Y"); @@ -431,7 +428,7 @@ class MatMulMKLDNNHandler * If transposed, `H,W` will be swapped. */ static void ReshapeTensorToMatrixSequence( - Tensor* x, const paddle::operators::math::MatDescriptor& descriptor) { + Tensor* x, const pten::funcs::MatDescriptor& descriptor) { int64_t h, w; h = descriptor.height_; w = descriptor.width_; @@ -463,10 +460,8 @@ static void ReshapeXYOutToMatrixSequence(Tensor* x, Tensor* y, Tensor* out, bool trans_x, bool trans_y) { auto x_dim = RowMatrixDimsFromVector(x->dims()); auto y_dim = ColumnMatrixDimsFromVector(y->dims()); - auto mat_dim_x = - paddle::operators::math::CreateMatrixDescriptor(x_dim, 0, trans_x); - auto mat_dim_y = - paddle::operators::math::CreateMatrixDescriptor(y_dim, 0, trans_y); + auto mat_dim_x = pten::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x); + auto mat_dim_y = pten::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y); if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { out->Resize({mat_dim_x.height_, mat_dim_y.width_}); } else { diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h index af4c154cd378f..1fdd84b7f33d2 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h @@ -16,8 +16,8 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index d3c7c1759641b..569b1325c5ea7 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -84,11 +84,10 @@ std::vector GetInputStrides(const ExecutionContext& ctx, auto& MatrixDimsFromVector = input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector; - paddle::operators::math::MatDescriptor mat_dim = - paddle::operators::math::CreateMatrixDescriptor( - MatrixDimsFromVector(new_dims), 0, - ctx.Attr(std::string("trans_") + - static_cast(std::tolower(input_name[0])))); + pten::funcs::MatDescriptor mat_dim = pten::funcs::CreateMatrixDescriptor( + MatrixDimsFromVector(new_dims), 0, + ctx.Attr(std::string("trans_") + + static_cast(std::tolower(input_name[0])))); std::vector strides; if (!shape.empty()) { diff --git a/paddle/fluid/operators/mul_op.h b/paddle/fluid/operators/mul_op.h index 6ea154c25db5d..728773fead4da 100644 --- a/paddle/fluid/operators/mul_op.h +++ b/paddle/fluid/operators/mul_op.h @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { @@ -51,7 +51,7 @@ class MulKernel : public framework::OpKernel { z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); } - auto blas = math::GetBlas(context); + auto blas = pten::funcs::GetBlas(context); blas.MatMul(x_matrix, y_matrix, z); if (z_dim.size() != 2) { @@ -92,7 +92,7 @@ class MulGradKernel : public framework::OpKernel { } auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); if (dx) { dx->mutable_data(ctx.GetPlace()); Tensor dx_matrix = dx->dims().size() > 2 @@ -153,7 +153,7 @@ class MulDoubleGradKernel : public framework::OpKernel { } auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); // a flag to specify whether ddout value has been set, if flag // is false, MatMul beta should be 0 to set ddout, if flag is // true, MatMul beta should be 1 to add result to ddout. diff --git a/paddle/fluid/operators/multi_dot_op.cc b/paddle/fluid/operators/multi_dot_op.cc index 2d06170d34a91..784b4394ea70c 100644 --- a/paddle/fluid/operators/multi_dot_op.cc +++ b/paddle/fluid/operators/multi_dot_op.cc @@ -18,9 +18,9 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/utils.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -95,15 +95,15 @@ inline framework::Tensor MatMul(const framework::ExecutionContext& ctx, const framework::DDim& a_dim, const framework::DDim& b_dim) { auto place = ctx.GetPlace(); - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); framework::Tensor matrix_c; framework::DDim c_dim = framework::make_ddim({a_dim[0], b_dim[1]}); matrix_c.Resize(c_dim); matrix_c.mutable_data(place); - auto mat_dim_a = math::CreateMatrixDescriptor(a_dim, 0, false); - auto mat_dim_b = math::CreateMatrixDescriptor(b_dim, 0, false); + auto mat_dim_a = pten::funcs::CreateMatrixDescriptor(a_dim, 0, false); + auto mat_dim_b = pten::funcs::CreateMatrixDescriptor(b_dim, 0, false); const T alpha = static_cast(1.0); blas.MatMul(matrix_a, mat_dim_a, matrix_b, mat_dim_b, alpha, &matrix_c, T(0)); return matrix_c; @@ -269,7 +269,7 @@ class MultiDotKernel : public framework::OpKernel { auto place = ctx.GetPlace(); out->mutable_data(place); - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); auto n = ins.size(); std::vector ins_dims(n); @@ -277,8 +277,10 @@ class MultiDotKernel : public framework::OpKernel { const T scale = static_cast(1.0); if (n == 2) { - auto mat_dim_a = math::CreateMatrixDescriptor(ins_dims[0], 0, false); - auto mat_dim_b = math::CreateMatrixDescriptor(ins_dims[1], 0, false); + auto mat_dim_a = + pten::funcs::CreateMatrixDescriptor(ins_dims[0], 0, false); + auto mat_dim_b = + pten::funcs::CreateMatrixDescriptor(ins_dims[1], 0, false); blas.MatMul(*ins[0], mat_dim_a, *ins[1], mat_dim_b, scale, out, T(0)); } else if (n == 3) { const auto Ma = ins_dims[0][0]; @@ -287,16 +289,20 @@ class MultiDotKernel : public framework::OpKernel { const auto Nc = ins_dims[2][1]; const uint64_t cost1 = Ma * Nb * (Ka + Nc); const uint64_t cost2 = Ka * Nc * (Nb + Ma); - auto mat_dim_a = math::CreateMatrixDescriptor(ins_dims[0], 0, false); - auto mat_dim_b = math::CreateMatrixDescriptor(ins_dims[1], 0, false); - auto mat_dim_c = math::CreateMatrixDescriptor(ins_dims[2], 0, false); + auto mat_dim_a = + pten::funcs::CreateMatrixDescriptor(ins_dims[0], 0, false); + auto mat_dim_b = + pten::funcs::CreateMatrixDescriptor(ins_dims[1], 0, false); + auto mat_dim_c = + pten::funcs::CreateMatrixDescriptor(ins_dims[2], 0, false); if (cost1 < cost2) { framework::Tensor tmp_out; tmp_out.mutable_data(place, Ma * Nb * sizeof(T)); framework::DDim tmp_dim = framework::make_ddim({Ma, Nb}); blas.MatMul(*ins[0], mat_dim_a, *ins[1], mat_dim_b, scale, &tmp_out, T(0)); - auto mat_dim_tmp = math::CreateMatrixDescriptor(tmp_dim, 0, false); + auto mat_dim_tmp = + pten::funcs::CreateMatrixDescriptor(tmp_dim, 0, false); blas.MatMul(tmp_out, mat_dim_tmp, *ins[2], mat_dim_c, scale, out, T(0)); } else { framework::Tensor tmp_out; @@ -304,7 +310,8 @@ class MultiDotKernel : public framework::OpKernel { framework::DDim tmp_dim = framework::make_ddim({Ka, Nc}); blas.MatMul(*ins[1], mat_dim_b, *ins[2], mat_dim_c, scale, &tmp_out, T(0)); - auto mat_dim_tmp = math::CreateMatrixDescriptor(tmp_dim, 0, false); + auto mat_dim_tmp = + pten::funcs::CreateMatrixDescriptor(tmp_dim, 0, false); blas.MatMul(*ins[0], mat_dim_a, tmp_out, mat_dim_tmp, scale, out, T(0)); } } else { @@ -348,11 +355,11 @@ class MultiDotGradKernel : public framework::OpKernel { const framework::Tensor& B, const framework::DDim& dout_dim, const framework::DDim& a_dim, const framework::DDim& b_dim, framework::Tensor* dA, framework::Tensor* dB) const { - auto mat_dim_dout = math::CreateMatrixDescriptor(dout_dim, 0, false); - auto mat_dim_a = math::CreateMatrixDescriptor(a_dim, 0, true); - auto mat_dim_b = math::CreateMatrixDescriptor(b_dim, 0, true); + auto mat_dim_dout = pten::funcs::CreateMatrixDescriptor(dout_dim, 0, false); + auto mat_dim_a = pten::funcs::CreateMatrixDescriptor(a_dim, 0, true); + auto mat_dim_b = pten::funcs::CreateMatrixDescriptor(b_dim, 0, true); T alpha = static_cast(1.0); - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); blas.MatMul(A, mat_dim_a, dout, mat_dim_dout, alpha, dB, T(0)); blas.MatMul(dout, mat_dim_dout, B, mat_dim_b, alpha, dA, T(0)); } @@ -433,7 +440,7 @@ class MultiDotGradKernel : public framework::OpKernel { auto dout = *ctx.Input(framework::GradVarName("Out")); auto dx = ctx.MultiOutput(framework::GradVarName("X")); - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); auto place = ctx.GetPlace(); const auto n = ins.size(); @@ -458,7 +465,7 @@ class MultiDotGradKernel : public framework::OpKernel { } T alpha = static_cast(1); - auto mat_dim_dout = math::CreateMatrixDescriptor(dout_dim, 0, false); + auto mat_dim_dout = pten::funcs::CreateMatrixDescriptor(dout_dim, 0, false); if (n == 2) { CalcGrad(ctx, dout, *ins[0], *ins[1], dout_dim, ins_dims[0], ins_dims[1], dx[0], dx[1]); @@ -469,9 +476,12 @@ class MultiDotGradKernel : public framework::OpKernel { const auto Nc = ins_dims[2][1]; const uint64_t cost1 = Ma * Nb * (Ka + Nc); const uint64_t cost2 = Ka * Nc * (Nb + Ma); - auto mat_dim_a = math::CreateMatrixDescriptor(ins_dims[0], 0, false); - auto mat_dim_b = math::CreateMatrixDescriptor(ins_dims[1], 0, false); - auto mat_dim_c = math::CreateMatrixDescriptor(ins_dims[2], 0, false); + auto mat_dim_a = + pten::funcs::CreateMatrixDescriptor(ins_dims[0], 0, false); + auto mat_dim_b = + pten::funcs::CreateMatrixDescriptor(ins_dims[1], 0, false); + auto mat_dim_c = + pten::funcs::CreateMatrixDescriptor(ins_dims[2], 0, false); if (cost1 < cost2) { framework::Tensor tmp_out, tmp_dout; tmp_out.Resize({Ma, Nb}); diff --git a/paddle/fluid/operators/mv_op.cu b/paddle/fluid/operators/mv_op.cu index cec17f1324313..cdb99b4a77ee1 100644 --- a/paddle/fluid/operators/mv_op.cu +++ b/paddle/fluid/operators/mv_op.cu @@ -59,7 +59,7 @@ class MVGradKernel auto &dev_ctx = context.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); auto stream = context.cuda_device_context().stream(); auto config = GetGpuLaunchConfig1D(dev_ctx, m * n); diff --git a/paddle/fluid/operators/mv_op.h b/paddle/fluid/operators/mv_op.h index e29449962989f..2af79634a70ee 100644 --- a/paddle/fluid/operators/mv_op.h +++ b/paddle/fluid/operators/mv_op.h @@ -18,7 +18,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -45,7 +45,7 @@ class MVKernel : public framework::OpKernel { T *out_data = out->mutable_data(context.GetPlace()); auto &dev_ctx = context.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); blas.GEMV(false, dim_x[0], dim_x[1], static_cast(1), x_data, vec_data, static_cast(0), out_data); @@ -93,7 +93,7 @@ class MVGradKernel : public framework::OpKernel { T *dvec_data = dvec->mutable_data(context.GetPlace()); auto &dev_ctx = context.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); blas.GEMV(true, dim_x[0], dim_x[1], static_cast(1), x_data, dout_data, static_cast(0), dvec_data); diff --git a/paddle/fluid/operators/rank_attention_op.cu b/paddle/fluid/operators/rank_attention_op.cu index 23b4475e1f7c1..e8273b40aa834 100644 --- a/paddle/fluid/operators/rank_attention_op.cu +++ b/paddle/fluid/operators/rank_attention_op.cu @@ -14,11 +14,11 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/rank_attention.cu.h" #include "paddle/fluid/operators/rank_attention_op.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -114,7 +114,7 @@ class RankAttentionCUDAKernel : public framework::OpKernel { int64_t strideA = block_matrix_row; int64_t strideB = block_matrix_row * para_col; - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); blas.BatchedGEMM(transA, transB, 1, para_col, block_matrix_row, alpha, input_help_data, param_help_data, beta, out_data, ins_num, strideA, strideB); @@ -170,7 +170,7 @@ class RankAttentionGradOpCUDAKernel : public framework::OpKernel { const T *ins_rank_data = ins_rank->data(); T *param_grad_data = param_grad.data(); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); T alpha = 1; T beta = 0; diff --git a/paddle/fluid/operators/repeat_interleave_op.h b/paddle/fluid/operators/repeat_interleave_op.h index 4e47226358aa9..5a79c042d8034 100644 --- a/paddle/fluid/operators/repeat_interleave_op.h +++ b/paddle/fluid/operators/repeat_interleave_op.h @@ -15,7 +15,7 @@ #pragma once #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" #include "paddle/fluid/operators/index_select_op.h" diff --git a/paddle/fluid/operators/rnn_op.h b/paddle/fluid/operators/rnn_op.h index b2c1b8b9895d3..28733d951699b 100644 --- a/paddle/fluid/operators/rnn_op.h +++ b/paddle/fluid/operators/rnn_op.h @@ -19,7 +19,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/dropout_op.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/fc.h" @@ -27,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/unique_op.h" #include "paddle/fluid/operators/utils.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { @@ -108,9 +108,12 @@ struct SimpleRNNCell : Cell { const Tensor* init_c, Tensor* last_h, Tensor* last_c, Tensor* last_c_act, Tensor* output, const Tensor* bias_hh, Tensor* weight_hh_gru) const override { - auto blas = math::GetBlas(*device_ctx); - auto mat_dim_a = math::CreateMatrixDescriptor(init_h->dims(), 0, false); - auto mat_dim_b = math::CreateMatrixDescriptor(weight_hh->dims(), 0, true); + auto blas = + pten::funcs::GetBlas(*device_ctx); + auto mat_dim_a = + pten::funcs::CreateMatrixDescriptor(init_h->dims(), 0, false); + auto mat_dim_b = + pten::funcs::CreateMatrixDescriptor(weight_hh->dims(), 0, true); mat_dim_a.height_ *= mat_dim_a.batch_size_; mat_dim_a.batch_size_ = 0; // convert the batch matmul to matmul, this operator could be speed faster @@ -134,10 +137,12 @@ struct GRUCell : Cell { const Tensor* init_c, Tensor* last_h, Tensor* last_c, Tensor* last_c_act, Tensor* output, const Tensor* bias_hh, Tensor* weight_hh_gru) const override { - auto blas = math::GetBlas(*device_ctx); - auto mat_dim_a = math::CreateMatrixDescriptor(init_h->dims(), 0, false); + auto blas = + pten::funcs::GetBlas(*device_ctx); + auto mat_dim_a = + pten::funcs::CreateMatrixDescriptor(init_h->dims(), 0, false); auto mat_dim_b = - math::CreateMatrixDescriptor(weight_hh_gru->dims(), 0, true); + pten::funcs::CreateMatrixDescriptor(weight_hh_gru->dims(), 0, true); mat_dim_a.height_ *= mat_dim_a.batch_size_; mat_dim_a.batch_size_ = 0; // convert the batch matmul to matmul, this operator could be speed faster @@ -171,9 +176,12 @@ struct LSTMCell : Cell { const Tensor* init_c, Tensor* last_h, Tensor* last_c, Tensor* last_c_act, Tensor* output, const Tensor* bias_hh, Tensor* weight_hh_gru) const override { - auto blas = math::GetBlas(*device_ctx); - auto mat_dim_a = math::CreateMatrixDescriptor(init_h->dims(), 0, false); - auto mat_dim_b = math::CreateMatrixDescriptor(weight_hh->dims(), 0, true); + auto blas = + pten::funcs::GetBlas(*device_ctx); + auto mat_dim_a = + pten::funcs::CreateMatrixDescriptor(init_h->dims(), 0, false); + auto mat_dim_b = + pten::funcs::CreateMatrixDescriptor(weight_hh->dims(), 0, true); mat_dim_a.height_ *= mat_dim_a.batch_size_; mat_dim_a.batch_size_ = 0; // convert the batch matmul to matmul, this operator could be speed faster @@ -281,9 +289,11 @@ struct Layer { if (is_test) { cache_input->mutable_data(context.GetPlace()); } - auto blas = math::GetBlas(dev_ctx); - auto mat_dim_a = math::CreateMatrixDescriptor(input->dims(), 0, false); - auto mat_dim_b = math::CreateMatrixDescriptor(weight.dims(), 0, true); + auto blas = pten::funcs::GetBlas(dev_ctx); + auto mat_dim_a = + pten::funcs::CreateMatrixDescriptor(input->dims(), 0, false); + auto mat_dim_b = + pten::funcs::CreateMatrixDescriptor(weight.dims(), 0, true); // convert the batch matmul to matmul, this operator could be speed faster mat_dim_a.height_ *= mat_dim_a.batch_size_; mat_dim_a.batch_size_ = 0; @@ -1268,12 +1278,13 @@ struct GradLayer { } auto& device_ctx = context.template device_context(); - auto blas = math::GetBlas(device_ctx); + auto blas = pten::funcs::GetBlas(device_ctx); // calc the gradient for the w_hi auto mat_dim_out_grad = - math::CreateMatrixDescriptor(grad_gate.dims(), 0, true); - auto mat_dim_input = math::CreateMatrixDescriptor(input.dims(), 0, false); + pten::funcs::CreateMatrixDescriptor(grad_gate.dims(), 0, true); + auto mat_dim_input = + pten::funcs::CreateMatrixDescriptor(input.dims(), 0, false); mat_dim_out_grad.width_ *= mat_dim_out_grad.batch_size_; mat_dim_out_grad.batch_size_ = 0; mat_dim_input.height_ *= mat_dim_input.batch_size_; @@ -1284,11 +1295,11 @@ struct GradLayer { // calc the gradient for the X auto mat_dim_out_grad_new = - math::CreateMatrixDescriptor(grad_gate.dims(), 0, false); + pten::funcs::CreateMatrixDescriptor(grad_gate.dims(), 0, false); mat_dim_out_grad_new.height_ *= mat_dim_out_grad_new.batch_size_; mat_dim_out_grad_new.batch_size_ = 0; auto mat_dim_parameter = - math::CreateMatrixDescriptor(parameters[0].dims(), 0, false); + pten::funcs::CreateMatrixDescriptor(parameters[0].dims(), 0, false); blas.MatMul(grad_gate, mat_dim_out_grad_new, parameters[begin_idx + 0], mat_dim_parameter, static_cast(1.0), input_grad, T(1)); @@ -1583,13 +1594,14 @@ struct GradCell { bool has_sequence_length) const { auto& device_ctx = context.template device_context(); - auto blas = math::GetBlas(device_ctx); + auto blas = pten::funcs::GetBlas(device_ctx); Tensor* grad_gate_tmp = grad_gate; auto mat_dim_a = - math::CreateMatrixDescriptor(grad_gate_tmp->dims(), 0, false); + pten::funcs::CreateMatrixDescriptor(grad_gate_tmp->dims(), 0, false); mat_dim_a.height_ *= mat_dim_a.batch_size_; mat_dim_a.batch_size_ = 0; - auto mat_dim_b = math::CreateMatrixDescriptor(weight_hh->dims(), 0, false); + auto mat_dim_b = + pten::funcs::CreateMatrixDescriptor(weight_hh->dims(), 0, false); blas.MatMul(*grad_gate_tmp, mat_dim_a, *weight_hh, mat_dim_b, static_cast(1.0), grad_pre_hidden, 0); postprocess_pre_hidden_grad(context, grad_pre_hidden, grad_pre_hidden_bak, @@ -1602,11 +1614,13 @@ struct GradCell { Tensor* grad_weight_hh) const { auto& device_ctx = context.template device_context(); - auto blas = math::GetBlas(device_ctx); - auto mat_dim_c = math::CreateMatrixDescriptor(grad_gate->dims(), 0, true); + auto blas = pten::funcs::GetBlas(device_ctx); + auto mat_dim_c = + pten::funcs::CreateMatrixDescriptor(grad_gate->dims(), 0, true); mat_dim_c.height_ *= mat_dim_c.batch_size_; mat_dim_c.batch_size_ = 0; - auto mat_dim_d = math::CreateMatrixDescriptor(pre_hidden->dims(), 0, false); + auto mat_dim_d = + pten::funcs::CreateMatrixDescriptor(pre_hidden->dims(), 0, false); mat_dim_d.height_ *= mat_dim_d.batch_size_; mat_dim_d.batch_size_ = 0; blas.MatMul(*grad_gate, mat_dim_c, *pre_hidden, mat_dim_d, diff --git a/paddle/fluid/operators/scatter.h b/paddle/fluid/operators/scatter.h index c4f9c628dbb04..7dd019447ddf7 100644 --- a/paddle/fluid/operators/scatter.h +++ b/paddle/fluid/operators/scatter.h @@ -19,8 +19,8 @@ limitations under the License. */ #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/place.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "unordered_set" namespace paddle { @@ -37,7 +37,7 @@ typename std::enable_if::value>::type elementwise_inner_add(const framework::ExecutionContext& ctx, const T* src_pointer, T* dst_pointer, size_t src_index, IndexT dst_index, size_t slice_size) { - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); blas.VADD(slice_size, src_pointer + src_index * slice_size, dst_pointer + dst_index * slice_size, dst_pointer + dst_index * slice_size); diff --git a/paddle/fluid/operators/search_compute.h b/paddle/fluid/operators/search_compute.h index 3e8d270ca4f06..8c3a044b98b07 100644 --- a/paddle/fluid/operators/search_compute.h +++ b/paddle/fluid/operators/search_compute.h @@ -22,7 +22,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { @@ -33,7 +33,7 @@ using LoDTensor = framework::LoDTensor; using LoD = framework::LoD; template -void call_gemm(const math::BlasT& blas, +void call_gemm(const pten::funcs::BlasT& blas, const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, const T alpha, const T* A, const T* B, const T beta, T* C) { @@ -49,12 +49,12 @@ void call_gemm(const framework::ExecutionContext& ctx, const T* B, const T beta, T* C) { int lda = (TransA == CblasNoTrans) ? K : M; int ldb = (TransB == CblasNoTrans) ? N : K; - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N); } template -void call_gemm_with_lda(const math::BlasT& blas, +void call_gemm_with_lda(const pten::funcs::BlasT& blas, const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, const T alpha, const T* A, const T* B, diff --git a/paddle/fluid/operators/sequence_ops/sequence_conv_op.h b/paddle/fluid/operators/sequence_ops/sequence_conv_op.h index b43254f91fde7..6065b4ef15b84 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_conv_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_conv_op.h @@ -66,7 +66,7 @@ class SequenceConvKernel : public framework::OpKernel { // Because if padding_trainable is false, padding data should be zeros. pten::funcs::SetConstant set_zero; auto& dev_ctx = context.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); set_zero(dev_ctx, &col, static_cast(0)); math::ContextProjectFunctor seq_project_functor; @@ -109,7 +109,7 @@ class SequenceConvGradKernel : public framework::OpKernel { pten::funcs::SetConstant set_zero; auto& dev_ctx = context.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); // use col_shape in the im2col calculation framework::DDim col_shape = {in->dims()[0], sequence_width * context_length}; diff --git a/paddle/fluid/operators/solve_op.h b/paddle/fluid/operators/solve_op.h index f75dd6243a2ee..c3a688c2a4c9c 100644 --- a/paddle/fluid/operators/solve_op.h +++ b/paddle/fluid/operators/solve_op.h @@ -20,10 +20,10 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/eigen/eigen_function.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/matrix_solve.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" #include "paddle/fluid/operators/squeeze_op.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" #if defined(__NVCC__) || defined(__HIPCC__) #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" @@ -523,10 +523,12 @@ class SolveGradKernel : public framework::OpKernel { if (dx) { dx->mutable_data(ctx.GetPlace()); // to get dx - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); if (input->dims().size() == 2 && y->dims().size() == 2) { - auto mat_dim_a1 = math::CreateMatrixDescriptor(tmp_dy.dims(), 0, false); - auto mat_dim_b1 = math::CreateMatrixDescriptor(out->dims(), 0, true); + auto mat_dim_a1 = + pten::funcs::CreateMatrixDescriptor(tmp_dy.dims(), 0, false); + auto mat_dim_b1 = + pten::funcs::CreateMatrixDescriptor(out->dims(), 0, true); blas.MatMul(tmp_dy, mat_dim_a1, *out, mat_dim_b1, T(-1), &tmp_dx, T(0)); } else if (is_vector_rhs(*input, *y)) { Tensor tmp_dy_; @@ -538,14 +540,16 @@ class SolveGradKernel : public framework::OpKernel { to_unsqueeze(ctx, *out, &tmp_out_); auto mat_dim_a1 = - math::CreateMatrixDescriptor(tmp_dy_.dims(), 0, false); + pten::funcs::CreateMatrixDescriptor(tmp_dy_.dims(), 0, false); auto mat_dim_b1 = - math::CreateMatrixDescriptor(tmp_out_.dims(), 0, true); + pten::funcs::CreateMatrixDescriptor(tmp_out_.dims(), 0, true); blas.MatMul(tmp_dy_, mat_dim_a1, tmp_out_, mat_dim_b1, T(-1), &tmp_dx, T(0)); } else { - auto mat_dim_a1 = math::CreateMatrixDescriptor(tmp_dy.dims(), 0, false); - auto mat_dim_b1 = math::CreateMatrixDescriptor(out->dims(), 0, true); + auto mat_dim_a1 = + pten::funcs::CreateMatrixDescriptor(tmp_dy.dims(), 0, false); + auto mat_dim_b1 = + pten::funcs::CreateMatrixDescriptor(out->dims(), 0, true); blas.MatMul(tmp_dy, mat_dim_a1, *out, mat_dim_b1, T(-1), &tmp_dx, T(0)); } } diff --git a/paddle/fluid/operators/spectral_norm_op.h b/paddle/fluid/operators/spectral_norm_op.h index d0edcc169255e..0c0f23317171f 100644 --- a/paddle/fluid/operators/spectral_norm_op.h +++ b/paddle/fluid/operators/spectral_norm_op.h @@ -13,7 +13,7 @@ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { @@ -65,7 +65,7 @@ static inline void CalcMatrixSigmaAndNormWeight( Tensor* sigma, Tensor* u, Tensor* v, Tensor* weight, const int power_iters, const float eps, const framework::ExecutionContext& ctx) { auto& place = *ctx.template device_context().eigen_device(); - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); auto sigma_t = EigenTensor::From(*sigma); auto weight_t = EigenTensor::From(*weight); auto u_t = EigenTensor::From(*u); @@ -179,7 +179,7 @@ class SpectralNormGradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto& place = *ctx.template device_context().eigen_device(); auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); auto weight = ctx.Input("Weight"); auto u = ctx.Input("U"); auto v = ctx.Input("V"); diff --git a/paddle/fluid/operators/squeeze_op.h b/paddle/fluid/operators/squeeze_op.h index d86037fa03258..c2f6c5fb2cd19 100644 --- a/paddle/fluid/operators/squeeze_op.h +++ b/paddle/fluid/operators/squeeze_op.h @@ -17,9 +17,9 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/pooling.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index 4384e7152fa4e..365fc42b083ea 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -24,9 +24,9 @@ #include "paddle/fluid/operators/diag_op.h" #include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/for_range.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/complex_functors.h" #include "paddle/pten/kernels/funcs/math_function.h" @@ -318,8 +318,8 @@ struct DeviceIndependenceTensorOperations { ret.Resize(framework::make_ddim(x_vec)); ret.mutable_data(context.GetPlace()); auto blas = GetBlas(); - auto mat_a_discrib = math::CreateMatrixDescriptor(a_dim, 0, trans_a); - auto mat_b_discrib = math::CreateMatrixDescriptor(b_dim, 0, trans_b); + auto mat_a_discrib = pten::funcs::CreateMatrixDescriptor(a_dim, 0, trans_a); + auto mat_b_discrib = pten::funcs::CreateMatrixDescriptor(b_dim, 0, trans_b); blas.MatMul(mat_a, mat_a_discrib, mat_b, mat_b_discrib, T(1.0), &ret, T(0.0)); return ret; @@ -688,8 +688,8 @@ struct DeviceIndependenceTensorOperations { private: const framework::ExecutionContext& context; - BlasT GetBlas() { - return math::GetBlas(context); + pten::funcs::BlasT GetBlas() { + return pten::funcs::GetBlas(context); } platform::ForRange GetForRange(int numel) { auto& dev_ctx = context.template device_context(); diff --git a/paddle/fluid/operators/tree_conv_op.h b/paddle/fluid/operators/tree_conv_op.h index c2a6cfdd0d37c..5418e4d3046b5 100644 --- a/paddle/fluid/operators/tree_conv_op.h +++ b/paddle/fluid/operators/tree_conv_op.h @@ -16,8 +16,8 @@ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/tree2col.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { @@ -37,7 +37,7 @@ class TreeConvKernel : public framework::OpKernel { int max_depth = ctx.Attr("max_depth"); auto &dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); Tensor W; W.ShareDataWith(*Filter); @@ -88,7 +88,7 @@ class TreeConvGradKernel : public framework::OpKernel { math::Col2TreeFunctor col2tree; pten::funcs::SetConstant constant; auto &dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); Tensor W; W.ShareDataWith(*Filter); diff --git a/paddle/fluid/operators/triangular_solve_op.h b/paddle/fluid/operators/triangular_solve_op.h index e892d258f3b12..dc0906a83792e 100644 --- a/paddle/fluid/operators/triangular_solve_op.h +++ b/paddle/fluid/operators/triangular_solve_op.h @@ -18,10 +18,10 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/operators/solve_op.h" #include "paddle/fluid/operators/tril_triu_op.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/complex_functors.h" namespace paddle { @@ -184,16 +184,19 @@ class TriangularSolveGradKernel : public framework::OpKernel { out_conj.mutable_data(out->dims(), dev_ctx.GetPlace())); out_for_range(out_functor); - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); if (transpose) { auto mat_dim_a = - math::CreateMatrixDescriptor(out_conj.dims(), 0, false); - auto mat_dim_b = math::CreateMatrixDescriptor(dy_bst.dims(), 0, true); + pten::funcs::CreateMatrixDescriptor(out_conj.dims(), 0, false); + auto mat_dim_b = + pten::funcs::CreateMatrixDescriptor(dy_bst.dims(), 0, true); blas.MatMul(out_conj, mat_dim_a, dy_bst, mat_dim_b, static_cast(-1), &dx_bst, static_cast(0)); } else { - auto mat_dim_a = math::CreateMatrixDescriptor(dy_bst.dims(), 0, false); - auto mat_dim_b = math::CreateMatrixDescriptor(out_conj.dims(), 0, true); + auto mat_dim_a = + pten::funcs::CreateMatrixDescriptor(dy_bst.dims(), 0, false); + auto mat_dim_b = + pten::funcs::CreateMatrixDescriptor(out_conj.dims(), 0, true); blas.MatMul(dy_bst, mat_dim_a, out_conj, mat_dim_b, static_cast(-1), &dx_bst, static_cast(0)); } diff --git a/paddle/fluid/operators/unsqueeze_op.h b/paddle/fluid/operators/unsqueeze_op.h index 649cc9de50e0d..ea0b698479b24 100644 --- a/paddle/fluid/operators/unsqueeze_op.h +++ b/paddle/fluid/operators/unsqueeze_op.h @@ -16,10 +16,10 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/pooling.h" #include "paddle/fluid/operators/utils.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/var_conv_2d_op.cc b/paddle/fluid/operators/var_conv_2d_op.cc index f67b969d4590a..d4e749bd087f9 100644 --- a/paddle/fluid/operators/var_conv_2d_op.cc +++ b/paddle/fluid/operators/var_conv_2d_op.cc @@ -15,8 +15,8 @@ limitations under the License. */ #include "paddle/fluid/operators/var_conv_2d_op.h" #include #include -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/dynload/mklml.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { @@ -300,7 +300,7 @@ class CPUVarConv2dOPKernel : public framework::OpKernel { auto* w_data = w->data(); auto* col_data = col->data(); - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); for (int b = 0; b < batch; ++b) { int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel; if (top_im_size == 0) { @@ -448,7 +448,7 @@ class CPUVarConv2dOPGradKernel : public framework::OpKernel { int batch = x->lod()[0].size() - 1; const auto& top_offset = out->lod()[0]; const auto& col_offset = col->lod()[0]; - auto blas = math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); for (int b = 0; b < batch; ++b) { int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel; if (top_im_size == 0) { diff --git a/paddle/pten/kernels/cpu/elementwise.h b/paddle/pten/kernels/cpu/elementwise.h index 0cd50be511add..1accc9994f540 100644 --- a/paddle/pten/kernels/cpu/elementwise.h +++ b/paddle/pten/kernels/cpu/elementwise.h @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/pten/kernels/funcs/common_shape.h" #include "paddle/pten/kernels/funcs/elementwise_base.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/eigen/common.h" namespace pten { @@ -44,7 +44,7 @@ struct SameDimsAddFunctor< const DenseTensor& x, const DenseTensor& y, DenseTensor* z) { - auto blas = paddle::operators::math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); blas.VADD( x.numel(), x.data(), y.data(), dev_ctx.template Alloc(z)); } @@ -86,7 +86,7 @@ struct SameDimsSubtractFunctor< const DenseTensor& x, const DenseTensor& y, DenseTensor* z) { - auto blas = paddle::operators::math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); blas.VSUB( x.numel(), x.data(), y.data(), dev_ctx.template Alloc(z)); } @@ -142,7 +142,7 @@ struct SameDimsDivideFunctor< const DenseTensor& x, const DenseTensor& y, DenseTensor* z) { - auto blas = paddle::operators::math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); blas.VDIV( x.numel(), x.data(), y.data(), dev_ctx.template Alloc(z)); } @@ -166,7 +166,7 @@ struct SameDimsMultiplyFunctor< const DenseTensor& x, const DenseTensor& y, DenseTensor* z) { - auto blas = paddle::operators::math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); blas.VMUL( x.numel(), x.data(), y.data(), dev_ctx.template Alloc(z)); } @@ -763,7 +763,7 @@ elementwise_add_grad(const CPUContext& ctx, DenseTensor* dx, DenseTensor* dy, int axis = -1) { - auto blas = paddle::operators::math::GetBlas(ctx); + auto blas = pten::funcs::GetBlas(ctx); if (dx) { blas.VCOPY( dout.numel(), dout.data(), dx->mutable_data(ctx.GetPlace())); diff --git a/paddle/pten/kernels/funcs/CMakeLists.txt b/paddle/pten/kernels/funcs/CMakeLists.txt index 844464a52dcbf..ba0c848df434e 100644 --- a/paddle/pten/kernels/funcs/CMakeLists.txt +++ b/paddle/pten/kernels/funcs/CMakeLists.txt @@ -1,42 +1,5 @@ add_subdirectory(eigen) - -function(math_library TARGET) - # math_library is a function to create math library. - # The interface is the same as cc_library. - # But it handle split GPU/CPU code and link some common library. - set(cc_srcs) - set(cu_srcs) - set(hip_srcs) - set(math_common_deps device_context framework_proto enforce) - if (WITH_GPU) - if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) - list(APPEND math_common_deps cub) - else() - list(APPEND math_common_deps) - endif() - endif() - set(multiValueArgs DEPS) - cmake_parse_arguments(math_library "${options}" "${oneValueArgs}" - "${multiValueArgs}" ${ARGN}) - - if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc) - list(APPEND cc_srcs ${TARGET}.cc) - endif() - if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) - list(APPEND cu_srcs ${TARGET}.cu) - endif() - if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc) - list(APPEND cu_srcs ${TARGET}.cu.cc) - endif() - - list(LENGTH cc_srcs cc_srcs_len) - if (WITH_GPU) - nv_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) - elseif (WITH_ROCM) - hip_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) - elseif(${cc_srcs_len} GREATER 0) - cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) - endif() -endfunction() +add_subdirectory(blas) +add_subdirectory(lapack) math_library(math_function DEPS blas dense_tensor tensor) diff --git a/paddle/pten/kernels/funcs/blas/CMakeLists.txt b/paddle/pten/kernels/funcs/blas/CMakeLists.txt new file mode 100644 index 0000000000000..cb054cc76e1d7 --- /dev/null +++ b/paddle/pten/kernels/funcs/blas/CMakeLists.txt @@ -0,0 +1 @@ +cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context) diff --git a/paddle/fluid/operators/math/blas.cc b/paddle/pten/kernels/funcs/blas/blas.cc similarity index 75% rename from paddle/fluid/operators/math/blas.cc rename to paddle/pten/kernels/funcs/blas/blas.cc index 77122a5882d6a..8cf3d16044cac 100644 --- a/paddle/fluid/operators/math/blas.cc +++ b/paddle/pten/kernels/funcs/blas/blas.cc @@ -12,18 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" -namespace paddle { -namespace operators { -namespace math { +namespace pten { +namespace funcs { MatDescriptor CreateMatrixDescriptor(const framework::DDim &tensor_dim, - int num_flatten_cols, bool trans) { + int num_flatten_cols, + bool trans) { PADDLE_ENFORCE_GT( - tensor_dim.size(), 1, - platform::errors::InvalidArgument("The tensor dim size should be greater " - "than 1, but reveived dim size is %d", - tensor_dim.size())); + tensor_dim.size(), + 1, + pten::errors::InvalidArgument("The tensor dim size should be greater " + "than 1, but reveived dim size is %d", + tensor_dim.size())); MatDescriptor retv; if (num_flatten_cols > 1) { auto flatten_dim = framework::flatten_to_2d(tensor_dim, num_flatten_cols); @@ -50,6 +51,5 @@ MatDescriptor CreateMatrixDescriptor(const framework::DDim &tensor_dim, retv.trans_ = trans; return retv; } -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace pten diff --git a/paddle/fluid/operators/math/blas.h b/paddle/pten/kernels/funcs/blas/blas.h similarity index 59% rename from paddle/fluid/operators/math/blas.h rename to paddle/pten/kernels/funcs/blas/blas.h index f17cc3094f7fc..ba69775b2d0a5 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/pten/kernels/funcs/blas/blas.h @@ -15,13 +15,7 @@ #pragma once #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/tensor.h" - -namespace paddle { -namespace framework { -class ExecutionContext; -} // namespace framework -} // namespace paddle +#include "paddle/pten/core/dense_tensor.h" #ifdef PADDLE_WITH_MKLML #include "paddle/fluid/platform/dynload/mklml.h" @@ -35,9 +29,8 @@ class ExecutionContext; #include #endif -namespace paddle { -namespace operators { -namespace math { +namespace pten { +namespace funcs { /** * Matrix Descriptor of a memory buffer. @@ -81,7 +74,8 @@ struct MatDescriptor { * @param trans: True if the matrix is transposed. */ extern MatDescriptor CreateMatrixDescriptor(const framework::DDim& tensor_dim, - int num_flatten_cols, bool trans); + int num_flatten_cols, + bool trans); template class Blas { @@ -89,73 +83,149 @@ class Blas { explicit Blas(const DeviceContext& context) : context_(context) {} template - void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - T alpha, const T* A, const T* B, T beta, T* C) const; + void GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T* A, + const T* B, + T beta, + T* C) const; template - void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A, - int lda, const T* B, int ldb, T beta, T* C, int ldc) const; + void GEMM(bool transA, + bool transB, + int M, + int N, + int K, + T alpha, + const T* A, + int lda, + const T* B, + int ldb, + T beta, + T* C, + int ldc) const; template - void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C, + void GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T* A, + int lda, + const T* B, + int ldb, + T beta, + T* C, int ldc) const; #ifdef PADDLE_WITH_MKLML // @{ Group MKLML: class Blas template - T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N, + T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, + const int M, + const int N, const int K) const; template - void GEMM_PACK(const CBLAS_IDENTIFIER id, const CBLAS_TRANSPOSE trans, int M, - int N, int K, const T alpha, const T* src, const int ld, + void GEMM_PACK(const CBLAS_IDENTIFIER id, + const CBLAS_TRANSPOSE trans, + int M, + int N, + int K, + const T alpha, + const T* src, + const int ld, T* dst) const; template - void GEMM_COMPUTE(int transA, int transB, int M, int N, int K, const T* A, - const int lda, const T* B, const int ldb, T beta, T* C, + void GEMM_COMPUTE(int transA, + int transB, + int M, + int N, + int K, + const T* A, + const int lda, + const T* B, + const int ldb, + T beta, + T* C, const int ldc) const; template void GEMM_FREE(T* data) const; template - void CSRMM(const char* transa, const int* m, const int* n, const int* k, - const T* alpha, const char* matdescra, const T* val, - const int* indx, const int* pntrb, const int* pntre, const T* b, - const int* ldb, const T* beta, T* c, const int* ldc) const; + void CSRMM(const char* transa, + const int* m, + const int* n, + const int* k, + const T* alpha, + const char* matdescra, + const T* val, + const int* indx, + const int* pntrb, + const int* pntre, + const T* b, + const int* ldb, + const T* beta, + T* c, + const int* ldc) const; #if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) template - void MatMulWithHead(const framework::Tensor& mat_a, + void MatMulWithHead(const pten::DenseTensor& mat_a, const MatDescriptor& dim_a, - const framework::Tensor& mat_b, - const MatDescriptor& dim_b, T alpha, int head_number, - framework::Tensor* mat_out, T beta, + const pten::DenseTensor& mat_b, + const MatDescriptor& dim_b, + T alpha, + int head_number, + pten::DenseTensor* mat_out, + T beta, bool mat_y_split_vertical) const; #endif #endif // @} End Group MKLML: class Blas template - void MatMul(const int M, const int N, const int K, const T* A, const T* B, + void MatMul(const int M, + const int N, + const int K, + const T* A, + const T* B, T* C) const; template - void MatMul(const framework::Tensor& mat_a, bool trans_a, - const framework::Tensor& mat_b, bool trans_b, T alpha, - framework::Tensor* mat_out, T beta) const; - - template - void MatMul(const framework::Tensor& mat_a, bool trans_a, - const framework::Tensor& mat_b, bool trans_b, - framework::Tensor* mat_out) const { - MatMul(mat_a, trans_a, mat_b, trans_b, static_cast(1.0), mat_out, + void MatMul(const pten::DenseTensor& mat_a, + bool trans_a, + const pten::DenseTensor& mat_b, + bool trans_b, + T alpha, + pten::DenseTensor* mat_out, + T beta) const; + + template + void MatMul(const pten::DenseTensor& mat_a, + bool trans_a, + const pten::DenseTensor& mat_b, + bool trans_b, + pten::DenseTensor* mat_out) const { + MatMul(mat_a, + trans_a, + mat_b, + trans_b, + static_cast(1.0), + mat_out, static_cast(0.0)); } template - void MatMul(const framework::Tensor& mat_a, const framework::Tensor& mat_b, - framework::Tensor* mat_out) const { + void MatMul(const pten::DenseTensor& mat_a, + const pten::DenseTensor& mat_b, + pten::DenseTensor* mat_out) const { this->template MatMul(mat_a, false, mat_b, false, mat_out); } @@ -187,7 +257,13 @@ class Blas { void VPOW(int n, const T* x, T alpha, T* y) const; template - void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta, + void GEMV(bool trans_a, + int M, + int N, + T alpha, + const T* A, + const T* B, + T beta, T* C) const; template @@ -200,33 +276,71 @@ class Blas { T ASUM(int n, T* x, int inc) const; template - void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, - int K, T alpha, const T* A, const T* B, T beta, T* C, - int batchCount, int64_t strideA, int64_t strideB) const; - - template - void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, - int K, T alpha, const T** A, const T** B, T beta, T** C, + void BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T* A, + const T* B, + T beta, + T* C, + int batchCount, + int64_t strideA, + int64_t strideB) const; + + template + void BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T** A, + const T** B, + T beta, + T** C, int batchCount) const; #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ !defined(PADDLE_WITH_HIP) template - void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, - int W1, int H1, int W2, int H2, T alpha, const T* A, - const T* B, T beta, T* C, int batchCount, - int64_t strideA, int64_t strideB, - int64_t head_number, bool split_b_vertical) const; + void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int W1, + int H1, + int W2, + int H2, + T alpha, + const T* A, + const T* B, + T beta, + T* C, + int batchCount, + int64_t strideA, + int64_t strideB, + int64_t head_number, + bool split_b_vertical) const; #endif template - void MatMul(const framework::Tensor& mat_a, const MatDescriptor& dim_a, - const framework::Tensor& mat_b, const MatDescriptor& dim_b, - T alpha, framework::Tensor* mat_out, T beta) const; + void MatMul(const pten::DenseTensor& mat_a, + const MatDescriptor& dim_a, + const pten::DenseTensor& mat_b, + const MatDescriptor& dim_b, + T alpha, + pten::DenseTensor* mat_out, + T beta) const; template - void MatMul(const T* mat_a, const MatDescriptor& dim_a, const T* mat_b, - const MatDescriptor& dim_b, T alpha, T* mat_out, T beta) const; + void MatMul(const T* mat_a, + const MatDescriptor& dim_a, + const T* mat_b, + const MatDescriptor& dim_b, + T alpha, + T* mat_out, + T beta) const; template void VINV(int n, const T* a, T* y) const; @@ -235,8 +349,16 @@ class Blas { void VMERF(int n, const T* a, T* y, int64_t mode) const; template - void TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA, - CBLAS_DIAG diag, int M, int N, T alpha, const T* A, int lda, T* B, + void TRSM(CBLAS_SIDE side, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE transA, + CBLAS_DIAG diag, + int M, + int N, + T alpha, + const T* A, + int lda, + T* B, int ldb) const; #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -244,24 +366,44 @@ class Blas { void BatchedGETRF(int n, T** a, int* ipiv, int* info, int batch_size) const; template - void BatchedGETRI(int n, const T** a, const int* ipiv, T** a_inv, int* info, + void BatchedGETRI(int n, + const T** a, + const int* ipiv, + T** a_inv, + int* info, int batch_size) const; template - void BatchedMatInv(int n, const T** a, T** a_inv, int* info, - int batch_size) const; + void BatchedMatInv( + int n, const T** a, T** a_inv, int* info, int batch_size) const; // cuBlas solve template - void BatchedGETRS(CBLAS_TRANSPOSE trans, int n, int nrhs, const T** a, - int lda, int* ipiv, T** b, int ldb, int* info, + void BatchedGETRS(CBLAS_TRANSPOSE trans, + int n, + int nrhs, + const T** a, + int lda, + int* ipiv, + T** b, + int ldb, + int* info, int batch_size) const; // cuBlas triangular_solve template - void BatchedTRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA, - CBLAS_DIAG diag, int M, int N, T alpha, const T** a, int lda, - T** b, int ldb, int batch_size) const; + void BatchedTRSM(CBLAS_SIDE side, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE transA, + CBLAS_DIAG diag, + int M, + int N, + T alpha, + const T** a, + int lda, + T** b, + int ldb, + int batch_size) const; #endif private: @@ -439,7 +581,7 @@ class BlasT : private Blas { template inline BlasT GetBlas( - const framework::ExecutionContext& exe_ctx) { + const paddle::framework::ExecutionContext& exe_ctx) { return BlasT( exe_ctx.template device_context()); } @@ -449,14 +591,13 @@ inline BlasT GetBlas(const DeviceContext& dev_ctx) { return BlasT(dev_ctx); } -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace pten -#include "paddle/fluid/operators/math/blas_impl.h" +#include "paddle/pten/kernels/funcs/blas/blas_impl.h" #ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/operators/math/blas_impl.cu.h" +#include "paddle/pten/kernels/funcs/blas/blas_impl.cu.h" #endif #ifdef PADDLE_WITH_HIP -#include "paddle/fluid/operators/math/blas_impl.hip.h" +#include "paddle/pten/kernels/funcs/blas/blas_impl.hip.h" #endif diff --git a/paddle/pten/kernels/funcs/blas/blas_impl.cu.h b/paddle/pten/kernels/funcs/blas/blas_impl.cu.h new file mode 100644 index 0000000000000..23dec2543f8c0 --- /dev/null +++ b/paddle/pten/kernels/funcs/blas/blas_impl.cu.h @@ -0,0 +1,2941 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/dynload/cublas.h" +#include "paddle/pten/kernels/funcs/math_function.h" + +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/pten/backends/gpu/gpu_context.h" + +DECLARE_bool(enable_cublas_tensor_op_math); + +namespace pten { +namespace funcs { + +template +struct CUBlas; + +template <> +struct CUBlas { + template + static void GEMM(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasSgemm(args...)); + } + + template + static void AXPY(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasSaxpy(args...)); + } + + template + static void SCAL(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasSscal(args...)); + } + + template + static void VCOPY(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasScopy(args...)); + } + + template + static void GEMV(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasSgemv(args...)); + } + + template + static void GEMM_STRIDED_BATCH(ARGS... args) { +#if CUDA_VERSION >= 8000 + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasSgemmStridedBatched(args...)); +#else + PADDLE_THROW(pten::errors::Unimplemented( + "SgemmStridedBatched is not supported on cuda <= 7.5")); +#endif + } + + // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. + // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode + template + static void GEMM_EX(paddle::platform::CUDADeviceContext *dev_ctx, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float *alpha, + const void *A, + cudaDataType_t Atype, + int lda, + const void *B, + cudaDataType_t Btype, + int ldb, + const float *beta, + void *C, + cudaDataType_t Ctype, + int ldc) { +// Because the gcc 4.8 doesn't expand template parameter pack that +// appears in a lambda-expression, I can not use template parameter pack +// here. +#if CUDA_VERSION >= 8000 + VLOG(5) << "use_tensor_op_math: " + << (dev_ctx->tensor_core_available() ? "True" : "False"); + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasSgemmEx(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc)); + }); +#else + PADDLE_THROW(pten::errors::Unimplemented( + "cublasSgemmEx is not supported on cuda <= 7.5")); +#endif + } + + // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. + // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode + template + static void GEMM_EX(pten::GPUContext *dev_ctx, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float *alpha, + const void *A, + cudaDataType_t Atype, + int lda, + const void *B, + cudaDataType_t Btype, + int ldb, + const float *beta, + void *C, + cudaDataType_t Ctype, + int ldc) { +// Because the gcc 4.8 doesn't expand template parameter pack that +// appears in a lambda-expression, I can not use template parameter pack +// here. +#if CUDA_VERSION >= 8000 + VLOG(5) << "use_tensor_op_math: " + << (dev_ctx->tensor_core_available() ? "True" : "False"); + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasSgemmEx(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc)); + }); +#else + PADDLE_THROW(pten::errors::Unimplemented( + "cublasSgemmEx is not supported on cuda <= 7.5")); +#endif + } + + template + static void TRSM(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasStrsm(args...)); + } + + template + static void GETRF_BATCH(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasSgetrfBatched(args...)); + } + + template + static void GETRI_BATCH(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasSgetriBatched(args...)); + } + + template + static void MATINV_BATCH(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasSmatinvBatched(args...)); + } + + template + static void GETRS_BATCH(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasSgetrsBatched(args...)); + } + + template + static void TRSM_BATCH(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasStrsmBatched(args...)); + } +}; + +template <> +struct CUBlas { + template + static void GEMM(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDgemm(args...)); + } + + template + static void AXPY(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDaxpy(args...)); + } + + template + static void SCAL(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDscal(args...)); + } + + template + static void VCOPY(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDcopy(args...)); + } + + template + static void GEMV(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDgemv(args...)); + } + + template + static void GEMM_STRIDED_BATCH(ARGS... args) { +#if CUDA_VERSION >= 8000 + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasDgemmStridedBatched(args...)); +#else + PADDLE_THROW(pten::errors::Unimplemented( + "DgemmStridedBatched is not supported on cuda <= 7.5")); +#endif + } + + template + static void GEMM_EX(ARGS... args) { + PADDLE_THROW( + pten::errors::Unimplemented("Currently there are not cublasDgemmEx.")); + } + + template + static void TRSM(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDtrsm(args...)); + } + + template + static void GETRF_BATCH(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasDgetrfBatched(args...)); + } + + template + static void GETRI_BATCH(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasDgetriBatched(args...)); + } + + template + static void MATINV_BATCH(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasDmatinvBatched(args...)); + } + + template + static void GETRS_BATCH(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasDgetrsBatched(args...)); + } + + template + static void TRSM_BATCH(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasDtrsmBatched(args...)); + } +}; + +template <> +struct CUBlas { + using float16 = pten::dtype::float16; + + static void GEMM(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float16 *alpha, + const float16 *A, + int lda, + const float16 *B, + int ldb, + const float16 *beta, + float16 *C, + int ldc) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasHgemm( + handle, + transa, + transb, + m, + n, + k, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + reinterpret_cast(B), + ldb, + reinterpret_cast(beta), + reinterpret_cast<__half *>(C), + ldc)); + } + + static void GEMM_STRIDED_BATCH(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float16 *alpha, + const float16 *A, + int lda, + long long int strideA, // NOLINT + const float16 *B, // NOLINT + int ldb, + long long int strideB, // NOLINT + const float16 *beta, + float16 *C, + int ldc, + long long int strideC, // NOLINT + int batchCount) { +#if CUDA_VERSION >= 8000 + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasHgemmStridedBatched( + handle, + transa, + transb, + m, + n, + k, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + strideA, + reinterpret_cast(B), + ldb, + strideB, + reinterpret_cast(beta), + reinterpret_cast<__half *>(C), + ldc, + strideC, + batchCount)); +#else + PADDLE_THROW(pten::errors::Unimplemented( + "HgemmStridedBatched is not supported on cuda <= 7.5")); +#endif + } + + // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. + // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode + template + static void GEMM_EX(paddle::platform::CUDADeviceContext *dev_ctx, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void *alpha, + const void *A, + cudaDataType_t Atype, + int lda, + const void *B, + cudaDataType_t Btype, + int ldb, + const void *beta, + void *C, + cudaDataType_t Ctype, + int ldc, + cudaDataType_t computeType) { +#if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; +#if CUDA_VERSION >= 9000 + bool use_tensor_op_math = dev_ctx->tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); +#endif // CUDA_VERSION >= 9000 + + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc, + computeType, + algo)); + }); +#else + PADDLE_THROW(pten::errors::Unimplemented( + "cublasGemmEx is not supported on cuda <= 7.5")); +#endif + } + + // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. + // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode + template + static void GEMM_EX(pten::GPUContext *dev_ctx, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void *alpha, + const void *A, + cudaDataType_t Atype, + int lda, + const void *B, + cudaDataType_t Btype, + int ldb, + const void *beta, + void *C, + cudaDataType_t Ctype, + int ldc, + cudaDataType_t computeType) { +#if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; +#if CUDA_VERSION >= 9000 + bool use_tensor_op_math = dev_ctx->tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); +#endif // CUDA_VERSION >= 9000 + + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc, + computeType, + algo)); + }); +#else + PADDLE_THROW(pten::errors::Unimplemented( + "cublasGemmEx is not supported on cuda <= 7.5")); +#endif + } +}; + +template <> +struct CUBlas> { + static void GEMV(cublasHandle_t handle, + cublasOperation_t transa, + int m, + int n, + const pten::dtype::complex *alpha, + const pten::dtype::complex *A, + int lda, + const pten::dtype::complex *B, + int ldb, + const pten::dtype::complex *beta, + pten::dtype::complex *C, + int ldc) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasCgemv( + handle, + transa, + m, + n, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + reinterpret_cast(B), + ldb, + reinterpret_cast(beta), + reinterpret_cast(C), + ldc)); + } + + static void AXPY(cublasHandle_t handle, + int n, + const pten::dtype::complex *alpha, + const pten::dtype::complex *X, + const int incX, + pten::dtype::complex *Y, + const int incY) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasCaxpy( + handle, + n, + reinterpret_cast(alpha), + reinterpret_cast(X), + incX, + reinterpret_cast(Y), + incY)); + } + + static void GEMM_STRIDED_BATCH( + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const pten::dtype::complex *alpha, + const pten::dtype::complex *A, + int lda, + long long int strideA, // NOLINT + const pten::dtype::complex *B, // NOLINT + int ldb, + long long int strideB, // NOLINT + const pten::dtype::complex *beta, + pten::dtype::complex *C, + int ldc, + long long int strideC, // NOLINT + int batchCount) { +#if CUDA_VERSION >= 8000 + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasCgemmStridedBatched( + handle, + transa, + transb, + m, + n, + k, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + strideA, + reinterpret_cast(B), + ldb, + strideB, + reinterpret_cast(beta), + reinterpret_cast(C), + ldc, + strideC, + batchCount)); +#else + PADDLE_THROW(pten::errors::Unimplemented( + "CgemmStridedBatched is not supported on cuda <= 7.5")); +#endif + } + + static void GEMM(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const pten::dtype::complex *alpha, + const pten::dtype::complex *A, + int lda, + const pten::dtype::complex *B, + int ldb, + const pten::dtype::complex *beta, + pten::dtype::complex *C, + int ldc) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasCgemm( + handle, + transa, + transb, + m, + n, + k, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + reinterpret_cast(B), + ldb, + reinterpret_cast(beta), + reinterpret_cast(C), + ldc)); + } + + static void TRSM(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t transa, + cublasDiagType_t diag, + int m, + int n, + const pten::dtype::complex *alpha, + const pten::dtype::complex *A, + int lda, + pten::dtype::complex *B, + int ldb) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasCtrsm( + handle, + side, + uplo, + transa, + diag, + m, + n, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + reinterpret_cast(B), + ldb)); + } + + // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. + // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode + template + static void GEMM_EX(paddle::platform::CUDADeviceContext *dev_ctx, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void *alpha, + const void *A, + cudaDataType_t Atype, + int lda, + const void *B, + cudaDataType_t Btype, + int ldb, + const void *beta, + void *C, + cudaDataType_t Ctype, + int ldc, + cudaDataType_t computeType) { +#if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; +#if CUDA_VERSION >= 9000 + bool use_tensor_op_math = dev_ctx->tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); +#endif // CUDA_VERSION >= 9000 + + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc, + computeType, + algo)); + }); +#else + PADDLE_THROW(pten::errors::Unimplemented( + "cublasGemmEx is not supported on cuda <= 7.5")); +#endif + } + + // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. + // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode + template + static void GEMM_EX(pten::GPUContext *dev_ctx, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void *alpha, + const void *A, + cudaDataType_t Atype, + int lda, + const void *B, + cudaDataType_t Btype, + int ldb, + const void *beta, + void *C, + cudaDataType_t Ctype, + int ldc, + cudaDataType_t computeType) { +#if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; +#if CUDA_VERSION >= 9000 + bool use_tensor_op_math = dev_ctx->tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); +#endif // CUDA_VERSION >= 9000 + + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc, + computeType, + algo)); + }); +#else + PADDLE_THROW(pten::errors::Unimplemented( + "cublasGemmEx is not supported on cuda <= 7.5")); +#endif + } + + static void TRSM_BATCH(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t transa, + cublasDiagType_t diag, + int m, + int n, + const pten::dtype::complex *alpha, + const pten::dtype::complex **A, + int lda, + pten::dtype::complex **B, + int ldb, + int batch_size) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasCtrsmBatched( + handle, + side, + uplo, + transa, + diag, + m, + n, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + reinterpret_cast(B), + ldb, + batch_size)); + } +}; + +template <> +struct CUBlas> { + static void GEMV(cublasHandle_t handle, + cublasOperation_t transa, + int m, + int n, + const pten::dtype::complex *alpha, + const pten::dtype::complex *A, + int lda, + const pten::dtype::complex *B, + int ldb, + const pten::dtype::complex *beta, + pten::dtype::complex *C, + int ldc) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasZgemv( + handle, + transa, + m, + n, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + reinterpret_cast(B), + ldb, + reinterpret_cast(beta), + reinterpret_cast(C), + ldc)); + } + + static void AXPY(cublasHandle_t handle, + int n, + const pten::dtype::complex *alpha, + const pten::dtype::complex *X, + const int incX, + pten::dtype::complex *Y, + const int incY) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasZaxpy( + handle, + n, + reinterpret_cast(alpha), + reinterpret_cast(X), + incX, + reinterpret_cast(Y), + incY)); + } + + static void GEMM_STRIDED_BATCH( + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const pten::dtype::complex *alpha, + const pten::dtype::complex *A, + int lda, + long long int strideA, // NOLINT + const pten::dtype::complex *B, // NOLINT + int ldb, + long long int strideB, // NOLINT + const pten::dtype::complex *beta, + pten::dtype::complex *C, + int ldc, + long long int strideC, // NOLINT + int batchCount) { +#if CUDA_VERSION >= 8000 + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasZgemmStridedBatched( + handle, + transa, + transb, + m, + n, + k, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + strideA, + reinterpret_cast(B), + ldb, + strideB, + reinterpret_cast(beta), + reinterpret_cast(C), + ldc, + strideC, + batchCount)); +#else + PADDLE_THROW(pten::errors::Unimplemented( + "CgemmStridedBatched is not supported on cuda <= 7.5")); +#endif + } + + static void GEMM(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const pten::dtype::complex *alpha, + const pten::dtype::complex *A, + int lda, + const pten::dtype::complex *B, + int ldb, + const pten::dtype::complex *beta, + pten::dtype::complex *C, + int ldc) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasZgemm( + handle, + transa, + transb, + m, + n, + k, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + reinterpret_cast(B), + ldb, + reinterpret_cast(beta), + reinterpret_cast(C), + ldc)); + } + + static void TRSM(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t transa, + cublasDiagType_t diag, + int m, + int n, + const pten::dtype::complex *alpha, + const pten::dtype::complex *A, + int lda, + pten::dtype::complex *B, + int ldb) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasZtrsm( + handle, + side, + uplo, + transa, + diag, + m, + n, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + reinterpret_cast(B), + ldb)); + } + + static void TRSM_BATCH(cublasHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t transa, + cublasDiagType_t diag, + int m, + int n, + const pten::dtype::complex *alpha, + const pten::dtype::complex **A, + int lda, + pten::dtype::complex **B, + int ldb, + int batch_size) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasZtrsmBatched( + handle, + side, + uplo, + transa, + diag, + m, + n, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + reinterpret_cast(B), + ldb, + batch_size)); + } + + // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. + // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode + template + static void GEMM_EX(paddle::platform::CUDADeviceContext *dev_ctx, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void *alpha, + const void *A, + cudaDataType_t Atype, + int lda, + const void *B, + cudaDataType_t Btype, + int ldb, + const void *beta, + void *C, + cudaDataType_t Ctype, + int ldc, + cudaDataType_t computeType) { +#if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; +#if CUDA_VERSION >= 9000 + bool use_tensor_op_math = dev_ctx->tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); +#endif // CUDA_VERSION >= 9000 + + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc, + computeType, + algo)); + }); +#else + PADDLE_THROW(pten::errors::Unimplemented( + "cublasGemmEx is not supported on cuda <= 7.5")); +#endif + } + + // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. + // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode + template + static void GEMM_EX(pten::GPUContext *dev_ctx, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void *alpha, + const void *A, + cudaDataType_t Atype, + int lda, + const void *B, + cudaDataType_t Btype, + int ldb, + const void *beta, + void *C, + cudaDataType_t Ctype, + int ldc, + cudaDataType_t computeType) { +#if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; +#if CUDA_VERSION >= 9000 + bool use_tensor_op_math = dev_ctx->tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); +#endif // CUDA_VERSION >= 9000 + + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc, + computeType, + algo)); + }); +#else + PADDLE_THROW(pten::errors::Unimplemented( + "cublasGemmEx is not supported on cuda <= 7.5")); +#endif + } +}; + +template <> +template +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T *A, + const T *B, + T beta, + T *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + +#if CUDA_VERSION >= 8000 + if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { + auto &cuda_ctx = + const_cast(context_); + CUBlas::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + CUDA_R_32F, + ldb, + A, + CUDA_R_32F, + lda, + &beta, + C, + CUDA_R_32F, + N); + } else { +#endif // CUDA_VERSION >= 8000 + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + N); + }); + +#if CUDA_VERSION >= 8000 + } +#endif // CUDA_VERSION >= 8000 +} + +template <> +template +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T *A, + const T *B, + T beta, + T *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + +#if CUDA_VERSION >= 8000 + if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { + auto &cuda_ctx = const_cast(context_); + CUBlas::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + CUDA_R_32F, + ldb, + A, + CUDA_R_32F, + lda, + &beta, + C, + CUDA_R_32F, + N); + } else { +#endif // CUDA_VERSION >= 8000 + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + N); + }); + +#if CUDA_VERSION >= 8000 + } +#endif // CUDA_VERSION >= 8000 +} + +template <> +template <> +inline void Blas::GEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::float16 alpha, + const pten::dtype::float16 *A, + const pten::dtype::float16 *B, + pten::dtype::float16 beta, + pten::dtype::float16 *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 53, + pten::errors::InvalidArgument( + "cublas fp16 gemm requires GPU compute capability >= 53," + "but received %d", + context_.GetComputeCapability())); + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + +#if CUDA_VERSION >= 8000 + // cublasHgemm does true FP16 computation which is slow for non-Volta + // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: + // input/output in fp16, computation in fp32, which can also be accelerated + // using tensor cores in volta GPUs. + auto &cuda_ctx = const_cast(context_); + CUBlas::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + CUDA_R_16F, + ldb, + A, + CUDA_R_16F, + lda, + &h_beta, + C, + CUDA_R_16F, + N, + CUDA_R_32F); +#else + // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + h_B, + ldb, + h_A, + lda, + &h_beta, + h_C, + N); + }); +#endif // CUDA_VERSION >= 8000 +} + +template <> +template <> +inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::float16 alpha, + const pten::dtype::float16 *A, + const pten::dtype::float16 *B, + pten::dtype::float16 beta, + pten::dtype::float16 *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 53, + pten::errors::InvalidArgument( + "cublas fp16 gemm requires GPU compute capability >= 53," + "but received %d", + context_.GetComputeCapability())); + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + +#if CUDA_VERSION >= 8000 + // cublasHgemm does true FP16 computation which is slow for non-Volta + // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: + // input/output in fp16, computation in fp32, which can also be accelerated + // using tensor cores in volta GPUs. + auto &cuda_ctx = const_cast(context_); + CUBlas::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + CUDA_R_16F, + ldb, + A, + CUDA_R_16F, + lda, + &h_beta, + C, + CUDA_R_16F, + N, + CUDA_R_32F); +#else + // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + h_B, + ldb, + h_A, + lda, + &h_beta, + h_C, + N); + }); +#endif // CUDA_VERSION >= 8000 +} + +template <> +template <> +inline void Blas::GEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::bfloat16 alpha, + const pten::dtype::bfloat16 *A, + const pten::dtype::bfloat16 *B, + pten::dtype::bfloat16 beta, + pten::dtype::bfloat16 *C) const { +#if CUDA_VERSION >= 11000 + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 80, + pten::errors::InvalidArgument( + "cublas fp16 gemm requires GPU compute capability >= 80," + "but received %d", + context_.GetComputeCapability())); + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasGemmEx(handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + CUDA_R_16BF, + ldb, + A, + CUDA_R_16BF, + lda, + &h_beta, + C, + CUDA_R_16BF, + N, + CUDA_R_32F, + algo)); + }); +#else + // raise error + PADDLE_THROW(pten::errors::Unimplemented( + "cublasGemmEx with bfloat16 is not supported on cuda <= 11")); + +#endif // CUDA_VERSION >= 11000 +} + +template <> +template <> +inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::bfloat16 alpha, + const pten::dtype::bfloat16 *A, + const pten::dtype::bfloat16 *B, + pten::dtype::bfloat16 beta, + pten::dtype::bfloat16 *C) const { +#if CUDA_VERSION >= 11000 + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 80, + pten::errors::InvalidArgument( + "cublas bf16 gemm requires GPU compute capability >= 80," + "but received %d", + context_.GetComputeCapability())); + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); + + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasGemmEx(handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + CUDA_R_16BF, + ldb, + A, + CUDA_R_16BF, + lda, + &h_beta, + C, + CUDA_R_16BF, + N, + CUDA_R_32F, + algo)); + }); +#else + // raise error + PADDLE_THROW(pten::errors::Unimplemented( + "cublasGemmEx with bfloat16 is not supported on cuda <= 11")); + +#endif // CUDA_VERSION >= 11000 +} + +template <> +template <> +inline void Blas::GEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::complex alpha, + const pten::dtype::complex *A, + const pten::dtype::complex *B, + pten::dtype::complex beta, + pten::dtype::complex *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 53, + pten::errors::InvalidArgument( + "cublas complex64 gemm requires GPU compute capability >= 53," + "but received %d", + context_.GetComputeCapability())); + + thrust::complex c_alpha = + thrust::complex(alpha.real, alpha.imag); + thrust::complex c_beta = thrust::complex(beta.real, beta.imag); + +#if CUDA_VERSION >= 8000 + // cublasHgemm does true FP16 computation which is slow for non-Volta + // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: + // input/output in fp16, computation in fp32, which can also be accelerated + // using tensor cores in volta GPUs. + auto &cuda_ctx = const_cast(context_); + CUBlas>::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &c_alpha, + B, + CUDA_C_32F, + ldb, + A, + CUDA_C_32F, + lda, + &c_beta, + C, + CUDA_C_32F, + N, + CUDA_C_32F); +#else + // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas>::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &c_alpha, + h_B, + ldb, + h_A, + lda, + &c_beta, + h_C, + N); + }); +#endif // CUDA_VERSION >= 8000 +} + +template <> +template <> +inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::complex alpha, + const pten::dtype::complex *A, + const pten::dtype::complex *B, + pten::dtype::complex beta, + pten::dtype::complex *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 53, + pten::errors::InvalidArgument( + "cublas complex64 gemm requires GPU compute capability >= 53," + "but received %d", + context_.GetComputeCapability())); + + thrust::complex c_alpha = + thrust::complex(alpha.real, alpha.imag); + thrust::complex c_beta = thrust::complex(beta.real, beta.imag); + +#if CUDA_VERSION >= 8000 + // cublasHgemm does true FP16 computation which is slow for non-Volta + // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: + // input/output in fp16, computation in fp32, which can also be accelerated + // using tensor cores in volta GPUs. + auto &cuda_ctx = const_cast(context_); + CUBlas>::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &c_alpha, + B, + CUDA_C_32F, + ldb, + A, + CUDA_C_32F, + lda, + &c_beta, + C, + CUDA_C_32F, + N, + CUDA_C_32F); +#else + // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas>::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &c_alpha, + h_B, + ldb, + h_A, + lda, + &c_beta, + h_C, + N); + }); +#endif // CUDA_VERSION >= 8000 +} + +template <> +template <> +inline void Blas::GEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::complex alpha, + const pten::dtype::complex *A, + const pten::dtype::complex *B, + pten::dtype::complex beta, + pten::dtype::complex *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 53, + pten::errors::InvalidArgument( + "cublas complex128 gemm requires GPU compute capability >= 53," + "but received %d", + context_.GetComputeCapability())); + + thrust::complex c_alpha = + thrust::complex(alpha.real, alpha.imag); + thrust::complex c_beta = + thrust::complex(beta.real, beta.imag); + +#if CUDA_VERSION >= 8000 + // cublasHgemm does true FP16 computation which is slow for non-Volta + // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: + // input/output in fp16, computation in fp32, which can also be accelerated + // using tensor cores in volta GPUs. + auto &cuda_ctx = const_cast(context_); + CUBlas>::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &c_alpha, + B, + CUDA_C_64F, + ldb, + A, + CUDA_C_64F, + lda, + &c_beta, + C, + CUDA_C_64F, + N, + CUDA_C_64F); +#else + // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas>::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &c_alpha, + h_B, + ldb, + h_A, + lda, + &c_beta, + h_C, + N); + }); +#endif // CUDA_VERSION >= 8000 +} + +template <> +template <> +inline void Blas::GEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::complex alpha, + const pten::dtype::complex *A, + const pten::dtype::complex *B, + pten::dtype::complex beta, + pten::dtype::complex *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 53, + pten::errors::InvalidArgument( + "cublas complex128 gemm requires GPU compute capability >= 53," + "but received %d", + context_.GetComputeCapability())); + + thrust::complex c_alpha = + thrust::complex(alpha.real, alpha.imag); + thrust::complex c_beta = + thrust::complex(beta.real, beta.imag); + +#if CUDA_VERSION >= 8000 + // cublasHgemm does true FP16 computation which is slow for non-Volta + // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: + // input/output in fp16, computation in fp32, which can also be accelerated + // using tensor cores in volta GPUs. + auto &cuda_ctx = const_cast(context_); + CUBlas>::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &c_alpha, + B, + CUDA_C_64F, + ldb, + A, + CUDA_C_64F, + lda, + &c_beta, + C, + CUDA_C_64F, + N, + CUDA_C_64F); +#else + // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas>::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &c_alpha, + h_B, + ldb, + h_A, + lda, + &c_beta, + h_C, + N); + }); +#endif // CUDA_VERSION >= 8000 +} + +template <> +template +void Blas::GEMM(bool transA, + bool transB, + int M, + int N, + int K, + T alpha, + const T *A, + int lda, + const T *B, + int ldb, + T beta, + T *C, + int ldc) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + +#if CUDA_VERSION >= 8000 + if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { + auto &cuda_ctx = + const_cast(context_); + CUBlas::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + CUDA_R_32F, + ldb, + A, + CUDA_R_32F, + lda, + &beta, + C, + CUDA_R_32F, + ldc); + } else { +#endif // CUDA_VERSION >= 8000 + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + ldc); + }); + +#if CUDA_VERSION >= 8000 + } +#endif // CUDA_VERSION >= 8000 +} + +template <> +template +void Blas::GEMM(bool transA, + bool transB, + int M, + int N, + int K, + T alpha, + const T *A, + int lda, + const T *B, + int ldb, + T beta, + T *C, + int ldc) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + +#if CUDA_VERSION >= 8000 + if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { + auto &cuda_ctx = const_cast(context_); + CUBlas::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + CUDA_R_32F, + ldb, + A, + CUDA_R_32F, + lda, + &beta, + C, + CUDA_R_32F, + ldc); + } else { +#endif // CUDA_VERSION >= 8000 + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + ldc); + }); + +#if CUDA_VERSION >= 8000 + } +#endif // CUDA_VERSION >= 8000 +} + +template <> +template <> +inline void Blas::GEMM( + bool transA, + bool transB, + int M, + int N, + int K, + pten::dtype::float16 alpha, + const pten::dtype::float16 *A, + int lda, + const pten::dtype::float16 *B, + int ldb, + pten::dtype::float16 beta, + pten::dtype::float16 *C, + int ldc) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + ldc); + }); +} + +template <> +template <> +inline void Blas::GEMM(bool transA, + bool transB, + int M, + int N, + int K, + pten::dtype::float16 alpha, + const pten::dtype::float16 *A, + int lda, + const pten::dtype::float16 *B, + int ldb, + pten::dtype::float16 beta, + pten::dtype::float16 *C, + int ldc) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + ldc); + }); +} + +template <> +template +void Blas::AXPY(int n, + T alpha, + const T *x, + T *y) const { + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::AXPY(handle, n, &alpha, x, 1, y, 1); + }); +} + +template <> +template +void Blas::AXPY(int n, T alpha, const T *x, T *y) const { + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::AXPY(handle, n, &alpha, x, 1, y, 1); + }); +} + +template <> +template +void Blas::SCAL(int n, + const T alpha, + T *x) const { + context_.CublasCall( + [&](cublasHandle_t handle) { CUBlas::SCAL(handle, n, &alpha, x, 1); }); +} + +template <> +template +void Blas::SCAL(int n, const T alpha, T *x) const { + context_.CublasCall( + [&](cublasHandle_t handle) { CUBlas::SCAL(handle, n, &alpha, x, 1); }); +} + +template <> +template +void Blas::VCOPY(int n, + const T *x, + T *y) const { + context_.CublasCall( + [&](cublasHandle_t handle) { CUBlas::VCOPY(handle, n, x, 1, y, 1); }); +} + +template <> +template +void Blas::VCOPY(int n, const T *x, T *y) const { + context_.CublasCall( + [&](cublasHandle_t handle) { CUBlas::VCOPY(handle, n, x, 1, y, 1); }); +} + +template <> +template +void Blas::GEMV(bool trans_a, + int M, + int N, + T alpha, + const T *A, + const T *B, + T beta, + T *C) const { + cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1); + }); +} + +template <> +template +void Blas::GEMV(bool trans_a, + int M, + int N, + T alpha, + const T *A, + const T *B, + T beta, + T *C) const { + cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1); + }); +} + +template <> +template <> +inline void Blas::GEMV( + bool trans_a, + int M, + int N, + pten::dtype::float16 alpha, + const pten::dtype::float16 *A, + const pten::dtype::float16 *B, + pten::dtype::float16 beta, + pten::dtype::float16 *C) const { + // Because cublas doesn't support half gemv, we use cublasHgemm to achieve it. + if (trans_a) { + this->template GEMM( + CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C); + } else { + this->template GEMM( + CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C); + } +} + +template <> +template <> +inline void Blas::GEMV(bool trans_a, + int M, + int N, + pten::dtype::float16 alpha, + const pten::dtype::float16 *A, + const pten::dtype::float16 *B, + pten::dtype::float16 beta, + pten::dtype::float16 *C) const { + // Because cublas doesn't support half gemv, we use cublasHgemm to achieve it. + if (trans_a) { + this->template GEMM( + CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C); + } else { + this->template GEMM( + CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C); + } +} + +template <> +template <> +inline void Blas::GEMV( + bool trans_a, + int M, + int N, + pten::dtype::bfloat16 alpha, + const pten::dtype::bfloat16 *A, + const pten::dtype::bfloat16 *B, + pten::dtype::bfloat16 beta, + pten::dtype::bfloat16 *C) const { + // Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve + // it. + if (trans_a) { + this->template GEMM( + CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C); + } else { + this->template GEMM( + CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C); + } +} + +template <> +template <> +inline void Blas::GEMV(bool trans_a, + int M, + int N, + pten::dtype::bfloat16 alpha, + const pten::dtype::bfloat16 *A, + const pten::dtype::bfloat16 *B, + pten::dtype::bfloat16 beta, + pten::dtype::bfloat16 *C) const { + // Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve + // it. + if (trans_a) { + this->template GEMM( + CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C); + } else { + this->template GEMM( + CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C); + } +} + +template <> +template +void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T *A, + const T *B, + T beta, + T *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const int64_t strideC = M * N; + +#if CUDA_VERSION >= 9010 + if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same::value)) || + std::is_same::value) { + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); + + auto fp = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasGemmStridedBatchedEx(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + fp, + ldb, + strideB, + A, + fp, + lda, + strideA, + &beta, + C, + fp, + ldc, + strideC, + batchCount, + fp, + algo)); + }); + } else { +#endif // CUDA_VERSION >= 9010 + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM_STRIDED_BATCH(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + strideB, + A, + lda, + strideA, + &beta, + C, + ldc, + strideC, + batchCount); + }); + +#if CUDA_VERSION >= 9010 + } +#endif // CUDA_VERSION >= 9010 +} + +template <> +template +void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T *A, + const T *B, + T beta, + T *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const int64_t strideC = M * N; + +#if CUDA_VERSION >= 9010 + if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same::value)) || + std::is_same::value) { + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); + + auto fp = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasGemmStridedBatchedEx(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + fp, + ldb, + strideB, + A, + fp, + lda, + strideA, + &beta, + C, + fp, + ldc, + strideC, + batchCount, + fp, + algo)); + }); + } else { +#endif // CUDA_VERSION >= 9010 + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM_STRIDED_BATCH(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + strideB, + A, + lda, + strideA, + &beta, + C, + ldc, + strideC, + batchCount); + }); + +#if CUDA_VERSION >= 9010 + } +#endif // CUDA_VERSION >= 9010 +} + +template <> +template <> +inline void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::bfloat16 alpha, + const pten::dtype::bfloat16 *A, + const pten::dtype::bfloat16 *B, + pten::dtype::bfloat16 beta, + pten::dtype::bfloat16 *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { +#if CUDA_VERSION >= 11000 + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const int64_t strideC = M * N; + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); + + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasGemmStridedBatchedEx( + handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + CUDA_R_16BF, + ldb, + strideB, + A, + CUDA_R_16BF, + lda, + strideA, + &h_beta, + C, + CUDA_R_16BF, + ldc, + strideC, + batchCount, + CUBLAS_COMPUTE_32F, + algo)); + }); +#else + // raise error + PADDLE_THROW(pten::errors::Unimplemented( + "cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= " + "11")); +#endif // CUDA_VERSION >= 11000 +} + +template <> +template <> +inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::bfloat16 alpha, + const pten::dtype::bfloat16 *A, + const pten::dtype::bfloat16 *B, + pten::dtype::bfloat16 beta, + pten::dtype::bfloat16 *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { +#if CUDA_VERSION >= 11000 + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const int64_t strideC = M * N; + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); + + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasGemmStridedBatchedEx( + handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + CUDA_R_16BF, + ldb, + strideB, + A, + CUDA_R_16BF, + lda, + strideA, + &h_beta, + C, + CUDA_R_16BF, + ldc, + strideC, + batchCount, + CUBLAS_COMPUTE_32F, + algo)); + }); +#else + // raise error + PADDLE_THROW(pten::errors::Unimplemented( + "cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= " + "11")); +#endif // CUDA_VERSION >= 11000 +} + +template <> +template +void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T **A, + const T **B, + T beta, + T **C, + int batchCount) const { + for (int k = 0; k < batchCount; ++k) { + this->template GEMM( + transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); + } +} + +template <> +template +void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T **A, + const T **B, + T beta, + T **C, + int batchCount) const { + for (int k = 0; k < batchCount; ++k) { + this->template GEMM( + transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); + } +} + +template <> +template <> +inline void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::float16 alpha, + const pten::dtype::float16 **A, + const pten::dtype::float16 **B, + pten::dtype::float16 beta, + pten::dtype::float16 **C, + int batchCount) const { + for (int k = 0; k < batchCount; ++k) { + this->template GEMM( + transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); + } +} + +template <> +template <> +inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::float16 alpha, + const pten::dtype::float16 **A, + const pten::dtype::float16 **B, + pten::dtype::float16 beta, + pten::dtype::float16 **C, + int batchCount) const { + for (int k = 0; k < batchCount; ++k) { + this->template GEMM( + transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); + } +} + +template <> +template <> +inline void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::bfloat16 alpha, + const pten::dtype::bfloat16 **A, + const pten::dtype::bfloat16 **B, + pten::dtype::bfloat16 beta, + pten::dtype::bfloat16 **C, + int batchCount) const { + for (int k = 0; k < batchCount; ++k) { + this->template GEMM( + transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); + } +} + +template <> +template <> +inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::bfloat16 alpha, + const pten::dtype::bfloat16 **A, + const pten::dtype::bfloat16 **B, + pten::dtype::bfloat16 beta, + pten::dtype::bfloat16 **C, + int batchCount) const { + for (int k = 0; k < batchCount; ++k) { + this->template GEMM( + transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); + } +} + +template <> +template +void Blas::TRSM(CBLAS_SIDE side, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE transA, + CBLAS_DIAG diag, + int M, + int N, + T alpha, + const T *A, + int lda, + T *B, + int ldb) const { + // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` + // where ' stands for transpose + cublasSideMode_t cuSide = + (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT; + cublasFillMode_t cuUplo = + (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; + // use CUBLAS_OP_C (conjugate transpose) for complex + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasDiagType_t cuDiag = + (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::TRSM( + handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A, lda, B, ldb); + }); +} + +template <> +template +void Blas::TRSM(CBLAS_SIDE side, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE transA, + CBLAS_DIAG diag, + int M, + int N, + T alpha, + const T *A, + int lda, + T *B, + int ldb) const { + // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` + // where ' stands for transpose + cublasSideMode_t cuSide = + (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT; + cublasFillMode_t cuUplo = + (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; + // use CUBLAS_OP_C (conjugate transpose) for complex + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasDiagType_t cuDiag = + (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::TRSM( + handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A, lda, B, ldb); + }); +} + +template <> +template +void Blas::BatchedGETRF( + int n, T **a, int *ipiv, int *info, int batch_size) const { + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size); + }); +} + +template <> +template +void Blas::BatchedGETRF( + int n, T **a, int *ipiv, int *info, int batch_size) const { + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size); + }); +} + +template <> +template +void Blas::BatchedGETRI( + int n, const T **a, const int *ipiv, T **a_inv, int *info, int batch_size) + const { + PADDLE_ENFORCE_NE( + a_inv, + a, + pten::errors::InvalidArgument( + "cuBLAS fuction 'cublasgetrfBatched' cannot be executed " + "in-place. The memory space of output matrix (address: %p) cannot " + "overlap memory space of input matrix (address: %p).", + a_inv, + a)); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size); + }); +} + +template <> +template +void Blas::BatchedGETRI(int n, + const T **a, + const int *ipiv, + T **a_inv, + int *info, + int batch_size) const { + PADDLE_ENFORCE_NE( + a_inv, + a, + pten::errors::InvalidArgument( + "cuBLAS fuction 'cublasgetrfBatched' cannot be executed " + "in-place. The memory space of output matrix (address: %p) cannot " + "overlap memory space of input matrix (address: %p).", + a_inv, + a)); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size); + }); +} + +template <> +template +void Blas::BatchedMatInv( + int n, const T **a, T **a_inv, int *info, int batch_size) const { + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size); + }); +} + +template <> +template +void Blas::BatchedMatInv( + int n, const T **a, T **a_inv, int *info, int batch_size) const { + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size); + }); +} + +template <> +template +void Blas::BatchedGETRS( + CBLAS_TRANSPOSE trans, + int n, + int nrhs, + const T **a, + int lda, + int *ipiv, + T **b, + int ldb, + int *info, + int batch_size) const { + // use CUBLAS_OP_C (conjugate transpose) for complex + cublasOperation_t cuTrans = + (trans == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GETRS_BATCH( + handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info, batch_size); + }); +} + +template <> +template +void Blas::BatchedGETRS(CBLAS_TRANSPOSE trans, + int n, + int nrhs, + const T **a, + int lda, + int *ipiv, + T **b, + int ldb, + int *info, + int batch_size) const { + // use CUBLAS_OP_C (conjugate transpose) for complex + cublasOperation_t cuTrans = + (trans == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GETRS_BATCH( + handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info, batch_size); + }); +} + +template <> +template +void Blas::BatchedTRSM( + CBLAS_SIDE side, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE transA, + CBLAS_DIAG diag, + int M, + int N, + T alpha, + const T **A, + int lda, + T **B, + int ldb, + int batch_size) const { + // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` + // where ' stands for transpose + cublasSideMode_t cuSide = + (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT; + cublasFillMode_t cuUplo = + (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; + // use CUBLAS_OP_C (conjugate transpose) for complex + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasDiagType_t cuDiag = + (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::TRSM_BATCH(handle, + cuSide, + cuUplo, + cuTransA, + cuDiag, + N, + M, + &alpha, + A, + lda, + B, + ldb, + batch_size); + }); +} + +template <> +template +void Blas::BatchedTRSM(CBLAS_SIDE side, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE transA, + CBLAS_DIAG diag, + int M, + int N, + T alpha, + const T **A, + int lda, + T **B, + int ldb, + int batch_size) const { + // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` + // where ' stands for transpose + cublasSideMode_t cuSide = + (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT; + cublasFillMode_t cuUplo = + (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; + // use CUBLAS_OP_C (conjugate transpose) for complex + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasDiagType_t cuDiag = + (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::TRSM_BATCH(handle, + cuSide, + cuUplo, + cuTransA, + cuDiag, + N, + M, + &alpha, + A, + lda, + B, + ldb, + batch_size); + }); +} + +} // namespace funcs +} // namespace pten diff --git a/paddle/pten/kernels/funcs/blas/blas_impl.h b/paddle/pten/kernels/funcs/blas/blas_impl.h new file mode 100644 index 0000000000000..5c93011ab500a --- /dev/null +++ b/paddle/pten/kernels/funcs/blas/blas_impl.h @@ -0,0 +1,2530 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "paddle/pten/backends/cpu/cpu_context.h" +#ifdef PADDLE_WITH_MKLML +#include +#endif + +#include +#include +#include +#include + +#include "paddle/pten/common/bfloat16.h" +#include "paddle/pten/common/complex.h" +#include "paddle/pten/kernels/funcs/math_function.h" + +namespace pten { +namespace funcs { + +namespace detail { +template +static void axpy( + int n, const T alpha, const T *x, const int incx, T *y, const int incy) { + // Y = Y + alpha * X + while (n-- > 0) { + *y += alpha * *x; + y = y + incy; + x = x + incx; + } +} +} // namespace detail + +template +struct CBlas; + +template <> +struct CBlas { + template + static void VCOPY(ARGS... args) { + PADDLE_THROW(pten::errors::Unimplemented( + "Blas VCOPY do not supported on CPU, please check your code")); + } +}; + +template <> +struct CBlas { + template + static void VCOPY(ARGS... args) { + PADDLE_THROW(pten::errors::Unimplemented( + "Blas VCOPY do not supported on CPU, please check your code")); + } +}; + +template <> +struct CBlas { + template + static void AXPY(ARGS... args) { + detail::axpy(args...); + } + + template + static void VCOPY(ARGS... args) { + PADDLE_THROW(pten::errors::Unimplemented( + "Blas VCOPY do not supported on CPU with bfloat16," + " please check your code")); + } +}; + +#ifdef PADDLE_WITH_MKLML +template <> +struct CBlas { + template + static void GEMM(ARGS... args) { + paddle::platform::dynload::cblas_sgemm(args...); + } + + template + static float *GEMM_ALLOC(ARGS... args) { + return paddle::platform::dynload::cblas_sgemm_alloc(args...); + } + + template + static void GEMM_PACK(ARGS... args) { + paddle::platform::dynload::cblas_sgemm_pack(args...); + } + + template + static void GEMM_COMPUTE(ARGS... args) { + paddle::platform::dynload::cblas_sgemm_compute(args...); + } + + template + static void GEMM_FREE(ARGS... args) { + paddle::platform::dynload::cblas_sgemm_free(args...); + } + +#ifdef PADDLE_WITH_LIBXSMM + template + static void SMM_GEMM(ARGS... args) { + libxsmm_sgemm(args...); + } +#endif + + template + static void AXPY(ARGS... args) { + paddle::platform::dynload::cblas_saxpy(args...); + } + + template + static void VCOPY(ARGS... args) { + paddle::platform::dynload::cblas_scopy(args...); + } + + template + static void GEMV(ARGS... args) { + paddle::platform::dynload::cblas_sgemv(args...); + } + + template + static float DOT(ARGS... args) { + return paddle::platform::dynload::cblas_sdot(args...); + } + + template + static void SCAL(ARGS... args) { + paddle::platform::dynload::cblas_sscal(args...); + } + + template + static float ASUM(ARGS... args) { + return paddle::platform::dynload::cblas_sasum(args...); + } + + template + static void GEMM_BATCH(ARGS... args) { + paddle::platform::dynload::cblas_sgemm_batch(args...); + } + + template + static void VADD(ARGS... args) { + paddle::platform::dynload::vsAdd(args...); + } + + template + static void VSUB(ARGS... args) { + paddle::platform::dynload::vsSub(args...); + } + + template + static void VMUL(ARGS... args) { + paddle::platform::dynload::vsMul(args...); + } + + template + static void VDIV(ARGS... args) { + paddle::platform::dynload::vsDiv(args...); + } + + template + static void VEXP(ARGS... args) { + paddle::platform::dynload::vsExp(args...); + } + + template + static void VSQUARE(ARGS... args) { + paddle::platform::dynload::vsSqr(args...); + } + + template + static void VPOW(ARGS... args) { + paddle::platform::dynload::vsPowx(args...); + } + + template + static void VINV(ARGS... args) { + paddle::platform::dynload::vsInv(args...); + } + + template + static void VMERF(ARGS... args) { + paddle::platform::dynload::vmsErf(args...); + } +#if !defined(_WIN32) + template + static void CSRMM(ARGS... args) { + paddle::platform::dynload::mkl_scsrmm(args...); + } +#endif + + template + static void TRSM(ARGS... args) { + paddle::platform::dynload::cblas_strsm(args...); + } +}; + +template <> +struct CBlas { + template + static void GEMM(ARGS... args) { + paddle::platform::dynload::cblas_dgemm(args...); + } + + template + static double *GEMM_ALLOC(ARGS... args) { + return paddle::platform::dynload::cblas_dgemm_alloc(args...); + } + + template + static void GEMM_PACK(ARGS... args) { + paddle::platform::dynload::cblas_dgemm_pack(args...); + } + + template + static void GEMM_COMPUTE(ARGS... args) { + paddle::platform::dynload::cblas_dgemm_compute(args...); + } + + template + static void GEMM_FREE(ARGS... args) { + paddle::platform::dynload::cblas_dgemm_free(args...); + } + +#ifdef PADDLE_WITH_LIBXSMM + template + static void SMM_GEMM(ARGS... args) { + libxsmm_dgemm(args...); + } +#endif + + template + static void AXPY(ARGS... args) { + paddle::platform::dynload::cblas_daxpy(args...); + } + + template + static void VCOPY(ARGS... args) { + paddle::platform::dynload::cblas_dcopy(args...); + } + + template + static void GEMV(ARGS... args) { + paddle::platform::dynload::cblas_dgemv(args...); + } + + template + static double DOT(ARGS... args) { + return paddle::platform::dynload::cblas_ddot(args...); + } + + template + static void SCAL(ARGS... args) { + paddle::platform::dynload::cblas_dscal(args...); + } + + template + static double ASUM(ARGS... args) { + return paddle::platform::dynload::cblas_dasum(args...); + } + + template + static void GEMM_BATCH(ARGS... args) { + paddle::platform::dynload::cblas_dgemm_batch(args...); + } + + template + static void VADD(ARGS... args) { + paddle::platform::dynload::vdAdd(args...); + } + + template + static void VSUB(ARGS... args) { + paddle::platform::dynload::vdSub(args...); + } + + template + static void VMUL(ARGS... args) { + paddle::platform::dynload::vdMul(args...); + } + + template + static void VDIV(ARGS... args) { + paddle::platform::dynload::vdDiv(args...); + } + + template + static void VEXP(ARGS... args) { + paddle::platform::dynload::vdExp(args...); + } + + template + static void VSQUARE(ARGS... args) { + paddle::platform::dynload::vdSqr(args...); + } + + template + static void VPOW(ARGS... args) { + paddle::platform::dynload::vdPowx(args...); + } + + template + static void VINV(ARGS... args) { + paddle::platform::dynload::vdInv(args...); + } + + template + static void VMERF(ARGS... args) { + paddle::platform::dynload::vmdErf(args...); + } +#if !defined(_WIN32) + template + static void CSRMM(ARGS... args) { + paddle::platform::dynload::mkl_dcsrmm(args...); + } +#endif + + template + static void TRSM(ARGS... args) { + paddle::platform::dynload::cblas_dtrsm(args...); + } +}; + +template <> +struct CBlas> { + template + static void AXPY(int n, + const pten::dtype::complex alpha, + const pten::dtype::complex *X, + const int incX, + pten::dtype::complex *Y, + const int incY) { + paddle::platform::dynload::cblas_caxpy(n, &alpha, X, incX, Y, incY); + } + + template + static void VCOPY(ARGS... args) { + paddle::platform::dynload::cblas_ccopy(args...); + } + + // the libmklml_intel.so paddle used has no vcAdd, vcSub, + // vcMul, vcDiv apis before rebuild from source + // so replace with the raw operator methods + /* + template + static void VADD(ARGS... args) { + paddle::platform::dynload::vcAdd(args...); + } + + template + static void VSUB(ARGS... args) { + paddle::platform::dynload::vcSub(args...); + } + + template + static void VMUL(ARGS... args) { + paddle::platform::dynload::vcMul(args...); + } + + template + static void VDIV(ARGS... args) { + paddle::platform::dynload::vcDiv(args...); + } + */ + + template + static void VADD(int n, + const pten::dtype::complex *a, + const pten::dtype::complex *b, + pten::dtype::complex *y) { + for (int i = 0; i < n; ++i) { + y[i] = a[i] + b[i]; + } + } + + template + static void VSUB(int n, + const pten::dtype::complex *a, + const pten::dtype::complex *b, + pten::dtype::complex *y) { + for (int i = 0; i < n; ++i) { + y[i] = a[i] - b[i]; + } + } + + template + static void VMUL(int n, + const pten::dtype::complex *a, + const pten::dtype::complex *b, + pten::dtype::complex *y) { + for (int i = 0; i < n; ++i) { + y[i] = a[i] * b[i]; + } + } + template + static void VDIV(int n, + const pten::dtype::complex *a, + const pten::dtype::complex *b, + pten::dtype::complex *y) { + for (int i = 0; i < n; ++i) { + y[i] = a[i] / b[i]; + } + } + + template + static void GEMV(CBLAS_LAYOUT layout, + CBLAS_TRANSPOSE trans, + int M, + int N, + pten::dtype::complex alpha, + const pten::dtype::complex *A, + int lda, + const pten::dtype::complex *X, + int incx, + pten::dtype::complex beta, + pten::dtype::complex *Y, + int incy) { + const void *a_ = (const void *)(A); + const void *x_ = (const void *)(X); + void *y_ = static_cast(Y); + paddle::platform::dynload::cblas_cgemv( + layout, trans, M, N, &alpha, a_, lda, x_, incx, &beta, y_, incy); + } + + template + static void GEMM(CBLAS_LAYOUT layout, + CBLAS_TRANSPOSE trans_a, + CBLAS_TRANSPOSE trans_b, + int M, + int N, + int K, + pten::dtype::complex alpha, + const pten::dtype::complex *A, + int lda, + const pten::dtype::complex *B, + int ldb, + pten::dtype::complex beta, + pten::dtype::complex *C, + int ldc) { + const void *a_ = (const void *)(A); + const void *b_ = (const void *)(B); + void *c_ = static_cast(C); + paddle::platform::dynload::cblas_cgemm(layout, + trans_a, + trans_b, + M, + N, + K, + &alpha, + a_, + lda, + b_, + ldb, + &beta, + c_, + ldc); + } + + static void TRSM(CBLAS_LAYOUT layout, + CBLAS_SIDE side, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans_a, + CBLAS_DIAG diag, + int M, + int N, + pten::dtype::complex alpha, + const pten::dtype::complex *A, + int lda, + pten::dtype::complex *B, + int ldb) { + const void *a_ = (const void *)(A); + void *b_ = static_cast(B); + paddle::platform::dynload::cblas_ctrsm( + layout, side, uplo, trans_a, diag, M, N, &alpha, a_, lda, b_, ldb); + } + + template + static void GEMM_BATCH(CBLAS_LAYOUT layout, + CBLAS_TRANSPOSE *trans_a, + CBLAS_TRANSPOSE *trans_b, + int *M, + int *N, + int *K, + pten::dtype::complex *alpha, + const pten::dtype::complex **A, + const int *lda, + const pten::dtype::complex **B, + const int *ldb, + pten::dtype::complex *beta, + pten::dtype::complex **C, + const int *ldc, + int group_count, + int *group_size) { + const void **A_void = (const void **)(&(*A)); + const void **B_void = (const void **)(&(*B)); + void **C_void = reinterpret_cast(C); + + paddle::platform::dynload::cblas_cgemm_batch(layout, + trans_a, + trans_b, + M, + N, + K, + alpha, + A_void, + lda, + B_void, + ldb, + beta, + C_void, + ldc, + group_count, + group_size); + } + + template + static void GEMM_EX(ARGS... args) { + paddle::platform::dynload::cblas_cgemm_batch(args...); + } +}; + +template <> +struct CBlas> { + template + static void AXPY(int n, + const pten::dtype::complex alpha, + const pten::dtype::complex *X, + const int incX, + pten::dtype::complex *Y, + const int incY) { + paddle::platform::dynload::cblas_zaxpy(n, &alpha, X, incX, Y, incY); + } + + template + static void VCOPY(ARGS... args) { + paddle::platform::dynload::cblas_zcopy(args...); + } + + // the libmklml_intel.so paddle used has no vzAdd, vzSub, + // vzMul, vzDiv apis before rebuild from source + // so replace with the raw operator methods + /* + template + static void VADD(ARGS... args) { + paddle::platform::dynload::vzAdd(args...); + } + + template + static void VSUB(ARGS... args) { + paddle::platform::dynload::vzSub(args...); + } + + template + static void VMUL(ARGS... args) { + paddle::platform::dynload::vzMul(args...); + } + + template + static void VDIV(ARGS... args) { + paddle::platform::dynload::vzDiv(args...); + } + */ + + template + static void VADD(int n, + const pten::dtype::complex *a, + const pten::dtype::complex *b, + pten::dtype::complex *y) { + for (int i = 0; i < n; ++i) { + y[i] = a[i] + b[i]; + } + } + + template + static void VSUB(int n, + const pten::dtype::complex *a, + const pten::dtype::complex *b, + pten::dtype::complex *y) { + for (int i = 0; i < n; ++i) { + y[i] = a[i] - b[i]; + } + } + + template + static void VMUL(int n, + const pten::dtype::complex *a, + const pten::dtype::complex *b, + pten::dtype::complex *y) { + for (int i = 0; i < n; ++i) { + y[i] = a[i] * b[i]; + } + } + template + static void VDIV(int n, + const pten::dtype::complex *a, + const pten::dtype::complex *b, + pten::dtype::complex *y) { + for (int i = 0; i < n; ++i) { + y[i] = a[i] / b[i]; + } + } + + template + static void GEMV(CBLAS_LAYOUT layout, + CBLAS_TRANSPOSE trans, + int M, + int N, + pten::dtype::complex alpha, + const pten::dtype::complex *A, + int lda, + const pten::dtype::complex *X, + int incx, + pten::dtype::complex beta, + pten::dtype::complex *Y, + int incy) { + const void *a_ = (const void *)(A); + const void *x_ = (const void *)(X); + void *y_ = static_cast(Y); + paddle::platform::dynload::cblas_zgemv( + layout, trans, M, N, &alpha, a_, lda, x_, incx, &beta, y_, incy); + } + + template + static void GEMM(CBLAS_LAYOUT layout, + CBLAS_TRANSPOSE trans_a, + CBLAS_TRANSPOSE trans_b, + int M, + int N, + int K, + pten::dtype::complex alpha, + const pten::dtype::complex *A, + int lda, + const pten::dtype::complex *B, + int ldb, + pten::dtype::complex beta, + pten::dtype::complex *C, + int ldc) { + const void *a_ = (const void *)(A); + const void *b_ = (const void *)(B); + void *c_ = static_cast(C); + paddle::platform::dynload::cblas_zgemm(layout, + trans_a, + trans_b, + M, + N, + K, + &alpha, + a_, + lda, + b_, + ldb, + &beta, + c_, + ldc); + } + + static void TRSM(CBLAS_LAYOUT layout, + CBLAS_SIDE side, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans_a, + CBLAS_DIAG diag, + int M, + int N, + pten::dtype::complex alpha, + const pten::dtype::complex *A, + int lda, + pten::dtype::complex *B, + int ldb) { + const void *a_ = (const void *)(A); + void *b_ = static_cast(B); + paddle::platform::dynload::cblas_ztrsm( + layout, side, uplo, trans_a, diag, M, N, &alpha, a_, lda, b_, ldb); + } + + template + static void GEMM_BATCH(CBLAS_LAYOUT layout, + CBLAS_TRANSPOSE *trans_a, + CBLAS_TRANSPOSE *trans_b, + int *M, + int *N, + int *K, + pten::dtype::complex *alpha, + const pten::dtype::complex **A, + const int *lda, + const pten::dtype::complex **B, + const int *ldb, + pten::dtype::complex *beta, + pten::dtype::complex **C, + const int *ldc, + int group_count, + int *group_size) { + const void **A_void = (const void **)(&(*A)); + const void **B_void = (const void **)(&(*B)); + void **C_void = reinterpret_cast(C); + + paddle::platform::dynload::cblas_zgemm_batch(layout, + trans_a, + trans_b, + M, + N, + K, + alpha, + A_void, + lda, + B_void, + ldb, + beta, + C_void, + ldc, + group_count, + group_size); + } + + template + static void GEMM_EX(ARGS... args) { + paddle::platform::dynload::cblas_zgemm_batch(args...); + } +}; + +#else + +template <> +struct CBlas { + template + static void GEMM(ARGS... args) { + cblas_sgemm(args...); + } + + template + static void AXPY(ARGS... args) { + cblas_saxpy(args...); + } + + template + static void VCOPY(ARGS... args) { + cblas_scopy(args...); + } + + template + static void GEMV(ARGS... args) { + cblas_sgemv(args...); + } + + template + static void TRSM(ARGS... args) { + cblas_strsm(args...); + } +}; + +template <> +struct CBlas { + template + static void GEMM(ARGS... args) { + cblas_dgemm(args...); + } + + template + static void AXPY(ARGS... args) { + cblas_daxpy(args...); + } + + template + static void VCOPY(ARGS... args) { + cblas_dcopy(args...); + } + + template + static void GEMV(ARGS... args) { + cblas_dgemv(args...); + } + + template + static void TRSM(ARGS... args) { + cblas_dtrsm(args...); + } +}; + +template <> +struct CBlas> { + template + static void VCOPY(ARGS... args) { + cblas_ccopy(args...); + } + + template + static void AXPY(int n, + const pten::dtype::complex alpha, + const pten::dtype::complex *X, + const int incX, + pten::dtype::complex *Y, + const int incY) { + cblas_caxpy(n, &alpha, X, incX, Y, incY); + } + + template + static void GEMV(const CBLAS_LAYOUT layout, + const CBLAS_TRANSPOSE TransA, + const int M, + const int N, + const pten::dtype::complex alpha, + const pten::dtype::complex *A, + const int lda, + const pten::dtype::complex *X, + const int incX, + const pten::dtype::complex beta, + pten::dtype::complex *Y, + const int incY) { + cblas_cgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY); + } + + template + static void GEMM(const CBLAS_LAYOUT layout, + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int M, + const int N, + const int K, + const pten::dtype::complex alpha, + const pten::dtype::complex *A, + const int lda, + const pten::dtype::complex *B, + const int ldb, + const pten::dtype::complex beta, + pten::dtype::complex *C, + const int ldc) { + cblas_cgemm( + layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc); + } + + static void TRSM(const CBLAS_LAYOUT layout, + const CBLAS_SIDE side, + const CBLAS_UPLO uplo, + const CBLAS_TRANSPOSE transA, + const CBLAS_DIAG diag, + const int M, + const int N, + const pten::dtype::complex alpha, + const pten::dtype::complex *A, + const int lda, + pten::dtype::complex *B, + const int ldb) { + cblas_ctrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb); + } +}; + +template <> +struct CBlas> { + template + static void VCOPY(ARGS... args) { + cblas_zcopy(args...); + } + + template + static void AXPY(int n, + const pten::dtype::complex alpha, + const pten::dtype::complex *X, + const int incX, + pten::dtype::complex *Y, + const int incY) { + cblas_zaxpy(n, &alpha, X, incX, Y, incY); + } + + template + static void GEMV(const CBLAS_LAYOUT layout, + const CBLAS_TRANSPOSE TransA, + const int M, + const int N, + const pten::dtype::complex alpha, + const pten::dtype::complex *A, + const int lda, + const pten::dtype::complex *X, + const int incX, + const pten::dtype::complex beta, + pten::dtype::complex *Y, + const int incY) { + cblas_zgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY); + } + + template + static void GEMM(const CBLAS_LAYOUT layout, + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int M, + const int N, + const int K, + const pten::dtype::complex alpha, + const pten::dtype::complex *A, + const int lda, + const pten::dtype::complex *B, + const int ldb, + const pten::dtype::complex beta, + pten::dtype::complex *C, + const int ldc) { + cblas_zgemm( + layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc); + } + + static void TRSM(const CBLAS_LAYOUT layout, + const CBLAS_SIDE side, + const CBLAS_UPLO uplo, + const CBLAS_TRANSPOSE transA, + const CBLAS_DIAG diag, + const int M, + const int N, + const pten::dtype::complex alpha, + const pten::dtype::complex *A, + const int lda, + pten::dtype::complex *B, + const int ldb) { + cblas_ztrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb); + } +}; + +#endif + +template <> +struct CBlas { + static void GEMM(...) { + PADDLE_THROW(pten::errors::Unimplemented( + "float16 GEMM not supported on CPU, please check your code")); + } + + static void SMM_GEMM(...) { + PADDLE_THROW(pten::errors::Unimplemented( + "float16 SMM_GEMM not supported on CPU, please check your code")); + } + static void VMUL(...) { + PADDLE_THROW(pten::errors::Unimplemented( + "float16 VMUL not supported on CPU, please check your code")); + } + static void VEXP(...) { + PADDLE_THROW(pten::errors::Unimplemented( + "float16 VEXP not supported on CPU, please check your code")); + } + static void VSQUARE(...) { + PADDLE_THROW(pten::errors::Unimplemented( + "float16 VSQUARE not supported on CPU, please check your code")); + } + static void VPOW(...) { + PADDLE_THROW(pten::errors::Unimplemented( + "float16 VPOW not supported on CPU, please check your code")); + } + static void DOT(...) { + PADDLE_THROW(pten::errors::Unimplemented( + "float16 DOT not supported on CPU, please check your code")); + }; + static void SCAL(...) { + PADDLE_THROW(pten::errors::Unimplemented( + "float16 SCAL not supported on CPU, please check your code")); + }; + static void ASUM(...) { + PADDLE_THROW(pten::errors::Unimplemented( + "float16 ASUM not supported on CPU, please check your code")); + }; +#ifdef PADDLE_WITH_MKLML + static void GEMM_BATCH(...) { + PADDLE_THROW(pten::errors::Unimplemented( + "float16 GEMM_BATCH not supported on CPU, please check your code")); + } +#endif +}; + +#ifdef PADDLE_WITH_MKLML +template <> +template +T *Blas::GEMM_ALLOC( + const CBLAS_IDENTIFIER id, const int M, const int N, const int K) const { + return CBlas::GEMM_ALLOC(id, M, N, K); +} +template <> +template +T *Blas::GEMM_ALLOC(const CBLAS_IDENTIFIER id, + const int M, + const int N, + const int K) const { + return CBlas::GEMM_ALLOC(id, M, N, K); +} + +template <> +template +void Blas::GEMM_PACK( + const CBLAS_IDENTIFIER id, + const CBLAS_TRANSPOSE trans, + int M, + int N, + int K, + const T alpha, + const T *src, + const int ld, + T *dst) const { + CBlas::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst); +} +template <> +template +void Blas::GEMM_PACK(const CBLAS_IDENTIFIER id, + const CBLAS_TRANSPOSE trans, + int M, + int N, + int K, + const T alpha, + const T *src, + const int ld, + T *dst) const { + CBlas::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst); +} + +template <> +template +void Blas::GEMM_COMPUTE( + int transA, + int transB, + int M, + int N, + int K, + const T *A, + const int lda, + const T *B, + const int ldb, + T beta, + T *C, + const int ldc) const { + CBlas::GEMM_COMPUTE( + CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb, beta, C, ldc); +} +template <> +template +void Blas::GEMM_COMPUTE(int transA, + int transB, + int M, + int N, + int K, + const T *A, + const int lda, + const T *B, + const int ldb, + T beta, + T *C, + const int ldc) const { + CBlas::GEMM_COMPUTE( + CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb, beta, C, ldc); +} + +template <> +template +void Blas::GEMM_FREE(T *data) const { + CBlas::GEMM_FREE(data); +} +template <> +template +void Blas::GEMM_FREE(T *data) const { + CBlas::GEMM_FREE(data); +} +#endif + +template <> +template +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T *A, + const T *B, + T beta, + T *C) const { + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + CBlas::GEMM(CblasRowMajor, + transA, + transB, + M, + N, + K, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc); +} +template <> +template +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T *A, + const T *B, + T beta, + T *C) const { + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + CBlas::GEMM(CblasRowMajor, + transA, + transB, + M, + N, + K, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc); +} + +template <> +template +void Blas::GEMM(bool transA, + bool transB, + int M, + int N, + int K, + T alpha, + const T *A, + int lda, + const T *B, + int ldb, + T beta, + T *C, + int ldc) const { + CBlas::GEMM(CblasRowMajor, + transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, + M, + N, + K, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc); +} +template <> +template +void Blas::GEMM(bool transA, + bool transB, + int M, + int N, + int K, + T alpha, + const T *A, + int lda, + const T *B, + int ldb, + T beta, + T *C, + int ldc) const { + CBlas::GEMM(CblasRowMajor, + transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, + M, + N, + K, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc); +} + +template <> +template +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T *A, + int lda, + const T *B, + int ldb, + T beta, + T *C, + int ldc) const { + CBlas::GEMM(CblasRowMajor, + transA, + transB, + M, + N, + K, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc); +} +template <> +template +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T *A, + int lda, + const T *B, + int ldb, + T beta, + T *C, + int ldc) const { + CBlas::GEMM(CblasRowMajor, + transA, + transB, + M, + N, + K, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc); +} + +template +template +void Blas::MatMul(const pten::DenseTensor &mat_a, + bool trans_a, + const pten::DenseTensor &mat_b, + bool trans_b, + T alpha, + pten::DenseTensor *mat_out, + T beta) const { + auto dim_a = mat_a.dims(); + auto dim_b = mat_b.dims(); + auto dim_out = mat_out->dims(); + PADDLE_ENFORCE_EQ( + dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + true, + pten::errors::InvalidArgument( + "The input and output of matmul should be matrix, the dim size must " + "be 2," + "but received dim size input_a:%d, input_b:%d, output:%d", + dim_a.size(), + dim_b.size(), + dim_out.size())); + PADDLE_ENFORCE_EQ( + mat_a.place() == mat_b.place() && mat_a.place() == mat_out->place(), + true, + pten::errors::InvalidArgument("The places of matrices in the matmul " + "should be same, please check your " + "code.")); + + int M = dim_out[0]; + int N = dim_out[1]; + int K = !trans_a ? dim_a[1] : dim_a[0]; + + CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = !trans_b ? CblasNoTrans : CblasTrans; + + this->GEMM(transA, + transB, + M, + N, + K, + alpha, + mat_a.data(), + mat_b.data(), + beta, + mat_out->data()); +} + +template <> +template +void Blas::AXPY(int n, + T alpha, + const T *x, + T *y) const { + CBlas::AXPY(n, alpha, x, 1, y, 1); +} +template <> +template +void Blas::AXPY(int n, T alpha, const T *x, T *y) const { + CBlas::AXPY(n, alpha, x, 1, y, 1); +} + +template <> +template +void Blas::VCOPY(int n, + const T *x, + T *y) const { + CBlas::VCOPY(n, x, 1, y, 1); +} +template <> +template +void Blas::VCOPY(int n, const T *x, T *y) const { + CBlas::VCOPY(n, x, 1, y, 1); +} + +template <> +template +void Blas::VADD(int n, + const T *x, + const T *y, + T *z) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VADD(n, x, y, z); +#else + if (x == z) { + this->template AXPY(n, (T)(1.), y, z); + } else { + this->template VCOPY(n, y, z); + this->template AXPY(n, (T)(1.), x, z); + } +#endif +} +template <> +template +void Blas::VADD(int n, const T *x, const T *y, T *z) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VADD(n, x, y, z); +#else + if (x == z) { + this->template AXPY(n, (T)(1.), y, z); + } else { + this->template VCOPY(n, y, z); + this->template AXPY(n, (T)(1.), x, z); + } +#endif +} + +template <> +template +void Blas::VSUB(int n, + const T *x, + const T *y, + T *z) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VSUB(n, x, y, z); +#else + // try to find if openblas support vsub + for (int i = 0; i < n; ++i) { + z[i] = x[i] - y[i]; + } +#endif +} +template <> +template +void Blas::VSUB(int n, const T *x, const T *y, T *z) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VSUB(n, x, y, z); +#else + // try to find if openblas support vsub + for (int i = 0; i < n; ++i) { + z[i] = x[i] - y[i]; + } +#endif +} + +template <> +template +void Blas::VMUL(int n, + const T *x, + const T *y, + T *z) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VMUL(n, x, y, z); +#else + // try to find if openblas support vmul + for (int i = 0; i < n; ++i) { + z[i] = x[i] * y[i]; + } +#endif +} +template <> +template +void Blas::VMUL(int n, const T *x, const T *y, T *z) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VMUL(n, x, y, z); +#else + // try to find if openblas support vmul + for (int i = 0; i < n; ++i) { + z[i] = x[i] * y[i]; + } +#endif +} + +template <> +template +void Blas::VDIV(int n, + const T *x, + const T *y, + T *z) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VDIV(n, x, y, z); +#else + // try to find if openblas support vdiv + for (int i = 0; i < n; ++i) { + z[i] = x[i] / y[i]; + } +#endif +} +template <> +template +void Blas::VDIV(int n, const T *x, const T *y, T *z) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VDIV(n, x, y, z); +#else + // try to find if openblas support vdiv + for (int i = 0; i < n; ++i) { + z[i] = x[i] / y[i]; + } +#endif +} + +template <> +template +void Blas::VEXP(int n, + const T *x, + T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VEXP(n, x, y); +#else + // try to find if openblas support vexp + for (int i = 0; i < n; ++i) { + y[i] = std::exp(x[i]); + } +#endif +} +template <> +template +void Blas::VEXP(int n, const T *x, T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VEXP(n, x, y); +#else + // try to find if openblas support vexp + for (int i = 0; i < n; ++i) { + y[i] = std::exp(x[i]); + } +#endif +} + +template <> +template +void Blas::VSQUARE(int n, + const T *x, + T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VSQUARE(n, x, y); +#else + for (int i = 0; i < n; ++i) { + y[i] = x[i] * x[i]; + } +#endif +} +template <> +template +void Blas::VSQUARE(int n, const T *x, T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VSQUARE(n, x, y); +#else + for (int i = 0; i < n; ++i) { + y[i] = x[i] * x[i]; + } +#endif +} + +template <> +template +void Blas::VPOW(int n, + const T *x, + T a, + T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VPOW(n, x, a, y); +#else + for (int i = 0; i < n; ++i) { + y[i] = std::pow(x[i], a); + } +#endif +} +template <> +template +void Blas::VPOW(int n, const T *x, T a, T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VPOW(n, x, a, y); +#else + for (int i = 0; i < n; ++i) { + y[i] = std::pow(x[i], a); + } +#endif +} + +template <> +template +T Blas::DOT(int n, + const T *x, + const T *y) const { +#ifdef PADDLE_WITH_MKLML + return CBlas::DOT(n, x, 1, y, 1); +#else + // try to find if openblas support cblas_dot + T sum = 0; + for (int i = 0; i < n; ++i) { + sum += x[i] * y[i]; + } + return sum; +#endif +} +template <> +template +T Blas::DOT(int n, const T *x, const T *y) const { +#ifdef PADDLE_WITH_MKLML + return CBlas::DOT(n, x, 1, y, 1); +#else + // try to find if openblas support cblas_dot + T sum = 0; + for (int i = 0; i < n; ++i) { + sum += x[i] * y[i]; + } + return sum; +#endif +} + +template <> +template +void Blas::SCAL(int n, + const T a, + T *x) const { +#ifdef PADDLE_WITH_MKLML + CBlas::SCAL(n, a, x, 1); +#else + // try to find if openblas support cblas_scal + for (int i = 0; i < n; ++i) { + x[i] = a * x[i]; + } +#endif +} +template <> +template +void Blas::SCAL(int n, const T a, T *x) const { +#ifdef PADDLE_WITH_MKLML + CBlas::SCAL(n, a, x, 1); +#else + // try to find if openblas support cblas_scal + for (int i = 0; i < n; ++i) { + x[i] = a * x[i]; + } +#endif +} + +template <> +template +T Blas::ASUM(int n, T *x, int inc) const { + auto sum = static_cast(0.0); +#ifdef PADDLE_WITH_MKLML + sum = CBlas::ASUM(n, x, inc); +#else + // TODO(jczaja): check if openblas does provide cblas_sasum/cblas_dasum + for (int c = 0; c < n; ++c) { + sum += x[c]; + } +#endif + return sum; +} +template <> +template +T Blas::ASUM(int n, T *x, int inc) const { + auto sum = static_cast(0.0); +#ifdef PADDLE_WITH_MKLML + sum = CBlas::ASUM(n, x, inc); +#else + // TODO(jczaja): check if openblas does provide cblas_sasum/cblas_dasum + for (int c = 0; c < n; ++c) { + sum += x[c]; + } +#endif + return sum; +} + +template <> +template +void Blas::GEMV(bool trans_a, + int M, + int N, + T alpha, + const T *A, + const T *B, + T beta, + T *C) const { + CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; + CBlas::GEMV(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); +} +template <> +template +void Blas::GEMV(bool trans_a, + int M, + int N, + T alpha, + const T *A, + const T *B, + T beta, + T *C) const { + CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; + CBlas::GEMV(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); +} + +template <> +template +void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T *A, + const T *B, + T beta, + T *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { + PADDLE_ENFORCE_NOT_NULL( + A, pten::errors::InvalidArgument("Pointer A should not be null.")); + PADDLE_ENFORCE_NOT_NULL( + B, pten::errors::InvalidArgument("Pointer B should not be null.")); + PADDLE_ENFORCE_NOT_NULL( + C, pten::errors::InvalidArgument("Pointer C should not be null.")); +#ifdef PADDLE_WITH_MKLML + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + auto a_array = std::vector(batchCount); + auto b_array = std::vector(batchCount); + auto c_array = std::vector(batchCount); + for (int k = 0; k < batchCount; ++k) { + a_array[k] = &A[k * strideA]; + b_array[k] = &B[k * strideB]; + c_array[k] = &C[k * M * N]; + } + + CBlas::GEMM_BATCH(CblasRowMajor, + &transA, + &transB, + &M, + &N, + &K, + &alpha, + a_array.data(), + &lda, + b_array.data(), + &ldb, + &beta, + c_array.data(), + &ldc, + 1 /* group_count */, + &batchCount); +#else + for (int k = 0; k < batchCount; ++k) { + auto *Ak = &A[k * strideA]; + auto *Bk = &B[k * strideB]; + auto *Ck = &C[k * M * N]; + this->template GEMM(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck); + } +#endif +} +template <> +template +void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T *A, + const T *B, + T beta, + T *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { + PADDLE_ENFORCE_NOT_NULL( + A, pten::errors::InvalidArgument("Pointer A should not be null.")); + PADDLE_ENFORCE_NOT_NULL( + B, pten::errors::InvalidArgument("Pointer B should not be null.")); + PADDLE_ENFORCE_NOT_NULL( + C, pten::errors::InvalidArgument("Pointer C should not be null.")); +#ifdef PADDLE_WITH_MKLML + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + auto a_array = std::vector(batchCount); + auto b_array = std::vector(batchCount); + auto c_array = std::vector(batchCount); + for (int k = 0; k < batchCount; ++k) { + a_array[k] = &A[k * strideA]; + b_array[k] = &B[k * strideB]; + c_array[k] = &C[k * M * N]; + } + + CBlas::GEMM_BATCH(CblasRowMajor, + &transA, + &transB, + &M, + &N, + &K, + &alpha, + a_array.data(), + &lda, + b_array.data(), + &ldb, + &beta, + c_array.data(), + &ldc, + 1 /* group_count */, + &batchCount); +#else + for (int k = 0; k < batchCount; ++k) { + auto *Ak = &A[k * strideA]; + auto *Bk = &B[k * strideB]; + auto *Ck = &C[k * M * N]; + this->template GEMM(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck); + } +#endif +} + +template <> +template +void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T **A, + const T **B, + T beta, + T **C, + int batchCount) const { +#ifdef PADDLE_WITH_MKLML + const int lda = (std::max)((transA == CblasNoTrans) ? K : M, 1); + const int ldb = (std::max)((transB == CblasNoTrans) ? N : K, 1); + const int ldc = (std::max)(N, 1); + CBlas::GEMM_BATCH(CblasRowMajor, + &transA, + &transB, + &M, + &N, + &K, + &alpha, + A, + &lda, + B, + &ldb, + &beta, + C, + &ldc, + 1 /* group_count */, + &batchCount); +#else + for (int k = 0; k < batchCount; ++k) { + this->template GEMM( + transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); + } +#endif +} +template <> +template +void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T **A, + const T **B, + T beta, + T **C, + int batchCount) const { +#ifdef PADDLE_WITH_MKLML + const int lda = (std::max)((transA == CblasNoTrans) ? K : M, 1); + const int ldb = (std::max)((transB == CblasNoTrans) ? N : K, 1); + const int ldc = (std::max)(N, 1); + CBlas::GEMM_BATCH(CblasRowMajor, + &transA, + &transB, + &M, + &N, + &K, + &alpha, + A, + &lda, + B, + &ldb, + &beta, + C, + &ldc, + 1 /* group_count */, + &batchCount); +#else + for (int k = 0; k < batchCount; ++k) { + this->template GEMM( + transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); + } +#endif +} + +#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ + !defined(PADDLE_WITH_HIP) // @{ Group Blas MKLML: BatchedGEMMWithHead +template <> +template +void Blas::BatchedGEMMWithHead( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int W1, + int H1, + int W2, + int H2, + T alpha, + const T *A, + const T *B, + T beta, + T *C, + int batchCount, + int64_t strideA, + int64_t strideB, + int64_t head_number, + bool split_b_vertical) const { + int lda = (transA == CblasNoTrans) ? W1 : H1; + int ldb = (transB == CblasNoTrans) ? W2 : H2; + auto a_array = std::vector(batchCount); + auto b_array = std::vector(batchCount); + auto c_array = std::vector(batchCount); + + if (split_b_vertical) { + int ldc = W2; + int sub_width = W2 / head_number; + + for (int i = 0; i < head_number; i++) { + int sub_matA_offset = (transA == CblasNoTrans) + ? i * (W1 / head_number) + : i * (W1 / head_number) * H1; + int sub_matB_offset = (transB == CblasNoTrans) + ? i * (W2 / head_number) + : i * (W2 / head_number) * H2; + int sub_matC_offset = i * W2 / head_number; + for (int k = 0; k < batchCount; ++k) { + a_array[k] = &A[k * strideA] + sub_matA_offset; + b_array[k] = &B[k * strideB] + sub_matB_offset; + c_array[k] = &C[k * H1 * W2] + sub_matC_offset; + } + + CBlas::GEMM_BATCH(CblasRowMajor, + &transA, + &transB, + &H1, + &sub_width, + &H2, + &alpha, + a_array.data(), + &lda, + b_array.data(), + &ldb, + &beta, + c_array.data(), + &ldc, + 1 /* group_count */, + &batchCount); + } + + } else { + PADDLE_ENFORCE_EQ( + W1, + H2, + pten::errors::InvalidArgument( + "The fisrt matrix width should be same as second matrix height," + "but received fisrt matrix width %d" + ", second matrix height %d", + W1, + H2)); + int ldc = W2 * head_number; + int sub_width = W1 / head_number; + + for (int i = 0; i < head_number; i++) { + int sub_matA_offset = (transA == CblasNoTrans) + ? i * (W1 / head_number) + : i * (W1 / head_number) * H1; + int sub_matB_offset = (transB == CblasNoTrans) + ? i * (W1 / head_number) * W2 + : i * (W1 / head_number); + int sub_matC_offset = i * W2; + for (int k = 0; k < batchCount; ++k) { + a_array[k] = &A[k * strideA] + sub_matA_offset; + b_array[k] = &B[k * strideB] + sub_matB_offset; + c_array[k] = &C[k * H1 * head_number * W2] + sub_matC_offset; + } + + CBlas::GEMM_BATCH(CblasRowMajor, + &transA, + &transB, + &H1, + &W2, + &sub_width, + &alpha, + a_array.data(), + &lda, + b_array.data(), + &ldb, + &beta, + c_array.data(), + &ldc, + 1 /* group_count */, + &batchCount); + } + } +} +template <> +template +void Blas::BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int W1, + int H1, + int W2, + int H2, + T alpha, + const T *A, + const T *B, + T beta, + T *C, + int batchCount, + int64_t strideA, + int64_t strideB, + int64_t head_number, + bool split_b_vertical) const { + int lda = (transA == CblasNoTrans) ? W1 : H1; + int ldb = (transB == CblasNoTrans) ? W2 : H2; + auto a_array = std::vector(batchCount); + auto b_array = std::vector(batchCount); + auto c_array = std::vector(batchCount); + + if (split_b_vertical) { + int ldc = W2; + int sub_width = W2 / head_number; + + for (int i = 0; i < head_number; i++) { + int sub_matA_offset = (transA == CblasNoTrans) + ? i * (W1 / head_number) + : i * (W1 / head_number) * H1; + int sub_matB_offset = (transB == CblasNoTrans) + ? i * (W2 / head_number) + : i * (W2 / head_number) * H2; + int sub_matC_offset = i * W2 / head_number; + for (int k = 0; k < batchCount; ++k) { + a_array[k] = &A[k * strideA] + sub_matA_offset; + b_array[k] = &B[k * strideB] + sub_matB_offset; + c_array[k] = &C[k * H1 * W2] + sub_matC_offset; + } + + CBlas::GEMM_BATCH(CblasRowMajor, + &transA, + &transB, + &H1, + &sub_width, + &H2, + &alpha, + a_array.data(), + &lda, + b_array.data(), + &ldb, + &beta, + c_array.data(), + &ldc, + 1 /* group_count */, + &batchCount); + } + + } else { + PADDLE_ENFORCE_EQ( + W1, + H2, + pten::errors::InvalidArgument( + "The fisrt matrix width should be same as second matrix height," + "but received fisrt matrix width %d" + ", second matrix height %d", + W1, + H2)); + int ldc = W2 * head_number; + int sub_width = W1 / head_number; + + for (int i = 0; i < head_number; i++) { + int sub_matA_offset = (transA == CblasNoTrans) + ? i * (W1 / head_number) + : i * (W1 / head_number) * H1; + int sub_matB_offset = (transB == CblasNoTrans) + ? i * (W1 / head_number) * W2 + : i * (W1 / head_number); + int sub_matC_offset = i * W2; + for (int k = 0; k < batchCount; ++k) { + a_array[k] = &A[k * strideA] + sub_matA_offset; + b_array[k] = &B[k * strideB] + sub_matB_offset; + c_array[k] = &C[k * H1 * head_number * W2] + sub_matC_offset; + } + + CBlas::GEMM_BATCH(CblasRowMajor, + &transA, + &transB, + &H1, + &W2, + &sub_width, + &alpha, + a_array.data(), + &lda, + b_array.data(), + &ldb, + &beta, + c_array.data(), + &ldc, + 1 /* group_count */, + &batchCount); + } + } +} +#endif // @} End Group Blas MKLML: BatchedGEMMWithHead + +template +template +void Blas::MatMul( + const int M, const int N, const int K, const T *A, const T *B, T *C) const { + this->template GEMM(CblasRowMajor, + CblasNoTrans, + CblasNoTrans, + M, + N, + K, + static_cast(1), + A, + K, + B, + N, + static_cast(0), + C, + N); +} + +template <> +template +void Blas::MatMul( + const int M, const int N, const int K, const T *A, const T *B, T *C) const { +#ifdef PADDLE_WITH_LIBXSMM + // Refer to https://github.com/hfp/libxsmm/blob/master/README.md + // But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; + + // Since the matrix is very small, + // so the unit of calculation is already very fast, + // and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead, + // use xsmm directly. + // Note: SMM use ColMajor + const char transa = 'N'; + const char transb = 'N'; + const T alpha = static_cast(1); + const T beta = static_cast(0); + CBlas::SMM_GEMM( + &transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta, C, &N); + return; +#endif + + CBlas::GEMM(CblasRowMajor, + CblasNoTrans, + CblasNoTrans, + M, + N, + K, + static_cast(1), + A, + K, + B, + N, + static_cast(0), + C, + N); +} +template <> +template +void Blas::MatMul( + const int M, const int N, const int K, const T *A, const T *B, T *C) const { +#ifdef PADDLE_WITH_LIBXSMM + // Refer to https://github.com/hfp/libxsmm/blob/master/README.md + // But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; + + // Since the matrix is very small, + // so the unit of calculation is already very fast, + // and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead, + // use xsmm directly. + // Note: SMM use ColMajor + const char transa = 'N'; + const char transb = 'N'; + const T alpha = static_cast(1); + const T beta = static_cast(0); + CBlas::SMM_GEMM( + &transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta, C, &N); + return; +#endif + + CBlas::GEMM(CblasRowMajor, + CblasNoTrans, + CblasNoTrans, + M, + N, + K, + static_cast(1), + A, + K, + B, + N, + static_cast(0), + C, + N); +} + +template +template +void Blas::MatMul(const pten::DenseTensor &mat_a, + const MatDescriptor &dim_a, + const pten::DenseTensor &mat_b, + const MatDescriptor &dim_b, + T alpha, + pten::DenseTensor *mat_out, + T beta) const { + MatMul(mat_a.data(), + dim_a, + mat_b.data(), + dim_b, + alpha, + mat_out->data(), + beta); +} + +template +template +void Blas::MatMul(const T *mat_a, + const MatDescriptor &dim_a, + const T *mat_b, + const MatDescriptor &dim_b, + T alpha, + T *mat_out, + T beta) const { + PADDLE_ENFORCE_EQ( + dim_a.width_, + dim_b.height_, + pten::errors::InvalidArgument( + "The fisrt matrix width should be same as second matrix height," + "but received fisrt matrix width %d" + ", second matrix height %d", + dim_a.width_, + dim_b.height_)); + + CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; + if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { + this->template GEMM(transA, + transB, + dim_a.height_, + dim_b.width_, + dim_a.width_, + alpha, + mat_a, + mat_b, + beta, + mat_out); + } else { + PADDLE_ENFORCE_EQ( + dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || + dim_b.batch_size_ == 0, + true, + pten::errors::InvalidArgument( + "dim_a.batch_size should be equal to dim_b.batch_size, or " + "one of dim_a.batch_size and dim_b.batch_size should be 0. " + "But got dim_a.batch_size = %d, dim_b.batch_size = %d.", + dim_a.batch_size_, + dim_b.batch_size_)); + this->template BatchedGEMM( + transA, + transB, + dim_a.height_, + dim_b.width_, + dim_a.width_, + alpha, + mat_a, + mat_b, + beta, + mat_out, + dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, + dim_a.stride_, + dim_b.stride_); + } +} + +#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ + !defined(PADDLE_WITH_HIP) +// @{ Group Blas MKLML: MatMulWithHead +/* + * Multiple two matrixes with multiple heads + * + * A new parameter, i.e head_number is added compared to normal MatMul. + * The head_number describes the number of heads a matrix is vertically + * split. + * + * When user calls this API, the multiplication of two big matrixes is split + * into multiplication of several (head_number_) small matrixes. e.g. if Mat A + * is [3, 24] and Mat B is [24, 4], when multiple A and B with head_number as + * 4, Mat A will be split as 4 matrix of [3, 6] and Mat B will be + * (horizontally) split as 4 matrix of [6, 4]. The result of final matrix + * will be 4 matrix of [3, 4], i.e. [3, 16]. + * Another example is A is [3, 8], B is [2, 16], head_number is 4. In this + * case, A will be split as [3, 2], B will be (vertically) split as + * [2, 4]. The final result will be 4 matrix of 4 matrix of [3,4], i.e. [3, 16] + */ +template +template +void Blas::MatMulWithHead(const pten::DenseTensor &mat_a, + const MatDescriptor &dim_a, + const pten::DenseTensor &mat_b, + const MatDescriptor &dim_b, + T alpha, + int head_number, + pten::DenseTensor *mat_out, + T beta, + bool mat_b_split_vertical) const { + PADDLE_ENFORCE_EQ( + dim_a.width_ % head_number, + 0, + pten::errors::InvalidArgument( + "The first input width must be some times the head number" + "but received first input width %d" + ", head_number %d", + dim_a.width_, + head_number)); + PADDLE_ENFORCE_GE( + head_number, + 1, + pten::errors::InvalidArgument("The head number should be greater equal 1," + "but received head number %d", + head_number)); + PADDLE_ENFORCE_LE( + head_number, + dim_a.width_, + pten::errors::InvalidArgument( + "The head number should be less equal first input width," + "but received first input width %d" + ", head_number %d", + dim_a.width_, + head_number)); + CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; + + if (mat_b_split_vertical) { + PADDLE_ENFORCE_EQ( + dim_b.height_, + dim_a.width_ / head_number, + pten::errors::InvalidArgument( + "The second input height should be equal than first input width," + "but received second input height %d, first input width %d", + dim_b.height_, + dim_a.width_ / head_number)); + PADDLE_ENFORCE_EQ( + dim_a.width_ % head_number, + 0, + pten::errors::InvalidArgument( + "The second input width should be some times the head number" + "but received second input width %d" + ", head_number %d", + dim_b.width_, + head_number)); + } + + if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { + int lda = !dim_a.trans_ ? dim_a.width_ : dim_a.height_; + int ldb = !dim_b.trans_ ? dim_b.width_ : dim_b.height_; + int sub_matA_offset; + int sub_matB_offset; + int sub_matC_offset; + int sub_mat_M = dim_a.height_; + int sub_mat_N; + int sub_mat_K; + int ldc; + + for (int i = 0; i < head_number; i++) { + sub_matA_offset = dim_a.trans_ + ? i * (dim_a.width_ / head_number) * dim_a.height_ + : i * (dim_a.width_ / head_number); + if (mat_b_split_vertical) { + sub_matB_offset = dim_b.trans_ + ? i * (dim_b.width_ / head_number) * dim_b.height_ + : i * (dim_b.width_ / head_number); + sub_matC_offset = i * dim_b.width_ / head_number; + + sub_mat_N = dim_b.width_ / head_number; + sub_mat_K = dim_b.height_; + + ldc = dim_b.width_; + } else { + sub_matB_offset = + dim_b.trans_ ? i * (dim_b.height_ / head_number) + : i * (dim_b.height_ / head_number) * dim_b.width_; + sub_matC_offset = i * dim_b.width_; + + sub_mat_N = dim_b.width_; + sub_mat_K = dim_a.width_ / head_number; + + ldc = head_number * dim_b.width_; + } + + this->template GEMM(transA, + transB, + sub_mat_M, + sub_mat_N, + sub_mat_K, + alpha, + mat_a.data() + sub_matA_offset, + lda, + mat_b.data() + sub_matB_offset, + ldb, + beta, + mat_out->data() + sub_matC_offset, + ldc); + } + } else { + PADDLE_ENFORCE_EQ( + (dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || + dim_b.batch_size_ == 0), + true, + pten::errors::InvalidArgument( + "The first input batch size should be equal than second input," + "either two input batch size is 0, but received first input batch " + "size" + " %d, second input batch size %d", + dim_a.batch_size_, + dim_b.batch_size_)); + + this->template BatchedGEMMWithHead( + transA, + transB, + dim_a.width_, + dim_a.height_, + dim_b.width_, + dim_b.height_, + alpha, + mat_a.data(), + mat_b.data(), + beta, + mat_out->data(), + dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, + dim_a.stride_, + dim_b.stride_, + head_number, + mat_b_split_vertical); + } +} +#endif // @} End Group Blas MKLML: MatMulWithHead + +template +template +void Blas::VINV(int n, const T *a, T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VINV(n, a, y); +#else + for (int i = 0; i < n; ++i) { + y[i] = 1.0 / a[i]; + } +#endif +} + +template <> +template +void Blas::VMERF(int n, + const T *a, + T *y, + int64_t mode) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VMERF(n, a, y, mode); +#else + for (int i = 0; i < n; ++i) { + y[i] = std::erf(a[i]); + } +#endif +} +template <> +template +void Blas::VMERF(int n, + const T *a, + T *y, + int64_t mode) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VMERF(n, a, y, mode); +#else + for (int i = 0; i < n; ++i) { + y[i] = std::erf(a[i]); + } +#endif +} + +#ifdef PADDLE_WITH_MKLML +template <> +template +void Blas::CSRMM(const char *transa, + const int *m, + const int *n, + const int *k, + const T *alpha, + const char *matdescra, + const T *val, + const int *indx, + const int *pntrb, + const int *pntre, + const T *b, + const int *ldb, + const T *beta, + T *c, + const int *ldc) const { + CBlas::CSRMM(transa, + m, + n, + k, + alpha, + matdescra, + val, + indx, + pntrb, + pntre, + b, + ldb, + beta, + c, + ldc); +} +template <> +template +void Blas::CSRMM(const char *transa, + const int *m, + const int *n, + const int *k, + const T *alpha, + const char *matdescra, + const T *val, + const int *indx, + const int *pntrb, + const int *pntre, + const T *b, + const int *ldb, + const T *beta, + T *c, + const int *ldc) const { + CBlas::CSRMM(transa, + m, + n, + k, + alpha, + matdescra, + val, + indx, + pntrb, + pntre, + b, + ldb, + beta, + c, + ldc); +} +#endif + +template <> +template +void Blas::TRSM(CBLAS_SIDE side, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE transA, + CBLAS_DIAG diag, + int M, + int N, + T alpha, + const T *A, + int lda, + T *B, + int ldb) const { + CBlas::TRSM( + CblasRowMajor, side, uplo, transA, diag, M, N, alpha, A, lda, B, ldb); +} +template <> +template +void Blas::TRSM(CBLAS_SIDE side, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE transA, + CBLAS_DIAG diag, + int M, + int N, + T alpha, + const T *A, + int lda, + T *B, + int ldb) const { + CBlas::TRSM( + CblasRowMajor, side, uplo, transA, diag, M, N, alpha, A, lda, B, ldb); +} + +} // namespace funcs +} // namespace pten diff --git a/paddle/pten/kernels/funcs/blas/blas_impl.hip.h b/paddle/pten/kernels/funcs/blas/blas_impl.hip.h new file mode 100644 index 0000000000000..e2d264e3f8e70 --- /dev/null +++ b/paddle/pten/kernels/funcs/blas/blas_impl.hip.h @@ -0,0 +1,2276 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/dynload/rocblas.h" +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/kernels/funcs/math_function.h" + +DECLARE_bool(enable_cublas_tensor_op_math); + +namespace pten { +namespace funcs { + +template +struct CUBlas; + +template <> +struct CUBlas { + template + static void GEMM(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_sgemm(args...)); + } + + template + static void AXPY(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_saxpy(args...)); + } + + template + static void SCAL(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_sscal(args...)); + } + + template + static void VCOPY(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_scopy(args...)); + } + + template + static void GEMV(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_sgemv(args...)); + } + + template + static void GEMM_STRIDED_BATCH(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_sgemm_strided_batched(args...)); + } + + // HIP not supportted, refer to the doc here: + // https://github.com/ROCm-Developer-Tools/HIP/blob/roc-3.5.x/docs/markdown/CUBLAS_API_supported_by_HIP.md + template + static void GEMM_EX(ARGS... args) { + PADDLE_THROW(pten::errors::Unimplemented( + "cublasSgemmEx is not supported on HIP platform.")); + } + + template + static void TRSM(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_strsm(args...)); + } + + template + static void GETRF_BATCH(ARGS... args) { + PADDLE_THROW(pten::errors::Unimplemented( + "cublasSgetrfBatched is not supported on HIP platform.")); + } + + template + static void GETRI_BATCH(ARGS... args) { + PADDLE_THROW(pten::errors::Unimplemented( + "cublasSgetriBatched is not supported on HIP platform.")); + } + + template + static void MATINV_BATCH(ARGS... args) { + PADDLE_THROW(pten::errors::Unimplemented( + "cublasSmatinvBatched is not supported on HIP platform.")); + } + + template + static void TRSM_BATCH(ARGS... args) { + PADDLE_THROW(pten::errors::Unimplemented( + "cublasStrsmBatched is not supported on HIP platform.")); + } +}; + +template <> +struct CUBlas { + template + static void GEMM(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_dgemm(args...)); + } + + template + static void AXPY(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_daxpy(args...)); + } + + template + static void SCAL(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_dscal(args...)); + } + + template + static void VCOPY(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_dcopy(args...)); + } + + template + static void GEMV(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_dgemv(args...)); + } + + template + static void GEMM_STRIDED_BATCH(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_dgemm_strided_batched(args...)); + } + + template + static void GEMM_EX(ARGS... args) { + PADDLE_THROW( + pten::errors::Unimplemented("Currently there are not cublasDgemmEx.")); + } + + template + static void TRSM(ARGS... args) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_dtrsm(args...)); + } + + template + static void GETRF_BATCH(ARGS... args) { + PADDLE_THROW(pten::errors::Unimplemented( + "cublasDgetrfBatched is not supported on HIP platform.")); + } + + template + static void GETRI_BATCH(ARGS... args) { + PADDLE_THROW(pten::errors::Unimplemented( + "cublasDgetriBatched is not supported on HIP platform.")); + } + + template + static void MATINV_BATCH(ARGS... args) { + PADDLE_THROW(pten::errors::Unimplemented( + "cublasDmatinvBatched is not supported on HIP platform.")); + } + + template + static void TRSM_BATCH(ARGS... args) { + PADDLE_THROW(pten::errors::Unimplemented( + "cublasDtrsmBatched is not supported on HIP platform.")); + } +}; + +template <> +struct CUBlas { + using float16 = pten::dtype::float16; + + static void GEMM(rocblas_handle handle, + rocblas_operation transa, + rocblas_operation transb, + int m, + int n, + int k, + const float16 *alpha, + const float16 *A, + int lda, + const float16 *B, + int ldb, + const float16 *beta, + float16 *C, + int ldc) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::rocblas_hgemm( + handle, + transa, + transb, + m, + n, + k, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + reinterpret_cast(B), + ldb, + reinterpret_cast(beta), + reinterpret_cast(C), + ldc)); + } + + static void GEMM_STRIDED_BATCH(rocblas_handle handle, + rocblas_operation transa, + rocblas_operation transb, + int m, + int n, + int k, + const float16 *alpha, + const float16 *A, + int lda, + long long int strideA, // NOLINT + const float16 *B, // NOLINT + int ldb, + long long int strideB, // NOLINT + const float16 *beta, + float16 *C, + int ldc, + long long int strideC, // NOLINT + int batchCount) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_hgemm_strided_batched( + handle, + transa, + transb, + m, + n, + k, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + strideA, + reinterpret_cast(B), + ldb, + strideB, + reinterpret_cast(beta), + reinterpret_cast(C), + ldc, + strideC, + batchCount)); + } + + // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. + // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode + template + static void GEMM_EX(paddle::platform::CUDADeviceContext *dev_ctx, + rocblas_operation transa, + rocblas_operation transb, + int m, + int n, + int k, + const void *alpha, + const void *A, + rocblas_datatype Atype, + int lda, + const void *B, + rocblas_datatype Btype, + int ldb, + const void *beta, + void *C, + rocblas_datatype Ctype, + int ldc, + rocblas_datatype computeType) { + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_gemm_ex(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc, + C, + Ctype, + ldc, + computeType, + algo, + 0, + 0)); + }); + } + template + static void GEMM_EX(pten::GPUContext *dev_ctx, + rocblas_operation transa, + rocblas_operation transb, + int m, + int n, + int k, + const void *alpha, + const void *A, + rocblas_datatype Atype, + int lda, + const void *B, + rocblas_datatype Btype, + int ldb, + const void *beta, + void *C, + rocblas_datatype Ctype, + int ldc, + rocblas_datatype computeType) { + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_gemm_ex(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc, + C, + Ctype, + ldc, + computeType, + algo, + 0, + 0)); + }); + } +}; + +template <> +struct CUBlas> { + static void GEMV(rocblas_handle handle, + rocblas_operation transa, + int m, + int n, + const pten::dtype::complex *alpha, + const pten::dtype::complex *A, + int lda, + const pten::dtype::complex *B, + int ldb, + const pten::dtype::complex *beta, + pten::dtype::complex *C, + int ldc) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::rocblas_cgemv( + handle, + transa, + m, + n, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + reinterpret_cast(B), + ldb, + reinterpret_cast(beta), + reinterpret_cast(C), + ldc)); + } + + static void AXPY(rocblas_handle handle, + int n, + const pten::dtype::complex *alpha, + const pten::dtype::complex *X, + const int incX, + pten::dtype::complex *Y, + const int incY) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::rocblas_caxpy( + handle, + n, + reinterpret_cast(alpha), + reinterpret_cast(X), + incX, + reinterpret_cast(Y), + incY)); + } + + static void GEMM_STRIDED_BATCH( + rocblas_handle handle, + rocblas_operation transa, + rocblas_operation transb, + int m, + int n, + int k, + const pten::dtype::complex *alpha, + const pten::dtype::complex *A, + int lda, + long long int strideA, // NOLINT + const pten::dtype::complex *B, // NOLINT + int ldb, + long long int strideB, // NOLINT + const pten::dtype::complex *beta, + pten::dtype::complex *C, + int ldc, + long long int strideC, // NOLINT + int batchCount) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_cgemm_strided_batched( + handle, + transa, + transb, + m, + n, + k, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + strideA, + reinterpret_cast(B), + ldb, + strideB, + reinterpret_cast(beta), + reinterpret_cast(C), + ldc, + strideC, + batchCount)); + } + + static void GEMM(rocblas_handle handle, + rocblas_operation transa, + rocblas_operation transb, + int m, + int n, + int k, + const pten::dtype::complex *alpha, + const pten::dtype::complex *A, + int lda, + const pten::dtype::complex *B, + int ldb, + const pten::dtype::complex *beta, + pten::dtype::complex *C, + int ldc) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::rocblas_cgemm( + handle, + transa, + transb, + m, + n, + k, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + reinterpret_cast(B), + ldb, + reinterpret_cast(beta), + reinterpret_cast(C), + ldc)); + } + + // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. + // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode + template + static void GEMM_EX(paddle::platform::CUDADeviceContext *dev_ctx, + rocblas_operation transa, + rocblas_operation transb, + int m, + int n, + int k, + const void *alpha, + const void *A, + rocblas_datatype Atype, + int lda, + const void *B, + rocblas_datatype Btype, + int ldb, + const void *beta, + void *C, + rocblas_datatype Ctype, + int ldc, + rocblas_datatype computeType) { + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_gemm_ex(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc, + C, + Ctype, + ldc, + computeType, + algo, + 0, + 0)); + }); + } + template + static void GEMM_EX(pten::GPUContext *dev_ctx, + rocblas_operation transa, + rocblas_operation transb, + int m, + int n, + int k, + const void *alpha, + const void *A, + rocblas_datatype Atype, + int lda, + const void *B, + rocblas_datatype Btype, + int ldb, + const void *beta, + void *C, + rocblas_datatype Ctype, + int ldc, + rocblas_datatype computeType) { + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_gemm_ex(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc, + C, + Ctype, + ldc, + computeType, + algo, + 0, + 0)); + }); + } +}; + +template <> +struct CUBlas> { + static void GEMV(rocblas_handle handle, + rocblas_operation transa, + int m, + int n, + const pten::dtype::complex *alpha, + const pten::dtype::complex *A, + int lda, + const pten::dtype::complex *B, + int ldb, + const pten::dtype::complex *beta, + pten::dtype::complex *C, + int ldc) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::rocblas_zgemv( + handle, + transa, + m, + n, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + reinterpret_cast(B), + ldb, + reinterpret_cast(beta), + reinterpret_cast(C), + ldc)); + } + + static void AXPY(rocblas_handle handle, + int n, + const pten::dtype::complex *alpha, + const pten::dtype::complex *X, + const int incX, + pten::dtype::complex *Y, + const int incY) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::rocblas_zaxpy( + handle, + n, + reinterpret_cast(alpha), + reinterpret_cast(X), + incX, + reinterpret_cast(Y), + incY)); + } + + static void GEMM_STRIDED_BATCH( + rocblas_handle handle, + rocblas_operation transa, + rocblas_operation transb, + int m, + int n, + int k, + const pten::dtype::complex *alpha, + const pten::dtype::complex *A, + int lda, + long long int strideA, // NOLINT + const pten::dtype::complex *B, // NOLINT + int ldb, + long long int strideB, // NOLINT + const pten::dtype::complex *beta, + pten::dtype::complex *C, + int ldc, + long long int strideC, // NOLINT + int batchCount) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_zgemm_strided_batched( + handle, + transa, + transb, + m, + n, + k, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + strideA, + reinterpret_cast(B), + ldb, + strideB, + reinterpret_cast(beta), + reinterpret_cast(C), + ldc, + strideC, + batchCount)); + } + + static void GEMM(rocblas_handle handle, + rocblas_operation transa, + rocblas_operation transb, + int m, + int n, + int k, + const pten::dtype::complex *alpha, + const pten::dtype::complex *A, + int lda, + const pten::dtype::complex *B, + int ldb, + const pten::dtype::complex *beta, + pten::dtype::complex *C, + int ldc) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::rocblas_zgemm( + handle, + transa, + transb, + m, + n, + k, + reinterpret_cast(alpha), + reinterpret_cast(A), + lda, + reinterpret_cast(B), + ldb, + reinterpret_cast(beta), + reinterpret_cast(C), + ldc)); + } + + // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. + // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode + template + static void GEMM_EX(paddle::platform::CUDADeviceContext *dev_ctx, + rocblas_operation transa, + rocblas_operation transb, + int m, + int n, + int k, + const void *alpha, + const void *A, + rocblas_datatype Atype, + int lda, + const void *B, + rocblas_datatype Btype, + int ldb, + const void *beta, + void *C, + rocblas_datatype Ctype, + int ldc, + rocblas_datatype computeType) { + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_gemm_ex(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc, + C, + Ctype, + ldc, + computeType, + algo, + 0, + 0)); + }); + } + template + static void GEMM_EX(pten::GPUContext *dev_ctx, + rocblas_operation transa, + rocblas_operation transb, + int m, + int n, + int k, + const void *alpha, + const void *A, + rocblas_datatype Atype, + int lda, + const void *B, + rocblas_datatype Btype, + int ldb, + const void *beta, + void *C, + rocblas_datatype Ctype, + int ldc, + rocblas_datatype computeType) { + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_gemm_ex(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc, + C, + Ctype, + ldc, + computeType, + algo, + 0, + 0)); + }); + } +}; + +template <> +template +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T *A, + const T *B, + T beta, + T *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + N); + }); +} +template <> +template +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T *A, + const T *B, + T beta, + T *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + N); + }); +} + +template <> +template <> +inline void Blas::GEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::float16 alpha, + const pten::dtype::float16 *A, + const pten::dtype::float16 *B, + pten::dtype::float16 beta, + pten::dtype::float16 *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 53, + pten::errors::InvalidArgument( + "cublas fp16 gemm requires GPU compute capability >= 53," + "but received %d", + context_.GetComputeCapability())); + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + + auto &cuda_ctx = const_cast(context_); + CUBlas::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + rocblas_datatype_f16_r, + ldb, + A, + rocblas_datatype_f16_r, + lda, + &h_beta, + C, + rocblas_datatype_f16_r, + N, + rocblas_datatype_f32_r); +} +template <> +template <> +inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::float16 alpha, + const pten::dtype::float16 *A, + const pten::dtype::float16 *B, + pten::dtype::float16 beta, + pten::dtype::float16 *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 53, + pten::errors::InvalidArgument( + "cublas fp16 gemm requires GPU compute capability >= 53," + "but received %d", + context_.GetComputeCapability())); + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + + auto &cuda_ctx = const_cast(context_); + CUBlas::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + rocblas_datatype_f16_r, + ldb, + A, + rocblas_datatype_f16_r, + lda, + &h_beta, + C, + rocblas_datatype_f16_r, + N, + rocblas_datatype_f32_r); +} + +template <> +template <> +inline void Blas::GEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::bfloat16 alpha, + const pten::dtype::bfloat16 *A, + const pten::dtype::bfloat16 *B, + pten::dtype::bfloat16 beta, + pten::dtype::bfloat16 *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + // TODO(zhiqiu): 80 has the same meaning for rocm and cuda? + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 80, + pten::errors::InvalidArgument( + "rocblas fp16 gemm requires GPU compute capability >= 80," + "but received %d", + context_.GetComputeCapability())); + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + + context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_gemm_ex(handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + rocblas_datatype_bf16_r, + ldb, + A, + rocblas_datatype_bf16_r, + lda, + &h_beta, + C, + rocblas_datatype_bf16_r, + N, + C, + rocblas_datatype_bf16_r, + N, + rocblas_datatype_f32_r, + algo, + 0, + 0)); + }); +} + +template <> +template <> +inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::bfloat16 alpha, + const pten::dtype::bfloat16 *A, + const pten::dtype::bfloat16 *B, + pten::dtype::bfloat16 beta, + pten::dtype::bfloat16 *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + // TODO(zhiqiu): 80 has the same meaning for rocm and cuda? + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 80, + pten::errors::InvalidArgument( + "rocblas fp16 gemm requires GPU compute capability >= 80," + "but received %d", + context_.GetComputeCapability())); + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + + context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_gemm_ex(handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + rocblas_datatype_bf16_r, + ldb, + A, + rocblas_datatype_bf16_r, + lda, + &h_beta, + C, + rocblas_datatype_bf16_r, + N, + C, + rocblas_datatype_bf16_r, + N, + rocblas_datatype_f32_r, + algo, + 0, + 0)); + }); +} + +template <> +template <> +inline void Blas::GEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::complex alpha, + const pten::dtype::complex *A, + const pten::dtype::complex *B, + pten::dtype::complex beta, + pten::dtype::complex *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 53, + pten::errors::InvalidArgument( + "cublas complex64 gemm requires GPU compute capability >= 53," + "but received %d", + context_.GetComputeCapability())); + + thrust::complex c_alpha = + thrust::complex(alpha.real, alpha.imag); + thrust::complex c_beta = thrust::complex(beta.real, beta.imag); + + auto &cuda_ctx = const_cast(context_); + CUBlas>::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &c_alpha, + B, + rocblas_datatype_f32_c, + ldb, + A, + rocblas_datatype_f32_c, + lda, + &c_beta, + C, + rocblas_datatype_f32_c, + N, + rocblas_datatype_f32_c); +} +template <> +template <> +inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::complex alpha, + const pten::dtype::complex *A, + const pten::dtype::complex *B, + pten::dtype::complex beta, + pten::dtype::complex *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 53, + pten::errors::InvalidArgument( + "cublas complex64 gemm requires GPU compute capability >= 53," + "but received %d", + context_.GetComputeCapability())); + + thrust::complex c_alpha = + thrust::complex(alpha.real, alpha.imag); + thrust::complex c_beta = thrust::complex(beta.real, beta.imag); + + auto &cuda_ctx = const_cast(context_); + CUBlas>::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &c_alpha, + B, + rocblas_datatype_f32_c, + ldb, + A, + rocblas_datatype_f32_c, + lda, + &c_beta, + C, + rocblas_datatype_f32_c, + N, + rocblas_datatype_f32_c); +} + +template <> +template <> +inline void Blas::GEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::complex alpha, + const pten::dtype::complex *A, + const pten::dtype::complex *B, + pten::dtype::complex beta, + pten::dtype::complex *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 53, + pten::errors::InvalidArgument( + "cublas complex128 gemm requires GPU compute capability >= 53," + "but received %d", + context_.GetComputeCapability())); + + thrust::complex c_alpha = + thrust::complex(alpha.real, alpha.imag); + thrust::complex c_beta = + thrust::complex(beta.real, beta.imag); + + auto &cuda_ctx = const_cast(context_); + CUBlas>::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &c_alpha, + B, + rocblas_datatype_f64_c, + ldb, + A, + rocblas_datatype_f64_c, + lda, + &c_beta, + C, + rocblas_datatype_f64_c, + N, + rocblas_datatype_f64_c); +} +template <> +template <> +inline void Blas::GEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::complex alpha, + const pten::dtype::complex *A, + const pten::dtype::complex *B, + pten::dtype::complex beta, + pten::dtype::complex *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 53, + pten::errors::InvalidArgument( + "cublas complex128 gemm requires GPU compute capability >= 53," + "but received %d", + context_.GetComputeCapability())); + + thrust::complex c_alpha = + thrust::complex(alpha.real, alpha.imag); + thrust::complex c_beta = + thrust::complex(beta.real, beta.imag); + + auto &cuda_ctx = const_cast(context_); + CUBlas>::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &c_alpha, + B, + rocblas_datatype_f64_c, + ldb, + A, + rocblas_datatype_f64_c, + lda, + &c_beta, + C, + rocblas_datatype_f64_c, + N, + rocblas_datatype_f64_c); +} + +template <> +template +void Blas::GEMM(bool transA, + bool transB, + int M, + int N, + int K, + T alpha, + const T *A, + int lda, + const T *B, + int ldb, + T beta, + T *C, + int ldc) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + rocblas_operation cuTransA = + transA ? rocblas_operation_transpose : rocblas_operation_none; + rocblas_operation cuTransB = + transB ? rocblas_operation_transpose : rocblas_operation_none; + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + ldc); + }); +} +template <> +template +void Blas::GEMM(bool transA, + bool transB, + int M, + int N, + int K, + T alpha, + const T *A, + int lda, + const T *B, + int ldb, + T beta, + T *C, + int ldc) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + rocblas_operation cuTransA = + transA ? rocblas_operation_transpose : rocblas_operation_none; + rocblas_operation cuTransB = + transB ? rocblas_operation_transpose : rocblas_operation_none; + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + ldc); + }); +} + +template <> +template <> +inline void Blas::GEMM( + bool transA, + bool transB, + int M, + int N, + int K, + pten::dtype::float16 alpha, + const pten::dtype::float16 *A, + int lda, + const pten::dtype::float16 *B, + int ldb, + pten::dtype::float16 beta, + pten::dtype::float16 *C, + int ldc) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + rocblas_operation cuTransA = + transA ? rocblas_operation_transpose : rocblas_operation_none; + rocblas_operation cuTransB = + transB ? rocblas_operation_transpose : rocblas_operation_none; + + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + ldc); + }); +} +template <> +template <> +inline void Blas::GEMM(bool transA, + bool transB, + int M, + int N, + int K, + pten::dtype::float16 alpha, + const pten::dtype::float16 *A, + int lda, + const pten::dtype::float16 *B, + int ldb, + pten::dtype::float16 beta, + pten::dtype::float16 *C, + int ldc) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + rocblas_operation cuTransA = + transA ? rocblas_operation_transpose : rocblas_operation_none; + rocblas_operation cuTransB = + transB ? rocblas_operation_transpose : rocblas_operation_none; + + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + ldc); + }); +} + +template <> +template +void Blas::AXPY(int n, + T alpha, + const T *x, + T *y) const { + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::AXPY(handle, n, &alpha, x, 1, y, 1); + }); +} +template <> +template +void Blas::AXPY(int n, T alpha, const T *x, T *y) const { + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::AXPY(handle, n, &alpha, x, 1, y, 1); + }); +} + +template <> +template +void Blas::SCAL(int n, + const T alpha, + T *x) const { + context_.CublasCall( + [&](rocblas_handle handle) { CUBlas::SCAL(handle, n, &alpha, x, 1); }); +} +template <> +template +void Blas::SCAL(int n, const T alpha, T *x) const { + context_.CublasCall( + [&](rocblas_handle handle) { CUBlas::SCAL(handle, n, &alpha, x, 1); }); +} + +template <> +template +void Blas::VCOPY(int n, + const T *x, + T *y) const { + context_.CublasCall( + [&](rocblas_handle handle) { CUBlas::VCOPY(handle, n, x, 1, y, 1); }); +} +template <> +template +void Blas::VCOPY(int n, const T *x, T *y) const { + context_.CublasCall( + [&](rocblas_handle handle) { CUBlas::VCOPY(handle, n, x, 1, y, 1); }); +} + +template <> +template +void Blas::GEMV(bool trans_a, + int M, + int N, + T alpha, + const T *A, + const T *B, + T beta, + T *C) const { + rocblas_operation cuTransA = + !trans_a ? rocblas_operation_transpose : rocblas_operation_none; + + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1); + }); +} +template <> +template +void Blas::GEMV(bool trans_a, + int M, + int N, + T alpha, + const T *A, + const T *B, + T beta, + T *C) const { + rocblas_operation cuTransA = + !trans_a ? rocblas_operation_transpose : rocblas_operation_none; + + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1); + }); +} + +template <> +template <> +inline void Blas::GEMV( + bool trans_a, + int M, + int N, + pten::dtype::float16 alpha, + const pten::dtype::float16 *A, + const pten::dtype::float16 *B, + pten::dtype::float16 beta, + pten::dtype::float16 *C) const { + // Because cublas doesn't support half gemv, we use cublasHgemm to achieve it. + if (trans_a) { + this->template GEMM( + CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C); + } else { + this->template GEMM( + CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C); + } +} +template <> +template <> +inline void Blas::GEMV(bool trans_a, + int M, + int N, + pten::dtype::float16 alpha, + const pten::dtype::float16 *A, + const pten::dtype::float16 *B, + pten::dtype::float16 beta, + pten::dtype::float16 *C) const { + // Because cublas doesn't support half gemv, we use cublasHgemm to achieve it. + if (trans_a) { + this->template GEMM( + CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C); + } else { + this->template GEMM( + CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C); + } +} + +template <> +template <> +inline void Blas::GEMV( + bool trans_a, + int M, + int N, + pten::dtype::bfloat16 alpha, + const pten::dtype::bfloat16 *A, + const pten::dtype::bfloat16 *B, + pten::dtype::bfloat16 beta, + pten::dtype::bfloat16 *C) const { + // Because rocblas doesn't support bfloat16 gemv, we use gemmex to achieve it. + if (trans_a) { + this->template GEMM( + CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C); + } else { + this->template GEMM( + CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C); + } +} +template <> +template <> +inline void Blas::GEMV(bool trans_a, + int M, + int N, + pten::dtype::bfloat16 alpha, + const pten::dtype::bfloat16 *A, + const pten::dtype::bfloat16 *B, + pten::dtype::bfloat16 beta, + pten::dtype::bfloat16 *C) const { + // Because rocblas doesn't support bfloat16 gemv, we use gemmex to achieve it. + if (trans_a) { + this->template GEMM( + CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C); + } else { + this->template GEMM( + CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C); + } +} + +template <> +template +void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T *A, + const T *B, + T beta, + T *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + const int64_t strideC = M * N; + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::GEMM_STRIDED_BATCH(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + strideB, + A, + lda, + strideA, + &beta, + C, + ldc, + strideC, + batchCount); + }); +} + +template <> +template +void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T *A, + const T *B, + T beta, + T *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + const int64_t strideC = M * N; + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::GEMM_STRIDED_BATCH(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + strideB, + A, + lda, + strideA, + &beta, + C, + ldc, + strideC, + batchCount); + }); +} + +template <> +template <> +inline void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::bfloat16 alpha, + const pten::dtype::bfloat16 *A, + const pten::dtype::bfloat16 *B, + pten::dtype::bfloat16 beta, + pten::dtype::bfloat16 *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + const int64_t strideC = M * N; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + + context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_gemm_strided_batched_ex( + handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + rocblas_datatype_bf16_r, + ldb, + strideB, + A, + rocblas_datatype_bf16_r, + lda, + strideA, + &h_beta, + C, + rocblas_datatype_bf16_r, + ldc, + strideC, + C, + rocblas_datatype_bf16_r, + ldc, + strideC, + batchCount, + rocblas_datatype_f32_r, + algo, + 0, + 0)); + }); +} + +template <> +template <> +inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::bfloat16 alpha, + const pten::dtype::bfloat16 *A, + const pten::dtype::bfloat16 *B, + pten::dtype::bfloat16 beta, + pten::dtype::bfloat16 *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + const int64_t strideC = M * N; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + + context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_gemm_strided_batched_ex( + handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + rocblas_datatype_bf16_r, + ldb, + strideB, + A, + rocblas_datatype_bf16_r, + lda, + strideA, + &h_beta, + C, + rocblas_datatype_bf16_r, + ldc, + strideC, + C, + rocblas_datatype_bf16_r, + ldc, + strideC, + batchCount, + rocblas_datatype_f32_r, + algo, + 0, + 0)); + }); +} + +template <> +template +void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T **A, + const T **B, + T beta, + T **C, + int batchCount) const { + for (int k = 0; k < batchCount; ++k) { + this->template GEMM( + transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); + } +} + +template <> +template +void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + T alpha, + const T **A, + const T **B, + T beta, + T **C, + int batchCount) const { + for (int k = 0; k < batchCount; ++k) { + this->template GEMM( + transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); + } +} + +template <> +template <> +inline void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::float16 alpha, + const pten::dtype::float16 **A, + const pten::dtype::float16 **B, + pten::dtype::float16 beta, + pten::dtype::float16 **C, + int batchCount) const { + for (int k = 0; k < batchCount; ++k) { + this->template GEMM( + transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); + } +} +template <> +template <> +inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::float16 alpha, + const pten::dtype::float16 **A, + const pten::dtype::float16 **B, + pten::dtype::float16 beta, + pten::dtype::float16 **C, + int batchCount) const { + for (int k = 0; k < batchCount; ++k) { + this->template GEMM( + transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); + } +} + +template <> +template <> +inline void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::bfloat16 alpha, + const pten::dtype::bfloat16 **A, + const pten::dtype::bfloat16 **B, + pten::dtype::bfloat16 beta, + pten::dtype::bfloat16 **C, + int batchCount) const { + for (int k = 0; k < batchCount; ++k) { + this->template GEMM( + transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); + } +} + +template <> +template <> +inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + pten::dtype::bfloat16 alpha, + const pten::dtype::bfloat16 **A, + const pten::dtype::bfloat16 **B, + pten::dtype::bfloat16 beta, + pten::dtype::bfloat16 **C, + int batchCount) const { + for (int k = 0; k < batchCount; ++k) { + this->template GEMM( + transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); + } +} + +template <> +template +void Blas::TRSM(CBLAS_SIDE side, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE transA, + CBLAS_DIAG diag, + int M, + int N, + T alpha, + const T *A, + int lda, + T *B, + int ldb) const { + // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` + // where ' stands for transpose + rocblas_side cuSide = + (side == CblasLeft) ? rocblas_side_right : rocblas_side_left; + rocblas_fill cuUplo = + (uplo == CblasLower) ? rocblas_fill_upper : rocblas_fill_lower; + // use CUBLAS_OP_C (conjugate transpose) for complex + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_diagonal cuDiag = + (diag == CblasUnit) ? rocblas_diagonal_unit : rocblas_diagonal_non_unit; + + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::TRSM( + handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A, lda, B, ldb); + }); +} +template <> +template +void Blas::TRSM(CBLAS_SIDE side, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE transA, + CBLAS_DIAG diag, + int M, + int N, + T alpha, + const T *A, + int lda, + T *B, + int ldb) const { + // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` + // where ' stands for transpose + rocblas_side cuSide = + (side == CblasLeft) ? rocblas_side_right : rocblas_side_left; + rocblas_fill cuUplo = + (uplo == CblasLower) ? rocblas_fill_upper : rocblas_fill_lower; + // use CUBLAS_OP_C (conjugate transpose) for complex + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_diagonal cuDiag = + (diag == CblasUnit) ? rocblas_diagonal_unit : rocblas_diagonal_non_unit; + + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::TRSM( + handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A, lda, B, ldb); + }); +} + +template <> +template +void Blas::BatchedGETRF( + int n, T **a, int *ipiv, int *info, int batch_size) const { + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size); + }); +} +template <> +template +void Blas::BatchedGETRF( + int n, T **a, int *ipiv, int *info, int batch_size) const { + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size); + }); +} + +template <> +template +void Blas::BatchedGETRI( + int n, const T **a, const int *ipiv, T **a_inv, int *info, int batch_size) + const { + PADDLE_ENFORCE_NE( + a_inv, + a, + pten::errors::InvalidArgument( + "cuBLAS fuction 'cublasgetrfBatched' cannot be executed " + "in-place. The memory space of output matrix (address: %p) cannot " + "overlap memory space of input matrix (address: %p).", + a_inv, + a)); + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size); + }); +} +template <> +template +void Blas::BatchedGETRI(int n, + const T **a, + const int *ipiv, + T **a_inv, + int *info, + int batch_size) const { + PADDLE_ENFORCE_NE( + a_inv, + a, + pten::errors::InvalidArgument( + "cuBLAS fuction 'cublasgetrfBatched' cannot be executed " + "in-place. The memory space of output matrix (address: %p) cannot " + "overlap memory space of input matrix (address: %p).", + a_inv, + a)); + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size); + }); +} + +template <> +template +void Blas::BatchedMatInv( + int n, const T **a, T **a_inv, int *info, int batch_size) const { + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size); + }); +} +template <> +template +void Blas::BatchedMatInv( + int n, const T **a, T **a_inv, int *info, int batch_size) const { + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size); + }); +} + +template <> +template +void Blas::BatchedGETRS( + CBLAS_TRANSPOSE trans, + int n, + int nrhs, + const T **a, + int lda, + int *ipiv, + T **b, + int ldb, + int *info, + int batch_size) const { + rocblas_operation cuTrans = (trans == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::GETRS_BATCH( + handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info, batch_size); + }); +} +template <> +template +void Blas::BatchedGETRS(CBLAS_TRANSPOSE trans, + int n, + int nrhs, + const T **a, + int lda, + int *ipiv, + T **b, + int ldb, + int *info, + int batch_size) const { + rocblas_operation cuTrans = (trans == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::GETRS_BATCH( + handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info, batch_size); + }); +} + +template <> +template +void Blas::BatchedTRSM( + CBLAS_SIDE side, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE transA, + CBLAS_DIAG diag, + int M, + int N, + T alpha, + const T **A, + int lda, + T **B, + int ldb, + int batch_size) const { + // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` + // where ' stands for transpose + rocblas_side cuSide = + (side == CblasLeft) ? rocblas_side_right : rocblas_side_left; + rocblas_fill cuUplo = + (uplo == CblasLower) ? rocblas_fill_upper : rocblas_fill_lower; + // use CUBLAS_OP_C (conjugate transpose) for complex + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_diagonal cuDiag = + (diag == CblasUnit) ? rocblas_diagonal_unit : rocblas_diagonal_non_unit; + + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::TRSM_BATCH(handle, + cuSide, + cuUplo, + cuTransA, + cuDiag, + N, + M, + &alpha, + A, + lda, + B, + ldb, + batch_size); + }); +} +template <> +template +void Blas::BatchedTRSM(CBLAS_SIDE side, + CBLAS_UPLO uplo, + CBLAS_TRANSPOSE transA, + CBLAS_DIAG diag, + int M, + int N, + T alpha, + const T **A, + int lda, + T **B, + int ldb, + int batch_size) const { + // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` + // where ' stands for transpose + rocblas_side cuSide = + (side == CblasLeft) ? rocblas_side_right : rocblas_side_left; + rocblas_fill cuUplo = + (uplo == CblasLower) ? rocblas_fill_upper : rocblas_fill_lower; + // use CUBLAS_OP_C (conjugate transpose) for complex + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_diagonal cuDiag = + (diag == CblasUnit) ? rocblas_diagonal_unit : rocblas_diagonal_non_unit; + + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::TRSM_BATCH(handle, + cuSide, + cuUplo, + cuTransA, + cuDiag, + N, + M, + &alpha, + A, + lda, + B, + ldb, + batch_size); + }); +} + +} // namespace funcs +} // namespace pten diff --git a/paddle/pten/kernels/funcs/functors.h b/paddle/pten/kernels/funcs/functors.h index 8b2bdfd0b1e32..1c170bb1f0f8d 100644 --- a/paddle/pten/kernels/funcs/functors.h +++ b/paddle/pten/kernels/funcs/functors.h @@ -19,30 +19,12 @@ limitations under the License. */ namespace pten { namespace funcs { - -// // MulFunctor -// // NOTE(chenfeiyu): IT IS NOLONGER USED, use pten::funcs::MultiplyFunctor -// instead -// template -// struct MulFunctor { -// // out = x * y; -// inline HOSTDEVICE T operator()(T x, T y) { return x * y; } -// }; - template struct MulGradFunctor { inline HOSTDEVICE T Dx(T x, T y) { return y; } inline HOSTDEVICE T Dy(T x, T y) { return x; } }; -// // AddFunctor -// // NOTE(chenfeiyu): IT IS NOLONGER USED, use pten::funcs::AddFunctor instead -// template -// struct AddFunctor { -// // out = x + y; -// inline HOSTDEVICE T operator()(T x, T y) { return x + y; } -// }; - template struct MaxFunctor { inline HOSTDEVICE T operator()(T a, T b) const { return a < b ? b : a; } diff --git a/paddle/pten/kernels/funcs/lapack/CMakeLists.txt b/paddle/pten/kernels/funcs/lapack/CMakeLists.txt new file mode 100644 index 0000000000000..ffff5ae8abe2a --- /dev/null +++ b/paddle/pten/kernels/funcs/lapack/CMakeLists.txt @@ -0,0 +1 @@ +math_library(lapack_function DEPS dynload_lapack) diff --git a/paddle/pten/kernels/funcs/lapack/lapack_function.cc b/paddle/pten/kernels/funcs/lapack/lapack_function.cc new file mode 100644 index 0000000000000..45f6faaeee8bb --- /dev/null +++ b/paddle/pten/kernels/funcs/lapack/lapack_function.cc @@ -0,0 +1,509 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/funcs/lapack/lapack_function.h" +#include "paddle/fluid/platform/dynload/lapack.h" +#include "paddle/pten/common/complex.h" + +namespace pten { +namespace funcs { + +// LU (for example) +template <> +void lapackLu(int m, int n, double *a, int lda, int *ipiv, int *info) { + paddle::platform::dynload::dgetrf_(&m, &n, a, &lda, ipiv, info); +} + +template <> +void lapackLu(int m, int n, float *a, int lda, int *ipiv, int *info) { + paddle::platform::dynload::sgetrf_(&m, &n, a, &lda, ipiv, info); +} + +// eigh +template <> +void lapackEigh(char jobz, + char uplo, + int n, + float *a, + int lda, + float *w, + float *work, + int lwork, + float *rwork, + int lrwork, + int *iwork, + int liwork, + int *info) { + (void)rwork; // unused + (void)lrwork; // unused + paddle::platform::dynload::ssyevd_( + &jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info); +} + +template <> +void lapackEigh(char jobz, + char uplo, + int n, + double *a, + int lda, + double *w, + double *work, + int lwork, + double *rwork, + int lrwork, + int *iwork, + int liwork, + int *info) { + (void)rwork; // unused + (void)lrwork; // unused + paddle::platform::dynload::dsyevd_( + &jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info); +} + +template <> +void lapackEigh, float>( + char jobz, + char uplo, + int n, + pten::dtype::complex *a, + int lda, + float *w, + pten::dtype::complex *work, + int lwork, + float *rwork, + int lrwork, + int *iwork, + int liwork, + int *info) { + paddle::platform::dynload::cheevd_( + &jobz, + &uplo, + &n, + reinterpret_cast *>(a), + &lda, + w, + reinterpret_cast *>(work), + &lwork, + rwork, + &lrwork, + iwork, + &liwork, + info); +} + +template <> +void lapackEigh, double>( + char jobz, + char uplo, + int n, + pten::dtype::complex *a, + int lda, + double *w, + pten::dtype::complex *work, + int lwork, + double *rwork, + int lrwork, + int *iwork, + int liwork, + int *info) { + paddle::platform::dynload::zheevd_( + &jobz, + &uplo, + &n, + reinterpret_cast *>(a), + &lda, + w, + reinterpret_cast *>(work), + &lwork, + rwork, + &lrwork, + iwork, + &liwork, + info); +} + +// Eig +template <> +void lapackEig(char jobvl, + char jobvr, + int n, + double *a, + int lda, + double *w, + double *vl, + int ldvl, + double *vr, + int ldvr, + double *work, + int lwork, + double *rwork, + int *info) { + double *wr = w; + double *wi = w + n; + (void)rwork; // unused + paddle::platform::dynload::dgeev_(&jobvl, + &jobvr, + &n, + a, + &lda, + wr, + wi, + vl, + &ldvl, + vr, + &ldvr, + work, + &lwork, + info); +} + +template <> +void lapackEig(char jobvl, + char jobvr, + int n, + float *a, + int lda, + float *w, + float *vl, + int ldvl, + float *vr, + int ldvr, + float *work, + int lwork, + float *rwork, + int *info) { + float *wr = w; + float *wi = w + n; + (void)rwork; // unused + paddle::platform::dynload::sgeev_(&jobvl, + &jobvr, + &n, + a, + &lda, + wr, + wi, + vl, + &ldvl, + vr, + &ldvr, + work, + &lwork, + info); +} + +template <> +void lapackEig, double>( + char jobvl, + char jobvr, + int n, + pten::dtype::complex *a, + int lda, + pten::dtype::complex *w, + pten::dtype::complex *vl, + int ldvl, + pten::dtype::complex *vr, + int ldvr, + pten::dtype::complex *work, + int lwork, + double *rwork, + int *info) { + paddle::platform::dynload::zgeev_( + &jobvl, + &jobvr, + &n, + reinterpret_cast *>(a), + &lda, + reinterpret_cast *>(w), + reinterpret_cast *>(vl), + &ldvl, + reinterpret_cast *>(vr), + &ldvr, + reinterpret_cast *>(work), + &lwork, + rwork, + info); +} + +template <> +void lapackEig, float>( + char jobvl, + char jobvr, + int n, + pten::dtype::complex *a, + int lda, + pten::dtype::complex *w, + pten::dtype::complex *vl, + int ldvl, + pten::dtype::complex *vr, + int ldvr, + pten::dtype::complex *work, + int lwork, + float *rwork, + int *info) { + paddle::platform::dynload::cgeev_( + &jobvl, + &jobvr, + &n, + reinterpret_cast *>(a), + &lda, + reinterpret_cast *>(w), + reinterpret_cast *>(vl), + &ldvl, + reinterpret_cast *>(vr), + &ldvr, + reinterpret_cast *>(work), + &lwork, + rwork, + info); +} + +template <> +void lapackGels(char trans, + int m, + int n, + int nrhs, + double *a, + int lda, + double *b, + int ldb, + double *work, + int lwork, + int *info) { + paddle::platform::dynload::dgels_( + &trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info); +} + +template <> +void lapackGels(char trans, + int m, + int n, + int nrhs, + float *a, + int lda, + float *b, + int ldb, + float *work, + int lwork, + int *info) { + paddle::platform::dynload::sgels_( + &trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info); +} + +template <> +void lapackGelsd(int m, + int n, + int nrhs, + double *a, + int lda, + double *b, + int ldb, + double *s, + double rcond, + int *rank, + double *work, + int lwork, + double *rwork, + int *iwork, + int *info) { + paddle::platform::dynload::dgelsd_(&m, + &n, + &nrhs, + a, + &lda, + b, + &ldb, + s, + &rcond, + rank, + work, + &lwork, + iwork, + info); +} + +template <> +void lapackGelsd(int m, + int n, + int nrhs, + float *a, + int lda, + float *b, + int ldb, + float *s, + float rcond, + int *rank, + float *work, + int lwork, + float *rwork, + int *iwork, + int *info) { + paddle::platform::dynload::sgelsd_(&m, + &n, + &nrhs, + a, + &lda, + b, + &ldb, + s, + &rcond, + rank, + work, + &lwork, + iwork, + info); +} + +template <> +void lapackGelsy(int m, + int n, + int nrhs, + double *a, + int lda, + double *b, + int ldb, + int *jpvt, + double rcond, + int *rank, + double *work, + int lwork, + double *rwork, + int *info) { + paddle::platform::dynload::dgelsy_( + &m, &n, &nrhs, a, &lda, b, &ldb, jpvt, &rcond, rank, work, &lwork, info); +} + +template <> +void lapackGelsy(int m, + int n, + int nrhs, + float *a, + int lda, + float *b, + int ldb, + int *jpvt, + float rcond, + int *rank, + float *work, + int lwork, + float *rwork, + int *info) { + paddle::platform::dynload::sgelsy_( + &m, &n, &nrhs, a, &lda, b, &ldb, jpvt, &rcond, rank, work, &lwork, info); +} + +template <> +void lapackGelss(int m, + int n, + int nrhs, + double *a, + int lda, + double *b, + int ldb, + double *s, + double rcond, + int *rank, + double *work, + int lwork, + double *rwork, + int *info) { + paddle::platform::dynload::dgelss_( + &m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank, work, &lwork, info); +} + +template <> +void lapackGelss(int m, + int n, + int nrhs, + float *a, + int lda, + float *b, + int ldb, + float *s, + float rcond, + int *rank, + float *work, + int lwork, + float *rwork, + int *info) { + paddle::platform::dynload::sgelss_( + &m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank, work, &lwork, info); +} + +template <> +void lapackCholeskySolve>( + char uplo, + int n, + int nrhs, + pten::dtype::complex *a, + int lda, + pten::dtype::complex *b, + int ldb, + int *info) { + paddle::platform::dynload::zpotrs_( + &uplo, + &n, + &nrhs, + reinterpret_cast *>(a), + &lda, + reinterpret_cast *>(b), + &ldb, + info); +} + +template <> +void lapackCholeskySolve>( + char uplo, + int n, + int nrhs, + pten::dtype::complex *a, + int lda, + pten::dtype::complex *b, + int ldb, + int *info) { + paddle::platform::dynload::cpotrs_(&uplo, + &n, + &nrhs, + reinterpret_cast *>(a), + &lda, + reinterpret_cast *>(b), + &ldb, + info); +} + +template <> +void lapackCholeskySolve(char uplo, + int n, + int nrhs, + double *a, + int lda, + double *b, + int ldb, + int *info) { + paddle::platform::dynload::dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info); +} + +template <> +void lapackCholeskySolve(char uplo, + int n, + int nrhs, + float *a, + int lda, + float *b, + int ldb, + int *info) { + paddle::platform::dynload::spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info); +} + +} // namespace funcs +} // namespace pten diff --git a/paddle/pten/kernels/funcs/lapack/lapack_function.h b/paddle/pten/kernels/funcs/lapack/lapack_function.h new file mode 100644 index 0000000000000..10dcd0d88563d --- /dev/null +++ b/paddle/pten/kernels/funcs/lapack/lapack_function.h @@ -0,0 +1,128 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace pten { +namespace funcs { + +// LU (for example) +template +void lapackLu(int m, int n, T *a, int lda, int *ipiv, int *info); + +// Eigh +template +void lapackEigh(char jobz, + char uplo, + int n, + T *a, + int lda, + ValueType *w, + T *work, + int lwork, + ValueType *rwork, + int lrwork, + int *iwork, + int liwork, + int *info); + +// Eig +template +void lapackEig(char jobvl, + char jobvr, + int n, + T1 *a, + int lda, + T1 *w, + T1 *vl, + int ldvl, + T1 *vr, + int ldvr, + T1 *work, + int lwork, + T2 *rwork, + int *info); + +// Gels +template +void lapackGels(char trans, + int m, + int n, + int nrhs, + T *a, + int lda, + T *b, + int ldb, + T *work, + int lwork, + int *info); + +// Gelsd +template +void lapackGelsd(int m, + int n, + int nrhs, + T1 *a, + int lda, + T1 *b, + int ldb, + T2 *s, + T2 rcond, + int *rank, + T1 *work, + int lwork, + T2 *rwork, + int *iwork, + int *info); + +// Gelsy +template +void lapackGelsy(int m, + int n, + int nrhs, + T1 *a, + int lda, + T1 *b, + int ldb, + int *jpvt, + T2 rcond, + int *rank, + T1 *work, + int lwork, + T2 *rwork, + int *info); + +// Gelss +template +void lapackGelss(int m, + int n, + int nrhs, + T1 *a, + int lda, + T1 *b, + int ldb, + T2 *s, + T2 rcond, + int *rank, + T1 *work, + int lwork, + T2 *rwork, + int *info); + +template +void lapackCholeskySolve( + char uplo, int n, int nrhs, T *a, int lda, T *b, int ldb, int *info); + +} // namespace funcs +} // namespace pten diff --git a/paddle/pten/kernels/funcs/math_function.cu b/paddle/pten/kernels/funcs/math_function.cu index d202e46da8bd9..e61e5dc68289b 100644 --- a/paddle/pten/kernels/funcs/math_function.cu +++ b/paddle/pten/kernels/funcs/math_function.cu @@ -16,10 +16,10 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/float16.h" #include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/eigen/common.h" #include "paddle/pten/kernels/funcs/math_function.h" #include "paddle/pten/kernels/funcs/math_function_impl.h" @@ -315,8 +315,7 @@ void ColwiseSum::operator()( one.mutable_data({in_dims[0]}, context.GetPlace()); SetConstant set; set(context, &one, static_cast(1.0)); - paddle::operators::math::GetBlas( - context) + pten::funcs::GetBlas(context) .GEMV(true, static_cast(in_dims[0]), static_cast(in_dims[1]), @@ -352,8 +351,7 @@ void RowwiseSum::operator()( one.mutable_data({size}, context.GetPlace()); SetConstant set; set(context, &one, static_cast(1.0)); - paddle::operators::math::GetBlas( - context) + pten::funcs::GetBlas(context) .GEMV(true, static_cast(in_dims[1]), static_cast(in_dims[0]), diff --git a/paddle/pten/kernels/impl/matmul_grad_kernel_impl.h b/paddle/pten/kernels/impl/matmul_grad_kernel_impl.h index f84187484b194..56d32a27f11a8 100644 --- a/paddle/pten/kernels/impl/matmul_grad_kernel_impl.h +++ b/paddle/pten/kernels/impl/matmul_grad_kernel_impl.h @@ -105,11 +105,9 @@ void MatMul(const Context& dev_ctx, DenseTensor* out, bool flag = false) { dev_ctx.template Alloc(out); - auto blas = paddle::operators::math::GetBlas(dev_ctx); - auto mat_dim_a = - paddle::operators::math::CreateMatrixDescriptor(a.dims(), 0, trans_a); - auto mat_dim_b = - paddle::operators::math::CreateMatrixDescriptor(b.dims(), 0, trans_b); + auto blas = pten::funcs::GetBlas(dev_ctx); + auto mat_dim_a = pten::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); + auto mat_dim_b = pten::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); if (a.dims().size() == 3 && b.dims().size() <= 2) { // the transpose_X must be false, if is true, the transpose cost much time if (!trans_a) { @@ -155,7 +153,7 @@ static DDim ColumnMatrixFromVector(const DDim& y_dim) { * If transposed, `H,W` will be swapped. */ static void ReshapeTensorIntoMatrixSequence( - DenseTensor* x, const paddle::operators::math::MatDescriptor& descriptor) { + DenseTensor* x, const pten::funcs::MatDescriptor& descriptor) { int64_t h, w; h = descriptor.height_; w = descriptor.width_; @@ -176,10 +174,8 @@ static void ReshapeXYOutIntoMatrixSequence(DenseTensor* x, bool trans_y) { auto x_dim = RowMatrixFromVector(x->dims()); auto y_dim = ColumnMatrixFromVector(y->dims()); - auto mat_dim_x = - paddle::operators::math::CreateMatrixDescriptor(x_dim, 0, trans_x); - auto mat_dim_y = - paddle::operators::math::CreateMatrixDescriptor(y_dim, 0, trans_y); + auto mat_dim_x = pten::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x); + auto mat_dim_y = pten::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y); if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { out->Resize({mat_dim_x.height_, mat_dim_y.width_}); } else { diff --git a/paddle/pten/kernels/impl/matmul_kernel_impl.h b/paddle/pten/kernels/impl/matmul_kernel_impl.h index addea622f1402..c237271f242d2 100644 --- a/paddle/pten/kernels/impl/matmul_kernel_impl.h +++ b/paddle/pten/kernels/impl/matmul_kernel_impl.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/complex_functors.h" #include "paddle/pten/core/dense_tensor.h" @@ -102,7 +102,7 @@ void MatMulFunction(const Context& dev_ctx, const T* x_data = X.data(); const T* y_data = Y.data(); - auto blas = paddle::operators::math::GetBlas(dev_ctx); + auto blas = pten::funcs::GetBlas(dev_ctx); if (x_ndim == 1 && y_ndim == 1) { const int M = X.numel(); diff --git a/paddle/pten/tests/kernels/test_math_function.cc b/paddle/pten/tests/kernels/test_math_function.cc index 0d53ff6c637ba..ab0ca0e5a1734 100644 --- a/paddle/pten/tests/kernels/test_math_function.cc +++ b/paddle/pten/tests/kernels/test_math_function.cc @@ -13,17 +13,16 @@ // limitations under the License. #include "gtest/gtest.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace pten { namespace tests { template -inline paddle::operators::math::BlasT -GetBlas(const paddle::platform::CPUDeviceContext& context) { - return paddle::operators::math::GetBlas(context); +inline pten::funcs::BlasT GetBlas( + const paddle::platform::CPUDeviceContext& context) { + return pten::funcs::GetBlas(context); } TEST(math_function, gemm_notrans_cblas) { @@ -98,36 +97,36 @@ void MklSmmCompare(int m, int n, int k) { auto smm = [&, m, n, k, lda, ldb, ldc, alpha, beta]() { const char transa = 'N'; const char transb = 'N'; - paddle::operators::math::CBlas::SMM_GEMM(&transa, - &transb, - &n, - &m, - &k, - &alpha, - B, - &ldb, - A, - &lda, - &beta, - CSMM, - &ldc); + pten::funcs::CBlas::SMM_GEMM(&transa, + &transb, + &n, + &m, + &k, + &alpha, + B, + &ldb, + A, + &lda, + &beta, + CSMM, + &ldc); }; auto mkl = [&, m, n, k, lda, ldb, ldc, alpha, beta]() { - paddle::operators::math::CBlas::GEMM(CblasRowMajor, - CblasNoTrans, - CblasNoTrans, - m, - n, - k, - alpha, - A, - lda, - B, - ldb, - beta, - CMKL, - ldc); + pten::funcs::CBlas::GEMM(CblasRowMajor, + CblasNoTrans, + CblasNoTrans, + m, + n, + k, + alpha, + A, + lda, + B, + ldb, + beta, + CMKL, + ldc); }; smm(); @@ -321,20 +320,20 @@ void GemmWarpTest(int m, int n, int k, T alpha, T beta) { int lda = k; int ldb = n; int ldc = n; - paddle::operators::math::CBlas::GEMM(CblasRowMajor, - CblasNoTrans, - CblasNoTrans, - m, - n, - k, - alpha, - A, - lda, - B, - ldb, - beta, - CMKL, - ldc); + pten::funcs::CBlas::GEMM(CblasRowMajor, + CblasNoTrans, + CblasNoTrans, + m, + n, + k, + alpha, + A, + lda, + B, + ldb, + beta, + CMKL, + ldc); for (int i = 0; i < mat_c_mkl.numel(); ++i) { EXPECT_FLOAT_EQ(CREF[i], CMKL[i]); diff --git a/paddle/pten/tests/kernels/test_math_function.cu b/paddle/pten/tests/kernels/test_math_function.cu index 69ea874408ec0..d26f5583a1470 100644 --- a/paddle/pten/tests/kernels/test_math_function.cu +++ b/paddle/pten/tests/kernels/test_math_function.cu @@ -13,8 +13,8 @@ // limitations under the License. #include "gtest/gtest.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace pten { @@ -37,10 +37,9 @@ void fill_fp16_data(pten::dtype::float16* in_ptr, } template -inline paddle::operators::math::BlasT -GetBlas(const paddle::platform::CUDADeviceContext& context) { - return paddle::operators::math::GetBlas(context); +inline pten::funcs::BlasT GetBlas( + const paddle::platform::CUDADeviceContext& context) { + return pten::funcs::GetBlas(context); } TEST(math_function, notrans_mul_trans_fp32) {