Skip to content

Commit

Permalink
[pten] update isnan registration (#39419)
Browse files Browse the repository at this point in the history
* update isnan registration

* fix compile
  • Loading branch information
zhiqiu committed Feb 10, 2022
1 parent c7c1db3 commit 14ed2f5
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 17 deletions.
31 changes: 30 additions & 1 deletion paddle/fluid/operators/isfinite_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,33 @@ namespace ops = paddle::operators;
REGISTER_OP_MAKER(isinf, "isinf(X)");
REGISTER_OP_MAKER(isnan, "isnan(X)");
REGISTER_OP_MAKER(isfinite, "isfinite(X)");
FOR_EACH_KERNEL_FUNCTOR(REGISTER_OVERFLOW_CPU_KERNEL);

REGISTER_OP_CPU_KERNEL(isinf,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int, ops::InfinityFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::InfinityFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
float, ops::InfinityFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
double, ops::InfinityFunctor>);

REGISTER_OP_CPU_KERNEL(isnan,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int, ops::NANFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::NANFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
float, ops::NANFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
double, ops::NANFunctor>);

REGISTER_OP_CPU_KERNEL(isfinite,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int, ops::IsfiniteFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::IsfiniteFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
float, ops::IsfiniteFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
double, ops::IsfiniteFunctor>);
39 changes: 28 additions & 11 deletions paddle/fluid/operators/isfinite_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,32 @@
namespace ops = paddle::operators;
namespace plat = paddle::platform;

#define REGISTER_OVERFLOW_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \
op_type, ops::OverflowKernel<paddle::platform::CUDADeviceContext, int, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CUDADeviceContext, float, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CUDADeviceContext, double, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CUDADeviceContext, plat::float16, \
ops::functor>);
REGISTER_OP_CUDA_KERNEL(
isinf, ops::OverflowKernel<paddle::platform::CUDADeviceContext, int,
ops::InfinityFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, float,
ops::InfinityFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, double,
ops::InfinityFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, plat::float16,
ops::InfinityFunctor>);

FOR_EACH_KERNEL_FUNCTOR(REGISTER_OVERFLOW_CUDA_KERNEL);
REGISTER_OP_CUDA_KERNEL(isnan,
ops::OverflowKernel<paddle::platform::CUDADeviceContext,
int, ops::NANFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext,
float, ops::NANFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext,
double, ops::NANFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext,
plat::float16, ops::NANFunctor>);

REGISTER_OP_CUDA_KERNEL(
isfinite, ops::OverflowKernel<paddle::platform::CUDADeviceContext, int,
ops::IsfiniteFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, float,
ops::IsfiniteFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, double,
ops::IsfiniteFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, plat::float16,
ops::IsfiniteFunctor>);
5 changes: 0 additions & 5 deletions paddle/fluid/operators/isfinite_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,3 @@ class OverflowKernel : public framework::OpKernel<T> {

} // namespace operators
} // namespace paddle

#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(isinf, InfinityFunctor); \
__macro(isnan, NANFunctor); \
__macro(isfinite, IsfiniteFunctor);

0 comments on commit 14ed2f5

Please sign in to comment.