Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-output feature for elementwise #38410

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ struct DimensionsTransform {
}
};

template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
template <ElementwiseType ET, typename InT, typename OutT, typename Functor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件里面的DimensionsTransform是不是可以删掉了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以删除掉,我提个PR把这里删除吧,本来以为pten那边会删除掉

int NumOuts = 1>
void LaunchBroadcastElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
Expand Down Expand Up @@ -190,11 +191,12 @@ void LaunchBroadcastElementwiseCudaKernel(
for (int i = 0; i < pt_outputs_tmp.size(); i++) {
pt_outputs.push_back(pt_outputs_tmp[i].get());
}
pten::LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT>(
pten::LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT, Functor, NumOuts>(
ctx, pt_inputs, &pt_outputs, axis, func);
}

template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
template <ElementwiseType ET, typename InT, typename OutT, typename Functor,
int NumOuts = 1>
void LaunchElementwiseCudaKernel(
const platform::CUDADeviceContext &cuda_ctx,
const std::vector<const framework::Tensor *> &ins,
Expand Down Expand Up @@ -222,8 +224,8 @@ void LaunchElementwiseCudaKernel(
for (int i = 0; i < pt_outputs_tmp.size(); i++) {
pt_outputs.push_back(pt_outputs_tmp[i].get());
}
pten::LaunchElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, pt_inputs,
&pt_outputs, axis, func);
pten::LaunchElementwiseCudaKernel<ET, InT, OutT, Functor, NumOuts>(
cuda_ctx, pt_inputs, &pt_outputs, axis, func);
}

} // namespace operators
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ namespace kps = paddle::operators::kernel_primitives;

using ElementwiseType = pten::ElementwiseType;

template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
template <ElementwiseType ET, typename InT, typename OutT, typename Functor,
int NumOuts = 1>
void LaunchSameDimsElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
Expand Down Expand Up @@ -66,8 +67,8 @@ void LaunchSameDimsElementwiseCudaKernel(
for (int i = 0; i < pt_outputs_tmp.size(); i++) {
pt_outputs.push_back(pt_outputs_tmp[i].get());
}
pten::LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT>(ctx, pt_inputs,
&pt_outputs, func);
pten::LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT, Functor, NumOuts>(
ctx, pt_inputs, &pt_outputs, func);
}

} // namespace operators
Expand Down
10 changes: 7 additions & 3 deletions paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ limitations under the License. */

