Skip to content

Commit

Permalink
merge ET
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG committed Sep 3, 2021
1 parent a68841c commit 487f12a
Showing 1 changed file with 48 additions and 127 deletions.
175 changes: 48 additions & 127 deletions paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,128 +180,59 @@ __device__ __forceinline__ void LoadData(
}
}

template <typename InT, typename OutT, int ShapeSize, int VecSize,
typename Functor>
__global__ void BroadcastKernelTernary(
const InT *__restrict__ in0, const InT *__restrict__ in1,
const InT *__restrict__ in2, OutT *out,
framework::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel,
framework::Array<kps::details::BroadcastConfig<ShapeSize>, MAX_INPUT_NUM>
configlists,
int main_tid, int tail_tid, Functor func) {
int block_offset =
blockIdx.x * blockDim.x * VecSize; // data offset of this block
int num = tail_tid;
InT arg[3][VecSize];
template <ElementwiseType ET, typename InT, typename OutT, int ShapeSize,
int VecSize, typename Functor, bool IsBoundary = false>
__device__ void DealSegment(
const framework::Array<const InT *__restrict__, ET> &in, OutT *out,
const framework::Array<bool, MAX_INPUT_NUM> &use_broadcast, uint32_t numel,
const framework::Array<kps::details::BroadcastConfig<ShapeSize>,
MAX_INPUT_NUM> &configlists,
int num, Functor func) {
InT args[ET][VecSize];
OutT result[VecSize];
const bool is_boundary = true;
if (blockIdx.x < main_tid) {
num = blockDim.x * VecSize; // blockIdx.x < main_tid
// load in0, in1, in2
LoadData<InT, VecSize, ShapeSize>(arg[0], in0, block_offset, configlists[0],
numel, num, use_broadcast[0]);
LoadData<InT, VecSize, ShapeSize>(arg[1], in1, block_offset, configlists[1],
numel, num, use_broadcast[1]);
LoadData<InT, VecSize, ShapeSize>(arg[2], in2, block_offset, configlists[2],
numel, num, use_broadcast[2]);
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, arg[0], arg[1], arg[2], func);
kps::WriteData<OutT, VecSize, 1, 1>(out + block_offset, result, num);
} else { // blockIdx.x == main_tid
// This is the last block and tial_tid != 0, set is_boundary = true
// is_boundary = true, boundary judgment needs to be made when loading data
// to avoid access storage overflow
kps::Init<InT, VecSize>(arg[0], static_cast<InT>(1.0f));
kps::Init<InT, VecSize>(arg[1], static_cast<InT>(1.0f));
kps::Init<InT, VecSize>(arg[2], static_cast<InT>(1.0f));
LoadData<InT, VecSize, ShapeSize, is_boundary>(arg[0], in0, block_offset,
configlists[0], numel, num,
use_broadcast[0]);
LoadData<InT, VecSize, ShapeSize, is_boundary>(arg[1], in1, block_offset,
configlists[1], numel, num,
use_broadcast[1]);
LoadData<InT, VecSize, ShapeSize, is_boundary>(arg[2], in2, block_offset,
configlists[2], numel, num,
use_broadcast[2]);
int block_offset = blockIdx.x * blockDim.x * VecSize;
// load
#pragma unroll
for (int i = 0; i < ET; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
LoadData<InT, VecSize, ShapeSize, IsBoundary>(args[i], in[i], block_offset,
configlists[i], numel, num,
use_broadcast[i]);
}
// compute
if (ET == kUnary) {
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
func);
} else if (ET == kBinary) {
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
args[1], func);
} else {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, arg[0], arg[1], arg[2], func);
kps::WriteData<OutT, VecSize, 1, 1, is_boundary>(out + block_offset, result,
num);
result, args[0], args[1], args[2], func);
}
// compute
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + block_offset, result,
num);
}

