From 40076158e0b967ef8bf71324174a0834c9e5dfa3 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Fri, 1 Apr 2022 13:59:43 +0800 Subject: [PATCH 01/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0logspace=E7=9A=84?= =?UTF-8?q?=E7=AE=97=E5=AD=90=E6=8F=8F=E8=BF=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/fluid/operators/logspace_op.cc | 77 +++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 paddle/fluid/operators/logspace_op.cc diff --git a/paddle/fluid/operators/logspace_op.cc b/paddle/fluid/operators/logspace_op.cc new file mode 100644 index 0000000000000..6d1154992ca8b --- /dev/null +++ b/paddle/fluid/operators/logspace_op.cc @@ -0,0 +1,77 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" + +namespace paddle { +namespace operators { + +class LogspaceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::proto::VarType::Type(ctx.Attr("dtype")), + ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + return expected_kernel_type; + } +}; + +class LogspaceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Start", + "Exponent of first entry in the sequence. It is a tensor of " + "shape [1], should be of type float32 or float64."); + AddInput("Stop", + "Exponent of last entry in the sequence. It is a tensor of " + "shape [1], should be of type float32 or float64."); + AddInput("Num", + "Number of entry in the sequence. It is a tensor of shape [1], " + "should be of type int32."); + AddInput("Base", + "Base of the logarithm function. It is a tensor of shape [1], " + "should be of type int32."); + AddAttr("dtype", "The output data type."); + AddOutput("Out", "A sequence of numbers."); + AddComment(R"DOC( + Return fixed number of logarithmical-evenly spaced values within a given + interval. First entry is exponential of Start with base Base, and last + entry is exponential of Stop with base Base. In the case when Num is 1, + only exponential of Start with base Base is returned. + Like logspace function of numpy. + )DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(logspace, LogspaceInferShapeFunctor, + PD_INFER_META(phi::LogspaceInferMeta)); +REGISTER_OPERATOR( + logspace, ops::LogspaceOp, ops::LogspaceOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + LogspaceInferShapeFunctor); From abc377f0ad50434379381abab0b5e81629c09923 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Fri, 1 Apr 2022 14:00:14 +0800 Subject: [PATCH 02/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0logspace=E7=9A=84?= =?UTF-8?q?=E5=BD=A2=E7=8A=B6=E6=8E=A8=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/infermeta/multiary.cc | 37 ++++++++++++++++++++++++++++++++ paddle/phi/infermeta/multiary.h | 6 ++++++ 2 files changed, 43 insertions(+) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index c6940492ce696..44aefbf210dd6 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -890,6 +890,43 @@ void HierarchicalSigmoidInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void LogspaceInferMeta(const MetaTensor& start, + const MetaTensor& stop, + const MetaTensor& number, + const MetaTensor& base, + MetaTensor* out) { + auto s_dims = start.dims(); + PADDLE_ENFORCE_EQ( + (s_dims.size() == 1) && (s_dims[0] == 1), + true, + phi::errors::InvalidArgument("The shape of Input(Start) must be [1]," + "but received input shape is [%s].", + s_dims)); + auto e_dims = stop.dims(); + PADDLE_ENFORCE_EQ( + (e_dims.size() == 1) && (e_dims[0] == 1), + true, + phi::errors::InvalidArgument("The shape of Input(Stop) must be [1]," + "but received input shape is [%s].", + e_dims)); + auto num_dims = number.dims(); + PADDLE_ENFORCE_EQ( + (num_dims.size() == 1) && (num_dims[0] == 1), + true, + phi::errors::InvalidArgument("The shape of Input(Num) must be [1]," + "but received input shape is [%s].", + num_dims)); + auto b_dims = base.dims(); + PADDLE_ENFORCE_EQ( + (b_dims.size() == 1) && (b_dims[0] == 1), + true, + phi::errors::InvalidArgument("The shape of Input(Base) must be [1]," + "but received input shape is [%s].", + b_dims)); + out->set_dims(phi::make_ddim({-1})); + out->set_dtype(start.dtype()); +} + void MultiDotInferMeta(const std::vector& x, MetaTensor* out) { auto inputs_dims = GetMetaTensorsDim(x); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 4a8020aefca50..b4aa0b1908c13 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -199,6 +199,12 @@ void HierarchicalSigmoidInferMeta(const MetaTensor& x, MetaTensor* pre_out, MetaTensor* w_out); +void LogspaceInferMeta(const MetaTensor& start, + const MetaTensor& stop, + const MetaTensor& number, + const MetaTensor& base, + MetaTensor* out); + void MultiDotInferMeta(const std::vector& x, MetaTensor* out); void MultiplexInferMeta(const std::vector& ins, From 18f7ac310e6178907c3aa2ef74acefaa283a4045 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Fri, 1 Apr 2022 14:02:19 +0800 Subject: [PATCH 03/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0logspace=E6=A0=B8?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/kernels/cpu/logspace_kernel.cc | 76 ++++++++++++++++ paddle/phi/kernels/gpu/logspace_kernel.cu | 102 ++++++++++++++++++++++ paddle/phi/kernels/logspace_kernel.h | 27 ++++++ 3 files changed, 205 insertions(+) create mode 100644 paddle/phi/kernels/cpu/logspace_kernel.cc create mode 100644 paddle/phi/kernels/gpu/logspace_kernel.cu create mode 100644 paddle/phi/kernels/logspace_kernel.h diff --git a/paddle/phi/kernels/cpu/logspace_kernel.cc b/paddle/phi/kernels/cpu/logspace_kernel.cc new file mode 100644 index 0000000000000..be123483a3a89 --- /dev/null +++ b/paddle/phi/kernels/cpu/logspace_kernel.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/phi/kernels/logspace_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/data_type_transform.h" + +namespace phi { + +template +void LogspaceKernel(const Context& ctx, + const DenseTensor& start, + const DenseTensor& stop, + const DenseTensor& number, + const DenseTensor& base, + DataType dtype, + DenseTensor* out) { + int32_t num = number.data()[0]; + auto start_t = phi::funcs::TransDataType(ctx, start, dtype); + auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype); + auto base_t = phi::funcs::TransDataType(ctx, base, dtype); + + T start_data = start_t.template data()[0]; + T stop_data = stop_t.template data()[0]; + T base_data = base_t.template data()[0]; + PADDLE_ENFORCE_GT( + num, + 0, + phi::errors::InvalidArgument("The num of logspace op should be larger " + "than 0, but received num is %d", + num)); + + out->Resize(phi::make_ddim({num})); + T* out_data = ctx.template Alloc(out); + + if (num > 1) { + // step should be of double type for all types + double step = (static_cast(stop_data - start_data)) / (num - 1); + int half_num = num / 2; + for (int i = 0; i < num; ++i) { + if (i < half_num) { + out_data[i] = static_cast(std::pow( + base_data, start_data + step * i)); + } else { + out_data[i] = static_cast(std::pow( + base_data, stop_data - step * (num - i - 1))); + } + } + } else { + out_data[0] = static_cast(std::pow(base_data, start_data)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(logspace, + CPU, + ALL_LAYOUT, + phi::LogspaceKernel, + float, + int32_t, + int64_t, + double) {} diff --git a/paddle/phi/kernels/gpu/logspace_kernel.cu b/paddle/phi/kernels/gpu/logspace_kernel.cu new file mode 100644 index 0000000000000..49d264d275334 --- /dev/null +++ b/paddle/phi/kernels/gpu/logspace_kernel.cu @@ -0,0 +1,102 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/data_type_transform.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/logspace_kernel.h" + +namespace phi { + +template +__global__ void LogspaceKernelInner( + T start, T stop, double step, T base, int64_t size, T* out) { + int64_t index = blockIdx.x * blockDim.x + threadIdx.x; + + for (; index < size; index += blockDim.x * gridDim.x) { + if (index < size / 2) { + out[index] = static_cast(pow(base, start + step * index)); + } else { + out[index] = static_cast(pow(base, stop - step * (size - index - 1))); + } + } +} + +template +__global__ void LogspaceSpecialKernel(T start, T base, T* out) { + out[0] = static_cast(pow(base, start)); +} + +template +void LogspaceKernel(const Context& ctx, + const DenseTensor& start, + const DenseTensor& stop, + const DenseTensor& number, + const DenseTensor& base, + DataType dtype, + DenseTensor* out) { + auto start_t = phi::funcs::TransDataType(ctx, start, dtype); + auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype); + auto base_t = phi::funcs::TransDataType(ctx, base, dtype); + + DenseTensor n_start; + DenseTensor n_stop; + DenseTensor n_num; + DenseTensor n_base; + phi::Copy(ctx, start_t, phi::CPUPlace(), false, &n_start); + T start_data = n_start.data()[0]; + phi::Copy(ctx, stop_t, phi::CPUPlace(), false, &n_stop); + T stop_data = n_stop.data()[0]; + phi::Copy(ctx, number, phi::CPUPlace(), false, &n_num); + int64_t num = static_cast(n_num.data()[0]); + phi::Copy(ctx, base_t, phi::CPUPlace(), false, &n_base); + T base_data = n_base.data()[0]; + + PADDLE_ENFORCE_GT( + num, + 0, + phi::errors::InvalidArgument("The num of logspace op should be larger " + "than 0, but received num is %d", + num)); + + out->Resize(phi::make_ddim({num})); + T* out_data = ctx.template Alloc(out); + + double step = 0; + auto stream = ctx.stream(); + int block = 512; + int grid = (num + block - 1) / block; + if (num != 1) { + step = (static_cast(stop_data - start_data)) / (num - 1); + LogspaceKernelInner<<>>( + start_data, stop_data, step, base_data, num, out_data); + } else { + LogspaceSpecialKernel<<>>( + start_data, base_data, out_data); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(logspace, + GPU, + ALL_LAYOUT, + phi::LogspaceKernel, + float, + int32_t, + int64_t, + double) {} diff --git a/paddle/phi/kernels/logspace_kernel.h b/paddle/phi/kernels/logspace_kernel.h new file mode 100644 index 0000000000000..59862514e78ae --- /dev/null +++ b/paddle/phi/kernels/logspace_kernel.h @@ -0,0 +1,27 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LogspaceKernel(const Context& ctx, + const DenseTensor& start, + const DenseTensor& stop, + const DenseTensor& number, + const DenseTensor& base, + DataType dtype, + DenseTensor* out); + +} // namespace phi From 14823ed4c212daf926609b291b70f3359e8a9d38 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Fri, 1 Apr 2022 14:03:03 +0800 Subject: [PATCH 04/17] =?UTF-8?q?=E5=9C=A8python=E4=B8=AD=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0logspace=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/__init__.py | 2 + python/paddle/fluid/layers/tensor.py | 116 +++++++++++++++++++++++++++ python/paddle/tensor/creation.py | 1 + 3 files changed, 119 insertions(+) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index bba9c226dc07b..57b2c04869c8c 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -85,6 +85,7 @@ from .tensor.creation import diagflat # noqa: F401 from .tensor.creation import eye # noqa: F401 from .tensor.creation import linspace # noqa: F401 +from .tensor.creation import logspace # noqa: F401 from .tensor.creation import ones # noqa: F401 from .tensor.creation import ones_like # noqa: F401 from .tensor.creation import zeros # noqa: F401 @@ -582,6 +583,7 @@ 'sqrt', 'randperm', 'linspace', + 'logspace', 'reshape', 'reshape_', 'reverse', diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index f9f65ffb57f90..5f6b866de5b03 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -56,6 +56,7 @@ 'isfinite', 'range', 'linspace', + 'logspace', 'zeros_like', 'ones_like', 'diag', @@ -1557,6 +1558,121 @@ def linspace(start, stop, num, dtype=None, name=None): return out +def logspace(start, stop, num, base=10.0, dtype=None, name=None): + r""" + This OP return fixed number of logarithmical-evenly spaced values within a given interval. + Args: + start(int|float|Tensor): The input :attr:`start` is exponent of first entry in + the sequence. It is a scalar, or a Tensor of shape [1] with input data + type int32, int64, float32 or float64. + stop(int|float|Tensor): The input :attr:`stop` is exponent of last entry in the + sequence. It is a scalar, or a Tensor of shape [1] with input data + type int32, int64, float32 or float64. + num(int|Tensor): The input :attr:`num` is given number of items in the sequence. + It is an int scalar, or a Tensor of shape [1] with data type int32. + base(int|float|Tensor): The input :attr:`base` is base of the logarithm function. + It is a scalar, or a Tensor of shape [1] with input data type int32, int64, + float32 or float64. + dtype(np.dtype|str, optional): The data type of output tensor, it could be + int32, int64, float32 and float64. Default: if None, the data type is float32. + name(str, optional): Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name`.Default: None. + Returns: + Tensor: the output data type will be float32, float64. The 1-D tensor with + fixed number of evenly spaced values, the data shape of this tensor + is :math:`[num]` . If the :attr:`num` is set 1, the output tensor just + has the value with exponential of :attr:`start` with base :attr:`base`. + Examples: + .. code-block:: python + import paddle + data = paddle.logspace(0, 10, 5, 2, 'float32') + # [1. , 5.65685415 , 32. , 181.01933289, 1024. ] + data = paddle.logspace(0, 10, 1, 2, 'float32') + # [1.] + """ + if dtype is None: + dtype = 'float32' + tensor_num = num + tensor_start = start + tensor_stop = stop + tensor_base = base + if not isinstance(num, Variable): + check_type(num, 'num', (int), 'logspace') + if not isinstance(dtype, core.VarDesc.VarType): + dtype = convert_np_dtype_to_dtype_(dtype) + if not isinstance(start, Variable): + with device_guard("cpu"): + tensor_start = fill_constant([1], dtype, start) + if not isinstance(stop, Variable): + with device_guard("cpu"): + tensor_stop = fill_constant([1], dtype, stop) + if not isinstance(num, Variable): + with device_guard("cpu"): + tensor_num = fill_constant([1], 'int32', num) + if not isinstance(base, Variable): + with device_guard("cpu"): + tensor_base = fill_constant([1], dtype, base) + if _non_static_mode(): + return _C_ops.logspace(tensor_start, tensor_stop, tensor_num, + tensor_base, 'dtype', dtype) + + helper = LayerHelper("logspace", **locals()) + + start_dtype = convert_dtype(tensor_start.dtype) + stop_dtype = convert_dtype(tensor_stop.dtype) + base_dtype = convert_dtype(tensor_base.dtype) + out_dtype = convert_dtype(dtype) + if isinstance(start, Variable): + check_dtype(start.dtype, 'start', + ['float32', 'float64', 'int32', 'int64'], 'logspace') + else: + check_type(start, 'start', (int, float), 'logspace') + + if isinstance(stop, Variable): + check_dtype(stop.dtype, 'stop', + ['float32', 'float64', 'int32', 'int64'], 'logspace') + else: + check_type(stop, 'stop', (int, float), 'logspace') + + if isinstance(num, Variable): + check_dtype(num.dtype, 'num', ['int32'], 'logspace') + + if isinstance(base, Variable): + check_dtype(base.dtype, 'base', + ['float32', 'float64', 'int32', 'int64'], 'logspace') + else: + check_type(base, 'base', (int, float), 'logspace') + + check_dtype(dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'], + 'logspace') + if ((stop_dtype == "float64" or start_dtype == "float64" + or base_dtype == "float64") + and out_dtype in ["float32", "int32"]) or \ + ((stop_dtype == "int64" or start_dtype == "int64" + or base_dtype == "int64") + and out_dtype == "int32"): + raise ValueError( + "The dtype of start/stop/base is {}/{}/{} but the attr(dtype) of logspace is {}, " + "which may cause data type overflows. Please reset attr(dtype) of logspace." + .format(start_dtype, stop_dtype, base_dtype, dtype)) + + out = helper.create_variable_for_type_inference(dtype=dtype) + + helper.append_op( + type='logspace', + inputs={ + 'Start': tensor_start, + 'Stop': tensor_stop, + 'Num': tensor_num, + 'Base': tensor_base + }, + attrs={'dtype': dtype}, + outputs={'Out': [out]}) + if isinstance(num, int): + out.desc.set_shape((num, )) + return out + + def zeros_like(x, out=None): """ This OP creates a zeros tensor which has identical shape and dtype diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 6e7e5678be0b0..e47869e3d0d6d 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -28,6 +28,7 @@ from paddle.tensor.attribute import _complex_to_real_dtype, _real_to_complex_dtype # TODO: define functions to get create a tensor from ..fluid.layers import linspace # noqa: F401 +from ..fluid.layers import logspace # noqa: F401 import paddle from paddle import _C_ops from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _in_eager_without_dygraph_check From 7e33b85d62bccdcbc9e4cbee93da1a0e4eea388e Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Fri, 1 Apr 2022 14:03:24 +0800 Subject: [PATCH 05/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0logspace=E5=8D=95?= =?UTF-8?q?=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../fluid/tests/unittests/test_logspace.py | 225 ++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_logspace.py diff --git a/python/paddle/fluid/tests/unittests/test_logspace.py b/python/paddle/fluid/tests/unittests/test_logspace.py new file mode 100644 index 0000000000000..3e9835efc62bd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_logspace.py @@ -0,0 +1,225 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +from paddle.fluid import core + + +class TestLogspaceOpCommonCase(OpTest): + def setUp(self): + self.op_type = "logspace" + dtype = 'float32' + self.inputs = { + 'Start': np.array([0]).astype(dtype), + 'Stop': np.array([10]).astype(dtype), + 'Num': np.array([11]).astype('int32'), + 'Base': np.array([2]).astype(dtype), + } + self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} + + self.outputs = {'Out': np.power(2, np.arange(0, 11)).astype(dtype)} + + def test_check_output(self): + self.check_output() + + +class TestLogspaceOpReverseCase(OpTest): + def setUp(self): + self.op_type = "logspace" + dtype = 'float32' + self.inputs = { + 'Start': np.array([10]).astype(dtype), + 'Stop': np.array([0]).astype(dtype), + 'Num': np.array([11]).astype('int32'), + 'Base': np.array([2]).astype(dtype) + } + self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} + + self.outputs = {'Out': np.power(2, np.arange(10, -1, -1)).astype(dtype)} + + def test_check_output(self): + self.check_output() + + +class TestLogspaceOpNumOneCase(OpTest): + def setUp(self): + self.op_type = "logspace" + dtype = 'float32' + self.inputs = { + 'Start': np.array([10]).astype(dtype), + 'Stop': np.array([0]).astype(dtype), + 'Num': np.array([1]).astype('int32'), + 'Base': np.array([2]).astype(dtype) + } + self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} + + self.outputs = {'Out': np.power(2, np.array(10)).astype(dtype)} + + def test_check_output(self): + self.check_output() + + +class TestLogspaceOpMinusBaseCase(OpTest): + def setUp(self): + self.op_type = "logspace" + dtype = 'float32' + self.inputs = { + 'Start': np.array([0]).astype(dtype), + 'Stop': np.array([10]).astype(dtype), + 'Num': np.array([11]).astype('int32'), + 'Base': np.array([-2]).astype(dtype), + } + self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} + + self.outputs = {'Out': np.power(-2, np.arange(0, 11)).astype(dtype)} + + def test_check_output(self): + self.check_output() + + +class TestLogspaceOpZeroBaseCase(OpTest): + def setUp(self): + self.op_type = "logspace" + dtype = 'float32' + self.inputs = { + 'Start': np.array([0]).astype(dtype), + 'Stop': np.array([10]).astype(dtype), + 'Num': np.array([11]).astype('int32'), + 'Base': np.array([0]).astype(dtype), + } + self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} + + self.outputs = {'Out': np.power(0, np.arange(0, 11)).astype(dtype)} + + def test_check_output(self): + self.check_output() + + +class TestLogspaceAPI(unittest.TestCase): + def test_variable_input1(self): + start = paddle.full(shape=[1], fill_value=0, dtype='float32') + stop = paddle.full(shape=[1], fill_value=10, dtype='float32') + num = paddle.full(shape=[1], fill_value=5, dtype='int32') + base = paddle.full(shape=[1], fill_value=2, dtype='float32') + out = paddle.logspace(start, stop, num, base, dtype='float32') + exe = fluid.Executor(place=fluid.CPUPlace()) + res = exe.run(fluid.default_main_program(), fetch_list=[out]) + np_res = np.logspace(0, 10, 5, base=2, dtype='float32') + self.assertEqual((res == np_res).all(), True) + + def test_variable_input2(self): + paddle.disable_static() + start = paddle.full(shape=[1], fill_value=0, dtype='float32') + stop = paddle.full(shape=[1], fill_value=10, dtype='float32') + num = paddle.full(shape=[1], fill_value=5, dtype='int32') + base = paddle.full(shape=[1], fill_value=2, dtype='float32') + out = paddle.logspace(start, stop, num, base, dtype='float32') + np_res = np.logspace(0, 10, 5, base=2, dtype='float32') + self.assertEqual((out.numpy() == np_res).all(), True) + paddle.enable_static() + + def test_dtype(self): + paddle.enable_static() + out_1 = paddle.logspace(0, 10, 5, 2, dtype='float32') + out_2 = paddle.logspace(0, 10, 5, 2, dtype=np.float32) + out_3 = paddle.logspace(0, 10, 5, 2, dtype=core.VarDesc.VarType.FP32) + exe = fluid.Executor(place=fluid.CPUPlace()) + res_1, res_2, res_3 = exe.run(fluid.default_main_program(), + fetch_list=[out_1, out_2, out_3]) + assert np.array_equal(res_1, res_2) + paddle.disable_static() + + def test_name(self): + with paddle.static.program_guard(paddle.static.Program()): + out = paddle.logspace( + 0, 10, 5, 2, dtype='float32', name='logspace_res') + assert 'logspace_res' in out.name + + def test_imperative(self): + paddle.disable_static() + out1 = paddle.logspace(0, 10, 5, 2, dtype='float32') + np_out1 = np.logspace(0, 10, 5, base=2, dtype='float32') + out2 = paddle.logspace(0, 10, 5, 2, dtype='int32') + np_out2 = np.logspace(0, 10, 5, base=2, dtype='int32') + out3 = paddle.logspace(0, 10, 200, 2, dtype='int32') + np_out3 = np.logspace(0, 10, 200, base=2, dtype='int32') + paddle.enable_static() + self.assertEqual((out1.numpy() == np_out1).all(), True) + self.assertEqual((out2.numpy() == np_out2).all(), True) + self.assertEqual((out3.numpy() == np_out3).all(), True) + + +class TestLogspaceOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + + def test_dtype(): + fluid.layers.logspace(0, 10, 1, 2, dtype="int8") + + self.assertRaises(TypeError, test_dtype) + + def test_dtype1(): + fluid.layers.logspace(0, 10, 1.33, 2, dtype="int32") + + self.assertRaises(TypeError, test_dtype1) + + def test_start_type(): + fluid.layers.logspace([0], 10, 1, 2, dtype="float32") + + self.assertRaises(TypeError, test_start_type) + + def test_end_type(): + fluid.layers.logspace(0, [10], 1, 2, dtype="float32") + + self.assertRaises(TypeError, test_end_type) + + def test_num_type(): + fluid.layers.logspace(0, 10, [0], 2, dtype="float32") + + self.assertRaises(TypeError, test_num_type) + + def test_start_dtype(): + start = fluid.data(shape=[1], dtype="float64", name="start") + fluid.layers.logspace(start, 10, 1, 2, dtype="float32") + + self.assertRaises(ValueError, test_start_dtype) + + def test_end_dtype(): + end = fluid.data(shape=[1], dtype="float64", name="end") + fluid.layers.logspace(0, end, 1, 2, dtype="float32") + + self.assertRaises(ValueError, test_end_dtype) + + def test_num_dtype(): + num = fluid.data(shape=[1], dtype="float32", name="step") + fluid.layers.logspace(0, 10, num, 2, dtype="float32") + + self.assertRaises(TypeError, test_num_dtype) + + def test_base_dtype(): + base = fluid.data(shape=[1], dtype="float64", name="end") + fluid.layers.logspace(0, 10, 1, base, dtype="float32") + + self.assertRaises(ValueError, test_base_dtype) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 7fb1d2ea56c9c6c4dbf6a9d2bfc8fbe9de7a62ef Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Fri, 1 Apr 2022 14:03:37 +0800 Subject: [PATCH 06/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0logspace?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/static_mode_white_list.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 365047f7e8382..d87e52a4f49f0 100755 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -306,6 +306,7 @@ 'test_linear_interp_op', 'test_linear_interp_v2_op', 'test_linspace', + 'test_logspace', 'test_load_op', 'test_load_vars_shape_check', 'test_locality_aware_nms_op', From efe557724a6352bfa6b1fc93e0578b7337fc8885 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 2 Apr 2022 10:42:19 +0800 Subject: [PATCH 07/17] Update logspace_kernel.cu --- paddle/phi/kernels/gpu/logspace_kernel.cu | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/gpu/logspace_kernel.cu b/paddle/phi/kernels/gpu/logspace_kernel.cu index 49d264d275334..b0d0e2523454a 100644 --- a/paddle/phi/kernels/gpu/logspace_kernel.cu +++ b/paddle/phi/kernels/gpu/logspace_kernel.cu @@ -29,16 +29,21 @@ __global__ void LogspaceKernelInner( for (; index < size; index += blockDim.x * gridDim.x) { if (index < size / 2) { - out[index] = static_cast(pow(base, start + step * index)); + out[index] = static_cast(pow( + static_cast(base), + static_cast(start + step * index))); } else { - out[index] = static_cast(pow(base, stop - step * (size - index - 1))); + out[index] = static_cast(pow( + static_cast(base), + static_cast(stop - step * (size - index - 1)))); } } } template __global__ void LogspaceSpecialKernel(T start, T base, T* out) { - out[0] = static_cast(pow(base, start)); + out[0] = static_cast(pow( + static_cast(base), static_cast(start))); } template From 7b9c8ed879c99ff80919df924c8a51c44c432567 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Thu, 7 Apr 2022 10:59:34 +0800 Subject: [PATCH 08/17] Update logspace_op.cc --- paddle/fluid/operators/logspace_op.cc | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/logspace_op.cc b/paddle/fluid/operators/logspace_op.cc index 6d1154992ca8b..e762dd6b455ac 100644 --- a/paddle/fluid/operators/logspace_op.cc +++ b/paddle/fluid/operators/logspace_op.cc @@ -10,6 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include + #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" @@ -34,7 +35,8 @@ class LogspaceOp : public framework::OperatorWithKernel { framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const framework::Tensor &tensor, const framework::OpKernelType &expected_kernel_type) const override { - return expected_kernel_type; + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); } }; @@ -43,23 +45,24 @@ class LogspaceOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("Start", "Exponent of first entry in the sequence. It is a tensor of " - "shape [1], should be of type float32 or float64."); + "shape [1], should be of type int32, int64, float32 or float64."); AddInput("Stop", "Exponent of last entry in the sequence. It is a tensor of " - "shape [1], should be of type float32 or float64."); + "shape [1], should be of type int32, int64, float32 or float64."); AddInput("Num", "Number of entry in the sequence. It is a tensor of shape [1], " "should be of type int32."); AddInput("Base", "Base of the logarithm function. It is a tensor of shape [1], " - "should be of type int32."); + "should be of type int32, int64, float32 or float64."); AddAttr("dtype", "The output data type."); AddOutput("Out", "A sequence of numbers."); AddComment(R"DOC( Return fixed number of logarithmical-evenly spaced values within a given interval. First entry is exponential of Start with base Base, and last entry is exponential of Stop with base Base. In the case when Num is 1, - only exponential of Start with base Base is returned. + only exponential of Start with base Base is returned. If dtype is int32 + or int64, the decimal part of values will be truncated. Like logspace function of numpy. )DOC"); } From 3e2e62a24b26ca7263faeab79e20309ca1851b1b Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Thu, 7 Apr 2022 10:59:57 +0800 Subject: [PATCH 09/17] =?UTF-8?q?=E8=B0=83=E6=95=B4=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/kernels/cpu/logspace_kernel.cc | 11 ++++++----- paddle/phi/kernels/gpu/logspace_kernel.cu | 16 ++++++++-------- .../fluid/tests/unittests/test_logspace.py | 2 +- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/paddle/phi/kernels/cpu/logspace_kernel.cc b/paddle/phi/kernels/cpu/logspace_kernel.cc index be123483a3a89..e429ae93c8a5e 100644 --- a/paddle/phi/kernels/cpu/logspace_kernel.cc +++ b/paddle/phi/kernels/cpu/logspace_kernel.cc @@ -13,10 +13,11 @@ // limitations under the License. #include -#include "paddle/phi/kernels/logspace_kernel.h" + #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/data_type_transform.h" +#include "paddle/phi/kernels/logspace_kernel.h" namespace phi { @@ -52,11 +53,11 @@ void LogspaceKernel(const Context& ctx, int half_num = num / 2; for (int i = 0; i < num; ++i) { if (i < half_num) { - out_data[i] = static_cast(std::pow( - base_data, start_data + step * i)); + out_data[i] = + static_cast(std::pow(base_data, start_data + step * i)); } else { - out_data[i] = static_cast(std::pow( - base_data, stop_data - step * (num - i - 1))); + out_data[i] = static_cast( + std::pow(base_data, stop_data - step * (num - i - 1))); } } } else { diff --git a/paddle/phi/kernels/gpu/logspace_kernel.cu b/paddle/phi/kernels/gpu/logspace_kernel.cu index b0d0e2523454a..84957536fd706 100644 --- a/paddle/phi/kernels/gpu/logspace_kernel.cu +++ b/paddle/phi/kernels/gpu/logspace_kernel.cu @@ -29,21 +29,21 @@ __global__ void LogspaceKernelInner( for (; index < size; index += blockDim.x * gridDim.x) { if (index < size / 2) { - out[index] = static_cast(pow( - static_cast(base), - static_cast(start + step * index))); + out[index] = + static_cast(pow(static_cast(base), + static_cast(start + step * index))); } else { - out[index] = static_cast(pow( - static_cast(base), - static_cast(stop - step * (size - index - 1)))); + out[index] = static_cast( + pow(static_cast(base), + static_cast(stop - step * (size - index - 1)))); } } } template __global__ void LogspaceSpecialKernel(T start, T base, T* out) { - out[0] = static_cast(pow( - static_cast(base), static_cast(start))); + out[0] = static_cast( + pow(static_cast(base), static_cast(start))); } template diff --git a/python/paddle/fluid/tests/unittests/test_logspace.py b/python/paddle/fluid/tests/unittests/test_logspace.py index 3e9835efc62bd..886f89c64378e 100644 --- a/python/paddle/fluid/tests/unittests/test_logspace.py +++ b/python/paddle/fluid/tests/unittests/test_logspace.py @@ -222,4 +222,4 @@ def test_base_dtype(): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From fee242f59c8ae64ddf7bcace1aaa24a3b88064ac Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Wed, 13 Apr 2022 15:28:11 +0800 Subject: [PATCH 10/17] Update doc of logspace --- python/paddle/fluid/layers/tensor.py | 50 ++++++++++++++++------------ 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 222583baa022f..1100034c9e750 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -1586,35 +1586,43 @@ def linspace(start, stop, num, dtype=None, name=None): def logspace(start, stop, num, base=10.0, dtype=None, name=None): r""" - This OP return fixed number of logarithmical-evenly spaced values within a given interval. + Return fixed number of logarithmical-evenly spaced values within the interval \ + :math:`[base^{start}, base^{stop}]`. + + Notes: + This API does not compute the gradient. + Args: - start(int|float|Tensor): The input :attr:`start` is exponent of first entry in - the sequence. It is a scalar, or a Tensor of shape [1] with input data + start(int|float|Tensor): The input :attr:`start` is exponent of first entry in \ + the sequence. It is a scalar, or a Tensor of shape [1] with input data \ type int32, int64, float32 or float64. - stop(int|float|Tensor): The input :attr:`stop` is exponent of last entry in the - sequence. It is a scalar, or a Tensor of shape [1] with input data + stop(int|float|Tensor): The input :attr:`stop` is exponent of last entry in the \ + sequence. It is a scalar, or a Tensor of shape [1] with input data \ type int32, int64, float32 or float64. - num(int|Tensor): The input :attr:`num` is given number of items in the sequence. + num(int|Tensor): The input :attr:`num` is given number of items in the sequence. \ It is an int scalar, or a Tensor of shape [1] with data type int32. - base(int|float|Tensor): The input :attr:`base` is base of the logarithm function. - It is a scalar, or a Tensor of shape [1] with input data type int32, int64, + base(int|float|Tensor): The input :attr:`base` is base of the logarithm function. \ + It is a scalar, or a Tensor of shape [1] with input data type int32, int64, \ float32 or float64. - dtype(np.dtype|str, optional): The data type of output tensor, it could be - int32, int64, float32 and float64. Default: if None, the data type is float32. - name(str, optional): Normally there is no need for user to set this property. - For more information, please refer to :ref:`api_guide_Name`.Default: None. + dtype(np.dtype|str, optional): The data type of output tensor, it could be \ + int32, int64, float32 or float64. Default: if None, the data type is float32. \ + name(str, optional): Normally there is no need for user to set this property. \ + For more information, please refer to :ref:`api_guide_Name`. Default: None. + Returns: - Tensor: the output data type will be float32, float64. The 1-D tensor with - fixed number of evenly spaced values, the data shape of this tensor - is :math:`[num]` . If the :attr:`num` is set 1, the output tensor just - has the value with exponential of :attr:`start` with base :attr:`base`. + Tensor: The output data type will be float32, float64. The 1-D tensor with \ + fixed number of logarithmical-evenly spaced values, the data shape of this \ + tensor is :math:`[num]`. If the :attr:`num` is set 1, the output tensor \ + just has the value with exponential of :attr:`start` with base :attr:`base`. + Examples: .. code-block:: python - import paddle - data = paddle.logspace(0, 10, 5, 2, 'float32') - # [1. , 5.65685415 , 32. , 181.01933289, 1024. ] - data = paddle.logspace(0, 10, 1, 2, 'float32') - # [1.] + :name: logspace-example + import paddle + data = paddle.logspace(0, 10, 5, 2, 'float32') + # [1. , 5.65685415 , 32. , 181.01933289, 1024. ] + data = paddle.logspace(0, 10, 1, 2, 'float32') + # [1.] """ if dtype is None: dtype = 'float32' From f0862ad6a987cb7e91b8635b6af417fad68c0276 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Wed, 13 Apr 2022 22:51:28 +0800 Subject: [PATCH 11/17] Update tensor.py --- python/paddle/fluid/layers/tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 81d99332e186b..e617824764338 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -1649,6 +1649,7 @@ def logspace(start, stop, num, base=10.0, dtype=None, name=None): Examples: .. code-block:: python :name: logspace-example + import paddle data = paddle.logspace(0, 10, 5, 2, 'float32') # [1. , 5.65685415 , 32. , 181.01933289, 1024. ] From ec204b66303fef93aa8e952228841db6df16d355 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 16 Apr 2022 18:24:53 +0800 Subject: [PATCH 12/17] Update logspace_op.cc --- paddle/fluid/operators/logspace_op.cc | 30 ++++++++++++--------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/logspace_op.cc b/paddle/fluid/operators/logspace_op.cc index e762dd6b455ac..1d1653b053679 100644 --- a/paddle/fluid/operators/logspace_op.cc +++ b/paddle/fluid/operators/logspace_op.cc @@ -1,13 +1,16 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -31,13 +34,6 @@ class LogspaceOp : public framework::OperatorWithKernel { framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } - - framework::OpKernelType GetKernelTypeForVar( - const std::string &var_name, const framework::Tensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), tensor.layout()); - } }; class LogspaceOpMaker : public framework::OpProtoAndCheckerMaker { From 15ef37241b889003d8b6d7ab73c6d8e8c6fe7e5e Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 16 Apr 2022 18:25:09 +0800 Subject: [PATCH 13/17] Update logspace_kernel.cc --- paddle/phi/kernels/cpu/logspace_kernel.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/cpu/logspace_kernel.cc b/paddle/phi/kernels/cpu/logspace_kernel.cc index e429ae93c8a5e..fbb31057a35ae 100644 --- a/paddle/phi/kernels/cpu/logspace_kernel.cc +++ b/paddle/phi/kernels/cpu/logspace_kernel.cc @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/logspace_kernel.h" + #include #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/data_type_transform.h" -#include "paddle/phi/kernels/logspace_kernel.h" namespace phi { From 88a744622a169840eb49a655560222f83dc88c1f Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 16 Apr 2022 18:25:24 +0800 Subject: [PATCH 14/17] Update logspace_kernel.cu --- paddle/phi/kernels/gpu/logspace_kernel.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/gpu/logspace_kernel.cu b/paddle/phi/kernels/gpu/logspace_kernel.cu index 84957536fd706..f47b7d35cdcda 100644 --- a/paddle/phi/kernels/gpu/logspace_kernel.cu +++ b/paddle/phi/kernels/gpu/logspace_kernel.cu @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/kernels/logspace_kernel.h" + #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/data_type_transform.h" #include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/logspace_kernel.h" namespace phi { From 7cc079f0a51b46d47556fc1baedcfd590684225c Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 16 Apr 2022 18:25:34 +0800 Subject: [PATCH 15/17] Update test_logspace.py --- .../fluid/tests/unittests/test_logspace.py | 73 ++++++++++--------- 1 file changed, 38 insertions(+), 35 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_logspace.py b/python/paddle/fluid/tests/unittests/test_logspace.py index 886f89c64378e..8f5aaf0aa5dd8 100644 --- a/python/paddle/fluid/tests/unittests/test_logspace.py +++ b/python/paddle/fluid/tests/unittests/test_logspace.py @@ -18,9 +18,6 @@ import numpy as np from op_test import OpTest import paddle -import paddle.fluid as fluid -from paddle.fluid import Program, program_guard -from paddle.fluid import core class TestLogspaceOpCommonCase(OpTest): @@ -33,7 +30,7 @@ def setUp(self): 'Num': np.array([11]).astype('int32'), 'Base': np.array([2]).astype(dtype), } - self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} + self.attrs = {'dtype': int(paddle.float32)} self.outputs = {'Out': np.power(2, np.arange(0, 11)).astype(dtype)} @@ -51,7 +48,7 @@ def setUp(self): 'Num': np.array([11]).astype('int32'), 'Base': np.array([2]).astype(dtype) } - self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} + self.attrs = {'dtype': int(paddle.float32)} self.outputs = {'Out': np.power(2, np.arange(10, -1, -1)).astype(dtype)} @@ -69,7 +66,7 @@ def setUp(self): 'Num': np.array([1]).astype('int32'), 'Base': np.array([2]).astype(dtype) } - self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} + self.attrs = {'dtype': int(paddle.float32)} self.outputs = {'Out': np.power(2, np.array(10)).astype(dtype)} @@ -87,7 +84,7 @@ def setUp(self): 'Num': np.array([11]).astype('int32'), 'Base': np.array([-2]).astype(dtype), } - self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} + self.attrs = {'dtype': int(paddle.float32)} self.outputs = {'Out': np.power(-2, np.arange(0, 11)).astype(dtype)} @@ -105,7 +102,7 @@ def setUp(self): 'Num': np.array([11]).astype('int32'), 'Base': np.array([0]).astype(dtype), } - self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} + self.attrs = {'dtype': int(paddle.float32)} self.outputs = {'Out': np.power(0, np.arange(0, 11)).astype(dtype)} @@ -115,15 +112,20 @@ def test_check_output(self): class TestLogspaceAPI(unittest.TestCase): def test_variable_input1(self): - start = paddle.full(shape=[1], fill_value=0, dtype='float32') - stop = paddle.full(shape=[1], fill_value=10, dtype='float32') - num = paddle.full(shape=[1], fill_value=5, dtype='int32') - base = paddle.full(shape=[1], fill_value=2, dtype='float32') - out = paddle.logspace(start, stop, num, base, dtype='float32') - exe = fluid.Executor(place=fluid.CPUPlace()) - res = exe.run(fluid.default_main_program(), fetch_list=[out]) + paddle.enable_static() + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + start = paddle.full(shape=[1], fill_value=0, dtype='float32') + stop = paddle.full(shape=[1], fill_value=10, dtype='float32') + num = paddle.full(shape=[1], fill_value=5, dtype='int32') + base = paddle.full(shape=[1], fill_value=2, dtype='float32') + out = paddle.logspace(start, stop, num, base, dtype='float32') + + exe = paddle.static.Executor() + res = exe.run(prog, fetch_list=[out]) np_res = np.logspace(0, 10, 5, base=2, dtype='float32') self.assertEqual((res == np_res).all(), True) + paddle.disable_static() def test_variable_input2(self): paddle.disable_static() @@ -138,12 +140,13 @@ def test_variable_input2(self): def test_dtype(self): paddle.enable_static() - out_1 = paddle.logspace(0, 10, 5, 2, dtype='float32') - out_2 = paddle.logspace(0, 10, 5, 2, dtype=np.float32) - out_3 = paddle.logspace(0, 10, 5, 2, dtype=core.VarDesc.VarType.FP32) - exe = fluid.Executor(place=fluid.CPUPlace()) - res_1, res_2, res_3 = exe.run(fluid.default_main_program(), - fetch_list=[out_1, out_2, out_3]) + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + out_1 = paddle.logspace(0, 10, 5, 2, dtype='float32') + out_2 = paddle.logspace(0, 10, 5, 2, dtype=np.float32) + + exe = paddle.static.Executor() + res_1, res_2 = exe.run(prog, fetch_list=[out_1, out_2]) assert np.array_equal(res_1, res_2) paddle.disable_static() @@ -169,54 +172,54 @@ def test_imperative(self): class TestLogspaceOpError(unittest.TestCase): def test_errors(self): - with program_guard(Program(), Program()): + with paddle.static.program_guard(paddle.static.Program()): def test_dtype(): - fluid.layers.logspace(0, 10, 1, 2, dtype="int8") + paddle.logspace(0, 10, 1, 2, dtype="int8") self.assertRaises(TypeError, test_dtype) def test_dtype1(): - fluid.layers.logspace(0, 10, 1.33, 2, dtype="int32") + paddle.logspace(0, 10, 1.33, 2, dtype="int32") self.assertRaises(TypeError, test_dtype1) def test_start_type(): - fluid.layers.logspace([0], 10, 1, 2, dtype="float32") + paddle.logspace([0], 10, 1, 2, dtype="float32") self.assertRaises(TypeError, test_start_type) def test_end_type(): - fluid.layers.logspace(0, [10], 1, 2, dtype="float32") + paddle.logspace(0, [10], 1, 2, dtype="float32") self.assertRaises(TypeError, test_end_type) def test_num_type(): - fluid.layers.logspace(0, 10, [0], 2, dtype="float32") + paddle.logspace(0, 10, [0], 2, dtype="float32") self.assertRaises(TypeError, test_num_type) def test_start_dtype(): - start = fluid.data(shape=[1], dtype="float64", name="start") - fluid.layers.logspace(start, 10, 1, 2, dtype="float32") + start = paddle.static.data(shape=[1], dtype="float64", name="start") + paddle.logspace(start, 10, 1, 2, dtype="float32") self.assertRaises(ValueError, test_start_dtype) def test_end_dtype(): - end = fluid.data(shape=[1], dtype="float64", name="end") - fluid.layers.logspace(0, end, 1, 2, dtype="float32") + end = paddle.static.data(shape=[1], dtype="float64", name="end") + paddle.logspace(0, end, 1, 2, dtype="float32") self.assertRaises(ValueError, test_end_dtype) def test_num_dtype(): - num = fluid.data(shape=[1], dtype="float32", name="step") - fluid.layers.logspace(0, 10, num, 2, dtype="float32") + num = paddle.static.data(shape=[1], dtype="float32", name="step") + paddle.logspace(0, 10, num, 2, dtype="float32") self.assertRaises(TypeError, test_num_dtype) def test_base_dtype(): - base = fluid.data(shape=[1], dtype="float64", name="end") - fluid.layers.logspace(0, 10, 1, base, dtype="float32") + base = paddle.static.data(shape=[1], dtype="float64", name="end") + paddle.logspace(0, 10, 1, base, dtype="float32") self.assertRaises(ValueError, test_base_dtype) From 3530acf1cea35aa1db8a46ce5d70e4426291fe65 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 16 Apr 2022 18:28:03 +0800 Subject: [PATCH 16/17] =?UTF-8?q?=E8=B0=83=E6=95=B4=20logspace=20=E7=9A=84?= =?UTF-8?q?=E4=BD=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/fluid/layers/tensor.py | 125 --------------------------- python/paddle/tensor/creation.py | 124 +++++++++++++++++++++++++- 2 files changed, 123 insertions(+), 126 deletions(-) diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index e617824764338..3a8dfdc858079 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -56,7 +56,6 @@ 'isfinite', 'range', 'linspace', - 'logspace', 'zeros_like', 'ones_like', 'diag', @@ -1615,130 +1614,6 @@ def linspace(start, stop, num, dtype=None, name=None): return out -def logspace(start, stop, num, base=10.0, dtype=None, name=None): - r""" - Return fixed number of logarithmical-evenly spaced values within the interval \ - :math:`[base^{start}, base^{stop}]`. - - Notes: - This API does not compute the gradient. - - Args: - start(int|float|Tensor): The input :attr:`start` is exponent of first entry in \ - the sequence. It is a scalar, or a Tensor of shape [1] with input data \ - type int32, int64, float32 or float64. - stop(int|float|Tensor): The input :attr:`stop` is exponent of last entry in the \ - sequence. It is a scalar, or a Tensor of shape [1] with input data \ - type int32, int64, float32 or float64. - num(int|Tensor): The input :attr:`num` is given number of items in the sequence. \ - It is an int scalar, or a Tensor of shape [1] with data type int32. - base(int|float|Tensor): The input :attr:`base` is base of the logarithm function. \ - It is a scalar, or a Tensor of shape [1] with input data type int32, int64, \ - float32 or float64. - dtype(np.dtype|str, optional): The data type of output tensor, it could be \ - int32, int64, float32 or float64. Default: if None, the data type is float32. \ - name(str, optional): Normally there is no need for user to set this property. \ - For more information, please refer to :ref:`api_guide_Name`. Default: None. - - Returns: - Tensor: The output data type will be float32, float64. The 1-D tensor with \ - fixed number of logarithmical-evenly spaced values, the data shape of this \ - tensor is :math:`[num]`. If the :attr:`num` is set 1, the output tensor \ - just has the value with exponential of :attr:`start` with base :attr:`base`. - - Examples: - .. code-block:: python - :name: logspace-example - - import paddle - data = paddle.logspace(0, 10, 5, 2, 'float32') - # [1. , 5.65685415 , 32. , 181.01933289, 1024. ] - data = paddle.logspace(0, 10, 1, 2, 'float32') - # [1.] - """ - if dtype is None: - dtype = 'float32' - tensor_num = num - tensor_start = start - tensor_stop = stop - tensor_base = base - if not isinstance(num, Variable): - check_type(num, 'num', (int), 'logspace') - if not isinstance(dtype, core.VarDesc.VarType): - dtype = convert_np_dtype_to_dtype_(dtype) - if not isinstance(start, Variable): - with device_guard("cpu"): - tensor_start = fill_constant([1], dtype, start) - if not isinstance(stop, Variable): - with device_guard("cpu"): - tensor_stop = fill_constant([1], dtype, stop) - if not isinstance(num, Variable): - with device_guard("cpu"): - tensor_num = fill_constant([1], 'int32', num) - if not isinstance(base, Variable): - with device_guard("cpu"): - tensor_base = fill_constant([1], dtype, base) - if _non_static_mode(): - return _C_ops.logspace(tensor_start, tensor_stop, tensor_num, - tensor_base, 'dtype', dtype) - - helper = LayerHelper("logspace", **locals()) - - start_dtype = convert_dtype(tensor_start.dtype) - stop_dtype = convert_dtype(tensor_stop.dtype) - base_dtype = convert_dtype(tensor_base.dtype) - out_dtype = convert_dtype(dtype) - if isinstance(start, Variable): - check_dtype(start.dtype, 'start', - ['float32', 'float64', 'int32', 'int64'], 'logspace') - else: - check_type(start, 'start', (int, float), 'logspace') - - if isinstance(stop, Variable): - check_dtype(stop.dtype, 'stop', - ['float32', 'float64', 'int32', 'int64'], 'logspace') - else: - check_type(stop, 'stop', (int, float), 'logspace') - - if isinstance(num, Variable): - check_dtype(num.dtype, 'num', ['int32'], 'logspace') - - if isinstance(base, Variable): - check_dtype(base.dtype, 'base', - ['float32', 'float64', 'int32', 'int64'], 'logspace') - else: - check_type(base, 'base', (int, float), 'logspace') - - check_dtype(dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'], - 'logspace') - if ((stop_dtype == "float64" or start_dtype == "float64" - or base_dtype == "float64") - and out_dtype in ["float32", "int32"]) or \ - ((stop_dtype == "int64" or start_dtype == "int64" - or base_dtype == "int64") - and out_dtype == "int32"): - raise ValueError( - "The dtype of start/stop/base is {}/{}/{} but the attr(dtype) of logspace is {}, " - "which may cause data type overflows. Please reset attr(dtype) of logspace." - .format(start_dtype, stop_dtype, base_dtype, dtype)) - - out = helper.create_variable_for_type_inference(dtype=dtype) - - helper.append_op( - type='logspace', - inputs={ - 'Start': tensor_start, - 'Stop': tensor_stop, - 'Num': tensor_num, - 'Base': tensor_base - }, - attrs={'dtype': dtype}, - outputs={'Out': [out]}) - if isinstance(num, int): - out.desc.set_shape((num, )) - return out - - def zeros_like(x, out=None): """ This OP creates a zeros tensor which has identical shape and dtype diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 623c802d0b202..c0dd80fea4236 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -26,7 +26,6 @@ from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype from ..framework import convert_np_dtype_to_dtype_, _varbase_creator, OpProtoHolder # TODO: define functions to get create a tensor -from ..fluid.layers import logspace # noqa: F401 import paddle from paddle import _C_ops from ..fluid.framework import _in_legacy_dygraph, _in_eager_without_dygraph_check @@ -146,6 +145,129 @@ def linspace(start, stop, num, dtype=None, name=None): out.desc.set_shape((num, )) return out +def logspace(start, stop, num, base=10.0, dtype=None, name=None): + r""" + Return fixed number of logarithmical-evenly spaced values within the interval \ + :math:`[base^{start}, base^{stop}]`. + + Notes: + This API does not compute the gradient. + + Args: + start(int|float|Tensor): The input :attr:`start` is exponent of first entry in \ + the sequence. It is a scalar, or a Tensor of shape [1] with input data \ + type int32, int64, float32 or float64. + stop(int|float|Tensor): The input :attr:`stop` is exponent of last entry in the \ + sequence. It is a scalar, or a Tensor of shape [1] with input data \ + type int32, int64, float32 or float64. + num(int|Tensor): The input :attr:`num` is given number of items in the sequence. \ + It is an int scalar, or a Tensor of shape [1] with data type int32. + base(int|float|Tensor): The input :attr:`base` is base of the logarithm function. \ + It is a scalar, or a Tensor of shape [1] with input data type int32, int64, \ + float32 or float64. + dtype(np.dtype|str, optional): The data type of output tensor, it could be \ + int32, int64, float32 or float64. Default: if None, the data type is float32. \ + name(str, optional): Normally there is no need for user to set this property. \ + For more information, please refer to :ref:`api_guide_Name`. Default: None. + + Returns: + Tensor: The output data type will be float32, float64. The 1-D tensor with \ + fixed number of logarithmical-evenly spaced values, the data shape of this \ + tensor is :math:`[num]`. If the :attr:`num` is set 1, the output tensor \ + just has the value with exponential of :attr:`start` with base :attr:`base`. + + Examples: + .. code-block:: python + :name: logspace-example + + import paddle + data = paddle.logspace(0, 10, 5, 2, 'float32') + # [1. , 5.65685415 , 32. , 181.01933289, 1024. ] + data = paddle.logspace(0, 10, 1, 2, 'float32') + # [1.] + """ + if dtype is None: + dtype = 'float32' + tensor_num = num + tensor_start = start + tensor_stop = stop + tensor_base = base + if not isinstance(num, Variable): + check_type(num, 'num', (int), 'logspace') + if not isinstance(dtype, core.VarDesc.VarType): + dtype = convert_np_dtype_to_dtype_(dtype) + if not isinstance(start, Variable): + with device_guard("cpu"): + tensor_start = fill_constant([1], dtype, start) + if not isinstance(stop, Variable): + with device_guard("cpu"): + tensor_stop = fill_constant([1], dtype, stop) + if not isinstance(num, Variable): + with device_guard("cpu"): + tensor_num = fill_constant([1], 'int32', num) + if not isinstance(base, Variable): + with device_guard("cpu"): + tensor_base = fill_constant([1], dtype, base) + if _non_static_mode(): + return _C_ops.logspace(tensor_start, tensor_stop, tensor_num, + tensor_base, 'dtype', dtype) + + helper = LayerHelper("logspace", **locals()) + + start_dtype = convert_dtype(tensor_start.dtype) + stop_dtype = convert_dtype(tensor_stop.dtype) + base_dtype = convert_dtype(tensor_base.dtype) + out_dtype = convert_dtype(dtype) + if isinstance(start, Variable): + check_dtype(start.dtype, 'start', + ['float32', 'float64', 'int32', 'int64'], 'logspace') + else: + check_type(start, 'start', (int, float), 'logspace') + + if isinstance(stop, Variable): + check_dtype(stop.dtype, 'stop', + ['float32', 'float64', 'int32', 'int64'], 'logspace') + else: + check_type(stop, 'stop', (int, float), 'logspace') + + if isinstance(num, Variable): + check_dtype(num.dtype, 'num', ['int32'], 'logspace') + + if isinstance(base, Variable): + check_dtype(base.dtype, 'base', + ['float32', 'float64', 'int32', 'int64'], 'logspace') + else: + check_type(base, 'base', (int, float), 'logspace') + + check_dtype(dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'], + 'logspace') + if ((stop_dtype == "float64" or start_dtype == "float64" + or base_dtype == "float64") + and out_dtype in ["float32", "int32"]) or \ + ((stop_dtype == "int64" or start_dtype == "int64" + or base_dtype == "int64") + and out_dtype == "int32"): + raise ValueError( + "The dtype of start/stop/base is {}/{}/{} but the attr(dtype) of logspace is {}, " + "which may cause data type overflows. Please reset attr(dtype) of logspace." + .format(start_dtype, stop_dtype, base_dtype, dtype)) + + out = helper.create_variable_for_type_inference(dtype=dtype) + + helper.append_op( + type='logspace', + inputs={ + 'Start': tensor_start, + 'Stop': tensor_stop, + 'Num': tensor_num, + 'Base': tensor_base + }, + attrs={'dtype': dtype}, + outputs={'Out': [out]}) + if isinstance(num, int): + out.desc.set_shape((num, )) + return out + @dygraph_only def to_tensor(data, dtype=None, place=None, stop_gradient=True): From 407c2a6991ad6d21a9cb3d144f13f9f6934c3df2 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sun, 17 Apr 2022 18:28:05 +0800 Subject: [PATCH 17/17] =?UTF-8?q?=E8=B0=83=E6=95=B4=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/fluid/tests/unittests/test_logspace.py | 9 ++++++--- python/paddle/tensor/creation.py | 1 + 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_logspace.py b/python/paddle/fluid/tests/unittests/test_logspace.py index 8f5aaf0aa5dd8..ffa9885e7671e 100644 --- a/python/paddle/fluid/tests/unittests/test_logspace.py +++ b/python/paddle/fluid/tests/unittests/test_logspace.py @@ -200,7 +200,8 @@ def test_num_type(): self.assertRaises(TypeError, test_num_type) def test_start_dtype(): - start = paddle.static.data(shape=[1], dtype="float64", name="start") + start = paddle.static.data( + shape=[1], dtype="float64", name="start") paddle.logspace(start, 10, 1, 2, dtype="float32") self.assertRaises(ValueError, test_start_dtype) @@ -212,13 +213,15 @@ def test_end_dtype(): self.assertRaises(ValueError, test_end_dtype) def test_num_dtype(): - num = paddle.static.data(shape=[1], dtype="float32", name="step") + num = paddle.static.data( + shape=[1], dtype="float32", name="step") paddle.logspace(0, 10, num, 2, dtype="float32") self.assertRaises(TypeError, test_num_dtype) def test_base_dtype(): - base = paddle.static.data(shape=[1], dtype="float64", name="end") + base = paddle.static.data( + shape=[1], dtype="float64", name="end") paddle.logspace(0, 10, 1, base, dtype="float32") self.assertRaises(ValueError, test_base_dtype) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index c0dd80fea4236..06d181011fcb1 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -145,6 +145,7 @@ def linspace(start, stop, num, dtype=None, name=None): out.desc.set_shape((num, )) return out + def logspace(start, stop, num, base=10.0, dtype=None, name=None): r""" Return fixed number of logarithmical-evenly spaced values within the interval \