diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index 541ff9aacfc46..e6074bc9aacef 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -465,43 +465,47 @@ void LaunchBroadcastElementwiseCudaKernel( const platform::CUDADeviceContext &ctx, const std::vector &ins, std::vector *outs, int axis, Functor func) { - PADDLE_ENFORCE_EQ(ET, ElementwiseType::kBinary, - platform::errors::InvalidArgument( - "Currently, only Support binary calculation, " - "but received %d input tensors.\n", - static_cast(ET))); - int in_vec_size = 4; - framework::Tensor *out = (*outs)[0]; - for (auto *in : ins) { - auto temp_size = GetVectorizedSizeImpl(in->data()); - in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size) - : in_vec_size; - } - int out_vec_size = GetVectorizedSizeImpl(out->data()); - int vec_size = std::min(out_vec_size, in_vec_size); - - switch (vec_size) { - case 4: { - LaunchBroadcastKernelForDifferentDimSize(ctx, ins, out, - axis, func); - break; - } - case 2: { - LaunchBroadcastKernelForDifferentDimSize(ctx, ins, out, - axis, func); - break; - } - case 1: { - LaunchBroadcastKernelForDifferentDimSize(ctx, ins, out, - axis, func); - break; - } - default: { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported vectorized size: %d !", vec_size)); - break; - } - } + /* +PADDLE_ENFORCE_EQ(ET, ElementwiseType::kBinary, + platform::errors::InvalidArgument( + "Currently, only Support binary calculation, " + "but received %d input tensors.\n", + static_cast(ET))); +int in_vec_size = 4; +framework::Tensor *out = (*outs)[0]; +for (auto *in : ins) { +auto temp_size = GetVectorizedSizeImpl(in->data()); +in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size) + : in_vec_size; +} +int out_vec_size = GetVectorizedSizeImpl(out->data()); +int vec_size = std::min(out_vec_size, in_vec_size); +LaunchBroadcastKernelForDifferentDimSize(ctx, ins, out, + axis, func); + +switch (vec_size) { +case 4: { +LaunchBroadcastKernelForDifferentDimSize(ctx, ins, out, + axis, func); +break; +} +case 2: { +LaunchBroadcastKernelForDifferentDimSize(ctx, ins, out, + axis, func); +break; +} +case 1: { +LaunchBroadcastKernelForDifferentDimSize(ctx, ins, out, + axis, func); +break; +} +default: { +PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported vectorized size: %d !", vec_size)); +break; +} +} +*/ } template diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 101512e35fdcb..cb1049ec4fb58 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/fast_divmod.h" @@ -101,116 +102,42 @@ int GetVectorizedSize(const std::vector &ins, return vec_size; } -template -struct ElementwiseDataWrapper { - OutT *out; - const InT *in0; - const InT *in1; - __device__ ElementwiseDataWrapper(OutT *out, const InT *in0, - const InT *in1 = nullptr) - : out(out), in0(in0), in1(in1) {} - - using InVecType = CudaAlignedVector; - using OutVecType = CudaAlignedVector; - - inline __device__ void load_vector(InVecType args[], int idx) { - const InVecType *x_vec = reinterpret_cast(in0); - args[0] = x_vec[idx]; - if (ET == ElementwiseType::kBinary) { - const InVecType *y_vec = reinterpret_cast(in1); - args[1] = y_vec[idx]; - } - } - - inline __device__ void load_scalar(InT args[], int idx) { - args[0] = in0[idx]; - if (ET == ElementwiseType::kBinary) { - args[1] = in1[idx]; - } - } - - inline __device__ void store_vector(OutVecType res, int idx) { - OutVecType *out_vec = reinterpret_cast(out); - out_vec[idx] = res; - } - - inline __device__ void store_scalar(OutT res, int idx) { out[idx] = res; } -}; - template -__device__ inline void VectorizedKernelImpl( - ElementwiseDataWrapper data, Functor func, - int tid) { - using InVecType = CudaAlignedVector; +__device__ inline void Compute(const InT *__restrict__ in0, + const InT *__restrict__ in1, OutT *out, + Functor func, int size) { using OutVecType = CudaAlignedVector; - InVecType ins_vec[ET]; - OutVecType out_vec; - InT *ins_ptr[ET]; - InT ins[ET]; -#pragma unroll - for (int i = 0; i < ET; ++i) { - ins_ptr[i] = reinterpret_cast(&(ins_vec[i])); - } - // load - data.load_vector(ins_vec, tid); - -// compute + OutVecType *dst = reinterpret_cast(out); + InT args[ET][VecSize]; + kernel_primitives::read_data(&args[0][0], in0, size); + kernel_primitives::read_data(&args[1][0], in1, size); + InT data[ET]; + OutVecType result; #pragma unroll for (int i = 0; i < VecSize; ++i) { #pragma unroll for (int j = 0; j < ET; ++j) { - ins[j] = ins_ptr[j][i]; + data[j] = args[j][i]; } - out_vec.val[i] = func(ins); + result.val[i] = static_cast(func(data)); } - // store - data.store_vector(out_vec, tid); + kernel_primitives::write_data(&result, dst, 0, 0); } template -__device__ inline void ScalarKernelImpl( - ElementwiseDataWrapper data, Functor func, - int start, int remain) { - InT ins[ET]; - OutT out; - - for (int i = 0; i < remain; ++i) { - int idx = start + i; - // load - data.load_scalar(ins, idx); - // compute - out = func(ins); - // store - data.store_scalar(out, idx); - } -} - -template -__global__ void VectorizedKernel(const InT *__restrict__ in0, - const InT *__restrict__ in1, OutT *out, - int size, Functor func) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int remain = size - VecSize * tid; - remain = remain > 0 ? remain : 0; - auto data = ElementwiseDataWrapper(out, in0, in1); - if (remain >= VecSize) { - VectorizedKernelImpl(data, func, tid); - } else { - ScalarKernelImpl(data, func, tid * VecSize, remain); - } -} - -template -__global__ void ScalarKernel(const InT *__restrict__ in0, - const InT *__restrict__ in1, OutT *out, int size, - Functor func) { - auto data = ElementwiseDataWrapper(out, in0, in1); - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int remain = tid < size ? 1 : 0; - ScalarKernelImpl(data, func, tid, remain); +__global__ void ElementVectorized(const InT *__restrict__ in0, + const InT *__restrict__ in1, OutT *out, + int size, Functor func) { + int tid = blockIdx.x * blockDim.x; + int fix = VecSize * tid; + int max_size = blockDim.x * VecSize; + int remain = size - fix; + int num = remain > max_size ? max_size : remain; + num = num > 0 ? num : 0; + Compute(in0 + fix, in1 + fix, out + fix, + func, num); } template @@ -220,7 +147,7 @@ void LaunchSameDimsElementwiseCudaKernel( std::vector *outs, Functor func) { // calculate the max vec_size for all ins and outs auto size = ins[0]->numel(); - int vec_size = GetVectorizedSize(ins, *outs); + const int vec_size = 4; int block_size = GetThreadsConfig(ctx, size, vec_size); int grid_size = ((size + vec_size - 1) / vec_size + block_size - 1) / block_size; @@ -231,24 +158,8 @@ void LaunchSameDimsElementwiseCudaKernel( // cuda kernel auto stream = ctx.stream(); - switch (vec_size) { - case 4: - VectorizedKernel<<>>( - in0, in1, out, size, func); - break; - case 2: - VectorizedKernel<<>>( - in0, in1, out, size, func); - break; - case 1: - ScalarKernel<<>>(in0, in1, out, - size, func); - break; - default: - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported vectorized size: %d !", vec_size)); - break; - } + ElementVectorized<<>>( + in0, in1, out, size, func); } } // namespace operators diff --git a/paddle/fluid/operators/kernel_primitives/compute_primitives.h b/paddle/fluid/operators/kernel_primitives/compute_primitives.h index 1d23cfe007558..67e946891feb3 100644 --- a/paddle/fluid/operators/kernel_primitives/compute_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/compute_primitives.h @@ -16,6 +16,84 @@ namespace paddle { namespace operators { -namespace kernel_primitives {} +namespace kernel_primitives { + +/** + * @brief compute functor for elementwise_two, in1 and in2 has the same shape + * @param: + * T : the type of in1 and in2 + * NX: the row of in1 and in2 + * NY: the col of in1 and in2 + * BlockSize: the strid of col + * OpFunc: compute functor eg: ADD, SUB, XOR, OR, MUL + */ +template +__device__ void elementwise_binary(const T* __restrict__ in1, + const T* __restrict__ in2, + T* __restrict__ out) { + OpFunc compute; +#pragma unroll + for (int idx = 0; idx < NX * NY; ++idx) { + compute(in1[idx], in2[idx], &out[idx]); + } +} + +/** + * @brief fma eg: a * b + c, in1 in2, in3 and out has the same shape + * @param: + * T : the type of in1 and in2, in3 + * NX: the row of in1, in2 and in3 + * NY: the col of in1, in2 and in3 + * BlockSize: the strid of col + */ +template +__device__ void elementwise_fma(const T* in1, const T* in2, const T* in3, + T* out) { +#pragma unroll + for (int idx = 0; idx < NX * NY; ++idx) { + out[idx] = in1[idx] * in2[idx] + out[idx]; + } } + +/** + * @brief compute functor for elementwise_two, in1 is [1, NY], in2 is [NX, NY] + * @param: + * T : the type of in1 and in2 + * NX: the row of in1 and in2 + * NY: the col of in2 + * BlockSize: the strid of col + * OpFunc: compute functor eg: ADD, SUB, XOR, OR, MUL + */ +template +__device__ void cycle_binary(const T* in1, const T* in2, T* out) { + OpFunc compute; +#pragma unroll + for (int idx = 0; idx < NX; idx++) { +#pragma unroll + for (int idy = 0; idy < NY; idy++) { + compute(in1[idx], in2[idx + idy * NX], out[idx + idy * NX]); + } + } } + +/** + * @brief compute functor for unary, in1 is [NX, NY] + * @param: + * T : the type of in + * NX: the row of in + * NY: the col of in + * BlockSize: the strid of col + * OpFunc: compute functor eg: relu, sigmoid, exp + */ +template +__device__ void elementwise_unary(const T* in, T* out) { + OpFunc compute; +#pragma unroll + for (int idx = 0; idx < NX * NY; idx++) { + compute(in[idx], out[idx]); + } +} + +} // namespace kernel_primitives +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/kernel_primitives/datamover_primitives.h b/paddle/fluid/operators/kernel_primitives/datamover_primitives.h index 1d23cfe007558..37104a0cee75e 100644 --- a/paddle/fluid/operators/kernel_primitives/datamover_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/datamover_primitives.h @@ -13,9 +13,135 @@ // limitations under the License. #pragma once +#include +#include +#include +#include +#include +#include "paddle/fluid/platform/fast_divmod.h" namespace paddle { namespace operators { -namespace kernel_primitives {} +namespace kernel_primitives { + +template +struct alignas(sizeof(T) * VecSize) VectorType { + T val[VecSize]; +}; + +template +__device__ void read_data_base(T* dst, const T* __restrict__ src, int size) { + int dx = threadIdx.x * NX; +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + if ((idx + dx) >= size) { + break; + } + dst[idx] = src[idx + dx]; + } +} + +template +__device__ void read_data(T* dst, const T* __restrict__ src, int size) { + enum { + VECTOR_SIZE = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1, + VECTORS_PER_THREAD = NX / VECTOR_SIZE, + }; + + // Vector per thread + if (blockDim.x * NX > size) { + read_data_base(dst, src, size); + } else { + // Vector type + using VecType = VectorType; + VecType vec_temp[VECTORS_PER_THREAD]; + const VecType* vec_input = reinterpret_cast(src); + read_data_base( + vec_temp, vec_input, VECTORS_PER_THREAD * blockDim.x); +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + dst[idx] = *(reinterpret_cast(vec_temp) + idx); + } + } } + +/** @brief: read_data + * read data from src ptr + * @param: + * src: the source pointer + * dst: the dst pointer + * stride_nx: the stride of src + * stride_ny: the stride of src + * the shape of dst is [NY, NX]; + */ +template +__device__ void read_data(const T* __restrict__ src, T* dst, int stride_nx, + int stride_ny) { + // out[NY][NX]; + int base_offset = threadIdx.x * NX; +#pragma unroll + for (int idy = 0; idy < NY; idy++) { +#pragma unroll + for (int idx = 0; idx < NX; idx++) { + dst[idy * NX + idx] = + src[idy * stride_ny + stride_nx * idx + base_offset]; + } + } } + +/** @brief: read_data_bc + * read data from src ptr when the shape of src and dst are different + * @param: + * src: the source pointer + * dst: the dst pointer + * stride_nx: the stride of src + * stride_ny: the stride of src + * the shape of dst is [NY, NX] + */ +template +__device__ __forceinline__ void read_data_bc( + const T* __restrict__ src, T* dst, uint32_t fix, FastDivMod* divmoders, + uint32_t* strides, uint32_t stride_nx, uint32_t stride_ny) { + uint32_t base_offset = fix + threadIdx.x * NX; + uint32_t offset = 0; +#pragma unroll + for (int ny = 0; ny < NY; ++ny) { +#pragma unroll + for (uint32_t nx = 0; nx < NX; ++nx) { + uint32_t idx = base_offset + ny * stride_ny + nx * stride_nx; + offset = 0; +#pragma unroll + for (int i = 0; i < Shape_Size; ++i) { + auto fast_divmoder = divmoders[i].Divmod(idx); + idx = fast_divmoder.val[0]; + offset += fast_divmoder.val[1] * strides[i]; + } + dst[nx + ny * NX] = src[offset]; + } + } +} + +/** @brief: write_data + * @param: + * src: the source pointer + * dst: the dst pointer + * stride_nx: the stride of dst + * stride_ny: the stride of dst + * the shape of src is [NY, NX]; + */ +template +__device__ void write_data(T* dst, const T* src, int stride_nx, int stride_ny) { + uint32_t base_offset = threadIdx.x * NX; +#pragma unroll + for (int idy = 0; idy < NY; idy++) { +#pragma unroll + for (int idx = 0; idx < NX; idx++) { + dst[idy * stride_ny + idx * stride_nx + base_offset] = + src[idx + idy * NX]; + } + } +} + +} // namespace kernel_primitives +} // namespace operators +} // namespace paddle