From 2783c7600d3eb124afdcd3b7b36ac86bfb49fa5e Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Tue, 31 Aug 2021 01:42:10 +0000 Subject: [PATCH 1/7] commit for pool higher preformance --- .../kernel_primitives/compute_primitives.h | 222 +++++++-- .../kernel_primitives/datamover_primitives.h | 350 ++++++++++++--- .../kernel_primitives/helper_primitives.h | 61 ++- .../operators/margin_cross_entropy_op.cu | 38 +- .../operators/reduce_ops/reduce_functor_op.h | 16 +- .../fluid/operators/reduce_ops/reduce_op.cu.h | 420 ++++++++---------- 6 files changed, 726 insertions(+), 381 deletions(-) diff --git a/paddle/fluid/operators/kernel_primitives/compute_primitives.h b/paddle/fluid/operators/kernel_primitives/compute_primitives.h index ccd301aa8ca3d..88c73797bbfb4 100644 --- a/paddle/fluid/operators/kernel_primitives/compute_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/compute_primitives.h @@ -21,7 +21,8 @@ #include #endif -#include +// #include +#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/float16.h" namespace paddle { @@ -29,6 +30,16 @@ namespace operators { namespace kernel_primitives { namespace details { +#ifdef __HIPCC__ +constexpr int kMaxThread = 256; +constexpr int kWarpSize = 64; +#else +constexpr int kMaxThread = 128; +constexpr int kWarpSize = 32; +#endif + +enum ReduceMode { kGlobalMode, kLocalMode }; + template class MPTypeTrait { public: @@ -41,37 +52,98 @@ class MPTypeTrait { using Type = float; }; -} // namespace details +/** + * @brief will be used in BlockYReduce, get the index of reduce_num in shared + * memory + */ +__device__ __forceinline__ int SharedMemoryIndex(int index) { + return (threadIdx.y + index) * blockDim.x + threadIdx.x; +} -/*************************** Compute Functor****************************/ -template -struct DivFunctor { - inline HOSTDEVICE T operator()(const T* args) const { - return args[0] / args[1]; +template +__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) { + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, true); + for (int stride = details::kWarpSize / 2; stride > 0; stride >>= 1) { + T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride); + val = reducer(val, temp); } -}; + return val; +} -template -struct DivFunctor::value>> { - inline HOSTDEVICE T operator()(const T* args) const { - PADDLE_ENFORCE(args[1] != 0, - platform::errors::InvalidArgument( - "Invalid Argument Error: Integer division by zero " - "encountered in divide. Please check the input value.")); - return args[0] / args[1]; +/* e.g. + * |---------block---------| + * |warp0|warp1|warp2|warp3| + * |0~31|32~63|64~95|96~127| ---->blockDim.x = 128 + * \|/ \|/ \|/ \|/ ---->1. First WarpReduce in each warp + * res0 res1 res2 res3 ---->2. Store result of each warp to shared memory + * \ \ / / ---->3. Load the result above from shared memory + * res to warp0 and process the second WarpReduce + */ + +/** + * @brief BlockXReduce reduce along blockDim.x + */ +template +__device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) { + __syncthreads(); + using details::kWarpSize; + __shared__ T shared[2 * kWarpSize]; + int block_dim_x = blockDim.x; + if (blockDim.x > kWarpSize) { + block_dim_x = blockDim.x / kWarpSize; + int lane = threadIdx.x % kWarpSize; + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int wid = tid / kWarpSize; + int bid = threadIdx.y; + val = WarpReduce(val, reducer); + if (lane == 0) { + shared[wid] = val; + } + __syncthreads(); + val = shared[bid * block_dim_x + lane]; } -}; + + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, true); + for (int stride = 1; stride < block_dim_x; stride <<= 1) { + T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride); + val = reducer(val, temp); + } + return val; +} + +/** + * @brief BlockYReduce reduce along blockDim.y + */ +template +__device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) { + __shared__ T shared_memory[details::kMaxThread]; + shared_memory[SharedMemoryIndex(0)] = val; + for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) { + __syncthreads(); + if (threadIdx.y < stride && threadIdx.y + stride < blockDim.y) { + T temp = shared_memory[SharedMemoryIndex(stride)]; + val = reducer(val, temp); + } + shared_memory[SharedMemoryIndex(0)] = val; + } + return val; +} + +} // namespace details /*************************** Compute Function****************************/ /** - * @brief compute functor for elementwise_two, in1 and in2 has the same shape + * @brief binary function, in1 and in2 have same shape * @param: - * T : the type of in1 and in2 - * NX: the row of in1 and in2 - * NY: the col of in1 and in2 - * BlockSize: the strid of col - * OpFunc: compute functor eg: ADD, SUB, XOR, OR, MUL + * T: data type of in1, in2 + * OutT: data type of out + * NX: the cols of in1, in2 + * NY: the rows of in1, in2 + * BlockSize: the config of this device + * OpFunc: compute functor eg: in1 + in2, in1 - in2 */ template @@ -88,32 +160,40 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, const T* in1, } /** - * @brief fma eg: a * b + c, in1 in2, in3 and out has the same shape + * @brief ternary function, in1, in2 and in3 have same shape * @param: - * T : the type of in1 and in2, in3 - * NX: the row of in1, in2 and in3 - * NY: the col of in1, in2 and in3 - * BlockSize: the strid of col + * T: data type of in1, in2, in3 + * OutT: data type of out + * NX: the cols of in1, in2 + * NY: the rows of in1, in2 + * BlockSize: the config of this device + * OpFunc: compute functor eg: out = in1 * in2 + in3 */ template -__device__ __forceinline__ void ElementwiseFma(OutT* out, const T* in1, - const T* in2, const T* in3, - OpFunc compute) { +__device__ __forceinline__ void ElementwiseTernary(OutT* out, const T* in1, + const T* in2, const T* in3, + OpFunc compute) { + T args[3]; #pragma unroll for (int idx = 0; idx < NX * NY; ++idx) { - out[idx] = static_cast(compute(in1[idx], in2[idx], in3[idx])); + args[0] = in1[idx]; + args[1] = in2[idx]; + args[2] = in3[idx]; + out[idx] = static_cast(compute(args)); } } /** - * @brief compute functor for elementwise_two, in1 is [1, NY], in2 is [NX, NY] + * @brief cycle binary function, in1's shape size is [1, NX], in2's shape size + * is [NY, NX], out's shape size is [NY, NX] * @param: - * T : the type of in1 and in2 - * NX: the row of in1 and in2 - * NY: the col of in2 - * BlockSize: the strid of col - * OpFunc: compute functor eg: ADD, SUB, XOR, OR, MUL + * T: data type of in1, in2 + * OutT: data type of out + * NX: the cols of in1, in2 + * NY: the rows of in1, in2 + * BlockSize: the config of this device + * OpFunc: compute functor eg: in1 + in2, in1 - in2 */ template @@ -130,13 +210,14 @@ __device__ __forceinline__ void CycleBinary(OutT* out, const T* in1, } /** - * @brief compute functor for unary, in1 is [NX, NY] + * @brief unary function * @param: - * T : the type of in - * NX: the row of in - * NY: the col of in - * BlockSize: the strid of col - * OpFunc: compute functor eg: relu, sigmoid, exp + * T: data type of in + * OutT: data type of out + * NX: the cols of in + * NY: the rows of in + * BlockSize: the config of this device + * OpFunc: compute functor eg: relu, exp */ template @@ -148,6 +229,59 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in, } } +/** + * @brief reduce function, in's shape size is [NX, NY]. + * If ReduceMode == kLocalMode then reduce NX, the shape of out is [NY, 1], + * if ReduceMode == kGlobalMode then reduce between different threads, the + * shape of out is [NY, NX]. If reduce_last_dim is false and reduce_num was + * split, BlockYReduce will be called. If reduce_last_dim is true and + * reduce_num was split, BlockXReduce will be called + * @typename: + * T: data type of in + * NX: the cols of in + * NY: the rows of in + * BlockSize: the config of this device + * OpFunc: reduce functor, eg: CustomSum, CustomMean in reduce_functor_op.h + * @param: + * reducer: reduce functor, eg: CustomSum() + * reduce_last_dim: if in's last dim need to be reduce then reduce_last_dim = + * true + */ +template +__device__ __forceinline__ void Reduce(T* out, const T* in, OpFunc reducer, + bool reduce_last_dim) { + int block_index = blockDim.y; + + if (ReduceMode == details::ReduceMode::kGlobalMode) { + bool block_reduce_y = (!reduce_last_dim) && (block_index > 1); + // when reduce is not required for the last dim, and reduce num has been + // split into multiple threads + if (block_reduce_y) { +#pragma unroll + for (int i = 0; i < NY * NX; i++) { // reduce along blockdim.y + out[i] = details::BlockYReduce(out[i], reducer); + } + } + + // when last dimension need to be reduced + if (reduce_last_dim) { +#pragma unroll + for (int i = 0; i < NY * NX; i++) { // reduce along blockDim.x + out[i] = details::BlockXReduce(out[i], reducer); + } + } + } else { // else kLocalMode +#pragma unroll + for (int i = 0; i < NY; ++i) { +#pragma unroll + for (int j = 0; j < NX; ++j) { + out[i] = reducer(out[i], in[i * NX + j]); + } + } + } +} + } // namespace kernel_primitives } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/kernel_primitives/datamover_primitives.h b/paddle/fluid/operators/kernel_primitives/datamover_primitives.h index d520c33ca9bcc..2313adf51b799 100644 --- a/paddle/fluid/operators/kernel_primitives/datamover_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/datamover_primitives.h @@ -13,11 +13,13 @@ // limitations under the License. #pragma once +#ifdef PADDLE_WITH_CUDA #include #include -#include -#include -#include +#endif +#ifdef PADDLE_WITH_HIP +#include +#endif namespace paddle { namespace operators { @@ -104,52 +106,197 @@ struct BroadcastConfig { #undef INT_BITS } // namespace details -template -__device__ __forceinline__ void ReadDataBase(T* dst, const T* __restrict__ src, - int size) { +/** + * @brief load data from src to dst, src can be 1D data or 2D data. Note that + * you can use this function when you are sure that the data will not cross the + * boundary. + * @typename: + * Tx: data type of src + * Ty: data type of dstt + * NX: the cols of src, dst + * NY: the rows of src, dst + * BlockSize: the config of this device + * @param: + * stride_nx: the stride of cols + * stride_ny: the stride of rows + */ + +template +__device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src, + int stride_nx, int stride_ny) { + if (NY == 1 && NX == 1) { + dst[0] = static_cast(src[threadIdx.x]); + } else if (NX == 1) { + int dx = threadIdx.x; +#pragma unroll + for (int idy = 0; idy < NY; ++idy) { + dst[idy] = static_cast(src[dx + idy * stride_ny]); + } + } else if (NY == 1) { +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + dst[idx] = static_cast(src[idx * stride_nx]); + } + } else { + int dx = threadIdx.x * NX; +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { +#pragma unroll + for (int idy = 0; idy < NY; ++idy) { + dst[idy * NX + idx] = + static_cast(src[idx * stride_nx + dx + idy * stride_ny]); + } + } + } +} + +/** + * @brief load data from src to dst, src can be 1D data or 2D data. When + * boundary judgment is required, you need to set a to true, and a is false by + * default. + * @typename: + * Tx: data type of src + * Ty: data type of dstt + * NX: the cols of src, dst + * NY: the rows of src, dst + * BlockSize: the config of this device + * IsBoundary: whether to make boundary judgment + * @param: + * size_nx: number of columns to be processed by the current block + * size_ny: number of rows to be processed by the current block + * stride_nx: the stride of cols + * stride_ny: the stride of rows + */ +template +__device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src, + int size_nx, int size_ny, + int stride_nx, int stride_ny) { int dx = threadIdx.x * NX; + int size = size_nx - dx; + + // Each branch is added for better performance + if (NX == 1 && NY == 1) { // for NX == 1 and NY == 1 + if (IsBoundary) { + if (dx < size_nx) { + dst[0] = static_cast(src[dx]); + } + } else { + dst[0] = static_cast(src[dx]); + } + } else if (NX == 1) { // for NX == 1 and NY != 1 #pragma unroll - for (int idx = 0; idx < NX; ++idx) { - if ((idx + dx) >= size) { - break; + for (int idy = 0; idy < NY; ++idy) { + if (IsBoundary) { + if (idy >= size_ny) { + break; + } + } + dst[idy] = static_cast(src[dx + idy * stride_ny]); + } + } else if (NY == 1) { // for NY == 1 and NX != 1 +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + if (IsBoundary) { + if (idx >= size) { + break; + } + } + dst[idx] = static_cast(src[idx * stride_nx + dx]); + } + } else { // for NX != 1 and NY != 1 +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + if (IsBoundary) { + if (idx >= size) { + break; + } + } +#pragma unroll + for (int idy = 0; idy < NY; ++idy) { + if (IsBoundary) { + if (idy >= size_ny) { + break; + } + } + dst[idy * NX + idx] = + static_cast(src[idx * stride_nx + dx + idy * stride_ny]); + } } - dst[idx] = src[idx + dx]; } } -template +template +__device__ __forceinline__ void Init(T* dst, T init_data) { +#pragma unroll + for (int i = 0; i < NX; i++) { + dst[i] = init_data; + } +} + +/** @brief: ReadData + * @brief load data from src to dst, src can be 1D data, you should set NY = 1. + * When boundary judgment is required, you need to set a to true, and a is false + * by default. + * @typename: + * T : the data type of src + * NX: the cols of src, dst + * NY: in this function NY only can be 1 + * BlockSize: the config of this device + * IsBoundary: whether to make boundary judgment + * @param: + * num: number of columns to be processed by the current block + */ +template __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src, - int size) { - const int VECTOR_SIZE = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; - const int VECTORS_PER_THREAD = NX / VECTOR_SIZE; + int num) { + if (IsBoundary) { // blockDim.x * NX > num + int dx = threadIdx.x * NX; +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + if (idx + dx < num) { + dst[idx] = src[idx + dx]; + } + } + } else { // blockDim,x * NX < num + const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; + const int kVectorsPerThread = NX / kVectorSize; + int tid = threadIdx.x * kVectorsPerThread; - // Vector per thread - if (blockDim.x * NX > size) { - ReadDataBase(dst, src, size); - } else { - // Vector type - using VecType = details::VectorType; - VecType vec_temp[VECTORS_PER_THREAD]; + using VecType = details::VectorType; const VecType* vec_input = reinterpret_cast(src); - ReadDataBase( - vec_temp, vec_input, VECTORS_PER_THREAD * blockDim.x); + VecType vec_temp[kVectorsPerThread]; + #pragma unroll - for (int idx = 0; idx < NX; ++idx) { - dst[idx] = *(reinterpret_cast(vec_temp) + idx); + for (int i = 0; i < kVectorsPerThread; ++i) { + vec_temp[i] = vec_input[i + tid]; +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + dst[idx] = *(reinterpret_cast(vec_temp) + idx); + } } } } -/** @brief: ReadDataBc - * read data from src ptr when the shape of src and dst are different +/** + * @brief: read data for broadcast + * @typename: + * T : the data type of src + * NX: the cols of src, dst + * NY: in this function NY only can be 1 + * BlockSize: the config of this device + * ShapeSize: the shape size of out. eg in[1, 35], out[32, 35] then shape size + * is 2 + * IsBoundary: whether to make boundary judgment * @param: - * src: the source pointer - * dst: the dst pointer - * stride_nx: the stride of src - * stride_ny: the stride of src - * the shape of dst is [NY, NX] + * fix: data offset of this block, blockDim.x * blockIdx.x * NX; + * config: get the global index in src, attention config was declared in host; + * num: the num of out + * stride_nx: the stride of cols + * stride_ny: the stride of rows */ -template +template __device__ __forceinline__ void ReadDataBc( T* dst, const T* __restrict__ src, uint32_t fix, details::BroadcastConfig config, int num, int stride_nx, @@ -162,53 +309,130 @@ __device__ __forceinline__ void ReadDataBc( #pragma unroll for (uint32_t nx = 0; nx < NX; ++nx) { uint32_t idx = base_offset + ny * stride_ny + nx * stride_nx; - if (idx < num) { - offset = 0; -#pragma unroll - for (int i = 0; i < ShapeSize; ++i) { - auto fast_divmoder = config.divmoders[i].Divmod(idx); - idx = fast_divmoder.val[0]; - offset += fast_divmoder.val[1] * config.strides[i]; + if (IsBoundary) { + if (idx >= num) { + break; } - dst[nx + ny * NX] = src[offset]; } + offset = 0; +#pragma unroll + for (int i = 0; i < ShapeSize; ++i) { + auto fast_divmoder = config.divmoders[i].Divmod(idx); + idx = fast_divmoder.val[0]; + offset += fast_divmoder.val[1] * config.strides[i]; + } + dst[nx + ny * NX] = src[offset]; } } } -template -__device__ __forceinline__ void WriteDataBase(T* dst, const T* __restrict__ src, - int size) { - int dx = threadIdx.x * NX; +/** + * @brief: read data for broadcast + * @typename: + * T : the data type of src + * NX: the cols of src, dst + * NY: in this function NY only can be 1 + * BlockSize: the config of this device + * ShapeSize: the shape size of out. eg in[1, 35], out[32, 35] then shape size + * is 2 + * IndexCal: get the global index in src, attention config was declared in host; + * IsBoundary: whether to make boundary judgment + * @param: + * fix: data offset of this block, blockDim.x * blockIdx.x * NX; + * index_cal: get the global index in src, attention config was declared in + * host; + * size_nx: number of columns to be processed by the current block + * size_ny: number of rows to be processed by the current block + * stride_nx: the stride of cols + * stride_ny: the stride of rows + * reduce_last_dim: according to the block split set threadIdx + */ +template +__device__ __forceinline__ void ReadDataReduce( + T* dst, const T* __restrict__ src, int fix, const IndexCal& index_cal, + int size_nx, int size_ny, int stride_nx, int stride_ny, + bool reduce_last_dim) { + int base_offset = fix; + if (reduce_last_dim) { + base_offset += threadIdx.x; + } else { + base_offset += threadIdx.y; + } + + if (NX == 1) { #pragma unroll - for (int idx = 0; idx < NX; ++idx) { - if ((idx + dx) >= size) { - break; + for (int ny = 0; ny < NY; ++ny) { + if (IsBoundary) { + if (base_offset >= size_ny) { + break; + } + } + uint32_t offset = index_cal(base_offset); + dst[ny] = src[offset]; + base_offset += stride_ny; + } + } else { +#pragma unroll + for (int nx = 0; nx < NX; ++nx) { + if (IsBoundary) { + if (nx * stride_nx >= size_nx) { + break; + } + } +#pragma unroll + for (int ny = 0; ny < NY; ++ny) { + if (IsBoundary) { + if (nx * stride_nx >= size_nx) { + break; + } + } + uint32_t offset = index_cal(base_offset); + dst[nx + ny * NX] = src[offset]; + base_offset += stride_ny; + } } - dst[idx + dx] = src[idx]; } } -template +/** @brief: WriteData + * @brief store data from src to dst, src can be 1D data, you should set NY = 1. + * When boundary judgment is required, you need to set a to true, and a is false + * by default. + * @typename: + * T : the data type of src + * NX: the cols of src, dst + * NY: in this function NY only can be 1 + * BlockSize: the config of this device + * IsBoundary: whether to make boundary judgment + * @param: + * num: number of columns to be processed by the current block + */ +template __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src, - int size) { - const int VECTOR_SIZE = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; - const int VECTORS_PER_THREAD = NX / VECTOR_SIZE; - - // Vector per thread - if (blockDim.x * NX > size) { - WriteDataBase(dst, src, size); + int num) { + if (IsBoundary) { + int dx = threadIdx.x * NX; +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + if ((idx + dx) < num) { + dst[idx + dx] = src[idx]; + } + } } else { // Vector type - using VecType = details::VectorType; - VecType vec_temp[VECTORS_PER_THREAD]; + const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; + const int kVectorsPerThread = NX / kVectorSize; + + int dx = threadIdx.x * kVectorsPerThread; + using VecType = details::VectorType; + VecType* vec_dst = reinterpret_cast(dst); + VecType vec_temp[kVectorsPerThread]; #pragma unroll - for (int idx = 0; idx < VECTORS_PER_THREAD; ++idx) { + for (int idx = 0; idx < kVectorsPerThread; ++idx) { vec_temp[idx] = *(reinterpret_cast(src) + idx); + vec_dst[dx + idx] = vec_temp[idx]; } - VecType* vec_dst = reinterpret_cast(dst); - WriteDataBase( - vec_dst, vec_temp, VECTORS_PER_THREAD * blockDim.x); } } diff --git a/paddle/fluid/operators/kernel_primitives/helper_primitives.h b/paddle/fluid/operators/kernel_primitives/helper_primitives.h index 1d23cfe007558..68fd48e97e31a 100644 --- a/paddle/fluid/operators/kernel_primitives/helper_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/helper_primitives.h @@ -16,6 +16,65 @@ namespace paddle { namespace operators { -namespace kernel_primitives {} +namespace kernel_primitives { +namespace details { + +static __device__ __forceinline__ platform::float16 ExpFunctor( + platform::float16 x) { + return ::Eigen::numext::exp(x); } +static __device__ __forceinline__ float ExpFunctor(float x) { return expf(x); } +static __device__ __forceinline__ double ExpFunctor(double x) { return exp(x); } +static __device__ __forceinline__ platform::float16 LogFunctor( + platform::float16 x) { + return ::Eigen::numext::log(x); } +static __device__ __forceinline__ float LogFunctor(float x) { return logf(x); } +static __device__ __forceinline__ double LogFunctor(double x) { return log(x); } + +} // namespace details +/*************************** Compute Functor****************************/ +// for margin_cross_entropy +template +struct ExpLogitTransformer { + HOSTDEVICE explicit inline ExpLogitTransformer(int n) {} + + HOSTDEVICE inline Ty operator()(const Tx* x) const { + return static_cast(details::ExpFunctor(x[0])); + } + + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(details::ExpFunctor(x)); + } +}; + +// Post processing function for sum, max, min, prod, any +template +struct IdentityFunctor { + HOSTDEVICE explicit inline IdentityFunctor(int n) {} + + HOSTDEVICE inline Ty operator()(const Tx* x) const { + return static_cast(x[0]); + } + + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(x); + } +}; + +// Post processing function for mean +template +struct DivideFunctor { + HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {} + + HOSTDEVICE inline T operator()(const T* x) const { return x[0] * n_inv; } + + HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; } + + private: + T n_inv; +}; + +} // namespace kernel_primitives +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/margin_cross_entropy_op.cu b/paddle/fluid/operators/margin_cross_entropy_op.cu index ccdba43b0542d..4b63dc5e8527e 100644 --- a/paddle/fluid/operators/margin_cross_entropy_op.cu +++ b/paddle/fluid/operators/margin_cross_entropy_op.cu @@ -128,39 +128,9 @@ __global__ void AddMarginToPositiveLogitsKernel( } } -static __device__ __forceinline__ platform::float16 exp_on_device( - platform::float16 x) { - return ::Eigen::numext::exp(x); -} -static __device__ __forceinline__ float exp_on_device(float x) { - return expf(x); -} -static __device__ __forceinline__ double exp_on_device(double x) { - return exp(x); -} -static __device__ __forceinline__ platform::float16 log_on_device( - platform::float16 x) { - return ::Eigen::numext::log(x); -} -static __device__ __forceinline__ float log_on_device(float x) { - return logf(x); -} -static __device__ __forceinline__ double log_on_device(double x) { - return log(x); -} - -template -struct ExpLogitTransformer { - HOSTDEVICE explicit inline ExpLogitTransformer(int n) {} - - HOSTDEVICE inline Ty operator()(const Tx& x) const { - return static_cast(exp_on_device(x)); - } -}; - template struct ExpAndSum { - using Transformer = ExpLogitTransformer; + using Transformer = kpds::ExpLogitTransformer; inline Ty initial() { return static_cast(0.0f); } @@ -189,7 +159,7 @@ __global__ void LogitsMinusLogSumKernel(T* logits, const T* logits_sum_per_row, const int64_t N, const int64_t D) { CUDA_KERNEL_LOOP(i, N * D) { auto row = i / D; - logits[i] -= log_on_device(logits_sum_per_row[row]); + logits[i] -= kpds::details::LogFunctor(logits_sum_per_row[row]); } } @@ -204,9 +174,9 @@ __global__ void HardLabelSoftmaxWithCrossEntropyKernel( if ((col + start_index) == labels[row]) { auto softmax = log_softmax[i]; loss[row] = -softmax; - log_softmax[i] = exp_on_device(softmax); + log_softmax[i] = kpds::details::ExpFunctor(softmax); } else { - log_softmax[i] = exp_on_device(log_softmax[i]); + log_softmax[i] = kpds::details::ExpFunctor(log_softmax[i]); } } } diff --git a/paddle/fluid/operators/reduce_ops/reduce_functor_op.h b/paddle/fluid/operators/reduce_ops/reduce_functor_op.h index 0f02be21cc907..637ed2dcc2bba 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_functor_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_functor_op.h @@ -24,9 +24,11 @@ limitations under the License. */ namespace paddle { namespace operators { +namespace kpds = paddle::operators::kernel_primitives; + template struct CustomMin { - using Transformer = detail::IdentityFunctor; + using Transformer = kpds::IdentityFunctor; inline Ty initial() { return static_cast(std::numeric_limits::max()); @@ -39,7 +41,7 @@ struct CustomMin { template struct CustomMax { - using Transformer = detail::IdentityFunctor; + using Transformer = kpds::IdentityFunctor; inline Ty initial() { return static_cast(std::numeric_limits::lowest()); @@ -53,7 +55,7 @@ struct CustomMax { // for cub::Reduce template struct CustomSum { - using Transformer = detail::IdentityFunctor; + using Transformer = kpds::IdentityFunctor; inline Ty initial() { return static_cast(0.0f); } @@ -64,7 +66,7 @@ struct CustomSum { template struct CustomMean { - using Transformer = detail::DivideFunctor; + using Transformer = kpds::DivideFunctor; inline Ty initial() { return static_cast(0.0f); } @@ -75,7 +77,7 @@ struct CustomMean { template struct CustomMul { - using Transformer = detail::IdentityFunctor; + using Transformer = kpds::IdentityFunctor; inline Ty initial() { return static_cast(1.0f); } @@ -86,7 +88,7 @@ struct CustomMul { template struct CustomLogicalOr { - using Transformer = detail::IdentityFunctor; + using Transformer = kpds::IdentityFunctor; inline Ty initial() { return static_cast(false); } @@ -97,7 +99,7 @@ struct CustomLogicalOr { template struct CustomLogicalAnd { - using Transformer = detail::IdentityFunctor; + using Transformer = kpds::IdentityFunctor; inline Ty initial() { return static_cast(true); } diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h index 30b1cf5ac711d..a5f51896ad924 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h @@ -34,6 +34,7 @@ namespace cub = hipcub; #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/fast_divmod.h" @@ -43,28 +44,10 @@ namespace cub = hipcub; namespace paddle { namespace operators { -namespace detail { - -// Post processing function for sum, max, min, prod, any -template -struct IdentityFunctor { - HOSTDEVICE explicit inline IdentityFunctor(int n) {} - - HOSTDEVICE inline Ty operator()(const Tx& x) const { - return static_cast(x); - } -}; -// Post processing function for mean -template -struct DivideFunctor { - HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {} +namespace kps = paddle::operators::kernel_primitives; - HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; } - - private: - T n_inv; -}; +namespace detail { static inline int GetLastPow2(int n) { n |= (n >> 1); @@ -90,17 +73,10 @@ static inline std::vector GetDimStrides(const std::vector& dims, return strides; } -#ifdef __HIPCC__ -constexpr int kMaxThread = 256; -constexpr int kWarpSize = 64; -#else -constexpr int kMaxThread = 128; -constexpr int kWarpSize = 32; -#endif - // get blockDim for reduceLastDim and reduceAny static inline int GetBlockDim(int block_dim) { - return block_dim >= kMaxThread ? kMaxThread : GetLastPow2(block_dim); + return block_dim >= kps::details::kMaxThread ? kps::details::kMaxThread + : GetLastPow2(block_dim); } // check reduce rand is valid @@ -167,7 +143,7 @@ struct IndexCalculator { detail::VectorToArray(cal_divmoders); } - __device__ inline int Get(int offset) const { + __device__ inline int operator()(int offset) const { int index = 0; #pragma unroll for (int i = 0; i < kMaxRank; ++i) { @@ -187,6 +163,15 @@ struct IndexCalculator { framework::Array divmoders; }; +// when reduce_type == kReduceLastDim this struct will be used +// for higher performance +struct LastDimIndexCal { + explicit LastDimIndexCal(int num) : stride(num) {} + + __device__ inline int operator()(int index) const { return index * stride; } + int stride; +}; + // reduce config template struct ReduceConfig { @@ -307,7 +292,7 @@ struct ReduceConfig { left_dim.assign(left_set.begin(), left_set.end()); // if the last dim gets involved in reduction - reduce_lastdim = (reduce_dim.back() == x_dim.size() - 1); + reduce_last_dim = (reduce_dim.back() == x_dim.size() - 1); } // set x_strides, reduce_strides, left_strides for reduceLastDim and reduceAny @@ -354,19 +339,19 @@ struct ReduceConfig { void SetBlockDimForReduceAny(dim3* block_dim, dim3* grid_dim) { constexpr int min_reduce_num_per_thread = 16; constexpr int max_reduce_num_per_thread = 256; - constexpr int max_num_threads = detail::kMaxThread; + constexpr int max_num_threads = kps::details::kMaxThread; // set block size. - // 1. If reduce_lastdim == true, all the threads whose threadIdx.y are same + // 1. If reduce_last_dim == true, all the threads whose threadIdx.y are same // will process the reduction for one output. // The number of output for one block is blockDim.y; - // 2. If reduce_lastdim == false, different threadIdx.x will process + // 2. If reduce_last_dim == false, different threadIdx.x will process // different reduction and gets the output separately. If it is // necessary, it should reduce in block y. // The number of output for one block is blockDim.x; int block_x, block_y; int grid_num, reduce_num_per_thread; - if (reduce_lastdim) { + if (reduce_last_dim) { block_x = detail::GetBlockDim(reduce_num); block_y = detail::GetBlockDim(left_num); block_dim->x = block_x; @@ -459,13 +444,13 @@ struct ReduceConfig { should_reduce_again = true; - block_dim.x = 32; + block_dim.x = detail::GetBlockDim(left_num); block_dim.y = 1; grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x; grid_dim.y = (reduce_num + blocking_size - 1) / blocking_size; } else { - block_dim.x = 32; + block_dim.x = detail::GetBlockDim(left_num); block_dim.y = 1; blocking_size = reduce_num; grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x; @@ -493,258 +478,197 @@ struct ReduceConfig { int left_num; int blocking_size; bool should_reduce_again; - bool reduce_lastdim; + bool reduce_last_dim; Ty* output_data; dim3 block; dim3 grid; }; - -static __device__ int SharedMemoryIndex(int index) { - return (threadIdx.y + index) * blockDim.x + threadIdx.x; -} - -template -static __device__ T WarpReduce(T val, ReduceOp reducer) { - unsigned mask = 0u; - CREATE_SHFL_MASK(mask, true); - for (int stride = detail::kWarpSize / 2; stride > 0; stride >>= 1) { - T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride); - val = reducer(val, temp); - } - return val; -} - -/* e.g. - * |---------block---------| - * |warp0|warp1|warp2|warp3| - * |0~31|32~63|64~95|96~127| ---->blockDim.x = 128 - * \|/ \|/ \|/ \|/ ---->1. First WarpReduce in each warp - * res0 res1 res2 res3 ---->2. Store result of each warp to shared memory - * \ \ / / ---->3. Load the result above from shared memory - * res to warp0 and process the second WarpReduce +/* size : how many colonms left have to be reduced + * loop : how many rows data have to be reduced + * block_size: max rows this block to reduce */ -template -static __device__ T BlockXReduce(T val, ReduceOp reducer) { - using detail::kWarpSize; - __shared__ T shared[2 * kWarpSize]; - int block_dim_x = blockDim.x; - if (blockDim.x > kWarpSize) { - block_dim_x = blockDim.x / kWarpSize; - int lane = threadIdx.x % kWarpSize; - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int wid = tid / kWarpSize; - int bid = threadIdx.y; - val = WarpReduce(val, reducer); - if (lane == 0) { - shared[wid] = val; - } - __syncthreads(); - val = shared[bid * block_dim_x + lane]; - } - - unsigned mask = 0u; - CREATE_SHFL_MASK(mask, true); - for (int stride = 1; stride < block_dim_x; stride <<= 1) { - T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride); - val = reducer(val, temp); - } - return val; -} - -template -static __device__ T BlockYReduce(T val, ReduceOp reducer) { - __shared__ T shared_memory[detail::kMaxThread]; - shared_memory[SharedMemoryIndex(0)] = val; - for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) { - __syncthreads(); - if (threadIdx.y < stride && threadIdx.y + stride < blockDim.y) { - T temp = shared_memory[SharedMemoryIndex(stride)]; - val = reducer(val, temp); - } - shared_memory[SharedMemoryIndex(0)] = val; - } - return val; -} - -// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this -// function will be used -// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1 -// if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / 32 -// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32 template -__device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer, - TransformOp transformer, MPType init, - int reduce_num, int left_num, int block_size) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; + typename TransformOp, bool IsBoundary = false> +__device__ void HigherDimDealSegment(const Tx* x, Ty* y, ReduceOp reducer, + TransformOp transformer, MPType init, + int reduce_num, int left_num, + int block_size) { + const int NY = 1; + int idx = blockIdx.x * blockDim.x; int idy = blockIdx.y * block_size; - - MPType reduce_var = init; - - if (idx < left_num) { - int loop = reduce_num - idy; - loop = loop > block_size ? block_size : loop; - - for (int iy = 0; iy < loop; iy++) { - int id = (idy + iy) * left_num + idx + blockIdx.z * reduce_num * left_num; - reduce_var = reducer(reduce_var, static_cast(transformer(x[id]))); - } - - y[idx + blockIdx.y * left_num + blockIdx.z * gridDim.y * left_num] = - static_cast(reduce_var); + // block_offset of rows + Tx reduce_input[NY]; + MPType reduce_compute[NY]; + MPType result = init; + // the offset of this block + int block_offset = idy * left_num + idx + blockIdx.z * reduce_num * left_num; + const Tx* input = x + block_offset; + int store_offset = + blockIdx.y * left_num + blockIdx.z * gridDim.y * left_num + idx; + // how many columns left + int size = left_num - idx; + // how many rows have to be reduced + int loop = reduce_num - idy; + loop = loop > block_size ? block_size : loop; + + for (int loop_index = 0; loop_index < loop; loop_index += NY) { + kps::ReadData( + &reduce_input[0], input + loop_index * left_num, size, NY, 1, left_num); + kps::ElementwiseUnary( + &reduce_compute[0], &reduce_input[0], transformer); + kps::Reduce( + &result, &reduce_compute[0], reducer, false); } + + Ty temp_data = static_cast(result); + kps::WriteData(y + store_offset, &temp_data, size); } // when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or // when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this // function will be used template -__device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, - TransformOp transformer, MPType init, int reduce_num, - int left_num, bool reduce_lastdim, - const IndexCalculator& reduce_index_calculator, - const IndexCalculator& left_index_calculator) { + typename TransformOp, typename Calculator> +__global__ void ReduceAnyKernel(const Tx* x, Ty* y, ReduceOp reducer, + TransformOp transformer, MPType init, + int reduce_num, int left_num, + bool reduce_last_dim, + const Calculator reduce_index_calculator, + const Calculator left_index_calculator) { int input_idx, left_idx, stride; + int block_size = 0; + bool need_store = true; + int tid = 0; // the last dim gets involved in reduction - if (reduce_lastdim) { - input_idx = blockIdx.y * blockDim.x + threadIdx.x; + if (reduce_last_dim) { + input_idx = blockIdx.y * blockDim.x; left_idx = blockIdx.x * blockDim.y + threadIdx.y; stride = gridDim.y * blockDim.x; + block_size = blockDim.x; + need_store = (threadIdx.x == 0) && (left_idx < left_num); + tid = threadIdx.x; } else { - input_idx = blockIdx.y * blockDim.y + threadIdx.y; + input_idx = blockIdx.y * blockDim.y; left_idx = blockIdx.x * blockDim.x + threadIdx.x; stride = gridDim.y * blockDim.y; + block_size = blockDim.y; + need_store = (threadIdx.y == 0) && (left_idx < left_num); + tid = threadIdx.y; } + int store_offset = blockIdx.y * left_num + left_idx; // calculate the offset, means the addr where each thread really start. - int input_offset = left_index_calculator.Get(left_idx); + int input_offset = left_index_calculator(left_idx); const Tx* input = x + input_offset; MPType reduce_var = init; + Ty store_data; // 1. reduce for each thread if (left_idx < left_num) { // load REDUCE_VEC_SIZE data once, and then compute Tx input_reg[REDUCE_VEC_SIZE]; + MPType input_compute[REDUCE_VEC_SIZE]; int bound = reduce_num - (REDUCE_VEC_SIZE - 1) * stride; - while (input_idx < bound) { -#pragma unroll - for (int i = 0; i < REDUCE_VEC_SIZE; ++i) { - int reduce_idx = input_idx + i * stride; - int idx_x = reduce_index_calculator.Get(reduce_idx); - input_reg[i] = input[idx_x]; - } -#pragma unroll - for (int i = 0; i < REDUCE_VEC_SIZE; ++i) { - reduce_var = - reducer(reduce_var, static_cast(transformer(input_reg[i]))); - } - input_idx += REDUCE_VEC_SIZE * stride; + for (; input_idx + block_size < bound; + input_idx += REDUCE_VEC_SIZE * stride) { + kps::ReadDataReduce( + &input_reg[0], input, input_idx, reduce_index_calculator, 1, + reduce_num, 1, stride, reduce_last_dim); + kps::ElementwiseUnary( + &input_compute[0], &input_reg[0], transformer); + kps::Reduce( + &reduce_var, &input_compute[0], reducer, reduce_last_dim); } - // deal with the remain part - int input_idx_tmp = input_idx; -#pragma unroll - for (int i = 0; i < REDUCE_VEC_SIZE; ++i) { - if (input_idx >= reduce_num) { - break; - } - int reduce_idx = input_idx; - int idx_x = reduce_index_calculator.Get(reduce_idx); - input_reg[i] = input[idx_x]; - input_idx += stride; - } - input_idx = input_idx_tmp; + kps::Init(&input_compute[0], init); + kps::ReadDataReduce( + &input_reg[0], input, input_idx, reduce_index_calculator, 1, reduce_num, + 1, stride, reduce_last_dim); + input_idx += tid; #pragma unroll for (int i = 0; i < REDUCE_VEC_SIZE; ++i) { if (input_idx >= reduce_num) { break; } - reduce_var = - reducer(reduce_var, static_cast(transformer(input_reg[i]))); + input_compute[i] = static_cast(transformer(input_reg[i])); input_idx += stride; } + kps::Reduce( + &reduce_var, &input_compute[0], reducer, reduce_last_dim); } - // 2. reduce in block y - if (!reduce_lastdim && blockDim.y > 1) { - reduce_var = BlockYReduce(reduce_var, reducer); - } - __syncthreads(); - - if (reduce_lastdim) { - // 3. reduce in block x - reduce_var = BlockXReduce(reduce_var, reducer); - if (left_idx < left_num && threadIdx.x == 0) { - y[blockIdx.y * left_num + left_idx] = static_cast(reduce_var); - } - } else { - if (left_idx < left_num && threadIdx.y == 0) { - y[blockIdx.y * left_num + left_idx] = static_cast(reduce_var); - } + kps::Reduce( + &reduce_var, &reduce_var, reducer, reduce_last_dim); + if (need_store) { + y[store_offset] = static_cast(reduce_var); } } -// module function designed for global function template -__device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer, - TransformOp transformer, MPType init, - int reduce_num, int left_num, int blocking_size, - int reduce_type, bool reduce_lastdim, - const IndexCalculator& reduce_index_calculator, - const IndexCalculator& left_index_calculator) { - if (reduce_type == ReduceType::kReduceLastDim || - reduce_type == ReduceType::kReduceAny) { - ReduceAny( - x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim, - reduce_index_calculator, left_index_calculator); - // reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1 - } else if (reduce_type == ReduceType::kReduceHigherDim) { - ReduceHigherDim( +__global__ void ReduceHigherDimKernel(const Tx* x, Ty* y, ReduceOp reducer, + TransformOp transformer, MPType init, + int reduce_num, int left_num, + int blocking_size) { + // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this + // function will be used + // eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1 + // if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / + // 32 + // else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32 + int idx = blockIdx.x * blockDim.x; + int size = left_num - idx; + if (size >= blockDim.x) { // complete segment + HigherDimDealSegment( + x, y, reducer, transformer, init, reduce_num, left_num, blocking_size); + } else { + HigherDimDealSegment( x, y, reducer, transformer, init, reduce_num, left_num, blocking_size); } } -template -__global__ void ReduceKernelFunction(const Tx* x, Ty* y, ReduceOp reducer, - TransformOp transformer, MPType init, - int reduce_num, int left_num, - int blocking_size, int reduce_type, - bool reduce_lastdim, - IndexCalculator reduce_index_calculator, - IndexCalculator left_index_calculator) { - ReduceModule( - x, y, reducer, transformer, init, reduce_num, left_num, blocking_size, - reduce_type, reduce_lastdim, reduce_index_calculator, - left_index_calculator); -} - template static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, const ReduceOp& reducer, MPType init, gpuStream_t stream, ReduceConfig config) { using TransformOp = typename ReduceOp::Transformer; - int reduce_rank = config.reduce_strides.size(); - int left_rank = config.left_strides.size(); - auto reduce_index_calculator = IndexCalculator( - reduce_rank, config.reduce_dim, config.reduce_strides, config.x_strides); - auto left_index_calculator = IndexCalculator( - left_rank, config.left_dim, config.left_strides, config.x_strides); - - ReduceKernelFunction<<>>( - x_data, config.output_data, reducer, TransformOp(config.reduce_num), init, - config.reduce_num, config.left_num, config.blocking_size, - config.reduce_type, config.reduce_lastdim, reduce_index_calculator, - left_index_calculator); + + if (config.reduce_type == kReduceLastDim) { + int stride_reduce = 1; + int stride_left = config.reduce_num; + // for higher performance + auto reduce_index_calculator = LastDimIndexCal(stride_reduce); + auto left_index_calculator = LastDimIndexCal(stride_left); + + ReduceAnyKernel<<>>( + x_data, config.output_data, reducer, TransformOp(config.reduce_num), + init, config.reduce_num, config.left_num, config.reduce_last_dim, + reduce_index_calculator, left_index_calculator); + + } else { + int reduce_rank = config.reduce_strides.size(); + int left_rank = config.left_strides.size(); + auto reduce_index_calculator = + IndexCalculator(reduce_rank, config.reduce_dim, config.reduce_strides, + config.x_strides); + auto left_index_calculator = IndexCalculator( + left_rank, config.left_dim, config.left_strides, config.x_strides); + ReduceAnyKernel<<>>( + x_data, config.output_data, reducer, TransformOp(config.reduce_num), + init, config.reduce_num, config.left_num, config.reduce_last_dim, + reduce_index_calculator, left_index_calculator); + } if (config.should_reduce_again) { dim3 block; dim3 grid; - if (config.reduce_lastdim) { + if (config.reduce_last_dim) { block = dim3(32, 1, 1); grid = dim3(detail::AlignUp(config.left_num, 32), 1, 1); } else { @@ -752,13 +676,12 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, grid = dim3(config.grid.x, 1, config.grid.z); } - ReduceKernelFunction< + ReduceHigherDimKernel< Ty, Ty, MPType, ReduceOp, - detail::IdentityFunctor><<>>( + kps::IdentityFunctor><<>>( config.output_data, y_data, reducer, - detail::IdentityFunctor(config.grid.y), init, config.grid.y, - config.left_num, config.grid.y, ReduceType::kReduceHigherDim, - config.reduce_lastdim, reduce_index_calculator, left_index_calculator); + kps::IdentityFunctor(config.grid.y), init, config.grid.y, + config.left_num, config.grid.y); } } @@ -812,6 +735,39 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, using MPType = typename details::MPTypeTrait::Type; auto reducer = ReduceOp(); + // launch ReduceHigherDimKernel + // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this + // function will be used + // eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1 + // if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / + // 32 + // else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32 + if (config.reduce_type == ReduceType::kReduceHigherDim) { + using TransformOp = typename ReduceOp::Transformer; + + ReduceHigherDimKernel< + Tx, Ty, MPType, ReduceOp, + TransformOp><<>>( + x_data, config.output_data, reducer, TransformOp(config.reduce_num), + reducer.initial(), config.reduce_num, config.left_num, + config.blocking_size); + + if (config.should_reduce_again) { + dim3 block = dim3(config.block.x, 1, 1); + dim3 grid = dim3(config.grid.x, 1, config.grid.z); + ReduceHigherDimKernel< + Ty, Ty, MPType, ReduceOp, + kps::IdentityFunctor><<>>( + config.output_data, y_data, reducer, + kps::IdentityFunctor(config.grid.y), reducer.initial(), + config.grid.y, config.left_num, config.grid.y); + } + return; + } + + // when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or + // when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this + // function will be used LaunchReduceKernel>( x_data, y_data, reducer, reducer.initial(), stream, config); } From 73d13a3fdf3b28c3f105ad13244a078041571a96 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Tue, 31 Aug 2021 01:53:23 +0000 Subject: [PATCH 2/7] update ReduceMode --- paddle/fluid/operators/kernel_primitives/compute_primitives.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/kernel_primitives/compute_primitives.h b/paddle/fluid/operators/kernel_primitives/compute_primitives.h index 88c73797bbfb4..f7297fab8a897 100644 --- a/paddle/fluid/operators/kernel_primitives/compute_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/compute_primitives.h @@ -248,12 +248,12 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in, * true */ template + ReduceMode Mode> __device__ __forceinline__ void Reduce(T* out, const T* in, OpFunc reducer, bool reduce_last_dim) { int block_index = blockDim.y; - if (ReduceMode == details::ReduceMode::kGlobalMode) { + if (Mode == details::ReduceMode::kGlobalMode) { bool block_reduce_y = (!reduce_last_dim) && (block_index > 1); // when reduce is not required for the last dim, and reduce num has been // split into multiple threads From b0e3fdbd30d46f217a2b89dd8bd153b9c38a64db Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Tue, 31 Aug 2021 02:44:41 +0000 Subject: [PATCH 3/7] update ReduceMode --- paddle/fluid/operators/kernel_primitives/compute_primitives.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/kernel_primitives/compute_primitives.h b/paddle/fluid/operators/kernel_primitives/compute_primitives.h index f7297fab8a897..2314aad029a05 100644 --- a/paddle/fluid/operators/kernel_primitives/compute_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/compute_primitives.h @@ -248,7 +248,7 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in, * true */ template + details::ReduceMode Mode> __device__ __forceinline__ void Reduce(T* out, const T* in, OpFunc reducer, bool reduce_last_dim) { int block_index = blockDim.y; From 2c32248193896e4bf307a6ed2cac768b116b646f Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Thu, 2 Sep 2021 02:59:40 +0000 Subject: [PATCH 4/7] Add API comments and specify variable names --- .../kernel_primitives/compute_primitives.h | 8 +- .../kernel_primitives/datamover_primitives.h | 116 +++++++++--------- .../fluid/operators/reduce_ops/reduce_op.cu.h | 7 +- 3 files changed, 68 insertions(+), 63 deletions(-) diff --git a/paddle/fluid/operators/kernel_primitives/compute_primitives.h b/paddle/fluid/operators/kernel_primitives/compute_primitives.h index 2314aad029a05..58642ef263156 100644 --- a/paddle/fluid/operators/kernel_primitives/compute_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/compute_primitives.h @@ -31,13 +31,15 @@ namespace kernel_primitives { namespace details { #ifdef __HIPCC__ -constexpr int kMaxThread = 256; +constexpr int kReduceMaxThread = 256; constexpr int kWarpSize = 64; #else -constexpr int kMaxThread = 128; +constexpr int kReduceMaxThread = 128; constexpr int kWarpSize = 32; #endif +// kGlobalMode: block reduce, each block gets an output; +// kLocalMode: thread reduce, each thread gets an output; enum ReduceMode { kGlobalMode, kLocalMode }; template @@ -118,7 +120,7 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) { */ template __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) { - __shared__ T shared_memory[details::kMaxThread]; + __shared__ T shared_memory[details::kReduceMaxThread]; shared_memory[SharedMemoryIndex(0)] = val; for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) { __syncthreads(); diff --git a/paddle/fluid/operators/kernel_primitives/datamover_primitives.h b/paddle/fluid/operators/kernel_primitives/datamover_primitives.h index 2313adf51b799..96088cb9ca003 100644 --- a/paddle/fluid/operators/kernel_primitives/datamover_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/datamover_primitives.h @@ -124,34 +124,35 @@ struct BroadcastConfig { template __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src, int stride_nx, int stride_ny) { + int thread_offset = threadIdx.x * NX; + if (NY == 1 && NX == 1) { - dst[0] = static_cast(src[threadIdx.x]); + dst[0] = static_cast(src[thread_offset]); } else if (NX == 1) { - int dx = threadIdx.x; #pragma unroll for (int idy = 0; idy < NY; ++idy) { - dst[idy] = static_cast(src[dx + idy * stride_ny]); + dst[idy] = static_cast(src[thread_offset + idy * stride_ny]); } } else if (NY == 1) { #pragma unroll for (int idx = 0; idx < NX; ++idx) { - dst[idx] = static_cast(src[idx * stride_nx]); + dst[idx] = static_cast(src[thread_offset + idx * stride_nx]); } } else { - int dx = threadIdx.x * NX; #pragma unroll for (int idx = 0; idx < NX; ++idx) { #pragma unroll for (int idy = 0; idy < NY; ++idy) { - dst[idy * NX + idx] = - static_cast(src[idx * stride_nx + dx + idy * stride_ny]); + dst[idy * NX + idx] = static_cast( + src[thread_offset + idx * stride_nx + idy * stride_ny]); } } } } /** - * @brief load data from src to dst, src can be 1D data or 2D data. When + * @brief load data from src to dst with strid, src can be 1D data or 2D data. + * When * boundary judgment is required, you need to set a to true, and a is false by * default. * @typename: @@ -172,17 +173,17 @@ template (src[dx]); + if (left_size_nx > 0) { + dst[0] = static_cast(src[thread_offset]); } } else { - dst[0] = static_cast(src[dx]); + dst[0] = static_cast(src[thread_offset]); } } else if (NX == 1) { // for NX == 1 and NY != 1 #pragma unroll @@ -192,23 +193,23 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src, break; } } - dst[idy] = static_cast(src[dx + idy * stride_ny]); + dst[idy] = static_cast(src[thread_offset + idy * stride_ny]); } } else if (NY == 1) { // for NY == 1 and NX != 1 #pragma unroll for (int idx = 0; idx < NX; ++idx) { if (IsBoundary) { - if (idx >= size) { + if (idx >= left_size_nx) { break; } } - dst[idx] = static_cast(src[idx * stride_nx + dx]); + dst[idx] = static_cast(src[thread_offset + idx * stride_nx]); } } else { // for NX != 1 and NY != 1 #pragma unroll for (int idx = 0; idx < NX; ++idx) { if (IsBoundary) { - if (idx >= size) { + if (idx >= left_size_nx) { break; } } @@ -219,8 +220,8 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src, break; } } - dst[idy * NX + idx] = - static_cast(src[idx * stride_nx + dx + idy * stride_ny]); + dst[idy * NX + idx] = static_cast( + src[thread_offset + idx * stride_nx + idy * stride_ny]); } } } @@ -251,17 +252,17 @@ template __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src, int num) { if (IsBoundary) { // blockDim.x * NX > num - int dx = threadIdx.x * NX; + int thread_offset = threadIdx.x * NX; #pragma unroll for (int idx = 0; idx < NX; ++idx) { - if (idx + dx < num) { - dst[idx] = src[idx + dx]; + if (idx + thread_offset < num) { + dst[idx] = src[thread_offset + idx]; } } } else { // blockDim,x * NX < num const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; const int kVectorsPerThread = NX / kVectorSize; - int tid = threadIdx.x * kVectorsPerThread; + int thread_offset = threadIdx.x * kVectorsPerThread; using VecType = details::VectorType; const VecType* vec_input = reinterpret_cast(src); @@ -269,7 +270,7 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src, #pragma unroll for (int i = 0; i < kVectorsPerThread; ++i) { - vec_temp[i] = vec_input[i + tid]; + vec_temp[i] = vec_input[thread_offset + i]; #pragma unroll for (int idx = 0; idx < NX; ++idx) { dst[idx] = *(reinterpret_cast(vec_temp) + idx); @@ -289,39 +290,39 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src, * is 2 * IsBoundary: whether to make boundary judgment * @param: - * fix: data offset of this block, blockDim.x * blockIdx.x * NX; + * block_offset: data offset of this block, blockDim.x * blockIdx.x * NX; * config: get the global index in src, attention config was declared in host; - * num: the num of out + * total_num_output: total num of output * stride_nx: the stride of cols * stride_ny: the stride of rows */ template __device__ __forceinline__ void ReadDataBc( - T* dst, const T* __restrict__ src, uint32_t fix, - details::BroadcastConfig config, int num, int stride_nx, - int stride_ny) { - uint32_t base_offset = fix + threadIdx.x * NX; - uint32_t offset = 0; + T* dst, const T* __restrict__ src, uint32_t block_offset, + details::BroadcastConfig config, int total_num_output, + int stride_nx, int stride_ny) { + uint32_t thread_offset = block_offset + threadIdx.x * NX; + uint32_t index_src = 0; #pragma unroll for (int ny = 0; ny < NY; ++ny) { #pragma unroll for (uint32_t nx = 0; nx < NX; ++nx) { - uint32_t idx = base_offset + ny * stride_ny + nx * stride_nx; + uint32_t index_output = thread_offset + ny * stride_ny + nx * stride_nx; + index_src = 0; if (IsBoundary) { - if (idx >= num) { + if (index_output >= total_num_output) { break; } } - offset = 0; #pragma unroll for (int i = 0; i < ShapeSize; ++i) { - auto fast_divmoder = config.divmoders[i].Divmod(idx); - idx = fast_divmoder.val[0]; - offset += fast_divmoder.val[1] * config.strides[i]; + auto fast_divmoder = config.divmoders[i].Divmod(index_output); + index_output = fast_divmoder.val[0]; + index_src += fast_divmoder.val[1] * config.strides[i]; } - dst[nx + ny * NX] = src[offset]; + dst[nx + ny * NX] = src[index_src]; } } } @@ -338,7 +339,7 @@ __device__ __forceinline__ void ReadDataBc( * IndexCal: get the global index in src, attention config was declared in host; * IsBoundary: whether to make boundary judgment * @param: - * fix: data offset of this block, blockDim.x * blockIdx.x * NX; + * block_offset: data offset of this block, blockDim.x * blockIdx.x * NX; * index_cal: get the global index in src, attention config was declared in * host; * size_nx: number of columns to be processed by the current block @@ -350,27 +351,27 @@ __device__ __forceinline__ void ReadDataBc( template __device__ __forceinline__ void ReadDataReduce( - T* dst, const T* __restrict__ src, int fix, const IndexCal& index_cal, - int size_nx, int size_ny, int stride_nx, int stride_ny, - bool reduce_last_dim) { - int base_offset = fix; + T* dst, const T* __restrict__ src, int block_offset, + const IndexCal& index_cal, int size_nx, int size_ny, int stride_nx, + int stride_ny, bool reduce_last_dim) { + int thread_offset = 0; if (reduce_last_dim) { - base_offset += threadIdx.x; + thread_offset = block_offset + threadIdx.x; } else { - base_offset += threadIdx.y; + thread_offset = block_offset + threadIdx.y; } if (NX == 1) { #pragma unroll for (int ny = 0; ny < NY; ++ny) { if (IsBoundary) { - if (base_offset >= size_ny) { + if (thread_offset >= size_ny) { break; } } - uint32_t offset = index_cal(base_offset); - dst[ny] = src[offset]; - base_offset += stride_ny; + uint32_t index_src = index_cal(thread_offset); + dst[ny] = src[index_src]; + thread_offset += stride_ny; } } else { #pragma unroll @@ -387,10 +388,11 @@ __device__ __forceinline__ void ReadDataReduce( break; } } - uint32_t offset = index_cal(base_offset); - dst[nx + ny * NX] = src[offset]; - base_offset += stride_ny; + uint32_t index_src = index_cal(thread_offset); + dst[nx + ny * NX] = src[index_src]; + thread_offset += stride_ny; } + thread_offset += stride_nx; } } } @@ -412,11 +414,11 @@ template __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src, int num) { if (IsBoundary) { - int dx = threadIdx.x * NX; + int thread_offset = threadIdx.x * NX; #pragma unroll for (int idx = 0; idx < NX; ++idx) { - if ((idx + dx) < num) { - dst[idx + dx] = src[idx]; + if ((thread_offset + idx) < num) { + dst[thread_offset + idx] = src[idx]; } } } else { @@ -424,14 +426,14 @@ __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src, const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; const int kVectorsPerThread = NX / kVectorSize; - int dx = threadIdx.x * kVectorsPerThread; + int thread_offset = threadIdx.x * kVectorsPerThread; using VecType = details::VectorType; VecType* vec_dst = reinterpret_cast(dst); VecType vec_temp[kVectorsPerThread]; #pragma unroll for (int idx = 0; idx < kVectorsPerThread; ++idx) { vec_temp[idx] = *(reinterpret_cast(src) + idx); - vec_dst[dx + idx] = vec_temp[idx]; + vec_dst[thread_offset + idx] = vec_temp[idx]; } } } diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h index a5f51896ad924..efa3c2c59da36 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h @@ -75,8 +75,9 @@ static inline std::vector GetDimStrides(const std::vector& dims, // get blockDim for reduceLastDim and reduceAny static inline int GetBlockDim(int block_dim) { - return block_dim >= kps::details::kMaxThread ? kps::details::kMaxThread - : GetLastPow2(block_dim); + return block_dim >= kps::details::kReduceMaxThread + ? kps::details::kReduceMaxThread + : GetLastPow2(block_dim); } // check reduce rand is valid @@ -339,7 +340,7 @@ struct ReduceConfig { void SetBlockDimForReduceAny(dim3* block_dim, dim3* grid_dim) { constexpr int min_reduce_num_per_thread = 16; constexpr int max_reduce_num_per_thread = 256; - constexpr int max_num_threads = kps::details::kMaxThread; + constexpr int max_num_threads = kps::details::kReduceMaxThread; // set block size. // 1. If reduce_last_dim == true, all the threads whose threadIdx.y are same From 840a65222026fc0a5b5226d08c04eea046a0647e Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Sat, 4 Sep 2021 13:57:34 +0000 Subject: [PATCH 5/7] update --- .../kernel_primitives/datamover_primitives.h | 9 ++- .../kernel_primitives/helper_primitives.h | 2 +- .../operators/reduce_ops/reduce_functor_op.h | 2 +- .../fluid/operators/reduce_ops/reduce_op.cu.h | 66 +++++++++---------- 4 files changed, 39 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/operators/kernel_primitives/datamover_primitives.h b/paddle/fluid/operators/kernel_primitives/datamover_primitives.h index 96088cb9ca003..3932ba1502ecb 100644 --- a/paddle/fluid/operators/kernel_primitives/datamover_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/datamover_primitives.h @@ -151,10 +151,9 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src, } /** - * @brief load data from src to dst with strid, src can be 1D data or 2D data. - * When - * boundary judgment is required, you need to set a to true, and a is false by - * default. + * @brief load data from src to dst with stride, src can be 1D data or 2D data. + * When boundary judgment is required, you need to set a to true, and a is false + * by default. * @typename: * Tx: data type of src * Ty: data type of dstt @@ -397,7 +396,7 @@ __device__ __forceinline__ void ReadDataReduce( } } -/** @brief: WriteData +/** * @brief store data from src to dst, src can be 1D data, you should set NY = 1. * When boundary judgment is required, you need to set a to true, and a is false * by default. diff --git a/paddle/fluid/operators/kernel_primitives/helper_primitives.h b/paddle/fluid/operators/kernel_primitives/helper_primitives.h index 68fd48e97e31a..28c226d77ee14 100644 --- a/paddle/fluid/operators/kernel_primitives/helper_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/helper_primitives.h @@ -32,7 +32,6 @@ static __device__ __forceinline__ platform::float16 LogFunctor( static __device__ __forceinline__ float LogFunctor(float x) { return logf(x); } static __device__ __forceinline__ double LogFunctor(double x) { return log(x); } -} // namespace details /*************************** Compute Functor****************************/ // for margin_cross_entropy template @@ -75,6 +74,7 @@ struct DivideFunctor { T n_inv; }; +} // namespace details } // namespace kernel_primitives } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/reduce_functor_op.h b/paddle/fluid/operators/reduce_ops/reduce_functor_op.h index 637ed2dcc2bba..bdd84ca153a23 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_functor_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_functor_op.h @@ -24,7 +24,7 @@ limitations under the License. */ namespace paddle { namespace operators { -namespace kpds = paddle::operators::kernel_primitives; +namespace kpds = paddle::operators::kernel_primitives::details; template struct CustomMin { diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h index efa3c2c59da36..4760270caa3c6 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h @@ -47,7 +47,7 @@ namespace operators { namespace kps = paddle::operators::kernel_primitives; -namespace detail { +namespace details { static inline int GetLastPow2(int n) { n |= (n >> 1); @@ -117,7 +117,7 @@ static inline paddle::framework::Array VectorToArray( return ret; } -} // namespace detail +} // namespace details using Tensor = framework::Tensor; constexpr int kMaxRank = framework::DDim::kMaxRank; @@ -133,15 +133,15 @@ struct IndexCalculator { const std::vector& cal_strides, const std::vector& full_strides) : dim(dim) { - dims = detail::VectorToArray(cal_dims); - strides = detail::VectorToArray(full_strides); + dims = details::VectorToArray(cal_dims); + strides = details::VectorToArray(full_strides); std::vector cal_divmoders; // fast divmod for (auto i : cal_strides) { cal_divmoders.push_back(platform::FastDivMod(i)); } divmoders = - detail::VectorToArray(cal_divmoders); + details::VectorToArray(cal_divmoders); } __device__ inline int operator()(int offset) const { @@ -306,9 +306,9 @@ struct ReduceConfig { idx_dim.push_back(i); } - x_strides = detail::GetDimStrides(x_dim, idx_dim); - reduce_strides = detail::GetDimStrides(x_dim, reduce_dim); - left_strides = detail::GetDimStrides(x_dim, left_dim); + x_strides = details::GetDimStrides(x_dim, idx_dim); + reduce_strides = details::GetDimStrides(x_dim, reduce_dim); + left_strides = details::GetDimStrides(x_dim, left_dim); reduce_num = reduce_strides[0] * x_dim[reduce_dim[0]]; left_num = 1; @@ -353,23 +353,23 @@ struct ReduceConfig { int block_x, block_y; int grid_num, reduce_num_per_thread; if (reduce_last_dim) { - block_x = detail::GetBlockDim(reduce_num); - block_y = detail::GetBlockDim(left_num); + block_x = details::GetBlockDim(reduce_num); + block_y = details::GetBlockDim(left_num); block_dim->x = block_x; block_dim->y = std::min(block_y, static_cast(max_num_threads / block_dim->x)); - grid_num = detail::AlignUp(left_num, block_dim->y); - reduce_num_per_thread = detail::AlignUp(reduce_num, block_dim->x); + grid_num = details::AlignUp(left_num, block_dim->y); + reduce_num_per_thread = details::AlignUp(reduce_num, block_dim->x); } else { - block_x = detail::GetBlockDim(left_num); - block_y = detail::GetBlockDim(reduce_num); + block_x = details::GetBlockDim(left_num); + block_y = details::GetBlockDim(reduce_num); block_dim->x = std::min(block_x, 32); block_dim->y = std::min(block_y, static_cast(max_num_threads / block_dim->x)); block_dim->x = std::min(block_x, static_cast(max_num_threads / block_dim->y)); - grid_num = detail::AlignUp(left_num, block_dim->x); - reduce_num_per_thread = detail::AlignUp(reduce_num, block_dim->y); + grid_num = details::AlignUp(left_num, block_dim->x); + reduce_num_per_thread = details::AlignUp(reduce_num, block_dim->y); } int device_id = platform::GetCurrentDeviceId(); int max_mp = platform::GetCUDAMultiProcessors(device_id); @@ -389,10 +389,10 @@ struct ReduceConfig { // the number cannot be larger than max_reduce_num_per_thread, so we // choose the maximum between the result above and input_split_num_2. int input_split_num_1 = - detail::AlignUp(reduce_num_per_thread, min_reduce_num_per_thread); + details::AlignUp(reduce_num_per_thread, min_reduce_num_per_thread); int input_split_num_2 = - detail::AlignUp(reduce_num_per_thread, max_reduce_num_per_thread); - int input_split_num_3 = detail::AlignUp(max_num_blocks, grid_num); + details::AlignUp(reduce_num_per_thread, max_reduce_num_per_thread); + int input_split_num_3 = details::AlignUp(max_num_blocks, grid_num); grid_dim->x = grid_num; grid_dim->y = std::max(std::min(input_split_num_1, input_split_num_3), @@ -409,7 +409,7 @@ struct ReduceConfig { // for others: block(block_num, 1) , grid(left_num, 1) void SetBlockDim() { // init - int block_num = detail::GetBlockDim(reduce_num); + int block_num = details::GetBlockDim(reduce_num); should_reduce_again = false; dim3 block_dim(block_num, 1); @@ -435,23 +435,23 @@ struct ReduceConfig { int num_block = (max_threads / left_num); if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) { - blocking_size = detail::GetLastPow2(reduce_num / num_block); + blocking_size = details::GetLastPow2(reduce_num / num_block); if (blocking_size <= 1) { - blocking_size = detail::GetLastPow2(sqrt(reduce_num)); + blocking_size = details::GetLastPow2(sqrt(reduce_num)); } else if (blocking_size * 2 < reduce_num) { blocking_size *= 2; } should_reduce_again = true; - block_dim.x = detail::GetBlockDim(left_num); + block_dim.x = details::GetBlockDim(left_num); block_dim.y = 1; grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x; grid_dim.y = (reduce_num + blocking_size - 1) / blocking_size; } else { - block_dim.x = detail::GetBlockDim(left_num); + block_dim.x = details::GetBlockDim(left_num); block_dim.y = 1; blocking_size = reduce_num; grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x; @@ -671,7 +671,7 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, dim3 grid; if (config.reduce_last_dim) { block = dim3(32, 1, 1); - grid = dim3(detail::AlignUp(config.left_num, 32), 1, 1); + grid = dim3(details::AlignUp(config.left_num, 32), 1, 1); } else { block = dim3(config.block.x, 1, 1); grid = dim3(config.grid.x, 1, config.grid.z); @@ -679,10 +679,10 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, ReduceHigherDimKernel< Ty, Ty, MPType, ReduceOp, - kps::IdentityFunctor><<>>( + kps::details::IdentityFunctor><<>>( config.output_data, y_data, reducer, - kps::IdentityFunctor(config.grid.y), init, config.grid.y, - config.left_num, config.grid.y); + kps::details::IdentityFunctor(config.grid.y), init, + config.grid.y, config.left_num, config.grid.y); } } @@ -756,12 +756,12 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, if (config.should_reduce_again) { dim3 block = dim3(config.block.x, 1, 1); dim3 grid = dim3(config.grid.x, 1, config.grid.z); - ReduceHigherDimKernel< - Ty, Ty, MPType, ReduceOp, - kps::IdentityFunctor><<>>( + ReduceHigherDimKernel, + kps::details::IdentityFunctor< + Ty, MPType>><<>>( config.output_data, y_data, reducer, - kps::IdentityFunctor(config.grid.y), reducer.initial(), - config.grid.y, config.left_num, config.grid.y); + kps::details::IdentityFunctor(config.grid.y), + reducer.initial(), config.grid.y, config.left_num, config.grid.y); } return; } From fc36b459e40fb1188d09dac527cb0f521b4106b2 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Sat, 4 Sep 2021 14:30:03 +0000 Subject: [PATCH 6/7] update --- paddle/fluid/operators/margin_cross_entropy_op.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/margin_cross_entropy_op.cu b/paddle/fluid/operators/margin_cross_entropy_op.cu index 4b63dc5e8527e..93c6232ef4ade 100644 --- a/paddle/fluid/operators/margin_cross_entropy_op.cu +++ b/paddle/fluid/operators/margin_cross_entropy_op.cu @@ -159,7 +159,7 @@ __global__ void LogitsMinusLogSumKernel(T* logits, const T* logits_sum_per_row, const int64_t N, const int64_t D) { CUDA_KERNEL_LOOP(i, N * D) { auto row = i / D; - logits[i] -= kpds::details::LogFunctor(logits_sum_per_row[row]); + logits[i] -= kpds::LogFunctor(logits_sum_per_row[row]); } } @@ -174,9 +174,9 @@ __global__ void HardLabelSoftmaxWithCrossEntropyKernel( if ((col + start_index) == labels[row]) { auto softmax = log_softmax[i]; loss[row] = -softmax; - log_softmax[i] = kpds::details::ExpFunctor(softmax); + log_softmax[i] = kpds::ExpFunctor(softmax); } else { - log_softmax[i] = kpds::details::ExpFunctor(log_softmax[i]); + log_softmax[i] = kpds::ExpFunctor(log_softmax[i]); } } } From 51f2f77b2fa575a10d38ea7815f2797eae4cf169 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Mon, 6 Sep 2021 02:43:13 +0000 Subject: [PATCH 7/7] update detail to details --- paddle/fluid/operators/fused/attn_bias_add.cu.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/fused/attn_bias_add.cu.h b/paddle/fluid/operators/fused/attn_bias_add.cu.h index ddefd8964af97..37e7bd9caa67e 100644 --- a/paddle/fluid/operators/fused/attn_bias_add.cu.h +++ b/paddle/fluid/operators/fused/attn_bias_add.cu.h @@ -202,9 +202,9 @@ void SetConfigForColumnReduce(const int max_threads, const int reduce_num, int num_block = (max_threads / left_num); if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) { - *blocking_size = detail::GetLastPow2(reduce_num / num_block); + *blocking_size = details::GetLastPow2(reduce_num / num_block); if (*blocking_size <= 1) { - *blocking_size = detail::GetLastPow2(sqrt(reduce_num)); + *blocking_size = details::GetLastPow2(sqrt(reduce_num)); } else if (*blocking_size * 2 < reduce_num) { *blocking_size *= 2; }