template <typename InT, typename OutT, int ShapeSize, int VecSize,
typename Functor>
__global__ void BroadcastKernelBinary(
const InT *__restrict__ in0, const InT *__restrict__ in1, OutT *out,
template <ElementwiseType ET, typename InT, typename OutT, int ShapeSize,
int VecSize, typename Functor>
__global__ void BroadcastKernel(
framework::Array<const InT *__restrict__, ET> in, OutT *out,
framework::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel,
framework::Array<kps::details::BroadcastConfig<ShapeSize>, MAX_INPUT_NUM>
configlists,
int main_tid, int tail_tid, Functor func) {
int block_offset =
blockIdx.x * blockDim.x * VecSize; // data offset of this block
int num = tail_tid;
InT arg[2][VecSize];
OutT result[VecSize];
const bool is_boundary = true;
int block_offset = blockIdx.x * blockDim.x * VecSize;
// data offset of this block
if (blockIdx.x < main_tid) {
num = blockDim.x * VecSize; // blockIdx.x < main_tid
LoadData<InT, VecSize, ShapeSize>(arg[0], in0, block_offset, configlists[0],
numel, num, use_broadcast[0]);
LoadData<InT, VecSize, ShapeSize>(arg[1], in1, block_offset, configlists[1],
numel, num, use_broadcast[1]);
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, arg[0],
arg[1], func);
kps::WriteData<OutT, VecSize, 1, 1>(out + block_offset, result, num);
int num = blockDim.x * VecSize; // blockIdx.x < main_tid
DealSegment<ET, InT, OutT, ShapeSize, VecSize, Functor, false>(
in, out, use_broadcast, numel, configlists, num, func);
} else { // reminder
// This is the last block and tial_tid != 0, set is_boundary = true
// is_boundary = true, boundary judgment needs to be made when loading data
// to avoid access storage overflow
kps::Init<InT, VecSize>(arg[0], static_cast<InT>(1.0f));
kps::Init<InT, VecSize>(arg[1], static_cast<InT>(1.0f));
LoadData<InT, VecSize, ShapeSize, is_boundary>(arg[0], in0, block_offset,
configlists[0], numel, num,
use_broadcast[0]);
LoadData<InT, VecSize, ShapeSize, is_boundary>(arg[1], in1, block_offset,
configlists[1], numel, num,
use_broadcast[1]);
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, arg[0],
arg[1], func);
kps::WriteData<OutT, VecSize, 1, 1, is_boundary>(out + block_offset, result,
num);
}
}

template <typename InT, typename OutT, int ShapeSize, int VecSize,
typename Functor>
__global__ void BroadcastKernelUnary(
const InT *__restrict__ in, OutT *out, int numel,
kps::details::BroadcastConfig<ShapeSize> config, int main_tid, int tail_tid,
Functor func) {
int block_offset =
blockIdx.x * blockDim.x * VecSize; // data offset of this block
int num = tail_tid;
InT arg[VecSize];
OutT result[VecSize];
const bool is_boundary = true;
if (blockIdx.x < main_tid) {
num = blockDim.x * VecSize; // blockIdx.x < main_tid
kps::ReadDataBc<InT, VecSize, 1, 1, ShapeSize>(&arg[0], in, block_offset,
config, numel, 1, 1);
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(&result[0],
&arg[0], func);
kps::WriteData<OutT, VecSize, 1, 1>(out + block_offset, &result[0], num);
} else {
// This is the last block and tial_tid != 0, set is_boundary = true
// is_boundary = true, boundary judgment needs to be made when loading data
// to avoid access storage overflow
kps::Init<InT, VecSize>(&arg[0], static_cast<InT>(1.0f));
kps::ReadDataBc<InT, VecSize, 1, 1, ShapeSize, is_boundary>(
&arg[0], in, block_offset, config, numel, 1, 1);
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(&result[0],
&arg[0], func);
kps::WriteData<OutT, VecSize, 1, 1, is_boundary>(out + block_offset,
&result[0], num);
int num = tail_tid;
DealSegment<ET, InT, OutT, ShapeSize, VecSize, Functor, true>(
in, out, use_broadcast, numel, configlists, num, func);
}
}

Expand All @@ -323,9 +254,11 @@ void LaunchKernel(const platform::CUDADeviceContext &ctx,
framework::Array<kps::details::BroadcastConfig<Size>, MAX_INPUT_NUM>
configlists;
framework::Array<bool, MAX_INPUT_NUM> use_broadcast;
framework::Array<const InT *__restrict__, ET> ins_data;

for (int i = 0; i < ET; i++) {
use_broadcast[i] = (ins[i]->numel() != numel);
ins_data[i] = ins[i]->data<InT>();
if (use_broadcast[i]) {
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
Expand All @@ -335,22 +268,10 @@ void LaunchKernel(const platform::CUDADeviceContext &ctx,
}
}

if (ET == kUnary) { // for unary eg: relu
BroadcastKernelUnary<InT, OutT, Size, VecSize,
Functor><<<blocks, threads, 0, stream>>>(
ins[0]->data<InT>(), out_data, numel, configlists[0], main_tid,
tail_tid, func);
} else if (ET == kBinary) { // for binary eg: add: a + b
BroadcastKernelBinary<InT, OutT, Size, VecSize,
Functor><<<blocks, threads, 0, stream>>>(
ins[0]->data<InT>(), ins[1]->data<InT>(), out_data, use_broadcast,
numel, configlists, main_tid, tail_tid, func);
} else { // for ternary eg:fma : a * b + c
BroadcastKernelTernary<InT, OutT, Size, VecSize,
Functor><<<blocks, threads, 0, stream>>>(
ins[0]->data<InT>(), ins[1]->data<InT>(), ins[2]->data<InT>(), out_data,
use_broadcast, numel, configlists, main_tid, tail_tid, func);
}
BroadcastKernel<ET, InT, OutT, Size, VecSize,
Functor><<<blocks, threads, 0, stream>>>(
ins_data, out_data, use_broadcast, numel, configlists, main_tid, tail_tid,
func);
}

template <typename InT, typename OutT, ElementwiseType ET, int VecSize,
Expand Down

1 comment on commit 487f12a

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.