From f17ce98be62fa5bfe4dd0df25766621c09602ef6 Mon Sep 17 00:00:00 2001 From: lijiaqi Date: Fri, 3 Dec 2021 03:31:04 +0000 Subject: [PATCH 01/11] add fmax and fmin oper --- .../elementwise/elementwise_functor.h | 16 +++ .../elementwise/elementwise_max_op.cc | 46 +++++++ .../elementwise/elementwise_max_op.cu | 13 ++ .../elementwise/elementwise_max_op.h | 49 +++++++ .../elementwise/elementwise_min_op.cc | 46 +++++++ .../elementwise/elementwise_min_op.cu | 13 ++ .../elementwise/elementwise_min_op.h | 48 +++++++ python/paddle/__init__.py | 4 + .../fluid/tests/unittests/test_fmax_op.py | 113 +++++++++++++++++ .../fluid/tests/unittests/test_fmin_op.py | 113 +++++++++++++++++ python/paddle/tensor/__init__.py | 4 + python/paddle/tensor/math.py | 120 ++++++++++++++++++ 12 files changed, 585 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_fmax_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_fmin_op.py diff --git a/paddle/fluid/operators/elementwise/elementwise_functor.h b/paddle/fluid/operators/elementwise/elementwise_functor.h index abac43a2616f0..b34466619eb2f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_functor.h +++ b/paddle/fluid/operators/elementwise/elementwise_functor.h @@ -113,5 +113,21 @@ struct MinFunctor { } }; +// Fmax +template +struct FMaxFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { + return std::fmax(a, b); + } +}; + +// Fmin +template +struct FMinFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { + return std::fmin(a, b); + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cc b/paddle/fluid/operators/elementwise/elementwise_max_op.cc index dde65c8199626..16fa73ca40c4c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cc @@ -70,6 +70,23 @@ class ElementwiseMaxGradOpMaker : public framework::SingleGradOpMaker { } }; +template +class ElementwiseFMaxGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("elementwise_fmax_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Y", this->Input("Y")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + op->SetAttrMap(this->Attrs()); + } +}; + } // namespace operators } // namespace paddle @@ -103,3 +120,32 @@ REGISTER_OP_VERSION(elementwise_max) "In order to support the function of scaling the input Y when " "using the operator of elementwise_max.", 1.0f)); + +REGISTER_OPERATOR(elementwise_fmax, ops::ElementwiseOp, + ops::ElementwiseMaxOpMaker, ops::ElementwiseOpInferVarType, + ops::ElementwiseFMaxGradOpMaker, + ops::ElementwiseFMaxGradOpMaker); + +REGISTER_OPERATOR(elementwise_fmax_grad, ops::ElementwiseOpGrad); + +REGISTER_OP_CPU_KERNEL( + elementwise_fmax, + ops::ElementwiseFMaxKernel, + ops::ElementwiseFMaxKernel, + ops::ElementwiseFMaxKernel, + ops::ElementwiseFMaxKernel); +REGISTER_OP_CPU_KERNEL( + elementwise_fmax_grad, + ops::ElementwiseFMaxGradKernel, + ops::ElementwiseFMaxGradKernel, + ops::ElementwiseFMaxGradKernel, + ops::ElementwiseFMaxGradKernel); + +REGISTER_OP_VERSION(elementwise_fmax) + .AddCheckpoint( + R"ROC(Register elementwise_fmax for adding the attribute of Scale_y)ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "Scale_y", + "In order to support the function of scaling the input Y when " + "using the operator of elementwise_fmax.", + 1.0f)); \ No newline at end of file diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cu b/paddle/fluid/operators/elementwise/elementwise_max_op.cu index 65505381db174..d350491b06f83 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cu @@ -56,3 +56,16 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseMaxGradKernel, ops::ElementwiseMaxGradKernel); + +REGISTER_OP_CUDA_KERNEL( + elementwise_fmax, + ops::ElementwiseFMaxKernel, + ops::ElementwiseFMaxKernel, + ops::ElementwiseFMaxKernel, + ops::ElementwiseFMaxKernel); +REGISTER_OP_CUDA_KERNEL( + elementwise_fmax_grad, + ops::ElementwiseFMaxGradKernel, + ops::ElementwiseFMaxGradKernel, + ops::ElementwiseFMaxGradKernel, + ops::ElementwiseFMaxGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.h b/paddle/fluid/operators/elementwise/elementwise_max_op.h index 06269b12e8e20..2995af05018ea 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.h @@ -36,6 +36,22 @@ class ElementwiseMaxKernel : public framework::OpKernel { } }; +template +class ElementwiseFMaxKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + + z->mutable_data(ctx.GetPlace()); + int axis = ctx.Attr("axis"); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + FMaxFunctor(), z); + } +}; + + template struct MaxGradDx { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { @@ -68,5 +84,38 @@ class ElementwiseMaxGradKernel : public ElemwiseGradKernel { ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx(), MaxGradDy()); } }; + +template +struct FMaxGradDx { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * static_cast((x >= y) || isnan(y)); + } +}; + +template +struct FMaxGradDy { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * static_cast (!((x >= y) || isnan(y))); + } +}; + +template +class ElementwiseFMaxGradKernel : public ElemwiseGradKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + auto* out = dout; // Fake out, not used + int axis = ctx.Attr("axis"); + ElemwiseGradCompute, FMaxGradDy>( + ctx, *x, *y, *out, *dout, axis, dx, dy, FMaxGradDx(), FMaxGradDy()); + } +}; } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cc b/paddle/fluid/operators/elementwise/elementwise_min_op.cc index 174684e3c8476..5ceece83d0372 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cc @@ -70,6 +70,23 @@ class ElementwiseMinGradOpMaker : public framework::SingleGradOpMaker { } }; +template +class ElementwiseFMinGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("elementwise_fmin_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Y", this->Input("Y")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + op->SetAttrMap(this->Attrs()); + } +}; + } // namespace operators } // namespace paddle @@ -103,3 +120,32 @@ REGISTER_OP_VERSION(elementwise_min) "In order to support the function of scaling the input Y when " "using the operator of elementwise_min.", 1.0f)); + +REGISTER_OPERATOR(elementwise_fmin, ops::ElementwiseOp, + ops::ElementwiseMinOpMaker, ops::ElementwiseOpInferVarType, + ops::ElementwiseFMinGradOpMaker, + ops::ElementwiseFMinGradOpMaker); + +REGISTER_OPERATOR(elementwise_fmin_grad, ops::ElementwiseOpGrad); + +REGISTER_OP_CPU_KERNEL( + elementwise_fmin, + ops::ElementwiseFMinKernel, + ops::ElementwiseFMinKernel, + ops::ElementwiseFMinKernel, + ops::ElementwiseFMinKernel); +REGISTER_OP_CPU_KERNEL( + elementwise_fmin_grad, + ops::ElementwiseFMinGradKernel, + ops::ElementwiseFMinGradKernel, + ops::ElementwiseFMinGradKernel, + ops::ElementwiseFMinGradKernel); + +REGISTER_OP_VERSION(elementwise_fmin) + .AddCheckpoint( + R"ROC(Register elementwise_min for adding the attribute of Scale_y)ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "Scale_y", + "In order to support the function of scaling the input Y when " + "using the operator of elementwise_fmin.", + 1.0f)); diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cu b/paddle/fluid/operators/elementwise/elementwise_min_op.cu index eed6f72b04fb9..5c1adab4b1264 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cu @@ -52,3 +52,16 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseMinGradKernel, ops::ElementwiseMinGradKernel); + +REGISTER_OP_CUDA_KERNEL( + elementwise_fmin, + ops::ElementwiseFMinKernel, + ops::ElementwiseFMinKernel, + ops::ElementwiseFMinKernel, + ops::ElementwiseFMinKernel); +REGISTER_OP_CUDA_KERNEL( + elementwise_fmin_grad, + ops::ElementwiseFMinGradKernel, + ops::ElementwiseFMinGradKernel, + ops::ElementwiseFMinGradKernel, + ops::ElementwiseFMinGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.h b/paddle/fluid/operators/elementwise/elementwise_min_op.h index 648691063c59b..151a48d31fd73 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.h @@ -36,6 +36,21 @@ class ElementwiseMinKernel : public framework::OpKernel { } }; +template +class ElementwiseFMinKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + + z->mutable_data(ctx.GetPlace()); + int axis = ctx.Attr("axis"); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + MinFunctor(), z); + } +}; + template struct MinGradDx { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { @@ -68,5 +83,38 @@ class ElementwiseMinGradKernel : public ElemwiseGradKernel { ctx, *x, *y, *out, *dout, axis, dx, dy, MinGradDx(), MinGradDy()); } }; + +template +struct FMinGradDx { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * static_cast((x <= y) || isnan(y)); + } +}; + +template +struct FMinGradDy { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * static_cast (!((x <= y) || isnan(y))); + } +}; + +template +class ElementwiseFMinGradKernel : public ElemwiseGradKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + auto* out = dout; // Fake out, not used + int axis = ctx.Attr("axis"); + ElemwiseGradCompute, FMinGradDy>( + ctx, *x, *y, *out, *dout, axis, dx, dy, FMinGradDx(), FMinGradDy()); + } +}; } // namespace operators } // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 5823cf460ee9f..ae4e3cfbd5a20 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -223,6 +223,8 @@ from .tensor.math import digamma # noqa: F401 from .tensor.math import neg # noqa: F401 from .tensor.math import lgamma # noqa: F401 +from .tensor.math import fmax # noqa: F401 +from .tensor.math import fmin # noqa: F401 from .tensor.random import multinomial # noqa: F401 from .tensor.random import standard_normal # noqa: F401 @@ -433,6 +435,8 @@ 'pow', 'zeros_like', 'maximum', + 'fmax', + 'fmin', 'topk', 'index_select', 'CPUPlace', diff --git a/python/paddle/fluid/tests/unittests/test_fmax_op.py b/python/paddle/fluid/tests/unittests/test_fmax_op.py new file mode 100644 index 0000000000000..1bc549c7151a9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fmax_op.py @@ -0,0 +1,113 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core + + +class ApiFMaxTest(unittest.TestCase): + def setUp(self): + if core.is_compiled_with_cuda(): + self.place = core.CUDAPlace(0) + else: + self.place = core.CPUPlace() + + self.input_x = np.random.rand(10, 15).astype("float32") + self.input_y = np.random.rand(10, 15).astype("float32") + self.input_z = np.random.rand(15).astype("float32") + self.input_a = np.array([0, np.nan, np.nan]).astype('int64') + self.input_b = np.array([2, np.inf, -np.inf]).astype('int64') + self.input_c = np.array([4, 1, 3]).astype('int64') + + self.np_expected1 = np.fmax(self.input_x, self.input_y) + self.np_expected2 = np.fmax(self.input_x, self.input_z) + self.np_expected3 = np.fmax(self.input_a, self.input_c) + self.np_expected4 = np.fmax(self.input_b, self.input_c) + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data_x = paddle.static.data("x", shape=[10, 15], dtype="float32") + data_y = paddle.static.data("y", shape=[10, 15], dtype="float32") + result_fmax = paddle.fmax(data_x, data_y) + exe = paddle.static.Executor(self.place) + res, = exe.run(feed={"x": self.input_x, + "y": self.input_y}, + fetch_list=[result_fmax]) + self.assertTrue(np.allclose(res, self.np_expected1)) + + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data_x = paddle.static.data("x", shape=[10, 15], dtype="float32") + data_z = paddle.static.data("z", shape=[15], dtype="float32") + result_fmax = paddle.fmax(data_x, data_z) + exe = paddle.static.Executor(self.place) + res, = exe.run(feed={"x": self.input_x, + "z": self.input_z}, + fetch_list=[result_fmax]) + self.assertTrue(np.allclose(res, self.np_expected2)) + + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data_a = paddle.static.data("a", shape=[3], dtype="int64") + data_c = paddle.static.data("c", shape=[3], dtype="int64") + result_fmax = paddle.fmax(data_a, data_c) + exe = paddle.static.Executor(self.place) + res, = exe.run(feed={"a": self.input_a, + "c": self.input_c}, + fetch_list=[result_fmax]) + self.assertTrue(np.allclose(res, self.np_expected3)) + + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data_b = paddle.static.data("b", shape=[3], dtype="int64") + data_c = paddle.static.data("c", shape=[3], dtype="int64") + result_fmax = paddle.fmax(data_b, data_c) + exe = paddle.static.Executor(self.place) + res, = exe.run(feed={"b": self.input_b, + "c": self.input_c}, + fetch_list=[result_fmax]) + self.assertTrue(np.allclose(res, self.np_expected4)) + + def test_dynamic_api(self): + paddle.disable_static() + x = paddle.to_tensor(self.input_x) + y = paddle.to_tensor(self.input_y) + z = paddle.to_tensor(self.input_z) + + a = paddle.to_tensor(self.input_a) + b = paddle.to_tensor(self.input_b) + c = paddle.to_tensor(self.input_c) + + res = paddle.fmax(x, y) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected1)) + + # test broadcast + res = paddle.fmax(x, z) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected2)) + + res = paddle.fmax(a, c) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected3)) + + res = paddle.fmax(b, c) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected4)) diff --git a/python/paddle/fluid/tests/unittests/test_fmin_op.py b/python/paddle/fluid/tests/unittests/test_fmin_op.py new file mode 100644 index 0000000000000..a1a25cce29af9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fmin_op.py @@ -0,0 +1,113 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core + + +class ApiFMinTest(unittest.TestCase): + def setUp(self): + if core.is_compiled_with_cuda(): + self.place = core.CUDAPlace(0) + else: + self.place = core.CPUPlace() + + self.input_x = np.random.rand(10, 15).astype("float32") + self.input_y = np.random.rand(10, 15).astype("float32") + self.input_z = np.random.rand(15).astype("float32") + self.input_a = np.array([0, np.nan, np.nan]).astype('int64') + self.input_b = np.array([2, np.inf, -np.inf]).astype('int64') + self.input_c = np.array([4, 1, 3]).astype('int64') + + self.np_expected1 = np.fmin(self.input_x, self.input_y) + self.np_expected2 = np.fmin(self.input_x, self.input_z) + self.np_expected3 = np.fmin(self.input_a, self.input_c) + self.np_expected4 = np.fmin(self.input_b, self.input_c) + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data_x = paddle.static.data("x", shape=[10, 15], dtype="float32") + data_y = paddle.static.data("y", shape=[10, 15], dtype="float32") + result_fmin = paddle.fmin(data_x, data_y) + exe = paddle.static.Executor(self.place) + res, = exe.run(feed={"x": self.input_x, + "y": self.input_y}, + fetch_list=[result_fmin]) + self.assertTrue(np.allclose(res, self.np_expected1)) + + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data_x = paddle.static.data("x", shape=[10, 15], dtype="float32") + data_z = paddle.static.data("z", shape=[15], dtype="float32") + result_fmin = paddle.fmin(data_x, data_z) + exe = paddle.static.Executor(self.place) + res, = exe.run(feed={"x": self.input_x, + "z": self.input_z}, + fetch_list=[result_fmin]) + self.assertTrue(np.allclose(res, self.np_expected2)) + + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data_a = paddle.static.data("a", shape=[3], dtype="int64") + data_c = paddle.static.data("c", shape=[3], dtype="int64") + result_fmin = paddle.fmin(data_a, data_c) + exe = paddle.static.Executor(self.place) + res, = exe.run(feed={"a": self.input_a, + "c": self.input_c}, + fetch_list=[result_fmin]) + self.assertTrue(np.allclose(res, self.np_expected3)) + + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data_b = paddle.static.data("b", shape=[3], dtype="int64") + data_c = paddle.static.data("c", shape=[3], dtype="int64") + result_fmin = paddle.fmin(data_b, data_c) + exe = paddle.static.Executor(self.place) + res, = exe.run(feed={"b": self.input_b, + "c": self.input_c}, + fetch_list=[result_fmin]) + self.assertTrue(np.allclose(res, self.np_expected4)) + + def test_dynamic_api(self): + paddle.disable_static() + x = paddle.to_tensor(self.input_x) + y = paddle.to_tensor(self.input_y) + z = paddle.to_tensor(self.input_z) + + a = paddle.to_tensor(self.input_a) + b = paddle.to_tensor(self.input_b) + c = paddle.to_tensor(self.input_c) + + res = paddle.fmin(x, y) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected1)) + + # test broadcast + res = paddle.fmin(x, z) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected2)) + + res = paddle.fmin(a, c) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected3)) + + res = paddle.fmin(b, c) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected4)) diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 21d1dd1793b2c..1ac92b78490b1 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -189,6 +189,8 @@ from .math import neg # noqa: F401 from .math import lgamma # noqa: F401 from .math import diagonal # noqa: F401 +from .math import fmax # noqa: F401 +from .math import fmin # noqa: F401 from .random import multinomial # noqa: F401 from .random import standard_normal # noqa: F401 @@ -291,6 +293,8 @@ 'maximum', 'min', 'minimum', + 'fmax', + 'fmin', 'mm', 'divide', 'floor_divide', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index f5f0b5ed0873c..53a0e6b8f2133 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -630,6 +630,126 @@ def minimum(x, y, name=None): x, y, axis=axis, act=act, op_name=op_type) return _elementwise_op(LayerHelper(op_type, **locals())) +def fmax(x, y, name=None): + """ + Compares the elements at the corresponding positions of the two tensors and returns a new tensor containing the maximum value of the element. + The equation is: + + .. math:: + out = fmax(x, y) + + **Note**: + ``paddle.fmax`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` . + + Args: + x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + N-D Tensor. A location into which the result is stored. If x, y have different shapes and are "broadcastable", the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = paddle.to_tensor([[1, 2], [7, 8]]) + y = paddle.to_tensor([[3, 4], [5, 6]]) + res = paddle.fmax(x, y) + print(res) + # [[3, 4], + # [7, 8]] + + x = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) + y = paddle.to_tensor([3, 0, 4]) + res = paddle.fmax(x, y) + print(res) + # [[3, 2, 4], + # [3, 2, 4]] + + x = paddle.to_tensor([2, 3, 5], dtype='float32') + y = paddle.to_tensor([1, np.nan, np.nan], dtype='float32') + res = paddle.fmax(x, y) + print(res) + # [ 2., 3., 5.] + + x = paddle.to_tensor([5, 3, np.inf], dtype='float32') + y = paddle.to_tensor([1, -np.inf, 5], dtype='float32') + res = paddle.fmax(x, y) + print(res) + # [ 5., 3., inf.] + """ + op_type = 'elementwise_fmax' + axis = -1 + act = None + if in_dygraph_mode(): + return _elementwise_op_in_dygraph( + x, y, axis=axis, act=act, op_name=op_type) + return _elementwise_op(LayerHelper(op_type, **locals())) + +def fmin(x, y, name=None): + """ + Compares the elements at the corresponding positions of the two tensors and returns a new tensor containing the minimum value of the element. + The equation is: + + .. math:: + out = fmin(x, y) + + **Note**: + ``paddle.fmin`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` . + + Args: + x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + N-D Tensor. A location into which the result is stored. If x, y have different shapes and are "broadcastable", the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = paddle.to_tensor([[1, 2], [7, 8]]) + y = paddle.to_tensor([[3, 4], [5, 6]]) + res = paddle.fmin(x, y) + print(res) + # [[1, 2], + # [5, 6]] + + x = paddle.to_tensor([[[1, 2, 3], [1, 2, 3]]]) + y = paddle.to_tensor([3, 0, 4]) + res = paddle.fmin(x, y) + print(res) + # [[[1, 0, 3], + # [1, 0, 3]]] + + x = paddle.to_tensor([2, 3, 5], dtype='float32') + y = paddle.to_tensor([1, np.nan, np.nan], dtype='float32') + res = paddle.fmin(x, y) + print(res) + # [ 1., 3., 5.] + + x = paddle.to_tensor([5, 3, np.inf], dtype='float64') + y = paddle.to_tensor([1, -np.inf, 5], dtype='float64') + res = paddle.fmin(x, y) + print(res) + # [ 1., -inf., 5.] + """ + op_type = 'elementwise_fmin' + axis = -1 + act = None + if in_dygraph_mode(): + return _elementwise_op_in_dygraph( + x, y, axis=axis, act=act, op_name=op_type) + return _elementwise_op(LayerHelper(op_type, **locals())) + for func in [ add, multiply From cc6b4906462fc0e59991d4cc40ad36c28c1cf0a7 Mon Sep 17 00:00:00 2001 From: lijiaqi Date: Fri, 3 Dec 2021 06:42:54 +0000 Subject: [PATCH 02/11] Resolve conflicts --- python/paddle/__init__.py | 4 ++++ python/paddle/tensor/__init__.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index c361e581c7bfb..49ab3101a8339 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -223,6 +223,10 @@ from .tensor.math import digamma # noqa: F401 from .tensor.math import neg # noqa: F401 from .tensor.math import lgamma # noqa: F401 +from .tensor.math import rad2deg # noqa: F401 +from .tensor.math import deg2rad # noqa: F401 +from .tensor.math import diff # noqa: F401 +from .tensor.math import angle # noqa: F401 from .tensor.math import fmax # noqa: F401 from .tensor.math import fmin # noqa: F401 diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index f29cdd6c8d1d6..3e7f319bd8a33 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -189,6 +189,10 @@ from .math import neg # noqa: F401 from .math import lgamma # noqa: F401 from .math import diagonal # noqa: F401 +from .math import rad2deg # noqa: F401 +from .math import deg2rad # noqa: F401 +from .math import diff # noqa: F401 +from .math import angle # noqa: F401 from .math import fmax # noqa: F401 from .math import fmin # noqa: F401 From 5cf555310ca451564099c95115827a46f8eadaa0 Mon Sep 17 00:00:00 2001 From: lijiaqi Date: Mon, 6 Dec 2021 08:16:37 +0000 Subject: [PATCH 03/11] add fp16 type to fmax and fmin --- .../elementwise/elementwise_functor.h | 23 ++++++ .../elementwise/elementwise_max_op.cc | 32 +++++--- .../elementwise/elementwise_max_op.cu | 2 + .../elementwise/elementwise_min_op.cc | 32 +++++--- .../elementwise/elementwise_min_op.cu | 2 + .../elementwise/elementwise_min_op.h | 4 +- .../fluid/tests/unittests/test_fmax_op.py | 72 ++++++++++++++++++ .../fluid/tests/unittests/test_fmin_op.py | 73 +++++++++++++++++++ .../white_list/no_grad_set_white_list.py | 2 + python/paddle/tensor/math.py | 2 + 10 files changed, 222 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_functor.h b/paddle/fluid/operators/elementwise/elementwise_functor.h index b34466619eb2f..91dae22fc5d50 100644 --- a/paddle/fluid/operators/elementwise/elementwise_functor.h +++ b/paddle/fluid/operators/elementwise/elementwise_functor.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/hostdevice.h" +#include "paddle/fluid/platform/bfloat16.h" namespace paddle { namespace operators { @@ -121,6 +122,17 @@ struct FMaxFunctor { } }; +template <> +struct FMaxFunctor { + inline HOSTDEVICE paddle::platform::float16 operator()(const paddle::platform::float16& a, + const paddle::platform::float16& b) const { + float float_a = static_cast(a); + float float_b = static_cast(b); + auto result = std::fmax(float_a, float_b); + return static_cast(result); + } +}; + // Fmin template struct FMinFunctor { @@ -129,5 +141,16 @@ struct FMinFunctor { } }; +template <> +struct FMinFunctor { + inline HOSTDEVICE paddle::platform::float16 operator()(const paddle::platform::float16& a, + const paddle::platform::float16& b) const { + float float_a = static_cast(a); + float float_b = static_cast(b); + auto result = std::fmin(float_a, float_b); + return static_cast(result); + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cc b/paddle/fluid/operators/elementwise/elementwise_max_op.cc index 16fa73ca40c4c..0cc864e8d6e37 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cc @@ -53,6 +53,25 @@ class ElementwiseMaxOpMaker : public ElementwiseOpMaker { } }; +class ElementwiseFMaxOpMaker : public ElementwiseOpMaker { + protected: + std::string GetName() const override { return "FMax"; } + std::string GetEquation() const override { return "Out = fmax(X, Y)"; } + + void AddInputX() override { + AddInput("X", "The first tensor holding the elements to be compared."); + } + + void AddInputY() override { + AddInput("Y", "The second tensor holding the elements to be compared."); + } + + std::string GetOpFuntionality() const override { + return "Compare two tensors and returns a new tensor containing the " + "element-wise maxima."; + } +}; + template class ElementwiseMaxGradOpMaker : public framework::SingleGradOpMaker { public: @@ -122,7 +141,7 @@ REGISTER_OP_VERSION(elementwise_max) 1.0f)); REGISTER_OPERATOR(elementwise_fmax, ops::ElementwiseOp, - ops::ElementwiseMaxOpMaker, ops::ElementwiseOpInferVarType, + ops::ElementwiseFMaxOpMaker, ops::ElementwiseOpInferVarType, ops::ElementwiseFMaxGradOpMaker, ops::ElementwiseFMaxGradOpMaker); @@ -131,21 +150,14 @@ REGISTER_OPERATOR(elementwise_fmax_grad, ops::ElementwiseOpGrad); REGISTER_OP_CPU_KERNEL( elementwise_fmax, ops::ElementwiseFMaxKernel, + ops::ElementwiseFMaxKernel, ops::ElementwiseFMaxKernel, ops::ElementwiseFMaxKernel, ops::ElementwiseFMaxKernel); REGISTER_OP_CPU_KERNEL( elementwise_fmax_grad, ops::ElementwiseFMaxGradKernel, + ops::ElementwiseFMaxGradKernel, ops::ElementwiseFMaxGradKernel, ops::ElementwiseFMaxGradKernel, ops::ElementwiseFMaxGradKernel); - -REGISTER_OP_VERSION(elementwise_fmax) - .AddCheckpoint( - R"ROC(Register elementwise_fmax for adding the attribute of Scale_y)ROC", - paddle::framework::compatible::OpVersionDesc().NewAttr( - "Scale_y", - "In order to support the function of scaling the input Y when " - "using the operator of elementwise_fmax.", - 1.0f)); \ No newline at end of file diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cu b/paddle/fluid/operators/elementwise/elementwise_max_op.cu index d350491b06f83..b0708acef1998 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cu @@ -60,12 +60,14 @@ REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL( elementwise_fmax, ops::ElementwiseFMaxKernel, + ops::ElementwiseFMaxKernel, ops::ElementwiseFMaxKernel, ops::ElementwiseFMaxKernel, ops::ElementwiseFMaxKernel); REGISTER_OP_CUDA_KERNEL( elementwise_fmax_grad, ops::ElementwiseFMaxGradKernel, + ops::ElementwiseFMaxGradKernel, ops::ElementwiseFMaxGradKernel, ops::ElementwiseFMaxGradKernel, ops::ElementwiseFMaxGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cc b/paddle/fluid/operators/elementwise/elementwise_min_op.cc index 5ceece83d0372..3e04622b02cd7 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cc @@ -53,6 +53,25 @@ class ElementwiseMinOpMaker : public ElementwiseOpMaker { } }; +class ElementwiseFMinOpMaker : public ElementwiseOpMaker { + protected: + std::string GetName() const override { return "FMin"; } + std::string GetEquation() const override { return "Out = fmin(X, Y)"; } + + void AddInputX() override { + AddInput("X", "The first tensor holding the elements to be compared."); + } + + void AddInputY() override { + AddInput("Y", "The second tensor holding the elements to be compared."); + } + + std::string GetOpFuntionality() const override { + return "Compare two tensors and returns a new tensor containing the " + "element-wise minima."; + } +}; + template class ElementwiseMinGradOpMaker : public framework::SingleGradOpMaker { public: @@ -122,7 +141,7 @@ REGISTER_OP_VERSION(elementwise_min) 1.0f)); REGISTER_OPERATOR(elementwise_fmin, ops::ElementwiseOp, - ops::ElementwiseMinOpMaker, ops::ElementwiseOpInferVarType, + ops::ElementwiseFMinOpMaker, ops::ElementwiseOpInferVarType, ops::ElementwiseFMinGradOpMaker, ops::ElementwiseFMinGradOpMaker); @@ -131,21 +150,14 @@ REGISTER_OPERATOR(elementwise_fmin_grad, ops::ElementwiseOpGrad); REGISTER_OP_CPU_KERNEL( elementwise_fmin, ops::ElementwiseFMinKernel, + ops::ElementwiseFMinKernel, ops::ElementwiseFMinKernel, ops::ElementwiseFMinKernel, ops::ElementwiseFMinKernel); REGISTER_OP_CPU_KERNEL( elementwise_fmin_grad, ops::ElementwiseFMinGradKernel, + ops::ElementwiseFMinGradKernel, ops::ElementwiseFMinGradKernel, ops::ElementwiseFMinGradKernel, ops::ElementwiseFMinGradKernel); - -REGISTER_OP_VERSION(elementwise_fmin) - .AddCheckpoint( - R"ROC(Register elementwise_min for adding the attribute of Scale_y)ROC", - paddle::framework::compatible::OpVersionDesc().NewAttr( - "Scale_y", - "In order to support the function of scaling the input Y when " - "using the operator of elementwise_fmin.", - 1.0f)); diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cu b/paddle/fluid/operators/elementwise/elementwise_min_op.cu index 5c1adab4b1264..e9d8197fce36c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cu @@ -56,12 +56,14 @@ REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL( elementwise_fmin, ops::ElementwiseFMinKernel, + ops::ElementwiseFMinKernel, ops::ElementwiseFMinKernel, ops::ElementwiseFMinKernel, ops::ElementwiseFMinKernel); REGISTER_OP_CUDA_KERNEL( elementwise_fmin_grad, ops::ElementwiseFMinGradKernel, + ops::ElementwiseFMinGradKernel, ops::ElementwiseFMinGradKernel, ops::ElementwiseFMinGradKernel, ops::ElementwiseFMinGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.h b/paddle/fluid/operators/elementwise/elementwise_min_op.h index 151a48d31fd73..a61d59dcd4e43 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.h @@ -46,8 +46,8 @@ class ElementwiseFMinKernel : public framework::OpKernel { z->mutable_data(ctx.GetPlace()); int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - MinFunctor(), z); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + FMinFunctor(), z); } }; diff --git a/python/paddle/fluid/tests/unittests/test_fmax_op.py b/python/paddle/fluid/tests/unittests/test_fmax_op.py index 1bc549c7151a9..40c5b03edd309 100644 --- a/python/paddle/fluid/tests/unittests/test_fmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_fmax_op.py @@ -18,10 +18,13 @@ import numpy as np import paddle import paddle.fluid.core as core +from op_test import OpTest class ApiFMaxTest(unittest.TestCase): + """ApiFMaxTest""" def setUp(self): + """setUp""" if core.is_compiled_with_cuda(): self.place = core.CUDAPlace(0) else: @@ -40,6 +43,7 @@ def setUp(self): self.np_expected4 = np.fmax(self.input_b, self.input_c) def test_static_api(self): + """test_static_api""" paddle.enable_static() with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()): @@ -86,6 +90,7 @@ def test_static_api(self): self.assertTrue(np.allclose(res, self.np_expected4)) def test_dynamic_api(self): + """test_dynamic_api""" paddle.disable_static() x = paddle.to_tensor(self.input_x) y = paddle.to_tensor(self.input_y) @@ -111,3 +116,70 @@ def test_dynamic_api(self): res = paddle.fmax(b, c) res = res.numpy() self.assertTrue(np.allclose(res, self.np_expected4)) + + +class TestElementwiseFmaxOp(OpTest): + """TestElementwiseFmaxOp""" + def setUp(self): + """setUp""" + self.op_type = "elementwise_fmax" + # If x and y have the same value, the max() is not differentiable. + # So we generate test data by the following method + # to avoid them being too close to each other. + x = np.random.uniform(0.1, 1, [13, 17]).astype("float64") + sgn = np.random.choice([-1, 1], [13, 17]).astype("float64") + y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float64") + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.fmax(self.inputs['X'], self.inputs['Y'])} + + def test_check_output(self): + """test_check_output""" + self.check_output() + + def test_check_grad_normal(self): + """test_check_grad_normal""" + self.check_grad(['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + """test_check_grad_ingore_x""" + self.check_grad( + ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + """test_check_grad_ingore_y""" + self.check_grad( + ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) + + +class TestElementwiseFmax2Op(OpTest): + """TestElementwiseFmax2Op""" + def setUp(self): + """setUp""" + self.op_type = "elementwise_fmax" + # If x and y have the same value, the max() is not differentiable. + # So we generate test data by the following method + # to avoid them being too close to each other. + x = np.random.uniform(0.1, 1, [13, 17]).astype("float64") + sgn = np.random.choice([-1, 1], [13, 17]).astype("float64") + y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float64") + y[2, 10:] = np.nan + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.fmax(self.inputs['X'], self.inputs['Y'])} + + def test_check_output(self): + """test_check_output""" + self.check_output() + + def test_check_grad_normal(self): + """test_check_grad_normal""" + self.check_grad(['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + """test_check_grad_ingore_x""" + self.check_grad( + ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + """test_check_grad_ingore_y""" + self.check_grad( + ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) diff --git a/python/paddle/fluid/tests/unittests/test_fmin_op.py b/python/paddle/fluid/tests/unittests/test_fmin_op.py index a1a25cce29af9..3dead6d37f214 100644 --- a/python/paddle/fluid/tests/unittests/test_fmin_op.py +++ b/python/paddle/fluid/tests/unittests/test_fmin_op.py @@ -18,10 +18,14 @@ import numpy as np import paddle import paddle.fluid.core as core +from op_test import OpTest +paddle.enable_static() class ApiFMinTest(unittest.TestCase): + """ApiFMinTest""" def setUp(self): + """setUp""" if core.is_compiled_with_cuda(): self.place = core.CUDAPlace(0) else: @@ -40,6 +44,7 @@ def setUp(self): self.np_expected4 = np.fmin(self.input_b, self.input_c) def test_static_api(self): + """test_static_api""" paddle.enable_static() with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()): @@ -86,6 +91,7 @@ def test_static_api(self): self.assertTrue(np.allclose(res, self.np_expected4)) def test_dynamic_api(self): + """test_dynamic_api""" paddle.disable_static() x = paddle.to_tensor(self.input_x) y = paddle.to_tensor(self.input_y) @@ -111,3 +117,70 @@ def test_dynamic_api(self): res = paddle.fmin(b, c) res = res.numpy() self.assertTrue(np.allclose(res, self.np_expected4)) + + +class TestElementwiseFminOp(OpTest): + """TestElementwiseFminOp""" + def setUp(self): + """setUp""" + self.op_type = "elementwise_fmin" + # If x and y have the same value, the min() is not differentiable. + # So we generate test data by the following method + # to avoid them being too close to each other. + x = np.random.uniform(0.1, 1, [13, 17]).astype("float64") + sgn = np.random.choice([-1, 1], [13, 17]).astype("float64") + y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float64") + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.fmin(self.inputs['X'], self.inputs['Y'])} + + def test_check_output(self): + """test_check_output""" + self.check_output() + + def test_check_grad_normal(self): + """test_check_grad_normal""" + self.check_grad(['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + """test_check_grad_ingore_x""" + self.check_grad( + ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + """test_check_grad_ingore_y""" + self.check_grad( + ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) + + +class TestElementwiseFmin2Op(OpTest): + """TestElementwiseFmin2Op""" + def setUp(self): + """setUp""" + self.op_type = "elementwise_fmin" + # If x and y have the same value, the min() is not differentiable. + # So we generate test data by the following method + # to avoid them being too close to each other. + x = np.random.uniform(0.1, 1, [13, 17]).astype("float64") + sgn = np.random.choice([-1, 1], [13, 17]).astype("float64") + y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float64") + y[2, 10:] = np.nan + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.fmin(self.inputs['X'], self.inputs['Y'])} + + def test_check_output(self): + """test_check_output""" + self.check_output() + + def test_check_grad_normal(self): + """test_check_grad_normal""" + self.check_grad(['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + """test_check_grad_ingore_x""" + self.check_grad( + ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + """test_check_grad_ingore_y""" + self.check_grad( + ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) diff --git a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py index 29374a9796504..725ad4e93824f 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py @@ -42,6 +42,8 @@ 'elementwise_mul', 'elementwise_sub', 'elementwise_pow', + 'elementwise_fmin', + 'elementwise_fmax', 'filter_by_instag', 'fused_elemwise_activation', 'fused_emb_seq_pool', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 6ed12e4b93a8e..1bd6aeb0dd446 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -634,6 +634,7 @@ def minimum(x, y, name=None): def fmax(x, y, name=None): """ Compares the elements at the corresponding positions of the two tensors and returns a new tensor containing the maximum value of the element. + If one of them is a nan value, the other value is directly returned, if both are nan values, then the first nan value is returned. The equation is: .. math:: @@ -694,6 +695,7 @@ def fmax(x, y, name=None): def fmin(x, y, name=None): """ Compares the elements at the corresponding positions of the two tensors and returns a new tensor containing the minimum value of the element. + If one of them is a nan value, the other value is directly returned, if both are nan values, then the first nan value is returned. The equation is: .. math:: From 3fff0fafd330365f54df651f8404cb6bd4732872 Mon Sep 17 00:00:00 2001 From: lijiaqi Date: Mon, 6 Dec 2021 13:39:37 +0000 Subject: [PATCH 04/11] fix bug --- .../fluid/operators/elementwise/elementwise_max_op.h | 10 +++++----- python/paddle/fluid/tests/unittests/test_fmin_op.py | 6 +++++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.h b/paddle/fluid/operators/elementwise/elementwise_max_op.h index 2995af05018ea..ba379745a2d29 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.h @@ -47,11 +47,10 @@ class ElementwiseFMaxKernel : public framework::OpKernel { z->mutable_data(ctx.GetPlace()); int axis = ctx.Attr("axis"); ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - FMaxFunctor(), z); + FMaxFunctor(), z); } }; - template struct MaxGradDx { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { @@ -88,14 +87,14 @@ class ElementwiseMaxGradKernel : public ElemwiseGradKernel { template struct FMaxGradDx { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * static_cast((x >= y) || isnan(y)); + return dout * static_cast((x >= y) || std::isnan(y)); } }; template struct FMaxGradDy { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * static_cast (!((x >= y) || isnan(y))); + return dout * static_cast(!((x >= y) || std::isnan(y))); } }; @@ -114,7 +113,8 @@ class ElementwiseFMaxGradKernel : public ElemwiseGradKernel { auto* out = dout; // Fake out, not used int axis = ctx.Attr("axis"); ElemwiseGradCompute, FMaxGradDy>( - ctx, *x, *y, *out, *dout, axis, dx, dy, FMaxGradDx(), FMaxGradDy()); + ctx, *x, *y, *out, *dout, axis, dx, dy, FMaxGradDx(), + FMaxGradDy()); } }; } // namespace operators diff --git a/python/paddle/fluid/tests/unittests/test_fmin_op.py b/python/paddle/fluid/tests/unittests/test_fmin_op.py index 3dead6d37f214..3339bd3d4a64e 100644 --- a/python/paddle/fluid/tests/unittests/test_fmin_op.py +++ b/python/paddle/fluid/tests/unittests/test_fmin_op.py @@ -22,8 +22,10 @@ paddle.enable_static() + class ApiFMinTest(unittest.TestCase): """ApiFMinTest""" + def setUp(self): """setUp""" if core.is_compiled_with_cuda(): @@ -121,6 +123,7 @@ def test_dynamic_api(self): class TestElementwiseFminOp(OpTest): """TestElementwiseFminOp""" + def setUp(self): """setUp""" self.op_type = "elementwise_fmin" @@ -154,6 +157,7 @@ def test_check_grad_ingore_y(self): class TestElementwiseFmin2Op(OpTest): """TestElementwiseFmin2Op""" + def setUp(self): """setUp""" self.op_type = "elementwise_fmin" @@ -163,7 +167,7 @@ def setUp(self): x = np.random.uniform(0.1, 1, [13, 17]).astype("float64") sgn = np.random.choice([-1, 1], [13, 17]).astype("float64") y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float64") - y[2, 10:] = np.nan + y[2, 10:] = np.nan self.inputs = {'X': x, 'Y': y} self.outputs = {'Out': np.fmin(self.inputs['X'], self.inputs['Y'])} From 882dce886e53ab19040d182247eea261832e4694 Mon Sep 17 00:00:00 2001 From: lijiaqi Date: Tue, 7 Dec 2021 04:19:28 +0000 Subject: [PATCH 05/11] fix bugs --- .../elementwise/elementwise_functor.h | 35 ++++++++++--------- .../elementwise/elementwise_max_op.h | 4 +-- .../elementwise/elementwise_min_op.h | 9 ++--- 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_functor.h b/paddle/fluid/operators/elementwise/elementwise_functor.h index 91dae22fc5d50..6e53af41b657c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_functor.h +++ b/paddle/fluid/operators/elementwise/elementwise_functor.h @@ -17,7 +17,6 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/hostdevice.h" -#include "paddle/fluid/platform/bfloat16.h" namespace paddle { namespace operators { @@ -118,19 +117,20 @@ struct MinFunctor { template struct FMaxFunctor { inline HOSTDEVICE T operator()(const T& a, const T& b) const { - return std::fmax(a, b); - } + return std::fmax(a, b); + } }; template <> struct FMaxFunctor { - inline HOSTDEVICE paddle::platform::float16 operator()(const paddle::platform::float16& a, - const paddle::platform::float16& b) const { - float float_a = static_cast(a); - float float_b = static_cast(b); - auto result = std::fmax(float_a, float_b); - return static_cast(result); - } + inline HOSTDEVICE paddle::platform::float16 operator()( + const paddle::platform::float16& a, + const paddle::platform::float16& b) const { + float float_a = static_cast(a); + float float_b = static_cast(b); + auto result = std::fmax(float_a, float_b); + return static_cast(result); + } }; // Fmin @@ -143,13 +143,14 @@ struct FMinFunctor { template <> struct FMinFunctor { - inline HOSTDEVICE paddle::platform::float16 operator()(const paddle::platform::float16& a, - const paddle::platform::float16& b) const { - float float_a = static_cast(a); - float float_b = static_cast(b); - auto result = std::fmin(float_a, float_b); - return static_cast(result); - } + inline HOSTDEVICE paddle::platform::float16 operator()( + const paddle::platform::float16& a, + const paddle::platform::float16& b) const { + float float_a = static_cast(a); + float float_b = static_cast(b); + auto result = std::fmin(float_a, float_b); + return static_cast(result); + } }; } // namespace operators diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.h b/paddle/fluid/operators/elementwise/elementwise_max_op.h index ba379745a2d29..0dc30b37cfad5 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.h @@ -87,14 +87,14 @@ class ElementwiseMaxGradKernel : public ElemwiseGradKernel { template struct FMaxGradDx { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * static_cast((x >= y) || std::isnan(y)); + return dout * static_cast((x >= y) || ::isnan(y)); } }; template struct FMaxGradDy { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * static_cast(!((x >= y) || std::isnan(y))); + return dout * static_cast(!((x >= y) || ::isnan(y))); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.h b/paddle/fluid/operators/elementwise/elementwise_min_op.h index a61d59dcd4e43..e86e3bb3e340f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.h @@ -47,7 +47,7 @@ class ElementwiseFMinKernel : public framework::OpKernel { z->mutable_data(ctx.GetPlace()); int axis = ctx.Attr("axis"); ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - FMinFunctor(), z); + FMinFunctor(), z); } }; @@ -87,14 +87,14 @@ class ElementwiseMinGradKernel : public ElemwiseGradKernel { template struct FMinGradDx { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * static_cast((x <= y) || isnan(y)); + return dout * static_cast((x <= y) || ::isnan(y)); } }; template struct FMinGradDy { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * static_cast (!((x <= y) || isnan(y))); + return dout * static_cast(!((x <= y) || ::isnan(y))); } }; @@ -113,7 +113,8 @@ class ElementwiseFMinGradKernel : public ElemwiseGradKernel { auto* out = dout; // Fake out, not used int axis = ctx.Attr("axis"); ElemwiseGradCompute, FMinGradDy>( - ctx, *x, *y, *out, *dout, axis, dx, dy, FMinGradDx(), FMinGradDy()); + ctx, *x, *y, *out, *dout, axis, dx, dy, FMinGradDx(), + FMinGradDy()); } }; } // namespace operators From 3869a11269325670bb14291b86288cf18acd697b Mon Sep 17 00:00:00 2001 From: lijiaqi Date: Tue, 7 Dec 2021 06:46:09 +0000 Subject: [PATCH 06/11] fix bugs second --- paddle/fluid/operators/elementwise/elementwise_max_op.cc | 9 ++++++--- paddle/fluid/operators/elementwise/elementwise_max_op.cu | 9 ++++++--- paddle/fluid/operators/elementwise/elementwise_max_op.h | 4 ++-- paddle/fluid/operators/elementwise/elementwise_min_op.h | 4 ++-- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cc b/paddle/fluid/operators/elementwise/elementwise_max_op.cc index 0cc864e8d6e37..23ffe8c5db0c1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cc @@ -150,14 +150,17 @@ REGISTER_OPERATOR(elementwise_fmax_grad, ops::ElementwiseOpGrad); REGISTER_OP_CPU_KERNEL( elementwise_fmax, ops::ElementwiseFMaxKernel, - ops::ElementwiseFMaxKernel, + ops::ElementwiseFMaxKernel, ops::ElementwiseFMaxKernel, ops::ElementwiseFMaxKernel, ops::ElementwiseFMaxKernel); REGISTER_OP_CPU_KERNEL( elementwise_fmax_grad, ops::ElementwiseFMaxGradKernel, - ops::ElementwiseFMaxGradKernel, + ops::ElementwiseFMaxGradKernel, ops::ElementwiseFMaxGradKernel, ops::ElementwiseFMaxGradKernel, - ops::ElementwiseFMaxGradKernel); + ops::ElementwiseFMaxGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cu b/paddle/fluid/operators/elementwise/elementwise_max_op.cu index b0708acef1998..eb6f78bf270ad 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cu @@ -60,14 +60,17 @@ REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL( elementwise_fmax, ops::ElementwiseFMaxKernel, - ops::ElementwiseFMaxKernel, + ops::ElementwiseFMaxKernel, ops::ElementwiseFMaxKernel, ops::ElementwiseFMaxKernel, ops::ElementwiseFMaxKernel); REGISTER_OP_CUDA_KERNEL( elementwise_fmax_grad, ops::ElementwiseFMaxGradKernel, - ops::ElementwiseFMaxGradKernel, + ops::ElementwiseFMaxGradKernel, ops::ElementwiseFMaxGradKernel, ops::ElementwiseFMaxGradKernel, - ops::ElementwiseFMaxGradKernel); + ops::ElementwiseFMaxGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.h b/paddle/fluid/operators/elementwise/elementwise_max_op.h index 0dc30b37cfad5..eaa4cdf770c34 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.h @@ -87,14 +87,14 @@ class ElementwiseMaxGradKernel : public ElemwiseGradKernel { template struct FMaxGradDx { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * static_cast((x >= y) || ::isnan(y)); + return dout * static_cast((x >= y) || (std::isnan(y))); } }; template struct FMaxGradDy { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * static_cast(!((x >= y) || ::isnan(y))); + return dout * static_cast(!((x >= y) || (std::isnan(y)))); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.h b/paddle/fluid/operators/elementwise/elementwise_min_op.h index e86e3bb3e340f..69468ad52c4ad 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.h @@ -87,14 +87,14 @@ class ElementwiseMinGradKernel : public ElemwiseGradKernel { template struct FMinGradDx { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * static_cast((x <= y) || ::isnan(y)); + return dout * static_cast((x <= y) || (std::isnan(y))); } }; template struct FMinGradDy { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * static_cast(!((x <= y) || ::isnan(y))); + return dout * static_cast(!((x <= y) || (std::isnan(y)))); } }; From 20e801ffa0225b203de9028f9708d9109e8dc636 Mon Sep 17 00:00:00 2001 From: lijiaqi Date: Tue, 7 Dec 2021 08:46:59 +0000 Subject: [PATCH 07/11] fix bugs third --- paddle/fluid/operators/elementwise/elementwise_max_op.h | 5 +++-- paddle/fluid/operators/elementwise/elementwise_min_op.h | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.h b/paddle/fluid/operators/elementwise/elementwise_max_op.h index eaa4cdf770c34..561ec97e6621c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" @@ -87,14 +88,14 @@ class ElementwiseMaxGradKernel : public ElemwiseGradKernel { template struct FMaxGradDx { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * static_cast((x >= y) || (std::isnan(y))); + return dout * static_cast((x >= y) || isnan(y)); } }; template struct FMaxGradDy { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * static_cast(!((x >= y) || (std::isnan(y)))); + return dout * static_cast(!((x >= y) || isnan(y))); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.h b/paddle/fluid/operators/elementwise/elementwise_min_op.h index 69468ad52c4ad..228d6fa9c92ce 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" @@ -87,14 +88,14 @@ class ElementwiseMinGradKernel : public ElemwiseGradKernel { template struct FMinGradDx { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * static_cast((x <= y) || (std::isnan(y))); + return dout * static_cast((x <= y) || isnan(y)); } }; template struct FMinGradDy { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * static_cast(!((x <= y) || (std::isnan(y)))); + return dout * static_cast(!((x <= y) || isnan(y))); } }; From 0b061189c60319b1b7e076e8e673e596501b2eeb Mon Sep 17 00:00:00 2001 From: lijiaqi Date: Tue, 7 Dec 2021 09:40:30 +0000 Subject: [PATCH 08/11] fix bugs four- temple specily --- .../elementwise/elementwise_max_op.h | 30 +++++++++++++++++++ .../elementwise/elementwise_min_op.h | 30 +++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.h b/paddle/fluid/operators/elementwise/elementwise_max_op.h index 561ec97e6621c..bba6dd80db153 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.h @@ -92,6 +92,21 @@ struct FMaxGradDx { } }; +template <> +struct FMaxGradDx { + HOSTDEVICE int operator()(int x, int y, int out, int dout) const { + return dout * static_cast((x >= y)); + } +}; + +template <> +struct FMaxGradDx { + HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out, + int64_t dout) const { + return dout * static_cast((x >= y)); + } +}; + template struct FMaxGradDy { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { @@ -99,6 +114,21 @@ struct FMaxGradDy { } }; +template <> +struct FMaxGradDy { + HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out, + int64_t dout) const { + return dout * static_cast(!((x >= y))); + } +}; + +template <> +struct FMaxGradDy { + HOSTDEVICE int operator()(int x, int y, int out, int dout) const { + return dout * static_cast(!((x >= y))); + } +}; + template class ElementwiseFMaxGradKernel : public ElemwiseGradKernel { public: diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.h b/paddle/fluid/operators/elementwise/elementwise_min_op.h index 228d6fa9c92ce..a981a80fd1081 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.h @@ -92,6 +92,21 @@ struct FMinGradDx { } }; +template <> +struct FMinGradDx { + HOSTDEVICE int operator()(int x, int y, int out, int dout) const { + return dout * static_cast((x <= y)); + } +}; + +template <> +struct FMinGradDx { + HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out, + int64_t dout) const { + return dout * static_cast((x <= y)); + } +}; + template struct FMinGradDy { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { @@ -99,6 +114,21 @@ struct FMinGradDy { } }; +template <> +struct FMinGradDy { + HOSTDEVICE int operator()(int x, int y, int out, int dout) const { + return dout * static_cast(!((x <= y))); + } +}; + +template <> +struct FMinGradDy { + HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out, + int64_t dout) const { + return dout * static_cast(!((x <= y))); + } +}; + template class ElementwiseFMinGradKernel : public ElemwiseGradKernel { public: From 165dffa678037f38a9b54a9bd06eacde73242847 Mon Sep 17 00:00:00 2001 From: lijiaqi Date: Tue, 7 Dec 2021 12:15:37 +0000 Subject: [PATCH 09/11] fix bugs fiveth-temple specily --- .../elementwise/elementwise_max_op.h | 21 +++++++++++++++++++ .../elementwise/elementwise_min_op.cc | 9 +++++--- .../elementwise/elementwise_min_op.cu | 9 +++++--- .../elementwise/elementwise_min_op.h | 21 +++++++++++++++++++ .../fluid/tests/unittests/test_fmax_op.py | 6 +++++- .../fluid/tests/unittests/test_fmin_op.py | 1 + 6 files changed, 60 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.h b/paddle/fluid/operators/elementwise/elementwise_max_op.h index bba6dd80db153..acb212e992a1d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/platform/eigen_ext.h" namespace paddle { namespace operators { @@ -92,6 +93,16 @@ struct FMaxGradDx { } }; +template <> +struct FMaxGradDx { + HOSTDEVICE paddle::platform::float16 operator()( + paddle::platform::float16 x, paddle::platform::float16 y, + paddle::platform::float16 out, paddle::platform::float16 dout) const { + return dout * static_cast( + (x >= y) || paddle::platform::isnan(y)); + } +}; + template <> struct FMaxGradDx { HOSTDEVICE int operator()(int x, int y, int out, int dout) const { @@ -114,6 +125,16 @@ struct FMaxGradDy { } }; +template <> +struct FMaxGradDy { + HOSTDEVICE paddle::platform::float16 operator()( + paddle::platform::float16 x, paddle::platform::float16 y, + paddle::platform::float16 out, paddle::platform::float16 dout) const { + return dout * static_cast( + !((x >= y) || paddle::platform::isnan(y))); + } +}; + template <> struct FMaxGradDy { HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out, diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cc b/paddle/fluid/operators/elementwise/elementwise_min_op.cc index 3e04622b02cd7..157ad45a1759b 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cc @@ -150,14 +150,17 @@ REGISTER_OPERATOR(elementwise_fmin_grad, ops::ElementwiseOpGrad); REGISTER_OP_CPU_KERNEL( elementwise_fmin, ops::ElementwiseFMinKernel, - ops::ElementwiseFMinKernel, + ops::ElementwiseFMinKernel, ops::ElementwiseFMinKernel, ops::ElementwiseFMinKernel, ops::ElementwiseFMinKernel); REGISTER_OP_CPU_KERNEL( elementwise_fmin_grad, ops::ElementwiseFMinGradKernel, - ops::ElementwiseFMinGradKernel, + ops::ElementwiseFMinGradKernel, ops::ElementwiseFMinGradKernel, ops::ElementwiseFMinGradKernel, - ops::ElementwiseFMinGradKernel); + ops::ElementwiseFMinGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cu b/paddle/fluid/operators/elementwise/elementwise_min_op.cu index e9d8197fce36c..a51398640579b 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cu @@ -56,14 +56,17 @@ REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL( elementwise_fmin, ops::ElementwiseFMinKernel, - ops::ElementwiseFMinKernel, + ops::ElementwiseFMinKernel, ops::ElementwiseFMinKernel, ops::ElementwiseFMinKernel, ops::ElementwiseFMinKernel); REGISTER_OP_CUDA_KERNEL( elementwise_fmin_grad, ops::ElementwiseFMinGradKernel, - ops::ElementwiseFMinGradKernel, + ops::ElementwiseFMinGradKernel, ops::ElementwiseFMinGradKernel, ops::ElementwiseFMinGradKernel, - ops::ElementwiseFMinGradKernel); + ops::ElementwiseFMinGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.h b/paddle/fluid/operators/elementwise/elementwise_min_op.h index a981a80fd1081..2f96ef747708b 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/platform/eigen_ext.h" namespace paddle { namespace operators { @@ -92,6 +93,16 @@ struct FMinGradDx { } }; +template <> +struct FMinGradDx { + HOSTDEVICE paddle::platform::float16 operator()( + paddle::platform::float16 x, paddle::platform::float16 y, + paddle::platform::float16 out, paddle::platform::float16 dout) const { + return dout * static_cast( + (x <= y) || paddle::platform::isnan(y)); + } +}; + template <> struct FMinGradDx { HOSTDEVICE int operator()(int x, int y, int out, int dout) const { @@ -114,6 +125,16 @@ struct FMinGradDy { } }; +template <> +struct FMinGradDy { + HOSTDEVICE paddle::platform::float16 operator()( + paddle::platform::float16 x, paddle::platform::float16 y, + paddle::platform::float16 out, paddle::platform::float16 dout) const { + return dout * static_cast( + !((x <= y) || paddle::platform::isnan(y))); + } +}; + template <> struct FMinGradDy { HOSTDEVICE int operator()(int x, int y, int out, int dout) const { diff --git a/python/paddle/fluid/tests/unittests/test_fmax_op.py b/python/paddle/fluid/tests/unittests/test_fmax_op.py index 40c5b03edd309..3981d63c00582 100644 --- a/python/paddle/fluid/tests/unittests/test_fmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_fmax_op.py @@ -23,6 +23,7 @@ class ApiFMaxTest(unittest.TestCase): """ApiFMaxTest""" + def setUp(self): """setUp""" if core.is_compiled_with_cuda(): @@ -120,6 +121,7 @@ def test_dynamic_api(self): class TestElementwiseFmaxOp(OpTest): """TestElementwiseFmaxOp""" + def setUp(self): """setUp""" self.op_type = "elementwise_fmax" @@ -153,6 +155,7 @@ def test_check_grad_ingore_y(self): class TestElementwiseFmax2Op(OpTest): """TestElementwiseFmax2Op""" + def setUp(self): """setUp""" self.op_type = "elementwise_fmax" @@ -162,7 +165,8 @@ def setUp(self): x = np.random.uniform(0.1, 1, [13, 17]).astype("float64") sgn = np.random.choice([-1, 1], [13, 17]).astype("float64") y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float64") - y[2, 10:] = np.nan + + y[2, 10:] = np.nan self.inputs = {'X': x, 'Y': y} self.outputs = {'Out': np.fmax(self.inputs['X'], self.inputs['Y'])} diff --git a/python/paddle/fluid/tests/unittests/test_fmin_op.py b/python/paddle/fluid/tests/unittests/test_fmin_op.py index 3339bd3d4a64e..5cdf096be6708 100644 --- a/python/paddle/fluid/tests/unittests/test_fmin_op.py +++ b/python/paddle/fluid/tests/unittests/test_fmin_op.py @@ -167,6 +167,7 @@ def setUp(self): x = np.random.uniform(0.1, 1, [13, 17]).astype("float64") sgn = np.random.choice([-1, 1], [13, 17]).astype("float64") y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float64") + y[2, 10:] = np.nan self.inputs = {'X': x, 'Y': y} self.outputs = {'Out': np.fmin(self.inputs['X'], self.inputs['Y'])} From 91587cff60e22005cde445d2db079257a005e25e Mon Sep 17 00:00:00 2001 From: lijiaqi Date: Fri, 10 Dec 2021 12:19:27 +0000 Subject: [PATCH 10/11] describe opmaker comments more detailed --- paddle/fluid/operators/elementwise/elementwise_max_op.cc | 4 +++- paddle/fluid/operators/elementwise/elementwise_min_op.cc | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cc b/paddle/fluid/operators/elementwise/elementwise_max_op.cc index 23ffe8c5db0c1..e0686e815459a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cc @@ -68,7 +68,9 @@ class ElementwiseFMaxOpMaker : public ElementwiseOpMaker { std::string GetOpFuntionality() const override { return "Compare two tensors and returns a new tensor containing the " - "element-wise maxima."; + "element-wise maxima. If the element of one tensor is nan, " + "return the element value of the other tensor, if both are nan, " + "return the first nan"; } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cc b/paddle/fluid/operators/elementwise/elementwise_min_op.cc index 157ad45a1759b..1448520eca18f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cc @@ -68,7 +68,9 @@ class ElementwiseFMinOpMaker : public ElementwiseOpMaker { std::string GetOpFuntionality() const override { return "Compare two tensors and returns a new tensor containing the " - "element-wise minima."; + "element-wise minima. If the element of one tensor is nan, " + "return the element value of the other tensor, if both are nan, " + "return the first nan"; } }; From f21bbffbb3c277ccec8c1260f96264bd0f0a5c9a Mon Sep 17 00:00:00 2001 From: lijiaqi Date: Mon, 13 Dec 2021 03:13:59 +0000 Subject: [PATCH 11/11] nothing --- python/paddle/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 99d92d60d5103..229207676a46c 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -460,8 +460,6 @@ 'pow', 'zeros_like', 'maximum', - 'fmax', - 'fmin', 'topk', 'index_select', 'CPUPlace', @@ -570,4 +568,6 @@ 'as_real', 'diff', 'angle', + 'fmax', + 'fmin', ]