From 16dab98743d5ecf4df3f6d159e9dc4735e0baf23 Mon Sep 17 00:00:00 2001 From: humingqing Date: Mon, 18 Dec 2023 12:08:27 +0800 Subject: [PATCH 1/9] fix tensor core --- paddle/phi/backends/gpu/gpu_context.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/phi/backends/gpu/gpu_context.cc b/paddle/phi/backends/gpu/gpu_context.cc index c3f038bb1b953..3066368a98e4d 100644 --- a/paddle/phi/backends/gpu/gpu_context.cc +++ b/paddle/phi/backends/gpu/gpu_context.cc @@ -66,6 +66,7 @@ PADDLE_DEFINE_EXPORTED_bool(enable_cublas_tf32_op_math, true, "enable tf32 for cublas."); #endif +DECLARE_bool(enable_cublas_tensor_op_math); namespace phi { @@ -382,7 +383,7 @@ struct GPUContext::Impl { } #ifdef PADDLE_WITH_CUDA #if CUDA_VERSION >= 9000 - if (!blas_tensor_core_handle_) { + if (FLAGS_enable_cublas_tensor_op_math && !blas_tensor_core_handle_) { if (!blas_tensor_core_handle_creator_) { phi::InitBlasHandle(&blas_tensor_core_handle_, stream()); } else { @@ -393,7 +394,7 @@ struct GPUContext::Impl { } #endif #if CUDA_VERSION >= 11000 - if (!blas_tf32_tensor_core_handle_) { + if (FLAGS_enable_cublas_tf32_op_math && !blas_tf32_tensor_core_handle_) { if (!blas_tf32_tensor_core_handle_creator_) { phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream()); } else { @@ -574,8 +575,7 @@ struct GPUContext::Impl { } inline void CublasCall(const std::function& callback) { - if (FLAGS_enable_cublas_tf32_op_math && - blas_tf32_tensor_core_handle_ != nullptr) { + if (blas_tf32_tensor_core_handle_ != nullptr) { std::lock_guard guard(blas_tf32_mtx_); callback(blas_tf32_tensor_core_handle_); } else { From d841ecc15305a883f429fb2bab9a772256ef6d4e Mon Sep 17 00:00:00 2001 From: humingqing Date: Mon, 18 Dec 2023 19:10:44 +0800 Subject: [PATCH 2/9] fix normal mode error --- paddle/fluid/framework/boxps_worker.cc | 8 ++++---- paddle/phi/backends/gpu/gpu_resources.cc | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/boxps_worker.cc b/paddle/fluid/framework/boxps_worker.cc index d4432cc162a77..4875cacf9ce03 100644 --- a/paddle/fluid/framework/boxps_worker.cc +++ b/paddle/fluid/framework/boxps_worker.cc @@ -968,10 +968,6 @@ void BoxPSWorker::CreateThreadScopeForNorm(const ProgramDesc& program) { auto dim = root_tensor->dims(); param_sync_.share(gpu_tensor, len).Resize(dim); skip_vars_.push_back(name); - // add copy back to root scope - if (device_id_ == 0) { - need_copy_vars_.push_back(name); - } } } // data norm copy and learning rate @@ -985,6 +981,10 @@ void BoxPSWorker::CreateThreadScopeForNorm(const ProgramDesc& program) { place_, static_cast(gpu_tensor)); ++copy_persist_num; + // add copy back to root scope + if (device_id_ == 0) { + need_copy_vars_.push_back(name); + } } } else { auto* ptr = thread_scope_->Var(name); diff --git a/paddle/phi/backends/gpu/gpu_resources.cc b/paddle/phi/backends/gpu/gpu_resources.cc index 0257139914384..22b6d497851a6 100644 --- a/paddle/phi/backends/gpu/gpu_resources.cc +++ b/paddle/phi/backends/gpu/gpu_resources.cc @@ -67,7 +67,9 @@ void InitGpuProperties(Place place, << "." << (*driver_version % 100) / 10 << ", Runtime API Version: " << *runtime_version / 1000 << "." - << (*runtime_version % 100) / 10; + << (*runtime_version % 100) / 10 + << ", build date " + << __DATE__ << " time " << __TIME__; #ifdef PADDLE_WITH_HIP size_t miopen_major, miopen_minor, miopen_patch; PADDLE_ENFORCE_GPU_SUCCESS( From fad65fdb6da506f404d71cac3b2c602fa82f8293 Mon Sep 17 00:00:00 2001 From: humingqing Date: Mon, 18 Dec 2023 19:14:38 +0800 Subject: [PATCH 3/9] add BatchedGEMM optimize --- paddle/phi/backends/gpu/gpu_resources.cc | 5 +- paddle/phi/kernels/funcs/blas/blas_impl.cu.h | 484 ++++++++++++++++-- paddle/phi/kernels/funcs/blas/blas_impl.hip.h | 215 ++++++++ 3 files changed, 656 insertions(+), 48 deletions(-) diff --git a/paddle/phi/backends/gpu/gpu_resources.cc b/paddle/phi/backends/gpu/gpu_resources.cc index 22b6d497851a6..22d2d51ac367b 100644 --- a/paddle/phi/backends/gpu/gpu_resources.cc +++ b/paddle/phi/backends/gpu/gpu_resources.cc @@ -67,9 +67,8 @@ void InitGpuProperties(Place place, << "." << (*driver_version % 100) / 10 << ", Runtime API Version: " << *runtime_version / 1000 << "." - << (*runtime_version % 100) / 10 - << ", build date " - << __DATE__ << " time " << __TIME__; + << (*runtime_version % 100) / 10 << ", build date " + << __DATE__ << " time " << __TIME__; #ifdef PADDLE_WITH_HIP size_t miopen_major, miopen_minor, miopen_patch; PADDLE_ENFORCE_GPU_SUCCESS( diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h index 509d824ca0553..6d54e466f3240 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h @@ -14,6 +14,9 @@ #pragma once +#if defined(__NVCC__) +#include +#endif #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/phi/backends/gpu/gpu_context.h" @@ -55,6 +58,17 @@ struct CUBlas { PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasSgemv(args...)); } + template + static void GEMM_BATCH(ARGS... args) { +#if CUDA_VERSION >= 8000 + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasSgemmBatched(args...)); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "SgemmBatched is not supported on cuda <= 7.5")); +#endif + } + template static void GEMM_STRIDED_BATCH(ARGS... args) { #if CUDA_VERSION >= 8000 @@ -181,6 +195,17 @@ struct CUBlas { PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDgemv(args...)); } + template + static void GEMM_BATCH(ARGS... args) { +#if CUDA_VERSION >= 8000 + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasDgemmBatched(args...)); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "DgemmBatched is not supported on cuda <= 7.5")); +#endif + } + template static void GEMM_STRIDED_BATCH(ARGS... args) { #if CUDA_VERSION >= 8000 @@ -235,40 +260,69 @@ struct CUBlas { }; template <> struct CUBlas { - //int8_t call func: - //CUBlas::GEMM_EX( - // &cuda_ctx, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_8I, ldb, A, - // CUDA_R_8I, lda, &h_beta, C, CUDA_R_32F, N, CUDA_R_32F); + // int8_t call func: + // CUBlas::GEMM_EX( + // &cuda_ctx, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_8I, ldb, A, + // CUDA_R_8I, lda, &h_beta, C, CUDA_R_32F, N, CUDA_R_32F); // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode template static void GEMM_EX(phi::GPUContext *dev_ctx, - cublasOperation_t transa, cublasOperation_t transb, int m, - int n, int k, const void *alpha, const void *A, - cudaDataType_t Atype, int lda, const void *B, - cudaDataType_t Btype, int ldb, const void *beta, void *C, - cudaDataType_t Ctype, int ldc, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void *alpha, + const void *A, + cudaDataType_t Atype, + int lda, + const void *B, + cudaDataType_t Btype, + int ldb, + const void *beta, + void *C, + cudaDataType_t Ctype, + int ldc, cudaDataType_t computeType) { #if CUDA_VERSION >= 8000 cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; #if CUDA_VERSION >= 9000 bool use_tensor_op_math = dev_ctx->tensor_core_available(); if (use_tensor_op_math) { - //algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - //VLOG(5) << "2. CUBlas int8_t, algo is CUBLAS_GEMM_DFALT_TENSOR_OP."; - algo = CUBLAS_GEMM_DFALT; // only for int8 gemm + // algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + // VLOG(5) << "2. CUBlas int8_t, algo is CUBLAS_GEMM_DFALT_TENSOR_OP."; + algo = CUBLAS_GEMM_DFALT; // only for int8 gemm } VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - VLOG(5) << "3. use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); + VLOG(5) << "3. use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); algo = CUBLAS_GEMM_DFALT; #endif // CUDA_VERSION >= 9000 dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasGemmEx( - handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, - beta, C, Ctype, ldc, computeType, algo)); + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + B, + Btype, + ldb, + beta, + C, + Ctype, + ldc, + computeType, + algo)); }); #else PADDLE_THROW(platform::errors::Unimplemented( @@ -311,6 +365,69 @@ struct CUBlas { ldc)); } +#if defined(__NVCC__) + static void GEMM_BATCH(phi::GPUContext *dev_ctx, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float *alpha, + const float16 **A, + cudaDataType_t Atype, + int lda, + const float16 **B, + cudaDataType_t Btype, + int ldb, + const float *beta, + float16 **C, + cudaDataType_t Ctype, + int ldc, + int batchCount, + cudaDataType_t computeType) { +#if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; +#if CUDA_VERSION >= 9000 + bool use_tensor_op_math = dev_ctx->tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); +#endif // CUDA_VERSION >= 9000 + thrust::device_vector A_ptr(A, A + batchCount); + thrust::device_vector B_ptr(B, B + batchCount); + thrust::device_vector C_ptr(C, C + batchCount); + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasGemmBatchedEx(handle, + transa, + transb, + m, + n, + k, + alpha, + A_ptr.data().get(), + Atype, + lda, + B_ptr.data().get(), + Btype, + ldb, + beta, + C_ptr.data().get(), + Ctype, + ldc, + batchCount, + computeType, + algo)); + }); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "cublasGemmBatchedEx is not supported on cuda <= 7.5")); +#endif + } +#endif + static void GEMM_STRIDED_BATCH(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, @@ -961,20 +1078,20 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, #endif // CUDA_VERSION >= 8000 } -//int8_t matmul +// int8_t matmul template <> template <> -inline void Blas::GEMM( - CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int M, - int N, - int K, - float alpha, - const int8_t *A, - const int8_t *B, - float beta, - float *C, int flag) const { +inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + float alpha, + const int8_t *A, + const int8_t *B, + float beta, + float *C, + int flag) const { // Note that cublas follows fortran order, so the order is different from // the cblas convention. int lda = (transA == CblasNoTrans) ? K : M; @@ -985,7 +1102,8 @@ inline void Blas::GEMM( (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; PADDLE_ENFORCE_GE( - context_.GetComputeCapability(), 53, + context_.GetComputeCapability(), + 53, phi::errors::InvalidArgument( "cublas int8_t gemm requires GPU compute capability >= 53," "but received %d", @@ -1001,17 +1119,32 @@ inline void Blas::GEMM( // using tensor cores in volta GPUs. auto &cuda_ctx = const_cast(context_); VLOG(3) << "1. call int8_t GEMM_EX."; - CUBlas::GEMM_EX( - &cuda_ctx, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_8I, ldb, A, - CUDA_R_8I, lda, &h_beta, C, CUDA_R_32F, N, CUDA_R_32F); + CUBlas::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + CUDA_R_8I, + ldb, + A, + CUDA_R_8I, + lda, + &h_beta, + C, + CUDA_R_32F, + N, + CUDA_R_32F); #else // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - //context_.CublasCall([&](cublasHandle_t handle) { - // CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, - // &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C, - // N); - //}); + // context_.CublasCall([&](cublasHandle_t handle) { + // CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, + // &h_alpha, h_B, ldb, h_A, lda, &h_beta, + // h_C, N); + // }); #endif // CUDA_VERSION >= 8000 } @@ -1428,6 +1561,75 @@ inline void Blas::GEMM(bool transA, }); } +template <> +template <> +inline void Blas::GEMM(bool transA, + bool transB, + int M, + int N, + int K, + phi::dtype::bfloat16 alpha, + const phi::dtype::bfloat16 *A, + int lda, + const phi::dtype::bfloat16 *B, + int ldb, + phi::dtype::bfloat16 beta, + phi::dtype::bfloat16 *C, + int ldc) const { +#if CUDA_VERSION >= 11000 + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 80, + phi::errors::InvalidArgument( + "cublas bf16 gemm requires GPU compute capability >= 80," + "but received %d", + context_.GetComputeCapability())); + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); + + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasGemmEx(handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + CUDA_R_16BF, + ldb, + A, + CUDA_R_16BF, + lda, + &h_beta, + C, + CUDA_R_16BF, + ldc, + CUDA_R_32F, + algo)); + }); +#else + // raise error + PADDLE_THROW(phi::errors::Unimplemented( + "cublasGemmEx with bfloat16 is not supported on cuda <= 11")); + +#endif // CUDA_VERSION >= 11000 +} + template <> template void Blas::AXPY(int n, T alpha, const T *x, T *y) const { @@ -1548,7 +1750,11 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, << FLAGS_gemm_use_half_precision_compute_type; auto fp = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; - cudaDataType_t compute_type = CUDA_R_32F; +#if CUDA_VERSION >= 11000 + auto compute_type = CUBLAS_COMPUTE_32F; +#else + auto compute_type = CUDA_R_32F; +#endif float h_alpha = static_cast(alpha); float h_beta = static_cast(beta); @@ -1559,7 +1765,11 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, std::is_same::value) { a = static_cast(&alpha); b = static_cast(&beta); +#if CUDA_VERSION >= 11000 + compute_type = CUBLAS_COMPUTE_16F; +#else compute_type = CUDA_R_16F; +#endif } context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { @@ -1708,6 +1918,97 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, } } +#if defined(__NVCC__) +template <> +template <> +inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + double alpha, + const double **A, + const double **B, + double beta, + double **C, + int batchCount) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + thrust::device_vector A_ptr(A, A + batchCount); + thrust::device_vector B_ptr(B, B + batchCount); + thrust::device_vector C_ptr(C, C + batchCount); + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM_BATCH(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B_ptr.data().get(), + ldb, + A_ptr.data().get(), + lda, + &beta, + C_ptr.data().get(), + ldc, + batchCount); + }); +} + +template <> +template <> +inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + float alpha, + const float **A, + const float **B, + float beta, + float **C, + int batchCount) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + thrust::device_vector A_ptr(A, A + batchCount); + thrust::device_vector B_ptr(B, B + batchCount); + thrust::device_vector C_ptr(C, C + batchCount); + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM_BATCH(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B_ptr.data().get(), + ldb, + A_ptr.data().get(), + lda, + &beta, + C_ptr.data().get(), + ldc, + batchCount); + }); +} + template <> template <> inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, @@ -1721,10 +2022,45 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, phi::dtype::float16 beta, phi::dtype::float16 **C, int batchCount) const { - for (int k = 0; k < batchCount; ++k) { - this->template GEMM( - transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); - } + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 53, + phi::errors::InvalidArgument( + "cublas fp16 gemm requires GPU compute capability >= 53," + "but received %d", + context_.GetComputeCapability())); + float f_alpha = static_cast(alpha); + float f_beta = static_cast(beta); + auto &cuda_ctx = const_cast(context_); + CUBlas::GEMM_BATCH(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &f_alpha, + B, + CUDA_R_16F, + ldb, + A, + CUDA_R_16F, + lda, + &f_beta, + C, + CUDA_R_16F, + ldc, + batchCount, + CUDA_R_32F); } template <> @@ -1740,11 +2076,69 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, phi::dtype::bfloat16 beta, phi::dtype::bfloat16 **C, int batchCount) const { - for (int k = 0; k < batchCount; ++k) { - this->template GEMM( - transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]); +#if CUDA_VERSION >= 11000 + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 80, + phi::errors::InvalidArgument( + "cublas bf16 gemm requires GPU compute capability >= 80," + "but received %d", + context_.GetComputeCapability())); + + float f_alpha = static_cast(alpha); + float f_beta = static_cast(beta); + + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; } + VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); + + thrust::device_vector A_ptr(A, A + batchCount); + thrust::device_vector B_ptr(B, B + batchCount); + thrust::device_vector C_ptr(C, C + batchCount); + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cublasGemmBatchedEx(handle, + cuTransB, + cuTransA, + N, + M, + K, + &f_alpha, + B_ptr.data().get(), + CUDA_R_16BF, + ldb, + A_ptr.data().get(), + CUDA_R_16BF, + lda, + &f_beta, + C_ptr.data().get(), + CUDA_R_16BF, + ldc, + batchCount, + CUDA_R_32F, + algo)); + }); +#else + // raise error + PADDLE_THROW(phi::errors::Unimplemented( + "cublasGemmBatchedEx with bfloat16 is not supported on cuda <= 11")); + +#endif // CUDA_VERSION >= 11000 } +#endif template <> template diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.hip.h b/paddle/phi/kernels/funcs/blas/blas_impl.hip.h index e322fba39a481..60d0b4ff3c0ef 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.hip.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.hip.h @@ -999,6 +999,68 @@ inline void Blas::GEMM(bool transA, }); } +template <> +template <> +inline void Blas::GEMM(bool transA, + bool transB, + int M, + int N, + int K, + phi::dtype::bfloat16 alpha, + const phi::dtype::bfloat16 *A, + int lda, + const phi::dtype::bfloat16 *B, + int ldb, + phi::dtype::bfloat16 beta, + phi::dtype::bfloat16 *C, + int ldc) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + rocblas_operation cuTransA = + transA ? rocblas_operation_none : rocblas_operation_transpose; + rocblas_operation cuTransB = + transB ? rocblas_operation_none : rocblas_operation_transpose; + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 80, + phi::errors::InvalidArgument( + "rocblas bf16 gemm requires GPU compute capability >= 80," + "but received %d", + context_.GetComputeCapability())); + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + + context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_gemm_ex(handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + rocblas_datatype_bf16_r, + ldb, + A, + rocblas_datatype_bf16_r, + lda, + &h_beta, + C, + rocblas_datatype_bf16_r, + ldc, + C, + rocblas_datatype_bf16_r, + ldc, + rocblas_datatype_f32_r, + algo, + 0, + 0)); + }); +} + template <> template void Blas::AXPY(int n, T alpha, const T *x, T *y) const { @@ -1128,6 +1190,159 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, }); } +template <> +template <> +inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + float16 alpha, + const float16 *A, + const float16 *B, + float16 beta, + float16 *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + const int64_t strideC = M * N; + context_.CublasCall([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_hgemm_strided_batched( + handle, + cuTransB, + cuTransA, + N, + M, + K, + reinterpret_cast(&alpha), + reinterpret_cast(B), + ldb, + strideB, + reinterpret_cast(A), + lda, + strideA, + reinterpret_cast(&beta), + reinterpret_cast(C), + ldc, + strideC, + batchCount)); + }); +} + +// note(wangran16): unknown bug. parameters dislocation when calling +// GEMM_STRIDED_BATCH and GEMM_STRIDED_BATCH +template <> +template <> +inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + float alpha, + const float *A, + const float *B, + float beta, + float *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + const int64_t strideC = M * N; + context_.CublasCall([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_sgemm_strided_batched(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + strideB, + A, + lda, + strideA, + &beta, + C, + ldc, + strideC, + batchCount)); + }); +} + +template <> +template <> +inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + double alpha, + const double *A, + const double *B, + double beta, + double *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + const int64_t strideC = M * N; + context_.CublasCall([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::rocblas_dgemm_strided_batched(handle, + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + strideB, + A, + lda, + strideA, + &beta, + C, + ldc, + strideC, + batchCount)); + }); +} + template <> template <> inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, From d7bd48e7eb31ce7faa5b193a77de1718566c2cad Mon Sep 17 00:00:00 2001 From: humingqing Date: Mon, 18 Dec 2023 19:15:48 +0800 Subject: [PATCH 4/9] format --- paddle/fluid/framework/boxps_worker.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/boxps_worker.cc b/paddle/fluid/framework/boxps_worker.cc index 4875cacf9ce03..7643758ec641e 100644 --- a/paddle/fluid/framework/boxps_worker.cc +++ b/paddle/fluid/framework/boxps_worker.cc @@ -983,8 +983,8 @@ void BoxPSWorker::CreateThreadScopeForNorm(const ProgramDesc& program) { ++copy_persist_num; // add copy back to root scope if (device_id_ == 0) { - need_copy_vars_.push_back(name); - } + need_copy_vars_.push_back(name); + } } } else { auto* ptr = thread_scope_->Var(name); From c297081315d9e0249a3f822acd530d8cc8cec559 Mon Sep 17 00:00:00 2001 From: humingqing Date: Tue, 19 Dec 2023 10:52:43 +0800 Subject: [PATCH 5/9] add paddlebox git version --- cmake/version.cmake | 24 +++++++++++++++++++++++ paddle/phi/backends/gpu/gpu_resources.cc | 25 +++++++++++++----------- 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/cmake/version.cmake b/cmake/version.cmake index 83bd3f1b1bc4a..88e767b968bd9 100644 --- a/cmake/version.cmake +++ b/cmake/version.cmake @@ -71,3 +71,27 @@ math(EXPR PADDLE_VERSION_INTEGER "${PADDLE_MAJOR_VER} * 1000000 add_definitions(-DPADDLE_VERSION=${PADDLE_VERSION}) add_definitions(-DPADDLE_VERSION_INTEGER=${PADDLE_VERSION_INTEGER}) message(STATUS "Paddle version is ${PADDLE_VERSION}") + +#add git version +set(COMMIT_HASH "") +set(BRANCH_NAME "") +find_package(Git QUIET) +if(GIT_FOUND) +execute_process( + COMMAND ${GIT_EXECUTABLE} log -1 --pretty=format:%H + OUTPUT_VARIABLE COMMIT_HASH + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) +execute_process( + COMMAND ${GIT_EXECUTABLE} symbolic-ref --short -q HEAD + OUTPUT_VARIABLE BRANCH_NAME + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) +endif() +message(STATUS "Git version is ${BRANCH_NAME}:${COMMIT_HASH}") +add_definitions(-DPADDLE_BRANCH_NAME="${BRANCH_NAME}") +add_definitions(-DPADDLE_COMMIT_HASH="${COMMIT_HASH}") diff --git a/paddle/phi/backends/gpu/gpu_resources.cc b/paddle/phi/backends/gpu/gpu_resources.cc index 22d2d51ac367b..2a8dbb85e8035 100644 --- a/paddle/phi/backends/gpu/gpu_resources.cc +++ b/paddle/phi/backends/gpu/gpu_resources.cc @@ -58,17 +58,20 @@ void InitGpuProperties(Place place, *runtime_version = backends::gpu::GetGPURuntimeVersion(place.GetDeviceId()); // TODO(wilber): glog may be replaced in the future? - LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " - << static_cast(place.device) - << ", GPU Compute Capability: " - << *compute_capability / 10 << "." - << *compute_capability % 10 - << ", Driver API Version: " << *driver_version / 1000 - << "." << (*driver_version % 100) / 10 - << ", Runtime API Version: " - << *runtime_version / 1000 << "." - << (*runtime_version % 100) / 10 << ", build date " - << __DATE__ << " time " << __TIME__; + LOG_FIRST_N(WARNING, 1) + << "Please NOTE: device: " << static_cast(place.device) + << ", GPU Compute Capability: " << *compute_capability / 10 << "." + << *compute_capability % 10 + << ", Driver API Version: " << *driver_version / 1000 << "." + << (*driver_version % 100) / 10 + << ", Runtime API Version: " << *runtime_version / 1000 << "." + << (*runtime_version % 100) / 10 << ", Build Date " +#ifdef PADDLE_BRANCH_NAME + << __DATE__ << " Time " << __TIME__ + << ", Git Version: " PADDLE_BRANCH_NAME ":" PADDLE_COMMIT_HASH; +#else + << __DATE__ << " Time " << __TIME__; +#endif #ifdef PADDLE_WITH_HIP size_t miopen_major, miopen_minor, miopen_patch; PADDLE_ENFORCE_GPU_SUCCESS( From 723a4661dafdae3b2a44226b95c2be1cfeeebd1f Mon Sep 17 00:00:00 2001 From: humingqing Date: Wed, 20 Dec 2023 12:09:05 +0800 Subject: [PATCH 6/9] fix auc monitor var gc bug --- paddle/fluid/framework/boxps_worker.cc | 2 ++ paddle/fluid/framework/fleet/box_wrapper.h | 20 ++++++++++++----- paddle/fluid/framework/operator.cc | 25 ++++++++++++++-------- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/boxps_worker.cc b/paddle/fluid/framework/boxps_worker.cc index 7643758ec641e..2958afeae7bb5 100644 --- a/paddle/fluid/framework/boxps_worker.cc +++ b/paddle/fluid/framework/boxps_worker.cc @@ -984,6 +984,7 @@ void BoxPSWorker::CreateThreadScopeForNorm(const ProgramDesc& program) { // add copy back to root scope if (device_id_ == 0) { need_copy_vars_.push_back(name); + skip_vars_.push_back(name); } } } else { @@ -1104,6 +1105,7 @@ void BoxPSWorker::CreateThreadScopeForSharding(const ProgramDesc& program) { // device 0 need sync datanorm and learning rate to root scope if (device_id_ == 0) { need_copy_vars_.push_back(name); + skip_vars_.push_back(name); } } } else { diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h index d094729e25bfb..29cfb18fdc4a5 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.h +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -55,7 +55,7 @@ DECLARE_int32(padbox_dataset_shuffle_thread_num); namespace paddle { namespace framework { -extern int make_day_id(const int &y, const int &m, const int &d); +extern int make_day_id(const int& y, const int& m, const int& d); #ifdef PADDLE_WITH_BOX_PS #define MAX_GPU_NUM 16 @@ -322,6 +322,11 @@ class MetricMsg { platform::errors::NotFound("Error: var %s is not found in scope.", varname.c_str())); auto& gpu_tensor = var->Get(); + PADDLE_ENFORCE_EQ( + gpu_tensor.IsInitialized(), + true, + platform::errors::InvalidArgument( + "Error: monitor var `%s` uninitialized Tensor.", varname.c_str())); *data = gpu_tensor.data(); *len = gpu_tensor.numel(); } @@ -335,6 +340,11 @@ class MetricMsg { platform::errors::NotFound("Error: var %s is not found in scope.", varname.c_str())); auto& gpu_tensor = var->Get(); + PADDLE_ENFORCE_EQ( + gpu_tensor.IsInitialized(), + true, + platform::errors::InvalidArgument( + "Error: monitor var `%s` uninitialized Tensor.", varname.c_str())); auto* gpu_data = gpu_tensor.data(); auto len = gpu_tensor.numel(); data->resize(len); @@ -424,7 +434,7 @@ class BoxWrapper { } int GetMpiSize() { return boxps::MPICluster::Ins().size(); } int GetMpiRank() { return boxps::MPICluster::Ins().rank(); } - int GetNCCLRankId(const int &device_id) { + int GetNCCLRankId(const int& device_id) { return (GetMpiRank() * gpu_num_ + device_id); } int GetGpuNum() { return gpu_num_; } @@ -832,7 +842,7 @@ class BoxWrapper { for (auto& name : var_names) { auto it = std::find(skip_gc_vars_.begin(), skip_gc_vars_.end(), name); if (it != skip_gc_vars_.end()) { - return; + continue; } skip_gc_vars_.push_back(name); } @@ -1026,8 +1036,8 @@ class BoxHelper { void SetDate(int year, int month, int day) { day_id_ = make_day_id(year, month, day); - VLOG(0) << "BoxHelpler set year=" << year << ", month=" - << month << ", day=" << day << ", day id=" << day_id_; + VLOG(0) << "BoxHelpler set year=" << year << ", month=" << month + << ", day=" << day << ", day id=" << day_id_; } void BeginPass() { #ifdef PADDLE_WITH_BOX_PS diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index b5cab12122da0..99fe1ea3bf5ff 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -60,6 +60,9 @@ DECLARE_bool(check_nan_inf); DECLARE_bool(enable_unused_var_check); DECLARE_bool(run_kp_kernel); DECLARE_bool(enable_host_event_recorder_hook); +PADDLE_DEFINE_EXPORTED_bool(enable_check_input_var, + false, + "enable check input var"); namespace paddle { namespace framework { @@ -1773,7 +1776,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, os << "\n"; printf("%s", os.str().c_str()); } - PADDLE_ENFORCE(false, "ERROR: check INF and NAN: %s", + PADDLE_ENFORCE(false, + "ERROR: check INF and NAN: %s", DebugStringEx(&exec_scope).c_str()); } #else @@ -1938,7 +1942,8 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const { << ", fallbacking to CPU one!"; expected_kernel_key.place_ = platform::CPUPlace(); kernel_iter = kernels.find(expected_kernel_key); - } else if (!paddle::platform::is_xpu_support_op(type_, expected_kernel_key)) { + } else if (!paddle::platform::is_xpu_support_op(type_, + expected_kernel_key)) { VLOG(3) << "fluid XPU not support kernel: " << type_ << ", expected_kernel_key:" << expected_kernel_key << ", fallbacking to CPU one!"; @@ -2419,13 +2424,15 @@ void OperatorWithKernel::ParseInputDataType( } } if (t != nullptr) { -// PADDLE_ENFORCE_EQ( -// t->IsInitialized(), -// true, -// platform::errors::InvalidArgument("The %s Op's Input Variable `%s` " -// "contains uninitialized Tensor.", -// Type(), -// name)); + if (FLAGS_enable_check_input_var) { + PADDLE_ENFORCE_EQ( + t->IsInitialized(), + true, + platform::errors::InvalidArgument("The %s Op's Input Variable `%s` " + "contains uninitialized Tensor.", + Type(), + name)); + } *data_type = paddle::framework::TransToProtoVarType(t->dtype()); } } From 25f741e0fa1a0421c0865370ad5e45b6ed5dad75 Mon Sep 17 00:00:00 2001 From: humingqing Date: Tue, 2 Jan 2024 10:14:56 +0800 Subject: [PATCH 7/9] support h800 --- cmake/cuda.cmake | 56 +- cmake/external/gloo.cmake | 9 +- cmake/external/warpctc.cmake | 6 +- paddle/fluid/operators/conv_base_helper.h | 52 +- paddle/fluid/operators/conv_cudnn_helper.h | 293 ++--- .../fluid/operators/cross_norm_hadamard_op.cu | 1 - paddle/fluid/operators/scaled_fc_op.cu | 1 - paddle/fluid/operators/scaled_int8fc_op.cu | 1 - .../platform/device/gpu/cuda/cuda_helper.h | 6 +- .../platform/device/gpu/cuda/cuda_profiler.cc | 20 +- .../platform/device/gpu/cuda/cuda_profiler.h | 18 +- paddle/fluid/platform/dynload/cublasLt.h | 30 +- paddle/fluid/platform/dynload/nvtx.h | 12 +- paddle/phi/backends/dynload/cublasLt.h | 29 + paddle/phi/backends/dynload/cuda_driver.cc | 1 + paddle/phi/backends/dynload/cuda_driver.h | 6 + paddle/phi/backends/dynload/cudnn.cc | 4 + paddle/phi/backends/dynload/cudnn.h | 13 + paddle/phi/backends/dynload/nvtx.h | 1 + paddle/phi/backends/gpu/cuda/cuda_helper.h | 44 +- paddle/phi/backends/gpu/gpu_context.cc | 18 + paddle/phi/backends/gpu/gpu_context.h | 3 + paddle/phi/backends/gpu/gpu_launch_config.h | 42 +- paddle/phi/common/memory_utils.h | 107 ++ paddle/phi/kernels/autotune/cache.cc | 28 +- paddle/phi/kernels/autotune/cache.h | 239 +++- paddle/phi/kernels/autotune/cache_test.cc | 23 +- paddle/phi/kernels/funcs/blas/blas_impl.cu.h | 10 +- .../phi/kernels/funcs/blas/blaslt_impl.cu.h | 1149 +++++++++++++++++ .../funcs/sparse/sparse_blas_impl.cu.h | 20 +- .../phi/kernels/gpu/graph_send_recv_kernel.cu | 16 +- paddle/phi/kernels/gpu/matmul_grad_kernel.cu | 55 +- paddle/phi/kernels/gpu/matmul_kernel.cu | 39 + paddle/phi/kernels/gpu/unique_kernel.cu | 97 +- .../kernels/gpudnn/conv_grad_grad_kernel.cu | 29 +- paddle/phi/kernels/gpudnn/conv_grad_kernel.cu | 28 +- paddle/phi/kernels/gpudnn/conv_kernel.cu | 11 +- .../gpudnn/conv_transpose_grad_kernel.cu | 39 +- .../kernels/gpudnn/conv_transpose_kernel.cu | 6 +- .../kernels/impl/matmul_grad_kernel_impl.h | 28 +- paddle/phi/kernels/impl/matmul_kernel_impl.h | 14 +- .../phi/kernels/sparse/gpu/coalesce_kernel.cu | 3 +- 42 files changed, 2209 insertions(+), 398 deletions(-) create mode 100644 paddle/phi/common/memory_utils.h create mode 100644 paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 87b943abd0106..e9cb7d325f711 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -6,28 +6,34 @@ if(WITH_NV_JETSON) add_definitions(-DWITH_NV_JETSON) set(paddle_known_gpu_archs "53 62 72") set(paddle_known_gpu_archs10 "53 62 72") + set(paddle_known_gpu_archs11 "53 62 72 87") + set(paddle_known_gpu_archs12 "53 62 72 87 90") elseif(NEW_RELEASE_ALL) message("Using New Release Strategy - All Arches Packge") add_definitions(-DNEW_RELEASE_ALL) - set(paddle_known_gpu_archs "35 50 52 60 61 70 75 80 86") - set(paddle_known_gpu_archs10 "35 50 52 60 61 70 75") + set(paddle_known_gpu_archs "50 52 60 61 70 75 80 86 90") + set(paddle_known_gpu_archs10 "50 52 60 61 70 75") set(paddle_known_gpu_archs11 "50 60 61 70 75 80") + set(paddle_known_gpu_archs12 "50 60 61 70 75 80 90") elseif(NEW_RELEASE_PYPI) message("Using New Release Strategy - Cubin Packge") add_definitions(-DNEW_RELEASE_PYPI) - set(paddle_known_gpu_archs "35 50 52 60 61 70 75 80 86") + set(paddle_known_gpu_archs "50 52 60 61 70 75 80 86 90") set(paddle_known_gpu_archs10 "") - set(paddle_known_gpu_archs11 "60 61 70 75 80") + set(paddle_known_gpu_archs11 "61 70 75 80") + set(paddle_known_gpu_archs12 "61 70 75 80 90") elseif(NEW_RELEASE_JIT) message("Using New Release Strategy - JIT Packge") add_definitions(-DNEW_RELEASE_JIT) - set(paddle_known_gpu_archs "35 50 52 60 61 70 75 80 86") - set(paddle_known_gpu_archs10 "35 50 60 70 75") - set(paddle_known_gpu_archs11 "35 50 60 70 75 80") + set(paddle_known_gpu_archs "50 52 60 61 70 75 80 86 90") + set(paddle_known_gpu_archs10 "50 60 70 75") + set(paddle_known_gpu_archs11 "50 60 70 75 80") + set(paddle_known_gpu_archs12 "50 60 70 75 80 90") else() - set(paddle_known_gpu_archs "35 50 52 60 61 70 75 80") - set(paddle_known_gpu_archs10 "35 50 52 60 61 70 75") + set(paddle_known_gpu_archs "70 80") + set(paddle_known_gpu_archs10 "50 52 60 61 70 75") set(paddle_known_gpu_archs11 "52 60 61 70 75 80") + set(paddle_known_gpu_archs12 "70 80") endif() ###################################################################################### @@ -98,12 +104,12 @@ endfunction() function(select_nvcc_arch_flags out_variable) # List of arch names set(archs_names - "Kepler" "Maxwell" "Pascal" "Volta" "Turing" "Ampere" + "Hopper" "All" "Manual") set(archs_name_default "Auto") @@ -142,9 +148,7 @@ function(select_nvcc_arch_flags out_variable) unset(CUDA_ARCH_PTX CACHE) endif() - if(${CUDA_ARCH_NAME} STREQUAL "Kepler") - set(cuda_arch_bin "30 35") - elseif(${CUDA_ARCH_NAME} STREQUAL "Maxwell") + if(${CUDA_ARCH_NAME} STREQUAL "Maxwell") if(WITH_NV_JETSON) set(cuda_arch_bin "53") else() @@ -165,11 +169,17 @@ function(select_nvcc_arch_flags out_variable) elseif(${CUDA_ARCH_NAME} STREQUAL "Turing") set(cuda_arch_bin "75") elseif(${CUDA_ARCH_NAME} STREQUAL "Ampere") - if(${CMAKE_CUDA_COMPILER_VERSION} LESS 11.1) # CUDA 11.0 - set(cuda_arch_bin "80") - elseif(${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0) # CUDA 11.1+ - set(cuda_arch_bin "80 86") + if(WITH_NV_JETSON) + set(cuda_arch_bin "87") + else() + if(${CMAKE_CUDA_COMPILER_VERSION} LESS 11.1) # CUDA 11.0 + set(cuda_arch_bin "80") + else() + set(cuda_arch_bin "80 86") + endif() endif() + elseif(${CUDA_ARCH_NAME} STREQUAL "Hopper") + set(cuda_arch_bin "90") elseif(${CUDA_ARCH_NAME} STREQUAL "All") set(cuda_arch_bin ${paddle_known_gpu_archs}) elseif(${CUDA_ARCH_NAME} STREQUAL "Auto") @@ -186,6 +196,13 @@ function(select_nvcc_arch_flags out_variable) set(cuda_arch_bin ${CUDA_ARCH_BIN}) endif() + # cuda11.4 + if(${CMAKE_CUDA_COMPILER_VERSION} LESS 11.6) + set(cuda_arch_bin "70 80") + else() + set(cuda_arch_bin "70 80 90") + endif() + if(NEW_RELEASE_JIT) set(cuda_arch_ptx "${cuda_arch_ptx}${cuda_arch_bin}") set(cuda_arch_bin "") @@ -249,6 +266,11 @@ elseif(${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0) # CUDA 11.2+ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wno-deprecated-gpu-targets") +elseif(${CMAKE_CUDA_COMPILER_VERSION} LESS 13.0) # CUDA 12.0+ + set(paddle_known_gpu_archs "${paddle_known_gpu_archs12} 90") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wno-deprecated-gpu-targets") endif() if(NOT ${CMAKE_CUDA_COMPILER_VERSION} LESS 10.0) diff --git a/cmake/external/gloo.cmake b/cmake/external/gloo.cmake index cd7b254892ed1..7fabe000b6058 100644 --- a/cmake/external/gloo.cmake +++ b/cmake/external/gloo.cmake @@ -25,8 +25,13 @@ set(GLOO_LIBRARY_DIR "${GLOO_INSTALL_DIR}/lib" CACHE PATH "gloo library directory." FORCE) # As we add extra features for gloo, we use the non-official repo -set(GLOO_REPOSITORY ${GIT_URL}/sandyhouse/gloo.git) -set(GLOO_TAG v0.0.2) +if(${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0) + set(GLOO_REPOSITORY ${GIT_URL}/sandyhouse/gloo.git) + set(GLOO_TAG v0.0.2) +else() + set(GLOO_REPOSITORY ${GIT_URL}/ziyoujiyi/gloo.git) + set(GLOO_TAG v0.0.3) +endif() set(GLOO_LIBRARIES "${GLOO_INSTALL_DIR}/lib/libgloo.a" CACHE FILEPATH "gloo library." FORCE) diff --git a/cmake/external/warpctc.cmake b/cmake/external/warpctc.cmake index c7a4e1d99bff1..8164c86f765d7 100644 --- a/cmake/external/warpctc.cmake +++ b/cmake/external/warpctc.cmake @@ -23,7 +23,11 @@ set(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc) # in case of low internet speed #set(WARPCTC_REPOSITORY https://gitee.com/tianjianhe/warp-ctc.git) set(WARPCTC_REPOSITORY ${GIT_URL}/baidu-research/warp-ctc.git) -set(WARPCTC_TAG 37ece0e1bbe8a0019a63ac7e6462c36591c66a5b) +if(${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0) + set(WARPCTC_TAG 37ece0e1bbe8a0019a63ac7e6462c36591c66a5b) +else() + set(WARPCTC_TAG bdc2b4550453e0ef2d3b5190f9c6103a84eff184) +endif() set(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include" diff --git a/paddle/fluid/operators/conv_base_helper.h b/paddle/fluid/operators/conv_base_helper.h index 8425dcb521ab6..4dc83c9717ae7 100644 --- a/paddle/fluid/operators/conv_base_helper.h +++ b/paddle/fluid/operators/conv_base_helper.h @@ -36,25 +36,33 @@ using framework::ConvSearchCache; template using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; -// As the basic for SearchAlgorithm struct. -template -struct SearchAlgorithm {}; - // As the container of searchAlgorithm::Find() result. template struct SearchResult { SearchResult() {} + explicit SearchResult(AlgoT a) : algo(a) {} + explicit SearchResult(AlgoT a, float t, size_t size) + : algo(a), time(t), workspace_size(size) {} AlgoT algo = static_cast(0); float time = -1.f; size_t workspace_size = 0; + bool exhaustive_search = false; }; template static std::ostream& operator<<(std::ostream& out, const std::vector& v) { out << "["; - for (auto const& tmp : v) out << tmp << ","; + bool is_first = true; + for (auto const& tmp : v) { + if (is_first) { + out << tmp; + is_first = false; + } else { + out << ", " << tmp; + } + } out << "]"; return out; } @@ -76,28 +84,50 @@ struct ConvArgsBase { // dilations std::vector d; + // groups + int group; + + // data foramt + DataLayout data_layout; + ConvArgsBase(const framework::Tensor* x, const framework::Tensor* w, const framework::Tensor* o, const std::vector s, const std::vector p, const std::vector d, - DataT dtype) - : x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {} + DataT dtype, + int g, + DataLayout layout) + : x(x), + w(w), + o(o), + s(s), + p(p), + d(d), + cudnn_dtype(dtype), + group(g), + data_layout(layout) {} template - size_t GetCacheKey() const { + phi::autotune::ConvCacheKey Convert2ConvCacheKey() const { auto x_shape = phi::vectorize(x->dims()); auto w_shape = phi::vectorize(w->dims()); VLOG(10) << "[ConvArgs] x_dims=" << x_shape << ", w_dims=" << w_shape - << ", strides=" << s << ", paddings=" << p << ", dilations=" << d; - return phi::autotune::ConvKey( + << ", strides=" << s << ", paddings=" << p << ", dilations=" << d + << ", data=" << paddle::experimental::CppTypeToDataType::Type() + << ", group=" << group + << ", data layout=" << static_cast(data_layout); + + return phi::autotune::ConvCacheKey( x_shape, w_shape, p, s, d, - paddle::experimental::CppTypeToDataType::Type()); + paddle::experimental::CppTypeToDataType::Type(), + group, + static_cast(data_layout)); } }; diff --git a/paddle/fluid/operators/conv_cudnn_helper.h b/paddle/fluid/operators/conv_cudnn_helper.h index 1b8d421d133f1..2fa1683833c33 100644 --- a/paddle/fluid/operators/conv_cudnn_helper.h +++ b/paddle/fluid/operators/conv_cudnn_helper.h @@ -146,79 +146,21 @@ void ChooseAlgoByWorkspace(const std::vector& perf_results, } } -static void SetConvMathType(const phi::GPUContext& ctx, - cudnnDataType_t dtype, - const platform::ConvolutionDescriptor& cdesc) { -#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) - if (ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( - cdesc.desc(), CUDNN_TENSOR_OP_MATH)); - VLOG(5) << "use cudnn_tensor_op_math"; -#if CUDA_VERSION >= 11000 -#if CUDNN_VERSION_MIN(8, 1, 0) - } else if (ctx.GetComputeCapability() >= 80 && dtype == CUDNN_DATA_BFLOAT16) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( - cdesc.desc(), CUDNN_TENSOR_OP_MATH)); -#endif // CUDNN_VERSION_MIN(8, 1, 0) - } else if (dtype == CUDNN_DATA_FLOAT && !cdesc.allow_tf32_) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( - cdesc.desc(), CUDNN_FMA_MATH)); -#endif // CUDA_VERSION >= 11000 - } else { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( - cdesc.desc(), CUDNN_DEFAULT_MATH)); - VLOG(5) << "NOT use cudnn_tensor_op_math"; - } -#endif -} +template +struct SearchAlgorithmBase {}; // cuDNN convolution forward algorithm searcher, consisted of three searching // modes, namely: deterministic, heuristic and exhaustive_search mode. // As well as one workspace size acquirsition function with respect to // the chosen alogrithm. template <> -struct SearchAlgorithm { +struct SearchAlgorithmBase { using PerfT = cudnnConvolutionFwdAlgoPerf_t; using AlgoT = cudnnConvolutionFwdAlgo_t; + constexpr static phi::autotune::AlgorithmType kAlgoType = + phi::autotune::AlgorithmType::kConvForward; - template - static SearchResult Find(const ConvArgs& args, - bool exhaustive_search, - bool deterministic, - const phi::GPUContext& ctx) { - SearchResult result; - auto dtype = platform::CudnnDataType::type; - SetConvMathType(ctx, dtype, args.cdesc); - - if (deterministic) { - result = FindAlgoDeterministic(); - } else { - // 1. Once turning on exhaustive FLAGS, always get exhaustive_search. - // 2. Once turning on auto-tune, runn heuristic search(default) before - // auto-tune process, run exhaustive_search during mentioned process. - // 3. After auto-tune process, run cached algorithm if cached, run - // default mode for the rest. - size_t key = args.GetCacheKey(); - auto& cache = phi::autotune::AutoTuneCache::Instance().GetConvForward(); - if (cache.Find(key)) { - result.algo = static_cast(cache.Get(key)); - } else { - bool use_autotune = - phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); - if (exhaustive_search || use_autotune) { - result = FindAlgoExhaustiveSearch(args, ctx); - cache.Set(key, static_cast(result.algo)); - } else { - result = FindAlgoHeuristic(args, ctx); - } - } - } - VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search - << ", deterministic=" << deterministic - << ", choose algo=" << result.algo << ", workspace=" - << ToMegaBytes(GetWorkspaceSize(args, result.algo)) << " MB"; - return result; - } + static const std::string GetPerfName() { return "ConvForward"; } static size_t GetWorkspaceSize(const ConvArgs& args, cudnnConvolutionFwdAlgo_t algo) { @@ -235,9 +177,10 @@ struct SearchAlgorithm { return workspace_size; } - private: - static SearchResult FindAlgoDeterministic() { - return SearchResult(static_cast(1)); + protected: + static SearchResult FindAlgoDeterministic(const ConvArgs& args) { + auto workspace_size = GetWorkspaceSize(args, static_cast(1)); + return SearchResult(static_cast(1), -1.0, workspace_size); } // Heuristic search mode, calling the cudnnGetXxxAlgorithm. @@ -266,6 +209,10 @@ struct SearchAlgorithm { if (result.workspace_size > workspace_size_limit) { #if CUDNN_VERSION >= 8000 + VLOG(4) << GetPerfResultString("[Heuristic] FwdAlgo Perf result", + perf_results, + actual_perf_count, + workspace_size_limit); // cudnnGetConvolutionForwardAlgorithm is removed in CUDNN-8 ChooseAlgoByWorkspace( perf_results, workspace_size_limit, &result); @@ -298,6 +245,7 @@ struct SearchAlgorithm { workspace_size_limit, &(result.algo))); #endif + result.workspace_size = GetWorkspaceSize(args, result.algo); return result; } @@ -343,6 +291,7 @@ struct SearchAlgorithm { ChooseAlgoByWorkspace( perf_results, workspace_size_limit, &result); + result.workspace_size = GetWorkspaceSize(args, result.algo); return result; } @@ -380,49 +329,13 @@ struct SearchAlgorithm { // As well as one workspace size acquirsition function with // respect to the chosen alogrithm. template <> -struct SearchAlgorithm { +struct SearchAlgorithmBase { using PerfT = cudnnConvolutionBwdDataAlgoPerf_t; using AlgoT = cudnnConvolutionBwdDataAlgo_t; + constexpr static phi::autotune::AlgorithmType kAlgoType = + phi::autotune::AlgorithmType::kConvBackwardData; - template - static SearchResult Find(const ConvArgs& args, - bool exhaustive_search, - bool deterministic, - const phi::GPUContext& ctx) { - SearchResult result; - auto dtype = platform::CudnnDataType::type; - SetConvMathType(ctx, dtype, args.cdesc); - - if (deterministic) { - result = FindAlgoDeterministic(); - } else { - // 1. Once turning on exhaustive FLAGS, always get exhaustive_search. - // 2. Once turning on auto-tune, runn heuristic search(default) before - // auto-tune process, run exhaustive_search during mentioned process. - // 3. After auto-tune process, run cached algorithm if cached, run - // default mode for the rest. - size_t key = args.GetCacheKey(); - auto& cache = - phi::autotune::AutoTuneCache::Instance().GetConvBackwardData(); - if (cache.Find(key)) { - result.algo = static_cast(cache.Get(key)); - } else { - bool use_autotune = - phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); - if (exhaustive_search || use_autotune) { - result = FindAlgoExhaustiveSearch(args, ctx); - cache.Set(key, static_cast(result.algo)); - } else { - result = FindAlgoHeuristic(args, ctx); - } - } - } - VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search - << ", deterministic=" << deterministic - << ", choose algo=" << result.algo << ", workspace=" - << ToMegaBytes(GetWorkspaceSize(args, result.algo)) << " MB"; - return result; - } + static const std::string GetPerfName() { return "ConvBackwardData"; } static size_t GetWorkspaceSize(const ConvArgs& args, cudnnConvolutionBwdDataAlgo_t algo) { @@ -439,9 +352,12 @@ struct SearchAlgorithm { return workspace_size; } - private: - static SearchResult FindAlgoDeterministic() { - return SearchResult(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1); + protected: + static SearchResult FindAlgoDeterministic(const ConvArgs& args) { + auto workspace_size = + GetWorkspaceSize(args, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1); + return SearchResult( + CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, -1.0, workspace_size); } static SearchResult FindAlgoHeuristic(const ConvArgs& args, @@ -513,7 +429,7 @@ struct SearchAlgorithm { workspace_size_limit, &(result.algo))); #endif - + result.workspace_size = GetWorkspaceSize(args, result.algo); return result; } @@ -559,6 +475,7 @@ struct SearchAlgorithm { ChooseAlgoByWorkspace( perf_results, workspace_size_limit, &result); + result.workspace_size = GetWorkspaceSize(args, result.algo); return result; } @@ -594,50 +511,13 @@ struct SearchAlgorithm { // exhaustive_search mode. As well as one workspace size acquirsition function // with respect to the chosen alogrithm. template <> -struct SearchAlgorithm { +struct SearchAlgorithmBase { using PerfT = cudnnConvolutionBwdFilterAlgoPerf_t; using AlgoT = cudnnConvolutionBwdFilterAlgo_t; + constexpr static phi::autotune::AlgorithmType kAlgoType = + phi::autotune::AlgorithmType::kConvBackwardFilter; - template - static SearchResult Find(const ConvArgs& args, - bool exhaustive_search, - bool deterministic, - const phi::GPUContext& ctx) { - platform::CUDAGraphCaptureModeGuard guard; - SearchResult result; - auto dtype = platform::CudnnDataType::type; - SetConvMathType(ctx, dtype, args.cdesc); - - if (deterministic) { - result = FindAlgoDeterministic(); - } else { - // 1. Once turning on exhaustive FLAGS, always get exhaustive_search. - // 2. Once turning on auto-tune, runn heuristic search(default) before - // auto-tune process, run exhaustive_search during mentioned process. - // 3. After auto-tune process, run cached algorithm if cached, run - // default mode for the rest. - size_t key = args.GetCacheKey(); - auto& cache = - phi::autotune::AutoTuneCache::Instance().GetConvBackwardFilter(); - if (cache.Find(key)) { - result.algo = static_cast(cache.Get(key)); - } else { - bool use_autotune = - phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); - if (exhaustive_search || use_autotune) { - result = FindAlgoExhaustiveSearch(args, ctx); - cache.Set(key, static_cast(result.algo)); - } else { - result = FindAlgoHeuristic(args, ctx); - } - } - } - VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search - << ", deterministic=" << deterministic - << ", choose algo=" << result.algo << ", workspace=" - << ToMegaBytes(GetWorkspaceSize(args, result.algo)) << " MB"; - return result; - } + static const std::string GetPerfName() { return "ConvBackwardFilter"; } static size_t GetWorkspaceSize(const ConvArgs& args, cudnnConvolutionBwdFilterAlgo_t algo) { @@ -655,9 +535,12 @@ struct SearchAlgorithm { return workspace_size; } - private: - static SearchResult FindAlgoDeterministic() { - return SearchResult(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1); + protected: + static SearchResult FindAlgoDeterministic(const ConvArgs& args) { + auto workspace_size = + GetWorkspaceSize(args, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1); + return SearchResult( + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, -1.0, workspace_size); } static SearchResult FindAlgoHeuristic(const ConvArgs& args, @@ -718,6 +601,7 @@ struct SearchAlgorithm { &(result.algo))); #endif + result.workspace_size = GetWorkspaceSize(args, result.algo); return result; } @@ -786,6 +670,7 @@ struct SearchAlgorithm { ChooseAlgo(perf_results, workspace_size_limit, &result); } + result.workspace_size = GetWorkspaceSize(args, result.algo); return result; } @@ -867,5 +752,103 @@ struct SearchAlgorithm { } }; +template +struct SearchAlgorithm : public SearchAlgorithmBase { + using AlgoT = typename SearchAlgorithmBase::AlgoT; + + template + static SearchResult Find(const phi::GPUContext& ctx, + const ConvArgs& args, + bool exhaustive_search, + bool deterministic, + bool enable_autotune = true) { + SearchResult result; + bool use_autotune = false; + auto dtype = platform::CudnnDataType::type; + SetConvMathType(ctx, dtype, args.cdesc); + + if (deterministic) { + result = SearchAlgorithmBase::FindAlgoDeterministic(args); + } else { + // 1. Once turning on exhaustive FLAGS, always get exhaustive_search. + // 2. Once turning on auto-tune, run heuristic (default) before + // auto-tune process, run exhaustive_search during mentioned process. + // Auto tune is only enabled between specified range. + // 3. After auto-tune process, run cached algorithm if cached, run + // default mode for the rest. + auto key = args.Convert2ConvCacheKey(); + auto& cache = phi::autotune::AutoTuneCache::Instance().GetConv( + SearchAlgorithmBase::kAlgoType); + bool find_in_cache = cache.Find(key); + if (find_in_cache) { + auto t = cache.Get(key); + result.algo = static_cast(t.algo); + result.workspace_size = t.workspace_size; + result.exhaustive_search = t.exhaustive_search; + } + if (!result.exhaustive_search) { + bool need_update_cache = false; + // In conv2d_tranpose, enable_autotune is set to false because some + // algorithm picked by exhaustive search method produce wrong result. + use_autotune = enable_autotune && + phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); + if (exhaustive_search || use_autotune) { + // Once autotune is enabled, the autotuned result can rewrite the + // previous result in cache found by heuristic method. + result = + SearchAlgorithmBase::template FindAlgoExhaustiveSearch( + args, ctx); + need_update_cache = true; + } else if (!find_in_cache) { + result = SearchAlgorithmBase::FindAlgoHeuristic(args, ctx); + need_update_cache = true; + } + if (need_update_cache) { + phi::autotune::ConvAutoTuneResult node( + static_cast(result.algo), + result.workspace_size, + exhaustive_search || use_autotune); + cache.Set(key, node); + } + } + } + VLOG(3) << "[cuDNN " << SearchAlgorithmBase::GetPerfName() + << "] exhaustive_search=" << exhaustive_search + << ", use_autotune=" << use_autotune + << ", deterministic=" << deterministic + << ", choose algo=" << result.algo + << ", workspace=" << ToMegaBytes(result.workspace_size) << " MB"; + return result; + } + + static void SetConvMathType(const phi::GPUContext& ctx, + cudnnDataType_t dtype, + const platform::ConvolutionDescriptor& cdesc) { +#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) + if (ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( + cdesc.desc(), CUDNN_TENSOR_OP_MATH)); + VLOG(5) << "Enable Tensor Core for FLOAT16"; +#if CUDA_VERSION >= 11000 +#if CUDNN_VERSION_MIN(8, 1, 0) + } else if (ctx.GetComputeCapability() >= 80 && + dtype == CUDNN_DATA_BFLOAT16) { + VLOG(5) << "Enable Tensor Core for BFLOAT16"; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( + cdesc.desc(), CUDNN_TENSOR_OP_MATH)); +#endif // CUDNN_VERSION_MIN(8, 1, 0) + } else if (dtype == CUDNN_DATA_FLOAT && !cdesc.allow_tf32_) { + VLOG(5) << "Disable TensorFloat (Tensor Core) for FLOAT"; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( + cdesc.desc(), CUDNN_FMA_MATH)); +#endif // CUDA_VERSION >= 11000 + } else { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( + cdesc.desc(), CUDNN_DEFAULT_MATH)); + } +#endif + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/cross_norm_hadamard_op.cu b/paddle/fluid/operators/cross_norm_hadamard_op.cu index df643de164ffe..4594421565770 100644 --- a/paddle/fluid/operators/cross_norm_hadamard_op.cu +++ b/paddle/fluid/operators/cross_norm_hadamard_op.cu @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include #include #include "paddle/fluid/framework/eigen.h" diff --git a/paddle/fluid/operators/scaled_fc_op.cu b/paddle/fluid/operators/scaled_fc_op.cu index 20bd9dbf07361..bf920093ff794 100644 --- a/paddle/fluid/operators/scaled_fc_op.cu +++ b/paddle/fluid/operators/scaled_fc_op.cu @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/scaled_fc_op.h" diff --git a/paddle/fluid/operators/scaled_int8fc_op.cu b/paddle/fluid/operators/scaled_int8fc_op.cu index c03bbf61d67fb..347640fadd68f 100644 --- a/paddle/fluid/operators/scaled_int8fc_op.cu +++ b/paddle/fluid/operators/scaled_int8fc_op.cu @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/scaled_int8fc_op.h" diff --git a/paddle/fluid/platform/device/gpu/cuda/cuda_helper.h b/paddle/fluid/platform/device/gpu/cuda/cuda_helper.h index 7185d2356aae5..4ff874c3e89f5 100644 --- a/paddle/fluid/platform/device/gpu/cuda/cuda_helper.h +++ b/paddle/fluid/platform/device/gpu/cuda/cuda_helper.h @@ -22,6 +22,7 @@ #include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/macros.h" +#include "paddle/phi/backends/gpu/cuda/cuda_helper.h" namespace paddle { namespace platform { @@ -70,11 +71,6 @@ namespace platform { * */ -#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \ - int64_t __index__ = blockIdx.x * blockDim.x + threadIdx.x; \ - for (index_type i = __index__; __index__ < (num); \ - __index__ += blockDim.x * gridDim.x, i = __index__) - class CublasHandleHolder { public: CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) { diff --git a/paddle/fluid/platform/device/gpu/cuda/cuda_profiler.cc b/paddle/fluid/platform/device/gpu/cuda/cuda_profiler.cc index 291dd6c7ce1c7..a49d9013fb6d0 100644 --- a/paddle/fluid/platform/device/gpu/cuda/cuda_profiler.cc +++ b/paddle/fluid/platform/device/gpu/cuda/cuda_profiler.cc @@ -17,9 +17,10 @@ namespace paddle { namespace platform { -void CudaProfilerInit(std::string output_file, - std::string output_mode, - std::string config_file) { +void CudaProfilerInit(const std::string& output_file, + const std::string& output_mode, + const std::string& config_file) { +#if CUDA_VERSION < 11000 PADDLE_ENFORCE(output_mode == "kvp" || output_mode == "csv", platform::errors::InvalidArgument( "Unsupported cuda profiler output mode, expect `kvp` or " @@ -28,6 +29,7 @@ void CudaProfilerInit(std::string output_file, cudaOutputMode_t mode = output_mode == "csv" ? cudaCSV : cudaKeyValuePair; PADDLE_ENFORCE_GPU_SUCCESS( cudaProfilerInitialize(config_file.c_str(), output_file.c_str(), mode)); +#endif } void CudaProfilerStart() { PADDLE_ENFORCE_GPU_SUCCESS(cudaProfilerStart()); } @@ -35,8 +37,16 @@ void CudaProfilerStart() { PADDLE_ENFORCE_GPU_SUCCESS(cudaProfilerStart()); } void CudaProfilerStop() { PADDLE_ENFORCE_GPU_SUCCESS(cudaProfilerStop()); } #ifndef _WIN32 -void CudaNvtxRangePush(std::string name) { - dynload::nvtxRangePushA(name.c_str()); +void CudaNvtxRangePush(const std::string& name, const NvtxRangeColor color) { + nvtxEventAttributes_t eventAttrib; + eventAttrib.version = NVTX_VERSION; + eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; + eventAttrib.colorType = NVTX_COLOR_ARGB; + eventAttrib.color = static_cast(color); + eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII; + eventAttrib.message.ascii = name.c_str(); + + dynload::nvtxRangePushEx(&eventAttrib); } void CudaNvtxRangePop() { dynload::nvtxRangePop(); } diff --git a/paddle/fluid/platform/device/gpu/cuda/cuda_profiler.h b/paddle/fluid/platform/device/gpu/cuda/cuda_profiler.h index 6c7cf0fd8dd94..555a83a0210f2 100644 --- a/paddle/fluid/platform/device/gpu/cuda/cuda_profiler.h +++ b/paddle/fluid/platform/device/gpu/cuda/cuda_profiler.h @@ -23,16 +23,26 @@ limitations under the License. */ namespace paddle { namespace platform { -void CudaProfilerInit(std::string output_file, - std::string output_mode, - std::string config_file); +void CudaProfilerInit(const std::string& output_file, + const std::string& output_mode, + const std::string& config_file); void CudaProfilerStart(); void CudaProfilerStop(); #ifndef _WIN32 -void CudaNvtxRangePush(std::string name); +enum class NvtxRangeColor : uint32_t { + Black = 0x00000000, + Red = 0x00ff0000, + Green = 0x0000ff00, + Blue = 0x000000ff, + White = 0x00ffffff, + Yellow = 0x00ffff00, +}; + +void CudaNvtxRangePush(const std::string& name, + const NvtxRangeColor color = NvtxRangeColor::Green); void CudaNvtxRangePop(); #endif diff --git a/paddle/fluid/platform/dynload/cublasLt.h b/paddle/fluid/platform/dynload/cublasLt.h index 3a1d28072c591..5bf92876f4fd0 100644 --- a/paddle/fluid/platform/dynload/cublasLt.h +++ b/paddle/fluid/platform/dynload/cublasLt.h @@ -39,7 +39,34 @@ namespace dynload { extern DynLoad__##__name __name // APIs available after CUDA 10.1 -// #if CUDA_VERSION >= 10100 +#if CUDA_VERSION >= 11010 +#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasLtCreate); \ + __macro(cublasLtDestroy); \ + __macro(cublasLtMatmul); \ + __macro(cublasLtMatmulDescCreate); \ + __macro(cublasLtMatmulDescDestroy); \ + __macro(cublasLtMatmulDescSetAttribute); \ + __macro(cublasLtMatmulDescGetAttribute); \ + __macro(cublasLtMatrixLayoutCreate); \ + __macro(cublasLtMatrixLayoutDestroy); \ + __macro(cublasLtMatrixLayoutSetAttribute); \ + __macro(cublasLtMatrixLayoutGetAttribute); \ + __macro(cublasLtMatmulPreferenceCreate); \ + __macro(cublasLtMatmulPreferenceDestroy); \ + __macro(cublasLtMatmulPreferenceSetAttribute); \ + __macro(cublasLtMatmulAlgoGetHeuristic); \ + __macro(cublasLtMatrixTransform); \ + __macro(cublasLtMatrixTransformDescCreate); \ + __macro(cublasLtMatrixTransformDescDestroy); \ + __macro(cublasLtMatrixTransformDescSetAttribute); \ + __macro(cublasLtMatmulAlgoInit); \ + __macro(cublasLtMatmulAlgoConfigSetAttribute); \ + __macro(cublasLtMatmulAlgoGetIds); \ + __macro(cublasLtMatmulAlgoCapGetAttribute); \ + __macro(cublasLtMatmulAlgoCheck); \ + __macro(cublasLtGetCudartVersion); +#else #define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ __macro(cublasLtCreate); \ __macro(cublasLtDestroy); \ @@ -60,6 +87,7 @@ namespace dynload { __macro(cublasLtMatrixTransformDescCreate); \ __macro(cublasLtMatrixTransformDescDestroy); \ __macro(cublasLtMatrixTransformDescSetAttribute); +#endif CUBLASLT_BLAS_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) // #endif diff --git a/paddle/fluid/platform/dynload/nvtx.h b/paddle/fluid/platform/dynload/nvtx.h index c3dc9e31df354..e5816e240e6d2 100644 --- a/paddle/fluid/platform/dynload/nvtx.h +++ b/paddle/fluid/platform/dynload/nvtx.h @@ -13,11 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once #ifndef _WIN32 -#include -#include - -#include // NOLINT - #include "paddle/phi/backends/dynload/nvtx.h" namespace paddle { @@ -28,11 +23,12 @@ namespace dynload { using DynLoad__##__name = phi::dynload::DynLoad__##__name; \ extern DynLoad__##__name __name -#define NVTX_ROUTINE_EACH(__macro) \ - __macro(nvtxRangePushA); \ +#define PLATFORM_NVTX_ROUTINE_EACH(__macro) \ + __macro(nvtxRangePushA); \ + __macro(nvtxRangePushEx); \ __macro(nvtxRangePop); -NVTX_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_NVTX_WRAP); +PLATFORM_NVTX_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_NVTX_WRAP); #undef PLATFORM_DECLARE_DYNAMIC_LOAD_NVTX_WRAP } // namespace dynload diff --git a/paddle/phi/backends/dynload/cublasLt.h b/paddle/phi/backends/dynload/cublasLt.h index 1e2a20ebdf440..8a005cb93b7d4 100644 --- a/paddle/phi/backends/dynload/cublasLt.h +++ b/paddle/phi/backends/dynload/cublasLt.h @@ -54,6 +54,34 @@ extern void *cublasLt_dso_handle; // APIs available after CUDA 10.1 // #if CUDA_VERSION >= 10100 +#if CUDA_VERSION >= 11010 +#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasLtCreate); \ + __macro(cublasLtDestroy); \ + __macro(cublasLtMatmul); \ + __macro(cublasLtMatmulDescCreate); \ + __macro(cublasLtMatmulDescDestroy); \ + __macro(cublasLtMatmulDescSetAttribute); \ + __macro(cublasLtMatmulDescGetAttribute); \ + __macro(cublasLtMatrixLayoutCreate); \ + __macro(cublasLtMatrixLayoutDestroy); \ + __macro(cublasLtMatrixLayoutSetAttribute); \ + __macro(cublasLtMatrixLayoutGetAttribute); \ + __macro(cublasLtMatmulPreferenceCreate); \ + __macro(cublasLtMatmulPreferenceDestroy); \ + __macro(cublasLtMatmulPreferenceSetAttribute); \ + __macro(cublasLtMatmulAlgoGetHeuristic); \ + __macro(cublasLtMatrixTransform); \ + __macro(cublasLtMatrixTransformDescCreate); \ + __macro(cublasLtMatrixTransformDescDestroy); \ + __macro(cublasLtMatrixTransformDescSetAttribute); \ + __macro(cublasLtMatmulAlgoInit); \ + __macro(cublasLtMatmulAlgoConfigSetAttribute); \ + __macro(cublasLtMatmulAlgoGetIds); \ + __macro(cublasLtMatmulAlgoCapGetAttribute); \ + __macro(cublasLtMatmulAlgoCheck); \ + __macro(cublasLtGetCudartVersion); +#else #define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ __macro(cublasLtCreate); \ __macro(cublasLtDestroy); \ @@ -74,6 +102,7 @@ extern void *cublasLt_dso_handle; __macro(cublasLtMatrixTransformDescCreate); \ __macro(cublasLtMatrixTransformDescDestroy); \ __macro(cublasLtMatrixTransformDescSetAttribute); +#endif CUBLASLT_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) // #endif diff --git a/paddle/phi/backends/dynload/cuda_driver.cc b/paddle/phi/backends/dynload/cuda_driver.cc index 2bd0a7bfea5c1..d9fd89a0c65a6 100644 --- a/paddle/phi/backends/dynload/cuda_driver.cc +++ b/paddle/phi/backends/dynload/cuda_driver.cc @@ -24,6 +24,7 @@ void* cuda_dso_handle = nullptr; #if CUDA_VERSION >= 10020 CUDA_ROUTINE_EACH_VVM(DEFINE_WRAP); +CUDA_ROUTINE_EACH_CUDA_GRAPH(DEFINE_WRAP); #endif CUDA_ROUTINE_EACH(DEFINE_WRAP); diff --git a/paddle/phi/backends/dynload/cuda_driver.h b/paddle/phi/backends/dynload/cuda_driver.h index f743a33a1866f..ba771afe09023 100644 --- a/paddle/phi/backends/dynload/cuda_driver.h +++ b/paddle/phi/backends/dynload/cuda_driver.h @@ -72,7 +72,13 @@ extern bool HasCUDADriver(); __macro(cuMemRelease); \ __macro(cuMemAddressFree) +#define CUDA_ROUTINE_EACH_CUDA_GRAPH(__macro) \ + __macro(cuGraphNodeGetType); \ + __macro(cuGraphKernelNodeGetParams); \ + __macro(cuGraphExecKernelNodeSetParams) + CUDA_ROUTINE_EACH_VVM(DECLARE_DYNAMIC_LOAD_CUDA_WRAP); +CUDA_ROUTINE_EACH_CUDA_GRAPH(DECLARE_DYNAMIC_LOAD_CUDA_WRAP); #endif CUDA_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDA_WRAP); diff --git a/paddle/phi/backends/dynload/cudnn.cc b/paddle/phi/backends/dynload/cudnn.cc index 8aa3b623273d7..9bd38a89ab177 100644 --- a/paddle/phi/backends/dynload/cudnn.cc +++ b/paddle/phi/backends/dynload/cudnn.cc @@ -46,6 +46,10 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DEFINE_WRAP); CUDNN_DNN_ROUTINE_EACH_R8(DEFINE_WRAP); #endif +#ifdef CUDNN_DNN_ROUTINE_EACH_FRONTEND +CUDNN_DNN_ROUTINE_EACH_FRONTEND(DEFINE_WRAP); +#endif + bool HasCUDNN() { std::call_once(cudnn_dso_flag, []() { cudnn_dso_handle = GetCUDNNDsoHandle(); }); diff --git a/paddle/phi/backends/dynload/cudnn.h b/paddle/phi/backends/dynload/cudnn.h index 7b9004308e95b..3292beb037110 100644 --- a/paddle/phi/backends/dynload/cudnn.h +++ b/paddle/phi/backends/dynload/cudnn.h @@ -194,6 +194,19 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_R8(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif +#ifdef PADDLE_WITH_CUDNN_FRONTEND +#define CUDNN_DNN_ROUTINE_EACH_FRONTEND(__macro) \ + __macro(cudnnBackendCreateDescriptor); \ + __macro(cudnnBackendDestroyDescriptor); \ + __macro(cudnnBackendExecute); \ + __macro(cudnnBackendFinalize); \ + __macro(cudnnBackendGetAttribute); \ + __macro(cudnnBackendSetAttribute); \ + __macro(cudnnGetStream); \ + __macro(cudnnReorderFilterAndBias); +CUDNN_DNN_ROUTINE_EACH_FRONTEND(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) +#endif + } // namespace dynload } // namespace phi diff --git a/paddle/phi/backends/dynload/nvtx.h b/paddle/phi/backends/dynload/nvtx.h index a9a166b289e33..e51bbf2154a17 100644 --- a/paddle/phi/backends/dynload/nvtx.h +++ b/paddle/phi/backends/dynload/nvtx.h @@ -42,6 +42,7 @@ extern void *nvtx_dso_handle; #define NVTX_ROUTINE_EACH(__macro) \ __macro(nvtxRangePushA); \ + __macro(nvtxRangePushEx); \ __macro(nvtxRangePop); NVTX_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NVTX_WRAP); diff --git a/paddle/phi/backends/gpu/cuda/cuda_helper.h b/paddle/phi/backends/gpu/cuda/cuda_helper.h index c62addfd257ab..2d527dd526a0e 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_helper.h +++ b/paddle/phi/backends/gpu/cuda/cuda_helper.h @@ -14,6 +14,13 @@ #pragma once +#include // NOLINT + +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/enforce.h" + namespace phi { namespace backends { namespace gpu { @@ -24,7 +31,7 @@ namespace gpu { * [ Why need this macro? ] * * The original looping in CUDA kernel is: - * + *p * `for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ * i += blockDim.x * gridDim.x)` * @@ -62,10 +69,37 @@ namespace gpu { * */ -#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \ - int64_t __index__ = blockIdx.x * blockDim.x + threadIdx.x; \ - for (index_type i = __index__; __index__ < (num); \ - __index__ += blockDim.x * gridDim.x, i = __index__) +#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \ + int64_t __index__ = \ + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; \ + int64_t __stride__ = static_cast(blockDim.x) * gridDim.x; \ + for (index_type i = __index__; __index__ < (num); \ + __index__ += __stride__, i = __index__) + +template +cudaDataType_t ToCudaDataType() { + if (std::is_same::value) { + return CUDA_R_32F; + } else if (std::is_same::value) { + return CUDA_R_64F; + } else if (std::is_same::value) { + return CUDA_R_16F; +#if CUDA_VERSION >= 11000 + } else if (std::is_same::value) { + return CUDA_R_16BF; +#endif +#if CUDA_VERSION >= 11040 + } else if (std::is_same::value) { + return CUDA_R_8I; + } else if (std::is_same::value) { + return CUDA_R_32I; +#endif + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "DataType %d is unsupported for CUDA.", + paddle::experimental::CppTypeToDataType::Type())); + } +} } // namespace gpu } // namespace backends diff --git a/paddle/phi/backends/gpu/gpu_context.cc b/paddle/phi/backends/gpu/gpu_context.cc index 3066368a98e4d..62082beac13a3 100644 --- a/paddle/phi/backends/gpu/gpu_context.cc +++ b/paddle/phi/backends/gpu/gpu_context.cc @@ -227,6 +227,7 @@ struct GPUContext::Impl { stream_ = new CUDAStream(place_); InitEigenDevice(); InitDnnWorkspace(); + GetDnnHandle(); GetBlasHandle(); } @@ -243,6 +244,7 @@ struct GPUContext::Impl { &max_threads_per_block_, &max_grid_dim_size_); stream_ = new CUDAStream(place_); + GetDnnHandle(); GetBlasHandle(); } @@ -251,6 +253,7 @@ struct GPUContext::Impl { stream_owned_ = true; backends::gpu::GPUDeviceGuard guard(place_.device); InitDnnWorkspace(); + GetDnnHandle(); GetBlasHandle(); } @@ -662,6 +665,14 @@ struct GPUContext::Impl { } } } + // get workspace ptr + void* GetWorkSpacePtr(const size_t& len) { + if (workspace_ptr_ == nullptr || len > workspace_ptr_->size()) { + workspace_ptr_.reset(); + workspace_ptr_ = allocator_->Allocate(len); + } + return workspace_ptr_->ptr(); + } // use one flag for all handles? // they should be accessed consistently @@ -726,6 +737,8 @@ struct GPUContext::Impl { Allocator* allocator_{nullptr}; // external resource. // A internal resouce to initinalize eigen_device. std::unique_ptr eigen_stream_{nullptr}; + // work space + phi::Allocator::AllocationPtr workspace_ptr_{nullptr}; }; GPUContext::GPUContext(GPUContext&&) = default; @@ -946,4 +959,9 @@ void GPUContext::SetDriverVersion(int val) { impl_->driver_version_ = val; } void GPUContext::SetRuntimeVersion(int val) { impl_->runtime_version_ = val; } +// Get Work Space +void* GPUContext::GetWorkSpacePtr(const size_t& len) const { + return impl_->GetWorkSpacePtr(len); +} + } // namespace phi diff --git a/paddle/phi/backends/gpu/gpu_context.h b/paddle/phi/backends/gpu/gpu_context.h index 989bbbcbbf5f8..c76d8549c284c 100644 --- a/paddle/phi/backends/gpu/gpu_context.h +++ b/paddle/phi/backends/gpu/gpu_context.h @@ -199,6 +199,9 @@ class PADDLE_API GPUContext : public DeviceContext { // clear: whether clear the original CUDAStream or not void SetCUDAStream(CUDAStream*, bool clear = true); + // Get Work Space + void* GetWorkSpacePtr(const size_t& len) const; + protected: // NOTE: External users manage resources. Used in inference scenarios. // The Set interface is for inference only, DeviceContext will mark the diff --git a/paddle/phi/backends/gpu/gpu_launch_config.h b/paddle/phi/backends/gpu/gpu_launch_config.h index 552f60783c8b2..fd712baf75480 100644 --- a/paddle/phi/backends/gpu/gpu_launch_config.h +++ b/paddle/phi/backends/gpu/gpu_launch_config.h @@ -34,18 +34,16 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/enforce.h" -#ifdef __HIPCC__ -// HIP results in error or nan if > 256 -#define PREDEFINED_BLOCK_SIZE 256 -#else // CUDA performs better when thread_per_block is between [64, 512] #define PREDEFINED_BLOCK_SIZE 512 -#endif namespace phi { namespace backends { namespace gpu { +// Limitation of the setting in one dimension of cuda grid. +constexpr int kMultiDimslimit = 65536; + template inline T DivUp(T a, T b) { return (a + b - 1) / b; @@ -53,20 +51,21 @@ inline T DivUp(T a, T b) { // https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 // for round integer value into next highest power of 2. -inline int64_t RoundToPowerOfTwo(int64_t n) { +inline int64_t RoundToNextHighPowOfTwo(int64_t n, int64_t min_val = 1) { n--; n |= (n >> 1); n |= (n >> 2); n |= (n >> 4); n |= (n >> 8); n |= (n >> 16); - int64_t min_val = 32; -#ifdef __HIPCC__ - int64_t max_val = 256; -#else + return std::max(min_val, (n + 1)); +} + +inline int64_t RoundToPowerOfTwo(int64_t n) { + constexpr int64_t min_val = 32; + int64_t num = RoundToNextHighPowOfTwo(n, min_val); int64_t max_val = 1024; -#endif - return std::min(max_val, std::max(min_val, (n + 1))); + return std::min(max_val, num); } #ifdef WITH_NV_JETSON @@ -162,8 +161,8 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(const phi::GPUContext& context, } inline GpuLaunchConfig GetGpuLaunchConfig2D(const phi::GPUContext& context, - int x_dim, - int y_dim) { + int64_t x_dim, + int64_t y_dim) { PADDLE_ENFORCE_GT( x_dim, 0, @@ -178,7 +177,7 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D(const phi::GPUContext& context, y_dim)); const int kThreadsPerBlock = 256; - int block_cols = std::min(x_dim, kThreadsPerBlock); + int block_cols = std::min(x_dim, kThreadsPerBlock); int block_rows = std::max(kThreadsPerBlock / block_cols, 1); int max_physical_threads = context.GetMaxPhysicalThreadCount(); @@ -188,8 +187,9 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D(const phi::GPUContext& context, // Noticed, block size is not align to 32, if needed do it yourself. config.thread_per_block = dim3(block_cols, block_rows, 1); - int grid_x = std::min(DivUp(x_dim, block_cols), max_blocks); - int grid_y = std::min(max_blocks / grid_x, std::max(y_dim / block_rows, 1)); + int grid_x = std::min(DivUp(x_dim, block_cols), max_blocks); + int grid_y = std::min(max_blocks / grid_x, + std::max(y_dim / block_rows, 1)); config.block_per_grid = dim3(grid_x, grid_y, 1); return config; @@ -229,6 +229,14 @@ inline GpuLaunchConfig GetGpuLaunchConfig3D(const phi::GPUContext& context, return config; } +template +void LimitGridDim(const Context& ctx, dim3* grid_dim) { + auto max_grid_dim = + reinterpret_cast(ctx).GetCUDAMaxGridDimSize(); + grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0]; + grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1]; + grid_dim->z = grid_dim->z < max_grid_dim[2] ? grid_dim->z : max_grid_dim[2]; +} } // namespace gpu } // namespace backends } // namespace phi diff --git a/paddle/phi/common/memory_utils.h b/paddle/phi/common/memory_utils.h new file mode 100644 index 0000000000000..045fdf9daa568 --- /dev/null +++ b/paddle/phi/common/memory_utils.h @@ -0,0 +1,107 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include // NOLINT +#include +#include "paddle/fluid/memory/malloc.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/allocator.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/stream.h" + +namespace phi { + +/* + NOTE(YuanRisheng) Why should we add the following code? + We need this because MemoryUtils::instance() is a singleton object and we + don't recommend using singleton object in kernels. So, we wrap it using a + function and if we delete this singleton object in future, it will be easy to + change code. +*/ + +namespace memory_utils { +class Buffer { + public: + explicit Buffer(const phi::Place& place) : place_(place) {} + + template + T* Alloc(size_t size) { + using AllocT = typename std:: + conditional::value, uint8_t, T>::type; + if (UNLIKELY(size == 0)) return nullptr; + size *= sizeof(AllocT); + if (allocation_ == nullptr || allocation_->size() < size) { + allocation_ = paddle::memory::Alloc(place_, size); + } + return reinterpret_cast(allocation_->ptr()); + } + + template + const T* Get() const { + return reinterpret_cast( + allocation_ && allocation_->size() > 0 ? allocation_->ptr() : nullptr); + } + + template + T* GetMutable() { + return reinterpret_cast( + allocation_ && allocation_->size() > 0 ? allocation_->ptr() : nullptr); + } + + size_t Size() const { return allocation_ ? allocation_->size() : 0; } + + phi::Place GetPlace() const { return place_; } + + private: + Allocator::AllocationPtr allocation_; + phi::Place place_; +}; + +template +struct ThrustAllocator { + typedef char value_type; + ThrustAllocator(phi::Place place, StreamType stream) { + place_ = place; + stream_ = stream; + } + ~ThrustAllocator() {} + char* allocate(std::ptrdiff_t num_bytes) { + auto storage = + paddle::memory::AllocShared(place_, + num_bytes, + phi::Stream(reinterpret_cast(stream_))); + char* ptr = reinterpret_cast(storage->ptr()); + busy_allocation_.emplace(std::make_pair(ptr, storage)); + return ptr; + } + void deallocate(char* ptr, size_t) { + allocation_map_type::iterator iter = busy_allocation_.find(ptr); + // CHECK(iter != busy_allocation_.end()); + busy_allocation_.erase(iter); + } + + private: + typedef std::unordered_map> + allocation_map_type; + allocation_map_type busy_allocation_; + phi::Place place_; + StreamType stream_; +}; + +} // namespace memory_utils + +} // namespace phi diff --git a/paddle/phi/kernels/autotune/cache.cc b/paddle/phi/kernels/autotune/cache.cc index 838f2dd265eb3..ad7a2b134a20c 100644 --- a/paddle/phi/kernels/autotune/cache.cc +++ b/paddle/phi/kernels/autotune/cache.cc @@ -21,21 +21,6 @@ namespace phi { namespace autotune { -// Define the cache key of operator -size_t ConvKey(const std::vector& x_dims, - const std::vector& w_dims, - const std::vector& strides, - const std::vector& paddings, - const std::vector& dilations, - phi::DataType dtype) { - return GetKey(x_dims, - w_dims, - strides, - paddings, - dilations, - static_cast(dtype)); -} - size_t TransposeKey(const std::vector& x_dims, const std::vector& perm, phi::DataType dtype) { @@ -73,6 +58,19 @@ void AutoTuneCache::UpdateStatus() { cache_hits += v.second.CacheHits(); cache_misses += v.second.CacheMisses(); } + + for (auto& v : cudnn_auto_tune_map_) { + VLOG(4) << "AlgoType: " << std::setfill(' ') << std::setw(name_width) + << AlgorithmTypeString(v.first) + << " Cache Size: " << v.second.Size() + << " Hits: " << v.second.CacheHits() + << " Misses: " << v.second.CacheMisses() + << " Hit Rate: " << v.second.CacheHitRate(); + size += v.second.Size(); + cache_hits += v.second.CacheHits(); + cache_misses += v.second.CacheMisses(); + } + total_size_ = size; total_cache_hits_ = cache_hits; total_cache_misses_ = cache_misses; diff --git a/paddle/phi/kernels/autotune/cache.h b/paddle/phi/kernels/autotune/cache.h index 1263cf40e567e..54c9508571c69 100644 --- a/paddle/phi/kernels/autotune/cache.h +++ b/paddle/phi/kernels/autotune/cache.h @@ -32,6 +32,7 @@ template inline void HashCombine(std::size_t* seed, const T& v, Rest... rest) { std::hash hasher; *seed ^= hasher(v) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); + *seed *= 0x00000100000001B3; HashCombine(seed, rest...); } @@ -41,7 +42,7 @@ namespace std { template struct hash> { std::size_t operator()(std::vector const& vec) const noexcept { - std::size_t seed = 0; + std::size_t seed = 0xcbf29ce484222325; for (auto val : vec) { HashCombine(&seed, val); } @@ -53,6 +54,16 @@ struct hash> { namespace phi { namespace autotune { +struct ConvAutoTuneResult { + ConvAutoTuneResult() {} + ConvAutoTuneResult(int64_t a, size_t size, bool search) + : algo(a), workspace_size(size), exhaustive_search(search) {} + + int64_t algo; + size_t workspace_size = 0; + bool exhaustive_search = false; +}; + template size_t GetKey(Args&&... args) { size_t seed = 0; @@ -60,24 +71,147 @@ size_t GetKey(Args&&... args) { return seed; } -// Define the cache key of operator -size_t ConvKey(const std::vector& x_dims, - const std::vector& w_dims, - const std::vector& strides, - const std::vector& paddings, - const std::vector& dilations, - phi::DataType dtype); +struct ConvCacheKey { + ConvCacheKey() {} + ConvCacheKey(const std::vector& arg_x_dims, + const std::vector& arg_w_dims, + const std::vector& arg_strides, + const std::vector& arg_paddings, + const std::vector& arg_dilations, + phi::DataType arg_dtype, + int arg_groups, + int64_t arg_data_layout) + : x_dims(arg_x_dims), + w_dims(arg_w_dims), + strides(arg_strides), + paddings(arg_paddings), + dilations(arg_dilations), + dtype(arg_dtype), + groups(arg_groups), + data_layout(arg_data_layout) {} + size_t hash_value() const { + return GetKey(x_dims, + w_dims, + strides, + paddings, + dilations, + static_cast(dtype), + groups, + data_layout); + } + + std::vector x_dims; + std::vector w_dims; + std::vector strides; + std::vector paddings; + std::vector dilations; + phi::DataType dtype; + int groups; + int64_t data_layout; +}; + +struct ConvCacheKeyHash { + size_t operator()(const ConvCacheKey& cache) const { + return cache.hash_value(); + } +}; + +struct ConvCacheKeyEqual { + size_t operator()(const ConvCacheKey& first, + const ConvCacheKey& second) const { + if (first.x_dims != second.x_dims) return false; + if (first.w_dims != second.w_dims) return false; + if (first.strides != second.strides) return false; + if (first.paddings != second.paddings) return false; + if (first.dilations != second.dilations) return false; + if (first.dtype != second.dtype) return false; + if (first.groups != second.groups) return false; + if (first.data_layout != second.data_layout) return false; + + return true; + } +}; + +class CudnnAlgorithmsCacheMap { + public: + CudnnAlgorithmsCacheMap() : cache_mutex_(new std::mutex()) { hash_.clear(); } + + ConvAutoTuneResult Get(const ConvCacheKey& key) { + std::lock_guard lock(*cache_mutex_); + PADDLE_ENFORCE_NE( + hash_.find(key), + hash_.end(), + phi::errors::PreconditionNotMet("The key does not exist.")); + return hash_[key]; + } + + bool Find(const ConvCacheKey& key) { + bool ret = false; + std::lock_guard lock(*cache_mutex_); + if (hash_.find(key) != hash_.end()) { + cache_hits_++; + ret = true; + } else { + cache_misses_++; + } + return ret; + } + + void Clean() { + std::lock_guard lock(*cache_mutex_); + hash_.clear(); + cache_hits_ = 0; + cache_misses_ = 0; + } + + void Set(const ConvCacheKey& key, ConvAutoTuneResult algo) { + std::lock_guard lock(*cache_mutex_); + if (hash_.size() > static_cast(1000000)) { + hash_.clear(); + } + hash_[key] = algo; + } + + int64_t CacheMisses() const { return cache_misses_; } + + int64_t CacheHits() const { return cache_hits_; } + + float CacheHitRate() const { + int64_t num_accesses = cache_hits_ + cache_misses_; + float cache_hit_rate = 0.; + if (num_accesses != 0) { + cache_hit_rate = + static_cast(cache_hits_) / static_cast(num_accesses); + } + return cache_hit_rate; + } + + int64_t Size() const { return hash_.size(); } + + private: + std::unordered_map + hash_; + std::shared_ptr cache_mutex_; + + int64_t cache_hits_{0}; + int64_t cache_misses_{0}; +}; size_t TransposeKey(const std::vector& x_dims, const std::vector& perm, phi::DataType dtype); - -template +template , + typename KeyEqualT = std::equal_to> class AlgorithmsCache { public: - AlgorithmsCache() : cache_mutex_(new std::mutex()) { hash_.clear(); } + AlgorithmsCache() : cache_mutex_(new std::mutex()) {} - AlgorithmT Get(size_t key) { + AlgorithmT Get(const KeyT& key) { std::lock_guard lock(*cache_mutex_); PADDLE_ENFORCE_NE( hash_.find(key), @@ -86,7 +220,7 @@ class AlgorithmsCache { return hash_[key]; } - bool Find(size_t key) { + bool Find(const KeyT& key) { bool ret = false; std::lock_guard lock(*cache_mutex_); if (hash_.find(key) != hash_.end()) { @@ -105,7 +239,7 @@ class AlgorithmsCache { cache_misses_ = 0; } - void Set(size_t key, AlgorithmT algo) { + void Set(const KeyT& key, AlgorithmT algo) { std::lock_guard lock(*cache_mutex_); hash_[key] = algo; } @@ -126,14 +260,43 @@ class AlgorithmsCache { int64_t Size() const { return hash_.size(); } - private: - std::unordered_map hash_; + protected: + std::unordered_map hash_; std::shared_ptr cache_mutex_; int64_t cache_hits_{0}; int64_t cache_misses_{0}; }; +template +class MatmulAlgorithmsCache : public AlgorithmsCache { + public: + MatmulAlgorithmsCache() : AlgorithmsCache() {} + + bool FindSubKey(const KeyT& sub_key) { + std::lock_guard lock(*(this->cache_mutex_)); + bool ret = (sub_hash_.find(sub_key) != sub_hash_.end()) ? true : false; + return ret; + } + + void SetSubKey(const KeyT& sub_key, void* algo) { + std::lock_guard lock(*(this->cache_mutex_)); + sub_hash_[sub_key] = algo; + } + + void* GetSubKey(const KeyT& sub_key) { + std::lock_guard lock(*(this->cache_mutex_)); + PADDLE_ENFORCE_NE( + sub_hash_.find(sub_key), + sub_hash_.end(), + phi::errors::PreconditionNotMet("The key does not exist.")); + return sub_hash_[sub_key]; + } + + private: + std::unordered_map sub_hash_; +}; + enum class AlgorithmType { kConvForward = 1, kConvBackwardData = 2, @@ -143,9 +306,13 @@ enum class AlgorithmType { }; // AlgorithmsConfigKey -> AlgorithmsID -using AlgorithmsCacheMap = AlgorithmsCache; +// (todo. hong) use cudnnConvolutionFwdAlgo_t +using AlgorithmsCacheMap = AlgorithmsCache; // AlgorithmType -> AlgorithmsCache using AlgorithmsTypeMap = std::unordered_map; +using CudnnAlgorithmsTypeMap = + std::unordered_map; +using MatmulAlgorithmsCacheMap = MatmulAlgorithmsCache; class AutoTuneCache { public: @@ -158,24 +325,22 @@ class AutoTuneCache { return auto_tune_map_[static_cast(algo_type)]; } - AlgorithmsCacheMap& GetConvForward() { - return Get(AlgorithmType::kConvForward); - } - - AlgorithmsCacheMap& GetConvBackwardData() { - return Get(AlgorithmType::kConvBackwardData); - } - - AlgorithmsCacheMap& GetConvBackwardFilter() { - return Get(AlgorithmType::kConvBackwardFilter); + CudnnAlgorithmsCacheMap& GetConv(const AlgorithmType& algo_type) { + return cudnn_auto_tune_map_[static_cast(algo_type)]; } AlgorithmsCacheMap& GetTranspose() { return Get(AlgorithmType::kTranspose); } + MatmulAlgorithmsCacheMap& GetMatmul() { return matmul_auto_tune_map_; } + void Clean() { for (auto& v : auto_tune_map_) { v.second.Clean(); } + + for (auto& v : cudnn_auto_tune_map_) { + v.second.Clean(); + } } void UpdateStatus(); @@ -206,14 +371,26 @@ class AutoTuneCache { void Register(const AlgorithmType& algo_type) { std::lock_guard lock(*autotune_cache_mutex_); - int64_t key = static_cast(algo_type); - if (auto_tune_map_.find(key) == auto_tune_map_.end()) { - AlgorithmsCacheMap cache; - auto_tune_map_[key] = cache; + if (algo_type == AlgorithmType::kConvForward || + algo_type == AlgorithmType::kConvBackwardData || + algo_type == AlgorithmType::kConvBackwardFilter) { + int64_t key = static_cast(algo_type); + if (auto_tune_map_.find(key) == auto_tune_map_.end()) { + CudnnAlgorithmsCacheMap cache; + cudnn_auto_tune_map_[key] = cache; + } + } else { + int64_t key = static_cast(algo_type); + if (auto_tune_map_.find(key) == auto_tune_map_.end()) { + AlgorithmsCacheMap cache; + auto_tune_map_[key] = cache; + } } } AlgorithmsTypeMap auto_tune_map_; + CudnnAlgorithmsTypeMap cudnn_auto_tune_map_; + MatmulAlgorithmsCacheMap matmul_auto_tune_map_; std::shared_ptr autotune_cache_mutex_; int64_t total_cache_hits_{0}; int64_t total_cache_misses_{0}; diff --git a/paddle/phi/kernels/autotune/cache_test.cc b/paddle/phi/kernels/autotune/cache_test.cc index 53574c3d0c9ac..18454ad3e1997 100644 --- a/paddle/phi/kernels/autotune/cache_test.cc +++ b/paddle/phi/kernels/autotune/cache_test.cc @@ -25,7 +25,8 @@ enum ConvAlgos { GEMMKernel = 0, CuDNNKernel_1 = 1, CuDNNKernel_2 = 2 }; TEST(AlgosCache, AlgosCache) { auto autotune_cache = phi::autotune::AutoTuneCache::Instance(); - auto& cache = autotune_cache.GetConvForward(); + auto& cache = + autotune_cache.GetConv(phi::autotune::AlgorithmType::kConvForward); std::vector x_shape = {4, 224, 224, 3}; std::vector w_shape = {32, 3, 3, 3}; @@ -34,20 +35,24 @@ TEST(AlgosCache, AlgosCache) { std::vector dilations = {1, 1}; phi::DataType dtype = paddle::experimental::CppTypeToDataType::Type(); - auto key = phi::autotune::ConvKey( - x_shape, w_shape, paddings, strides, dilations, dtype); + phi::autotune::ConvCacheKey key( + x_shape, w_shape, paddings, strides, dilations, dtype, 0, 0); EXPECT_EQ(cache.Find(key), false); - cache.Set(key, ConvAlgos::GEMMKernel); + phi::autotune::ConvAutoTuneResult node( + static_cast(ConvAlgos::GEMMKernel), 0, false); + cache.Set(key, node); EXPECT_EQ(cache.Size(), 1); EXPECT_EQ(cache.Find(key), true); auto algo = cache.Get(key); - EXPECT_EQ(algo, ConvAlgos::GEMMKernel); + EXPECT_EQ(algo.algo, ConvAlgos::GEMMKernel); x_shape = {4, 128, 128, 3}; - key = phi::autotune::ConvKey( - x_shape, w_shape, paddings, strides, dilations, dtype); - EXPECT_EQ(cache.Find(key), false); - cache.Set(key, ConvAlgos::CuDNNKernel_1); + phi::autotune::ConvCacheKey key1( + x_shape, w_shape, paddings, strides, dilations, dtype, 0, 1); + EXPECT_EQ(cache.Find(key1), false); + phi::autotune::ConvAutoTuneResult node1( + static_cast(ConvAlgos::CuDNNKernel_1), 0, false); + cache.Set(key1, node1); EXPECT_EQ(cache.Size(), 2); EXPECT_EQ(cache.CacheHits(), 1); EXPECT_EQ(cache.CacheMisses(), 2); diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h index 6d54e466f3240..459a701b5115b 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h @@ -1750,11 +1750,7 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, << FLAGS_gemm_use_half_precision_compute_type; auto fp = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; -#if CUDA_VERSION >= 11000 - auto compute_type = CUBLAS_COMPUTE_32F; -#else - auto compute_type = CUDA_R_32F; -#endif + cudaDataType_t compute_type = CUDA_R_32F; float h_alpha = static_cast(alpha); float h_beta = static_cast(beta); @@ -1765,11 +1761,7 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, std::is_same::value) { a = static_cast(&alpha); b = static_cast(&beta); -#if CUDA_VERSION >= 11000 - compute_type = CUBLAS_COMPUTE_16F; -#else compute_type = CUDA_R_16F; -#endif } context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { diff --git a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h new file mode 100644 index 0000000000000..37229fc0daff1 --- /dev/null +++ b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h @@ -0,0 +1,1149 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 + +#include "glog/logging.h" + +#include // NOLINT +#include "cuda.h" // NOLINT +#include "paddle/phi/backends/dynload/cublasLt.h" +#include "paddle/phi/backends/gpu/cuda/cuda_helper.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/platform/flags.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/kernels/autotune/cache.h" +#include "paddle/phi/kernels/autotune/gpu_timer.h" +#include "paddle/phi/kernels/autotune/switch_autotune.h" + +DECLARE_int64(cublaslt_exhaustive_search_times); +#endif + +namespace phi { +namespace funcs { + +#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) + +// Set this enum according to +// https://docs.nvidia.com/cuda/cublas/index.html#cublasltepilogue-t +// While kMatmul, kMatmulGrad, kMatmulGradWithoutBias share the same +// enum value, but if all elements for MatmulPlanner->GetKey() is same, +// no matter forward or backward, they could share the same descriptor +// cache, in that the descriptor is for description of matmul operation. +enum MatmulFusedType { + kMatmul = 0, + kMatmulGrad = 1, + kMatmulGradWithoutBias = 2, + kMatmulBias = 3, + kMatmulRelu = 4, + kMatmulGelu = 5, + kMatmulBiasRelu = 6, + kMatmulBiasGelu = 7, + kMatmulBiasReluWithReservedData = 8, + kMatmulBiasGeluWithReservedData = 9, + kMatmulReluGrad = 10, + kMatmulGeluGrad = 11, + kMatmulBiasGradToA = 12, + kMatmulBiasGradToB = 13, +}; + +static cublasLtEpilogue_t ConvertFusedType(MatmulFusedType fused_type) { + static std::map fused_type_map = { + {MatmulFusedType::kMatmul, CUBLASLT_EPILOGUE_DEFAULT}, + {MatmulFusedType::kMatmulGrad, CUBLASLT_EPILOGUE_DEFAULT}, + {MatmulFusedType::kMatmulGradWithoutBias, CUBLASLT_EPILOGUE_DEFAULT}, + {MatmulFusedType::kMatmulBias, CUBLASLT_EPILOGUE_BIAS}, + {MatmulFusedType::kMatmulRelu, CUBLASLT_EPILOGUE_RELU}, + {MatmulFusedType::kMatmulGelu, CUBLASLT_EPILOGUE_GELU}, + {MatmulFusedType::kMatmulBiasRelu, CUBLASLT_EPILOGUE_RELU_BIAS}, + {MatmulFusedType::kMatmulBiasGelu, CUBLASLT_EPILOGUE_GELU_BIAS}, + {MatmulFusedType::kMatmulBiasReluWithReservedData, + CUBLASLT_EPILOGUE_RELU_AUX_BIAS}, + {MatmulFusedType::kMatmulBiasGeluWithReservedData, + CUBLASLT_EPILOGUE_GELU_AUX_BIAS}, +#if CUDA_VERSION >= 11060 + {MatmulFusedType::kMatmulReluGrad, CUBLASLT_EPILOGUE_DRELU}, + {MatmulFusedType::kMatmulGeluGrad, CUBLASLT_EPILOGUE_DGELU}, + {MatmulFusedType::kMatmulBiasGradToA, CUBLASLT_EPILOGUE_BGRADA}, + {MatmulFusedType::kMatmulBiasGradToB, CUBLASLT_EPILOGUE_BGRADB} +#endif + }; + + return fused_type_map[fused_type]; +} + +enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 }; + +template +struct FusedGEMMGradTrait; + +template <> +struct FusedGEMMGradTrait { + static constexpr auto kXGradA = FusedGEMMGradInType::kDZ; + static constexpr auto kXGradB = FusedGEMMGradInType::kDY; + static constexpr auto kXGradATrans = false; + static constexpr auto kXGradBTrans = true; + + static constexpr auto kYGradA = FusedGEMMGradInType::kDX; + static constexpr auto kYGradB = FusedGEMMGradInType::kDZ; + static constexpr auto kYGradATrans = true; + static constexpr auto kYGradBTrans = false; +}; + +template <> +struct FusedGEMMGradTrait { + static constexpr auto kXGradA = FusedGEMMGradInType::kDY; + static constexpr auto kXGradB = FusedGEMMGradInType::kDZ; + static constexpr auto kXGradATrans = false; + static constexpr auto kXGradBTrans = true; + + static constexpr auto kYGradA = FusedGEMMGradInType::kDX; + static constexpr auto kYGradB = FusedGEMMGradInType::kDZ; + static constexpr auto kYGradATrans = false; + static constexpr auto kYGradBTrans = false; +}; + +template <> +struct FusedGEMMGradTrait { + static constexpr auto kXGradA = FusedGEMMGradInType::kDZ; + static constexpr auto kXGradB = FusedGEMMGradInType::kDY; + static constexpr auto kXGradATrans = false; + static constexpr auto kXGradBTrans = false; + + static constexpr auto kYGradA = FusedGEMMGradInType::kDZ; + static constexpr auto kYGradB = FusedGEMMGradInType::kDX; + static constexpr auto kYGradATrans = true; + static constexpr auto kYGradBTrans = false; +}; + +template <> +struct FusedGEMMGradTrait { + static constexpr auto kXGradA = FusedGEMMGradInType::kDY; + static constexpr auto kXGradB = FusedGEMMGradInType::kDZ; + static constexpr auto kXGradATrans = true; + static constexpr auto kXGradBTrans = true; + + static constexpr auto kYGradA = FusedGEMMGradInType::kDZ; + static constexpr auto kYGradB = FusedGEMMGradInType::kDX; + static constexpr auto kYGradATrans = true; + static constexpr auto kYGradBTrans = true; +}; + +// To tell any matmul or fused matmul operation from each other. +struct MatmulPlanner { + public: + const void* bias{nullptr}; + void* aux_data{nullptr}; + + MatmulPlanner() {} + MatmulPlanner(const std::vector& x_dims, + const std::vector& y_dims, + const bool trans_x, + const bool trans_y, + phi::DataType dtype, + MatmulFusedType fused_type, + const void* bias_data = nullptr, + void* reserve_data = nullptr, // Commonly for ReLu bit-mask. + bool use_addto = false, + bool no_exchange = true) + : bias(bias_data), aux_data(reserve_data), fused_type_(fused_type) { + use_addto_ = use_addto; + key_ = phi::autotune::GetKey(x_dims, + y_dims, + static_cast(trans_x), + static_cast(trans_y), + static_cast(dtype), + static_cast(fused_type_), + static_cast(use_addto_), + static_cast(no_exchange)); + } + + bool UseAddTo() const { return use_addto_; } + size_t GetKey() const { return key_; } + MatmulFusedType GetFusedType() const { return fused_type_; } + + size_t GenSubKey() const { return key_; } + + private: + MatmulFusedType fused_type_; + bool use_addto_; + size_t key_; +}; + +template +cublasComputeType_t GetCudaComputeType() { + if (std::is_same::value) { + return CUBLAS_COMPUTE_64F; + } else if (std::is_same::value) { + return CUBLAS_COMPUTE_32I; + } else { + return CUBLAS_COMPUTE_32F; + } +} + +struct MatmulDescriptor { + public: + cublasLtMatmulDesc_t op_desc{nullptr}; + cublasLtMatrixLayout_t x_desc{nullptr}; + cublasLtMatrixLayout_t y_desc{nullptr}; + cublasLtMatrixLayout_t out_desc{nullptr}; + cublasLtMatmulAlgo_t* algo{nullptr}; + bool is_cached{false}; + + MatmulDescriptor() {} + MatmulDescriptor(const MatmulDescriptor& obj) { + algo = obj.algo; + x_desc = obj.x_desc; + y_desc = obj.y_desc; + op_desc = obj.op_desc; + out_desc = obj.out_desc; + is_cached = obj.is_cached; + } + + MatmulDescriptor& operator=(const MatmulDescriptor& obj) { + algo = obj.algo; + x_desc = obj.x_desc; + y_desc = obj.y_desc; + op_desc = obj.op_desc; + out_desc = obj.out_desc; + is_cached = obj.is_cached; + + return *this; + } + + ~MatmulDescriptor() PADDLE_MAY_THROW { + if (!is_cached) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescDestroy(op_desc)); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(y_desc)); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(x_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatrixLayoutDestroy(out_desc)); + delete algo; + + op_desc = nullptr; + x_desc = nullptr; + y_desc = nullptr; + out_desc = nullptr; + algo = nullptr; + } + } + + // x_desc, y_desc, op_desc are allocated in heap memory. + template + void Create(const int64_t M, + const int64_t N, + const int64_t K, + const bool trans_x, + const bool trans_y, + phi::funcs::MatmulPlanner* planner, + const int batch_size = 1, + const int64_t stride_x = 0, + const int64_t stride_y = 0, + const int64_t stride_out = 0, + bool grad_for_dx = true) { + using MT = typename phi::dtype::MPTypeTrait::Type; + cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType(); + cudaDataType_t out_mat_type = phi::backends::gpu::ToCudaDataType(); + cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType(); + cublasComputeType_t compute_type = GetCudaComputeType(); + + if (std::is_same::value) { + out_mat_type = phi::backends::gpu::ToCudaDataType(); + scale_type = phi::backends::gpu::ToCudaDataType(); + } + + // Create operation descriptor; see cublasLtMatmulDescAttributes_t for + // details about defaults; just need to set the transforms for A and B + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type)); + SetFusedEpilogueOpDescriptor(planner, trans_x, trans_y, N); + + // Create matrix descriptors + CreateMatrixLayout(&x_desc, mat_type, M, K, trans_x); + CreateMatrixLayout(&y_desc, mat_type, K, N, trans_y); + CreateMatrixLayout(&out_desc, out_mat_type, M, N, false); + + // Config batch size and stride. + if (batch_size > 1) { + SetBatchAndStride(x_desc, batch_size, stride_x); + SetBatchAndStride(y_desc, batch_size, stride_y); + SetBatchAndStride(out_desc, batch_size, stride_out); + } + } + + cublasLtMatmulAlgo_t* SetAlgo() { + // while entering this function, the desc shall be cached. + is_cached = true; + algo = new cublasLtMatmulAlgo_t; + return algo; + } + + template + void SetFusedEpiloguePtr(phi::funcs::MatmulPlanner* planner) { + if (planner->bias != nullptr) { + const T* bias_data = static_cast(planner->bias); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute( + op_desc, + CUBLASLT_MATMUL_DESC_BIAS_POINTER, + &bias_data, + sizeof(bias_data))); + } + if (planner->aux_data != nullptr) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute( + op_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, + &(planner->aux_data), + sizeof(planner->aux_data))); + } + } + + std::string GetDescResultString(std::string prefix, + bool has_algo = true) const { + std::ostringstream out; + out << prefix << " \n"; +#define GET_DESC_DATA_STRING(src) \ + do { \ + out << " " << #src << " = ["; \ + int num = sizeof((*src)) / sizeof(src->data[0]); \ + for (int i = 0; i < num; ++i) { \ + if (i == 0) { \ + out << src->data[i]; \ + } else { \ + out << ", " << src->data[i]; \ + } \ + } \ + out << "]\n"; \ + } while (0); + + if (has_algo) { + GET_DESC_DATA_STRING(algo); + } + GET_DESC_DATA_STRING(x_desc); + GET_DESC_DATA_STRING(y_desc); + GET_DESC_DATA_STRING(out_desc); + GET_DESC_DATA_STRING(op_desc); +#undef GET_DESC_DATA_STRING + return out.str(); + } + + void ExchangeXYDesc(bool no_exchange) {} + + protected: + void SetFusedEpilogueOpDescriptor(phi::funcs::MatmulPlanner* planner, + const bool trans_x, + const bool trans_y, + int64_t lead_dim) { + cublasOperation_t cublas_trans_x = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t cublas_trans_y = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulDescSetAttribute(op_desc, + CUBLASLT_MATMUL_DESC_TRANSB, + &cublas_trans_x, + sizeof(cublas_trans_x))); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulDescSetAttribute(op_desc, + CUBLASLT_MATMUL_DESC_TRANSA, + &cublas_trans_y, + sizeof(cublas_trans_y))); + MatmulFusedType fused_type = planner->GetFusedType(); + if (fused_type != MatmulFusedType::kMatmul) { + cublasLtEpilogue_t cublaslt_fused_type = ConvertFusedType(fused_type); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulDescSetAttribute(op_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE, + &cublaslt_fused_type, + sizeof(fused_type))); + } + if (planner->aux_data) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute( + op_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, + &lead_dim, + sizeof(lead_dim))); + } + } + + void CreateMatrixLayout(cublasLtMatrixLayout_t* desc, + cudaDataType type, + uint64_t rows, + uint64_t cols, + bool trans) { + if (trans) { + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatrixLayoutCreate(desc, type, rows, cols, rows)); + } else { + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatrixLayoutCreate(desc, type, cols, rows, cols)); + } + } + + void SetBatchAndStride(cublasLtMatrixLayout_t desc, + int batch_size, + int64_t stride) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batch_size, + sizeof(batch_size))); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &stride, + sizeof(stride))); + } +}; + +struct MatmulGradDescriptor : MatmulDescriptor { + public: + MatmulGradDescriptor() {} + + template + void Create(const int64_t M, + const int64_t N, + const int64_t K, + const bool trans_x, + const bool trans_y, + phi::funcs::MatmulPlanner* planner, + const int batch_size = 1, + int64_t stride_x = 0, + int64_t stride_y = 0, + int64_t stride_out = 0, + bool grad_for_dx = true) { + using MT = typename phi::dtype::MPTypeTrait::Type; + cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType(); + cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType(); + cublasComputeType_t compute_type = GetCudaComputeType(); + + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type)); + this->SetFusedEpilogueOpDescriptor( + planner, trans_x, trans_y, TransX ? M : K); + + // Create operation desciriptor; see cublasLtMatmulDescAttributes_t for + // details about defaults; just need to set the transforms for A and B + this->CreateMatrixLayout(&x_desc, mat_type, N, M, true); + if (grad_for_dx) { + this->CreateMatrixLayout(&y_desc, mat_type, K, N, TransY); + this->CreateMatrixLayout( + &out_desc, phi::backends::gpu::ToCudaDataType(), M, K, TransX); + } else { + this->CreateMatrixLayout(&y_desc, mat_type, M, K, TransX); + this->CreateMatrixLayout( + &out_desc, phi::backends::gpu::ToCudaDataType(), K, N, TransY); + } + } + + void ExchangeXYDesc(bool no_exchange) { + if (no_exchange) { + return; + } + auto* temp = y_desc; + y_desc = x_desc; + x_desc = temp; + } +}; + +template +struct CublasLtBase { + public: + using MT = typename phi::dtype::MPTypeTrait::Type; + static phi::Allocator::AllocationPtr GetWorkspace(const phi::GPUContext& ctx, + size_t workspace_size) { + return paddle::memory::Alloc( + ctx.GetPlace(), + workspace_size, + phi::Stream(reinterpret_cast(ctx.stream()))); + } + + static void RunImpl(const phi::GPUContext& ctx, + MatmulDescT* desc, + const size_t sub_key, + const T* x_ptr, + const T* y_ptr, + OutT* out_ptr, + phi::funcs::MatmulPlanner* planner) { + MT alpha = static_cast(1); + MT beta = planner->UseAddTo() ? static_cast(1) : static_cast(0); + cublasLtHandle_t cublaslt_handle = ctx.cublaslt_handle(); + + // NOTE(limingshu): As workspace_size varies from different DL framework, + // I wonder is there any smarter idea for workspace setting, currently I + // just followed the settings from the NVIDIA colleague`s setting. + size_t workspace_size = static_cast(4) * 1024 * 1024; + // phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, + // workspace_size); + void* workspace_ptr = ctx.GetWorkSpacePtr(workspace_size); + + if (planner != nullptr) { + if (phi::autotune::AutoTuneStatus::Instance().UseAutoTune() && + (!desc->is_cached)) { + SearchBestAlgo(ctx, + cublaslt_handle, + desc, + static_cast(&alpha), + static_cast(&beta), + y_ptr, + x_ptr, + out_ptr, + workspace_ptr, + workspace_size); + MatmulDescT* best_desc = new MatmulDescT(*desc); + VLOG(6) << best_desc->GetDescResultString( + "[Searched CublasltDescriptor] "); + + auto& cache = phi::autotune::AutoTuneCache::Instance().GetMatmul(); + cache.SetSubKey(sub_key, reinterpret_cast(best_desc)); + } + } + + VLOG(7) << desc->GetDescResultString("[Impl CublasltDescriptor] "); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmul(cublaslt_handle, + desc->op_desc, + static_cast(&alpha), + y_ptr, + desc->y_desc, + x_ptr, + desc->x_desc, + static_cast(&beta), + out_ptr, + desc->out_desc, + out_ptr, + desc->out_desc, + desc->algo, + workspace_ptr, + workspace_size, + ctx.stream())); + } + + static void SearchBestAlgo(const phi::GPUContext& ctx, + const cublasLtHandle_t& lt_handle, + MatmulDescT* desc, + const void* alpha, + const void* beta, + const void* y_data, + const void* x_data, + void* out_data, + void* workspace_ptr, + size_t workspace_size) { + cublasLtMatmulPreference_t preference; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulPreferenceCreate(&preference)); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, + sizeof(workspace_size))); + + int returned_results = 0; + constexpr int requested_algo_count = 10; + std::vector heuristic_results( + requested_algo_count); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulAlgoGetHeuristic(lt_handle, + desc->op_desc, + desc->y_desc, + desc->x_desc, + desc->out_desc, + desc->out_desc, + preference, + requested_algo_count, + heuristic_results.data(), + &returned_results)); + PADDLE_ENFORCE_GT(returned_results, + 0, + phi::errors::Unavailable("No GEMM algorithm avaliable.")); + int best_algo_idx = -1; + if (returned_results == 1 || FLAGS_cublaslt_exhaustive_search_times <= 0) { + best_algo_idx = 0; + } else { + float min_time_cost = std::numeric_limits::max(); + for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) { + float cur_time_cost = + RunAndMeasureAlgo(ctx, + lt_handle, + desc, + alpha, + beta, + y_data, + x_data, + out_data, + workspace_ptr, + workspace_size, + &(heuristic_results[algo_idx].algo)); + VLOG(6) << "[MatmulWithCublaslt] algo[" << algo_idx + << "] time: " << cur_time_cost << " s"; + + if ((best_algo_idx == 0 && (1.05 * cur_time_cost < min_time_cost)) || + (cur_time_cost < min_time_cost)) { + best_algo_idx = algo_idx; + min_time_cost = cur_time_cost; + } + } + } + VLOG(6) << "[MatmulWithCublaslt] best_algo_idx: " << best_algo_idx; + + cublasLtMatmulAlgo_t* best_algo = desc->SetAlgo(); + *best_algo = heuristic_results[best_algo_idx].algo; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulPreferenceDestroy(preference)); + } + + static float RunAndMeasureAlgo(const phi::GPUContext& ctx, + const cublasLtHandle_t& lt_handle, + MatmulDescT* desc, + const void* alpha, + const void* beta, + const void* y_data, + const void* x_data, + void* out_data, + void* workspace_ptr, + size_t workspace_size, + cublasLtMatmulAlgo_t* algo) { + int repeats = FLAGS_cublaslt_exhaustive_search_times; + if (repeats <= 0) { + return std::numeric_limits::max(); + } + + phi::GpuTimer timer; + float time_cost = 0.f; + const auto& stream = ctx.stream(); + + for (int i = 0; i < repeats; ++i) { + timer.Start(stream); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmul(lt_handle, + desc->op_desc, + alpha, + y_data, + desc->y_desc, + x_data, + desc->x_desc, + beta, + out_data, + desc->out_desc, + out_data, + desc->out_desc, + algo, + workspace_ptr, + workspace_size, + stream)); + timer.Stop(stream); + ctx.Wait(); + auto time = timer.ElapsedTime(); + if (i > 0) { + // Exclude the warmup runtime. + time_cost += time; + } + } + return (time_cost / (repeats - 1)); + } +}; + +template <> +struct CublasLtBase { + public: + static phi::Allocator::AllocationPtr GetWorkspace(const phi::GPUContext& ctx, + size_t workspace_size) { + return paddle::memory::Alloc( + ctx.GetPlace(), + workspace_size, + phi::Stream(reinterpret_cast(ctx.stream()))); + } + + static void RunImpl(const phi::GPUContext& ctx, + MatmulDescriptor* desc, + const size_t sub_key, + const int8_t* x_ptr, + const int8_t* y_ptr, + int32_t* out_ptr, + phi::funcs::MatmulPlanner* planner) { + int32_t alpha = 1; + int32_t beta = + planner->UseAddTo() ? static_cast(1) : static_cast(0); + cublasLtHandle_t cublaslt_handle = ctx.cublaslt_handle(); + + size_t workspace_size = static_cast(4) * 1024 * 1024; + // phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, + // workspace_size); + void* workspace_ptr = ctx.GetWorkSpacePtr(workspace_size); + + if (planner != nullptr) { + if (phi::autotune::AutoTuneStatus::Instance().UseAutoTune() && + (!desc->is_cached)) { + SearchBestAlgo(ctx, + cublaslt_handle, + desc, + static_cast(&alpha), + static_cast(&beta), + y_ptr, + x_ptr, + out_ptr, + workspace_ptr, + workspace_size); + MatmulDescriptor* best_desc = new MatmulDescriptor(*desc); + VLOG(6) << best_desc->GetDescResultString( + "[Searched CublasltDescriptor] "); + + auto& cache = phi::autotune::AutoTuneCache::Instance().GetMatmul(); + cache.SetSubKey(sub_key, reinterpret_cast(best_desc)); + } + } + + VLOG(7) << desc->GetDescResultString("[Impl CublasltDescriptor] "); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmul(cublaslt_handle, + desc->op_desc, + static_cast(&alpha), + y_ptr, + desc->y_desc, + x_ptr, + desc->x_desc, + static_cast(&beta), + out_ptr, + desc->out_desc, + out_ptr, + desc->out_desc, + desc->algo, + workspace_ptr, + workspace_size, + ctx.stream())); + } + + static void SearchBestAlgo(const phi::GPUContext& ctx, + const cublasLtHandle_t& lt_handle, + MatmulDescriptor* desc, + const void* alpha, + const void* beta, + const void* y_data, + const void* x_data, + void* out_data, + void* workspace_ptr, + size_t workspace_size) { + cublasLtMatmulPreference_t preference; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulPreferenceCreate(&preference)); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, + sizeof(workspace_size))); + + int returned_results = 0; + constexpr int requested_algo_count = 10; + std::vector heuristic_results( + requested_algo_count); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulAlgoGetHeuristic(lt_handle, + desc->op_desc, + desc->y_desc, + desc->x_desc, + desc->out_desc, + desc->out_desc, + preference, + requested_algo_count, + heuristic_results.data(), + &returned_results)); + PADDLE_ENFORCE_GT(returned_results, + 0, + phi::errors::Unavailable("No GEMM algorithm avaliable.")); + int best_algo_idx = -1; + if (returned_results == 1 || FLAGS_cublaslt_exhaustive_search_times <= 0) { + best_algo_idx = 0; + } else { + float min_time_cost = std::numeric_limits::max(); + for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) { + float cur_time_cost = + RunAndMeasureAlgo(ctx, + lt_handle, + desc, + alpha, + beta, + y_data, + x_data, + out_data, + workspace_ptr, + workspace_size, + &(heuristic_results[algo_idx].algo)); + VLOG(6) << "[MatmulWithCublaslt] algo[" << algo_idx + << "] time: " << cur_time_cost << " s"; + + if ((best_algo_idx == 0 && (1.05 * cur_time_cost < min_time_cost)) || + (cur_time_cost < min_time_cost)) { + best_algo_idx = algo_idx; + min_time_cost = cur_time_cost; + } + } + } + VLOG(6) << "[MatmulWithCublaslt] best_algo_idx: " << best_algo_idx; + + cublasLtMatmulAlgo_t* best_algo = desc->SetAlgo(); + *best_algo = heuristic_results[best_algo_idx].algo; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulPreferenceDestroy(preference)); + } + + static float RunAndMeasureAlgo(const phi::GPUContext& ctx, + const cublasLtHandle_t& lt_handle, + MatmulDescriptor* desc, + const void* alpha, + const void* beta, + const void* y_data, + const void* x_data, + void* out_data, + void* workspace_ptr, + size_t workspace_size, + cublasLtMatmulAlgo_t* algo) { + int repeats = FLAGS_cublaslt_exhaustive_search_times; + if (repeats <= 0) { + return std::numeric_limits::max(); + } + + phi::GpuTimer timer; + float time_cost = 0.f; + const auto& stream = ctx.stream(); + + for (int i = 0; i < repeats; ++i) { + timer.Start(stream); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmul(lt_handle, + desc->op_desc, + alpha, + y_data, + desc->y_desc, + x_data, + desc->x_desc, + beta, + out_data, + desc->out_desc, + out_data, + desc->out_desc, + algo, + workspace_ptr, + workspace_size, + stream)); + timer.Stop(stream); + ctx.Wait(); + auto time = timer.ElapsedTime(); + if (i > 0) { + // Exclude the warmup runtime. + time_cost += time; + } + } + return (time_cost / (repeats - 1)); + } +}; + +// To judge if desc is cached or not. +template +struct DescriptorSetter { + public: + DescT desc; + size_t sub_key{std::numeric_limits::min()}; + + DescriptorSetter(phi::funcs::MatmulPlanner* planner, + const int64_t M, + const int64_t N, + const int64_t K, + const bool trans_x, + const bool trans_y, + const int batch_size = 1, + int64_t stride_x = 0, + int64_t stride_y = 0, + int64_t stride_out = 0, + const bool no_exchange = true, + bool grad_for_dx = true) { + if (std::is_same::value) { + if (!trans_x && !trans_y) { + PADDLE_ENFORCE_EQ( + (N % 4 == 0 || N == 1), + true, + phi::errors::InvalidArgument( + "The dimension size N used in int8 matmul must be 1 or a " + "multiple of 4 does not " + "match the size (%d) currently contained in the container.", + N)); + PADDLE_ENFORCE_EQ( + (K % 4 == 0), + true, + phi::errors::InvalidArgument( + "The dimension size K used in int8 matmul must be a multiple " + "of 4 does not " + "match the size (%d) currently contained in the container.", + K)); + } else if (!trans_x && trans_y) { + PADDLE_ENFORCE_EQ( + (K % 4 == 0), + true, + phi::errors::InvalidArgument( + "The dimension size K used in int8 matmul must be a multiple " + "of 4 does not " + "match the size (%d) currently contained in the container.", + K)); + } else if (trans_x && !trans_y) { + PADDLE_ENFORCE_EQ( + (M % 4 == 0 || M == 1), + true, + phi::errors::InvalidArgument( + "The dimension size M used in int8 matmul must be 1 or a " + "multiple of 4 does not " + "match the size (%d) currently contained in the container.", + M)); + PADDLE_ENFORCE_EQ( + (N % 4 == 0 || N == 1), + true, + phi::errors::InvalidArgument( + "The dimension size N used in int8 matmul must be 1 or a " + "multiple of 4 does not " + "match the size (%d) currently contained in the container.", + N)); + } else { + PADDLE_ENFORCE_EQ( + (M % 4 == 0 || M == 1), + true, + phi::errors::InvalidArgument( + "The dimension size M used in int8 matmul must be 1 or a " + "multiple of 4 does not " + "match the size (%d) currently contained in the container.", + M)); + PADDLE_ENFORCE_EQ( + (K % 4 == 0), + true, + phi::errors::InvalidArgument( + "The dimension size K used in int8 matmul must be a multiple " + "of 4 does not " + "match the size (%d) currently contained in the container.", + K)); + } + } + + if (planner != nullptr) { + sub_key = planner->GenSubKey(); + } + + auto& mamtul_cache = phi::autotune::AutoTuneCache::Instance().GetMatmul(); + if (mamtul_cache.FindSubKey(sub_key)) { + desc = *(reinterpret_cast(mamtul_cache.GetSubKey(sub_key))); + desc.template SetFusedEpiloguePtr(planner); + VLOG(7) << desc.GetDescResultString("[Heap CublasltDescriptor] "); + } else { + desc.template Create(M, + N, + K, + trans_x, + trans_y, + planner, + batch_size, + stride_x, + stride_y, + stride_out, + grad_for_dx); + desc.ExchangeXYDesc(no_exchange); + if (planner != nullptr) { + desc.template SetFusedEpiloguePtr(planner); + } + VLOG(7) << desc.GetDescResultString("[Stack CublasltDescriptor] ", false); + } + } +}; + +// For matmul with kernels autotune +template +struct MatmulWithCublasLt : public CublasLtBase { + public: + static void Run(const phi::GPUContext& ctx, + const T* x_data, + const T* y_data, + OutT* out_data, + const int64_t M, + const int64_t N, + const int64_t K, + const bool trans_x, + const bool trans_y, + phi::funcs::MatmulPlanner* planner = nullptr) { + auto setter = DescriptorSetter( + planner, M, N, K, trans_x, trans_y); + CublasLtBase::RunImpl( + ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner); + } + + static void RunWithBatch(const phi::GPUContext& ctx, + const T* x_data, + const T* y_data, + OutT* out_data, + const int64_t M, + const int64_t N, + const int64_t K, + bool trans_x, + bool trans_y, + int batch_size, + int64_t stride_x, + int64_t stride_y, + int64_t stride_out, + phi::funcs::MatmulPlanner* planner = nullptr) { + auto setter = DescriptorSetter(planner, + M, + N, + K, + trans_x, + trans_y, + batch_size, + stride_x, + stride_y, + stride_out); + CublasLtBase::RunImpl( + ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner); + } + + static void RunWithBatch(const phi::GPUContext& ctx, + const T** x_data, + const T** y_data, + OutT** out_data, + const int64_t M, + const int64_t N, + const int64_t K, + bool trans_x, + bool trans_y, + int batch_size, + phi::funcs::MatmulPlanner* planner = nullptr) { + for (int i = 0; i < batch_size; ++i) { + Run(ctx, + x_data[i], + y_data[i], + out_data[i], + M, + N, + K, + trans_x, + trans_y, + planner); + } + } +}; + +// As for just Linear fused ephilogue below: out = matmul(x, y) + bias. +template +struct LinearWithCublasLt : public CublasLtBase { + static void Run(const phi::GPUContext& ctx, + const phi::DenseTensor* x, + const phi::DenseTensor* y, + phi::DenseTensor* out, + const void* bias_data, + void* reserve_data, + const int64_t M, + const int64_t N, + const int64_t K, + const bool trans_x, + const bool trans_y, + const MatmulFusedType fused_type) { + auto planner = phi::funcs::MatmulPlanner( + vectorize(x->dims()), + vectorize(y->dims()), + trans_x, + trans_y, + paddle::experimental::CppTypeToDataType::Type(), + fused_type, + bias_data, + reserve_data); + auto setter = DescriptorSetter( + &planner, M, N, K, trans_x, trans_y); + CublasLtBase::RunImpl(ctx, + &setter.desc, + setter.sub_key, + x->data(), + y->data(), + out->data(), + &planner); + } +}; + +template +struct LinearGradWithCublasLt : public CublasLtBase { + static void Run( + const phi::GPUContext& ctx, + const phi::DenseTensor* x, + const phi::DenseTensor* y, + phi::DenseTensor* out, + const void* bias_data, + void* reserve_data, + const int64_t M, + const int64_t N, + const int64_t K, + const MatmulFusedType fused_type, + const bool trans_x, + const bool trans_y, + const bool use_addto, + const bool no_exchange, // exchange x_desc and y_desc for grad. + bool grad_for_dx = true) { + auto planner = phi::funcs::MatmulPlanner( + vectorize(x->dims()), + vectorize(y->dims()), + trans_x, + trans_y, + paddle::experimental::CppTypeToDataType::Type(), + fused_type, + bias_data, + reserve_data, + use_addto, + no_exchange); + auto setter = + DescriptorSetter( + &planner, + M, + N, + K, + trans_x, + trans_y, + /*batch_size=*/1, + /*stride_x=*/0, + /*stride_y=*/0, + /*stride_out=*/0, + /*exchange_x_y_desc=*/no_exchange, + /*grad_for_dx=*/grad_for_dx); + + // To setting data type for different kinda out_data. + if (grad_for_dx) { + CublasLtBase::RunImpl( + ctx, + &setter.desc, + setter.sub_key, + no_exchange ? x->data() : y->data(), + no_exchange ? y->data() : x->data(), + out->data(), + &planner); + } else { + CublasLtBase::RunImpl( + ctx, + &setter.desc, + setter.sub_key, + no_exchange ? x->data() : y->data(), + no_exchange ? y->data() : x->data(), + out->data(), + &planner); + } + } +}; +#else +// A void structure just for successfully compile. +struct MatmulPlanner {}; +#endif // (PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h index 0458f0d83ed1a..1b1814ec0ae2b 100644 --- a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h @@ -101,7 +101,7 @@ inline void CreateCsrDescriptor(const phi::SparseCsrTensor& x, gpu_type); }); if (batch_size > 1) { -#if CUDA_VERSION >= 11070 +#if CUDA_VERSION >= 11080 dev_ctx.CusparseCall([&](cusparseHandle_t handle) { phi::dynload::cusparseCsrSetStridedBatch( *descriptor, batch_size, M + 1, batch_nnz); @@ -109,7 +109,7 @@ inline void CreateCsrDescriptor(const phi::SparseCsrTensor& x, #else PADDLE_THROW(phi::errors::Unimplemented( "Batch Sparse matmul use 'cusparseCsrSetStridedBatch', which is " - "supported from CUDA 11.7")); + "supported from CUDA 11.8")); #endif } } @@ -155,7 +155,7 @@ inline void CreateCooDescriptor(const phi::SparseCooTensor& x, }); if (batch_size > 1) { -#if CUDA_VERSION >= 11070 +#if CUDA_VERSION >= 11080 dev_ctx.CusparseCall([&](cusparseHandle_t handle) { phi::dynload::cusparseCooSetStridedBatch( *descriptor, batch_size, batch_nnz); @@ -163,7 +163,7 @@ inline void CreateCooDescriptor(const phi::SparseCooTensor& x, #else PADDLE_THROW(phi::errors::Unimplemented( "Batch Sparse matmul use 'cusparseCooSetStridedBatch', which is " - "supported from CUDA 11.7")); + "supported from CUDA 11.8")); #endif } } @@ -241,7 +241,7 @@ class CuSparseDnMatDescriptor { PADDLE_ENFORCE_EQ(x.numel(), batch_size * M * N); if (batch_size > 1) { -#if CUDA_VERSION >= 11070 +#if CUDA_VERSION >= 11080 dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { phi::dynload::cusparseDnMatSetStridedBatch( descriptor_, batch_size, M * N); @@ -249,7 +249,7 @@ class CuSparseDnMatDescriptor { #else PADDLE_THROW(phi::errors::Unimplemented( "Batch Sparse matmul use 'cusparseDnMatSetStridedBatch', which is " - "supported from CUDA 11.7")); + "supported from CUDA 11.8")); #endif } VLOG(6) << "Create cusparseDnMatDescr_t " << &descriptor_; @@ -379,7 +379,11 @@ void SparseBlas::SPMV(bool transa, &beta, out_descriptor.descriptor(), gpu_type, +#if CUDA_VERSION >= 11040 + CUSPARSE_SPMV_ALG_DEFAULT, +#else CUSPARSE_MV_ALG_DEFAULT, +#endif &buffer_size); }); @@ -395,7 +399,11 @@ void SparseBlas::SPMV(bool transa, &beta, out_descriptor.descriptor(), gpu_type, +#if CUDA_VERSION >= 11040 + CUSPARSE_SPMV_ALG_DEFAULT, +#else CUSPARSE_MV_ALG_DEFAULT, +#endif tmp_buffer_ptr); }); } diff --git a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu index 7ecf352ffe996..700ce21caf2ba 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu @@ -14,9 +14,7 @@ #include "paddle/phi/kernels/graph_send_recv_kernel.h" -#include -#include - +#include "paddle/phi/kernels/funcs/math_function.h" #include #include @@ -59,17 +57,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, cudaMemset(p_output, 0, memset_bytes); #endif } else if (pool_type == "MAX") { - thrust::device_ptr p_output_ptr(p_output); - thrust::fill(thrust::device, - p_output_ptr, - p_output_ptr + memset_size, - std::numeric_limits::min()); + phi::funcs::set_constant(ctx, out, std::numeric_limits::min()); } else if (pool_type == "MIN") { - thrust::device_ptr p_output_ptr(p_output); - thrust::fill(thrust::device, - p_output_ptr, - p_output_ptr + memset_size, - std::numeric_limits::max()); + phi::funcs::set_constant(ctx, out, std::numeric_limits::max()); } if (index_size == 0) return; diff --git a/paddle/phi/kernels/gpu/matmul_grad_kernel.cu b/paddle/phi/kernels/gpu/matmul_grad_kernel.cu index 2753937eb7142..1591d86d8cf59 100644 --- a/paddle/phi/kernels/gpu/matmul_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/matmul_grad_kernel.cu @@ -12,11 +12,62 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/phi/kernels/matmul_grad_kernel.h" - #include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" #include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h" +#include "paddle/phi/kernels/matmul_grad_kernel.h" +namespace phi { +template <> +void MatMul(const phi::GPUContext& dev_ctx, + const DenseTensor& a, + bool trans_a, + const DenseTensor& b, + bool trans_b, + DenseTensor* out) { + dev_ctx.template Alloc(out); +#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) + if (a.dims().size() == 2 && b.dims().size() == 2) { + auto& x_dims = a.dims(); // M * K + auto& y_dims = b.dims(); // K * N + const int M = trans_a ? x_dims[1] : x_dims[0]; + const int K = trans_a ? x_dims[0] : x_dims[1]; + const int N = trans_b ? y_dims[0] : y_dims[1]; + phi::funcs::LinearWithCublasLt::Run( + dev_ctx, + &a, // x + &b, // y + out, // out + nullptr, // bias + nullptr, + M, // M bsz_seqf + N, // N output_size + K, // K input_size + trans_a, + trans_b, + phi::funcs::MatmulFusedType::kMatmul); + return; + } +#endif + auto blas = phi::funcs::GetBlas(dev_ctx); + auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); + auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); + if (a.dims().size() == 3 && b.dims().size() <= 2) { + // the transpose_X must be false, if is true, the transpose cost much time + if (!trans_a) { + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + } + } + blas.MatMul(a.data(), + mat_dim_a, + b.data(), + mat_dim_b, + static_cast(1), + dev_ctx.template Alloc(out), + static_cast(false)); +} +} // namespace phi PD_REGISTER_KERNEL(matmul_grad, GPU, diff --git a/paddle/phi/kernels/gpu/matmul_kernel.cu b/paddle/phi/kernels/gpu/matmul_kernel.cu index 32d70ae0763f0..e96a76b1d1e7b 100644 --- a/paddle/phi/kernels/gpu/matmul_kernel.cu +++ b/paddle/phi/kernels/gpu/matmul_kernel.cu @@ -15,9 +15,48 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" #include "paddle/phi/kernels/impl/matmul_kernel_impl.h" #include "paddle/phi/kernels/matmul_kernel.h" +namespace phi { +template <> +void MatMulFunction(const phi::GPUContext& dev_ctx, + const DenseTensor& X, + const DenseTensor& Y, + DenseTensor* Out, + bool trans_x, + bool trans_y) { +#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) + if (X.dims().size() == 2 && Y.dims().size() == 2) { + auto& x_dims = X.dims(); // M * K + auto& y_dims = Y.dims(); // K * N + const int M = trans_x ? x_dims[1] : x_dims[0]; + const int K = trans_x ? x_dims[0] : x_dims[1]; + const int N = trans_y ? y_dims[0] : y_dims[1]; + phi::funcs::LinearWithCublasLt::Run( + dev_ctx, + &X, // x + &Y, // y + Out, // out + nullptr, // bias + nullptr, + M, // M bsz_seqf + N, // N output_size + K, // K input_size + trans_x, + trans_y, + phi::funcs::MatmulFusedType::kMatmul); + return; + } +#endif + const std::vector x_dims = vectorize(X.dims()); + const std::vector y_dims = vectorize(Y.dims()); + MatMulFunction( + dev_ctx, X, Y, x_dims, y_dims, Out, trans_x, trans_y, false); +} +} // namespace phi + PD_REGISTER_KERNEL(matmul, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/gpu/unique_kernel.cu b/paddle/phi/kernels/gpu/unique_kernel.cu index 3d44c9af03c07..c52555c38e5a3 100644 --- a/paddle/phi/kernels/gpu/unique_kernel.cu +++ b/paddle/phi/kernels/gpu/unique_kernel.cu @@ -21,12 +21,13 @@ #include #include #include - +#include #include #include -#include "paddle/fluid/framework/tensor_util.h" // TensorToVector() +#include "cub/cub.cuh" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/unique_functor.h" @@ -194,22 +195,29 @@ static void UniqueFlattendCUDATensor(const Context& context, indices->Resize(phi::make_ddim({num_input})); auto* indices_data = context.template Alloc(indices); - thrust::sequence(thrust::device, indices_data, indices_data + num_input); +#ifdef PADDLE_WITH_CUDA + phi::memory_utils::ThrustAllocator allocator(context.GetPlace(), + context.stream()); + const auto& exec_policy = thrust::cuda::par(allocator).on(context.stream()); +#else + const auto& exec_policy = thrust::hip::par.on(context.stream()); +#endif + + thrust::sequence(exec_policy, indices_data, indices_data + num_input); thrust::sort_by_key( - thrust::device, in_data_hat, in_data_hat + num_input, indices_data); + exec_policy, in_data_hat, in_data_hat + num_input, indices_data); // 1. Calculate op result: 'out' DenseTensor range; range.Resize(phi::make_ddim({num_input + 1})); auto* range_data_ptr = context.template Alloc(&range); - thrust::sequence( - thrust::device, range_data_ptr, range_data_ptr + num_input + 1); + thrust::sequence(exec_policy, range_data_ptr, range_data_ptr + num_input + 1); phi::Copy(context, in_hat, context.GetPlace(), false, out); int num_out; auto out_data = context.template Alloc(out); num_out = thrust::unique_by_key( - thrust::device, out_data, out_data + num_input, range_data_ptr, equal) + exec_policy, out_data, out_data + num_input, range_data_ptr, equal) .first - out_data; out->Resize(phi::make_ddim({num_out})); @@ -221,18 +229,32 @@ static void UniqueFlattendCUDATensor(const Context& context, DenseTensor inv_loc; inv_loc.Resize(phi::make_ddim({num_input})); auto inv_loc_data_ptr = context.template Alloc(&inv_loc); - thrust::adjacent_difference(thrust::device, + thrust::adjacent_difference(exec_policy, in_data_hat, in_data_hat + num_input, inv_loc_data_ptr, not_equal); - thrust::device_ptr inv_loc_data_dev(inv_loc_data_ptr); - inv_loc_data_dev[0] = 0; // without device_ptr, segmentation fault - thrust::inclusive_scan(thrust::device, - inv_loc_data_ptr, - inv_loc_data_ptr + num_input, - inv_loc_data_ptr); - thrust::scatter(thrust::device, +#ifdef PADDLE_WITH_HIP + hipMemset(inv_loc_data_ptr, 0, sizeof(IndexT)); +#else + cudaMemsetAsync(inv_loc_data_ptr, 0, sizeof(IndexT), context.stream()); +#endif + size_t temp_storage_bytes = 0; + cub::DeviceScan::InclusiveSum(NULL, + temp_storage_bytes, + inv_loc_data_ptr, + inv_loc_data_ptr, + num_input, + context.stream()); + auto d_temp_storage = + paddle::memory::Alloc(context.GetPlace(), temp_storage_bytes); + cub::DeviceScan::InclusiveSum(d_temp_storage->ptr(), + temp_storage_bytes, + inv_loc_data_ptr, + inv_loc_data_ptr, + num_input, + context.stream()); + thrust::scatter(exec_policy, inv_loc_data_ptr, inv_loc_data_ptr + num_input, indices_data, @@ -244,11 +266,11 @@ static void UniqueFlattendCUDATensor(const Context& context, DenseTensor tmp_indices; tmp_indices.Resize(phi::make_ddim({num_input})); auto* tmp_indices_data_ptr = context.template Alloc(&tmp_indices); - thrust::copy(thrust::device, + thrust::copy(exec_policy, in_data_hat, in_data_hat + num_input, tmp_indices_data_ptr); - thrust::unique_by_key(thrust::device, + thrust::unique_by_key(exec_policy, tmp_indices_data_ptr, tmp_indices_data_ptr + num_input, indices_data, @@ -261,10 +283,10 @@ static void UniqueFlattendCUDATensor(const Context& context, counts->Resize(phi::make_ddim({num_out})); auto count_data = context.template Alloc(counts); // init 'count_data' as 0 - thrust::fill(thrust::device, count_data, count_data + num_out, 0); + thrust::fill(exec_policy, count_data, count_data + num_out, 0); thrust::device_ptr range_data_ptr_dev(range_data_ptr); range_data_ptr_dev[num_out] = num_input; - thrust::adjacent_difference(thrust::device, + thrust::adjacent_difference(exec_policy, range_data_ptr + 1, range_data_ptr + num_out + 1, count_data); @@ -290,24 +312,29 @@ static void ComputeUniqueDims(const Context& context, equal_T equal, not_equal_T not_equal, int64_t row) { +#ifdef PADDLE_WITH_CUDA + phi::memory_utils::ThrustAllocator allocator(context.GetPlace(), + context.stream()); + const auto& exec_policy = thrust::cuda::par(allocator).on(context.stream()); +#else + const auto& exec_policy = thrust::hip::par.on(context.stream()); +#endif // 1. inverse indices: 'inverse' inverse->Resize(phi::make_ddim({row})); auto* inverse_data = context.template Alloc(inverse); DenseTensor inv_loc; inv_loc.Resize(phi::make_ddim({row})); auto inv_loc_data_ptr = context.template Alloc(&inv_loc); - thrust::adjacent_difference(thrust::device, + thrust::adjacent_difference(exec_policy, sorted_indices_data, sorted_indices_data + row, inv_loc_data_ptr, not_equal); thrust::device_ptr inv_loc_data_dev(inv_loc_data_ptr); inv_loc_data_dev[0] = 0; - thrust::inclusive_scan(thrust::device, - inv_loc_data_ptr, - inv_loc_data_ptr + row, - inv_loc_data_ptr); - thrust::scatter(thrust::device, + thrust::inclusive_scan( + exec_policy, inv_loc_data_ptr, inv_loc_data_ptr + row, inv_loc_data_ptr); + thrust::scatter(exec_policy, inv_loc_data_ptr, inv_loc_data_ptr + row, sorted_indices_data, @@ -317,9 +344,9 @@ static void ComputeUniqueDims(const Context& context, DenseTensor range; range.Resize(phi::make_ddim({row + 1})); auto range_data_ptr = context.template Alloc(&range); - thrust::sequence(thrust::device, range_data_ptr, range_data_ptr + row + 1); + thrust::sequence(exec_policy, range_data_ptr, range_data_ptr + row + 1); int num_out; - num_out = thrust::unique_by_key(thrust::device, + num_out = thrust::unique_by_key(exec_policy, sorted_indices_data, sorted_indices_data + row, range_data_ptr, @@ -333,9 +360,9 @@ static void ComputeUniqueDims(const Context& context, // 3. counts: 'counts' counts->Resize(phi::make_ddim({num_out})); auto* count_data = context.template Alloc(counts); - thrust::fill(thrust::device, count_data, count_data + row, 0); + thrust::fill(exec_policy, count_data, count_data + row, 0); thrust::adjacent_difference( - thrust::device, range_data_ptr + 1, range_data_ptr + row + 1, count_data); + exec_policy, range_data_ptr + 1, range_data_ptr + row + 1, count_data); } // Calculate unique when 'axis' is set @@ -384,9 +411,15 @@ static void UniqueDimsCUDATensor(const Context& context, // 2. Calculate 'indices', 'inverse', 'counts' // Init index and sort - thrust::sequence( - thrust::device, sorted_indices_data, sorted_indices_data + row); - thrust::sort(thrust::device, +#ifdef PADDLE_WITH_CUDA + phi::memory_utils::ThrustAllocator allocator(context.GetPlace(), + context.stream()); + const auto& exec_policy = thrust::cuda::par(allocator).on(context.stream()); +#else + const auto& exec_policy = thrust::hip::par.on(context.stream()); +#endif + thrust::sequence(exec_policy, sorted_indices_data, sorted_indices_data + row); + thrust::sort(exec_policy, sorted_indices_data, sorted_indices_data + row, LessThan(col, in_trans_data)); diff --git a/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu b/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu index ef70907b59a61..e61f58450b34f 100644 --- a/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu @@ -254,6 +254,8 @@ void ConvCudnnGradGradKernel( auto dtype = paddle::platform::CudnnDataType::type; auto handle = ctx.cudnn_handle(); + auto layout = paddle::platform::GetCudnnTensorFormat( + paddle::platform::DataLayout::kNCHW); paddle::operators::ConvArgs args1{&transformed_ddX, W, @@ -261,28 +263,36 @@ void ConvCudnnGradGradKernel( strides, padding_common, dilations, - dtype}; + dtype, + groups, + paddle::platform::DataLayout::kNCHW}; paddle::operators::ConvArgs args2{&transformed_X, ddW, &transformed_ddO_channel, strides, padding_common, dilations, - dtype}; + dtype, + groups, + paddle::platform::DataLayout::kNCHW}; paddle::operators::ConvArgs args3{&transformed_ddX, dW, &transformed_dO_channel, strides, padding_common, dilations, - dtype}; + dtype, + groups, + paddle::platform::DataLayout::kNCHW}; paddle::operators::ConvArgs args4{&transformed_dX, ddW, &transformed_dO_channel, strides, padding_common, dilations, - dtype}; + dtype, + groups, + paddle::platform::DataLayout::kNCHW}; #ifdef PADDLE_WITH_HIP paddle::operators::SearchResult fwd_result1; @@ -298,9 +308,6 @@ void ConvCudnnGradGradKernel( filter_result; #endif - auto layout = paddle::platform::GetCudnnTensorFormat( - paddle::platform::DataLayout::kNCHW); - // ddo = conv(ddI, W) + conv(I, ddW) size_t workspace_size = 0; @@ -329,7 +336,7 @@ void ConvCudnnGradGradKernel( #else using search1 = paddle::operators::SearchAlgorithm; - fwd_result1 = search1::Find(args1, exhaustive_search, false, ctx); + fwd_result1 = search1::Find(ctx, args1, exhaustive_search, false); workspace_size = search1::GetWorkspaceSize(args1, fwd_result1.algo); #endif } @@ -357,7 +364,7 @@ void ConvCudnnGradGradKernel( #else using search2 = paddle::operators::SearchAlgorithm; - fwd_result2 = search2::Find(args2, exhaustive_search, false, ctx); + fwd_result2 = search2::Find(ctx, args2, exhaustive_search, false); workspace_size = std::max( workspace_size, search2::GetWorkspaceSize(args2, fwd_result2.algo)); #endif @@ -387,7 +394,7 @@ void ConvCudnnGradGradKernel( using search3 = paddle::operators::SearchAlgorithm; filter_result = - search3::Find(args3, exhaustive_search, deterministic, ctx); + search3::Find(ctx, args3, exhaustive_search, deterministic); workspace_size = std::max( workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo)); #endif @@ -417,7 +424,7 @@ void ConvCudnnGradGradKernel( using search4 = paddle::operators::SearchAlgorithm; data_result = - search4::Find(args4, exhaustive_search, deterministic, ctx); + search4::Find(ctx, args4, exhaustive_search, deterministic); workspace_size = std::max( workspace_size, search4::GetWorkspaceSize(args4, data_result.algo)); #endif diff --git a/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu b/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu index 4e9c37879c002..2d61ec6e62c9c 100644 --- a/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu @@ -251,27 +251,33 @@ void ConvCudnnGradKernel(const Context& ctx, T* input_grad_data = nullptr; T* transformed_input_grad_data = nullptr; + paddle::platform::DataLayout layout = + compute_format == paddle::platform::DataLayout::kNHWC + ? paddle::platform::DataLayout::kNHWC + : paddle::platform::DataLayout::kNCHW; + paddle::operators::ConvArgs args1{&transformed_input_grad, &transformed_filter_channel, &transformed_output_grad_channel, strides, padding_common, dilations, - dtype}; + dtype, + groups, + layout}; paddle::operators::ConvArgs args2{&transformed_input, &transformed_filter_grad_channel, &transformed_output_grad_channel, strides, padding_common, dilations, - dtype}; + dtype, + groups, + layout}; auto handle = ctx.cudnn_handle(); // TODO(phlrain): replace paddle::platform::DataLaytout to phi::DataLayout - paddle::platform::DataLayout layout = - compute_format == paddle::platform::DataLayout::kNHWC - ? paddle::platform::DataLayout::kNHWC - : paddle::platform::DataLayout::kNCHW; + if (transformed_input.dims().size() == 5) { layout = compute_format == paddle::platform::DataLayout::kNHWC ? paddle::platform::DataLayout::kNDHWC @@ -367,9 +373,8 @@ void ConvCudnnGradKernel(const Context& ctx, #else using search1 = paddle::operators::SearchAlgorithm; - bwd_result = search1::Find(args1, exhaustive_search, deterministic, ctx); - workspace_size_d = std::max( - workspace_size_d, search1::GetWorkspaceSize(args1, bwd_result.algo)); + bwd_result = search1::Find(ctx, args1, exhaustive_search, deterministic); + workspace_size_d = std::max(workspace_size_d, bwd_result.workspace_size); #endif } @@ -397,11 +402,10 @@ void ConvCudnnGradKernel(const Context& ctx, using search2 = paddle::operators::SearchAlgorithm; filter_result = - search2::Find(args2, exhaustive_search, deterministic, ctx); + search2::Find(ctx, args2, exhaustive_search, deterministic); VLOG(3) << "filter algo: " << filter_result.algo << ", time " << filter_result.time; - workspace_size_w = std::max( - workspace_size_w, search2::GetWorkspaceSize(args2, filter_result.algo)); + workspace_size_w = std::max(workspace_size_w, filter_result.workspace_size); #endif } diff --git a/paddle/phi/kernels/gpudnn/conv_kernel.cu b/paddle/phi/kernels/gpudnn/conv_kernel.cu index bd95a32bc724f..80544025ff738 100644 --- a/paddle/phi/kernels/gpudnn/conv_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_kernel.cu @@ -56,8 +56,7 @@ void ConvCudnnKernel(const Context& ctx, bool exhaustive_search = FLAGS_cudnn_exhaustive_search || exhaustive_search_t; bool deterministic = FLAGS_cudnn_deterministic; - auto exhaustive_deterministic = exhaustive_search && deterministic; - PADDLE_ENFORCE_EQ(exhaustive_deterministic, + PADDLE_ENFORCE_EQ(exhaustive_search && deterministic, false, phi::errors::InvalidArgument( "Cann't set exhaustive_search True and " @@ -213,7 +212,9 @@ void ConvCudnnKernel(const Context& ctx, strides, padding_common, dilations, - dtype}; + dtype, + groups, + compute_format}; auto handle = ctx.cudnn_handle(); auto workspace_handle = ctx.cudnn_workspace_handle(); @@ -313,8 +314,8 @@ void ConvCudnnKernel(const Context& ctx, paddle::operators::SearchResult fwd_result; using search = paddle::operators::SearchAlgorithm; - fwd_result = search::Find(args, exhaustive_search, deterministic, ctx); - workspace_size = search::GetWorkspaceSize(args, fwd_result.algo); + fwd_result = search::Find(ctx, args, exhaustive_search, deterministic); + workspace_size = fwd_result.workspace_size; #endif #if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION_MIN(7, 0, 1) diff --git a/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu b/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu index 0ce16f66becfa..36a3caf97eb94 100644 --- a/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu @@ -179,14 +179,18 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, strides, padding_common, dilations_, - dtype}; + dtype, + groups, + layout}; paddle::operators::ConvArgs args2{&transformed_dout, &filter, &x_transpose, strides, padding_common, dilations_, - dtype}; + dtype, + groups, + layout}; #ifdef PADDLE_WITH_HIP paddle::operators::SearchResult fwd_result; @@ -226,7 +230,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, #else using search1 = paddle::operators::SearchAlgorithm; - fwd_result = search1::Find(args1, false, deterministic, ctx); + fwd_result = search1::Find(ctx, args1, false, deterministic, false); workspace_size = std::max( workspace_size, search1::GetWorkspaceSize(args1, fwd_result.algo)); #endif @@ -253,7 +257,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, #else using search2 = paddle::operators::SearchAlgorithm; - filter_result = search2::Find(args2, false, deterministic, ctx); + filter_result = search2::Find(ctx, args2, false, deterministic, false); workspace_size = std::max( workspace_size, search2::GetWorkspaceSize(args2, filter_result.algo)); #endif @@ -625,6 +629,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( auto dtype = paddle::platform::CudnnDataType::type; auto handle = ctx.cudnn_handle(); + auto layout = paddle::platform::GetCudnnTensorFormat(GPUDNNDataLayout::kNCHW); paddle::operators::ConvArgs args1{&transformed_ddout_channel, &filter, @@ -632,14 +637,18 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( strides, padding_common, dilations_, - dtype}; + dtype, + groups, + GPUDNNDataLayout::kNCHW}; paddle::operators::ConvArgs args2{&transformed_ddout_channel, &ddfilter, &transformed_x, strides, padding_common, dilations_, - dtype}; + dtype, + groups, + GPUDNNDataLayout::kNCHW}; paddle::operators::ConvArgs args3{&transformed_dout, dfilter, @@ -647,14 +656,18 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( strides, padding_common, dilations_, - dtype}; + dtype, + groups, + GPUDNNDataLayout::kNCHW}; paddle::operators::ConvArgs args4{&transformed_dout, &ddfilter, &transformed_dx_channel, strides, padding_common, dilations_, - dtype}; + dtype, + groups, + GPUDNNDataLayout::kNCHW}; #ifdef PADDLE_WITH_HIP paddle::operators::SearchResult bwd_result1; paddle::operators::SearchResult bwd_result2; @@ -669,8 +682,6 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( paddle::operators::SearchResult fwd_result; #endif - auto layout = paddle::platform::GetCudnnTensorFormat(GPUDNNDataLayout::kNCHW); - // ddo = conv(ddI, filter) + conv(I, ddfilter) size_t workspace_size = 0; @@ -699,7 +710,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( #else using search1 = paddle::operators::SearchAlgorithm; - bwd_result1 = search1::Find(args1, false, deterministic, ctx); + bwd_result1 = search1::Find(ctx, args1, false, deterministic, false); workspace_size = search1::GetWorkspaceSize(args1, bwd_result1.algo); #endif @@ -723,7 +734,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( #else using search2 = paddle::operators::SearchAlgorithm; - bwd_result2 = search2::Find(args2, false, deterministic, ctx); + bwd_result2 = search2::Find(ctx, args2, false, deterministic, false); workspace_size = std::max( workspace_size, search2::GetWorkspaceSize(args2, bwd_result2.algo)); #endif @@ -750,7 +761,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( #else using search3 = paddle::operators::SearchAlgorithm; - filter_result = search3::Find(args3, false, deterministic, ctx); + filter_result = search3::Find(ctx, args3, false, deterministic, false); workspace_size = std::max( workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo)); #endif @@ -778,7 +789,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( #else using search4 = paddle::operators::SearchAlgorithm; - fwd_result = search4::Find(args4, false, deterministic, ctx); + fwd_result = search4::Find(ctx, args4, false, deterministic, false); workspace_size = std::max( workspace_size, search4::GetWorkspaceSize(args4, fwd_result.algo)); #endif diff --git a/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu b/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu index 58ead4c3287f8..5aa7bd60a0aa8 100644 --- a/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu @@ -205,7 +205,9 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx, strides, padding_common, dilations_, - dtype}; + dtype, + groups, + data_layout}; args.handle = handle; args.idesc.set(transformed_out, iwo_groups); args.wdesc.set(filter, layout_tensor, iwo_groups); @@ -228,7 +230,7 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx, paddle::operators::SearchResult bwd_result; using search = paddle::operators::SearchAlgorithm; - bwd_result = search::Find(args, false, deterministic, ctx); + bwd_result = search::Find(ctx, args, false, deterministic, false); workspace_size = std::max(workspace_size, search::GetWorkspaceSize(args, bwd_result.algo)); #endif diff --git a/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h b/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h index be32f85fe99a4..6c75ab86d7c4c 100644 --- a/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h @@ -100,7 +100,7 @@ void MatMul(const Context& dev_ctx, const DenseTensor& b, bool trans_b, DenseTensor* out, - bool flag = false) { + bool flag) { dev_ctx.template Alloc(out); auto blas = phi::funcs::GetBlas(dev_ctx); auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); @@ -120,6 +120,32 @@ void MatMul(const Context& dev_ctx, dev_ctx.template Alloc(out), static_cast(flag)); } +template +void MatMul(const Context& dev_ctx, + const DenseTensor& a, + bool trans_a, + const DenseTensor& b, + bool trans_b, + DenseTensor* out) { + dev_ctx.template Alloc(out); + auto blas = phi::funcs::GetBlas(dev_ctx); + auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); + auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); + if (a.dims().size() == 3 && b.dims().size() <= 2) { + // the transpose_X must be false, if is true, the transpose cost much time + if (!trans_a) { + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + } + } + blas.MatMul(a.data(), + mat_dim_a, + b.data(), + mat_dim_b, + static_cast(1), + dev_ctx.template Alloc(out), + static_cast(false)); +} /** * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index 99257ce4a6adf..6e2e8e3634c6e 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -478,12 +478,24 @@ void MatMulFunction(const Context& dev_ctx, DenseTensor* Out, bool trans_x, bool trans_y, - bool flag = false) { + bool flag) { const std::vector x_dims = vectorize(X.dims()); const std::vector y_dims = vectorize(Y.dims()); MatMulFunction( dev_ctx, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag); } +template +void MatMulFunction(const Context& dev_ctx, + const DenseTensor& X, + const DenseTensor& Y, + DenseTensor* Out, + bool trans_x, + bool trans_y) { + const std::vector x_dims = vectorize(X.dims()); + const std::vector y_dims = vectorize(Y.dims()); + MatMulFunction( + dev_ctx, X, Y, x_dims, y_dims, Out, trans_x, trans_y, false); +} template void MatmulKernel(const Context& dev_ctx, diff --git a/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu b/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu index a8e88f351ccbc..389737037a38e 100644 --- a/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/kernels/sparse/coalesce_kernel.h" - +#include +#include #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" From 9da652d10037c8525bc34299779e653183297746 Mon Sep 17 00:00:00 2001 From: qingshui Date: Tue, 9 Jan 2024 11:27:54 +0800 Subject: [PATCH 8/9] merge master (#106) * fused_seqpool_cvm_with_conv support filter by threshold * add fill zero in fused_seqpool_cvm * add fused seq tensor && support transpose batch fc weight --------- Co-authored-by: mojingcj Co-authored-by: jiaoxuewu Co-authored-by: yuandong1998 <1377526365@qq.com> Co-authored-by: shangzhongbin --- paddle/fluid/operators/batch_fc_op.cc | 56 ++++ paddle/fluid/operators/batch_fc_op.cu | 85 +++++ .../operators/fused/fused_seq_tensor_op.cc | 132 ++++++++ .../operators/fused/fused_seq_tensor_op.cu | 290 ++++++++++++++++++ .../operators/fused/fused_seq_tensor_op.h | 16 + .../operators/fused/fused_seqpool_cvm_op.cc | 2 + .../operators/fused/fused_seqpool_cvm_op.cu | 36 ++- .../fused/fused_seqpool_cvm_with_conv_op.cc | 4 + .../fused/fused_seqpool_cvm_with_conv_op.cu | 55 +++- python/paddle/fluid/contrib/layers/nn.py | 85 ++++- 10 files changed, 745 insertions(+), 16 deletions(-) create mode 100644 paddle/fluid/operators/fused/fused_seq_tensor_op.cc create mode 100644 paddle/fluid/operators/fused/fused_seq_tensor_op.cu create mode 100644 paddle/fluid/operators/fused/fused_seq_tensor_op.h diff --git a/paddle/fluid/operators/batch_fc_op.cc b/paddle/fluid/operators/batch_fc_op.cc index 9ee4bad1d73b7..7cc1844393b03 100644 --- a/paddle/fluid/operators/batch_fc_op.cc +++ b/paddle/fluid/operators/batch_fc_op.cc @@ -44,6 +44,61 @@ class BatchFCOp : public framework::OperatorWithKernel { auto w_dims = ctx->GetInputDim("W"); int batchcount = ctx->Attrs().Get("batchcount"); + int transpose_weight = ctx->Attrs().Get("transpose_weight"); + + if (transpose_weight) { + // Input_dim: [batch_count, ?, in_dim] + // W_dim: [in_dim, batch_count * out_dim] + // Bias_dim: [1, batch_count * out_dim] + // Out_dim: [batch_count, ?, out_dim] + PADDLE_ENFORCE_GT( + batchcount, + 0, + platform::errors::PreconditionNotMet( + "with transpose weight, batchcount should > 0")); + PADDLE_ENFORCE_EQ( + w_dims.size(), + 2, + platform::errors::InvalidArgument( + "W of BatchFCOp should have 2D.")); + + int out_dim = w_dims[1] / batchcount; + PADDLE_ENFORCE_EQ( + input_dims.size(), + 3, + platform::errors::InvalidArgument( + "Input of BatchFCOp should have 3D.")); + PADDLE_ENFORCE_EQ( + input_dims[2], + w_dims[0], + platform::errors::InvalidArgument( + "Input.dim[2] and w_dims[0] of BatchFCOp should be same.")); + PADDLE_ENFORCE_EQ( + input_dims[0], + batchcount, + platform::errors::InvalidArgument( + "Input.dim[0] and batchcount of BatchFCOp should be same.")); + PADDLE_ENFORCE_EQ( + input_dims[2], + w_dims[0], + platform::errors::InvalidArgument( + "Input.dim[2] and W.dim[1] of BatchFCOp should be same.")); + + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ( + bias_dims.size(), + 2, + platform::errors::InvalidArgument("Bias of BatchFCOp should have 2D.")); + PADDLE_ENFORCE_EQ( + bias_dims[1], + w_dims[1], + platform::errors::InvalidArgument( + "Bias.dim[1] should be same as input.dim[2].")); + + ctx->SetOutputDim("Out", {input_dims[0], input_dims[1], out_dim}); + ctx->ShareLoD("Input", /*->*/ "Out"); + return; + } if (batchcount > 0) { int feature_dim = input_dims[1] / batchcount; PADDLE_ENFORCE_EQ(feature_dim, w_dims[0], @@ -139,6 +194,7 @@ class BatchFCOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Bias", "(Tensor) Input tensor of batch_fc_op operator."); AddOutput("Out", "Output tensor of batch_fc_op operator."); AddAttr("batchcount", "(int64_t) the batchcount").SetDefault(0); + AddAttr("transpose_weight", "(bool) the transpose_weight").SetDefault(false); AddComment(R"DOC( BatchFC Operator. Notice: It currently supports GPU device. diff --git a/paddle/fluid/operators/batch_fc_op.cu b/paddle/fluid/operators/batch_fc_op.cu index f9fac45ef6e5e..652eddb560099 100644 --- a/paddle/fluid/operators/batch_fc_op.cu +++ b/paddle/fluid/operators/batch_fc_op.cu @@ -171,11 +171,96 @@ void transpose_split_row(cudaStream_t stream, const unsigned int rown, stream>>>(rown, coln, num_block, source, dest); } +template +__global__ void transpose_weight_kernel(const T* source, T* dest, + const unsigned int rown, const unsigned int coln, const int64_t batch_count) { + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + if (x < rown && y < coln) { + int dst_coln = coln / batch_count; + int dst_x = x + y / dst_coln * rown; + int dst_y = y % dst_coln; + dest[dst_x * dst_coln + dst_y] = source[x * coln + y]; + } +} + +template +void transpose_weight_impl(cudaStream_t stream, const T* source, T* dest, + const unsigned int rown, const unsigned int coln, const int64_t batch_count) { + dim3 grid((rown + 15) / 16, (coln + 15) / 16); + dim3 block(16, 16); + transpose_weight_kernel<<>>(source, dest, rown, coln, batch_count); +} + template class BatchFCCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { int batchcount = ctx.Attr("batchcount"); + auto transpose_weight = ctx.Attr("transpose_weight"); + if (transpose_weight) { + // Input_dim: [batch_count, ?, in_dim] + // W_dim: [in_dim, batch_count * out_dim] + // Bias_dim: [1, batch_count * out_dim] + // Out_dim: [batch_count, ?, out_dim] + auto* input = ctx.Input("Input"); + auto* w = ctx.Input("W"); + auto* bias = ctx.Input("Bias"); + auto* output = ctx.Output("Out"); + auto input_dims = input->dims(); + auto w_dims = w->dims(); + auto slot_pairs_num = input_dims[0]; + auto ins_num = input_dims[1]; + auto in_dim = input_dims[2]; + auto out_dim = w_dims[1] / batchcount; + + // get data ptr + const T* in_data = input->data(); + const T* w_data = w->data(); + const T* bias_data = bias->data(); + + output->Resize({slot_pairs_num, ins_num, out_dim}); + T* out_data = output->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + + Tensor w_help; + w_help = + ctx.AllocateTmpTensor({batchcount, w_dims[0], w_dims[1] / batchcount}, dev_ctx); + T* w_help_data = w_help.data(); + + transpose_weight_impl(ctx.cuda_device_context().stream(), w_data, w_help_data, w_dims[0], w_dims[1], batchcount); + + CBLAS_TRANSPOSE transA = CblasNoTrans; + CBLAS_TRANSPOSE transB = CblasNoTrans; + + T alpha = 1; + T beta = 0; + int64_t strideA = ins_num * in_dim; + int64_t strideB = in_dim * out_dim; + + auto blas = phi::funcs::GetBlas(dev_ctx); + blas.BatchedGEMM(transA, + transB, + ins_num, + out_dim, + in_dim, + alpha, + in_data, + w_help_data, + beta, + out_data, + slot_pairs_num, + strideA, + strideB); + add_bias(ctx.cuda_device_context().stream(), + out_data, + slot_pairs_num, + ins_num, + out_dim, + bias_data); + return; + } if (batchcount > 0) { auto* input = ctx.Input("Input"); auto* w = ctx.Input("W"); diff --git a/paddle/fluid/operators/fused/fused_seq_tensor_op.cc b/paddle/fluid/operators/fused/fused_seq_tensor_op.cc new file mode 100644 index 0000000000000..5ca2ec345f10e --- /dev/null +++ b/paddle/fluid/operators/fused/fused_seq_tensor_op.cc @@ -0,0 +1,132 @@ +#include "paddle/fluid/operators/fused/fused_seq_tensor_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include + +namespace paddle { +namespace operators { + +class FusedSeqTensorOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "FusedSeqTensorOp"); + OP_INOUT_CHECK(ctx->HasInput("ADInput"), "ADInput", "ADInput", "FusedSeqTensorOp"); + + OP_INOUT_CHECK(ctx->HasOutput("DINOut"), "DINOut", "DINOut", "FusedSeqTensorOp"); + OP_INOUT_CHECK(ctx->HasOutput("MaskOut"), "MaskOut", "MaskOut", "FusedSeqTensorOp"); + OP_INOUT_CHECK(ctx->HasOutput("SideInfoOut"), "SideInfoOut", "SideInfoOut", "FusedSeqTensorOp"); + OP_INOUT_CHECK(ctx->HasOutput("ADSlotSessionOut"), "ADSlotSessionOut", "ADSlotSessionOut", "FusedSeqTensorOp"); + + const framework::DDim input_dims = ctx->GetInputDim("Input"); + const framework::DDim ad_input_dims = ctx->GetInputDim("ADInput"); + + auto ad_slot_num = ctx->Attrs().Get("ad_slot_num"); + auto batch_count = ctx->Attrs().Get("batch_count"); + auto max_length = ctx->Attrs().Get("max_length"); + auto slot_num = ctx->Attrs().Get("slot_num"); + auto fea_emb_dim = ctx->Attrs().Get("fea_emb_dim"); + auto ad_slot_offset = ctx->Attrs().Get("ad_slot_offset"); + + int64_t one_ins_dim = batch_count * max_length * slot_num * fea_emb_dim; + PADDLE_ENFORCE_EQ( + input_dims[1], one_ins_dim, + platform::errors::InvalidArgument( + "input dims error, %ld != %ld", input_dims[1], one_ins_dim)); + + int64_t one_ins_ad_dim = batch_count * 1 * ad_slot_num * fea_emb_dim; + PADDLE_ENFORCE_EQ( + ad_input_dims[1], one_ins_ad_dim, + platform::errors::InvalidArgument( + "input dims error, %ld != %ld", ad_input_dims[1], one_ins_ad_dim)); + PADDLE_ENFORCE_LT( + ad_slot_num, slot_num, + platform::errors::InvalidArgument( + "ad_slot_num [%ld] > slot_num [%ld]", ad_slot_num, slot_num)); + PADDLE_ENFORCE_GT( + ad_slot_num, 0, + platform::errors::InvalidArgument( + "ad_slot_num [%ld] <= 0", ad_slot_num)); + PADDLE_ENFORCE_LT( + ad_slot_offset, slot_num - 1, + platform::errors::InvalidArgument( + "ad_slot_num [%ld] > slot_num - 1 [%ld]", ad_slot_offset, slot_num)); + PADDLE_ENFORCE_GE( + ad_slot_offset, 0, + platform::errors::InvalidArgument( + "ad_slot_offset [%ld] < 0", ad_slot_offset)); + if (ad_slot_offset != 0) { + PADDLE_ENFORCE_EQ( + ad_slot_num + ad_slot_offset, slot_num, + platform::errors::InvalidArgument( + "ad_slot_num [%ld] + ad_slot_offset [%ld] != slot_num [%ld]", ad_slot_num, ad_slot_offset, slot_num)); + } + + auto ins_num = input_dims[0]; + if (batch_count > 1) { + ctx->SetOutputDim("DINOut", {batch_count, ins_num * max_length, ad_slot_num * fea_emb_dim * 4}); + ctx->SetOutputDim("MaskOut", {batch_count, ins_num, max_length}); + ctx->SetOutputDim("SideInfoOut", {batch_count, ins_num * max_length, (slot_num - ad_slot_num) * fea_emb_dim}); + ctx->SetOutputDim("ADSlotSessionOut", {batch_count, ins_num * max_length, ad_slot_num, fea_emb_dim}); + } else { + ctx->SetOutputDim("DINOut", {ins_num, max_length, ad_slot_num * fea_emb_dim * 4}); + ctx->SetOutputDim("MaskOut", {ins_num, max_length}); + ctx->SetOutputDim("SideInfoOut", {ins_num, max_length, (slot_num - ad_slot_num) * fea_emb_dim}); + ctx->SetOutputDim("ADSlotSessionOut", {ins_num, max_length, ad_slot_num * fea_emb_dim}); + } + ctx->ShareLoD("Input", "DINOut"); + ctx->ShareLoD("Input", "MaskOut"); + ctx->ShareLoD("Input", "SideInfoOut"); + ctx->ShareLoD("Input", "ADSlotSessionOut"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); + } +}; + +class FusedSeqTensorOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", + "The input tensors of operator."); + AddInput("ADInput", + "The input ad tensors of operator. "); + AddOutput("DINOut", + "DINOut"); + AddOutput("MaskOut", + "MaskOut"); + AddOutput("SideInfoOut", + "SideInfoOut"); + AddOutput("ADSlotSessionOut", + "ADSlotSessionOut"); + + AddAttr("batch_count", "(int, default 1)"); + AddAttr("max_length", "(int, default 1)"); + AddAttr("slot_num", "(int, default 1)"); + AddAttr("fea_emb_dim", "(int, default 1)"); + AddAttr("ad_slot_num", "(int, default 1)"); + AddAttr("ad_slot_offset", "(int, default 1)"); + + AddComment(R"DOC( +Fuse seq tensor. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(fused_seq_tensor, + ops::FusedSeqTensorOp, ops::FusedSeqTensorOpMaker); + +REGISTER_OP_CPU_KERNEL( + fused_seq_tensor, + ops::FusedSeqTensorCPUKernel); diff --git a/paddle/fluid/operators/fused/fused_seq_tensor_op.cu b/paddle/fluid/operators/fused/fused_seq_tensor_op.cu new file mode 100644 index 0000000000000..d2fdf364d731d --- /dev/null +++ b/paddle/fluid/operators/fused/fused_seq_tensor_op.cu @@ -0,0 +1,290 @@ +#include +#include +#include +#include "paddle/fluid/operators/fused/fused_seq_tensor_op.h" // don't remove this +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace paddle { +namespace operators { + +template +__global__ void cal_ad_slot_session_kernel(const T* input, + const T* ad_input, + T* din_output, + T* ad_slot_session_output, + const size_t batch_num, + const size_t ins_num, + const size_t slot_num, + const size_t max_length, + const size_t fea_emb_dim, + const size_t ad_slot_num, + const size_t ad_slot_offset) { + + size_t batch_idx = blockIdx.x; + size_t ins_idx = blockIdx.y; + size_t fea_idx = blockIdx.z; + + const size_t one_slot_dim = max_length * fea_emb_dim; + const size_t one_seq_dim = slot_num * one_slot_dim; + const size_t ad_seq_dim = ad_slot_num * one_slot_dim; + + const size_t piece_of_ad_seq_dim = ad_slot_num * fea_emb_dim; + for (size_t idx = threadIdx.x; idx < piece_of_ad_seq_dim; idx += blockDim.x) { + size_t slot_idx = idx / fea_emb_dim + ad_slot_offset; + size_t out_slot_idx = idx / fea_emb_dim; + size_t fea_dim_idx = idx % fea_emb_dim; + + size_t input_fea_begin_idx = ins_idx * (batch_num * one_seq_dim) + batch_idx * one_seq_dim + + slot_idx * one_slot_dim + fea_idx * fea_emb_dim; + + size_t ad_fea_begin_idx = + ins_idx * (1 * batch_num * piece_of_ad_seq_dim) + batch_idx * piece_of_ad_seq_dim + + out_slot_idx * fea_emb_dim; + + const T input_val = input[input_fea_begin_idx + fea_dim_idx]; + const T ad_val = ad_input[ad_fea_begin_idx + fea_dim_idx]; + + size_t fea_concat_start_idx = + batch_idx * (ins_num * ad_seq_dim * 4) + ins_idx * (ad_seq_dim * 4) + + fea_idx * (piece_of_ad_seq_dim * 4) + out_slot_idx * fea_emb_dim; + + din_output[fea_concat_start_idx + fea_dim_idx] = input_val; + din_output[fea_concat_start_idx + fea_dim_idx + piece_of_ad_seq_dim] = ad_val; + din_output[fea_concat_start_idx + fea_dim_idx + piece_of_ad_seq_dim * 2] = input_val - ad_val; + din_output[fea_concat_start_idx + fea_dim_idx + piece_of_ad_seq_dim * 3] = input_val * ad_val; + + size_t ad_slot_session_out_start_idx = + batch_idx * (ins_num * ad_seq_dim) + ins_idx * ad_seq_dim + + fea_idx * piece_of_ad_seq_dim + out_slot_idx * fea_emb_dim; + ad_slot_session_output[ad_slot_session_out_start_idx + fea_dim_idx] = input_val; + } +} + +template +__global__ void cal_sideinfo_kernel(const T* input, + T* side_info_output, + const size_t batch_num, + const size_t ins_num, + const size_t slot_num, + const size_t max_length, + const size_t fea_emb_dim, + const size_t sideinfo_slot_num, + const size_t sideinfo_slot_offset) { + + size_t batch_idx = blockIdx.x; + size_t ins_idx = blockIdx.y; + size_t fea_idx = blockIdx.z; + + const size_t one_slot_dim = max_length * fea_emb_dim; + const size_t input_one_seq_dim = slot_num * one_slot_dim; + const size_t sideinfo_seq_dim = sideinfo_slot_num * one_slot_dim; + + const size_t piece_of_sideinfo_seq_dim = sideinfo_slot_num * fea_emb_dim; + for (size_t idx = threadIdx.x; idx < piece_of_sideinfo_seq_dim; idx += blockDim.x) { + size_t out_slot_idx = idx / fea_emb_dim; + size_t slot_idx = out_slot_idx + sideinfo_slot_offset; + size_t fea_dim_idx = idx % fea_emb_dim; + + size_t input_fea_begin_idx = ins_idx * (batch_num * input_one_seq_dim) + batch_idx * input_one_seq_dim + + slot_idx * one_slot_dim + fea_idx * fea_emb_dim; + + size_t fea_transpose_start_idx = + batch_idx * (ins_num * sideinfo_seq_dim) + ins_idx * sideinfo_seq_dim + + fea_idx * (sideinfo_slot_num * fea_emb_dim) + out_slot_idx * fea_emb_dim; + + side_info_output[fea_transpose_start_idx + fea_dim_idx] = input[input_fea_begin_idx + fea_dim_idx]; + } +} + +template +__global__ void cal_sideinfo_kernel_without_loop(const T* input, + T* side_info_output, + const size_t batch_num, + const size_t ins_num, + const size_t slot_num, + const size_t max_length, + const size_t fea_emb_dim, + const size_t sideinfo_slot_num, + const size_t sideinfo_slot_offset) { + + size_t batch_idx = blockIdx.x; + size_t ins_idx = blockIdx.y; + size_t fea_idx = blockIdx.z; + + size_t slot_idx = threadIdx.y + sideinfo_slot_offset; + size_t out_slot_idx = threadIdx.y; + size_t fea_dim_idx = threadIdx.x; + + const size_t one_slot_dim = max_length * fea_emb_dim; + size_t input_one_seq_dim = slot_num * one_slot_dim; + size_t out_one_seq_dim = sideinfo_slot_num * one_slot_dim; + + size_t input_fea_begin_idx = ins_idx * (batch_num * input_one_seq_dim) + batch_idx * (input_one_seq_dim) + + slot_idx * one_slot_dim + fea_idx * fea_emb_dim; + + size_t fea_transpose_start_idx = + batch_idx * (ins_num * out_one_seq_dim) + ins_idx * out_one_seq_dim + + fea_idx * (sideinfo_slot_num * fea_emb_dim) + out_slot_idx * fea_emb_dim; + + side_info_output[fea_transpose_start_idx + fea_dim_idx] = input[input_fea_begin_idx + fea_dim_idx]; +} + +template +__device__ void warpReduce(volatile T* cache, int tid) { + cache[tid] += cache[tid+32]; + cache[tid] += cache[tid+16]; + cache[tid] += cache[tid+8]; + cache[tid] += cache[tid+4]; + cache[tid] += cache[tid+2]; + cache[tid] += cache[tid+1]; +} + +#define THREAD_PER_BLOCK 128 +template +__global__ void reduce_sum_max_length(const T* input, // 1 + T* mask_output, // mask + const size_t batch_count, + const size_t ins_num, + const size_t slot_num, + const size_t max_length, + const size_t fea_emb_dim) { + size_t batch_idx = blockIdx.x; + size_t ins_idx = blockIdx.y; + size_t fea_idx = blockIdx.z; + + size_t data_len_per_block = slot_num * fea_emb_dim; + + __shared__ T sdata[THREAD_PER_BLOCK]; + //each thread loads one element from global memory to shared mem + size_t input_start_idx = ins_idx * (batch_count * slot_num * max_length * fea_emb_dim) + + batch_idx * (slot_num * max_length * fea_emb_dim); + + size_t tid = threadIdx.x; + // memset shared mem + sdata[tid] = 0; + for (size_t idx = tid; idx < data_len_per_block; idx += blockDim.x) { + size_t slot_idx = idx / fea_emb_dim; + size_t fea_dim_idx = idx % fea_emb_dim; + size_t offset = slot_idx * (max_length * fea_emb_dim) + fea_idx * fea_emb_dim + fea_dim_idx; + sdata[tid] += input[input_start_idx + offset]; + } + __syncthreads(); + + for(size_t s = blockDim.x / 2; s > 32; s >>= 1) { + if (tid < s) { + sdata[tid] += sdata[tid + s]; + } + __syncthreads(); + } + // When s < 32, we have only one warp left, no need to sync threads, no need to if (tid < s) + if(tid < 32) { + warpReduce(sdata, tid); + } + + if(tid == 0) { + // [batch_count, ins_num, max_length] + size_t out_idx = batch_idx * (ins_num * max_length) + + ins_idx * (max_length) + + fea_idx; + if (fabs(sdata[tid]) > 1e-8) { + mask_output[out_idx] = 1; + } else { + mask_output[out_idx] = 0; + } + } +} + +template +class FusedSeqTensorCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto input = ctx.Input("Input"); + PADDLE_ENFORCE_NOT_NULL(input, platform::errors::NotFound("Input not found")); + auto ad_input = ctx.Input("ADInput"); + PADDLE_ENFORCE_NOT_NULL(ad_input, platform::errors::NotFound("Input not found")); + + auto din_output = ctx.Output("DINOut"); + PADDLE_ENFORCE_NOT_NULL(din_output, + platform::errors::NotFound("DINOut not found")); + T* din_output_data = din_output->mutable_data(ctx.GetPlace()); + auto mask_output = ctx.Output("MaskOut"); + PADDLE_ENFORCE_NOT_NULL(mask_output, + platform::errors::NotFound("MaskOut not found")); + T* mask_output_output_data = mask_output->mutable_data(ctx.GetPlace()); + auto side_info_output = ctx.Output("SideInfoOut"); + PADDLE_ENFORCE_NOT_NULL(side_info_output, + platform::errors::NotFound("Output not found")); + T* side_info_output_data = + side_info_output->mutable_data(ctx.GetPlace()); + auto ad_slot_session_output = + ctx.Output("ADSlotSessionOut"); + PADDLE_ENFORCE_NOT_NULL(ad_slot_session_output, + platform::errors::NotFound("Output not found")); + T* ad_slot_session_output_data = + ad_slot_session_output->mutable_data(ctx.GetPlace()); + + auto batch_count = ctx.Attr("batch_count"); + auto max_length = ctx.Attr("max_length"); + auto slot_num = ctx.Attr("slot_num"); + auto fea_emb_dim = ctx.Attr("fea_emb_dim"); + auto ad_slot_num = ctx.Attr("ad_slot_num"); + auto ad_slot_offset = ctx.Attr("ad_slot_offset"); + + auto& dev_ctx = ctx.template device_context(); + auto stream = ctx.cuda_device_context().stream(); + + auto input_dims = input->dims(); + size_t ins_num = input_dims[0]; + + dim3 ad_grid(batch_count, ins_num, max_length); + dim3 ad_block(std::min(static_cast(1024), static_cast(ad_slot_num * fea_emb_dim))); + + cal_ad_slot_session_kernel<<>>( + input->data(), ad_input->data(), din_output_data, + ad_slot_session_output_data, + batch_count, ins_num, slot_num, max_length, fea_emb_dim, + ad_slot_num, ad_slot_offset); + + size_t sideinfo_slot_offset = 0; + if (ad_slot_offset == 0) { + sideinfo_slot_offset = ad_slot_num; + } + size_t fea_padding_dim = ((fea_emb_dim + 31) / 32) * 32; + size_t sideinfo_slot_num = slot_num - ad_slot_num; + + if (sideinfo_slot_num * fea_emb_dim < 1024) { + dim3 sideinfo_grid(batch_count, ins_num, max_length); + dim3 sideinfo_block(fea_emb_dim, sideinfo_slot_num); + cal_sideinfo_kernel_without_loop<<>>( + input->data(), side_info_output_data, batch_count, ins_num, + slot_num, max_length, fea_emb_dim, + sideinfo_slot_num, sideinfo_slot_offset); + } else { + dim3 sideinfo_grid(batch_count, ins_num, max_length); + dim3 sideinfo_block(sideinfo_slot_num * fea_emb_dim); + cal_sideinfo_kernel<<>>( + input->data(), side_info_output_data, batch_count, ins_num, + slot_num, max_length, fea_emb_dim, + sideinfo_slot_num, sideinfo_slot_offset); + } + + dim3 reduce_grid(batch_count, ins_num, max_length); + dim3 reduce_block(THREAD_PER_BLOCK); + reduce_sum_max_length<<>>( + input->data(), mask_output_output_data, batch_count, + ins_num, slot_num, max_length, fea_emb_dim); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + fused_seq_tensor, + ops::FusedSeqTensorCUDAKernel); \ No newline at end of file diff --git a/paddle/fluid/operators/fused/fused_seq_tensor_op.h b/paddle/fluid/operators/fused/fused_seq_tensor_op.h new file mode 100644 index 0000000000000..d7bbadd72e3b5 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_seq_tensor_op.h @@ -0,0 +1,16 @@ +#pragma once +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class FusedSeqTensorCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext&) const override { + PADDLE_THROW(platform::errors::Unimplemented("fused_seq_tensor supports only GPU")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc index 1dfbc30d06606..3945c027364e7 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc @@ -48,6 +48,7 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel { bool clk_filter = ctx->Attrs().Get("clk_filter"); const int embed_thres_size = ctx->Attrs().Get("embed_thres_size"); const int embedx_concate_size = ctx->Attrs().Get("embedx_concate_size"); + //const bool fill_zero = ctx->Attrs().Get("fill_zero"); // need filter quant_ratio more than zero if (ctx->Attrs().Get("need_filter")) { @@ -142,6 +143,7 @@ class FusedSeqpoolCVMOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("embed_thres_size", "(int, default 0)").SetDefault(0); AddAttr("embedx_concate_size", "(int, default 1)").SetDefault(1); AddAttr("embedx_concate_filter", "(bool, default false)").SetDefault(false); + AddAttr("fill_zero", "(bool, default true)").SetDefault(true); AddComment(R"DOC( Fuse multiple pairs of Sequence Pool and CVM Operator. diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu index 93cd61f6df2b4..45dc840d28995 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu @@ -177,7 +177,7 @@ __global__ void FusedSeqpoolKernelEmbedQuantFilterEmbedxConcate( size_t **lods_values, const int batch_size, const int embedding_size, const float pad_value, const int cvm_offset, const float show_coeff, const float clk_coeff, const float threshold, const int quant_ratio, - const float embed_threshold, const int embedx_concate_size, bool embedx_concate_filter) { + const float embed_threshold, const int embedx_concate_size, bool embedx_concate_filter, bool fill_zero) { CUDA_KERNEL_LOOP(i, N) { int key = i / embedding_size; int offset = i % embedding_size; // embedx id @@ -188,11 +188,17 @@ __global__ void FusedSeqpoolKernelEmbedQuantFilterEmbedxConcate( double val = pad_value; int concate_index = 0; + bool val_use_zero = false; for (auto k = start; k < end; ++k) { + val_use_zero = false; T &show = *(input_values[x] + k * embedding_size); T &click = *(input_values[x] + k * embedding_size + 1); if (embedx_concate_filter && (show - click) * show_coeff + click * clk_coeff < threshold) { - continue; + if (fill_zero) { + val_use_zero = true; + } else { + continue; + } } T &embedw = *(input_values[x] + k * embedding_size + cvm_offset); T embedx_weight_score = 0.0; @@ -202,16 +208,28 @@ __global__ void FusedSeqpoolKernelEmbedQuantFilterEmbedxConcate( } embedx_weight_score = std::sqrt(embedx_weight_score) + std::abs(embedw); if (embedx_concate_filter && embedx_weight_score < embed_threshold) { - continue; + if (fill_zero) { + val_use_zero = true; + } else { + continue; + } } if (offset < cvm_offset) { // show & click - val = *(input_values[x] + k * embedding_size + offset); + if (val_use_zero) { + val = pad_value; + } else { + val = *(input_values[x] + k * embedding_size + offset); + } } else { - val = ((static_cast( + if (val_use_zero) { + val = pad_value; + } else { + val = ((static_cast( *(input_values[x] + k * embedding_size + offset) * quant_ratio + 0.5)) / static_cast(quant_ratio)); + } } if (concate_index == embedx_concate_size) { *(seqpool_output_values[x] + y * embedding_size * embedx_concate_size + (embedx_concate_size-1) * embedding_size + offset) += val; @@ -352,7 +370,8 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place, float clk_coeff, float threshold, float embed_threshold, const int quant_ratio, const bool clk_filter, const int embed_thres_size, const int embedx_concate_size, - bool embedx_concate_filter) { + bool embedx_concate_filter, + bool fill_zero) { auto stream = dynamic_cast( platform::DeviceContextPool::Instance().Get(place)) ->stream(); @@ -395,7 +414,7 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place, 0, stream>>>( N, gpu_input_values, gpu_seqpool_output_values, lods_values, batch_size, embedding_size, padding_value, cvm_offset, show_coeff, clk_coeff, - threshold, quant_ratio, embed_threshold, embedx_concate_size, embedx_concate_filter); + threshold, quant_ratio, embed_threshold, embedx_concate_size, embedx_concate_filter, fill_zero); } } else if (need_filter) { // quant need filter FusedSeqpoolKernelQuantFilter<< { const int embed_thres_size = ctx.Attr("embed_thres_size"); const int embedx_concate_size = ctx.Attr("embedx_concate_size"); bool embedx_concate_filter = ctx.Attr("embedx_concate_filter"); + bool fill_zero = ctx.Attr("fill_zero"); framework::GPULodVector gpu_lods[slot_size]; auto place = ctx.GetPlace(); @@ -737,7 +757,7 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel { embedding_size, padding_value, use_cvm, cvm_offset, need_filter, embed_threshold_filter, show_coeff, clk_coeff, threshold, embed_threshold, quant_ratio, clk_filter, - embed_thres_size, embedx_concate_size, embedx_concate_filter); + embed_thres_size, embedx_concate_size, embedx_concate_filter, fill_zero); } }; diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cc b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cc index 2cb1a0caf30ea..66bb9afdde8c6 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cc +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cc @@ -109,6 +109,10 @@ class FusedSeqpoolCVMOpWithConvMaker : public framework::OpProtoAndCheckerMaker "(float, default 0.0) The value to pad for empty sequence.") .SetDefault(0.0); AddAttr("use_cvm", "bool, use cvm or not").SetDefault(true); + AddAttr("need_filter", "(bool, default false)").SetDefault(false); + AddAttr("show_coeff", "(float, default 0.2)").SetDefault(0.2); + AddAttr("clk_coeff", "(float, default 1)").SetDefault(1); + AddAttr("threshold", "(float, default 0.96)").SetDefault(0.96); AddAttr("cvm_offset", "(int, default 3)").SetDefault(3); AddAttr("show_filter", "(bool, default false)").SetDefault(false); AddAttr("embedx_concate_size", "(int, default 1)").SetDefault(1); diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cu b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cu index cb56a9109e6c7..0e01eb1785132 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cu +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cu @@ -53,6 +53,38 @@ __global__ void FusedSeqpoolWithConvKernelNormal(const size_t N, T **input_value } } +// Filter +template +__global__ void FusedSeqpoolWithConvKernelFilter(const size_t N, T **input_values, + T **seqpool_output_values, + size_t **lods_values, + const int batch_size, + const int embedding_size, + const float pad_value, + const float show_coeff, + const float clk_coeff, + const float threshold) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / embedding_size; + int offset = i % embedding_size; + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + auto &start = *(lods_values[x] + y); + auto &end = *(lods_values[x] + y + 1); + + double val = pad_value; + for (auto k = start; k < end; ++k) { + T &show = *(input_values[x] + k * embedding_size); + T &click = *(input_values[x] + k * embedding_size + 1); + if ((show - click) * show_coeff + click * clk_coeff < threshold) { + continue; + } + val += *(input_values[x] + k * embedding_size + offset); + } + *(seqpool_output_values[x] + y * embedding_size + offset) = val; + } +} + // normal & expand slot's feasign template __global__ void FusedSeqpoolWithConvKernelNormalEmbedxConcate(const size_t N, T **input_values, @@ -257,6 +289,8 @@ void FusedSeqpoolCVMWithConv(const paddle::platform::Place &place, std::vector lods, const int batch_size, const int slot_num, const int embedding_size, const float padding_value, const bool use_cvm, + float need_filter, float show_coeff, + float clk_coeff, float threshold, const int cvm_offset, bool show_filter, const int embedx_concate_size) { auto stream = dynamic_cast( @@ -290,10 +324,17 @@ void FusedSeqpoolCVMWithConv(const paddle::platform::Place &place, size_t N = static_cast(batch_size * slot_num * embedding_size); // first sum pool if (embedx_concate_size == 1){ + if (need_filter) { //filter + FusedSeqpoolWithConvKernelFilter<<>>( + N, gpu_input_values, gpu_seqpool_output_values, lods_values, batch_size, + embedding_size, padding_value, show_coeff, clk_coeff, threshold); + } else { //normal FusedSeqpoolWithConvKernelNormal<<>>( - N, gpu_input_values, gpu_seqpool_output_values, lods_values, batch_size, - embedding_size, padding_value); + stream>>>( + N, gpu_input_values, gpu_seqpool_output_values, lods_values, batch_size, + embedding_size, padding_value); + } } else { FusedSeqpoolWithConvKernelNormalEmbedxConcate<<>>( @@ -595,6 +636,10 @@ class FusedSeqpoolCVMWithConvCUDAKernel : public framework::OpKernel { auto padding_value = ctx.Attr("pad_value"); auto use_cvm = ctx.Attr("use_cvm"); + bool need_filter = ctx.Attr("need_filter"); + float show_coeff = ctx.Attr("show_coeff"); + float clk_coeff = ctx.Attr("clk_coeff"); + float threshold = ctx.Attr("threshold"); const int cvm_offset = ctx.Attr("cvm_offset"); bool show_filter = ctx.Attr("show_filter"); const int embedx_concate_size = ctx.Attr("embedx_concate_size"); @@ -638,7 +683,9 @@ class FusedSeqpoolCVMWithConvCUDAKernel : public framework::OpKernel { } FusedSeqpoolCVMWithConv(ctx.GetPlace(), input_data, output_data, seqpool_output_data, lods_data, batch_size, slot_size, - embedding_size, padding_value, use_cvm, cvm_offset, show_filter, embedx_concate_size); + embedding_size, padding_value, use_cvm, + need_filter, show_coeff, clk_coeff, threshold, + cvm_offset, show_filter, embedx_concate_size); } }; diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index 5e3eb92f2d401..5143303a286b1 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -71,6 +71,7 @@ 'fused_seqpool_concat', 'fused_concat', 'rank_attention2', + 'fused_seq_tensor', ] @@ -1601,7 +1602,7 @@ def rank_attention2(input, return output -def batch_fc(input, param_size, param_attr, bias_size, bias_attr, act=None, batchcount=0): +def batch_fc(input, param_size, param_attr, bias_size, bias_attr, act=None, batchcount=0, transpose_weight=False): """ **Batch FC layer** This Op can calculate BatchFC. This is similar to matmul op, @@ -1666,7 +1667,10 @@ def batch_fc(input, param_size, param_attr, bias_size, bias_attr, act=None, batc "W": w, "Bias": b }, - attrs={'batchcount': batchcount}, + attrs={ + 'batchcount': batchcount, + 'transpose_weight': transpose_weight + }, outputs={"Out": pre_act}) return helper.append_activation(pre_act) @@ -1759,7 +1763,8 @@ def fused_seqpool_cvm(input, clk_filter=False, embed_thres_size=0, embedx_concate_size=1, - embedx_concate_filter=False): + embedx_concate_filter=False, + fill_zero=True): """ **Notes: The Op only receives List of LoDTensor as input, only support SUM pooling now. :attr:`input`. @@ -1818,7 +1823,8 @@ def fused_seqpool_cvm(input, "clk_filter": clk_filter, "embed_thres_size": embed_thres_size, "embedx_concate_size": embedx_concate_size, - "embedx_concate_filter": embedx_concate_filter + "embedx_concate_filter": embedx_concate_filter, + "fill_zero": fill_zero }) return outs @@ -1908,6 +1914,10 @@ def fused_seqpool_cvm_with_conv(input, cvm, pad_value=0.0, use_cvm=True, + need_filter=False, + show_coeff=0.2, + clk_coeff=1.0, + threshold=0.96, show_filter=False, cvm_offset=3, embedx_concate_size=1): @@ -1955,6 +1965,10 @@ def fused_seqpool_cvm_with_conv(input, "pad_value": pad_value, "use_cvm": use_cvm, "cvm_offset": cvm_offset, + "need_filter": need_filter, + "show_coeff": show_coeff, + "clk_coeff": clk_coeff, + "threshold": threshold, "show_filter": show_filter, "embedx_concate_size": embedx_concate_size, }) @@ -2817,3 +2831,66 @@ def fused_concat(input, start_index=0, length=-1, axis=1): "length": length}) return out +def fused_seq_tensor(input, + batch_count, + max_length, + slot_num, + ad_slot_num, + fea_emb_dim, + ad_slot_offset): + """ + **fused seq tensor** + Notice: It currently only supports GPU device. + + Args: + input: [input, ad_input], input tensor list with data type float32. + batch_count: parrellel num. + max_length: max_length. + slot_num: slot_num, sum of ad_slot_num and side info slot. + ad_slot_num: ad slot num. + fea_emb_dim: embding dim. + ad_slot_offset: ad slot offset. + + Returns: + Variable: + din_out, mask_out, side_info_out, ad_slot_session_out + """ + + helper = LayerHelper("fused_seq_tensor", **locals()) + + check_type(input, "input", list, 'fused_seq_tensor') + + dtype = helper.input_dtype() + check_dtype(dtype, 'input', ['float32', 'float64'], 'fused_seq_tensor') + + check_type(batch_count, 'batch_count', (int, Variable), 'fused_seq_tensor') + check_type(max_length, 'max_length', (int, Variable), 'fused_seq_tensor') + check_type(slot_num, 'slot_num', (int, Variable), 'fused_seq_tensor') + check_type(fea_emb_dim, 'fea_emb_dim', (int, Variable), 'fused_seq_tensor') + + din_out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) + mask_out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) + side_info_out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) + ad_slot_session_out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) + + helper.append_op( + type="fused_seq_tensor", + inputs={"Input": input[0], + "ADInput": input[1] + }, + attrs={ + 'batch_count': batch_count, + 'max_length': max_length, + 'slot_num': slot_num, + 'fea_emb_dim': fea_emb_dim, + 'ad_slot_num': ad_slot_num, + 'ad_slot_offset': ad_slot_offset + }, + outputs={ + "DINOut": din_out, + "MaskOut": mask_out, + "SideInfoOut": side_info_out, + "ADSlotSessionOut": ad_slot_session_out + }) + + return din_out, mask_out, side_info_out, ad_slot_session_out From 2809f5f485e78c4501bcd7b5520542ad317fcaa1 Mon Sep 17 00:00:00 2001 From: humingqing Date: Fri, 12 Jan 2024 14:48:57 +0800 Subject: [PATCH 9/9] rollback --- paddle/phi/kernels/gpu/matmul_grad_kernel.cu | 52 -------------------- paddle/phi/kernels/gpu/matmul_kernel.cu | 39 --------------- 2 files changed, 91 deletions(-) diff --git a/paddle/phi/kernels/gpu/matmul_grad_kernel.cu b/paddle/phi/kernels/gpu/matmul_grad_kernel.cu index 1591d86d8cf59..b6c13360cd404 100644 --- a/paddle/phi/kernels/gpu/matmul_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/matmul_grad_kernel.cu @@ -14,60 +14,8 @@ limitations under the License. */ #include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" #include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h" #include "paddle/phi/kernels/matmul_grad_kernel.h" -namespace phi { -template <> -void MatMul(const phi::GPUContext& dev_ctx, - const DenseTensor& a, - bool trans_a, - const DenseTensor& b, - bool trans_b, - DenseTensor* out) { - dev_ctx.template Alloc(out); -#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) - if (a.dims().size() == 2 && b.dims().size() == 2) { - auto& x_dims = a.dims(); // M * K - auto& y_dims = b.dims(); // K * N - const int M = trans_a ? x_dims[1] : x_dims[0]; - const int K = trans_a ? x_dims[0] : x_dims[1]; - const int N = trans_b ? y_dims[0] : y_dims[1]; - phi::funcs::LinearWithCublasLt::Run( - dev_ctx, - &a, // x - &b, // y - out, // out - nullptr, // bias - nullptr, - M, // M bsz_seqf - N, // N output_size - K, // K input_size - trans_a, - trans_b, - phi::funcs::MatmulFusedType::kMatmul); - return; - } -#endif - auto blas = phi::funcs::GetBlas(dev_ctx); - auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a.dims(), 0, trans_a); - auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b.dims(), 0, trans_b); - if (a.dims().size() == 3 && b.dims().size() <= 2) { - // the transpose_X must be false, if is true, the transpose cost much time - if (!trans_a) { - mat_dim_a.height_ *= mat_dim_a.batch_size_; - mat_dim_a.batch_size_ = 0; - } - } - blas.MatMul(a.data(), - mat_dim_a, - b.data(), - mat_dim_b, - static_cast(1), - dev_ctx.template Alloc(out), - static_cast(false)); -} -} // namespace phi PD_REGISTER_KERNEL(matmul_grad, GPU, diff --git a/paddle/phi/kernels/gpu/matmul_kernel.cu b/paddle/phi/kernels/gpu/matmul_kernel.cu index e96a76b1d1e7b..32d70ae0763f0 100644 --- a/paddle/phi/kernels/gpu/matmul_kernel.cu +++ b/paddle/phi/kernels/gpu/matmul_kernel.cu @@ -15,48 +15,9 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" #include "paddle/phi/kernels/impl/matmul_kernel_impl.h" #include "paddle/phi/kernels/matmul_kernel.h" -namespace phi { -template <> -void MatMulFunction(const phi::GPUContext& dev_ctx, - const DenseTensor& X, - const DenseTensor& Y, - DenseTensor* Out, - bool trans_x, - bool trans_y) { -#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) - if (X.dims().size() == 2 && Y.dims().size() == 2) { - auto& x_dims = X.dims(); // M * K - auto& y_dims = Y.dims(); // K * N - const int M = trans_x ? x_dims[1] : x_dims[0]; - const int K = trans_x ? x_dims[0] : x_dims[1]; - const int N = trans_y ? y_dims[0] : y_dims[1]; - phi::funcs::LinearWithCublasLt::Run( - dev_ctx, - &X, // x - &Y, // y - Out, // out - nullptr, // bias - nullptr, - M, // M bsz_seqf - N, // N output_size - K, // K input_size - trans_x, - trans_y, - phi::funcs::MatmulFusedType::kMatmul); - return; - } -#endif - const std::vector x_dims = vectorize(X.dims()); - const std::vector y_dims = vectorize(Y.dims()); - MatMulFunction( - dev_ctx, X, Y, x_dims, y_dims, Out, trans_x, trans_y, false); -} -} // namespace phi - PD_REGISTER_KERNEL(matmul, GPU, ALL_LAYOUT,