-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support multi-output feature for elementwise #38410
Changes from all commits
ec7d2a3
b073486
61426f1
85ec155
b085c89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -208,7 +208,7 @@ __device__ void ElementwiseBroadcastKernelImpl( | |
int block_offset, | ||
Functor func) { | ||
InT args[Arity][VecSize]; | ||
OutType<OutT, NumOuts> result[VecSize]; | ||
ConditionalT<OutT, NumOuts> result[VecSize]; | ||
|
||
#pragma unroll | ||
for (int i = 0; i < Arity; i++) { | ||
|
@@ -224,7 +224,7 @@ __device__ void ElementwiseBroadcastKernelImpl( | |
constexpr bool kCallElementwiseAny = | ||
paddle::platform::FunctionTraits<Functor>::has_pointer_args; | ||
ElementwisePrimitiveCaller<InT, | ||
OutType<OutT, NumOuts>, | ||
ConditionalT<OutT, NumOuts>, | ||
VecSize, | ||
Functor, | ||
Arity, | ||
|
@@ -455,20 +455,19 @@ void LaunchBroadcastElementwiseCudaKernel( | |
"is %d, the arity of functor is %d.", | ||
ins.size(), | ||
kArity)); | ||
PADDLE_ENFORCE_EQ(kArity, | ||
2, | ||
PADDLE_ENFORCE_LE(kArity, | ||
ElementwiseType::kTernary, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 其实这里就应该用 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,和上一个删除 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
已经在PR38550中修改 |
||
paddle::platform::errors::InvalidArgument( | ||
"Currently only broadcast of binary is supported and " | ||
"verified, but received %d.", | ||
"Currently only broadcast of ternary is supported " | ||
"and verified, but received %d.", | ||
kArity)); | ||
PADDLE_ENFORCE_EQ( | ||
outs->size(), | ||
NumOuts, | ||
paddle::platform::errors::InvalidArgument( | ||
"Number of outputs shall equal to number of functions, " | ||
"but number of outputs is %d, number of functions is %d.", | ||
outs->size(), | ||
NumOuts)); | ||
PADDLE_ENFORCE_EQ(outs->size(), | ||
NumOuts, | ||
paddle::platform::errors::InvalidArgument( | ||
"Number of outputs shall equal to number of functions, " | ||
"but number of outputs is %d, of functions is %d.", | ||
outs->size(), | ||
NumOuts)); | ||
int in_vec_size = 4; | ||
int out_vec_size = 4; | ||
if (NumOuts > 1) { | ||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -55,16 +55,17 @@ template <typename InT, | |||||||
typename OutT, | ||||||||
typename Functor, | ||||||||
int Arity, | ||||||||
int NumOuts, | ||||||||
int VecSize, | ||||||||
bool IsBoundary> | ||||||||
__device__ void VectorizedElementwiseKernelImpl( | ||||||||
const paddle::framework::Array<const InT *__restrict__, Arity> &in, | ||||||||
OutT *out, | ||||||||
paddle::framework::Array<OutT *, NumOuts> outs, | ||||||||
int num, | ||||||||
int data_offset, | ||||||||
Functor func) { | ||||||||
InT args[Arity][VecSize]; | ||||||||
OutT result[VecSize]; | ||||||||
ConditionalT<OutT, NumOuts> result[VecSize]; | ||||||||
|
||||||||
#pragma unroll | ||||||||
for (int i = 0; i < Arity; i++) { | ||||||||
|
@@ -73,36 +74,53 @@ __device__ void VectorizedElementwiseKernelImpl( | |||||||
args[i], in[i] + data_offset, num); | ||||||||
} | ||||||||
|
||||||||
const bool kCallElementwiseAny = | ||||||||
constexpr bool kCallElementwiseAny = | ||||||||
paddle::platform::FunctionTraits<Functor>::has_pointer_args; | ||||||||
ElementwisePrimitiveCaller<InT, | ||||||||
OutT, | ||||||||
ConditionalT<OutT, NumOuts>, | ||||||||
VecSize, | ||||||||
Functor, | ||||||||
Arity, | ||||||||
kCallElementwiseAny>()(func, args, result); | ||||||||
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>( | ||||||||
out + data_offset, result, num); | ||||||||
|
||||||||
ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, NumOuts>()( | ||||||||
outs, result, data_offset, num); | ||||||||
} | ||||||||
|
||||||||
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize> | ||||||||
template <typename InT, | ||||||||
typename OutT, | ||||||||
typename Functor, | ||||||||
int Arity, | ||||||||
int NumOuts, | ||||||||
int VecSize> | ||||||||
__global__ void VectorizedElementwiseKernel( | ||||||||
paddle::framework::Array<const InT *__restrict__, Arity> ins, | ||||||||
OutT *out, | ||||||||
paddle::framework::Array<OutT *, NumOuts> outs, | ||||||||
int size, | ||||||||
int main_offset, | ||||||||
Functor func) { | ||||||||
int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; | ||||||||
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; | ||||||||
for (; data_offset < main_offset; data_offset += stride) { | ||||||||
VectorizedElementwiseKernelImpl<InT, OutT, Functor, Arity, VecSize, false>( | ||||||||
ins, out, VecSize * BLOCK_NUM_X, data_offset, func); | ||||||||
VectorizedElementwiseKernelImpl<InT, | ||||||||
OutT, | ||||||||
Functor, | ||||||||
Arity, | ||||||||
NumOuts, | ||||||||
VecSize, | ||||||||
false>( | ||||||||
ins, outs, VecSize * BLOCK_NUM_X, data_offset, func); | ||||||||
} | ||||||||
|
||||||||
int num = size - data_offset; | ||||||||
if (num > 0) { | ||||||||
VectorizedElementwiseKernelImpl<InT, OutT, Functor, Arity, VecSize, true>( | ||||||||
ins, out, num, data_offset, func); | ||||||||
VectorizedElementwiseKernelImpl<InT, | ||||||||
OutT, | ||||||||
Functor, | ||||||||
Arity, | ||||||||
NumOuts, | ||||||||
VecSize, | ||||||||
true>(ins, outs, num, data_offset, func); | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
|
@@ -121,7 +139,12 @@ int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins, | |||||||
return vec_size; | ||||||||
} | ||||||||
|
||||||||
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize> | ||||||||
template <typename InT, | ||||||||
typename OutT, | ||||||||
typename Functor, | ||||||||
int Arity, | ||||||||
int NumOuts, | ||||||||
int VecSize> | ||||||||
void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx, | ||||||||
const std::vector<const DenseTensor *> &ins, | ||||||||
std::vector<DenseTensor *> *outs, | ||||||||
|
@@ -131,11 +154,15 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx, | |||||||
int grid_size = | ||||||||
((numel + VecSize - 1) / VecSize + block_size - 1) / block_size; | ||||||||
auto stream = ctx.stream(); | ||||||||
OutT *out_data = (*outs)[0]->mutable_data<OutT>(); | ||||||||
paddle::framework::Array<const InT *__restrict__, Arity> ins_data; | ||||||||
for (int i = 0; i < Arity; i++) { | ||||||||
paddle::framework::Array<OutT *, NumOuts> outs_data; | ||||||||
|
||||||||
for (int i = 0; i < Arity; ++i) { | ||||||||
ins_data[i] = ins[i]->data<InT>(); | ||||||||
} | ||||||||
for (int i = 0; i < NumOuts; ++i) { | ||||||||
outs_data[i] = (*outs)[i]->mutable_data<OutT>(); | ||||||||
} | ||||||||
#ifdef PADDLE_WITH_XPU2 | ||||||||
block_size = 128; | ||||||||
grid_size = 8; | ||||||||
|
@@ -144,20 +171,26 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx, | |||||||
OutT, | ||||||||
Functor, | ||||||||
Arity, | ||||||||
NumOuts, | ||||||||
VecSize><<<grid_size, block_size, 0, stream>>>( | ||||||||
ins_data, out_data, numel, main_offset, func); | ||||||||
ins_data, outs_data, numel, main_offset, func); | ||||||||
#else | ||||||||
int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; | ||||||||
VectorizedElementwiseKernel<InT, | ||||||||
OutT, | ||||||||
Functor, | ||||||||
Arity, | ||||||||
NumOuts, | ||||||||
VecSize><<<grid_size, block_size, 0, stream>>>( | ||||||||
ins_data, out_data, numel, main_offset, func); | ||||||||
ins_data, outs_data, numel, main_offset, func); | ||||||||
#endif | ||||||||
} | ||||||||
|
||||||||
template <ElementwiseType ET, typename InT, typename OutT, typename Functor> | ||||||||
template <ElementwiseType ET, | ||||||||
typename InT, | ||||||||
typename OutT, | ||||||||
typename Functor, | ||||||||
int NumOuts = 1> | ||||||||
void LaunchSameDimsElementwiseCudaKernel( | ||||||||
const paddle::platform::CUDADeviceContext &ctx, | ||||||||
const std::vector<const DenseTensor *> &ins, | ||||||||
|
@@ -174,19 +207,39 @@ void LaunchSameDimsElementwiseCudaKernel( | |||||||
"is %d, the arity of functor is %d.", | ||||||||
ins.size(), | ||||||||
kArity)); | ||||||||
PADDLE_ENFORCE_EQ(outs->size(), | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里不需要像broadcast一样判断下ET的值吗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 由于 Paddle/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_no_broadcast.cu.h Lines 166 to 168 in 706d2c0
所以ET的值就被 kArity 取代了,后面就用kArity 做判断了
|
||||||||
NumOuts, | ||||||||
paddle::platform::errors::InvalidArgument( | ||||||||
"Number of outputs shall equal to number of functions, " | ||||||||
"but number of outputs is %d, of functions is %d.", | ||||||||
outs->size(), | ||||||||
NumOuts)); | ||||||||
|
||||||||
if (NumOuts > 1) { | ||||||||
for (int i = 1; i < NumOuts; ++i) { | ||||||||
PADDLE_ENFORCE_EQ( | ||||||||
(*outs)[i]->dims(), | ||||||||
(*outs)[0]->dims(), | ||||||||
paddle::platform::errors::InvalidArgument( | ||||||||
"The shape of each output tensor shall be identical yet, " | ||||||||
"but %dth output tensor`s shape is not.", | ||||||||
i)); | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
// calculate the max vec_size for all ins and outs | ||||||||
int vec_size = GetVectorizedSizeForTensors<InT, OutT>(ins, *outs); | ||||||||
switch (vec_size) { | ||||||||
case 4: | ||||||||
ElementwiseCudaKernel<InT, OutT, Functor, kArity, 4>( | ||||||||
ElementwiseCudaKernel<InT, OutT, Functor, kArity, NumOuts, 4>( | ||||||||
ctx, ins, outs, func); | ||||||||
break; | ||||||||
case 2: | ||||||||
ElementwiseCudaKernel<InT, OutT, Functor, kArity, 2>( | ||||||||
ElementwiseCudaKernel<InT, OutT, Functor, kArity, NumOuts, 2>( | ||||||||
ctx, ins, outs, func); | ||||||||
break; | ||||||||
case 1: | ||||||||
ElementwiseCudaKernel<InT, OutT, Functor, kArity, 1>( | ||||||||
ElementwiseCudaKernel<InT, OutT, Functor, kArity, NumOuts, 1>( | ||||||||
ctx, ins, outs, func); | ||||||||
break; | ||||||||
default: { | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件里面的
DimensionsTransform
是不是可以删掉了?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以删除掉,我提个PR把这里删除吧,本来以为pten那边会删除掉