Skip to content

Commit

Permalink
pass ctest of elementwise_div_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesLim-sy committed Dec 28, 2021
1 parent 61426f1 commit 85ec155
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ __device__ void ElementwiseBroadcastKernelImpl(
int block_offset,
Functor func) {
InT args[Arity][VecSize];
OutType<OutT, NumOuts> result[VecSize];
ConditionalT<OutT, NumOuts> result[VecSize];

#pragma unroll
for (int i = 0; i < Arity; i++) {
Expand All @@ -224,7 +224,7 @@ __device__ void ElementwiseBroadcastKernelImpl(
constexpr bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args;
ElementwisePrimitiveCaller<InT,
OutType<OutT, NumOuts>,
ConditionalT<OutT, NumOuts>,
VecSize,
Functor,
Arity,
Expand Down Expand Up @@ -455,7 +455,7 @@ void LaunchBroadcastElementwiseCudaKernel(
"is %d, the arity of functor is %d.",
ins.size(),
kArity));
PADDLE_ENFORCE_EQ(kArity,
PADDLE_ENFORCE_LE(kArity,
ElementwiseType::kTernary,
paddle::platform::errors::InvalidArgument(
"Currently only broadcast of ternary is supported "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 };
/* Packing scalar type T(float, int etc.) into Array<T, NumOuts> type
for supporting multiple-output feature in elementwise system.*/
template <class T, int Num>
using OutType =
using ConditionalT =
typename std::conditional_t<Num == 1, T, paddle::framework::Array<T, Num>>;

template <typename InT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ __device__ void VectorizedElementwiseKernelImpl(
int data_offset,
Functor func) {
InT args[Arity][VecSize];
OutType<OutT, NumOuts> result[VecSize];
ConditionalT<OutT, NumOuts> result[VecSize];

#pragma unroll
for (int i = 0; i < Arity; i++) {
Expand All @@ -77,7 +77,7 @@ __device__ void VectorizedElementwiseKernelImpl(
constexpr bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args;
ElementwisePrimitiveCaller<InT,
OutType<OutT, NumOuts>,
ConditionalT<OutT, NumOuts>,
VecSize,
Functor,
Arity,
Expand Down

0 comments on commit 85ec155

Please sign in to comment.