diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index f35d8b6bbf89f..d465e77ea1886 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -147,16 +147,15 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(fill_constant, ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel); +REGISTER_OP_CPU_KERNEL( + fill_constant, ops::FillConstantKernel, + ops::FillConstantKernel, ops::FillConstantKernel, + ops::FillConstantKernel, ops::FillConstantKernel, + ops::FillConstantKernel, + ops::FillConstantKernel, + ops::FillConstantKernel, + ops::FillConstantKernel>, + ops::FillConstantKernel>); REGISTER_OP_VERSION(fill_constant) .AddCheckpoint( diff --git a/paddle/fluid/operators/fill_constant_op.cu.cc b/paddle/fluid/operators/fill_constant_op.cu.cc index e784c20b8b8b4..a862cda13888e 100644 --- a/paddle/fluid/operators/fill_constant_op.cu.cc +++ b/paddle/fluid/operators/fill_constant_op.cu.cc @@ -15,12 +15,11 @@ limitations under the License. */ #include "paddle/fluid/operators/fill_constant_op.h" namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(fill_constant, ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel); +REGISTER_OP_CUDA_KERNEL( + fill_constant, ops::FillConstantKernel, + ops::FillConstantKernel, ops::FillConstantKernel, + ops::FillConstantKernel, ops::FillConstantKernel, + ops::FillConstantKernel, + ops::FillConstantKernel, + ops::FillConstantKernel>, + ops::FillConstantKernel>); diff --git a/paddle/fluid/operators/fill_constant_op_xpu.cc b/paddle/fluid/operators/fill_constant_op_xpu.cc index 16dd4c9292f89..d55b8e2b81b52 100644 --- a/paddle/fluid/operators/fill_constant_op_xpu.cc +++ b/paddle/fluid/operators/fill_constant_op_xpu.cc @@ -15,11 +15,10 @@ limitations under the License. */ namespace ops = paddle::operators; #ifdef PADDLE_WITH_XPU -REGISTER_OP_XPU_KERNEL(fill_constant, ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel, - ops::FillConstantKernel); +REGISTER_OP_XPU_KERNEL( + fill_constant, ops::FillConstantKernel, + ops::FillConstantKernel, ops::FillConstantKernel, + ops::FillConstantKernel, ops::FillConstantKernel, + ops::FillConstantKernel>, + ops::FillConstantKernel>); #endif diff --git a/paddle/fluid/operators/py_layer_op.cc b/paddle/fluid/operators/py_layer_op.cc index f91496eeab142..c2f68675beb62 100644 --- a/paddle/fluid/operators/py_layer_op.cc +++ b/paddle/fluid/operators/py_layer_op.cc @@ -199,9 +199,9 @@ REGISTER_OP_CPU_KERNEL( ops::PyLayerOpKernel, ops::PyLayerOpKernel, ops::PyLayerOpKernel, + ::paddle::platform::complex>, ops::PyLayerOpKernel); + ::paddle::platform::complex>); #ifdef PADDLE_WITH_CUDA REGISTER_OP_CUDA_KERNEL( py_layer, ops::PyLayerOpKernel, @@ -218,7 +218,7 @@ REGISTER_OP_CUDA_KERNEL( ops::PyLayerOpKernel, ops::PyLayerOpKernel, ops::PyLayerOpKernel, + ::paddle::platform::complex>, ops::PyLayerOpKernel); + ::paddle::platform::complex>); #endif // PADDLE_WITH_CUDA diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu index 67de8bb9a0c1a..230bae0cdd4b1 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu @@ -20,10 +20,9 @@ using CUDAReduceSumGradKernel = ops::ReduceGradKernel; -REGISTER_OP_CUDA_KERNEL(reduce_sum_grad, CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel); +REGISTER_OP_CUDA_KERNEL( + reduce_sum_grad, CUDAReduceSumGradKernel, + CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, + CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, + CUDAReduceSumGradKernel>, + CUDAReduceSumGradKernel>); diff --git a/paddle/fluid/operators/strided_slice_op.cc b/paddle/fluid/operators/strided_slice_op.cc index e49476e4dc7d4..d71be60e1f5c2 100644 --- a/paddle/fluid/operators/strided_slice_op.cc +++ b/paddle/fluid/operators/strided_slice_op.cc @@ -329,9 +329,9 @@ REGISTER_OP_CPU_KERNEL( ops::StridedSliceKernel, ops::StridedSliceKernel, ops::StridedSliceKernel, + paddle::platform::complex>, ops::StridedSliceKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( strided_slice_grad, @@ -340,6 +340,6 @@ REGISTER_OP_CPU_KERNEL( ops::StridedSliceGradKernel, ops::StridedSliceGradKernel, ops::StridedSliceGradKernel, + paddle::platform::complex>, ops::StridedSliceGradKernel); + paddle::platform::complex>); diff --git a/paddle/fluid/operators/strided_slice_op.cu b/paddle/fluid/operators/strided_slice_op.cu index b85403b1c5bb8..68a8312f0818d 100644 --- a/paddle/fluid/operators/strided_slice_op.cu +++ b/paddle/fluid/operators/strided_slice_op.cu @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/strided_slice_op.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( @@ -24,9 +23,9 @@ REGISTER_OP_CUDA_KERNEL( ops::StridedSliceKernel, ops::StridedSliceKernel, ops::StridedSliceKernel, + paddle::platform::complex>, ops::StridedSliceKernel); + paddle::platform::complex>); REGISTER_OP_CUDA_KERNEL( strided_slice_grad, @@ -35,6 +34,6 @@ REGISTER_OP_CUDA_KERNEL( ops::StridedSliceGradKernel, ops::StridedSliceGradKernel, ops::StridedSliceGradKernel, + paddle::platform::complex>, ops::StridedSliceGradKernel); + paddle::platform::complex>);