diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index d5748145ffe49..938ea9d500046 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -922,16 +922,6 @@ kernel : func : frame_grad -- backward_op : gammaln_grad - forward : gammaln(Tensor x) -> Tensor(out) - args : (Tensor x, Tensor out_grad) - output : Tensor(x_grad) - infer_meta : - func : UnchangedInferMeta - param: [x] - kernel : - func : gammaln_grad - - backward_op : gather_grad forward : gather(Tensor x, Tensor index, Scalar axis=0) -> Tensor(out) args : (Tensor x, Tensor index, Tensor out_grad, Scalar axis=0) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index dc545b7a2da54..de4d700cdf80e 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1042,16 +1042,6 @@ data_type : dtype backend : place -- op : gammaln - args : (Tensor x) - output : Tensor(out) - infer_meta : - func : UnchangedInferMeta - kernel : - func : gammaln - inplace: (x -> out) - backward : gammaln_grad - - op : gather args : (Tensor x, Tensor index, Scalar axis=0) output : Tensor(out) diff --git a/paddle/phi/kernels/cpu/gammaln_grad_kernel.cc b/paddle/phi/kernels/cpu/gammaln_grad_kernel.cc deleted file mode 100644 index c52ee8b3848e9..0000000000000 --- a/paddle/phi/kernels/cpu/gammaln_grad_kernel.cc +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/gammaln_grad_kernel.h" - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h" - -PD_REGISTER_KERNEL( - gammaln_grad, CPU, ALL_LAYOUT, phi::GammalnGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/gammaln_kernel.cc b/paddle/phi/kernels/cpu/gammaln_kernel.cc deleted file mode 100644 index ff62f86d2522f..0000000000000 --- a/paddle/phi/kernels/cpu/gammaln_kernel.cc +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/gammaln_kernel.h" - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/gammaln_kernel_impl.h" - -PD_REGISTER_KERNEL( - gammaln, CPU, ALL_LAYOUT, phi::GammalnKernel, float, double) {} diff --git a/paddle/phi/kernels/gammaln_grad_kernel.h b/paddle/phi/kernels/gammaln_grad_kernel.h deleted file mode 100644 index 440dca72a9d46..0000000000000 --- a/paddle/phi/kernels/gammaln_grad_kernel.h +++ /dev/null @@ -1,27 +0,0 @@ - -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" - -namespace phi { - -template -void GammalnGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& d_out, - DenseTensor* d_x); -} // namespace phi diff --git a/paddle/phi/kernels/gammaln_kernel.h b/paddle/phi/kernels/gammaln_kernel.h deleted file mode 100644 index db3015c4a747d..0000000000000 --- a/paddle/phi/kernels/gammaln_kernel.h +++ /dev/null @@ -1,26 +0,0 @@ - -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" - -namespace phi { - -template -void GammalnKernel(const Context& dev_ctx, - const DenseTensor& x, - DenseTensor* out); -} // namespace phi diff --git a/paddle/phi/kernels/gpu/gammaln_grad_kernel.cu b/paddle/phi/kernels/gpu/gammaln_grad_kernel.cu deleted file mode 100644 index b2513d9e3f25c..0000000000000 --- a/paddle/phi/kernels/gpu/gammaln_grad_kernel.cu +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/gammaln_grad_kernel.h" - -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/common/amp_type_traits.h" -#include "paddle/phi/core/kernel_registry.h" - -#include "paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h" - -PD_REGISTER_KERNEL(gammaln_grad, - GPU, - ALL_LAYOUT, - phi::GammalnGradKernel, - float, - double, - phi::dtype::float16, - phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/gammaln_kernel.cu b/paddle/phi/kernels/gpu/gammaln_kernel.cu deleted file mode 100644 index 3d57be7b27733..0000000000000 --- a/paddle/phi/kernels/gpu/gammaln_kernel.cu +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/gammaln_kernel.h" - -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/kernel_registry.h" - -#include "paddle/phi/kernels/impl/gammaln_kernel_impl.h" - -PD_REGISTER_KERNEL(gammaln, - GPU, - ALL_LAYOUT, - phi::GammalnKernel, - float, - double, - phi::dtype::float16, - phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h deleted file mode 100644 index 50c73cff27ce4..0000000000000 --- a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/common/amp_type_traits.h" -#include "paddle/phi/kernels/funcs/for_range.h" - -namespace phi { -template -HOSTDEVICE T digamma(T x) { - static T c = T{8.5}; - static T euler_mascheroni = T{0.57721566490153286060}; - T r; - T value; - T x2; - - if (x <= T{0.0}) { - value = T{0.0}; - return value; - } - - if (x <= T{0.000001}) { - value = -euler_mascheroni - T{1.0} / x + T{1.6449340668482264365} * x; - return value; - } - - value = T{0.0}; - x2 = x; - while (x2 < c) { - value = value - T{1.0} / x2; - x2 = x2 + T{1.0}; - } - - r = T{1.0} / x2; - value = value + std::log(x2) - T{0.5} * r; - - r = r * r; - - value = value - - r * (T{1.0} / T{12.0} - - r * (T{1.0} / T{120.0} - - r * (T{1.0} / T{252.0} - - r * (T{1.0} / T{240.0} - r * (T{1.0} / T{132.0}))))); - - return value; -} - -template -struct GammalnGradFunctor { - GammalnGradFunctor(const T* dout, const T* x, T* output, int64_t numel) - : dout_(dout), x_(x), output_(output), numel_(numel) {} - - HOSTDEVICE void operator()(int64_t idx) const { - using MT = typename phi::dtype::MPTypeTrait::Type; - const MT mp_dout = static_cast(dout_[idx]); - const MT mp_x = static_cast(x_[idx]); - output_[idx] = static_cast(mp_dout * digamma(mp_x)); - } - - private: - const T* dout_; - const T* x_; - T* output_; - int64_t numel_; -}; -template -void GammalnGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& d_out, - DenseTensor* d_x) { - auto numel = d_out.numel(); - auto* dout_data = d_out.data(); - auto* x_data = x.data(); - auto* dx_data = - dev_ctx.template Alloc(d_x, static_cast(numel * sizeof(T))); - phi::funcs::ForRange for_range(dev_ctx, numel); - GammalnGradFunctor functor(dout_data, x_data, dx_data, numel); - for_range(functor); -} -} // namespace phi diff --git a/paddle/phi/kernels/impl/gammaln_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_kernel_impl.h deleted file mode 100644 index 38385610de0de..0000000000000 --- a/paddle/phi/kernels/impl/gammaln_kernel_impl.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/common/amp_type_traits.h" -#include "paddle/phi/kernels/funcs/for_range.h" - -namespace phi { -template -struct GammalnFunctor { - GammalnFunctor(const T* x, T* output, int64_t numel) - : x_(x), output_(output), numel_(numel) {} - - HOSTDEVICE void operator()(int64_t idx) const { - using MT = typename phi::dtype::MPTypeTrait::Type; - const MT mp_x = static_cast(x_[idx]); - output_[idx] = static_cast(std::lgamma(mp_x)); - } - - private: - const T* x_; - T* output_; - int64_t numel_; -}; - -template -void GammalnKernel(const Context& dev_ctx, - const DenseTensor& x, - DenseTensor* out) { - auto numel = x.numel(); - auto* x_data = x.data(); - auto* out_data = dev_ctx.template Alloc(out); - phi::funcs::ForRange for_range(dev_ctx, numel); - GammalnFunctor functor(x_data, out_data, numel); - for_range(functor); -} -} // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 1f0017562ebad..fc7b2a3533f89 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -398,8 +398,6 @@ frac, frac_, frexp, - gammaln, - gammaln_, gcd, gcd_, heaviside, @@ -775,8 +773,6 @@ 'square_', 'divide', 'divide_', - 'gammaln', - 'gammaln_', 'ceil', 'atan', 'atan_', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index b718910348d8f..b26798892a2b2 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -278,8 +278,6 @@ frac, frac_, frexp, - gammaln, - gammaln_, gcd, gcd_, heaviside, @@ -670,8 +668,6 @@ 'real', 'imag', 'is_floating_point', - 'gammaln', - 'gammaln_', 'digamma', 'digamma_', 'diagonal', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 6d75d41b4949c..acaa0905ce6f4 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5003,51 +5003,6 @@ def conj(x, name=None): return out -def gammaln(x, name=None): - r""" - Calculates the logarithm of the absolute value of the gamma function elementwisely. - - Args: - x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64, bfloat16. - name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. - - Returns: - Tensor, The values of the logarithm of the absolute value of the gamma at the given tensor x. - - Examples: - .. code-block:: python - - >>> import paddle - - >>> x = paddle.arange(1.5, 4.5, 0.5) - >>> out = paddle.gammaln(x) - >>> print(out) - Tensor(shape=[6], dtype=float32, place=Place(cpu), stop_gradient=True, - [-0.12078224, 0. , 0.28468287, 0.69314718, 1.20097363, - 1.79175949]) - """ - if in_dynamic_or_pir_mode(): - return _C_ops.gammaln(x) - else: - check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'bfloat16'], 'gammaln' - ) - helper = LayerHelper('gammaln', **locals()) - out = helper.create_variable_for_type_inference(x.dtype) - helper.append_op(type='gammaln', inputs={'x': x}, outputs={'out': out}) - return out - - -@inplace_apis_in_dygraph_only -def gammaln_(x, name=None): - r""" - Inplace version of ``gammaln`` API, the output Tensor will be inplaced with input ``x``. - Please refer to :ref:`api_paddle_gammaln`. - """ - if in_dynamic_mode(): - return _C_ops.gammaln_(x) - - def digamma(x, name=None): r""" Calculates the digamma of the given input tensor, element-wise. diff --git a/test/legacy_test/test_gammaln_op.py b/test/legacy_test/test_gammaln_op.py deleted file mode 100644 index 50331af5c7a34..0000000000000 --- a/test/legacy_test/test_gammaln_op.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from op_test import OpTest, convert_float_to_uint16 -from scipy import special - -import paddle -from paddle.base import core - - -def ref_gammaln(x): - return special.gammaln(x) - - -def ref_gammaln_grad(x, dout): - return dout * special.polygamma(0, x) - - -class TestGammalnOp(OpTest): - def setUp(self): - self.op_type = 'gammaln' - self.python_api = paddle.gammaln - self.init_dtype_type() - self.shape = (3, 40) - self.x = np.random.random(self.shape).astype(self.dtype) + 1 - self.inputs = {'x': self.x} - out = ref_gammaln(self.x) - self.outputs = {'out': out} - - def init_dtype_type(self): - self.dtype = np.float64 - - def test_check_output(self): - self.check_output(check_pir=True) - - def test_check_grad(self): - self.check_grad(['x'], 'out', check_pir=True) - - -class TestGammalnOpFp32(TestGammalnOp): - def init_dtype_type(self): - self.dtype = np.float32 - - -class TestGammalnFP16Op(TestGammalnOp): - def init_dtype_type(self): - self.dtype = np.float16 - - -class TestGammalnBigNumberOp(TestGammalnOp): - def setUp(self): - self.op_type = 'gammaln' - self.python_api = paddle.gammaln - self.init_dtype_type() - self.shape = (100, 1) - self.x = np.random.random(self.shape).astype(self.dtype) + 1 - self.x[:5, 0] = np.array([1e5, 1e10, 1e20, 1e40, 1e80]) - self.inputs = {'x': self.x} - out = ref_gammaln(self.x) - self.outputs = {'out': out} - - def init_dtype_type(self): - self.dtype = np.float64 - - def test_check_grad(self): - d_out = self.outputs['out'] - d_x = ref_gammaln_grad(self.x, d_out) - self.check_grad( - ['x'], - 'out', - user_defined_grads=[ - d_x, - ], - user_defined_grad_outputs=[ - d_out, - ], - check_pir=True, - ) - - -@unittest.skipIf( - not core.is_compiled_with_cuda() - or not core.is_bfloat16_supported(core.CUDAPlace(0)), - "core is not compiled with CUDA or not support bfloat16", -) -class TestGammalnBF16Op(OpTest): - def setUp(self): - self.op_type = 'gammaln' - self.python_api = paddle.gammaln - self.dtype = np.uint16 - self.shape = (5, 30) - x = np.random.random(self.shape).astype("float32") + 1 - self.inputs = {'x': convert_float_to_uint16(x)} - out = ref_gammaln(x) - self.outputs = {'out': convert_float_to_uint16(out)} - - def test_check_output(self): - self.check_output_with_place(core.CUDAPlace(0), check_pir=True) - - def test_check_grad(self): - self.check_grad_with_place( - core.CUDAPlace(0), ['x'], 'out', check_pir=True - ) - - -class TestGammalnOpApi(unittest.TestCase): - def setUp(self): - self.shape = [2, 3, 4, 5] - self.init_dtype_type() - self.x_np = np.random.random(self.shape).astype(self.dtype) + 1 - self.place = ( - paddle.CUDAPlace(0) - if core.is_compiled_with_cuda() - else paddle.CPUPlace() - ) - - def init_dtype_type(self): - self.dtype = "float64" - - def test_static_api(self): - paddle.enable_static() - with paddle.static.program_guard(paddle.static.Program()): - x = paddle.static.data('x', self.x_np.shape, self.x_np.dtype) - out = paddle.gammaln(x) - exe = paddle.static.Executor(self.place) - (res,) = exe.run(feed={'x': self.x_np}, fetch_list=[out]) - out_ref = ref_gammaln(self.x_np) - np.testing.assert_allclose(out_ref, res, rtol=1e-5, atol=1e-5) - - def test_dygraph_api(self): - paddle.disable_static(self.place) - x = paddle.to_tensor(self.x_np) - out = paddle.gammaln(x) - out_ref = ref_gammaln(self.x_np) - np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-5, atol=1e-5) - paddle.enable_static() - - -class TestGammalnOpApiFp32(TestGammalnOpApi): - def init_dtype_type(self): - self.dtype = "float32" - - -if __name__ == "__main__": - paddle.enable_static() - unittest.main() diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index 38fbac0357d6d..42f9a46cfb910 100644 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -869,14 +869,6 @@ def test_leaf_inplace_var_error(self): pass -class TestDygraphInplaceGammaln(TestDygraphInplaceWithContinuous): - def inplace_api_processing(self, var): - return paddle.gammaln_(var) - - def non_inplace_api_processing(self, var): - return paddle.gammaln(var) - - class TestDygraphInplaceNeg(TestDygraphInplaceWithContinuous): def inplace_api_processing(self, var): return paddle.neg_(var)