Skip to content

Commit

Permalink
add api in compute_primitives.h and datamover_primitives.h for eleme…
Browse files Browse the repository at this point in the history
…ntwise_add
  • Loading branch information
AnnaTrainingG committed Aug 2, 2021
1 parent 0e2c73b commit 5065b23
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 155 deletions.
78 changes: 41 additions & 37 deletions paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -465,43 +465,47 @@ void LaunchBroadcastElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
PADDLE_ENFORCE_EQ(ET, ElementwiseType::kBinary,
platform::errors::InvalidArgument(
"Currently, only Support binary calculation, "
"but received %d input tensors.\n",
static_cast<int>(ET)));
int in_vec_size = 4;
framework::Tensor *out = (*outs)[0];
for (auto *in : ins) {
auto temp_size = GetVectorizedSizeImpl<InT>(in->data<InT>());
in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
: in_vec_size;
}
int out_vec_size = GetVectorizedSizeImpl<OutT>(out->data<OutT>());
int vec_size = std::min(out_vec_size, in_vec_size);

switch (vec_size) {
case 4: {
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 4>(ctx, ins, out,
axis, func);
break;
}
case 2: {
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 2>(ctx, ins, out,
axis, func);
break;
}
case 1: {
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 1>(ctx, ins, out,
axis, func);
break;
}
default: {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported vectorized size: %d !", vec_size));
break;
}
}
/*
PADDLE_ENFORCE_EQ(ET, ElementwiseType::kBinary,
platform::errors::InvalidArgument(
"Currently, only Support binary calculation, "
"but received %d input tensors.\n",
static_cast<int>(ET)));
int in_vec_size = 4;
framework::Tensor *out = (*outs)[0];
for (auto *in : ins) {
auto temp_size = GetVectorizedSizeImpl<InT>(in->data<InT>());
in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
: in_vec_size;
}
int out_vec_size = GetVectorizedSizeImpl<OutT>(out->data<OutT>());
int vec_size = std::min(out_vec_size, in_vec_size);
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 4>(ctx, ins, out,
axis, func);
switch (vec_size) {
case 4: {
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 4>(ctx, ins, out,
axis, func);
break;
}
case 2: {
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 2>(ctx, ins, out,
axis, func);
break;
}
case 1: {
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 1>(ctx, ins, out,
axis, func);
break;
}
default: {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported vectorized size: %d !", vec_size));
break;
}
}
*/
}

template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
Expand Down
143 changes: 27 additions & 116 deletions paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */
#pragma once

#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/fast_divmod.h"

Expand Down Expand Up @@ -101,116 +102,42 @@ int GetVectorizedSize(const std::vector<const framework::Tensor *> &ins,
return vec_size;
}

template <ElementwiseType ET, int VecSize, typename InT, typename OutT>
struct ElementwiseDataWrapper {
OutT *out;
const InT *in0;
const InT *in1;
__device__ ElementwiseDataWrapper(OutT *out, const InT *in0,
const InT *in1 = nullptr)
: out(out), in0(in0), in1(in1) {}

using InVecType = CudaAlignedVector<InT, VecSize>;
using OutVecType = CudaAlignedVector<OutT, VecSize>;

inline __device__ void load_vector(InVecType args[], int idx) {
const InVecType *x_vec = reinterpret_cast<const InVecType *>(in0);
args[0] = x_vec[idx];
if (ET == ElementwiseType::kBinary) {
const InVecType *y_vec = reinterpret_cast<const InVecType *>(in1);
args[1] = y_vec[idx];
}
}

inline __device__ void load_scalar(InT args[], int idx) {
args[0] = in0[idx];
if (ET == ElementwiseType::kBinary) {
args[1] = in1[idx];
}
}

inline __device__ void store_vector(OutVecType res, int idx) {
OutVecType *out_vec = reinterpret_cast<OutVecType *>(out);
out_vec[idx] = res;
}

inline __device__ void store_scalar(OutT res, int idx) { out[idx] = res; }
};

template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
typename Functor>
__device__ inline void VectorizedKernelImpl(
ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
int tid) {
using InVecType = CudaAlignedVector<InT, VecSize>;
__device__ inline void Compute(const InT *__restrict__ in0,
const InT *__restrict__ in1, OutT *out,
Functor func, int size) {
using OutVecType = CudaAlignedVector<OutT, VecSize>;
InVecType ins_vec[ET];
OutVecType out_vec;
InT *ins_ptr[ET];
InT ins[ET];
#pragma unroll
for (int i = 0; i < ET; ++i) {
ins_ptr[i] = reinterpret_cast<InT *>(&(ins_vec[i]));
}
// load
data.load_vector(ins_vec, tid);

// compute
OutVecType *dst = reinterpret_cast<OutVecType *>(out);
InT args[ET][VecSize];
kernel_primitives::read_data<InT, VecSize, 1, 1>(&args[0][0], in0, size);
kernel_primitives::read_data<InT, VecSize, 1, 1>(&args[1][0], in1, size);
InT data[ET];
OutVecType result;
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
#pragma unroll
for (int j = 0; j < ET; ++j) {
ins[j] = ins_ptr[j][i];
data[j] = args[j][i];
}
out_vec.val[i] = func(ins);
result.val[i] = static_cast<OutT>(func(data));
}
// store
data.store_vector(out_vec, tid);
kernel_primitives::write_data<OutVecType, 1, 1, 1>(&result, dst, 0, 0);
}

template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
typename Functor>
__device__ inline void ScalarKernelImpl(
ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
int start, int remain) {
InT ins[ET];
OutT out;

for (int i = 0; i < remain; ++i) {
int idx = start + i;
// load
data.load_scalar(ins, idx);
// compute
out = func(ins);
// store
data.store_scalar(out, idx);
}
}

