diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 0e476d663600f..289f799abb763 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -953,6 +953,16 @@ kernel : func : frame_grad +- backward_op : gammaincc_grad + forward : gammaincc(Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(y_grad) + infer_meta : + func : UnchangedInferMeta + param : [y] + kernel : + func : gammaincc_grad + - backward_op : gammaln_grad forward : gammaln(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 3b253dcad7ae6..e8d65233e2ca4 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1105,6 +1105,17 @@ backend : place interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : gammaincc + args : (Tensor x, Tensor y) + output : Tensor(out) + infer_meta : + func : ElementwiseInferMeta + param : [x, y] + kernel : + func : gammaincc + inplace: (x -> out) + backward : gammaincc_grad + - op : gammaln args : (Tensor x) output : Tensor(out) diff --git a/paddle/phi/kernels/cpu/gammaincc_grad_kernel.cc b/paddle/phi/kernels/cpu/gammaincc_grad_kernel.cc new file mode 100644 index 0000000000000..c6b3c83a6b906 --- /dev/null +++ b/paddle/phi/kernels/cpu/gammaincc_grad_kernel.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gammaincc_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gammaincc_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + gammaincc_grad, CPU, ALL_LAYOUT, phi::GammainccGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/gammaincc_kernel.cc b/paddle/phi/kernels/cpu/gammaincc_kernel.cc new file mode 100644 index 0000000000000..bfe21c24231b1 --- /dev/null +++ b/paddle/phi/kernels/cpu/gammaincc_kernel.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gammaincc_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gammaincc_kernel_impl.h" + +PD_REGISTER_KERNEL( + gammaincc, CPU, ALL_LAYOUT, phi::GammainccKernel, float, double) {} diff --git a/paddle/phi/kernels/gammaincc_grad_kernel.h b/paddle/phi/kernels/gammaincc_grad_kernel.h new file mode 100644 index 0000000000000..30e046b057564 --- /dev/null +++ b/paddle/phi/kernels/gammaincc_grad_kernel.h @@ -0,0 +1,28 @@ + +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GammainccGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& d_out, + DenseTensor* d_y); +} // namespace phi diff --git a/paddle/phi/kernels/gammaincc_kernel.h b/paddle/phi/kernels/gammaincc_kernel.h new file mode 100644 index 0000000000000..a5960fb33bca2 --- /dev/null +++ b/paddle/phi/kernels/gammaincc_kernel.h @@ -0,0 +1,27 @@ + +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GammainccKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); +} // namespace phi diff --git a/paddle/phi/kernels/gpu/gammaincc_grad_kernel.cu b/paddle/phi/kernels/gpu/gammaincc_grad_kernel.cu new file mode 100644 index 0000000000000..060806ddb1e22 --- /dev/null +++ b/paddle/phi/kernels/gpu/gammaincc_grad_kernel.cu @@ -0,0 +1,21 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gammaincc_grad_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gammaincc_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + gammaincc_grad, GPU, ALL_LAYOUT, phi::GammainccGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/gammaincc_kernel.cu b/paddle/phi/kernels/gpu/gammaincc_kernel.cu new file mode 100644 index 0000000000000..58f198af2b229 --- /dev/null +++ b/paddle/phi/kernels/gpu/gammaincc_kernel.cu @@ -0,0 +1,21 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gammaincc_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gammaincc_kernel_impl.h" + +PD_REGISTER_KERNEL( + gammaincc, GPU, ALL_LAYOUT, phi::GammainccKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/gammaincc_grad_kernel_impl.h b/paddle/phi/kernels/impl/gammaincc_grad_kernel_impl.h new file mode 100644 index 0000000000000..5a32f7ea46a2b --- /dev/null +++ b/paddle/phi/kernels/impl/gammaincc_grad_kernel_impl.h @@ -0,0 +1,62 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +namespace phi { +template +struct IgammaGradFunctor { + IgammaGradFunctor( + const T* dout, const T* x, const T* a, T* output, int64_t numel) + : dout_(dout), x_(x), a_(a), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_dout = static_cast(dout_[idx]); + const MT mp_x = static_cast(x_[idx]); + const MT mp_a = static_cast(a_[idx]); + const MT mp_a_1 = static_cast(a_[idx] - 1); + output_[idx] = static_cast(mp_dout * -std::exp(-mp_x) * + std::pow(mp_x, mp_a_1) / std::tgamma(mp_a)); + } + + private: + const T* dout_; + const T* x_; + const T* a_; + T* output_; + int64_t numel_; +}; + +template +void GammainccGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& d_out, + DenseTensor* d_y) { + auto numel = d_out.numel(); + auto* dout_data = d_out.data(); + auto* x_data = x.data(); + auto* y_data = y.data(); + auto* dy_data = + dev_ctx.template Alloc(d_y, static_cast(numel * sizeof(T))); + phi::funcs::ForRange for_range(dev_ctx, numel); + IgammaGradFunctor functor(dout_data, y_data, x_data, dy_data, numel); + for_range(functor); +} +} // namespace phi diff --git a/paddle/phi/kernels/impl/gammaincc_kernel_impl.h b/paddle/phi/kernels/impl/gammaincc_kernel_impl.h new file mode 100644 index 0000000000000..db5d0e67d12e4 --- /dev/null +++ b/paddle/phi/kernels/impl/gammaincc_kernel_impl.h @@ -0,0 +1,143 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +#define MAXLOG 7.09782712893383996732E2 +#define MACHEP 1.11022302462515654042E-16 + +namespace phi { +template +HOSTDEVICE T igam(const T a, const T x); +template +HOSTDEVICE T igamc(const T a, const T x); + +template +HOSTDEVICE T igam(const T a, const T x) { + if ((x <= T{0}) || (a <= T{0})) return (T{0.0}); + + if ((x > T{1.0}) && (x > a)) return (T{1.0} - igamc(a, x)); + + /* Compute x**a * exp(-x) / gamma(a) */ + T ax = a * log(x) - x - std::lgamma(a); + if (ax < -MAXLOG) { + return (T{0.0}); + } + ax = exp(ax); + + /* power series */ + T r = a; + T c = T{1.0}; + T ans = T{1.0}; + + do { + r += T{1.0}; + c *= x / r; + ans += c; + } while (c / ans > MACHEP); + + return (ans * ax / a); +} + +template +HOSTDEVICE T igamc(const T a, const T x) { + static T big = 4.503599627370496e15; + static T biginv = 2.22044604925031308085e-16; + + if ((x <= T{0}) || (a <= T{0})) return (T{1.0}); + + if ((x < T{1.0}) || (x < a)) return (T{1.0} - igam(a, x)); + + T ax = a * log(x) - x - std::lgamma(a); + if (ax < -MAXLOG) { + return (T{0.0}); + } + ax = exp(ax); + + /* continued fraction */ + T y = T{1.0} - a; + T z = x + y + T{1.0}; + T c = T{0.0}; + T pkm2 = T{1.0}; + T qkm2 = x; + T pkm1 = x + T{1.0}; + T qkm1 = z * x; + T ans = pkm1 / qkm1; + T t; + do { + c += T{1.0}; + y += T{1.0}; + z += T{2.0}; + T yc = y * c; + T pk = pkm1 * z - pkm2 * yc; + T qk = qkm1 * z - qkm2 * yc; + if (qk != T{0}) { + T r = pk / qk; + t = fabs((ans - r) / r); + ans = r; + } else { + t = T{1.0}; + } + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + if (fabs(pk) > big) { + pkm2 *= biginv; + pkm1 *= biginv; + qkm2 *= biginv; + qkm1 *= biginv; + } + } while (t > MACHEP); + + return (ans * ax); +} + +template +struct IgammaFunctor { + IgammaFunctor(const T* x, const T* a, T* output, int64_t numel) + : x_(x), a_(a), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_x = static_cast(x_[idx]); + const MT mp_a = static_cast(a_[idx]); + output_[idx] = static_cast(igamc(mp_a, mp_x)); + } + + private: + const T* x_; + const T* a_; + T* output_; + int64_t numel_; +}; + +template +void GammainccKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + auto numel = x.numel(); + auto* x_data = x.data(); + auto* y_data = y.data(); + auto* out_data = dev_ctx.template Alloc(out); + phi::funcs::ForRange for_range(dev_ctx, numel); + IgammaFunctor functor(y_data, x_data, out_data, numel); + for_range(functor); +} +} // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 2666a1295adb5..4603248f0fd10 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -404,6 +404,10 @@ frac, frac_, frexp, + gammainc, + gammainc_, + gammaincc, + gammaincc_, gammaln, gammaln_, gcd, @@ -775,6 +779,10 @@ 'neg_', 'lgamma', 'lgamma_', + 'gammaincc', + 'gammaincc_', + 'gammainc', + 'gammainc_', 'lerp', 'erfinv', 'inner', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 0ab10993b8aa7..4513bcbdba8f8 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -284,6 +284,10 @@ frac, frac_, frexp, + gammainc, + gammainc_, + gammaincc, + gammaincc_, gammaln, gammaln_, gcd, @@ -578,6 +582,10 @@ 'neg_', 'lgamma', 'lgamma_', + 'gammaincc', + 'gammaincc_', + 'gammainc', + 'gammainc_', 'equal', 'equal_', 'equal_all', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 6e3150b5d1f6d..48f6843ba00c8 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5100,6 +5100,105 @@ def digamma_(x, name=None): return _C_ops.digamma_(x) +def gammaincc(x, y, name=None): + r""" + Computes the regularized upper incomplete gamma function. + + .. math:: Q(x, y) = \frac{1}{\Gamma(x)} \int_{y}^{\infty} t^{x-1} e^{-t} dt + + Args: + x (Tensor): The non-negative argument Tensor. Must be one of the following types: float32, float64. + y (Tensor): The positive parameter Tensor. Must be one of the following types: float32, float64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, the gammaincc of the input Tensor. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> x = paddle.to_tensor([0.5, 0.5, 0.5, 0.5, 0.5], dtype="float32") + >>> y = paddle.to_tensor([0, 1, 10, 100, 1000], dtype="float32") + >>> out = paddle.gammaincc(x, y) + >>> print(out) + Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True, + [1. , 0.15729916, 0.00000774, 0. , 0. ]) + """ + if not paddle.all(paddle.greater_equal(x, paddle.zeros_like(x))): + raise ValueError( + "The input argument x must be greater than or equal to 0." + ) + if not paddle.all(paddle.greater_equal(y, paddle.zeros_like(y))): + raise ValueError( + "The input argument y must be greater than or equal to 0." + ) + if in_dynamic_or_pir_mode(): + return _C_ops.gammaincc(x, y) + else: + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'gammaincc') + check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'gammaincc') + helper = LayerHelper('gammaincc', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='gammaincc', inputs={'x': x, 'y': y}, outputs={'out': out} + ) + return out + + +@inplace_apis_in_dygraph_only +def gammaincc_(x, y, name=None): + r""" + Inplace version of ``gammaincc`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_gammaincc`. + """ + if in_dynamic_mode(): + return _C_ops.gammaincc_(x, y) + + +def gammainc(x, y, name=None): + r""" + Computes the regularized lower incomplete gamma function. + + .. math:: P(x, y) = \frac{1}{\Gamma(x)} \int_{0}^{y} t^{x-1} e^{-t} dt + + Args: + x (Tensor): The non-negative argument Tensor. Must be one of the following types: float32, float64. + y (Tensor): The positive parameter Tensor. Must be one of the following types: float32, float64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, the gammainc of the input Tensor. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> x = paddle.to_tensor([0.5, 0.5, 0.5, 0.5, 0.5], dtype="float32") + >>> y = paddle.to_tensor([0, 1, 10, 100, 1000], dtype="float32") + >>> out = paddle.gammainc(x, y) + >>> print(out) + Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True, + [0. , 0.84270084, 0.99999225, 1. , 1. ]) + """ + return 1 - paddle.gammaincc(x, y) + + +@inplace_apis_in_dygraph_only +def gammainc_(x, y, name=None): + r""" + Inplace version of ``gammainc`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_gammainc`. + """ + return ( + paddle.gammaincc_(x, y) + .multiply_(paddle.full_like(x, -1.0)) + .add_(paddle.full_like(x, 1.0)) + ) + + def lgamma(x, name=None): r""" Calculates the lgamma of the given input tensor, element-wise. diff --git a/test/legacy_test/test_gammainc.py b/test/legacy_test/test_gammainc.py new file mode 100644 index 0000000000000..1ffac938c1233 --- /dev/null +++ b/test/legacy_test/test_gammainc.py @@ -0,0 +1,74 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from scipy import special + +import paddle +from paddle.base import core + + +def ref_gammainc(x, y): + return special.gammainc(x, y) + + +class TestGammaincApi(unittest.TestCase): + def setUp(self): + self.shape = [2, 3, 4, 5] + self.init_dtype_type() + self.x_np = np.random.random(self.shape).astype(self.dtype) + 1 + self.y_np = np.random.random(self.shape).astype(self.dtype) + 1 + self.place = ( + paddle.CUDAPlace(0) + if core.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def init_dtype_type(self): + self.dtype = "float64" + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('x', self.x_np.shape, self.x_np.dtype) + y = paddle.static.data('y', self.y_np.shape, self.y_np.dtype) + out = paddle.gammainc(x, y) + exe = paddle.static.Executor(self.place) + (res,) = exe.run( + feed={'x': self.x_np, 'y': self.y_np}, fetch_list=[out] + ) + out_ref = ref_gammainc(self.x_np, self.y_np) + np.testing.assert_allclose(out_ref, res, rtol=1e-6, atol=1e-6) + self.assertEqual(out.dtype, x.dtype) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + y = paddle.to_tensor(self.y_np) + out = paddle.gammainc(x, y) + out_ref = ref_gammainc(self.x_np, self.y_np) + np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-6, atol=1e-6) + self.assertEqual(out.dtype, x.dtype) + paddle.enable_static() + + +class TestGammaincApiFp32(TestGammaincApi): + def init_dtype_type(self): + self.dtype = "float32" + + +if __name__ == "__main__": + unittest.main() diff --git a/test/legacy_test/test_gammaincc_op.py b/test/legacy_test/test_gammaincc_op.py new file mode 100644 index 0000000000000..1e22567d151ec --- /dev/null +++ b/test/legacy_test/test_gammaincc_op.py @@ -0,0 +1,134 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest +from scipy import special + +import paddle +from paddle.base import core + + +def ref_gammaincc(x, y): + return special.gammaincc(x, y) + + +class TestGammainccOp(OpTest): + def setUp(self): + self.op_type = 'gammaincc' + self.python_api = paddle.gammaincc + self.init_dtype_type() + self.shape = (3, 40) + self.x = np.random.random(self.shape).astype(self.dtype) + 1 + self.y = np.random.random(self.shape).astype(self.dtype) + 1 + self.inputs = {'x': self.x, 'y': self.y} + out = ref_gammaincc(self.x, self.y) + self.outputs = {'out': out} + + def init_dtype_type(self): + self.dtype = np.float64 + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad(self): + self.check_grad(['y'], 'out', check_pir=True) + + +class TestGammainccOpFp32(TestGammainccOp): + def init_dtype_type(self): + self.dtype = np.float32 + + +class TestGammainccOpApi(unittest.TestCase): + def setUp(self): + self.shape = [2, 3, 4, 5] + self.init_dtype_type() + self.x_np = np.random.random(self.shape).astype(self.dtype) + 1 + self.y_np = np.random.random(self.shape).astype(self.dtype) + 1 + self.place = ( + paddle.CUDAPlace(0) + if core.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def init_dtype_type(self): + self.dtype = "float64" + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('x', self.x_np.shape, self.x_np.dtype) + y = paddle.static.data('y', self.y_np.shape, self.y_np.dtype) + out = paddle.gammaincc(x, y) + exe = paddle.static.Executor(self.place) + (res,) = exe.run( + feed={'x': self.x_np, 'y': self.y_np}, fetch_list=[out] + ) + out_ref = ref_gammaincc(self.x_np, self.y_np) + np.testing.assert_allclose(out_ref, res, rtol=1e-6, atol=1e-6) + + def test_dygraph_api(self): + paddle.disable_static() + x = paddle.to_tensor(self.x_np) + y = paddle.to_tensor(self.y_np) + out = paddle.gammaincc(x, y) + out_ref = ref_gammaincc(self.x_np, self.y_np) + np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-6, atol=1e-6) + + def test_x_le_zero_error(self): + paddle.disable_static() + x = paddle.to_tensor(self.x_np) + y = paddle.to_tensor(self.y_np) + x[0] = -1 + self.assertRaises(ValueError, paddle.gammaincc, x, y) + + def test_a_le_zero_error(self): + paddle.disable_static() + x = paddle.to_tensor(self.x_np) + y = paddle.to_tensor(self.y_np) + y[0] = -1 + self.assertRaises(ValueError, paddle.gammaincc, x, y) + + def test_dtype_error(self): + paddle.enable_static() + # in static graph mode + with self.assertRaises(TypeError): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data( + name="x", shape=self.shape, dtype="int32" + ) + y = paddle.static.data( + name="y", shape=self.shape, dtype="int32" + ) + out = paddle.gammaincc(x, y) + + paddle.disable_static() + # in dynamic mode + with self.assertRaises(RuntimeError): + with paddle.base.dygraph.guard(): + x = paddle.to_tensor(self.x_np, dtype="int32") + y = paddle.to_tensor(self.y_np, dtype="int32") + res = paddle.gammaincc(x, y) + + +class TestGammainccOpFp32Api(TestGammainccOpApi): + def init_dtype_type(self): + self.dtype = "float32" + + +if __name__ == "__main__": + unittest.main() diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index 5f9fcb7be1e64..6ef5ae9b5135c 100755 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -909,6 +909,70 @@ def non_inplace_api_processing(self, var): return paddle.neg(var) +class TestDygraphInplaceGammaincc(TestDygraphInplace): + def init_data(self): + self.shape = (3, 40) + self.dtype = "float32" + self.input_var_numpy = ( + np.random.random(self.shape).astype(self.dtype) + 1 + ) + self.y = paddle.rand(shape=self.shape, dtype=self.dtype) + 1 + + def inplace_api_processing(self, var): + return paddle.gammaincc_(var, y=self.y) + + def non_inplace_api_processing(self, var): + return paddle.gammaincc(var, y=self.y) + + def test_backward_error(self): + pass + + def test_backward_success_1(self): + pass + + def test_backward_success_2(self): + pass + + +class TestDygraphInplaceGammainc(TestDygraphInplace): + def init_data(self): + self.shape = (3, 40) + self.dtype = "float32" + self.input_var_numpy = ( + np.random.random(self.shape).astype(self.dtype) + 1 + ) + self.y = paddle.rand(shape=self.shape, dtype=self.dtype) + 1 + + def inplace_api_processing(self, var): + return paddle.gammainc_(var, y=self.y) + + def non_inplace_api_processing(self, var): + return paddle.gammainc(var, y=self.y) + + def test_forward_version(self): + with paddle.base.dygraph.guard(): + var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + self.assertEqual(var.inplace_version, 0) + + inplace_var = self.inplace_api_processing(var) + self.assertEqual(var.inplace_version, 3) + + inplace_var[0] = 2 + self.assertEqual(var.inplace_version, 4) + + inplace_var = self.inplace_api_processing(inplace_var) + self.assertEqual(var.inplace_version, 7) + + def test_backward_error(self): + pass + + def test_backward_success_1(self): + pass + + def test_backward_success_2(self): + pass + + class TestDygraphInplaceLgamma(TestDygraphInplaceWithContinuous): def inplace_api_processing(self, var): return paddle.lgamma_(var)