diff --git a/paddle/phi/kernels/gpu/matmul_grad_kernel.cu b/paddle/phi/kernels/gpu/matmul_grad_kernel.cu index 1591d86d8cf59..b6c13360cd404 100644 --- a/paddle/phi/kernels/gpu/matmul_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/matmul_grad_kernel.cu @@ -14,60 +14,8 @@ limitations under the License. */ #include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" #include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h" #include "paddle/phi/kernels/matmul_grad_kernel.h" -namespace phi { -template <> -void MatMul(const phi::GPUContext& dev_ctx, - const DenseTensor& a, - bool trans_a, - const DenseTensor& b, - bool trans_b, - DenseTensor* out) { - dev_ctx.template Alloc(out); -#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) - if (a.dims().size() == 2 && b.dims().size() == 2) { - auto& x_dims = a.dims(); // M * K - auto& y_dims = b.dims(); // K * N - const int M = trans_a ? x_dims[1] : x_dims[0]; - const int K = trans_a ? x_dims[0] : x_dims[1]; - const int N = trans_b ? y_dims[0] : y_dims[1]; - phi::funcs::LinearWithCublasLt::Run( - dev_ctx, - &a, // x - &b, // y - out, // out - nullptr, // bias - nullptr, - M, // M bsz_seqf - N, // N output_size - K, // K input_size - trans_a, - trans_b, - phi::funcs::MatmulFusedType::kMatmul); - return; - } -#endif - auto blas = phi::funcs::GetBlas(dev_ctx); - auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); - auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); - if (a.dims().size() == 3 && b.dims().size() <= 2) { - // the transpose_X must be false, if is true, the transpose cost much time - if (!trans_a) { - mat_dim_a.height_ *= mat_dim_a.batch_size_; - mat_dim_a.batch_size_ = 0; - } - } - blas.MatMul(a.data(), - mat_dim_a, - b.data(), - mat_dim_b, - static_cast(1), - dev_ctx.template Alloc(out), - static_cast(false)); -} -} // namespace phi PD_REGISTER_KERNEL(matmul_grad, GPU, diff --git a/paddle/phi/kernels/gpu/matmul_kernel.cu b/paddle/phi/kernels/gpu/matmul_kernel.cu index e96a76b1d1e7b..32d70ae0763f0 100644 --- a/paddle/phi/kernels/gpu/matmul_kernel.cu +++ b/paddle/phi/kernels/gpu/matmul_kernel.cu @@ -15,48 +15,9 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" #include "paddle/phi/kernels/impl/matmul_kernel_impl.h" #include "paddle/phi/kernels/matmul_kernel.h" -namespace phi { -template <> -void MatMulFunction(const phi::GPUContext& dev_ctx, - const DenseTensor& X, - const DenseTensor& Y, - DenseTensor* Out, - bool trans_x, - bool trans_y) { -#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) - if (X.dims().size() == 2 && Y.dims().size() == 2) { - auto& x_dims = X.dims(); // M * K - auto& y_dims = Y.dims(); // K * N - const int M = trans_x ? x_dims[1] : x_dims[0]; - const int K = trans_x ? x_dims[0] : x_dims[1]; - const int N = trans_y ? y_dims[0] : y_dims[1]; - phi::funcs::LinearWithCublasLt::Run( - dev_ctx, - &X, // x - &Y, // y - Out, // out - nullptr, // bias - nullptr, - M, // M bsz_seqf - N, // N output_size - K, // K input_size - trans_x, - trans_y, - phi::funcs::MatmulFusedType::kMatmul); - return; - } -#endif - const std::vector x_dims = vectorize(X.dims()); - const std::vector y_dims = vectorize(Y.dims()); - MatMulFunction( - dev_ctx, X, Y, x_dims, y_dims, Out, trans_x, trans_y, false); -} -} // namespace phi - PD_REGISTER_KERNEL(matmul, GPU, ALL_LAYOUT,