template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
typename Functor>
__global__ void VectorizedKernel(const InT *__restrict__ in0,
const InT *__restrict__ in1, OutT *out,
int size, Functor func) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int remain = size - VecSize * tid;
remain = remain > 0 ? remain : 0;
auto data = ElementwiseDataWrapper<ET, VecSize, InT, OutT>(out, in0, in1);
if (remain >= VecSize) {
VectorizedKernelImpl(data, func, tid);
} else {
ScalarKernelImpl(data, func, tid * VecSize, remain);
}
}

template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
__global__ void ScalarKernel(const InT *__restrict__ in0,
const InT *__restrict__ in1, OutT *out, int size,
Functor func) {
auto data = ElementwiseDataWrapper<ET, 1, InT, OutT>(out, in0, in1);
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int remain = tid < size ? 1 : 0;
ScalarKernelImpl(data, func, tid, remain);
__global__ void ElementVectorized(const InT *__restrict__ in0,
const InT *__restrict__ in1, OutT *out,
int size, Functor func) {
int tid = blockIdx.x * blockDim.x;
int fix = VecSize * tid;
int max_size = blockDim.x * VecSize;
int remain = size - fix;
int num = remain > max_size ? max_size : remain;
num = num > 0 ? num : 0;
Compute<ET, VecSize, InT, OutT, Functor>(in0 + fix, in1 + fix, out + fix,
func, num);
}

template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
Expand All @@ -220,7 +147,7 @@ void LaunchSameDimsElementwiseCudaKernel(
std::vector<framework::Tensor *> *outs, Functor func) {
// calculate the max vec_size for all ins and outs
auto size = ins[0]->numel();
int vec_size = GetVectorizedSize<InT, OutT>(ins, *outs);
const int vec_size = 4;
int block_size = GetThreadsConfig(ctx, size, vec_size);
int grid_size =
((size + vec_size - 1) / vec_size + block_size - 1) / block_size;
Expand All @@ -231,24 +158,8 @@ void LaunchSameDimsElementwiseCudaKernel(
// cuda kernel
auto stream = ctx.stream();

switch (vec_size) {
case 4:
VectorizedKernel<ET, 4><<<grid_size, block_size, 0, stream>>>(
in0, in1, out, size, func);
break;
case 2:
VectorizedKernel<ET, 2><<<grid_size, block_size, 0, stream>>>(
in0, in1, out, size, func);
break;
case 1:
ScalarKernel<ET><<<grid_size, block_size, 0, stream>>>(in0, in1, out,
size, func);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported vectorized size: %d !", vec_size));
break;
}
ElementVectorized<ET, vec_size><<<grid_size, block_size, 0, stream>>>(
in0, in1, out, size, func);
}

} // namespace operators
Expand Down
80 changes: 79 additions & 1 deletion paddle/fluid/operators/kernel_primitives/compute_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,84 @@

namespace paddle {
namespace operators {
namespace kernel_primitives {}
namespace kernel_primitives {

/**
* @brief compute functor for elementwise_two, in1 and in2 has the same shape
* @param:
* T : the type of in1 and in2
* NX: the row of in1 and in2
* NY: the col of in1 and in2
* BlockSize: the strid of col
* OpFunc: compute functor eg: ADD, SUB, XOR, OR, MUL
*/
template <typename T, int NX, int NY, int BlockSize, class OpFunc>
__device__ void elementwise_binary(const T* __restrict__ in1,
const T* __restrict__ in2,
T* __restrict__ out) {
OpFunc compute;
#pragma unroll
for (int idx = 0; idx < NX * NY; ++idx) {
compute(in1[idx], in2[idx], &out[idx]);
}
}

/**
* @brief fma eg: a * b + c, in1 in2, in3 and out has the same shape
* @param:
* T : the type of in1 and in2, in3
* NX: the row of in1, in2 and in3
* NY: the col of in1, in2 and in3
* BlockSize: the strid of col
*/
template <typename T, int NX, int NY, int BlockSize, class OpFunc>
__device__ void elementwise_fma(const T* in1, const T* in2, const T* in3,
T* out) {
#pragma unroll
for (int idx = 0; idx < NX * NY; ++idx) {
out[idx] = in1[idx] * in2[idx] + out[idx];
}
}

/**
* @brief compute functor for elementwise_two, in1 is [1, NY], in2 is [NX, NY]
* @param:
* T : the type of in1 and in2
* NX: the row of in1 and in2
* NY: the col of in2
* BlockSize: the strid of col
* OpFunc: compute functor eg: ADD, SUB, XOR, OR, MUL
*/
template <typename T, int NX, int NY, int BlockSize, class OpFunc>
__device__ void cycle_binary(const T* in1, const T* in2, T* out) {
OpFunc compute;
#pragma unroll
for (int idx = 0; idx < NX; idx++) {
#pragma unroll
for (int idy = 0; idy < NY; idy++) {
compute(in1[idx], in2[idx + idy * NX], out[idx + idy * NX]);
}
}
}

/**
* @brief compute functor for unary, in1 is [NX, NY]
* @param:
* T : the type of in
* NX: the row of in
* NY: the col of in
* BlockSize: the strid of col
* OpFunc: compute functor eg: relu, sigmoid, exp
*/
template <typename T, int NX, int NY, int BlockSize, class OpFunc>
__device__ void elementwise_unary(const T* in, T* out) {
OpFunc compute;
#pragma unroll
for (int idx = 0; idx < NX * NY; idx++) {
compute(in[idx], out[idx]);
}
}

} // namespace kernel_primitives
} // namespace operators
} // namespace paddle
Loading

0 comments on commit 5065b23

Please sign in to comment.