Skip to content

Commit

Permalink
rollback
Browse files Browse the repository at this point in the history
  • Loading branch information
humingqing authored and humingqing committed Jan 12, 2024
1 parent 9da652d commit 2809f5f
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 91 deletions.
52 changes: 0 additions & 52 deletions paddle/phi/kernels/gpu/matmul_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,60 +14,8 @@ limitations under the License. */

#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h"
#include "paddle/phi/kernels/matmul_grad_kernel.h"
namespace phi {
template <>
void MatMul<phi::GPUContext, float>(const phi::GPUContext& dev_ctx,
const DenseTensor& a,
bool trans_a,
const DenseTensor& b,
bool trans_b,
DenseTensor* out) {
dev_ctx.template Alloc<float>(out);
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060)
if (a.dims().size() == 2 && b.dims().size() == 2) {
auto& x_dims = a.dims(); // M * K
auto& y_dims = b.dims(); // K * N
const int M = trans_a ? x_dims[1] : x_dims[0];
const int K = trans_a ? x_dims[0] : x_dims[1];
const int N = trans_b ? y_dims[0] : y_dims[1];
phi::funcs::LinearWithCublasLt<float>::Run(
dev_ctx,
&a, // x
&b, // y
out, // out
nullptr, // bias
nullptr,
M, // M bsz_seqf
N, // N output_size
K, // K input_size
trans_a,
trans_b,
phi::funcs::MatmulFusedType::kMatmul);
return;
}
#endif
auto blas = phi::funcs::GetBlas<phi::GPUContext, float>(dev_ctx);
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b);
if (a.dims().size() == 3 && b.dims().size() <= 2) {
// the transpose_X must be false, if is true, the transpose cost much time
if (!trans_a) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
}
}
blas.MatMul(a.data<float>(),
mat_dim_a,
b.data<float>(),
mat_dim_b,
static_cast<float>(1),
dev_ctx.template Alloc<float>(out),
static_cast<float>(false));
}
} // namespace phi

PD_REGISTER_KERNEL(matmul_grad,
GPU,
Expand Down
39 changes: 0 additions & 39 deletions paddle/phi/kernels/gpu/matmul_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,48 +15,9 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/phi/kernels/impl/matmul_kernel_impl.h"
#include "paddle/phi/kernels/matmul_kernel.h"

namespace phi {
template <>
void MatMulFunction<phi::GPUContext, float>(const phi::GPUContext& dev_ctx,
const DenseTensor& X,
const DenseTensor& Y,
DenseTensor* Out,
bool trans_x,
bool trans_y) {
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060)
if (X.dims().size() == 2 && Y.dims().size() == 2) {
auto& x_dims = X.dims(); // M * K
auto& y_dims = Y.dims(); // K * N
const int M = trans_x ? x_dims[1] : x_dims[0];
const int K = trans_x ? x_dims[0] : x_dims[1];
const int N = trans_y ? y_dims[0] : y_dims[1];
phi::funcs::LinearWithCublasLt<float>::Run(
dev_ctx,
&X, // x
&Y, // y
Out, // out
nullptr, // bias
nullptr,
M, // M bsz_seqf
N, // N output_size
K, // K input_size
trans_x,
trans_y,
phi::funcs::MatmulFusedType::kMatmul);
return;
}
#endif
const std::vector<std::int64_t> x_dims = vectorize(X.dims());
const std::vector<std::int64_t> y_dims = vectorize(Y.dims());
MatMulFunction<phi::GPUContext, float>(
dev_ctx, X, Y, x_dims, y_dims, Out, trans_x, trans_y, false);
}
} // namespace phi

PD_REGISTER_KERNEL(matmul,
GPU,
ALL_LAYOUT,
Expand Down

0 comments on commit 2809f5f

Please sign in to comment.