namespace pten {

template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
template <ElementwiseType ET,
typename InT,
typename OutT,
typename Functor,
int NumOuts = 1>
void LaunchElementwiseCudaKernel(
const paddle::platform::CUDADeviceContext &cuda_ctx,
const std::vector<const DenseTensor *> &ins,
Expand All @@ -33,14 +37,14 @@ void LaunchElementwiseCudaKernel(
dims_size.emplace_back(in->dims().size());
}
if (no_broadcast_flag) {
LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT>(
LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT, Functor, NumOuts>(
cuda_ctx, ins, outs, func);
} else {
axis = axis == -1
? *std::max_element(dims_size.begin(), dims_size.end()) -
*std::min_element(dims_size.begin(), dims_size.end())
: axis;
LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT>(
LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT, Functor, NumOuts>(
cuda_ctx, ins, outs, axis, func);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ __device__ void ElementwiseBroadcastKernelImpl(
int block_offset,
Functor func) {
InT args[Arity][VecSize];
OutType<OutT, NumOuts> result[VecSize];
ConditionalT<OutT, NumOuts> result[VecSize];

#pragma unroll
for (int i = 0; i < Arity; i++) {
Expand All @@ -224,7 +224,7 @@ __device__ void ElementwiseBroadcastKernelImpl(
constexpr bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args;
ElementwisePrimitiveCaller<InT,
OutType<OutT, NumOuts>,
ConditionalT<OutT, NumOuts>,
VecSize,
Functor,
Arity,
Expand Down Expand Up @@ -455,20 +455,19 @@ void LaunchBroadcastElementwiseCudaKernel(
"is %d, the arity of functor is %d.",
ins.size(),
kArity));
PADDLE_ENFORCE_EQ(kArity,
2,
PADDLE_ENFORCE_LE(kArity,
ElementwiseType::kTernary,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实这里就应该用3,因为ElementwiseType::kTernary是个枚举类型,可能设置成别的值。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,和上一个删除DimensionsTransform的PR一同把这里修改掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实这里就应该用3,因为ElementwiseType::kTernary是个枚举类型,可能设置成别的值。

已经在PR38550中修改

paddle::platform::errors::InvalidArgument(
"Currently only broadcast of binary is supported and "
"verified, but received %d.",
"Currently only broadcast of ternary is supported "
"and verified, but received %d.",
kArity));
PADDLE_ENFORCE_EQ(
outs->size(),
NumOuts,
paddle::platform::errors::InvalidArgument(
"Number of outputs shall equal to number of functions, "
"but number of outputs is %d, number of functions is %d.",
outs->size(),
NumOuts));
PADDLE_ENFORCE_EQ(outs->size(),
NumOuts,
paddle::platform::errors::InvalidArgument(
"Number of outputs shall equal to number of functions, "
"but number of outputs is %d, of functions is %d.",
outs->size(),
NumOuts));
int in_vec_size = 4;
int out_vec_size = 4;
if (NumOuts > 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 };
/* Packing scalar type T(float, int etc.) into Array<T, NumOuts> type
for supporting multiple-output feature in elementwise system.*/
template <class T, int Num>
using OutType =
using ConditionalT =
typename std::conditional_t<Num == 1, T, paddle::framework::Array<T, Num>>;

template <typename InT,
Expand Down Expand Up @@ -86,7 +86,7 @@ template <typename OutT, int VecSize, bool IsBoundary, int NumOuts>
struct ElementwiseWriteDataCaller {
__device__ __forceinline__ void operator()(
paddle::framework::Array<OutT *, NumOuts> outs,
OutType<OutT, NumOuts> src[VecSize],
ConditionalT<OutT, NumOuts> src[VecSize],
int block_offset,
int num) {
OutT dst[NumOuts][VecSize];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,17 @@ template <typename InT,
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize,
bool IsBoundary>
__device__ void VectorizedElementwiseKernelImpl(
const paddle::framework::Array<const InT *__restrict__, Arity> &in,
OutT *out,
paddle::framework::Array<OutT *, NumOuts> outs,
int num,
int data_offset,
Functor func) {
InT args[Arity][VecSize];
OutT result[VecSize];
ConditionalT<OutT, NumOuts> result[VecSize];

#pragma unroll
for (int i = 0; i < Arity; i++) {
Expand All @@ -73,36 +74,53 @@ __device__ void VectorizedElementwiseKernelImpl(
args[i], in[i] + data_offset, num);
}

const bool kCallElementwiseAny =
constexpr bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args;
ElementwisePrimitiveCaller<InT,
OutT,
ConditionalT<OutT, NumOuts>,
VecSize,
Functor,
Arity,
kCallElementwiseAny>()(func, args, result);
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
out + data_offset, result, num);

ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, NumOuts>()(
outs, result, data_offset, num);
}

template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
template <typename InT,
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize>
__global__ void VectorizedElementwiseKernel(
paddle::framework::Array<const InT *__restrict__, Arity> ins,
OutT *out,
paddle::framework::Array<OutT *, NumOuts> outs,
int size,
int main_offset,
Functor func) {
int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
for (; data_offset < main_offset; data_offset += stride) {
VectorizedElementwiseKernelImpl<InT, OutT, Functor, Arity, VecSize, false>(
ins, out, VecSize * BLOCK_NUM_X, data_offset, func);
VectorizedElementwiseKernelImpl<InT,
OutT,
Functor,
Arity,
NumOuts,
VecSize,
false>(
ins, outs, VecSize * BLOCK_NUM_X, data_offset, func);
}

int num = size - data_offset;
if (num > 0) {
VectorizedElementwiseKernelImpl<InT, OutT, Functor, Arity, VecSize, true>(
ins, out, num, data_offset, func);
VectorizedElementwiseKernelImpl<InT,
OutT,
Functor,
Arity,
NumOuts,
VecSize,
true>(ins, outs, num, data_offset, func);
}
}

Expand All @@ -121,7 +139,12 @@ int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins,
return vec_size;
}

template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
template <typename InT,
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize>
void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
Expand All @@ -131,11 +154,15 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx,
int grid_size =
((numel + VecSize - 1) / VecSize + block_size - 1) / block_size;
auto stream = ctx.stream();
OutT *out_data = (*outs)[0]->mutable_data<OutT>();
paddle::framework::Array<const InT *__restrict__, Arity> ins_data;
for (int i = 0; i < Arity; i++) {
paddle::framework::Array<OutT *, NumOuts> outs_data;

for (int i = 0; i < Arity; ++i) {
ins_data[i] = ins[i]->data<InT>();
}
for (int i = 0; i < NumOuts; ++i) {
outs_data[i] = (*outs)[i]->mutable_data<OutT>();
}
#ifdef PADDLE_WITH_XPU2
block_size = 128;
grid_size = 8;
Expand All @@ -144,20 +171,26 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx,
OutT,
Functor,
Arity,
NumOuts,
VecSize><<<grid_size, block_size, 0, stream>>>(
ins_data, out_data, numel, main_offset, func);
ins_data, outs_data, numel, main_offset, func);
#else
int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size;
VectorizedElementwiseKernel<InT,
OutT,
Functor,
Arity,
NumOuts,
VecSize><<<grid_size, block_size, 0, stream>>>(
ins_data, out_data, numel, main_offset, func);
ins_data, outs_data, numel, main_offset, func);
#endif
}

template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
template <ElementwiseType ET,
typename InT,
typename OutT,
typename Functor,
int NumOuts = 1>
void LaunchSameDimsElementwiseCudaKernel(
const paddle::platform::CUDADeviceContext &ctx,
const std::vector<const DenseTensor *> &ins,
Expand All @@ -174,19 +207,39 @@ void LaunchSameDimsElementwiseCudaKernel(
"is %d, the arity of functor is %d.",
ins.size(),
kArity));
PADDLE_ENFORCE_EQ(outs->size(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不需要像broadcast一样判断下ET的值吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

由于function_traits中这段的存在:

using Traits = paddle::platform::FunctionTraits<Functor>;
const int kArity =
Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;

所以ET的值就被 kArity 取代了,后面就用kArity 做判断了

NumOuts,
paddle::platform::errors::InvalidArgument(
"Number of outputs shall equal to number of functions, "
"but number of outputs is %d, of functions is %d.",
outs->size(),
NumOuts));

if (NumOuts > 1) {
for (int i = 1; i < NumOuts; ++i) {
PADDLE_ENFORCE_EQ(
(*outs)[i]->dims(),
(*outs)[0]->dims(),
paddle::platform::errors::InvalidArgument(
"The shape of each output tensor shall be identical yet, "
"but %dth output tensor`s shape is not.",
i));
}
}

// calculate the max vec_size for all ins and outs
int vec_size = GetVectorizedSizeForTensors<InT, OutT>(ins, *outs);
switch (vec_size) {
case 4:
ElementwiseCudaKernel<InT, OutT, Functor, kArity, 4>(
ElementwiseCudaKernel<InT, OutT, Functor, kArity, NumOuts, 4>(
ctx, ins, outs, func);
break;
case 2:
ElementwiseCudaKernel<InT, OutT, Functor, kArity, 2>(
ElementwiseCudaKernel<InT, OutT, Functor, kArity, NumOuts, 2>(
ctx, ins, outs, func);
break;
case 1:
ElementwiseCudaKernel<InT, OutT, Functor, kArity, 1>(
ElementwiseCudaKernel<InT, OutT, Functor, kArity, NumOuts, 1>(
ctx, ins, outs, func);
break;
default: {
Expand Down