From b00e7a3311468952abde6f77cdfd9abf60216a38 Mon Sep 17 00:00:00 2001 From: zhiboniu <31800336+zhiboniu@users.noreply.github.com> Date: Sat, 9 Oct 2021 16:21:39 +0800 Subject: [PATCH 01/16] update fft api path (#36219) * update fft api path * add sample code for ihfft2 Co-authored-by: chenfeiyu --- python/paddle/__init__.py | 2 +- python/paddle/fft.py | 61 +++++++++++++++++++++++++++++++++++++ python/paddle/tensor/fft.py | 44 ++++++++++++-------------- 3 files changed, 81 insertions(+), 26 deletions(-) create mode 100644 python/paddle/fft.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ad8640f6f5584..decffa66f4174 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -64,7 +64,6 @@ import paddle.static # noqa: F401 import paddle.vision # noqa: F401 -from .tensor import fft from .tensor.random import bernoulli # noqa: F401 from .tensor.attribute import rank # noqa: F401 @@ -294,6 +293,7 @@ from .hapi import flops # noqa: F401 from . import hub # noqa: F401 from . import linalg # noqa: F401 +from . import fft # noqa: F401 import paddle.text # noqa: F401 import paddle.vision # noqa: F401 diff --git a/python/paddle/fft.py b/python/paddle/fft.py new file mode 100644 index 0000000000000..3ac02c9c8dc18 --- /dev/null +++ b/python/paddle/fft.py @@ -0,0 +1,61 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .tensor.fft import fft # noqa: F401 +from .tensor.fft import fft2 # noqa: F401 +from .tensor.fft import fftn # noqa: F401 +from .tensor.fft import ifft # noqa: F401 +from .tensor.fft import ifft2 # noqa: F401 +from .tensor.fft import ifftn # noqa: F401 +from .tensor.fft import rfft # noqa: F401 +from .tensor.fft import rfft2 # noqa: F401 +from .tensor.fft import rfftn # noqa: F401 +from .tensor.fft import irfft # noqa: F401 +from .tensor.fft import irfft2 # noqa: F401 +from .tensor.fft import irfftn # noqa: F401 +from .tensor.fft import hfft # noqa: F401 +from .tensor.fft import hfft2 # noqa: F401 +from .tensor.fft import hfftn # noqa: F401 +from .tensor.fft import ihfft # noqa: F401 +from .tensor.fft import ihfft2 # noqa: F401 +from .tensor.fft import ihfftn # noqa: F401 +from .tensor.fft import fftfreq # noqa: F401 +from .tensor.fft import rfftfreq # noqa: F401 +from .tensor.fft import fftshift # noqa: F401 +from .tensor.fft import ifftshift # noqa: F401 + +__all__ = [ # noqa + 'fft', + 'fft2', + 'fftn', + 'ifft', + 'ifft2', + 'ifftn', + 'rfft', + 'rfft2', + 'rfftn', + 'irfft', + 'irfft2', + 'irfftn', + 'hfft', + 'hfft2', + 'hfftn', + 'ihfft', + 'ihfft2', + 'ihfftn', + 'fftfreq', + 'rfftfreq', + 'fftshift', + 'ifftshift' +] diff --git a/python/paddle/tensor/fft.py b/python/paddle/tensor/fft.py index 98ca858c0eb85..829399d14eaa0 100644 --- a/python/paddle/tensor/fft.py +++ b/python/paddle/tensor/fft.py @@ -21,30 +21,7 @@ from ..fluid.data_feeder import check_variable_and_dtype from ..fluid.layer_helper import LayerHelper -__all__ = [ - 'fft', - 'fft2', - 'fftn', - 'ifft', - 'ifft2', - 'ifftn', - 'rfft', - 'rfft2', - 'rfftn', - 'irfft', - 'irfft2', - 'irfftn', - 'hfft', - 'hfft2', - 'hfftn', - 'ihfft', - 'ihfft2', - 'ihfftn', - 'fftfreq', - 'rfftfreq', - 'fftshift', - 'ifftshift', -] +__all__ = [] def _check_normalization(norm): @@ -1135,7 +1112,24 @@ def ihfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): refer to :ref:`api_guide_Name` . Returns: - out(Tensor) : The result of the inverse real 2-D FFT. + out(Tensor) : The result of the inverse hermitian 2-D FFT. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.mgrid[:5, :5][0].astype(np.float64) + xp = paddle.to_tensor(x) + ihfft2_xp = paddle.fft.ihfft2(xp).numpy() + print(ihfft2_xp) + # [[ 2. +0.j 0. +0.j 0. +0.j ] + # [-0.5-0.68819096j 0. +0.j 0. +0.j ] + # [-0.5-0.16245985j 0. +0.j 0. +0.j ] + # [-0.5+0.16245985j 0. +0.j 0. +0.j ] + # [-0.5+0.68819096j 0. +0.j 0. +0.j ]] """ _check_at_least_ndim(x, 2) if s is not None: From b01cbdbd497f8f05708b1d99f5268d195cf2d8cf Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Mon, 11 Oct 2021 11:24:40 +0800 Subject: [PATCH 02/16] fix fft axis (#36321) fix: `-1` is used when fft's axis is `0` --- python/paddle/tensor/fft.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/tensor/fft.py b/python/paddle/tensor/fft.py index 829399d14eaa0..f7990e3f89107 100644 --- a/python/paddle/tensor/fft.py +++ b/python/paddle/tensor/fft.py @@ -1340,7 +1340,7 @@ def fft_c2c(x, n, axis, norm, forward, name): x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) _check_normalization(norm) - axis = axis or -1 + axis = axis if axis is not None else -1 _check_fft_axis(x, axis) axes = [axis] axes = _normalize_axes(x, axes) @@ -1370,7 +1370,7 @@ def fft_r2c(x, n, axis, norm, forward, onesided, name): if is_interger(x): x = paddle.cast(x, paddle.get_default_dtype()) _check_normalization(norm) - axis = axis or -1 + axis = axis if axis is not None else -1 _check_fft_axis(x, axis) axes = [axis] axes = _normalize_axes(x, axes) @@ -1409,7 +1409,7 @@ def fft_c2r(x, n, axis, norm, forward, name): elif is_floating_point(x): x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) _check_normalization(norm) - axis = axis or -1 + axis = axis if axis is not None else -1 _check_fft_axis(x, axis) axes = [axis] axes = _normalize_axes(x, axes) From 5d054d076a21a53bed368f523cd0425da6c11138 Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Mon, 11 Oct 2021 11:30:12 +0800 Subject: [PATCH 03/16] use unified external error message for cufft api (#36114) --- cmake/third_party.cmake | 4 +-- paddle/fluid/operators/spectral_op.cu | 5 ++-- paddle/fluid/platform/enforce.h | 14 ++++++++++ paddle/fluid/platform/enforce_test.cc | 22 +++++++++++++++- paddle/fluid/platform/external_error.proto | 1 + tools/externalError/README.md | 30 +++++++++++++++++----- tools/externalError/spider.py | 29 ++++++++++++++++++++- tools/externalError/start.sh | 2 +- 8 files changed, 92 insertions(+), 15 deletions(-) diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 892ae270267a7..b3260ba27b072 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -251,8 +251,8 @@ if(WITH_GPU) include(external/cub) # download cub list(APPEND third_party_deps extern_cub) endif() - set(URL "https://paddlepaddledeps.bj.bcebos.com/externalErrorMsg.tar.gz" CACHE STRING "" FORCE) - file_download_and_uncompress(${URL} "externalError" MD5 061f3b7895aadcbe2c3ed592590f8b10) # download file externalErrorMsg.tar.gz + set(URL "https://paddlepaddledeps.bj.bcebos.com/externalErrorMsg_20210928.tar.gz" CACHE STRING "" FORCE) + file_download_and_uncompress(${URL} "externalError" MD5 a712a49384e77ca216ad866712f7cafa) # download file externalErrorMsg.tar.gz if(WITH_TESTING) # copy externalErrorMsg.pb, just for unittest can get error message correctly. set(SRC_DIR ${THIRD_PARTY_PATH}/externalError/data) diff --git a/paddle/fluid/operators/spectral_op.cu b/paddle/fluid/operators/spectral_op.cu index 9aa5ca39d737e..24dffaad41b5f 100644 --- a/paddle/fluid/operators/spectral_op.cu +++ b/paddle/fluid/operators/spectral_op.cu @@ -83,9 +83,7 @@ static inline std::string get_cufft_error_info(cufftResult error) { } static inline void CUFFT_CHECK(cufftResult error) { - if (error != CUFFT_SUCCESS) { - PADDLE_THROW(platform::errors::External(get_cufft_error_info(error))); - } + PADDLE_ENFORCE_CUDA_SUCCESS(error); } // This struct is used to easily compute hashes of the @@ -413,6 +411,7 @@ void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out, ? framework::ToRealType(input.type()) : input.type(); auto fft_type = GetFFTTransformType(input.type(), output.type()); + PlanKey Key(framework::vectorize(input.dims()), framework::vectorize(output.dims()), signal_size, fft_type, value_type); diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index c420a5a64be06..7427060add8b1 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -31,6 +31,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_CUDA #include #include +#include #include #include #include @@ -714,6 +715,7 @@ DEFINE_EXTERNAL_API_TYPE(curandStatus_t, CURAND_STATUS_SUCCESS, CURAND); DEFINE_EXTERNAL_API_TYPE(cudnnStatus_t, CUDNN_STATUS_SUCCESS, CUDNN); DEFINE_EXTERNAL_API_TYPE(cublasStatus_t, CUBLAS_STATUS_SUCCESS, CUBLAS); DEFINE_EXTERNAL_API_TYPE(cusolverStatus_t, CUSOLVER_STATUS_SUCCESS, CUSOLVER); +DEFINE_EXTERNAL_API_TYPE(cufftResult_t, CUFFT_SUCCESS, CUFFT); #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) DEFINE_EXTERNAL_API_TYPE(ncclResult_t, ncclSuccess, NCCL); @@ -751,6 +753,8 @@ inline const char* GetErrorMsgUrl(T status) { return "https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/" "types.html#ncclresult-t"; break; + case platform::proto::ApiType::CUFFT: + return "https://docs.nvidia.com/cuda/cufft/index.html#cufftresult"; default: return "Unknown type of External API, can't get error message URL!"; break; @@ -839,6 +843,7 @@ template std::string GetExternalErrorMsg(curandStatus_t); template std::string GetExternalErrorMsg(cudnnStatus_t); template std::string GetExternalErrorMsg(cublasStatus_t); template std::string GetExternalErrorMsg(cusolverStatus_t); +template std::string GetExternalErrorMsg(cufftResult_t); #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) template std::string GetExternalErrorMsg(ncclResult_t); #endif @@ -899,6 +904,15 @@ inline std::string build_nvidia_error_msg(cusolverStatus_t stat) { return sout.str(); } +/*************** CUFFT ERROR ***************/ +inline bool is_error(cufftResult_t stat) { return stat != CUFFT_SUCCESS; } + +inline std::string build_nvidia_error_msg(cufftResult_t stat) { + std::ostringstream sout; + sout << "CUFFT error(" << stat << "). " << GetExternalErrorMsg(stat); + return sout.str(); +} + /**************** NCCL ERROR ****************/ #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) inline bool is_error(ncclResult_t nccl_result) { diff --git a/paddle/fluid/platform/enforce_test.cc b/paddle/fluid/platform/enforce_test.cc index 95a852ad6e92a..c6d5f171ddce4 100644 --- a/paddle/fluid/platform/enforce_test.cc +++ b/paddle/fluid/platform/enforce_test.cc @@ -9,10 +9,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/platform/enforce.h" + #include #include "gtest/gtest.h" -#include "paddle/fluid/platform/enforce.h" TEST(ENFORCE, OK) { PADDLE_ENFORCE(true, paddle::platform::errors::Unavailable( @@ -418,6 +419,25 @@ TEST(enforce, cuda_success) { "negative vector size, for example).To correct: ensure that all the " "parameters being passed have valid values")); + EXPECT_TRUE(CheckCudaStatusSuccess(CUFFT_SUCCESS)); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_INVALID_PLAN, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_ALLOC_FAILED, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_INVALID_TYPE, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_INVALID_VALUE, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_INTERNAL_ERROR, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_EXEC_FAILED, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_SETUP_FAILED, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_INVALID_SIZE, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_UNALIGNED_DATA, "CUFFT error")); + EXPECT_TRUE( + CheckCudaStatusFailure(CUFFT_INCOMPLETE_PARAMETER_LIST, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_INVALID_DEVICE, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_PARSE_ERROR, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_NO_WORKSPACE, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_NOT_IMPLEMENTED, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_LICENSE_ERROR, "CUFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(CUFFT_NOT_SUPPORTED, "CUFFT error")); + #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) EXPECT_TRUE(CheckCudaStatusSuccess(ncclSuccess)); EXPECT_TRUE(CheckCudaStatusFailure(ncclUnhandledCudaError, "NCCL error")); diff --git a/paddle/fluid/platform/external_error.proto b/paddle/fluid/platform/external_error.proto index 2094de7e10f69..cbbf803492e64 100644 --- a/paddle/fluid/platform/external_error.proto +++ b/paddle/fluid/platform/external_error.proto @@ -24,6 +24,7 @@ enum ApiType { CUBLAS = 3; CUSOLVER = 4; NCCL = 5; + CUFFT = 6; } message MessageDesc { diff --git a/tools/externalError/README.md b/tools/externalError/README.md index 029efd8cb9491..0c2ac626991da 100644 --- a/tools/externalError/README.md +++ b/tools/externalError/README.md @@ -1,9 +1,25 @@ -Usage: +#### **Introduction for crawling new error message:** -Please run: -``` -bash start.sh -``` -If you want to update all external error message, you need to run command `bash start.sh` in current directory, -and upload the generated file `externalErrorMsg.tar.gz` to https://paddlepaddledeps.bj.bcebos.com/externalErrorMsg.tar.gz + +1. add new spider code in spider.py for crawling error message from website. + +2. run `bash start.sh` in current directory to generate new externalErrorMsg_${date}.tar.gz file, for example `externalErrorMsg_20210928.tar.gz`. + +3. upload above tar file into bos https://paddlepaddledeps.bj.bcebos.com **paddlepaddledeps** bucket, and copy download link `${download_url}`. ***\*Be careful not to delete original tar file\****. + +4. compute md5 value of above tar file `${md5}`, and modify cmake/third_party.cmake file + + ``` + set(URL "${download_url}" CACHE STRING "" FORCE) + file_download_and_uncompress(${URL} "externalError" MD5 ${md5}) + ``` + + for example: + + ``` + set(URL "https://paddlepaddledeps.bj.bcebos.com/externalErrorMsg_20210928.tar.gz" CACHE STRING "" FORCE) + file_download_and_uncompress(${URL} "externalError" MD5 a712a49384e77ca216ad866712f7cafa) + ``` + +5. commit your changes, and create pull request. diff --git a/tools/externalError/spider.py b/tools/externalError/spider.py index a74d82f40ebeb..e07f05f561cb5 100644 --- a/tools/externalError/spider.py +++ b/tools/externalError/spider.py @@ -17,8 +17,10 @@ import urllib.request import json import collections -import sys, getopt +import sys +import getopt import external_error_pb2 +from html.parser import HTMLParser def parsing(externalErrorDesc): @@ -335,6 +337,31 @@ def parsing(externalErrorDesc): _Messages.message = "'%s'. %s" % (error[0], m_message) print("End crawling errorMessage for nvidia NCCL API!\n") + #*************************************************************************************************# + #*********************************** CUFFT Error Message **************************************# + print("start crawling errorMessage for nvidia CUFFT API--->") + url = 'https://docs.nvidia.com/cuda/cufft/index.html#cufftresult' + + allMessageDesc = externalErrorDesc.errors.add() + allMessageDesc.type = external_error_pb2.CUFFT + + html = urllib.request.urlopen(url).read().decode('utf-8') + + class CUFFTHTMLParser(HTMLParser): + '''CUFFTHTML Parser + ''' + + def handle_data(self, data): + if 'typedef enum cufftResult_t' in data: + for line in data.strip().splitlines()[1:-1]: + status, code, desc = re.split('=|//', line.strip()) + _Messages = allMessageDesc.messages.add() + _Messages.code = int(code.strip(' ,')) + _Messages.message = "'%s'. %s" % (status.strip(), + desc.strip()) + + CUFFTHTMLParser().feed(html) + def main(argv): try: diff --git a/tools/externalError/start.sh b/tools/externalError/start.sh index 32ef63c261268..82715dd47326c 100644 --- a/tools/externalError/start.sh +++ b/tools/externalError/start.sh @@ -32,4 +32,4 @@ fi protobuf/bin/protoc -I../../paddle/fluid/platform/ --python_out . ../../paddle/fluid/platform/external_error.proto python3.7 spider.py -tar czvf externalErrorMsg.tar.gz externalErrorMsg.pb +tar czvf externalErrorMsg_$(date +'%Y%m%d').tar.gz externalErrorMsg.pb From 14e553e57186ccfe142c5453c4c5974c0934c824 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?LJQ=E2=9D=A4=EF=B8=8F?= <33169170+lijiaqi0612@users.noreply.github.com> Date: Tue, 12 Oct 2021 10:29:03 +0800 Subject: [PATCH 04/16] fft: modify sample code result (#36325) --- python/paddle/tensor/fft.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/python/paddle/tensor/fft.py b/python/paddle/tensor/fft.py index f7990e3f89107..20fd143589fa4 100644 --- a/python/paddle/tensor/fft.py +++ b/python/paddle/tensor/fft.py @@ -339,7 +339,7 @@ def irfft(x, n=None, axis=-1, norm="backward", name=None): xp = paddle.to_tensor(x) irfft_xp = paddle.fft.irfft(xp).numpy() print(irfft_xp) - # [0. 0. 0. 4.] + # [0. 1. 0. 0.] """ return fft_c2r(x, n, axis, norm, forward=False, name=name) @@ -477,7 +477,7 @@ def fftn(x, s=None, axes=None, norm="backward", name=None): import numpy as np import paddle - x = x = np.mgrid[:4, :4, :4][1] + x = np.mgrid[:4, :4, :4][1] xp = paddle.to_tensor(x) fftn_xp = paddle.fft.fftn(xp, axes=(1, 2)).numpy() print(fftn_xp) @@ -631,9 +631,9 @@ def rfftn(x, s=None, axes=None, norm="backward", name=None): # use axes(2, 0) print(paddle.fft.rfftn(x, axes=(2, 0))) # Tensor(shape=[2, 3, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, - # [[[(24+0j), 0j , 0j ], - # [0j , 0j , 0j ], - # [0j , 0j , 0j ]], + # [[[(8+0j), 0j , 0j ], + # [(8+0j), 0j , 0j ], + # [(8+0j), 0j , 0j ]], # # [[0j , 0j , 0j ], # [0j , 0j , 0j ], @@ -1267,9 +1267,8 @@ def fftshift(x, axes=None, name=None): import paddle x = np.array([3, 1, 2, 2, 3], dtype=float) - scalar_temp = 0.3 n = x.size - fftfreq_xp = paddle.fft.fftfreq(n, d=scalar_temp) + fftfreq_xp = paddle.fft.fftfreq(n, d=0.3) res = paddle.fft.fftshift(fftfreq_xp).numpy() print(res) # [-1.3333334 -0.6666667 0. 0.6666667 1.3333334] @@ -1311,9 +1310,8 @@ def ifftshift(x, axes=None, name=None): import paddle x = np.array([3, 1, 2, 2, 3], dtype=float) - scalar_temp = 0.3 n = x.size - fftfreq_xp = paddle.fft.fftfreq(n, d=scalar_temp) + fftfreq_xp = paddle.fft.fftfreq(n, d=0.3) res = paddle.fft.ifftshift(fftfreq_xp).numpy() print(res) # [ 1.3333334 -1.3333334 -0.6666667 0. 0.6666667] From 12cbd4e6b1e2c7d64aa811662528924d0a2445db Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Fri, 15 Oct 2021 12:46:24 +0800 Subject: [PATCH 05/16] dynamic load mkl as a fft backend when it is avaialble and requested (#36414) --- paddle/fluid/operators/CMakeLists.txt | 15 ++- paddle/fluid/operators/spectral_op.cc | 113 +++++++++--------- paddle/fluid/platform/dynload/CMakeLists.txt | 6 + .../fluid/platform/dynload/dynamic_loader.cc | 16 +++ .../fluid/platform/dynload/dynamic_loader.h | 1 + paddle/fluid/platform/dynload/mklrt.cc | 51 ++++++++ paddle/fluid/platform/dynload/mklrt.h | 80 +++++++++++++ 7 files changed, 221 insertions(+), 61 deletions(-) create mode 100644 paddle/fluid/platform/dynload/mklrt.cc create mode 100644 paddle/fluid/platform/dynload/mklrt.h diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index c487313f91c58..a4fed3bb0fd76 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -102,10 +102,21 @@ else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() + if (WITH_GPU AND (NOT WITH_ROCM)) - op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS}) + if (MKL_FOUND AND WITH_ONEMKL) + op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda dynload_mklrt ${OP_HEADER_DEPS}) + target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE}) + else() + op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS}) + endif() else() - op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS}) + if (MKL_FOUND AND WITH_ONEMKL) + op_library(spectral_op SRCS spectral_op.cc DEPS dynload_mklrt ${OP_HEADER_DEPS}) + target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE}) + else() + op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS}) + endif() endif() op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute) diff --git a/paddle/fluid/operators/spectral_op.cc b/paddle/fluid/operators/spectral_op.cc index fb50702233b3b..b5edc1dda533b 100644 --- a/paddle/fluid/operators/spectral_op.cc +++ b/paddle/fluid/operators/spectral_op.cc @@ -27,7 +27,7 @@ #include "paddle/fluid/platform/complex.h" #if defined(PADDLE_WITH_ONEMKL) -#include +#include "paddle/fluid/platform/dynload/mklrt.h" #elif defined(PADDLE_WITH_POCKETFFT) #include "extern_pocketfft/pocketfft_hdronly.h" #endif @@ -357,46 +357,45 @@ FFTNormMode get_norm_from_string(const std::string& norm, bool forward) { // FFT Functors #if defined(PADDLE_WITH_ONEMKL) +#define MKL_DFTI_CHECK(expr) \ + do { \ + MKL_LONG status = (expr); \ + if (!platform::dynload::DftiErrorClass(status, DFTI_NO_ERROR)) \ + PADDLE_THROW(platform::errors::External( \ + platform::dynload::DftiErrorMessage(status))); \ + } while (0); + namespace { -static inline void MKL_DFTI_CHECK(MKL_INT status) { - if (status && !DftiErrorClass(status, DFTI_NO_ERROR)) { - PADDLE_THROW(platform::errors::External(DftiErrorMessage(status))); - } -} struct DftiDescriptorDeleter { void operator()(DFTI_DESCRIPTOR_HANDLE handle) { if (handle != nullptr) { - MKL_DFTI_CHECK(DftiFreeDescriptor(&handle)); + MKL_DFTI_CHECK(platform::dynload::DftiFreeDescriptor(&handle)); } } }; +// A RAII wrapper for MKL_DESCRIPTOR* class DftiDescriptor { public: void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type, MKL_LONG signal_ndim, MKL_LONG* sizes) { - if (desc_ != nullptr) { - PADDLE_THROW(platform::errors::AlreadyExists( - "DFT DESCRIPTOR can only be initialized once.")); - } + PADDLE_ENFORCE_EQ(desc_.get(), nullptr, + platform::errors::AlreadyExists( + "DftiDescriptor has already been initialized.")); + DFTI_DESCRIPTOR* raw_desc; - if (signal_ndim == 1) { - MKL_DFTI_CHECK( - DftiCreateDescriptor(&raw_desc, precision, signal_type, 1, sizes[0])); - } else { - MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type, - signal_ndim, sizes)); - } + MKL_DFTI_CHECK(platform::dynload::DftiCreateDescriptorX( + &raw_desc, precision, signal_type, signal_ndim, sizes)); desc_.reset(raw_desc); } DFTI_DESCRIPTOR* get() const { - if (desc_ == nullptr) { - PADDLE_THROW(platform::errors::PreconditionNotMet( - "DFTI DESCRIPTOR has not been initialized.")); - } - return desc_.get(); + DFTI_DESCRIPTOR* raw_desc = desc_.get(); + PADDLE_ENFORCE_NOT_NULL(raw_desc, + platform::errors::PreconditionNotMet( + "DFTI DESCRIPTOR has not been initialized.")); + return raw_desc; } private: @@ -421,7 +420,9 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, return DFTI_DOUBLE; default: PADDLE_THROW(platform::errors::InvalidArgument( - "Input data type should be FP32, FP64, COMPLEX64 or COMPLEX128.")); + "Invalid input datatype (%s), input data type should be FP32, " + "FP64, COMPLEX64 or COMPLEX128.", + framework::DataTypeToString(in_dtype))); } }(); @@ -430,35 +431,27 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, const DFTI_CONFIG_VALUE domain = (fft_type == FFTTransformType::C2C) ? DFTI_COMPLEX : DFTI_REAL; - // const bool complex_input = framework::IsComplexType(in_dtype); - // const bool complex_output = framework::IsComplexType(out_dtype); - // const DFTI_CONFIG_VALUE domain = [&] { - // if (forward) { - // return complex_input ? DFTI_COMPLEX : DFTI_REAL; - // } else { - // return complex_output ? DFTI_COMPLEX : DFTI_REAL; - // } - // }(); - DftiDescriptor descriptor; std::vector fft_sizes(signal_sizes.cbegin(), signal_sizes.cend()); const MKL_LONG signal_ndim = fft_sizes.size() - 1; descriptor.init(precision, domain, signal_ndim, fft_sizes.data() + 1); // placement inplace or not inplace - MKL_DFTI_CHECK( - DftiSetValue(descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE)); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE)); // number of transformations const MKL_LONG batch_size = fft_sizes[0]; - MKL_DFTI_CHECK( - DftiSetValue(descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size)); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size)); // input & output distance const MKL_LONG idist = in_strides[0]; const MKL_LONG odist = out_strides[0]; - MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist)); - MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_DISTANCE, odist)); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(), + DFTI_INPUT_DISTANCE, idist)); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(), + DFTI_OUTPUT_DISTANCE, odist)); // input & output stride std::vector mkl_in_stride(1 + signal_ndim, 0); @@ -467,15 +460,15 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, mkl_in_stride[i] = in_strides[i]; mkl_out_stride[i] = out_strides[i]; } - MKL_DFTI_CHECK( - DftiSetValue(descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data())); - MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_STRIDES, - mkl_out_stride.data())); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data())); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_out_stride.data())); // conjugate even storage if (!(fft_type == FFTTransformType::C2C)) { - MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, - DFTI_COMPLEX_COMPLEX)); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue( + descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX)); } MKL_LONG signal_numel = @@ -496,11 +489,12 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype, return DFTI_BACKWARD_SCALE; } }(); - MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), scale_direction, scale)); + MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(), + scale_direction, scale)); } // commit the descriptor - MKL_DFTI_CHECK(DftiCommitDescriptor(descriptor.get())); + MKL_DFTI_CHECK(platform::dynload::DftiCommitDescriptor(descriptor.get())); return descriptor; } @@ -592,15 +586,16 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out, collapsed_input.numel(), collapsed_input_conj.data()); for_range(functor); - MKL_DFTI_CHECK(DftiComputeBackward(desc.get(), - collapsed_input_conj.data(), - collapsed_output.data())); + MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward( + desc.get(), collapsed_input_conj.data(), + collapsed_output.data())); } else if (fft_type == FFTTransformType::R2C && !forward) { framework::Tensor collapsed_output_conj(collapsed_output.type()); collapsed_output_conj.mutable_data(collapsed_output.dims(), ctx.GetPlace()); - MKL_DFTI_CHECK(DftiComputeForward(desc.get(), collapsed_input.data(), - collapsed_output_conj.data())); + MKL_DFTI_CHECK(platform::dynload::DftiComputeForward( + desc.get(), collapsed_input.data(), + collapsed_output_conj.data())); // conjugate the output platform::ForRange for_range(ctx, collapsed_output.numel()); math::ConjFunctor functor(collapsed_output_conj.data(), @@ -609,13 +604,13 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out, for_range(functor); } else { if (forward) { - MKL_DFTI_CHECK(DftiComputeForward(desc.get(), - collapsed_input.data(), - collapsed_output.data())); + MKL_DFTI_CHECK(platform::dynload::DftiComputeForward( + desc.get(), collapsed_input.data(), + collapsed_output.data())); } else { - MKL_DFTI_CHECK(DftiComputeBackward(desc.get(), - collapsed_input.data(), - collapsed_output.data())); + MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward( + desc.get(), collapsed_input.data(), + collapsed_output.data())); } } diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index c0d4b349a9e09..8c64aad46cfc8 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -49,3 +49,9 @@ endif() cc_library(dynload_lapack SRCS lapack.cc DEPS dynamic_loader) add_dependencies(dynload_lapack extern_lapack) # TODO(TJ): add iomp, mkldnn? + +if (MKL_FOUND AND WITH_ONEMKL) + message("ONEMKL INCLUDE directory is ${MKL_INCLUDE}") + cc_library(dynload_mklrt SRCS mklrt.cc DEPS dynamic_loader) + target_include_directories(dynload_mklrt PRIVATE ${MKL_INCLUDE}) +endif() diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index a83f085f7d2d8..0c5c47e38f85e 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -53,6 +53,12 @@ DEFINE_string(mklml_dir, "", "Specify path for loading libmklml_intel.so."); DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so."); +DEFINE_string(mkl_dir, "", + "Specify path for loading libmkl_rt.so. " + "For insrance, /opt/intel/oneapi/mkl/latest/lib/intel64/." + "If default, " + "dlopen will search mkl from LD_LIBRARY_PATH"); + DEFINE_string(op_dir, "", "Specify path for loading user-defined op library."); #ifdef PADDLE_WITH_HIP @@ -518,6 +524,16 @@ void* GetCUFFTDsoHandle() { #endif } +void* GetMKLRTDsoHandle() { +#if defined(__APPLE__) || defined(__OSX__) + return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "libmkl_rt.dylib"); +#elif defined(_WIN32) + return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "mkl_rt.dll"); +#else + return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "libmkl_rt.so"); +#endif +} + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/dynamic_loader.h b/paddle/fluid/platform/dynload/dynamic_loader.h index 82c36d9e224f4..6260efdf71c59 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.h +++ b/paddle/fluid/platform/dynload/dynamic_loader.h @@ -43,6 +43,7 @@ void* GetLAPACKDsoHandle(); void* GetOpDsoHandle(const std::string& dso_name); void* GetNvtxDsoHandle(); void* GetCUFFTDsoHandle(); +void* GetMKLRTDsoHandle(); void SetPaddleLibPath(const std::string&); } // namespace dynload diff --git a/paddle/fluid/platform/dynload/mklrt.cc b/paddle/fluid/platform/dynload/mklrt.cc new file mode 100644 index 0000000000000..45fad15fb583e --- /dev/null +++ b/paddle/fluid/platform/dynload/mklrt.cc @@ -0,0 +1,51 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/dynload/mklrt.h" + +namespace paddle { +namespace platform { +namespace dynload { + +std::once_flag mklrt_dso_flag; +void* mklrt_dso_handle = nullptr; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +MKLDFTI_ROUTINE_EACH(DEFINE_WRAP); + +DFTI_EXTERN MKL_LONG DftiCreateDescriptorX(DFTI_DESCRIPTOR_HANDLE* desc, + enum DFTI_CONFIG_VALUE prec, + enum DFTI_CONFIG_VALUE domain, + MKL_LONG dim, MKL_LONG* sizes) { + if (prec == DFTI_SINGLE) { + if (dim == 1) { + return DftiCreateDescriptor_s_1d(desc, domain, sizes[0]); + } else { + return DftiCreateDescriptor_s_md(desc, domain, dim, sizes); + } + } else if (prec == DFTI_DOUBLE) { + if (dim == 1) { + return DftiCreateDescriptor_d_1d(desc, domain, sizes[0]); + } else { + return DftiCreateDescriptor_d_md(desc, domain, dim, sizes); + } + } else { + return DftiCreateDescriptor(desc, prec, domain, dim, sizes); + } +} + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/mklrt.h b/paddle/fluid/platform/dynload/mklrt.h new file mode 100644 index 0000000000000..423cd4d0a254c --- /dev/null +++ b/paddle/fluid/platform/dynload/mklrt.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include // NOLINT + +#include "paddle/fluid/platform/dynload/dynamic_loader.h" +#include "paddle/fluid/platform/port.h" + +namespace paddle { +namespace platform { +namespace dynload { + +extern std::once_flag mklrt_dso_flag; +extern void* mklrt_dso_handle; + +/** + * The following macro definition can generate structs + * (for each function) to dynamic load mkldfti routine + * via operator overloading. + */ +#define DYNAMIC_LOAD_MKLRT_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ + using mklrtFunc = decltype(&::__name); \ + std::call_once(mklrt_dso_flag, []() { \ + mklrt_dso_handle = paddle::platform::dynload::GetMKLRTDsoHandle(); \ + }); \ + static void* p_##__name = dlsym(mklrt_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern DynLoad__##__name __name + +// mkl_dfti.h has a macro that shadows the function with the same name +// un-defeine this macro so as to export that function +#undef DftiCreateDescriptor + +#define MKLDFTI_ROUTINE_EACH(__macro) \ + __macro(DftiCreateDescriptor); \ + __macro(DftiCreateDescriptor_s_1d); \ + __macro(DftiCreateDescriptor_d_1d); \ + __macro(DftiCreateDescriptor_s_md); \ + __macro(DftiCreateDescriptor_d_md); \ + __macro(DftiSetValue); \ + __macro(DftiGetValue); \ + __macro(DftiCommitDescriptor); \ + __macro(DftiComputeForward); \ + __macro(DftiComputeBackward); \ + __macro(DftiFreeDescriptor); \ + __macro(DftiErrorClass); \ + __macro(DftiErrorMessage); + +MKLDFTI_ROUTINE_EACH(DYNAMIC_LOAD_MKLRT_WRAP) + +#undef DYNAMIC_LOAD_MKLRT_WRAP + +// define another function to avoid naming conflict +DFTI_EXTERN MKL_LONG DftiCreateDescriptorX(DFTI_DESCRIPTOR_HANDLE* desc, + enum DFTI_CONFIG_VALUE prec, + enum DFTI_CONFIG_VALUE domain, + MKL_LONG dim, MKL_LONG* sizes); + +} // namespace dynload +} // namespace platform +} // namespace paddle From 01d9d0730b1930d933d96668f096ad8cd8c6e42d Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Tue, 19 Oct 2021 13:13:16 +0800 Subject: [PATCH 06/16] add rocm support for fft api (#36415) --- paddle/fluid/operators/CMakeLists.txt | 3 +- paddle/fluid/operators/spectral_helper.h | 261 ++++++++ paddle/fluid/operators/spectral_op.cu | 614 +++++++----------- paddle/fluid/platform/dynload/CMakeLists.txt | 2 +- .../fluid/platform/dynload/dynamic_loader.cc | 10 + .../fluid/platform/dynload/dynamic_loader.h | 1 + paddle/fluid/platform/dynload/hipfft.cc | 30 + paddle/fluid/platform/dynload/hipfft.h | 124 ++++ paddle/fluid/platform/enforce.h | 10 + paddle/fluid/platform/enforce_test.cc | 4 + 10 files changed, 679 insertions(+), 380 deletions(-) create mode 100644 paddle/fluid/operators/spectral_helper.h create mode 100644 paddle/fluid/platform/dynload/hipfft.cc create mode 100644 paddle/fluid/platform/dynload/hipfft.h diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index a4fed3bb0fd76..66be7d73b54ac 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -102,8 +102,7 @@ else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() - -if (WITH_GPU AND (NOT WITH_ROCM)) +if (WITH_GPU OR WITH_ROCM) if (MKL_FOUND AND WITH_ONEMKL) op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda dynload_mklrt ${OP_HEADER_DEPS}) target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE}) diff --git a/paddle/fluid/operators/spectral_helper.h b/paddle/fluid/operators/spectral_helper.h new file mode 100644 index 0000000000000..9c34d500eac92 --- /dev/null +++ b/paddle/fluid/operators/spectral_helper.h @@ -0,0 +1,261 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/operators/spectral_op.h" + +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/dynload/hipfft.h" +#endif + +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/dynload/cufft.h" +#endif + +namespace paddle { +namespace operators { +using ScalarType = framework::proto::VarType::Type; +const int64_t kMaxCUFFTNdim = 3; +const int64_t kMaxDataNdim = kMaxCUFFTNdim + 1; +// This struct is used to easily compute hashes of the +// parameters. It will be the **key** to the plan cache. +struct PlanKey { + // between 1 and kMaxCUFFTNdim, i.e., 1 <= signal_ndim <= 3 + int64_t signal_ndim_; + // These include additional batch dimension as well. + int64_t sizes_[kMaxDataNdim]; + int64_t input_shape_[kMaxDataNdim]; + int64_t output_shape_[kMaxDataNdim]; + FFTTransformType fft_type_; + ScalarType value_type_; + + PlanKey() = default; + + PlanKey(const std::vector& in_shape, + const std::vector& out_shape, + const std::vector& signal_size, FFTTransformType fft_type, + ScalarType value_type) { + // Padding bits must be zeroed for hashing + memset(this, 0, sizeof(*this)); + signal_ndim_ = signal_size.size() - 1; + fft_type_ = fft_type; + value_type_ = value_type; + + std::copy(signal_size.cbegin(), signal_size.cend(), sizes_); + std::copy(in_shape.cbegin(), in_shape.cend(), input_shape_); + std::copy(out_shape.cbegin(), out_shape.cend(), output_shape_); + } +}; + +#if defined(PADDLE_WITH_CUDA) +// An RAII encapsulation of cuFFTHandle +class CuFFTHandle { + ::cufftHandle handle_; + + public: + CuFFTHandle() { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftCreate(&handle_)); + } + + ::cufftHandle& get() { return handle_; } + const ::cufftHandle& get() const { return handle_; } + + ~CuFFTHandle() { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftDestroy(handle_)); + } +}; + +using plan_size_type = long long int; // NOLINT +// This class contains all the information needed to execute a cuFFT plan: +// 1. the plan +// 2. the workspace size needed +class CuFFTConfig { + public: + // Only move semantics is enought for this class. Although we already use + // unique_ptr for the plan, still remove copy constructor and assignment op so + // we don't accidentally copy and take perf hit. + explicit CuFFTConfig(const PlanKey& plan_key) + : CuFFTConfig( + std::vector(plan_key.sizes_, + plan_key.sizes_ + plan_key.signal_ndim_ + 1), + plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {} + + // sizes are full signal, including batch size and always two-sided + CuFFTConfig(const std::vector& sizes, const int64_t signal_ndim, + FFTTransformType fft_type, ScalarType dtype) + : fft_type_(fft_type), value_type_(dtype) { + // signal sizes (excluding batch dim) + std::vector signal_sizes(sizes.begin() + 1, sizes.end()); + + // input batch size + const auto batch = static_cast(sizes[0]); + // const int64_t signal_ndim = sizes.size() - 1; + PADDLE_ENFORCE_EQ(signal_ndim, sizes.size() - 1, + platform::errors::InvalidArgument( + "The signal_ndim must be equal to sizes.size() - 1," + "But signal_ndim is: [%d], sizes.size() - 1 is: [%d]", + signal_ndim, sizes.size() - 1)); + + cudaDataType itype, otype, exec_type; + const auto complex_input = has_complex_input(fft_type); + const auto complex_output = has_complex_output(fft_type); + if (dtype == framework::proto::VarType::FP32) { + itype = complex_input ? CUDA_C_32F : CUDA_R_32F; + otype = complex_output ? CUDA_C_32F : CUDA_R_32F; + exec_type = CUDA_C_32F; + } else if (dtype == framework::proto::VarType::FP64) { + itype = complex_input ? CUDA_C_64F : CUDA_R_64F; + otype = complex_output ? CUDA_C_64F : CUDA_R_64F; + exec_type = CUDA_C_64F; + } else if (dtype == framework::proto::VarType::FP16) { + itype = complex_input ? CUDA_C_16F : CUDA_R_16F; + otype = complex_output ? CUDA_C_16F : CUDA_R_16F; + exec_type = CUDA_C_16F; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "cuFFT only support transforms of type float16, float32 and " + "float64")); + } + + // disable auto allocation of workspace to use allocator from the framework + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftSetAutoAllocation( + plan(), /* autoAllocate */ 0)); + + size_t ws_size_t; + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftXtMakePlanMany( + plan(), signal_ndim, signal_sizes.data(), + /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype, + /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype, + batch, &ws_size_t, exec_type)); + + ws_size = ws_size_t; + } + + const cufftHandle& plan() const { return plan_ptr.get(); } + + FFTTransformType transform_type() const { return fft_type_; } + ScalarType data_type() const { return value_type_; } + size_t workspace_size() const { return ws_size; } + + private: + CuFFTHandle plan_ptr; + size_t ws_size; + FFTTransformType fft_type_; + ScalarType value_type_; +}; + +#elif defined(PADDLE_WITH_HIP) +// An RAII encapsulation of cuFFTHandle +class HIPFFTHandle { + ::hipfftHandle handle_; + + public: + HIPFFTHandle() { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftCreate(&handle_)); + } + + ::hipfftHandle& get() { return handle_; } + const ::hipfftHandle& get() const { return handle_; } + + ~HIPFFTHandle() { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftDestroy(handle_)); + } +}; +using plan_size_type = int; +// This class contains all the information needed to execute a cuFFT plan: +// 1. the plan +// 2. the workspace size needed +class HIPFFTConfig { + public: + // Only move semantics is enought for this class. Although we already use + // unique_ptr for the plan, still remove copy constructor and assignment op so + // we don't accidentally copy and take perf hit. + explicit HIPFFTConfig(const PlanKey& plan_key) + : HIPFFTConfig( + std::vector(plan_key.sizes_, + plan_key.sizes_ + plan_key.signal_ndim_ + 1), + plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {} + + // sizes are full signal, including batch size and always two-sided + HIPFFTConfig(const std::vector& sizes, const int64_t signal_ndim, + FFTTransformType fft_type, ScalarType dtype) + : fft_type_(fft_type), value_type_(dtype) { + // signal sizes (excluding batch dim) + std::vector signal_sizes(sizes.begin() + 1, sizes.end()); + + // input batch size + const auto batch = static_cast(sizes[0]); + // const int64_t signal_ndim = sizes.size() - 1; + PADDLE_ENFORCE_EQ(signal_ndim, sizes.size() - 1, + platform::errors::InvalidArgument( + "The signal_ndim must be equal to sizes.size() - 1," + "But signal_ndim is: [%d], sizes.size() - 1 is: [%d]", + signal_ndim, sizes.size() - 1)); + + hipfftType exec_type = [&] { + if (dtype == framework::proto::VarType::FP32) { + switch (fft_type) { + case FFTTransformType::C2C: + return HIPFFT_C2C; + case FFTTransformType::R2C: + return HIPFFT_R2C; + case FFTTransformType::C2R: + return HIPFFT_C2R; + } + } else if (dtype == framework::proto::VarType::FP64) { + switch (fft_type) { + case FFTTransformType::C2C: + return HIPFFT_Z2Z; + case FFTTransformType::R2C: + return HIPFFT_D2Z; + case FFTTransformType::C2R: + return HIPFFT_Z2D; + } + } + PADDLE_THROW(platform::errors::InvalidArgument( + "hipFFT only support transforms of type float32 and float64")); + }(); + + // disable auto allocation of workspace to use allocator from the framework + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftSetAutoAllocation( + plan(), /* autoAllocate */ 0)); + + size_t ws_size_t; + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftMakePlanMany( + plan(), signal_ndim, signal_sizes.data(), + /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, + /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, exec_type, + batch, &ws_size_t)); + + ws_size = ws_size_t; + } + + const hipfftHandle& plan() const { return plan_ptr.get(); } + + FFTTransformType transform_type() const { return fft_type_; } + ScalarType data_type() const { return value_type_; } + size_t workspace_size() const { return ws_size; } + + private: + HIPFFTHandle plan_ptr; + size_t ws_size; + FFTTransformType fft_type_; + ScalarType value_type_; +}; +#endif +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/spectral_op.cu b/paddle/fluid/operators/spectral_op.cu index 24dffaad41b5f..e8a4fac2915d7 100644 --- a/paddle/fluid/operators/spectral_op.cu +++ b/paddle/fluid/operators/spectral_op.cu @@ -8,10 +8,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#include -#include - #include #include #include @@ -24,311 +20,246 @@ #include #include "paddle/fluid/operators/conj_op.h" +#include "paddle/fluid/operators/spectral_helper.h" #include "paddle/fluid/operators/spectral_op.h" #include "paddle/fluid/operators/transpose_op.h" -#include "paddle/fluid/platform/dynload/cufft.h" +#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace operators { namespace { -using ScalarType = framework::proto::VarType::Type; -const int64_t kMaxCUFFTNdim = 3; -const int64_t kMaxDataNdim = kMaxCUFFTNdim + 1; - -static inline std::string get_cufft_error_info(cufftResult error) { - switch (error) { - case CUFFT_SUCCESS: - return "CUFFT_SUCCESS"; - case CUFFT_INVALID_PLAN: - return "CUFFT_INVALID_PLAN"; - case CUFFT_ALLOC_FAILED: - return "CUFFT_ALLOC_FAILED"; - case CUFFT_INVALID_TYPE: - return "CUFFT_INVALID_TYPE"; - case CUFFT_INVALID_VALUE: - return "CUFFT_INVALID_VALUE"; - case CUFFT_INTERNAL_ERROR: - return "CUFFT_INTERNAL_ERROR"; - case CUFFT_EXEC_FAILED: - return "CUFFT_EXEC_FAILED"; - case CUFFT_SETUP_FAILED: - return "CUFFT_SETUP_FAILED"; - case CUFFT_INVALID_SIZE: - return "CUFFT_INVALID_SIZE"; - case CUFFT_UNALIGNED_DATA: - return "CUFFT_UNALIGNED_DATA"; - case CUFFT_INCOMPLETE_PARAMETER_LIST: - return "CUFFT_INCOMPLETE_PARAMETER_LIST"; - case CUFFT_INVALID_DEVICE: - return "CUFFT_INVALID_DEVICE"; - case CUFFT_PARSE_ERROR: - return "CUFFT_PARSE_ERROR"; - case CUFFT_NO_WORKSPACE: - return "CUFFT_NO_WORKSPACE"; - case CUFFT_NOT_IMPLEMENTED: - return "CUFFT_NOT_IMPLEMENTED"; -#ifndef __HIPCC__ - case CUFFT_LICENSE_ERROR: - return "CUFFT_LICENSE_ERROR"; -#endif - case CUFFT_NOT_SUPPORTED: - return "CUFFT_NOT_SUPPORTED"; - default: - std::ostringstream ss; - ss << "unknown error " << error; - return ss.str(); +// Calculates the normalization constant +double fft_normalization_scale(FFTNormMode normalization, + const std::vector& sizes, + const std::vector& dims) { + // auto norm = static_cast(normalization); + if (normalization == FFTNormMode::none) { + return static_cast(1.0); } -} -static inline void CUFFT_CHECK(cufftResult error) { - PADDLE_ENFORCE_CUDA_SUCCESS(error); + int64_t signal_numel = 1; + for (auto dim : dims) { + signal_numel *= sizes[dim]; + } + const double scale_denom = (normalization == FFTNormMode::by_sqrt_n) + ? std::sqrt(signal_numel) + : static_cast(signal_numel); + return static_cast(1.0 / scale_denom); } -// This struct is used to easily compute hashes of the -// parameters. It will be the **key** to the plan cache. -struct PlanKey { - // between 1 and kMaxCUFFTNdim, i.e., 1 <= signal_ndim <= 3 - int64_t signal_ndim_; - // These include additional batch dimension as well. - int64_t sizes_[kMaxDataNdim]; - int64_t input_shape_[kMaxDataNdim]; - int64_t output_shape_[kMaxDataNdim]; - FFTTransformType fft_type_; - ScalarType value_type_; - - PlanKey() = default; - - PlanKey(const std::vector& in_shape, - const std::vector& out_shape, - const std::vector& signal_size, FFTTransformType fft_type, - ScalarType value_type) { - // Padding bits must be zeroed for hashing - memset(this, 0, sizeof(*this)); - signal_ndim_ = signal_size.size() - 1; - fft_type_ = fft_type; - value_type_ = value_type; - - std::copy(signal_size.cbegin(), signal_size.cend(), sizes_); - std::copy(in_shape.cbegin(), in_shape.cend(), input_shape_); - std::copy(out_shape.cbegin(), out_shape.cend(), output_shape_); +template +void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out, + FFTNormMode normalization, + const std::vector& sizes, + const std::vector& axes) { + double scale = fft_normalization_scale(normalization, sizes, axes); + if (scale != 1.0) { + auto eigen_out = framework::EigenVector::Flatten(*out); + auto eigen_in = framework::EigenVector::Flatten(*in); + auto dev = ctx.eigen_device(); + EigenScale::Eval(*dev, eigen_out, eigen_in, + static_cast(scale), + static_cast(0), false); + } else { + framework::TensorCopy(*in, ctx.GetPlace(), out); } -}; - -// An RAII encapsulation of cuFFTHandle -class CuFFTHandle { - ::cufftHandle handle_; - - public: - CuFFTHandle() { CUFFT_CHECK(platform::dynload::cufftCreate(&handle_)); } +} - ::cufftHandle& get() { return handle_; } - const ::cufftHandle& get() const { return handle_; } +#if defined(PADDLE_WITH_CUDA) +CuFFTConfig create_cufft_config(const framework::Tensor& input, + const framework::Tensor& output, + int signal_ndim) { + // Create the transform plan (either from cache or locally) + const auto value_type = framework::IsComplexType(input.type()) + ? framework::ToRealType(input.type()) + : input.type(); + auto fft_type = GetFFTTransformType(input.type(), output.type()); + // signal sizes + std::vector signal_size(signal_ndim + 1); - ~CuFFTHandle() { -// Not using fftDestroy() for rocFFT to work around double freeing of handles -#ifndef __HIPCC__ - CUFFT_CHECK(platform::dynload::cufftDestroy(handle_)); -#endif + signal_size[0] = input.dims()[0]; + for (int64_t i = 1; i <= signal_ndim; ++i) { + auto in_size = input.dims()[i]; + auto out_size = output.dims()[i]; + signal_size[i] = std::max(in_size, out_size); } -}; + PlanKey key(framework::vectorize(input.dims()), + framework::vectorize(output.dims()), signal_size, fft_type, + value_type); -#ifdef __HIPCC__ -using plan_size_type = int; -#else -using plan_size_type = long long int; // NOLINT -#endif + return CuFFTConfig(key); +} -// This class contains all the information needed to execute a cuFFT plan: -// 1. the plan -// 2. the workspace size needed -class CuFFTConfig { - public: - // Only move semantics is enought for this class. Although we already use - // unique_ptr for the plan, still remove copy constructor and assignment op so - // we don't accidentally copy and take perf hit. - CuFFTConfig(const CuFFTConfig&) = delete; - CuFFTConfig& operator=(CuFFTConfig const&) = delete; - - explicit CuFFTConfig(const PlanKey& plan_key) - : CuFFTConfig( - std::vector(plan_key.sizes_, - plan_key.sizes_ + plan_key.signal_ndim_ + 1), - plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {} - - // sizes are full signal, including batch size and always two-sided - CuFFTConfig(const std::vector& sizes, const int64_t signal_ndim, - FFTTransformType fft_type, ScalarType dtype) - : fft_type_(fft_type), value_type_(dtype) { - // signal sizes (excluding batch dim) - std::vector signal_sizes(sizes.begin() + 1, sizes.end()); - - // input batch size - const auto batch = static_cast(sizes[0]); - // const int64_t signal_ndim = sizes.size() - 1; - PADDLE_ENFORCE_EQ(signal_ndim, sizes.size() - 1, - platform::errors::InvalidArgument( - "The signal_ndim must be equal to sizes.size() - 1," - "But signal_ndim is: [%d], sizes.size() - 1 is: [%d]", - signal_ndim, sizes.size() - 1)); - -#ifdef __HIPCC__ - hipfftType exec_type = [&] { - if (dtype == framework::proto::VarType::FP32) { - switch (fft_type) { - case FFTTransformType::C2C: - return HIPFFT_C2C; - case FFTTransformType::R2C: - return HIPFFT_R2C; - case FFTTransformType::C2R: - return HIPFFT_C2R; - } - } else if (dtype == framework::proto::VarType::FP64) { - switch (fft_type) { - case FFTTransformType::C2C: - return HIPFFT_Z2Z; - case FFTTransformType::R2C: - return HIPFFT_D2Z; - case FFTTransformType::C2R: - return HIPFFT_Z2D; - } - } - PADDLE_THROW(platform::errors::InvalidArgument( - "hipFFT only support transforms of type float32 and float64")); - }(); -#else - cudaDataType itype, otype, exec_type; - const auto complex_input = has_complex_input(fft_type); - const auto complex_output = has_complex_output(fft_type); - if (dtype == framework::proto::VarType::FP32) { - itype = complex_input ? CUDA_C_32F : CUDA_R_32F; - otype = complex_output ? CUDA_C_32F : CUDA_R_32F; - exec_type = CUDA_C_32F; - } else if (dtype == framework::proto::VarType::FP64) { - itype = complex_input ? CUDA_C_64F : CUDA_R_64F; - otype = complex_output ? CUDA_C_64F : CUDA_R_64F; - exec_type = CUDA_C_64F; - } else if (dtype == framework::proto::VarType::FP16) { - itype = complex_input ? CUDA_C_16F : CUDA_R_16F; - otype = complex_output ? CUDA_C_16F : CUDA_R_16F; - exec_type = CUDA_C_16F; - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "cuFFT only support transforms of type float16, float32 and " - "float64")); - } -#endif +// Execute a pre-planned transform +static void exec_cufft_plan_raw(const CuFFTConfig& config, void* in_data, + void* out_data, bool forward) { + auto& plan = config.plan(); - // disable auto allocation of workspace to use allocator from the framework - CUFFT_CHECK(platform::dynload::cufftSetAutoAllocation( - plan(), /* autoAllocate */ 0)); - - size_t ws_size_t; - -// make plan -#ifdef __HIPCC__ - CUFFT_CHECK(hipfftMakePlanMany( - plan(), signal_ndim, signal_sizes.data(), - /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, - /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, exec_type, - batch, &ws_size_t)); -#else - - CUFFT_CHECK(platform::dynload::cufftXtMakePlanMany( - plan(), signal_ndim, signal_sizes.data(), - /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype, - /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype, - batch, &ws_size_t, exec_type)); -#endif + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftXtExec( + plan, in_data, out_data, forward ? CUFFT_FORWARD : CUFFT_INVERSE)); +} - ws_size = ws_size_t; +template +void exec_cufft_plan(const DeviceContext& ctx, const CuFFTConfig& config, + framework::Tensor* input, framework::Tensor* output, + bool forward) { + // execute transform plan + auto fft_type = config.transform_type(); + if (fft_type == FFTTransformType::C2R && forward) { + forward = false; + framework::Tensor input_conj(input->type()); + input_conj.mutable_data(input->dims(), ctx.GetPlace()); + platform::ForRange for_range(ctx, input->numel()); + math::ConjFunctor functor(input->data(), input->numel(), + input_conj.data()); + for_range(functor); + exec_cufft_plan_raw(config, input_conj.data(), output->data(), + forward); + } else if (fft_type == FFTTransformType::R2C && !forward) { + forward = true; + framework::Tensor out_conj(output->type()); + out_conj.mutable_data(output->dims(), ctx.GetPlace()); + exec_cufft_plan_raw(config, input->data(), out_conj.data(), + forward); + + platform::ForRange for_range(ctx, output->numel()); + math::ConjFunctor functor(out_conj.data(), output->numel(), + output->data()); + for_range(functor); + } else { + exec_cufft_plan_raw(config, input->data(), output->data(), + forward); } +} - const cufftHandle& plan() const { return plan_ptr.get(); } +#elif defined(PADDLE_WITH_HIP) - FFTTransformType transform_type() const { return fft_type_; } - ScalarType data_type() const { return value_type_; } - size_t workspace_size() const { return ws_size; } +HIPFFTConfig create_hipfft_config(const framework::Tensor& input, + const framework::Tensor& output, + int signal_ndim) { + // Create the transform plan (either from cache or locally) + const auto value_type = framework::IsComplexType(input.type()) + ? framework::ToRealType(input.type()) + : input.type(); + auto fft_type = GetFFTTransformType(input.type(), output.type()); + // signal sizes + std::vector signal_size(signal_ndim + 1); - private: - CuFFTHandle plan_ptr; - size_t ws_size; - FFTTransformType fft_type_; - ScalarType value_type_; -}; + signal_size[0] = input.dims()[0]; + for (int64_t i = 1; i <= signal_ndim; ++i) { + auto in_size = input.dims()[i]; + auto out_size = output.dims()[i]; + signal_size[i] = std::max(in_size, out_size); + } + PlanKey key(framework::vectorize(input.dims()), + framework::vectorize(output.dims()), signal_size, fft_type, + value_type); + + return HIPFFTConfig(key); +} // Execute a pre-planned transform -static void exec_cufft_plan(const CuFFTConfig& config, void* in_data, - void* out_data, bool forward) { +static void exec_hipfft_plan_raw(const HIPFFTConfig& config, void* in_data, + void* out_data, bool forward) { auto& plan = config.plan(); -#ifdef __HIPCC__ + auto value_type = config.data_type(); if (value_type == framework::proto::VarType::FP32) { switch (config.transform_type()) { case FFTTransformType::C2C: { - CUFFT_CHECK(hipfftExecC2C(plan, static_cast(in_data), - static_cast(out_data), - forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecC2C( + plan, static_cast(in_data), + static_cast(out_data), + forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD)); return; } case FFTTransformType::R2C: { - CUFFT_CHECK(hipfftExecR2C(plan, static_cast(in_data), - static_cast(out_data))); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecR2C( + plan, static_cast(in_data), + static_cast(out_data))); return; } case FFTTransformType::C2R: { - CUFFT_CHECK(hipfftExecC2R(plan, static_cast(in_data), - static_cast(out_data))); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecC2R( + plan, static_cast(in_data), + static_cast(out_data))); return; } } } else if (value_type == framework::proto::VarType::FP64) { switch (config.transform_type()) { case FFTTransformType::C2C: { - CUFFT_CHECK(hipfftExecZ2Z(plan, - static_cast(in_data), - static_cast(out_data), - forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecZ2Z( + plan, static_cast(in_data), + static_cast(out_data), + forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD)); return; } case FFTTransformType::R2C: { - CUFFT_CHECK(hipfftExecD2Z(plan, static_cast(in_data), - static_cast(out_data))); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecD2Z( + plan, static_cast(in_data), + static_cast(out_data))); return; } case FFTTransformType::C2R: { - CUFFT_CHECK(hipfftExecZ2D(plan, - static_cast(in_data), - static_cast(out_data))); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftExecZ2D( + plan, static_cast(in_data), + static_cast(out_data))); return; } } } PADDLE_THROW(platform::errors::InvalidArgument( "hipFFT only support transforms of type float32 and float64")); -#else - CUFFT_CHECK(platform::dynload::cufftXtExec( - plan, in_data, out_data, forward ? CUFFT_FORWARD : CUFFT_INVERSE)); -#endif } +template +void exec_hipfft_plan(const DeviceContext& ctx, const HIPFFTConfig& config, + framework::Tensor* input, framework::Tensor* output, + bool forward) { + auto fft_type = config.transform_type(); + if (fft_type == FFTTransformType::C2R && forward) { + forward = false; + framework::Tensor input_conj(input->type()); + input_conj.mutable_data(input->dims(), ctx.GetPlace()); + platform::ForRange for_range(ctx, input->numel()); + math::ConjFunctor functor(input->data(), input->numel(), + input_conj.data()); + for_range(functor); + exec_hipfft_plan_raw(config, input_conj.data(), output->data(), + forward); + } else if (fft_type == FFTTransformType::R2C && !forward) { + forward = true; + framework::Tensor out_conj(output->type()); + out_conj.mutable_data(output->dims(), ctx.GetPlace()); + exec_hipfft_plan_raw(config, input->data(), out_conj.data(), + forward); + + platform::ForRange for_range(ctx, output->numel()); + math::ConjFunctor functor(out_conj.data(), output->numel(), + output->data()); + for_range(functor); + } else { + exec_hipfft_plan_raw(config, input->data(), output->data(), + forward); + } +} + +#endif + // Execute a general unnormalized fft operation (can be c2c, onesided r2c or // onesided c2r) template void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out, const std::vector& dim, bool forward) { const auto x_dims = framework::vectorize(X->dims()); - const auto out_dims = framework::vectorize(out->dims()); const int64_t ndim = static_cast(X->dims().size()); - const int64_t signal_ndim = static_cast(dim.size()); - const int64_t batch_dims = ndim - signal_ndim; auto tensor_place = ctx.GetPlace(); - // Transpose batch dimensions first, then with transforming dims + // make a dim permutation std::vector dim_permute(ndim); - std::vector reverse_dim_permute(ndim); - std::vector trans_dims(ndim); std::iota(dim_permute.begin(), dim_permute.end(), int{0}); std::vector is_transformed_dim(ndim); for (const auto& d : dim) { @@ -340,160 +271,89 @@ void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out, std::sort(dim_permute.begin(), batch_end); std::copy(dim.cbegin(), dim.cend(), batch_end); - for (size_t i = 0; i < ndim; i++) { - trans_dims[i] = x_dims[dim_permute[i]]; // shape of input transpose - reverse_dim_permute[dim_permute[i]] = - static_cast(i); // reverse of dim permute - } - framework::Tensor input; - input.Resize(framework::make_ddim(trans_dims)); - input.mutable_data(tensor_place); - /* - auto in_ret = TransposeSimple::run(ctx, *X, dim_permute, input); - if (!in_ret) { - TransCompute(ndim, ctx, *X, input, dim_permute); - } - */ - TransCompute(ndim, ctx, *X, &input, dim_permute); + // transpose input according to dim permutation + auto transposed_input_shape = X->dims().transpose(dim_permute); + framework::Tensor transposed_input; + transposed_input.Resize(transposed_input_shape); + transposed_input.mutable_data(tensor_place); + TransCompute(ndim, ctx, *X, &transposed_input, + dim_permute); // Reshape batch dimensions into a single dimension - std::vector batched_sizes(signal_ndim + 1); + const int64_t signal_ndim = static_cast(dim.size()); + std::vector collapsed_input_shape(signal_ndim + 1); + + auto transposed_input_shape_ = framework::vectorize(transposed_input_shape); + const int64_t batch_dims = ndim - signal_ndim; auto batch_size = - std::accumulate(trans_dims.begin(), trans_dims.begin() + batch_dims, + std::accumulate(transposed_input_shape_.begin(), + transposed_input_shape_.begin() + batch_dims, static_cast(1), std::multiplies()); - batched_sizes[0] = batch_size; - std::copy(trans_dims.begin() + batch_dims, trans_dims.end(), - batched_sizes.begin() + 1); - input.Resize(framework::make_ddim(batched_sizes)); + collapsed_input_shape[0] = batch_size; - // Check the shape of transforming dims with input and output - std::vector signal_size(signal_ndim + 1); - signal_size[0] = batch_size; - for (int64_t i = 0; i < signal_ndim; ++i) { - auto in_size = input.dims()[i + 1]; - auto out_size = out_dims[dim[i]]; - signal_size[i + 1] = std::max(in_size, out_size); - PADDLE_ENFORCE_EQ( - (in_size == signal_size[i + 1] || - in_size == (signal_size[i + 1] / 2) + 1), - true, - platform::errors::InvalidArgument( - "The dimension[%d] of Input size: [%d] must be equal or half to " - "The dimension[%d] of Output size: [%d]", - dim[i], in_size, dim[i], out_size)); - PADDLE_ENFORCE_EQ( - (out_size == signal_size[i + 1] || - out_size == (signal_size[i + 1] / 2) + 1), - true, - platform::errors::InvalidArgument( - "The dimension[%d] of Output size: [%d] must be equal or half to " - "The dimension[%d] of Input size: [%d]", - dim[i], out_size, dim[i], in_size)); - } + std::copy(transposed_input_shape_.begin() + batch_dims, + transposed_input_shape_.end(), collapsed_input_shape.begin() + 1); - std::vector reshape_out_sizes(ndim); - for (size_t i = 0; i < ndim; ++i) { - reshape_out_sizes[i] = out_dims[dim_permute[i]]; - } - std::vector batched_out_sizes(batched_sizes.begin(), - batched_sizes.end()); + framework::Tensor& collapsed_input = transposed_input; + collapsed_input.Resize(framework::make_ddim(collapsed_input_shape)); + + // make a collpased output + const auto out_dims = framework::vectorize(out->dims()); + std::vector collapsed_output_shape(1 + signal_ndim); + collapsed_output_shape[0] = batch_size; for (size_t i = 0; i < dim.size(); ++i) { - batched_out_sizes[i + 1] = out_dims[dim[i]]; + collapsed_output_shape[i + 1] = out_dims[dim[i]]; } - - // output - framework::Tensor output; - output.Resize(framework::make_ddim(batched_out_sizes)); - output.mutable_data(tensor_place); - - // Create the transform plan (either from cache or locally) - const auto value_type = framework::IsComplexType(input.type()) - ? framework::ToRealType(input.type()) - : input.type(); - auto fft_type = GetFFTTransformType(input.type(), output.type()); - - PlanKey Key(framework::vectorize(input.dims()), - framework::vectorize(output.dims()), signal_size, fft_type, - value_type); - CuFFTConfig uncached_plan(Key); - CuFFTConfig* config = &uncached_plan; - auto& plan = config->plan(); - + framework::Tensor collapsed_output; + collapsed_output.Resize(framework::make_ddim(collapsed_output_shape)); + collapsed_output.mutable_data(tensor_place); + +#if defined(PADDLE_WITH_CUDA) + // create plan + CuFFTConfig config = + create_cufft_config(collapsed_input, collapsed_output, signal_ndim); // prepare cufft for execution - CUFFT_CHECK(platform::dynload::cufftSetStream(plan, ctx.stream())); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cufftSetStream(config.plan(), ctx.stream())); framework::Tensor workspace_tensor; - workspace_tensor.mutable_data(tensor_place, config->workspace_size()); - CUFFT_CHECK( - platform::dynload::cufftSetWorkArea(plan, workspace_tensor.data())); + workspace_tensor.mutable_data(tensor_place, config.workspace_size()); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftSetWorkArea( + config.plan(), workspace_tensor.data())); + // execute transform plan + exec_cufft_plan(ctx, config, &collapsed_input, + &collapsed_output, forward); +#elif defined(PADDLE_WITH_HIP) + // create plan + HIPFFTConfig config = + create_hipfft_config(collapsed_input, collapsed_output, signal_ndim); + // prepare cufft for execution + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::hipfftSetStream(config.plan(), ctx.stream())); + framework::Tensor workspace_tensor; + workspace_tensor.mutable_data(tensor_place, config.workspace_size()); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftSetWorkArea( + config.plan(), workspace_tensor.data())); // execute transform plan - if (fft_type == FFTTransformType::C2R && forward) { - forward = false; - framework::Tensor input_conj(input.type()); - input_conj.mutable_data(input.dims(), ctx.GetPlace()); - platform::ForRange for_range(ctx, input.numel()); - math::ConjFunctor functor(input.data(), input.numel(), - input_conj.data()); - for_range(functor); - exec_cufft_plan(*config, input_conj.data(), output.data(), - forward); - } else if (fft_type == FFTTransformType::R2C && !forward) { - forward = true; - framework::Tensor out_conj(output.type()); - out_conj.mutable_data(output.dims(), ctx.GetPlace()); - exec_cufft_plan(*config, input.data(), out_conj.data(), - forward); - - platform::ForRange for_range(ctx, output.numel()); - math::ConjFunctor functor(out_conj.data(), output.numel(), - output.data()); - for_range(functor); - } else { - exec_cufft_plan(*config, input.data(), output.data(), forward); - } + exec_hipfft_plan(ctx, config, &collapsed_input, + &collapsed_output, forward); +#endif // Inverting output by reshape and transpose to original batch and dimension - output.Resize(framework::make_ddim(reshape_out_sizes)); - out->Resize(framework::make_ddim(out_dims)); - TransCompute(ndim, ctx, output, out, reverse_dim_permute); -} + auto transposed_out_shape = out->dims().transpose(dim_permute); -// Calculates the normalization constant -double fft_normalization_scale(FFTNormMode normalization, - const std::vector& sizes, - const std::vector& dims) { - // auto norm = static_cast(normalization); - if (normalization == FFTNormMode::none) { - return static_cast(1.0); - } + collapsed_output.Resize(transposed_out_shape); + auto& transposed_output = collapsed_output; - int64_t signal_numel = 1; - for (auto dim : dims) { - signal_numel *= sizes[dim]; + std::vector reverse_dim_permute(ndim); + for (size_t i = 0; i < ndim; i++) { + reverse_dim_permute[dim_permute[i]] = i; } - const double scale_denom = (normalization == FFTNormMode::by_sqrt_n) - ? std::sqrt(signal_numel) - : static_cast(signal_numel); - return static_cast(1.0 / scale_denom); -} -template -void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out, - FFTNormMode normalization, - const std::vector& sizes, - const std::vector& axes) { - double scale = fft_normalization_scale(normalization, sizes, axes); - if (scale != 1.0) { - auto eigen_out = framework::EigenVector::Flatten(*out); - auto eigen_in = framework::EigenVector::Flatten(*in); - auto dev = ctx.eigen_device(); - EigenScale::Eval(*dev, eigen_out, eigen_in, - static_cast(scale), - static_cast(0), false); - } else { - framework::TensorCopy(*in, ctx.GetPlace(), out); - } + TransCompute(ndim, ctx, transposed_output, out, + reverse_dim_permute); } + } // anonymous namespace // Use the optimized path to perform single R2C or C2R if transformation dim is diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index 8c64aad46cfc8..6e90ccfc51e1b 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -7,7 +7,7 @@ if (NOT WITH_NV_JETSON) endif() if (WITH_ROCM) - list(APPEND HIP_SRCS rocblas.cc miopen.cc hiprand.cc) + list(APPEND HIP_SRCS rocblas.cc miopen.cc hiprand.cc hipfft.cc) endif() # There is no macOS version of NCCL. diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index 0c5c47e38f85e..1bfd48b133907 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -356,6 +356,16 @@ void* GetCurandDsoHandle() { #endif } +#ifdef PADDLE_WITH_HIP +void* GetROCFFTDsoHandle() { +#if defined(__APPLE__) || defined(__OSX__) + return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "librocfft.dylib"); +#else + return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "librocfft.so"); +#endif +} +#endif + void* GetNvjpegDsoHandle() { #if defined(__APPLE__) || defined(__OSX__) return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvjpeg.dylib"); diff --git a/paddle/fluid/platform/dynload/dynamic_loader.h b/paddle/fluid/platform/dynload/dynamic_loader.h index 6260efdf71c59..1a66f4b979207 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.h +++ b/paddle/fluid/platform/dynload/dynamic_loader.h @@ -44,6 +44,7 @@ void* GetOpDsoHandle(const std::string& dso_name); void* GetNvtxDsoHandle(); void* GetCUFFTDsoHandle(); void* GetMKLRTDsoHandle(); +void* GetROCFFTDsoHandle(); void SetPaddleLibPath(const std::string&); } // namespace dynload diff --git a/paddle/fluid/platform/dynload/hipfft.cc b/paddle/fluid/platform/dynload/hipfft.cc new file mode 100644 index 0000000000000..767d2161be9d8 --- /dev/null +++ b/paddle/fluid/platform/dynload/hipfft.cc @@ -0,0 +1,30 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/dynload/hipfft.h" + +namespace paddle { +namespace platform { +namespace dynload { + +std::once_flag hipfft_dso_flag; +void *hipfft_dso_handle; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +HIPFFT_FFT_ROUTINE_EACH(DEFINE_WRAP); + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/hipfft.h b/paddle/fluid/platform/dynload/hipfft.h new file mode 100644 index 0000000000000..50c25935e41b7 --- /dev/null +++ b/paddle/fluid/platform/dynload/hipfft.h @@ -0,0 +1,124 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once +#ifdef PADDLE_WITH_HIP +#include + +#include // NOLINT + +#include "paddle/fluid/platform/dynload/dynamic_loader.h" +#include "paddle/fluid/platform/port.h" + +namespace paddle { +namespace platform { +namespace dynload { +extern std::once_flag hipfft_dso_flag; +extern void *hipfft_dso_handle; + +#define DECLARE_DYNAMIC_LOAD_HIPFFT_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ + using hipfftFunc = decltype(&::__name); \ + std::call_once(hipfft_dso_flag, []() { \ + hipfft_dso_handle = paddle::platform::dynload::GetROCFFTDsoHandle(); \ + }); \ + static void *p_##__name = dlsym(hipfft_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern DynLoad__##__name __name + +#define HIPFFT_FFT_ROUTINE_EACH(__macro) \ + __macro(hipfftPlan1d); \ + __macro(hipfftPlan2d); \ + __macro(hipfftPlan3d); \ + __macro(hipfftPlanMany); \ + __macro(hipfftMakePlan1d); \ + __macro(hipfftMakePlanMany); \ + __macro(hipfftMakePlanMany64); \ + __macro(hipfftGetSizeMany64); \ + __macro(hipfftEstimate1d); \ + __macro(hipfftEstimate2d); \ + __macro(hipfftEstimate3d); \ + __macro(hipfftEstimateMany); \ + __macro(hipfftCreate); \ + __macro(hipfftGetSize1d); \ + __macro(hipfftGetSizeMany); \ + __macro(hipfftGetSize); \ + __macro(hipfftSetWorkArea); \ + __macro(hipfftSetAutoAllocation); \ + __macro(hipfftExecC2C); \ + __macro(hipfftExecR2C); \ + __macro(hipfftExecC2R); \ + __macro(hipfftExecZ2Z); \ + __macro(hipfftExecD2Z); \ + __macro(hipfftExecZ2D); \ + __macro(hipfftSetStream); \ + __macro(hipfftDestroy); \ + __macro(hipfftGetVersion); \ + __macro(hipfftGetProperty); + +HIPFFT_FFT_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_HIPFFT_WRAP); + +inline const char *hipfftGetErrorString(hipfftResult_t status) { + switch (status) { + case HIPFFT_SUCCESS: + return "'HIPFFT_SUCCESS'. The hipFFT operation was successful."; + case HIPFFT_INVALID_PLAN: + return "'HIPFFT_INVALID_PLAN'. hipFFT was passed an invalid plan handle."; + case HIPFFT_ALLOC_FAILED: + return "'HIPFFT_ALLOC_FAILED'. hipFFT failed to allocate GPU or CPU " + "memory."; + case HIPFFT_INVALID_TYPE: + return "'HIPFFT_INVALID_TYPE'. No longer used."; + case HIPFFT_INVALID_VALUE: + return "'HIPFFT_INVALID_VALUE'. User specified an invalid pointer or " + "parameter."; + case HIPFFT_INTERNAL_ERROR: + return "'HIPFFT_INTERNAL_ERROR'. Driver or internal hipFFT library " + "error."; + case HIPFFT_EXEC_FAILED: + return "'HIPFFT_EXEC_FAILED'. Failed to execute an FFT on the GPU."; + case HIPFFT_SETUP_FAILED: + return "'HIPFFT_SETUP_FAILED'. The hipFFT library failed to initialize."; + case HIPFFT_INVALID_SIZE: + return "'HIPFFT_INVALID_SIZE'. User specified an invalid transform size."; + case HIPFFT_UNALIGNED_DATA: + return "'HIPFFT_UNALIGNED_DATA'. No longer used."; + case HIPFFT_INCOMPLETE_PARAMETER_LIST: + return "'HIPFFT_INCOMPLETE_PARAMETER_LIST'. Missing parameters in call."; + case HIPFFT_INVALID_DEVICE: + return "'HIPFFT_INVALID_DEVICE'. Execution of a plan was on different " + "GPU than plan creation."; + case HIPFFT_PARSE_ERROR: + return "'HIPFFT_PARSE_ERROR'. Internal plan database error."; + case HIPFFT_NO_WORKSPACE: + return "'HIPFFT_NO_WORKSPACE'. No workspace has been provided prior to " + "plan execution."; + case HIPFFT_NOT_IMPLEMENTED: + return "'HIPFFT_NOT_IMPLEMENTED'. Function does not implement " + "functionality for parameters given."; + case HIPFFT_NOT_SUPPORTED: + return "'HIPFFT_NOT_SUPPORTED'. Operation is not supported for " + "parameters given."; + default: + return "HIPFFT_STATUS_UNKNOWN_ERROR"; + } +} +} // namespace dynload +} // namespace platform +} // namespace paddle + +#endif diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index 7427060add8b1..caa495bb7f8c5 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -86,6 +86,7 @@ limitations under the License. */ #endif // PADDLE_WITH_CUDA #ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/dynload/hipfft.h" #include "paddle/fluid/platform/dynload/hiprand.h" #include "paddle/fluid/platform/dynload/miopen.h" #include "paddle/fluid/platform/dynload/rocblas.h" @@ -1113,6 +1114,14 @@ inline std::string build_rocm_error_msg(ncclResult_t nccl_result) { } #endif // not(__APPLE__) and PADDLE_WITH_NCCL +/***** HIPFFT ERROR *****/ +inline bool is_error(hipfftResult_t stat) { return stat != HIPFFT_SUCCESS; } + +inline std::string build_rocm_error_msg(hipfftResult_t stat) { + std::string msg(" HIPFFT error, "); + return msg + platform::dynload::hipfftGetErrorString(stat) + " "; +} + namespace details { template @@ -1129,6 +1138,7 @@ DEFINE_EXTERNAL_API_TYPE(hipError_t, hipSuccess); DEFINE_EXTERNAL_API_TYPE(hiprandStatus_t, HIPRAND_STATUS_SUCCESS); DEFINE_EXTERNAL_API_TYPE(miopenStatus_t, miopenStatusSuccess); DEFINE_EXTERNAL_API_TYPE(rocblas_status, rocblas_status_success); +DEFINE_EXTERNAL_API_TYPE(hipfftResult_t, HIPFFT_SUCCESS); #if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL) DEFINE_EXTERNAL_API_TYPE(ncclResult_t, ncclSuccess); diff --git a/paddle/fluid/platform/enforce_test.cc b/paddle/fluid/platform/enforce_test.cc index c6d5f171ddce4..6ff9e6ea903cd 100644 --- a/paddle/fluid/platform/enforce_test.cc +++ b/paddle/fluid/platform/enforce_test.cc @@ -331,6 +331,10 @@ TEST(enforce, hip_success) { CheckCudaStatusFailure(rocblas_status_invalid_handle, "Rocblas error")); EXPECT_TRUE( CheckCudaStatusFailure(rocblas_status_invalid_value, "Rocblas error")); + EXPECT_TRUE(CheckCudaStatusSuccess(HIPFFT_SUCCESS)); + EXPECT_TRUE(CheckCudaStatusFailure(HIPFFT_INVALID_PLAN, "HIPFFT error")); + EXPECT_TRUE(CheckCudaStatusFailure(HIPFFT_ALLOC_FAILED, "HIPFFT error")); + #if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL) EXPECT_TRUE(CheckCudaStatusSuccess(ncclSuccess)); EXPECT_TRUE(CheckCudaStatusFailure(ncclUnhandledCudaError, "Rccl error")); From 5244c94a991ee49869091b873e2adb695021c8b7 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Tue, 19 Oct 2021 16:33:21 +0800 Subject: [PATCH 07/16] move signal apis --- python/paddle/__init__.py | 1 + python/paddle/signal.py | 20 ++++++++++++++++++++ python/paddle/tensor/signal.py | 11 +++-------- 3 files changed, 24 insertions(+), 8 deletions(-) create mode 100644 python/paddle/signal.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index decffa66f4174..3e30c94f8a265 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -294,6 +294,7 @@ from . import hub # noqa: F401 from . import linalg # noqa: F401 from . import fft # noqa: F401 +from . import signal # noqa: F401 import paddle.text # noqa: F401 import paddle.vision # noqa: F401 diff --git a/python/paddle/signal.py b/python/paddle/signal.py new file mode 100644 index 0000000000000..c97f7fd5fae25 --- /dev/null +++ b/python/paddle/signal.py @@ -0,0 +1,20 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .tensor.signal import stft # noqa: F401 +from .tensor.signal import istft # noqa: F401 + +__all__ = [ # noqa + 'stft', 'istft' +] diff --git a/python/paddle/tensor/signal.py b/python/paddle/tensor/signal.py index 86022a1748356..d5cb703dd09d7 100644 --- a/python/paddle/tensor/signal.py +++ b/python/paddle/tensor/signal.py @@ -23,12 +23,7 @@ from ..fluid.layer_helper import LayerHelper from .. import _C_ops -__all__ = [ - 'frame', - 'overlap_add', - 'stft', - 'istft', -] +__all__ = [] def frame(x, frame_length, hop_length, axis=-1, name=None): @@ -295,7 +290,7 @@ def stft(x, .. code-block:: python import paddle - from paddle.tensor.signal import stft + from paddle.signal import stft # real-valued input x = paddle.randn([8, 48000], dtype=paddle.float64) @@ -459,7 +454,7 @@ def istft(x, import numpy as np import paddle - from paddle.tensor.signal import stft, istft + from paddle.signal import stft, istft paddle.seed(0) From 3f605d8422eba26b47a89f72bfefa3aefe4e7cac Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Wed, 20 Oct 2021 13:01:44 +0800 Subject: [PATCH 08/16] move fft and signal API path (#2) * move signal apis * move fft.py and signal.py to paddle/, fix typos * fix relative imports from fft.py and signal.py --- python/paddle/fft.py | 1633 ++++++++++++++++- .../fluid/tests/unittests/test_signal.py | 20 +- python/paddle/signal.py | 568 +++++- python/paddle/tensor/__init__.py | 2 - python/paddle/tensor/fft.py | 1601 ---------------- python/paddle/tensor/signal.py | 571 ------ 6 files changed, 2169 insertions(+), 2226 deletions(-) delete mode 100644 python/paddle/tensor/fft.py delete mode 100644 python/paddle/tensor/signal.py diff --git a/python/paddle/fft.py b/python/paddle/fft.py index 3ac02c9c8dc18..de15eba0feffa 100644 --- a/python/paddle/fft.py +++ b/python/paddle/fft.py @@ -12,50 +12,1613 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .tensor.fft import fft # noqa: F401 -from .tensor.fft import fft2 # noqa: F401 -from .tensor.fft import fftn # noqa: F401 -from .tensor.fft import ifft # noqa: F401 -from .tensor.fft import ifft2 # noqa: F401 -from .tensor.fft import ifftn # noqa: F401 -from .tensor.fft import rfft # noqa: F401 -from .tensor.fft import rfft2 # noqa: F401 -from .tensor.fft import rfftn # noqa: F401 -from .tensor.fft import irfft # noqa: F401 -from .tensor.fft import irfft2 # noqa: F401 -from .tensor.fft import irfftn # noqa: F401 -from .tensor.fft import hfft # noqa: F401 -from .tensor.fft import hfft2 # noqa: F401 -from .tensor.fft import hfftn # noqa: F401 -from .tensor.fft import ihfft # noqa: F401 -from .tensor.fft import ihfft2 # noqa: F401 -from .tensor.fft import ihfftn # noqa: F401 -from .tensor.fft import fftfreq # noqa: F401 -from .tensor.fft import rfftfreq # noqa: F401 -from .tensor.fft import fftshift # noqa: F401 -from .tensor.fft import ifftshift # noqa: F401 - -__all__ = [ # noqa +from typing import Sequence +import numpy as np +import paddle +from .tensor.attribute import is_complex, is_floating_point, is_interger, _real_to_complex_dtype, _complex_to_real_dtype +from .fluid.framework import in_dygraph_mode +from . import _C_ops +from .fluid.data_feeder import check_variable_and_dtype +from .fluid.layer_helper import LayerHelper + +__all__ = [ 'fft', - 'fft2', - 'fftn', 'ifft', - 'ifft2', - 'ifftn', 'rfft', - 'rfft2', - 'rfftn', 'irfft', - 'irfft2', - 'irfftn', 'hfft', - 'hfft2', - 'hfftn', 'ihfft', + 'fft2', + 'ifft2', + 'rfft2', + 'irfft2', + 'hfft2', 'ihfft2', + 'fftn', + 'ifftn', + 'rfftn', + 'irfftn', + 'hfftn', 'ihfftn', 'fftfreq', 'rfftfreq', 'fftshift', - 'ifftshift' + 'ifftshift', ] + + +def _check_normalization(norm): + if norm not in ['forward', 'backward', 'ortho']: + raise ValueError( + "Unexpected norm: {}. Norm should be forward, backward or ortho". + format(norm)) + + +def _check_fft_n(n): + if not isinstance(n, int): + raise ValueError( + "Invalid FFT argument n({}), it shoule be an integer.".format(n)) + if n <= 0: + raise ValueError( + "Invalid FFT argument n({}), it should be positive.".format(n)) + + +def _check_fft_shape(x, s): + ndim = x.ndim + if not isinstance(s, Sequence): + raise ValueError( + "Invaid FFT argument s({}), it should be a sequence of integers.") + + if len(s) > ndim: + raise ValueError( + "Length of FFT argument s should not be larger than the rank of input. " + "Received s: {}, rank of x: {}".format(s, ndim)) + for size in s: + if not isinstance(size, int) or size <= 0: + raise ValueError("FFT sizes {} contains invalid value ({})".format( + s, size)) + + +def _check_fft_axis(x, axis): + ndim = x.ndim + if not isinstance(axis, int): + raise ValueError( + "Invalid FFT axis ({}), it shoule be an integer.".format(axis)) + if axis < -ndim or axis >= ndim: + raise ValueError( + "Invalid FFT axis ({}), it should be in range [-{}, {})".format( + axis, ndim, ndim)) + + +def _check_fft_axes(x, axes): + ndim = x.ndim + if not isinstance(axes, Sequence): + raise ValueError( + "Invalid FFT axes ({}), it should be a sequence of integers.". + format(axes)) + if len(axes) > ndim: + raise ValueError( + "Length of fft axes should not be larger than the rank of input. " + "Received, len of axes: {}, rank of x: {}".format(len(axes), ndim)) + for axis in axes: + if not isinstance(axis, int) or axis < -ndim or axis >= ndim: + raise ValueError( + "FFT axes {} contains invalid value ({}), it should be in range [-{}, {})". + format(axes, axis, ndim, ndim)) + + +def _resize_fft_input(x, s, axes): + if len(s) != len(axes): + raise ValueError("length of `s` should equals length of `axes`.") + shape = x.shape + ndim = x.ndim + + axes_to_pad = [] + paddings = [] + axes_to_slice = [] + slices = [] + for i, axis in enumerate(axes): + if shape[axis] < s[i]: + axes_to_pad.append(axis) + paddings.append(s[i] - shape[axis]) + elif shape[axis] > s[i]: + axes_to_slice.append(axis) + slices.append((0, s[i])) + + if axes_to_slice: + x = paddle.slice( + x, + axes_to_slice, + starts=[item[0] for item in slices], + ends=[item[1] for item in slices]) + if axes_to_pad: + padding_widths = [0] * (2 * ndim) + for axis, pad in zip(axes_to_pad, paddings): + padding_widths[2 * axis + 1] = pad + x = paddle.nn.functional.pad(x, padding_widths) + return x + + +def _normalize_axes(x, axes): + ndim = x.ndim + return [item if item >= 0 else (item + ndim) for item in axes] + + +def _check_at_least_ndim(x, rank): + if x.ndim < rank: + raise ValueError("The rank of the input ({}) should >= {}".format( + x.ndim, rank)) + + +# public APIs 1d +def fft(x, n=None, axis=-1, norm="backward", name=None): + """ + Calculate one-dimensional discrete Fourier transform. + + This function uses the efficient fast Fourier transform (FFT) algorithm [1] to + calculate the 1-D * n * point discrete Fourier transform (DFT). + + Args: + x (Tensor): The input data. It's a Tensor type. It's a complex. + n (int, optional): The length of the output transform axis. If `n` is less than + the length input, the input will be cropped. If larger, the input is filled + with zeros. If `n` is not given, the input length along the axis specified + by `axis` is used. + axis (int, optional): Axis used to calculate FFT. If not specified, the last axis + is used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward", meaning no normalization on + the forward transforms and scaling by ``1/n`` on the `ifft`. "forward" instead applies + the ``1/n`` factor on the forward tranform. For ``norm="ortho"``, both directions are + scaled by ``1/sqrt(n)``. + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + complex tensor. The truncated or zero-padded input, transformed along the axis indicated + by `axis`, or the last one if `axis` is not specified. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.exp(3j * np.pi * np.arange(7) / 7) + xp = paddle.to_tensor(x) + fft_xp = paddle.fft.fft(xp).numpy() + print(fft_xp) + # [1.+1.25396034e+00j 1.+4.38128627e+00j 1.-4.38128627e+00j + # 1.-1.25396034e+00j 1.-4.81574619e-01j 1.+8.88178420e-16j + # 1.+4.81574619e-01j] + + + """ + if is_interger(x) or is_floating_point(x): + return fft_r2c( + x, n, axis, norm, forward=True, onesided=False, name=name) + else: + return fft_c2c(x, n, axis, norm, forward=True, name=name) + + +def ifft(x, n=None, axis=-1, norm="backward", name=None): + """ + Compute the 1-D inverse discrete Fourier Transform. + + This function computes the inverse of the 1-D *n*-point discrete Fourier transform + computed by `fft`. In other words, ``ifft(fft(x)) == x`` to within numerical accuracy. + + The input should be ordered in the same way as is returned by `fft`, + i.e., + + * ``x[0]`` should contain the zero frequency term, + * ``x[1:n//2]`` should contain the positive-frequency terms, + * ``x[n//2 + 1:]`` should contain the negative-frequency terms, in + increasing order starting from the most negative frequency. + + For an even number of input points, ``x[n//2]`` represents the sum of + the values at the positive and negative Nyquist frequencies, as the two + are aliased together. + + Args: + x (Tensor): The input data. It's a Tensor type. It's a complex. + n (int, optional): The length of the output transform axis. If `n` is less than + the length input, the input will be cropped. If larger, the input is filled + with zeros. If `n` is not given, the input length along the axis specified + by `axis` is used. + axis (int, optional): Axis used to calculate FFT. If not specified, the last axis + is used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward", meaning no normalization on + the forward transforms and scaling by ``1/n`` on the `ifft`. "forward" instead applies + the ``1/n`` factor on the forward tranform. For ``norm="ortho"``, both directions are + scaled by ``1/sqrt(n)``. + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + complex tensor. The truncated or zero-padded input, transformed along the axis indicated + by `axis`, or the last one if `axis` is not specified. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.exp(3j * np.pi * np.arange(7) / 7) + xp = paddle.to_tensor(x) + ifft_xp = paddle.fft.ifft(xp).numpy() + print(ifft_xp) + # [0.14285714+1.79137191e-01j 0.14285714+6.87963741e-02j + # 0.14285714+1.26882631e-16j 0.14285714-6.87963741e-02j + # 0.14285714-1.79137191e-01j 0.14285714-6.25898038e-01j + # 0.14285714+6.25898038e-01j] + + """ + if is_interger(x) or is_floating_point(x): + return fft_r2c( + x, n, axis, norm, forward=False, onesided=False, name=name) + else: + return fft_c2c(x, n, axis, norm, forward=False, name=name) + + +def rfft(x, n=None, axis=-1, norm="backward", name=None): + """ + The one dimensional FFT for real input. + + This function computes the one dimensional *n*-point discrete Fourier + Transform (DFT) of a real-valued tensor by means of an efficient algorithm + called the Fast Fourier Transform (FFT). + + When the DFT is computed for purely real input, the output is + Hermitian-symmetric. This function does not compute the negative frequency + terms, and the length of the transformed axis of the output is therefore + ``n//2 + 1``. + + Args: + x(Tensor) : Real-valued input tensor + n(int, optional): Number of points along transformation axis in the + input to use. If `n` is smaller than the length of the input, the + input is cropped. If it is larger, the input is padded with zeros. + If `n` is not given, the length of the input along the axis + specified by `axis` is used. + axis(int, optional): Axis over which to compute the FFT. Default value + is last axis. + norm(str, optional) : Normalization mode, indicates which direction of + the forward/backward pair of transforms is scaled and with what + normalization factor. Include {"backward", "ortho", "forward"}, + default value is "backward". + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name` . + + Returns: + out(Tensor) : complex tensor + + Raises: + + + Examples: + .. code-block:: python + import paddle + + x = paddle.to_tensor([0.0, 1.0, 0.0, 0.0]) + print(paddle.fft.rfft(x)) + # Tensor(shape=[3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, + # [ (1+0j), -1j , (-1+0j)]) + """ + return fft_r2c(x, n, axis, norm, forward=True, onesided=True, name=name) + + +def irfft(x, n=None, axis=-1, norm="backward", name=None): + """ + Computes the inverse of `rfft`. + + This function calculates the inverse of the one-dimensional *n* point discrete + Fourier transform of the actual input calculated by "rfft". In other words, + ``irfft(rfft(a),len(a)) == a`` is within the numerical accuracy range. + + The input shall be in the form of "rfft", i.e. the actual zero frequency term, + followed by the complex positive frequency term, in the order of increasing frequency. + Because the discrete Fourier transform of the actual input is Hermite symmetric, + the negative frequency term is regarded as the complex conjugate term of the corresponding + positive frequency term. + + Args: + x (Tensor): The input data. It's a Tensor type. It's a complex. + n (int, optional): The length of the output transform axis. For `n` output + points, ``n//2 + 1``input points are necessary. If the length of the input tensor is greater + than `n`, it will be cropped, if it is shorter than this, fill in zero. If `n` is not given, + it is considered to be ``2 * (k-1)``, where ``k`` is the length of the input axis specified + along the ` axis'. + axis (int, optional): Axis used to calculate FFT. If not specified, the last axis + is used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Real tensor. Truncated or zero fill input for the transformation along the axis indicated by + `axis`, or the last input if `axis` is not specified. The length of the conversion axis + is `n`, or ``2 * k-2``, if `k` is None, where `k` is the length of the input conversion axis. + If the output is an odd number, you need to specify the value of 'n', such as ``2 * k-1`` + in some cases. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.array([1, -1j, -1]) + xp = paddle.to_tensor(x) + irfft_xp = paddle.fft.irfft(xp).numpy() + print(irfft_xp) + # [0. 1. 0. 0.] + + """ + return fft_c2r(x, n, axis, norm, forward=False, name=name) + + +def hfft(x, n=None, axis=-1, norm="backward", name=None): + """ + Compute the FFT of a signal that has Hermitian symmetry, a real + spectrum. + + Args: + x (Tensor): The input data. It's a Tensor type. It's a complex. + n (int, optional): The length of the output transform axis. For `n` output + points, ``n//2 + 1`` input points are necessary. If the length of the input tensor is greater + than `n`, it will be cropped, if it is shorter than this, fill in zero. If `n` is not given, + it is considered to be ``2 * (k-1)``, where ``k`` is the length of the input axis specified + along the ` axis'. + axis (int,optional): Axis used to calculate FFT. If not specified, the last axis + is used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Real tensor. Truncated or zero fill input for the transformation along the axis indicated by + `axis`, or the last input if `axis` is not specified. The length of the conversion axis + is `n`, or ``2 * k-2``, if `k` is None, where `k` is the length of the input conversion axis. + If the output is an odd number, you need to specify the value of 'n', such as ``2 * k-1`` in + some cases. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.array([1, -1j, -1]) + xp = paddle.to_tensor(x) + hfft_xp = paddle.fft.hfft(xp).numpy() + print(hfft_xp) + # [0. 0. 0. 4.] + """ + + return fft_c2r(x, n, axis, norm, forward=True, name=name) + + +def ihfft(x, n=None, axis=-1, norm="backward", name=None): + """ + The inverse FFT of a signal that has Hermitian symmetry. + + This function computes the one dimensional *n*-point inverse FFT of a signal + that has Hermitian symmetry by means of an efficient algorithm called + the Fast Fourier Transform (FFT). + + When the DFT is computed for purely real input, the output is + Hermitian-symmetric. This function does not compute the negative frequency + terms, and the length of the transformed axis of the output is therefore + ``n//2 + 1``. + + Args: + x(Tensor): Input tensor. + n(int, optional): The number of points along transformation axis in the + input to use. If `n` is smaller than the length of the input, the + input is cropped. If it is larger, the input is padded with zeros. + If `n` is not given, the length of the input along the axis + specified by `axis` is used. + axis(int, optional) : Axis over which to compute the inverse FFT. If not + given, the last axis is used. + norm(str, optional) : Normalization mode, indicates which direction of + the forward/backward pair of transforms is scaled and with what + normalization factor. Include {"backward", "ortho", "forward"}, + default value is "backward". + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name` . + + Returns: + out(Tensor) : complex tensor. + + Examples: + .. code-block:: python + import paddle + + spectrum = paddle.to_tensor([10.0, -5.0, 0.0, -1.0, 0.0, -5.0]) + print(paddle.fft.ifft(spectrum)) + # Tensor(shape=[6], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, + # [(-0.1666666716337204+0j), (1-1.9868215517249155e-08j), (2.3333334922790527-1.9868215517249155e-08j), (3.5+0j), (2.3333334922790527+1.9868215517249155e-08j), (1+1.9868215517249155e-08j)]) + print(paddle.fft.ihfft(spectrum)) + # Tensor(shape = [4], dtype = complex64, place = CUDAPlace(0), stop_gradient = True, + # [(-0.1666666716337204+0j), (1-1.9868215517249155e-08j), (2.3333334922790527-1.9868215517249155e-08j), (3.5+0j)]) + + """ + return fft_r2c(x, n, axis, norm, forward=False, onesided=True, name=name) + + +# public APIs nd +def fftn(x, s=None, axes=None, norm="backward", name=None): + """ + Compute the N-D discrete Fourier Transform. + + This function calculates the n-D discrete Fourier transform on any number of axes + in the M-D array by fast Fourier transform (FFT). + + Args: + x (Tensor): The input data. It's a Tensor type. It's a complex. + s (sequence of ints, optional): Shape (length of each transformed axis) of the output + (``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). + This corresponds to ``n`` for ``fft(x, n)``. + Along any axis, if the given shape is smaller than that of the input, + the input is cropped. If it is larger, the input is padded with zeros. + if `s` is not given, the shape of the input along the axes specified + by `axes` is used. + axes (sequence of ints, optional): Axes used to calculate FFT. If not given, the last ``len(s)`` + axes are used, or all axes if `s` is also not specified. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward", meaning no normalization on + the forward transforms and scaling by ``1/n`` on the `ifft`. "forward" instead applies + the ``1/n`` factor on the forward tranform. For ``norm="ortho"``, both directions are + scaled by ``1/sqrt(n)``. + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + complex tensor. The truncated or zero-padded input, transformed along the axes indicated by + `axes`, or by a combination of `s` and `x`, as explained in the parameters section above. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.mgrid[:4, :4, :4][1] + xp = paddle.to_tensor(x) + fftn_xp = paddle.fft.fftn(xp, axes=(1, 2)).numpy() + print(fftn_xp) + # [[[24.+0.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.+8.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.+0.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]] + # [[24.+0.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.+8.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.+0.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]] + # [[24.+0.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.+8.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.+0.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]] + # [[24.+0.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.+8.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.+0.j 0.+0.j 0.+0.j 0.-0.j] + # [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]]] + """ + if is_interger(x) or is_floating_point(x): + return fftn_r2c( + x, s, axes, norm, forward=True, onesided=False, name=name) + else: + return fftn_c2c(x, s, axes, norm, forward=True, name=name) + + +def ifftn(x, s=None, axes=None, norm="backward", name=None): + """ + Compute the N-D inverse discrete Fourier Transform. + + This function computes the inverse of the N-D discrete + Fourier Transform over any number of axes in an M-D array by + means of the Fast Fourier Transform (FFT). In other words, + ``ifftn(fftn(x)) == x`` to within numerical accuracy. + + The input, analogously to `ifft`, should be ordered in the same way as is + returned by `fftn`, i.e., it should have the term for zero frequency + in all axes in the low-order corner, the positive frequency terms in the + first half of all axes, the term for the Nyquist frequency in the middle + of all axes and the negative frequency terms in the second half of all + axes, in order of decreasingly negative frequency. + + Args: + x (Tensor): The input data. It's a Tensor type. It's a complex. + s (sequence of ints, optional): Shape (length of each transformed axis) of the output + (``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). + This corresponds to ``n`` for ``fft(x, n)``. + Along any axis, if the given shape is smaller than that of the input, + the input is cropped. If it is larger, the input is padded with zeros. + if `s` is not given, the shape of the input along the axes specified + by `axes` is used. + axes (sequence of ints, optional): Axes used to calculate FFT. If not given, the last ``len(s)`` + axes are used, or all axes if `s` is also not specified. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward", meaning no normalization on + the forward transforms and scaling by ``1/n`` on the `ifft`. "forward" instead applies + the ``1/n`` factor on the forward tranform. For ``norm="ortho"``, both directions are + scaled by ``1/sqrt(n)``. + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + complex tensor. The truncated or zero-padded input, transformed along the axes indicated by + `axes`, or by a combination of `s` and `x`, as explained in the parameters section above. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.eye(3) + xp = paddle.to_tensor(x) + ifftn_xp = paddle.fft.ifftn(xp, axes=(1,)).numpy() + print(ifftn_xp) + + # [[ 0.33333333+0.j 0.33333333+0.j 0.33333333-0.j ] + # [ 0.33333333+0.j -0.16666667+0.28867513j -0.16666667-0.28867513j] + # [ 0.33333333+0.j -0.16666667-0.28867513j -0.16666667+0.28867513j]] + + """ + if is_interger(x) or is_floating_point(x): + return fftn_r2c( + x, s, axes, norm, forward=False, onesided=False, name=name) + else: + return fftn_c2c(x, s, axes, norm, forward=False, name=name) + + +def rfftn(x, s=None, axes=None, norm="backward", name=None): + """ + The N dimensional FFT for real input. + + This function computes the N-dimensional discrete Fourier Transform over + any number of axes in an M-dimensional real array by means of the Fast + Fourier Transform (FFT). By default, all axes are transformed, with the + real transform performed over the last axis, while the remaining + transforms are complex. + + The transform for real input is performed over the last transformation + axis, as by `rfft`, then the transform over the remaining axes is + performed as by `fftn`. The order of the output is as for `rfft` for the + final transformation axis, and as for `fftn` for the remaining + transformation axes. + + Args: + x(Tensor) : Input tensor, taken to be real. + s(Sequence[int]) : Shape to use from the exec fft. The final element of + `s` corresponds to `n` for ``rfft(x, n)``, while for the remaining + axes, it corresponds to `n` for ``fft(x, n)``. Along any axis, if + the given shape is smaller than that of the input, the input is + cropped. If it is larger, the input is padded with zeros. if `s` is + not given, the shape of the input along the axes specified by `axes` + is used. + axes(Sequence[int]) : Axes over which to compute the FFT. If not given, + the last ``len(s)`` axes are used, or all axes if `s` is also not + specified. + norm(str, optional) : Normalization mode, indicates which direction of + the forward/backward pair of transforms is scaled and with what + normalization factor. Include {"backward", "ortho", "forward"}, + default value is "backward". + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name` . + + Returns: + out(Tensor): complex tensor + + + Raises: + ValueError: If `s` and `axes` have different length. + + Examples: + .. code-block:: python + import paddle + + # default, all axis will be used to exec fft + x = paddle.ones((2, 3, 4)) + print(paddle.fft.rfftn(x)) + # Tensor(shape=[2, 3, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, + # [[[(24+0j), 0j , 0j ], + # [0j , 0j , 0j ], + # [0j , 0j , 0j ]], + # + # [[0j , 0j , 0j ], + # [0j , 0j , 0j ], + # [0j , 0j , 0j ]]]) + + # use axes(2, 0) + print(paddle.fft.rfftn(x, axes=(2, 0))) + # Tensor(shape=[2, 3, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, + # [[[(8+0j), 0j , 0j ], + # [(8+0j), 0j , 0j ], + # [(8+0j), 0j , 0j ]], + # + # [[0j , 0j , 0j ], + # [0j , 0j , 0j ], + # [0j , 0j , 0j ]]]) + + """ + return fftn_r2c(x, s, axes, norm, forward=True, onesided=True, name=name) + + +def irfftn(x, s=None, axes=None, norm="backward", name=None): + """ + Computes the inverse of `rfftn`. + + This function computes the inverse of the N-D discrete + Fourier Transform for real input over any number of axes in an + M-D array by means of the Fast Fourier Transform (FFT). In + other words, ``irfftn(rfftn(x), x.shape) == x`` to within numerical + accuracy. (The ``a.shape`` is necessary like ``len(a)`` is for `irfft`, + and for the same reason.) + + The input should be ordered in the same way as is returned by `rfftn`, + i.e., as for `irfft` for the final transformation axis, and as for `ifftn` + along all the other axes. + + Args: + x (Tensor): The input data. It's a Tensor type. + s (sequence of ints, optional): The length of the output transform axis. + (``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). `s` is also the + number of input points used along this axis, except for the last axis, + where ``s[-1]//2+1`` points of the input are used. Along any axis, if + the shape indicated by `s` is smaller than that of the input, the input + is cropped. If it is larger, the input is padded with zeros. + If `s` is not given, the shape of the input along the axes specified by axes + is used. Except for the last axis which is taken to be ``2*(k-1)`` where + ``k`` is the length of the input along that axis. + axes (sequence of ints, optional): Axes over which to compute the inverse FFT. If not given, the last + `len(s)` axes are used, or all axes if `s` is also not specified. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Real tensor. The truncated or zero-padded input, transformed along the axes indicated by `axes`, + or by a combination of `s` or `x`, as explained in the parameters section above. The length of + each transformed axis is as given by the corresponding element of `s`, or the length of the input + in every axis except for the last one if `s` is not given. In the final transformed axis the length + of the output when `s` is not given is ``2*(m-1)``, where ``m`` is the length of the final + transformed axis of the input. To get an odd number of output points in the final axis, + `s` must be specified. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = (np.array([2, 2, 3]) + 1j * np.array([2, 2, 3])).astype(np.complex128) + xp = paddle.to_tensor(x) + irfftn_xp = paddle.fft.irfftn(xp).numpy() + print(irfftn_xp) + # [ 2.25 -1.25 0.25 0.75] + + """ + return fftn_c2r(x, s, axes, norm, forward=False, name=name) + + +def hfftn(x, s=None, axes=None, norm="backward", name=None): + """ + Compute the N-D FFT of Hermitian symmetric complex input, i.e., a + signal with a real spectrum. + + This function calculates the n-D discrete Fourier transform of Hermite symmetric + complex input on any axis in M-D array by fast Fourier transform (FFT). + In other words, ``ihfftn(hfftn(x, s)) == x is within the numerical accuracy range. + (``s`` here are ``x.shape`` and ``s[-1] = x.shape[- 1] * 2 - 1``. This is necessary + for the same reason that ``irfft` requires ``x.shape``.) + + Args: + x (Tensor): The input data. It's a Tensor type. + s (sequence of ints, optional): The length of the output transform axis. + (``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). `s` is also the + number of input points used along this axis, except for the last axis, + where ``s[-1]//2+1`` points of the input are used. Along any axis, if + the shape indicated by `s` is smaller than that of the input, the input + is cropped. If it is larger, the input is padded with zeros. + If `s` is not given, the shape of the input along the axes specified by axes + is used. Except for the last axis which is taken to be ``2*(k-1)`` where + ``k`` is the length of the input along that axis. + axes (sequence of ints, optional): Axes over which to compute the inverse FFT. If not given, the last + `len(s)` axes are used, or all axes if `s` is also not specified. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Real tensor. Truncate or zero fill input, transforming along the axis indicated by axis or + a combination of `s` or `X`. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = (np.array([2, 2, 3]) + 1j * np.array([2, 2, 3])).astype(np.complex128) + xp = paddle.to_tensor(x) + hfftn_xp = paddle.fft.hfftn(xp).numpy() + print(hfftn_xp) + # [ 9. 3. 1. -5.] + + + """ + return fftn_c2r(x, s, axes, norm, forward=True, name=name) + + +def ihfftn(x, s=None, axes=None, norm="backward", name=None): + """ + The n dimensional inverse FFT of a signal that has Hermitian symmetry. + + This function computes the n dimensional inverse FFT over any number of axes + in an M-dimensional of a signal that has Hermitian symmetry by means of an + efficient algorithm called the Fast Fourier Transform (FFT). + + Args: + x(Tensor): Input tensor. + s(Sequence[int], optional) : Shape (length along each transformed axis) + to use from the input. (``s[0]`` refers to axis 0, ``s[1]`` to axis + 1, etc.). Along any axis, if the given shape is smaller than that + of the input, the input is cropped. If it is larger, the input is + padded with zeros. if `s` is not given, the shape of the input + along the axes specified by `axes` is used. + axis(Sequence[int], optional) : Axis over which to compute the inverse FFT. If not + given, the last axis is used. + norm(str, optional) : Normalization mode, indicates which direction of + the forward/backward pair of transforms is scaled and with what + normalization factor. Include {"backward", "ortho", "forward"}, + default value is "backward". + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name` . + + Returns: + out(Tensor) : complex tensor. + + Examples: + .. code-block:: python + import paddle + + spectrum = paddle.to_tensor([10.0, -5.0, 0.0, -1.0, 0.0, -5.0]) + print(paddle.fft.ifft(spectrum)) + # Tensor(shape=[6], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, + # [(-0.1666666716337204+0j), (1-1.9868215517249155e-08j), (2.3333334922790527-1.9868215517249155e-08j), (3.5+0j), (2.3333334922790527+1.9868215517249155e-08j), (1+1.9868215517249155e-08j)]) + print(paddle.fft.ihfft(spectrum)) + # Tensor(shape = [4], dtype = complex64, place = CUDAPlace(0), stop_gradient = True, + # [(-0.1666666716337204+0j), (1-1.9868215517249155e-08j), (2.3333334922790527-1.9868215517249155e-08j), (3.5+0j)]) + + """ + return fftn_r2c(x, s, axes, norm, forward=False, onesided=True, name=name) + + +# public APIs 2d +def fft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + """ + Compute the 2-D discrete Fourier Transform + + This function computes the N-D discrete Fourier Transform + over any axes in an M-D array by means of the + Fast Fourier Transform (FFT). By default, the transform is computed over + the last two axes of the input array, i.e., a 2-dimensional FFT. + + Args: + x (Tensor): The input data. It's a Tensor type. + s (sequence of ints, optional): Shape (length of each transformed axis) of the output. + It should be a sequence of 2 integers. This corresponds to ``n`` for ``fft(x, n)``. + Along each axis, if the given shape is smaller than that of the input, + the input is cropped. If it is larger, the input is padded with zeros. + if `s` is not given, the shape of the input along the axes specified + by `axes` is used. Default is None. + axes (sequence of ints, optional): Axes over which to compute the FFT. It should be a + sequence of 2 integers. If not specified, the last two axes are used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Complex tensor. The truncated or zero-padded input, transformed along the axes indicated by `axes`, + or the last two axes if `axes` is not given. + + Raises: + ValueError: if `s` not be a sequence of 2 integers or None. + ValueError: if `axes` not be a sequence of 2 integers or None. + ValueError: If the input dimension is smaller than 2. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.mgrid[:2, :2][1] + xp = paddle.to_tensor(x) + fft2_xp = paddle.fft.fft2(xp).numpy() + print(fft2_xp) + # [[ 2.+0.j -2.+0.j] + # [ 0.+0.j 0.+0.j]] + + """ + _check_at_least_ndim(x, 2) + if s is not None: + if not isinstance(s, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". + format(s)) + if axes is not None: + if not isinstance(axes, Sequence) or len(axes) != 2: + raise ValueError( + "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". + format(axes)) + return fftn(x, s, axes, norm, name) + + +def ifft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + """ + Compute the 2-D inverse discrete Fourier Transform. + + This function computes the inverse of the 2-D discrete Fourier + Transform over any number of axes in an M-D array by means of + the Fast Fourier Transform (FFT). In other words, ``ifft2(fft2(x)) == x`` + to within numerical accuracy. By default, the inverse transform is + computed over the last two axes of the input array. + + The input, analogously to `ifft`, should be ordered in the same way as is + returned by `fft2`, i.e., it should have the term for zero frequency + in the low-order corner of the two axes, the positive frequency terms in + the first half of these axes, the term for the Nyquist frequency in the + middle of the axes and the negative frequency terms in the second half of + both axes, in order of decreasingly negative frequency. + + Args: + x (Tensor): The input data. It's a Tensor type. + s (sequence of ints, optional): Shape (length of each transformed axis) of the output. + It should be a sequence of 2 integers. This corresponds to ``n`` for ``fft(x, n)``. + Along each axis, if the given shape is smaller than that of the input, + the input is cropped. If it is larger, the input is padded with zeros. + if `s` is not given, the shape of the input along the axes specified + by `axes` is used. Default is None. + axes (sequence of ints, optional): Axes over which to compute the FFT. It should be a + sequence of 2 integers. If not specified, the last two axes are used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Complex tensor. The truncated or zero-padded input, transformed along the axes indicated by `axes`, + or the last two axes if `axes` is not given. + + Raises: + ValueError: if `s` not be a sequence of 2 integers or None. + ValueError: if `axes` not be a sequence of 2 integers or None. + ValueError: If the input dimension is smaller than 2. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.mgrid[:2, :2][1] + xp = paddle.to_tensor(x) + ifft2_xp = paddle.fft.ifft2(xp).numpy() + print(ifft2_xp) + # [[ 0.5+0.j -0.5+0.j] + # [ 0. +0.j 0. +0.j]] + """ + _check_at_least_ndim(x, 2) + if s is not None: + if not isinstance(s, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". + format(s)) + if axes is not None: + if not isinstance(axes, Sequence) or len(axes) != 2: + raise ValueError( + "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". + format(axes)) + return ifftn(x, s, axes, norm, name) + + +def rfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + """ + The two dimensional FFT with real tensor input. + + This is really just `rfftn` with different default behavior. + For more details see `rfftn`. + + Args: + x(Tensor): Input tensor, taken to be real. + s(Sequence[int]) : Shape of the FFT. + axes(Sequence[int], optional): Axes over which to compute the FFT. + norm(str, optional) : {"backward", "ortho", "forward"}, + default is "backward". Indicates which direction of the + forward/backward pair of transforms is scaled and with what + normalization factor. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name` . + + Returns: + out(Tensor): The result of the real 2-D FFT. + + Raises: + + + Examples: + + .. code-block:: python + import paddle + import numpy as np + + x = paddle.to_tensor(np.mgrid[:5, :5][0].astype(np.float32)) + print(paddle.fft.rfft2(x)) + # Tensor(shape=[5, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, + # [[ (50+0j) , (1.1920928955078125e-07+0j) , 0j ], + # [(-12.5+17.204774856567383j) , (-9.644234211236835e-08+7.006946134424652e-08j) , 0j ], + # [(-12.500000953674316+4.061495304107666j) , (3.6837697336977726e-08-1.1337477445749755e-07j), 0j ], + # [(-12.500000953674316-4.061495304107666j) , (3.6837697336977726e-08+1.1337477445749755e-07j), 0j ], + # [(-12.5-17.204774856567383j) , (-9.644234211236835e-08-7.006946134424652e-08j) , 0j ]]) + """ + _check_at_least_ndim(x, 2) + if s is not None: + if not isinstance(s, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". + format(s)) + if axes is not None: + if not isinstance(axes, Sequence) or len(axes) != 2: + raise ValueError( + "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". + format(axes)) + return rfftn(x, s, axes, norm, name) + + +def irfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + """ + Computes the inverse of `rfft2`. + + Args: + x (Tensor): The input data. It's a Tensor type. + s (sequence of ints, optional): Shape of the real output to the inverse FFT. Default is None. + axes (sequence of ints, optional): The axes over which to compute the inverse FFT. Axes + must be two-dimensional. If not specified, the last two axes are used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Real tensor. The result of the inverse real 2-D FFT. + + Raises: + ValueError: if `s` not be a sequence of 2 integers or None. + ValueError: if `axes` not be a sequence of 2 integers or None. + ValueError: If the input dimension is smaller than 2. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = (np.array([[3,2,3],[2, 2, 3]]) + 1j * np.array([[3,2,3],[2, 2, 3]])).astype(np.complex128) + xp = paddle.to_tensor(x) + irfft2_xp = paddle.fft.irfft2(xp).numpy() + print(irfft2_xp) + # [[ 2.375 -1.125 0.375 0.875] + # [ 0.125 0.125 0.125 0.125]] + + """ + _check_at_least_ndim(x, 2) + if s is not None: + if not isinstance(s, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". + format(s)) + if axes is not None: + if not isinstance(axes, Sequence) or len(axes) != 2: + raise ValueError( + "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". + format(axes)) + return irfftn(x, s, axes, norm, name) + + +def hfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + """ + Compute the 2-D FFT of a Hermitian complex array. + + Args: + x (Tensor): The input data. It's a Tensor type. + s (sequence of ints, optional): Shape of the real output. Default is None. + axes (sequence of ints, optional): Axes over which to compute the FFT. Axes must be + two-dimensional. If not specified, the last two axes are used by default. + norm (str): Indicates which direction to scale the `forward` or `backward` transform + pair and what normalization factor to use. The parameter value must be one + of "forward" or "backward" or "ortho". Default is "backward". + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Real tensor. The real result of the 2-D Hermitian complex real FFT. + + Raises: + ValueError: if `s` not be a sequence of 2 integers or None. + ValueError: if `axes` not be a sequence of 2 integers or None. + ValueError: If the input dimension is smaller than 2. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = (np.array([[3,2,3],[2, 2, 3]]) + 1j * np.array([[3,2,3],[2, 2, 3]])).astype(np.complex128) + xp = paddle.to_tensor(x) + hfft2_xp = paddle.fft.hfft2(xp).numpy() + print(hfft2_xp) + # [[19. 7. 3. -9.] + # [ 1. 1. 1. 1.]] + + + """ + _check_at_least_ndim(x, 2) + if s is not None: + if not isinstance(s, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". + format(s)) + if axes is not None: + if not isinstance(axes, Sequence) or len(axes) != 2: + raise ValueError( + "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". + format(axes)) + return hfftn(x, s, axes, norm, name) + + +def ihfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + """ + Compute the two dimensional inverse FFT of a real spectrum. + + This is really `ihfftn` with different defaults. + For more details see `ihfftn`. + + Args: + x(Tensor): Input tensor + s(Sequence[int], optional): Shape of the real input to the inverse FFT. + axes(Sequance[int], optional): The axes over which to compute the + inverse fft. Default is the last two axes. + norm(str, optional): {"backward", "ortho", "forward"}. Default is + "backward". + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name` . + + Returns: + out(Tensor) : The result of the inverse hermitian 2-D FFT. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.mgrid[:5, :5][0].astype(np.float64) + xp = paddle.to_tensor(x) + ihfft2_xp = paddle.fft.ihfft2(xp).numpy() + print(ihfft2_xp) + # [[ 2. +0.j 0. +0.j 0. +0.j ] + # [-0.5-0.68819096j 0. +0.j 0. +0.j ] + # [-0.5-0.16245985j 0. +0.j 0. +0.j ] + # [-0.5+0.16245985j 0. +0.j 0. +0.j ] + # [-0.5+0.68819096j 0. +0.j 0. +0.j ]] + """ + _check_at_least_ndim(x, 2) + if s is not None: + if not isinstance(s, Sequence) or len(s) != 2: + raise ValueError( + "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". + format(s)) + if axes is not None: + if not isinstance(axes, Sequence) or len(axes) != 2: + raise ValueError( + "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". + format(axes)) + return ihfftn(x, s, axes, norm, name) + + +# public APIs utilities +def fftfreq(n, d=1.0, dtype=None, name=None): + """ + Return the Discrete Fourier Transform sample frequencies. + + The returned float array `f` contains the frequency bin centers in cycles + per unit of the sample spacing (with zero at the start). For instance, if + the sample spacing is in seconds, then the frequency unit is cycles/second. + + Given input length `n` and a sample spacing `d`:: + + f = [0, 1, ..., n/2-1, -n/2, ..., -1] / (d*n) if n is even + f = [0, 1, ..., (n-1)/2, -(n-1)/2, ..., -1] / (d*n) if n is odd + + Args: + n (int): Dimension inputed. + d (scalar, optional): Sample spacing (inverse of the sampling rate). Defaults is 1. + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor. A tensor of length 'n' containing the sampling frequency. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.array([3, 1, 2, 2, 3], dtype=float) + scalar_temp = 0.5 + n = x.size + fftfreq_xp = paddle.fft.fftfreq(n, d=scalar_temp) + print(fftfreq_xp) + + # Tensor(shape=[5], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [ 0. , 0.40000001, 0.80000001, -0.80000001, -0.40000001]) + """ + + dtype = paddle.framework.get_default_dtype() + val = 1.0 / (n * d) + pos_max = (n + 1) // 2 + neg_max = n // 2 + indices = paddle.arange(-neg_max, pos_max, dtype=dtype, name=name) + indices = paddle.roll(indices, -neg_max, name=name) + return indices * val + + +def rfftfreq(n, d=1.0, dtype=None, name=None): + """ + Return the Discrete Fourier Transform sample frequencies. + + The returned floating-point array "F" contains the center of the frequency unit, + and the unit is the number of cycles of the sampling interval (the starting point is zero). + + Given input length `n` and a sample spacing `d`:: + + f = [0, 1, ..., n/2-1, n/2] / (d*n) if n is even + f = [0, 1, ..., (n-1)/2-1, (n-1)/2] / (d*n) if n is odd + + the Nyquist frequency component is considered to be positive. + + Args: + n (int): Dimension inputed. + d (scalar, optional): Sample spacing (inverse of the sampling rate). Defaults is 1. + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor. A tensor of length ``n//2 + 1`` containing the sample frequencies. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.array([3, 1, 2, 2, 3], dtype=float) + scalar_temp = 0.3 + n = x.size + rfftfreq_xp = paddle.fft.rfftfreq(n, d=scalar_temp) + print(rfftfreq_xp) + + # Tensor(shape=[3], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [0. , 0.66666669, 1.33333337]) + + """ + + dtype = paddle.framework.get_default_dtype() + val = 1.0 / (n * d) + pos_max = 1 + n // 2 + indices = paddle.arange(0, pos_max, dtype=dtype, name=name) + return indices * val + + +def fftshift(x, axes=None, name=None): + """ + Shift the zero-frequency component to the center of the spectrum. + + This function swaps half spaces for all the axes listed (all by default). + Note that ``y[0]`` is the Nyquist component only if ``len(x)`` is even. + + Args: + n (int): Dimension inputed. + axes (int|tuple, optional): The axis on which to move. The default is none, which moves all axes. + Default is None. + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor. The shifted tensor. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.array([3, 1, 2, 2, 3], dtype=float) + n = x.size + fftfreq_xp = paddle.fft.fftfreq(n, d=0.3) + res = paddle.fft.fftshift(fftfreq_xp).numpy() + print(res) + # [-1.3333334 -0.6666667 0. 0.6666667 1.3333334] + + """ + shape = paddle.shape(x) + if axes is None: + # shift all axes + rank = paddle.rank(x).reshape([1]) + axes = axes or paddle.arange(0, rank) + shifts = [size // 2 for size in shape] + elif isinstance(axes, int): + shifts = shape[axes] // 2 + else: + shifts = [shape[ax] // 2 for ax in axes] + return paddle.roll(x, shifts, axes, name=name) + + +def ifftshift(x, axes=None, name=None): + """ + The inverse of `fftshift`. Although the even length 'x' is the same, the function of the + odd length 'x' is different. An example. + + Args: + n (int): Dimension inputed. + axes (int|tuple, optional): The axis on which to move. The default is none, which moves all axes. + Default is None. + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor. The shifted tensor. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = np.array([3, 1, 2, 2, 3], dtype=float) + n = x.size + fftfreq_xp = paddle.fft.fftfreq(n, d=0.3) + res = paddle.fft.ifftshift(fftfreq_xp).numpy() + print(res) + # [ 1.3333334 -1.3333334 -0.6666667 0. 0.6666667] + + """ + shape = paddle.shape(x) + if axes is None: + # shift all axes + rank = paddle.rank(x).reshape([1]) + axes = axes or paddle.arange(0, rank) + shifts = [-size // 2 for size in shape] + elif isinstance(axes, int): + shifts = -shape[axes] // 2 + else: + shifts = [-shape[ax] // 2 for ax in axes] + return paddle.roll(x, shifts, axes, name=name) + + +# internal functions +def fft_c2c(x, n, axis, norm, forward, name): + if is_interger(x): + x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype())) + elif is_floating_point(x): + x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) + _check_normalization(norm) + + axis = axis if axis is not None else -1 + _check_fft_axis(x, axis) + axes = [axis] + axes = _normalize_axes(x, axes) + if n is not None: + _check_fft_n(n) + s = [n] + x = _resize_fft_input(x, s, axes) + op_type = 'fft_c2c' + + check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type) + if in_dygraph_mode(): + attrs = ('axes', axes, 'normalization', norm, 'forward', forward) + out = getattr(_C_ops, op_type)(x, *attrs) + else: + inputs = {'X': [x], } + attrs = {'axes': axes, 'normalization': norm, 'forward': forward} + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference(dtype) + outputs = {"Out": [out]} + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + return out + + +def fft_r2c(x, n, axis, norm, forward, onesided, name): + if is_interger(x): + x = paddle.cast(x, paddle.get_default_dtype()) + _check_normalization(norm) + axis = axis if axis is not None else -1 + _check_fft_axis(x, axis) + axes = [axis] + axes = _normalize_axes(x, axes) + if n is not None: + _check_fft_n(n) + s = [n] + x = _resize_fft_input(x, s, axes) + op_type = 'fft_r2c' + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], op_type) + + if in_dygraph_mode(): + attrs = ('axes', axes, 'normalization', norm, 'forward', forward, + 'onesided', onesided) + out = getattr(_C_ops, op_type)(x, *attrs) + else: + inputs = {'X': [x], } + attrs = { + 'axes': axes, + 'normalization': norm, + 'forward': forward, + 'onesided': onesided, + } + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference( + _real_to_complex_dtype(dtype)) + outputs = {"Out": [out]} + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + return out + + +def fft_c2r(x, n, axis, norm, forward, name): + if is_interger(x): + x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype())) + elif is_floating_point(x): + x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) + _check_normalization(norm) + axis = axis if axis is not None else -1 + _check_fft_axis(x, axis) + axes = [axis] + axes = _normalize_axes(x, axes) + if n is not None: + _check_fft_n(n) + s = [n // 2 + 1] + x = _resize_fft_input(x, s, axes) + op_type = 'fft_c2r' + check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type) + + if in_dygraph_mode(): + if n is not None: + attrs = ('axes', axes, 'normalization', norm, 'forward', forward, + 'last_dim_size', n) + else: + attrs = ('axes', axes, 'normalization', norm, 'forward', forward) + out = getattr(_C_ops, op_type)(x, *attrs) + else: + inputs = {'X': [x], } + attrs = {'axes': axes, 'normalization': norm, 'forward': forward} + if n is not None: + attrs['last_dim_size'] = n + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference( + _complex_to_real_dtype(dtype)) + outputs = {"Out": [out]} + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + return out + + +def fftn_c2c(x, s, axes, norm, forward, name): + if is_interger(x): + x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype())) + elif is_floating_point(x): + x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) + _check_normalization(norm) + if s is not None: + _check_fft_shape(x, s) + + rank = x.ndim + if axes is None: + if s is None: + axes = list(range(rank)) + else: + fft_ndims = len(s) + axes = list(range(rank - fft_ndims, rank)) + else: + _check_fft_axes(x, axes) + axes = _normalize_axes(x, axes) + axes_argsoft = np.argsort(axes).tolist() + axes = [axes[i] for i in axes_argsoft] + if s is not None: + if len(s) != len(axes): + raise ValueError( + "Length of s ({}) and length of axes ({}) does not match.". + format(len(s), len(axes))) + s = [s[i] for i in axes_argsoft] + + if s is not None: + x = _resize_fft_input(x, s, axes) + op_type = 'fft_c2c' + check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type) + + if in_dygraph_mode(): + attrs = ('axes', axes, 'normalization', norm, 'forward', forward) + out = getattr(_C_ops, op_type)(x, *attrs) + else: + inputs = {'X': [x], } + attrs = {'axes': axes, 'normalization': norm, 'forward': forward} + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference(dtype) + outputs = {"Out": [out]} + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + return out + + +def fftn_r2c(x, s, axes, norm, forward, onesided, name): + if is_interger(x): + x = paddle.cast(x, paddle.get_default_dtype()) + _check_normalization(norm) + if s is not None: + _check_fft_shape(x, s) + + rank = x.ndim + if axes is None: + if s is None: + axes = list(range(rank)) + else: + fft_ndims = len(s) + axes = list(range(rank - fft_ndims, rank)) + else: + _check_fft_axes(x, axes) + axes = _normalize_axes(x, axes) + axes_argsoft = np.argsort(axes[:-1]).tolist() + axes = [axes[i] for i in axes_argsoft] + [axes[-1]] + if s is not None: + if len(s) != len(axes): + raise ValueError( + "Length of s ({}) and length of axes ({}) does not match.". + format(len(s), len(axes))) + s = [s[i] for i in axes_argsoft] + [s[-1]] + + if s is not None: + x = _resize_fft_input(x, s, axes) + + op_type = 'fft_r2c' + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], op_type) + + if in_dygraph_mode(): + attrs = ('axes', axes, 'normalization', norm, 'forward', forward, + 'onesided', onesided) + out = getattr(_C_ops, op_type)(x, *attrs) + else: + inputs = {'X': [x], } + attrs = { + 'axes': axes, + 'normalization': norm, + 'forward': forward, + 'onesided': onesided, + } + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference( + _real_to_complex_dtype(dtype)) + outputs = {"Out": [out]} + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + + return out + + +def fftn_c2r(x, s, axes, norm, forward, name): + if is_interger(x): + x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype())) + elif is_floating_point(x): + x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) + _check_normalization(norm) + if s is not None: + _check_fft_shape(x, s) + + rank = x.ndim + if axes is None: + if s is None: + axes = list(range(rank)) + else: + fft_ndims = len(s) + axes = list(range(rank - fft_ndims, rank)) + else: + _check_fft_axes(x, axes) + axes = _normalize_axes(x, axes) + axes_argsoft = np.argsort(axes[:-1]).tolist() + axes = [axes[i] for i in axes_argsoft] + [axes[-1]] + if s is not None: + if len(s) != len(axes): + raise ValueError( + "Length of s ({}) and length of axes ({}) does not match.". + format(len(s), len(axes))) + s = [s[i] for i in axes_argsoft] + [s[-1]] + + if s is not None: + fft_input_shape = list(s) + fft_input_shape[-1] = fft_input_shape[-1] // 2 + 1 + x = _resize_fft_input(x, fft_input_shape, axes) + + op_type = 'fft_c2r' + check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type) + + if in_dygraph_mode(): + if s: + attrs = ('axes', axes, 'normalization', norm, 'forward', forward, + 'last_dim_size', s[-1]) + else: + attrs = ('axes', axes, 'normalization', norm, 'forward', forward) + out = getattr(_C_ops, op_type)(x, *attrs) + else: + inputs = {'X': [x], } + attrs = {'axes': axes, 'normalization': norm, 'forward': forward} + if s: + attrs["last_dim_size"] = s[-1] + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference( + _complex_to_real_dtype(dtype)) + outputs = {"Out": [out]} + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + return out diff --git a/python/paddle/fluid/tests/unittests/test_signal.py b/python/paddle/fluid/tests/unittests/test_signal.py index a109a5aa5d1a6..ecbbd8f52db9b 100644 --- a/python/paddle/fluid/tests/unittests/test_signal.py +++ b/python/paddle/fluid/tests/unittests/test_signal.py @@ -652,7 +652,7 @@ def test_frame(self): self.assertTrue( np.allclose( frame_for_api_test(self.x, self.frame_length, self.hop_length, self.axis), - paddle.tensor.signal.frame( + paddle.signal.frame( paddle.to_tensor(self.x), self.frame_length, self.hop_length, @@ -678,7 +678,7 @@ def test_frame_static(self): mp, sp = paddle.static.Program(), paddle.static.Program() with paddle.static.program_guard(mp, sp): input = paddle.static.data('input', self.x.shape, dtype=self.x.dtype) - output = paddle.tensor.signal.frame( + output = paddle.signal.frame( input, self.frame_length, self.hop_length, @@ -708,7 +708,7 @@ def test_frame_static(self): class TestFrameException(unittest.TestCase): def test_frame(self): with self.assertRaises(self.expect_exception): - paddle.tensor.signal.frame( + paddle.signal.frame( paddle.to_tensor(self.x), self.frame_length, self.hop_length, @@ -731,7 +731,7 @@ def test_overlap_add(self): self.assertTrue( np.allclose( overlap_add_for_api_test(self.x, self.hop_length, self.axis), - paddle.tensor.signal.overlap_add( + paddle.signal.overlap_add( paddle.to_tensor(self.x), self.hop_length, self.axis), @@ -756,7 +756,7 @@ def test_overlap_add_static(self): mp, sp = paddle.static.Program(), paddle.static.Program() with paddle.static.program_guard(mp, sp): input = paddle.static.data('input', self.x.shape, dtype=self.x.dtype) - output = paddle.tensor.signal.overlap_add( + output = paddle.signal.overlap_add( input, self.hop_length, self.axis), @@ -783,7 +783,7 @@ def test_overlap_add_static(self): class TestOverlapAddException(unittest.TestCase): def test_overlap_add(self): with self.assertRaises(self.expect_exception): - paddle.tensor.signal.overlap_add( + paddle.signal.overlap_add( paddle.to_tensor(self.x), self.hop_length, self.axis) @@ -848,7 +848,7 @@ def test_stft(self): self.assertTrue( np.allclose( stft(self.x, self.n_fft, self.hop_length, self.win_length, win_l, self.center, self.pad_mode), - paddle.tensor.signal.stft( + paddle.signal.stft( paddle.to_tensor(self.x), self.n_fft, self.hop_length, @@ -891,7 +891,7 @@ def test_stft(self): win_p = paddle.to_tensor(self.window) with self.assertRaises(self.expect_exception): - paddle.tensor.signal.stft( + paddle.signal.stft( paddle.to_tensor(self.x), self.n_fft, self.hop_length, @@ -934,7 +934,7 @@ def test_istft(self): self.assertTrue( np.allclose( istft(self.x, self.hop_length, self.win_length, win_l, self.center, self.length), - paddle.tensor.signal.istft( + paddle.signal.istft( paddle.to_tensor(self.x), self.n_fft, self.hop_length, @@ -986,7 +986,7 @@ def test_istft(self): win_p = paddle.to_tensor(self.window) with self.assertRaises(self.expect_exception): - paddle.tensor.signal.istft( + paddle.signal.istft( paddle.to_tensor(self.x), self.n_fft, self.hop_length, diff --git a/python/paddle/signal.py b/python/paddle/signal.py index c97f7fd5fae25..d550fa677d929 100644 --- a/python/paddle/signal.py +++ b/python/paddle/signal.py @@ -1,20 +1,574 @@ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .tensor.signal import stft # noqa: F401 -from .tensor.signal import istft # noqa: F401 +from typing import Optional -__all__ = [ # noqa - 'stft', 'istft' +import paddle + +from .tensor.attribute import is_complex, is_floating_point +from .fft import fft_r2c, fft_c2r, fft_c2c +from .fluid.data_feeder import check_variable_and_dtype +from .fluid.framework import in_dygraph_mode +from .fluid.layer_helper import LayerHelper +from . import _C_ops + +__all__ = [ + 'stft', + 'istft', ] + + +def frame(x, frame_length, hop_length, axis=-1, name=None): + """ + Slice the N-dimensional (where N >= 1) input into (overlapping) frames. + + Args: + x (Tensor): The input data which is a N-dimensional (where N >= 1) Tensor + with shape `[..., seq_length]` or `[seq_length, ...]`. + frame_length (int): Length of the frame and `0 < frame_length <= x.shape[axis]`. + hop_length (int): Number of steps to advance between adjacent frames + and `0 < hop_length`. + axis (int, optional): Specify the axis to operate on the input Tensors. Its + value should be 0(the first dimension) or -1(the last dimension). If not + specified, the last axis is used by default. + + Returns: + The output frames tensor with shape `[..., frame_length, num_frames]` if `axis==-1`, + otherwise `[num_frames, frame_length, ...]` where + + `num_framse = 1 + (x.shape[axis] - frame_length) // hop_length` + + Examples: + + .. code-block:: python + + import paddle + from paddle.signal import frame + + # 1D + x = paddle.arange(8) + y0 = frame(x, frame_length=4, hop_length=2, axis=-1) # [4, 3] + # [[0, 2, 4], + # [1, 3, 5], + # [2, 4, 6], + # [3, 5, 7]] + + y1 = frame(x, frame_length=4, hop_length=2, axis=0) # [3, 4] + # [[0, 1, 2, 3], + # [2, 3, 4, 5], + # [4, 5, 6, 7]] + + # 2D + x0 = paddle.arange(16).reshape([2, 8]) + y0 = frame(x0, frame_length=4, hop_length=2, axis=-1) # [2, 4, 3] + # [[[0, 2, 4], + # [1, 3, 5], + # [2, 4, 6], + # [3, 5, 7]], + # + # [[8 , 10, 12], + # [9 , 11, 13], + # [10, 12, 14], + # [11, 13, 15]]] + + x1 = paddle.arange(16).reshape([8, 2]) + y1 = frame(x1, frame_length=4, hop_length=2, axis=0) # [3, 4, 2] + # [[[0 , 1 ], + # [2 , 3 ], + # [4 , 5 ], + # [6 , 7 ]], + # + # [4 , 5 ], + # [6 , 7 ], + # [8 , 9 ], + # [10, 11]], + # + # [8 , 9 ], + # [10, 11], + # [12, 13], + # [14, 15]]] + + # > 2D + x0 = paddle.arange(32).reshape([2, 2, 8]) + y0 = frame(x0, frame_length=4, hop_length=2, axis=-1) # [2, 2, 4, 3] + + x1 = paddle.arange(32).reshape([8, 2, 2]) + y1 = frame(x1, frame_length=4, hop_length=2, axis=0) # [3, 4, 2, 2] + """ + if axis not in [0, -1]: + raise ValueError(f'Unexpected axis: {axis}. It should be 0 or -1.') + + if not isinstance(frame_length, int) or frame_length <= 0: + raise ValueError( + f'Unexpected frame_length: {frame_length}. It should be an positive integer.' + ) + + if not isinstance(hop_length, int) or hop_length <= 0: + raise ValueError( + f'Unexpected hop_length: {hop_length}. It should be an positive integer.' + ) + + if frame_length > x.shape[axis]: + raise ValueError( + f'Attribute frame_length should be less equal than sequence length, ' + f'but got ({frame_length}) > ({x.shape[axis]}).') + + op_type = 'frame' + + if in_dygraph_mode(): + attrs = ('frame_length', frame_length, 'hop_length', hop_length, 'axis', + axis) + op = getattr(_C_ops, op_type) + out = op(x, *attrs) + else: + check_variable_and_dtype( + x, 'x', ['int32', 'int64', 'float16', 'float32', + 'float64'], op_type) + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference(dtype=dtype) + helper.append_op( + type=op_type, + inputs={'X': x}, + attrs={ + 'frame_length': frame_length, + 'hop_length': hop_length, + 'axis': axis + }, + outputs={'Out': out}) + return out + + +def overlap_add(x, hop_length, axis=-1, name=None): + """ + Reconstructs a tensor consisted of overlap added sequences from input frames. + + Args: + x (Tensor): The input data which is a N-dimensional (where N >= 2) Tensor + with shape `[..., frame_length, num_frames]` or + `[num_frames, frame_length ...]`. + hop_length (int): Number of steps to advance between adjacent frames and + `0 < hop_length <= frame_length`. + axis (int, optional): Specify the axis to operate on the input Tensors. Its + value should be 0(the first dimension) or -1(the last dimension). If not + specified, the last axis is used by default. + + Returns: + The output frames tensor with shape `[..., seq_length]` if `axis==-1`, + otherwise `[seq_length, ...]` where + + `seq_length = (n_frames - 1) * hop_length + frame_length` + + Examples: + + .. code-block:: python + + import paddle + from paddle.signal import overlap_add + + # 2D + x0 = paddle.arange(16).reshape([8, 2]) + # [[0 , 1 ], + # [2 , 3 ], + # [4 , 5 ], + # [6 , 7 ], + # [8 , 9 ], + # [10, 11], + # [12, 13], + # [14, 15]] + y0 = overlap_add(x0, hop_length=2, axis=-1) # [10] + # [0 , 2 , 5 , 9 , 13, 17, 21, 25, 13, 15] + + x1 = paddle.arange(16).reshape([2, 8]) + # [[0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], + # [8 , 9 , 10, 11, 12, 13, 14, 15]] + y1 = overlap_add(x1, hop_length=2, axis=0) # [10] + # [0 , 1 , 10, 12, 14, 16, 18, 20, 14, 15] + + # > 2D + x0 = paddle.arange(32).reshape([2, 1, 8, 2]) + y0 = overlap_add(x0, hop_length=2, axis=-1) # [2, 1, 10] + + x1 = paddle.arange(32).reshape([2, 8, 1, 2]) + y1 = overlap_add(x1, hop_length=2, axis=0) # [10, 1, 2] + """ + if axis not in [0, -1]: + raise ValueError(f'Unexpected axis: {axis}. It should be 0 or -1.') + + if not isinstance(hop_length, int) or hop_length <= 0: + raise ValueError( + f'Unexpected hop_length: {hop_length}. It should be an positive integer.' + ) + + op_type = 'overlap_add' + + if in_dygraph_mode(): + attrs = ('hop_length', hop_length, 'axis', axis) + op = getattr(_C_ops, op_type) + out = op(x, *attrs) + else: + check_variable_and_dtype( + x, 'x', ['int32', 'int64', 'float16', 'float32', + 'float64'], op_type) + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference(dtype=dtype) + helper.append_op( + type=op_type, + inputs={'X': x}, + attrs={'hop_length': hop_length, + 'axis': axis}, + outputs={'Out': out}) + return out + + +def stft(x, + n_fft, + hop_length=None, + win_length=None, + window=None, + center=True, + pad_mode='reflect', + normalized=False, + onesided=True, + name=None): + """ + Short-time Fourier transform (STFT). + + The STFT computes the discrete Fourier transforms (DFT) of short overlapping + windows of the input using this formula: + + .. math:: + X_t[\omega] = \sum_{n = 0}^{N-1}% + \text{window}[n]\ x[t \times H + n]\ % + e^{-{2 \pi j \omega n}/{N}} + + Where: + - :math:`t`: The :math:`t`-th input window. + - :math:`\omega`: Frequency :math:`0 \leq \omega < \text{n\_fft}` for `onesided=False`, + or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for `onesided=True`. + - :math:`N`: Value of `n_fft`. + - :math:`H`: Value of `hop_length`. + + Args: + x (Tensor): The input data which is a 1-dimensional or 2-dimensional Tensor with + shape `[..., seq_length]`. It can be a real-valued or a complex Tensor. + n_fft (int): The number of input samples to perform Fourier transform. + hop_length (int, optional): Number of steps to advance between adjacent windows + and `0 < hop_length`. Default: `None`(treated as equal to `n_fft//4`) + win_length (int, optional): The size of window. Default: `None`(treated as equal + to `n_fft`) + window (Tensor, optional): A 1-dimensional tensor of size `win_length`. It will + be center padded to length `n_fft` if `win_length < n_fft`. Default: `None`( + treated as a rectangle window with value equal to 1 of size `win_length`). + center (bool, optional): Whether to pad `x` to make that the + :math:`t \times hop\_length` at the center of :math:`t`-th frame. Default: `True`. + pad_mode (str, optional): Choose padding pattern when `center` is `True`. See + `paddle.nn.functional.pad` for all padding options. Default: `"reflect"` + normalized (bool, optional): Control whether to scale the output by `1/sqrt(n_fft)`. + Default: `False` + onesided (bool, optional): Control whether to return half of the Fourier transform + output that satisfies the conjugate symmetry condition when input is a real-valued + tensor. It can not be `True` if input is a complex tensor. Default: `True` + name (str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + The complex STFT output tensor with shape `[..., n_fft//2 + 1, num_frames]`( + real-valued input and `onesided` is `True`) or `[..., n_fft, num_frames]`( + `onesided` is `False`) + + Exampels: + .. code-block:: python + + import paddle + from paddle.signal import stft + + # real-valued input + x = paddle.randn([8, 48000], dtype=paddle.float64) + y1 = stft(x, n_fft=512) # [8, 257, 376] + y2 = stft(x, n_fft=512, onesided=False) # [8, 512, 376] + + # complex input + x = paddle.randn([8, 48000], dtype=paddle.float64) + \ + paddle.randn([8, 48000], dtype=paddle.float64)*1j # [8, 48000] complex128 + y1 = stft(x, n_fft=512, center=False, onesided=False) # [8, 512, 372] + """ + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64', 'complex64', 'complex128'], + 'stft') + + x_rank = len(x.shape) + assert x_rank in [1, 2], \ + f'x should be a 1D or 2D real tensor, but got rank of x is {x_rank}' + + if x_rank == 1: # (batch, seq_length) + x = x.unsqueeze(0) + + if hop_length is None: + hop_length = int(n_fft // 4) + + assert hop_length > 0, \ + f'hop_length should be > 0, but got {hop_length}.' + + if win_length is None: + win_length = n_fft + + assert 0 < n_fft <= x.shape[-1], \ + f'n_fft should be in (0, seq_length({x.shape[-1]})], but got {n_fft}.' + + assert 0 < win_length <= n_fft, \ + f'win_length should be in (0, n_fft({n_fft})], but got {win_length}.' + + if window is not None: + assert len(window.shape) == 1 and len(window) == win_length, \ + f'expected a 1D window tensor of size equal to win_length({win_length}), but got window with shape {window.shape}.' + else: + window = paddle.ones(shape=(win_length, ), dtype=x.dtype) + + if win_length < n_fft: + pad_left = (n_fft - win_length) // 2 + pad_right = n_fft - win_length - pad_left + window = paddle.nn.functional.pad(window, + pad=[pad_left, pad_right], + mode='constant') + + if center: + assert pad_mode in ['constant', 'reflect'], \ + 'pad_mode should be "reflect" or "constant", but got "{}".'.format(pad_mode) + + pad_length = n_fft // 2 + # FIXME: Input `x` can be a complex tensor but pad does not supprt complex input. + x = paddle.nn.functional.pad(x.unsqueeze(-1), + pad=[pad_length, pad_length], + mode=pad_mode, + data_format="NLC").squeeze(-1) + + x_frames = frame(x=x, frame_length=n_fft, hop_length=hop_length, axis=-1) + x_frames = x_frames.transpose( + perm=[0, 2, + 1]) # switch n_fft to last dim, egs: (batch, num_frames, n_fft) + x_frames = x_frames * window + + norm = 'ortho' if normalized else 'backward' + if is_complex(x_frames): + assert not onesided, \ + 'onesided should be False when input or window is a complex Tensor.' + + if not is_complex(x): + out = fft_r2c( + x=x_frames, + n=None, + axis=-1, + norm=norm, + forward=True, + onesided=onesided, + name=name) + else: + out = fft_c2c( + x=x_frames, n=None, axis=-1, norm=norm, forward=True, name=name) + + out = out.transpose(perm=[0, 2, 1]) # (batch, n_fft, num_frames) + + if x_rank == 1: + out.squeeze_(0) + + return out + + +def istft(x, + n_fft, + hop_length=None, + win_length=None, + window=None, + center=True, + normalized=False, + onesided=True, + length=None, + return_complex=False, + name=None): + """ + Inverse short-time Fourier transform (ISTFT). + + Reconstruct time-domain signal from the giving complex input and window tensor when + nonzero overlap-add (NOLA) condition is met: + + .. math:: + \sum_{t = -\infty}^{\infty}% + \text{window}^2[n - t \times H]\ \neq \ 0, \ \text{for } all \ n + + Where: + - :math:`t`: The :math:`t`-th input window. + - :math:`N`: Value of `n_fft`. + - :math:`H`: Value of `hop_length`. + + Result of `istft` expected to be the inverse of `paddle.signal.stft`, but it is + not guaranteed to reconstruct a exactly realizible time-domain signal from a STFT + complex tensor which has been modified (via masking or otherwise). Therefore, `istft` + gives the [Griffin-Lim optimal estimate](https://ieeexplore.ieee.org/document/1164317) + (optimal in a least-squares sense) for the corresponding signal. + + Args: + x (Tensor): The input data which is a 2-dimensional or 3-dimensional **complesx** + Tensor with shape `[..., n_fft, num_frames]`. + n_fft (int): The size of Fourier transform. + hop_length (int, optional): Number of steps to advance between adjacent windows + from time-domain signal and `0 < hop_length < win_length`. Default: `None`( + treated as equal to `n_fft//4`) + win_length (int, optional): The size of window. Default: `None`(treated as equal + to `n_fft`) + window (Tensor, optional): A 1-dimensional tensor of size `win_length`. It will + be center padded to length `n_fft` if `win_length < n_fft`. It should be a + real-valued tensor if `return_complex` is False. Default: `None`(treated as + a rectangle window with value equal to 1 of size `win_length`). + center (bool, optional): It means that whether the time-domain signal has been + center padded. Default: `True`. + normalized (bool, optional): Control whether to scale the output by `1/sqrt(n_fft)`. + Default: `False` + onesided (bool, optional): It means that whether the input STFT tensor is a half + of the conjugate symmetry STFT tensor transformed from a real-valued signal + and `istft` will return a real-valued tensor when it is set to `True`. + Default: `True`. + length (int, optional): Specify the length of time-domain signal. Default: `None`( + treated as the whole length of signal). + return_complex (bool, optional): It means that whether the time-domain signal is + real-valued. If `return_complex` is set to `True`, `onesided` should be set to + `False` cause the output is complex. + name (str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A tensor of least squares estimation of the reconstructed signal(s) with shape + `[..., seq_length]` + + Exampels: + .. code-block:: python + + import numpy as np + import paddle + from paddle.signal import stft, istft + + paddle.seed(0) + + # STFT + x = paddle.randn([8, 48000], dtype=paddle.float64) + y = stft(x, n_fft=512) # [8, 257, 376] + + # ISTFT + x_ = istft(y, n_fft=512) # [8, 48000] + + np.allclose(x, x_) # True + """ + check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'istft') + + x_rank = len(x.shape) + assert x_rank in [2, 3], \ + 'x should be a 2D or 3D complex tensor, but got rank of x is {}'.format(x_rank) + + if x_rank == 2: # (batch, n_fft, n_frames) + x = x.unsqueeze(0) + + if hop_length is None: + hop_length = int(n_fft // 4) + + if win_length is None: + win_length = n_fft + + # Assure no gaps between frames. + assert 0 < hop_length <= win_length, \ + 'hop_length should be in (0, win_length({})], but got {}.'.format(win_length, hop_length) + + assert 0 < win_length <= n_fft, \ + 'win_length should be in (0, n_fft({})], but got {}.'.format(n_fft, win_length) + + n_frames = x.shape[-1] + fft_size = x.shape[-2] + + if onesided: + assert (fft_size == n_fft // 2 + 1), \ + 'fft_size should be equal to n_fft // 2 + 1({}) when onesided is True, but got {}.'.format(n_fft // 2 + 1, fft_size) + else: + assert (fft_size == n_fft), \ + 'fft_size should be equal to n_fft({}) when onesided is False, but got {}.'.format(n_fft, fft_size) + + if window is not None: + assert len(window.shape) == 1 and len(window) == win_length, \ + 'expected a 1D window tensor of size equal to win_length({}), but got window with shape {}.'.format(win_length, window.shape) + else: + window = paddle.ones(shape=(win_length, )) + + if win_length < n_fft: + pad_left = (n_fft - win_length) // 2 + pad_right = n_fft - win_length - pad_left + # FIXME: Input `window` can be a complex tensor but pad does not supprt complex input. + window = paddle.nn.functional.pad(window, + pad=[pad_left, pad_right], + mode='constant') + + x = x.transpose( + perm=[0, 2, + 1]) # switch n_fft to last dim, egs: (batch, num_frames, n_fft) + norm = 'ortho' if normalized else 'backward' + + if return_complex: + assert not onesided, \ + 'onesided should be False when input(output of istft) or window is a complex Tensor.' + + out = fft_c2c(x=x, n=None, axis=-1, norm=norm, forward=False, name=None) + else: + assert not is_complex(window), \ + 'Data type of window should not be complex when return_complex is False.' + + if onesided is False: + x = x[:, :, :n_fft // 2 + 1] + out = fft_c2r(x=x, n=None, axis=-1, norm=norm, forward=False, name=None) + + out = overlap_add( + x=(out * window).transpose( + perm=[0, 2, 1]), # (batch, n_fft, num_frames) + hop_length=hop_length, + axis=-1) # (batch, seq_length) + + window_envelop = overlap_add( + x=paddle.tile( + x=window * window, repeat_times=[n_frames, 1]).transpose( + perm=[1, 0]), # (n_fft, num_frames) + hop_length=hop_length, + axis=-1) # (seq_length, ) + + if length is None: + if center: + out = out[:, (n_fft // 2):-(n_fft // 2)] + window_envelop = window_envelop[(n_fft // 2):-(n_fft // 2)] + else: + if center: + start = n_fft // 2 + else: + start = 0 + + out = out[:, start:start + length] + window_envelop = window_envelop[start:start + length] + + # Check whether the Nonzero Overlap Add (NOLA) constraint is met. + if window_envelop.abs().min().item() < 1e-11: + raise ValueError( + 'Abort istft because Nonzero Overlap Add (NOLA) condition failed. For more information about NOLA constraint please see `scipy.signal.check_NOLA`(https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.check_NOLA.html).' + ) + + out = out / window_envelop + + if x_rank == 2: + out.squeeze_(0) + + return out diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index b5d79b6039320..4868a23ea0355 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -218,8 +218,6 @@ from .array import create_array # noqa: F401 from .einsum import einsum # noqa: F401 -from . import fft -from . import signal #this list used in math_op_patch.py for _binary_creator_ tensor_method_func = [ #noqa diff --git a/python/paddle/tensor/fft.py b/python/paddle/tensor/fft.py deleted file mode 100644 index 20fd143589fa4..0000000000000 --- a/python/paddle/tensor/fft.py +++ /dev/null @@ -1,1601 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Sequence -import numpy as np -import paddle -from .attribute import is_complex, is_floating_point, is_interger, _real_to_complex_dtype, _complex_to_real_dtype -from ..fluid.framework import in_dygraph_mode -from .. import _C_ops -from ..fluid.data_feeder import check_variable_and_dtype -from ..fluid.layer_helper import LayerHelper - -__all__ = [] - - -def _check_normalization(norm): - if norm not in ['forward', 'backward', 'ortho']: - raise ValueError( - "Unexpected norm: {}. Norm should be forward, backward or ortho". - format(norm)) - - -def _check_fft_n(n): - if not isinstance(n, int): - raise ValueError( - "Invalid FFT argument n({}), it shoule be an integer.".format(n)) - if n <= 0: - raise ValueError( - "Invalid FFT argument n({}), it should be positive.".format(n)) - - -def _check_fft_shape(x, s): - ndim = x.ndim - if not isinstance(s, Sequence): - raise ValueError( - "Invaid FFT argument s({}), it should be a sequence of integers.") - - if len(s) > ndim: - raise ValueError( - "Length of FFT argument s should not be larger than the rank of input. " - "Received s: {}, rank of x: {}".format(s, ndim)) - for size in s: - if not isinstance(size, int) or size <= 0: - raise ValueError("FFT sizes {} contains invalid value ({})".format( - s, size)) - - -def _check_fft_axis(x, axis): - ndim = x.ndim - if not isinstance(axis, int): - raise ValueError( - "Invalid FFT axis ({}), it shoule be an integer.".format(axis)) - if axis < -ndim or axis >= ndim: - raise ValueError( - "Invalid FFT axis ({}), it should be in range [-{}, {})".format( - axis, ndim, ndim)) - - -def _check_fft_axes(x, axes): - ndim = x.ndim - if not isinstance(axes, Sequence): - raise ValueError( - "Invalid FFT axes ({}), it should be a sequence of integers.". - format(axes)) - if len(axes) > ndim: - raise ValueError( - "Length of fft axes should not be larger than the rank of input. " - "Received, len of axes: {}, rank of x: {}".format(len(axes), ndim)) - for axis in axes: - if not isinstance(axis, int) or axis < -ndim or axis >= ndim: - raise ValueError( - "FFT axes {} contains invalid value ({}), it should be in range [-{}, {})". - format(axes, axis, ndim, ndim)) - - -def _resize_fft_input(x, s, axes): - if len(s) != len(axes): - raise ValueError("length of `s` should equals length of `axes`.") - shape = x.shape - ndim = x.ndim - - axes_to_pad = [] - paddings = [] - axes_to_slice = [] - slices = [] - for i, axis in enumerate(axes): - if shape[axis] < s[i]: - axes_to_pad.append(axis) - paddings.append(s[i] - shape[axis]) - elif shape[axis] > s[i]: - axes_to_slice.append(axis) - slices.append((0, s[i])) - - if axes_to_slice: - x = paddle.slice( - x, - axes_to_slice, - starts=[item[0] for item in slices], - ends=[item[1] for item in slices]) - if axes_to_pad: - padding_widths = [0] * (2 * ndim) - for axis, pad in zip(axes_to_pad, paddings): - padding_widths[2 * axis + 1] = pad - x = paddle.nn.functional.pad(x, padding_widths) - return x - - -def _normalize_axes(x, axes): - ndim = x.ndim - return [item if item >= 0 else (item + ndim) for item in axes] - - -def _check_at_least_ndim(x, rank): - if x.ndim < rank: - raise ValueError("The rank of the input ({}) should >= {}".format( - x.ndim, rank)) - - -# public APIs 1d -def fft(x, n=None, axis=-1, norm="backward", name=None): - """ - Calculate one-dimensional discrete Fourier transform. - - This function uses the efficient fast Fourier transform (FFT) algorithm [1] to - calculate the 1-D * n * point discrete Fourier transform (DFT). - - Args: - x (Tensor): The input data. It's a Tensor type. It's a complex. - n (int, optional): The length of the output transform axis. If `n` is less than - the length input, the input will be cropped. If larger, the input is filled - with zeros. If `n` is not given, the input length along the axis specified - by `axis` is used. - axis (int, optional): Axis used to calculate FFT. If not specified, the last axis - is used by default. - norm (str): Indicates which direction to scale the `forward` or `backward` transform - pair and what normalization factor to use. The parameter value must be one - of "forward" or "backward" or "ortho". Default is "backward", meaning no normalization on - the forward transforms and scaling by ``1/n`` on the `ifft`. "forward" instead applies - the ``1/n`` factor on the forward tranform. For ``norm="ortho"``, both directions are - scaled by ``1/sqrt(n)``. - name (str, optional): The default value is None. Normally there is no need for user to set - this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: - complex tensor. The truncated or zero-padded input, transformed along the axis indicated - by `axis`, or the last one if `axis` is not specified. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle - - x = np.exp(3j * np.pi * np.arange(7) / 7) - xp = paddle.to_tensor(x) - fft_xp = paddle.fft.fft(xp).numpy() - print(fft_xp) - # [1.+1.25396034e+00j 1.+4.38128627e+00j 1.-4.38128627e+00j - # 1.-1.25396034e+00j 1.-4.81574619e-01j 1.+8.88178420e-16j - # 1.+4.81574619e-01j] - - - """ - if is_interger(x) or is_floating_point(x): - return fft_r2c( - x, n, axis, norm, forward=True, onesided=False, name=name) - else: - return fft_c2c(x, n, axis, norm, forward=True, name=name) - - -def ifft(x, n=None, axis=-1, norm="backward", name=None): - """ - Compute the 1-D inverse discrete Fourier Transform. - - This function computes the inverse of the 1-D *n*-point discrete Fourier transform - computed by `fft`. In other words, ``ifft(fft(x)) == x`` to within numerical accuracy. - - The input should be ordered in the same way as is returned by `fft`, - i.e., - - * ``x[0]`` should contain the zero frequency term, - * ``x[1:n//2]`` should contain the positive-frequency terms, - * ``x[n//2 + 1:]`` should contain the negative-frequency terms, in - increasing order starting from the most negative frequency. - - For an even number of input points, ``x[n//2]`` represents the sum of - the values at the positive and negative Nyquist frequencies, as the two - are aliased together. - - Args: - x (Tensor): The input data. It's a Tensor type. It's a complex. - n (int, optional): The length of the output transform axis. If `n` is less than - the length input, the input will be cropped. If larger, the input is filled - with zeros. If `n` is not given, the input length along the axis specified - by `axis` is used. - axis (int, optional): Axis used to calculate FFT. If not specified, the last axis - is used by default. - norm (str): Indicates which direction to scale the `forward` or `backward` transform - pair and what normalization factor to use. The parameter value must be one - of "forward" or "backward" or "ortho". Default is "backward", meaning no normalization on - the forward transforms and scaling by ``1/n`` on the `ifft`. "forward" instead applies - the ``1/n`` factor on the forward tranform. For ``norm="ortho"``, both directions are - scaled by ``1/sqrt(n)``. - name (str, optional): The default value is None. Normally there is no need for user to set - this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: - complex tensor. The truncated or zero-padded input, transformed along the axis indicated - by `axis`, or the last one if `axis` is not specified. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle - - x = np.exp(3j * np.pi * np.arange(7) / 7) - xp = paddle.to_tensor(x) - ifft_xp = paddle.fft.ifft(xp).numpy() - print(ifft_xp) - # [0.14285714+1.79137191e-01j 0.14285714+6.87963741e-02j - # 0.14285714+1.26882631e-16j 0.14285714-6.87963741e-02j - # 0.14285714-1.79137191e-01j 0.14285714-6.25898038e-01j - # 0.14285714+6.25898038e-01j] - - """ - if is_interger(x) or is_floating_point(x): - return fft_r2c( - x, n, axis, norm, forward=False, onesided=False, name=name) - else: - return fft_c2c(x, n, axis, norm, forward=False, name=name) - - -def rfft(x, n=None, axis=-1, norm="backward", name=None): - """ - The one dimensional FFT for real input. - - This function computes the one dimensional *n*-point discrete Fourier - Transform (DFT) of a real-valued tensor by means of an efficient algorithm - called the Fast Fourier Transform (FFT). - - When the DFT is computed for purely real input, the output is - Hermitian-symmetric. This function does not compute the negative frequency - terms, and the length of the transformed axis of the output is therefore - ``n//2 + 1``. - - Args: - x(Tensor) : Real-valued input tensor - n(int, optional): Number of points along transformation axis in the - input to use. If `n` is smaller than the length of the input, the - input is cropped. If it is larger, the input is padded with zeros. - If `n` is not given, the length of the input along the axis - specified by `axis` is used. - axis(int, optional): Axis over which to compute the FFT. Default value - is last axis. - norm(str, optional) : Normalization mode, indicates which direction of - the forward/backward pair of transforms is scaled and with what - normalization factor. Include {"backward", "ortho", "forward"}, - default value is "backward". - name(str, optional): The default value is None. Normally there is no - need for user to set this property. For more information, please - refer to :ref:`api_guide_Name` . - - Returns: - out(Tensor) : complex tensor - - Raises: - - - Examples: - .. code-block:: python - import paddle - - x = paddle.to_tensor([0.0, 1.0, 0.0, 0.0]) - print(paddle.fft.rfft(x)) - # Tensor(shape=[3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, - # [ (1+0j), -1j , (-1+0j)]) - """ - return fft_r2c(x, n, axis, norm, forward=True, onesided=True, name=name) - - -def irfft(x, n=None, axis=-1, norm="backward", name=None): - """ - Computes the inverse of `rfft`. - - This function calculates the inverse of the one-dimensional *n* point discrete - Fourier transform of the actual input calculated by "rfft". In other words, - ``irfft(rfft(a),len(a)) == a`` is within the numerical accuracy range. - - The input shall be in the form of "rfft", i.e. the actual zero frequency term, - followed by the complex positive frequency term, in the order of increasing frequency. - Because the discrete Fourier transform of the actual input is Hermite symmetric, - the negative frequency term is regarded as the complex conjugate term of the corresponding - positive frequency term. - - Args: - x (Tensor): The input data. It's a Tensor type. It's a complex. - n (int, optional): The length of the output transform axis. For `n` output - points, ``n//2 + 1``input points are necessary. If the length of the input tensor is greater - than `n`, it will be cropped, if it is shorter than this, fill in zero. If `n` is not given, - it is considered to be ``2 * (k-1)``, where ``k`` is the length of the input axis specified - along the ` axis'. - axis (int, optional): Axis used to calculate FFT. If not specified, the last axis - is used by default. - norm (str): Indicates which direction to scale the `forward` or `backward` transform - pair and what normalization factor to use. The parameter value must be one - of "forward" or "backward" or "ortho". Default is "backward". - name (str, optional): The default value is None. Normally there is no need for user to set - this property. For more information, please refer to :ref:`api_guide_Name` . - - Returns: - Real tensor. Truncated or zero fill input for the transformation along the axis indicated by - `axis`, or the last input if `axis` is not specified. The length of the conversion axis - is `n`, or ``2 * k-2``, if `k` is None, where `k` is the length of the input conversion axis. - If the output is an odd number, you need to specify the value of 'n', such as ``2 * k-1`` - in some cases. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle - - x = np.array([1, -1j, -1]) - xp = paddle.to_tensor(x) - irfft_xp = paddle.fft.irfft(xp).numpy() - print(irfft_xp) - # [0. 1. 0. 0.] - - """ - return fft_c2r(x, n, axis, norm, forward=False, name=name) - - -def hfft(x, n=None, axis=-1, norm="backward", name=None): - """ - Compute the FFT of a signal that has Hermitian symmetry, a real - spectrum. - - Args: - x (Tensor): The input data. It's a Tensor type. It's a complex. - n (int, optional): The length of the output transform axis. For `n` output - points, ``n//2 + 1`` input points are necessary. If the length of the input tensor is greater - than `n`, it will be cropped, if it is shorter than this, fill in zero. If `n` is not given, - it is considered to be ``2 * (k-1)``, where ``k`` is the length of the input axis specified - along the ` axis'. - axis (int,optional): Axis used to calculate FFT. If not specified, the last axis - is used by default. - norm (str): Indicates which direction to scale the `forward` or `backward` transform - pair and what normalization factor to use. The parameter value must be one - of "forward" or "backward" or "ortho". Default is "backward". - name (str, optional): The default value is None. Normally there is no need for user to set - this property. For more information, please refer to :ref:`api_guide_Name` . - - Returns: - Real tensor. Truncated or zero fill input for the transformation along the axis indicated by - `axis`, or the last input if `axis` is not specified. The length of the conversion axis - is `n`, or ``2 * k-2``, if `k` is None, where `k` is the length of the input conversion axis. - If the output is an odd number, you need to specify the value of 'n', such as ``2 * k-1`` in - some cases. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle - - x = np.array([1, -1j, -1]) - xp = paddle.to_tensor(x) - hfft_xp = paddle.fft.hfft(xp).numpy() - print(hfft_xp) - # [0. 0. 0. 4.] - """ - - return fft_c2r(x, n, axis, norm, forward=True, name=name) - - -def ihfft(x, n=None, axis=-1, norm="backward", name=None): - """ - The inverse FFT of a signal that has Hermitian symmetry. - - This function computes the one dimensional *n*-point inverse FFT of a signal - that has Hermitian symmetry by means of an efficient algorithm called - the Fast Fourier Transform (FFT). - - When the DFT is computed for purely real input, the output is - Hermitian-symmetric. This function does not compute the negative frequency - terms, and the length of the transformed axis of the output is therefore - ``n//2 + 1``. - - Args: - x(Tensor): Input tensor. - n(int, optional): The number of points along transformation axis in the - input to use. If `n` is smaller than the length of the input, the - input is cropped. If it is larger, the input is padded with zeros. - If `n` is not given, the length of the input along the axis - specified by `axis` is used. - axis(int, optional) : Axis over which to compute the inverse FFT. If not - given, the last axis is used. - norm(str, optional) : Normalization mode, indicates which direction of - the forward/backward pair of transforms is scaled and with what - normalization factor. Include {"backward", "ortho", "forward"}, - default value is "backward". - name(str, optional): The default value is None. Normally there is no - need for user to set this property. For more information, please - refer to :ref:`api_guide_Name` . - - Returns: - out(Tensor) : complex tensor. - - Examples: - .. code-block:: python - import paddle - - spectrum = paddle.to_tensor([10.0, -5.0, 0.0, -1.0, 0.0, -5.0]) - print(paddle.fft.ifft(spectrum)) - # Tensor(shape=[6], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, - # [(-0.1666666716337204+0j), (1-1.9868215517249155e-08j), (2.3333334922790527-1.9868215517249155e-08j), (3.5+0j), (2.3333334922790527+1.9868215517249155e-08j), (1+1.9868215517249155e-08j)]) - print(paddle.fft.ihfft(spectrum)) - # Tensor(shape = [4], dtype = complex64, place = CUDAPlace(0), stop_gradient = True, - # [(-0.1666666716337204+0j), (1-1.9868215517249155e-08j), (2.3333334922790527-1.9868215517249155e-08j), (3.5+0j)]) - - """ - return fft_r2c(x, n, axis, norm, forward=False, onesided=True, name=name) - - -# public APIs nd -def fftn(x, s=None, axes=None, norm="backward", name=None): - """ - Compute the N-D discrete Fourier Transform. - - This function calculates the n-D discrete Fourier transform on any number of axes - in the M-D array by fast Fourier transform (FFT). - - Args: - x (Tensor): The input data. It's a Tensor type. It's a complex. - s (sequence of ints, optional): Shape (length of each transformed axis) of the output - (``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). - This corresponds to ``n`` for ``fft(x, n)``. - Along any axis, if the given shape is smaller than that of the input, - the input is cropped. If it is larger, the input is padded with zeros. - if `s` is not given, the shape of the input along the axes specified - by `axes` is used. - axes (sequence of ints, optional): Axes used to calculate FFT. If not given, the last ``len(s)`` - axes are used, or all axes if `s` is also not specified. - norm (str): Indicates which direction to scale the `forward` or `backward` transform - pair and what normalization factor to use. The parameter value must be one - of "forward" or "backward" or "ortho". Default is "backward", meaning no normalization on - the forward transforms and scaling by ``1/n`` on the `ifft`. "forward" instead applies - the ``1/n`` factor on the forward tranform. For ``norm="ortho"``, both directions are - scaled by ``1/sqrt(n)``. - name (str, optional): The default value is None. Normally there is no need for user to set - this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: - complex tensor. The truncated or zero-padded input, transformed along the axes indicated by - `axes`, or by a combination of `s` and `x`, as explained in the parameters section above. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle - - x = np.mgrid[:4, :4, :4][1] - xp = paddle.to_tensor(x) - fftn_xp = paddle.fft.fftn(xp, axes=(1, 2)).numpy() - print(fftn_xp) - # [[[24.+0.j 0.+0.j 0.+0.j 0.-0.j] - # [-8.+8.j 0.+0.j 0.+0.j 0.-0.j] - # [-8.+0.j 0.+0.j 0.+0.j 0.-0.j] - # [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]] - # [[24.+0.j 0.+0.j 0.+0.j 0.-0.j] - # [-8.+8.j 0.+0.j 0.+0.j 0.-0.j] - # [-8.+0.j 0.+0.j 0.+0.j 0.-0.j] - # [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]] - # [[24.+0.j 0.+0.j 0.+0.j 0.-0.j] - # [-8.+8.j 0.+0.j 0.+0.j 0.-0.j] - # [-8.+0.j 0.+0.j 0.+0.j 0.-0.j] - # [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]] - # [[24.+0.j 0.+0.j 0.+0.j 0.-0.j] - # [-8.+8.j 0.+0.j 0.+0.j 0.-0.j] - # [-8.+0.j 0.+0.j 0.+0.j 0.-0.j] - # [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]]] - """ - if is_interger(x) or is_floating_point(x): - return fftn_r2c( - x, s, axes, norm, forward=True, onesided=False, name=name) - else: - return fftn_c2c(x, s, axes, norm, forward=True, name=name) - - -def ifftn(x, s=None, axes=None, norm="backward", name=None): - """ - Compute the N-D inverse discrete Fourier Transform. - - This function computes the inverse of the N-D discrete - Fourier Transform over any number of axes in an M-D array by - means of the Fast Fourier Transform (FFT). In other words, - ``ifftn(fftn(x)) == x`` to within numerical accuracy. - - The input, analogously to `ifft`, should be ordered in the same way as is - returned by `fftn`, i.e., it should have the term for zero frequency - in all axes in the low-order corner, the positive frequency terms in the - first half of all axes, the term for the Nyquist frequency in the middle - of all axes and the negative frequency terms in the second half of all - axes, in order of decreasingly negative frequency. - - Args: - x (Tensor): The input data. It's a Tensor type. It's a complex. - s (sequence of ints, optional): Shape (length of each transformed axis) of the output - (``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). - This corresponds to ``n`` for ``fft(x, n)``. - Along any axis, if the given shape is smaller than that of the input, - the input is cropped. If it is larger, the input is padded with zeros. - if `s` is not given, the shape of the input along the axes specified - by `axes` is used. - axes (sequence of ints, optional): Axes used to calculate FFT. If not given, the last ``len(s)`` - axes are used, or all axes if `s` is also not specified. - norm (str): Indicates which direction to scale the `forward` or `backward` transform - pair and what normalization factor to use. The parameter value must be one - of "forward" or "backward" or "ortho". Default is "backward", meaning no normalization on - the forward transforms and scaling by ``1/n`` on the `ifft`. "forward" instead applies - the ``1/n`` factor on the forward tranform. For ``norm="ortho"``, both directions are - scaled by ``1/sqrt(n)``. - name (str, optional): The default value is None. Normally there is no need for user to set - this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: - complex tensor. The truncated or zero-padded input, transformed along the axes indicated by - `axes`, or by a combination of `s` and `x`, as explained in the parameters section above. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle - - x = np.eye(3) - xp = paddle.to_tensor(x) - ifftn_xp = paddle.fft.ifftn(xp, axes=(1,)).numpy() - print(ifftn_xp) - - # [[ 0.33333333+0.j 0.33333333+0.j 0.33333333-0.j ] - # [ 0.33333333+0.j -0.16666667+0.28867513j -0.16666667-0.28867513j] - # [ 0.33333333+0.j -0.16666667-0.28867513j -0.16666667+0.28867513j]] - - """ - if is_interger(x) or is_floating_point(x): - return fftn_r2c( - x, s, axes, norm, forward=False, onesided=False, name=name) - else: - return fftn_c2c(x, s, axes, norm, forward=False, name=name) - - -def rfftn(x, s=None, axes=None, norm="backward", name=None): - """ - The N dimensional FFT for real input. - - This function computes the N-dimensional discrete Fourier Transform over - any number of axes in an M-dimensional real array by means of the Fast - Fourier Transform (FFT). By default, all axes are transformed, with the - real transform performed over the last axis, while the remaining - transforms are complex. - - The transform for real input is performed over the last transformation - axis, as by `rfft`, then the transform over the remaining axes is - performed as by `fftn`. The order of the output is as for `rfft` for the - final transformation axis, and as for `fftn` for the remaining - transformation axes. - - Args: - x(Tensor) : Input tensor, taken to be real. - s(Sequence[int]) : Shape to use from the exec fft. The final element of - `s` corresponds to `n` for ``rfft(x, n)``, while for the remaining - axes, it corresponds to `n` for ``fft(x, n)``. Along any axis, if - the given shape is smaller than that of the input, the input is - cropped. If it is larger, the input is padded with zeros. if `s` is - not given, the shape of the input along the axes specified by `axes` - is used. - axes(Sequence[int]) : Axes over which to compute the FFT. If not given, - the last ``len(s)`` axes are used, or all axes if `s` is also not - specified. - norm(str, optional) : Normalization mode, indicates which direction of - the forward/backward pair of transforms is scaled and with what - normalization factor. Include {"backward", "ortho", "forward"}, - default value is "backward". - name(str, optional): The default value is None. Normally there is no - need for user to set this property. For more information, please - refer to :ref:`api_guide_Name` . - - Returns: - out(Tensor): complex tensor - - - Raises: - ValueError: If `s` and `axes` have different length. - - Examples: - .. code-block:: python - import paddle - - # default, all axis will be used to exec fft - x = paddle.ones((2, 3, 4)) - print(paddle.fft.rfftn(x)) - # Tensor(shape=[2, 3, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, - # [[[(24+0j), 0j , 0j ], - # [0j , 0j , 0j ], - # [0j , 0j , 0j ]], - # - # [[0j , 0j , 0j ], - # [0j , 0j , 0j ], - # [0j , 0j , 0j ]]]) - - # use axes(2, 0) - print(paddle.fft.rfftn(x, axes=(2, 0))) - # Tensor(shape=[2, 3, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, - # [[[(8+0j), 0j , 0j ], - # [(8+0j), 0j , 0j ], - # [(8+0j), 0j , 0j ]], - # - # [[0j , 0j , 0j ], - # [0j , 0j , 0j ], - # [0j , 0j , 0j ]]]) - - """ - return fftn_r2c(x, s, axes, norm, forward=True, onesided=True, name=name) - - -def irfftn(x, s=None, axes=None, norm="backward", name=None): - """ - Computes the inverse of `rfftn`. - - This function computes the inverse of the N-D discrete - Fourier Transform for real input over any number of axes in an - M-D array by means of the Fast Fourier Transform (FFT). In - other words, ``irfftn(rfftn(x), x.shape) == x`` to within numerical - accuracy. (The ``a.shape`` is necessary like ``len(a)`` is for `irfft`, - and for the same reason.) - - The input should be ordered in the same way as is returned by `rfftn`, - i.e., as for `irfft` for the final transformation axis, and as for `ifftn` - along all the other axes. - - Args: - x (Tensor): The input data. It's a Tensor type. - s (sequence of ints, optional): The length of the output transform axis. - (``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). `s` is also the - number of input points used along this axis, except for the last axis, - where ``s[-1]//2+1`` points of the input are used. Along any axis, if - the shape indicated by `s` is smaller than that of the input, the input - is cropped. If it is larger, the input is padded with zeros. - If `s` is not given, the shape of the input along the axes specified by axes - is used. Except for the last axis which is taken to be ``2*(k-1)`` where - ``k`` is the length of the input along that axis. - axes (sequence of ints, optional): Axes over which to compute the inverse FFT. If not given, the last - `len(s)` axes are used, or all axes if `s` is also not specified. - norm (str): Indicates which direction to scale the `forward` or `backward` transform - pair and what normalization factor to use. The parameter value must be one - of "forward" or "backward" or "ortho". Default is "backward". - name (str, optional): The default value is None. Normally there is no need for user to set - this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: - Real tensor. The truncated or zero-padded input, transformed along the axes indicated by `axes`, - or by a combination of `s` or `x`, as explained in the parameters section above. The length of - each transformed axis is as given by the corresponding element of `s`, or the length of the input - in every axis except for the last one if `s` is not given. In the final transformed axis the length - of the output when `s` is not given is ``2*(m-1)``, where ``m`` is the length of the final - transformed axis of the input. To get an odd number of output points in the final axis, - `s` must be specified. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle - - x = (np.array([2, 2, 3]) + 1j * np.array([2, 2, 3])).astype(np.complex128) - xp = paddle.to_tensor(x) - irfftn_xp = paddle.fft.irfftn(xp).numpy() - print(irfftn_xp) - # [ 2.25 -1.25 0.25 0.75] - - """ - return fftn_c2r(x, s, axes, norm, forward=False, name=name) - - -def hfftn(x, s=None, axes=None, norm="backward", name=None): - """ - Compute the N-D FFT of Hermitian symmetric complex input, i.e., a - signal with a real spectrum. - - This function calculates the n-D discrete Fourier transform of Hermite symmetric - complex input on any axis in M-D array by fast Fourier transform (FFT). - In other words, ``ihfftn(hfftn(x, s)) == x is within the numerical accuracy range. - (``s`` here are ``x.shape`` and ``s[-1] = x.shape[- 1] * 2 - 1``. This is necessary - for the same reason that ``irfft` requires ``x.shape``.) - - Args: - x (Tensor): The input data. It's a Tensor type. - s (sequence of ints, optional): The length of the output transform axis. - (``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). `s` is also the - number of input points used along this axis, except for the last axis, - where ``s[-1]//2+1`` points of the input are used. Along any axis, if - the shape indicated by `s` is smaller than that of the input, the input - is cropped. If it is larger, the input is padded with zeros. - If `s` is not given, the shape of the input along the axes specified by axes - is used. Except for the last axis which is taken to be ``2*(k-1)`` where - ``k`` is the length of the input along that axis. - axes (sequence of ints, optional): Axes over which to compute the inverse FFT. If not given, the last - `len(s)` axes are used, or all axes if `s` is also not specified. - norm (str): Indicates which direction to scale the `forward` or `backward` transform - pair and what normalization factor to use. The parameter value must be one - of "forward" or "backward" or "ortho". Default is "backward". - name (str, optional): The default value is None. Normally there is no need for user to set - this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: - Real tensor. Truncate or zero fill input, transforming along the axis indicated by axis or - a combination of `s` or `X`. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle - - x = (np.array([2, 2, 3]) + 1j * np.array([2, 2, 3])).astype(np.complex128) - xp = paddle.to_tensor(x) - hfftn_xp = paddle.fft.hfftn(xp).numpy() - print(hfftn_xp) - # [ 9. 3. 1. -5.] - - - """ - return fftn_c2r(x, s, axes, norm, forward=True, name=name) - - -def ihfftn(x, s=None, axes=None, norm="backward", name=None): - """ - The n dimensional inverse FFT of a signal that has Hermitian symmetry. - - This function computes the n dimensional inverse FFT over any number of axes - in an M-dimensional of a signal that has Hermitian symmetry by means of an - efficient algorithm called the Fast Fourier Transform (FFT). - - Args: - x(Tensor): Input tensor. - s(Sequence[int], optional) : Shape (length along each transformed axis) - to use from the input. (``s[0]`` refers to axis 0, ``s[1]`` to axis - 1, etc.). Along any axis, if the given shape is smaller than that - of the input, the input is cropped. If it is larger, the input is - padded with zeros. if `s` is not given, the shape of the input - along the axes specified by `axes` is used. - axis(Sequence[int], optional) : Axis over which to compute the inverse FFT. If not - given, the last axis is used. - norm(str, optional) : Normalization mode, indicates which direction of - the forward/backward pair of transforms is scaled and with what - normalization factor. Include {"backward", "ortho", "forward"}, - default value is "backward". - name(str, optional): The default value is None. Normally there is no - need for user to set this property. For more information, please - refer to :ref:`api_guide_Name` . - - Returns: - out(Tensor) : complex tensor. - - Examples: - .. code-block:: python - import paddle - - spectrum = paddle.to_tensor([10.0, -5.0, 0.0, -1.0, 0.0, -5.0]) - print(paddle.fft.ifft(spectrum)) - # Tensor(shape=[6], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, - # [(-0.1666666716337204+0j), (1-1.9868215517249155e-08j), (2.3333334922790527-1.9868215517249155e-08j), (3.5+0j), (2.3333334922790527+1.9868215517249155e-08j), (1+1.9868215517249155e-08j)]) - print(paddle.fft.ihfft(spectrum)) - # Tensor(shape = [4], dtype = complex64, place = CUDAPlace(0), stop_gradient = True, - # [(-0.1666666716337204+0j), (1-1.9868215517249155e-08j), (2.3333334922790527-1.9868215517249155e-08j), (3.5+0j)]) - - """ - return fftn_r2c(x, s, axes, norm, forward=False, onesided=True, name=name) - - -# public APIs 2d -def fft2(x, s=None, axes=(-2, -1), norm="backward", name=None): - """ - Compute the 2-D discrete Fourier Transform - - This function computes the N-D discrete Fourier Transform - over any axes in an M-D array by means of the - Fast Fourier Transform (FFT). By default, the transform is computed over - the last two axes of the input array, i.e., a 2-dimensional FFT. - - Args: - x (Tensor): The input data. It's a Tensor type. - s (sequence of ints, optional): Shape (length of each transformed axis) of the output. - It should be a sequence of 2 integers. This corresponds to ``n`` for ``fft(x, n)``. - Along each axis, if the given shape is smaller than that of the input, - the input is cropped. If it is larger, the input is padded with zeros. - if `s` is not given, the shape of the input along the axes specified - by `axes` is used. Default is None. - axes (sequence of ints, optional): Axes over which to compute the FFT. It should be a - sequence of 2 integers. If not specified, the last two axes are used by default. - norm (str): Indicates which direction to scale the `forward` or `backward` transform - pair and what normalization factor to use. The parameter value must be one - of "forward" or "backward" or "ortho". Default is "backward". - name (str, optional): The default value is None. Normally there is no need for user to set - this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: - Complex tensor. The truncated or zero-padded input, transformed along the axes indicated by `axes`, - or the last two axes if `axes` is not given. - - Raises: - ValueError: if `s` not be a sequence of 2 integers or None. - ValueError: if `axes` not be a sequence of 2 integers or None. - ValueError: If the input dimension is smaller than 2. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle - - x = np.mgrid[:2, :2][1] - xp = paddle.to_tensor(x) - fft2_xp = paddle.fft.fft2(xp).numpy() - print(fft2_xp) - # [[ 2.+0.j -2.+0.j] - # [ 0.+0.j 0.+0.j]] - - """ - _check_at_least_ndim(x, 2) - if s is not None: - if not isinstance(s, Sequence) or len(s) != 2: - raise ValueError( - "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". - format(s)) - if axes is not None: - if not isinstance(axes, Sequence) or len(axes) != 2: - raise ValueError( - "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". - format(axes)) - return fftn(x, s, axes, norm, name) - - -def ifft2(x, s=None, axes=(-2, -1), norm="backward", name=None): - """ - Compute the 2-D inverse discrete Fourier Transform. - - This function computes the inverse of the 2-D discrete Fourier - Transform over any number of axes in an M-D array by means of - the Fast Fourier Transform (FFT). In other words, ``ifft2(fft2(x)) == x`` - to within numerical accuracy. By default, the inverse transform is - computed over the last two axes of the input array. - - The input, analogously to `ifft`, should be ordered in the same way as is - returned by `fft2`, i.e., it should have the term for zero frequency - in the low-order corner of the two axes, the positive frequency terms in - the first half of these axes, the term for the Nyquist frequency in the - middle of the axes and the negative frequency terms in the second half of - both axes, in order of decreasingly negative frequency. - - Args: - x (Tensor): The input data. It's a Tensor type. - s (sequence of ints, optional): Shape (length of each transformed axis) of the output. - It should be a sequence of 2 integers. This corresponds to ``n`` for ``fft(x, n)``. - Along each axis, if the given shape is smaller than that of the input, - the input is cropped. If it is larger, the input is padded with zeros. - if `s` is not given, the shape of the input along the axes specified - by `axes` is used. Default is None. - axes (sequence of ints, optional): Axes over which to compute the FFT. It should be a - sequence of 2 integers. If not specified, the last two axes are used by default. - norm (str): Indicates which direction to scale the `forward` or `backward` transform - pair and what normalization factor to use. The parameter value must be one - of "forward" or "backward" or "ortho". Default is "backward". - name (str, optional): The default value is None. Normally there is no need for user to set - this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: - Complex tensor. The truncated or zero-padded input, transformed along the axes indicated by `axes`, - or the last two axes if `axes` is not given. - - Raises: - ValueError: if `s` not be a sequence of 2 integers or None. - ValueError: if `axes` not be a sequence of 2 integers or None. - ValueError: If the input dimension is smaller than 2. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle - - x = np.mgrid[:2, :2][1] - xp = paddle.to_tensor(x) - ifft2_xp = paddle.fft.ifft2(xp).numpy() - print(ifft2_xp) - # [[ 0.5+0.j -0.5+0.j] - # [ 0. +0.j 0. +0.j]] - """ - _check_at_least_ndim(x, 2) - if s is not None: - if not isinstance(s, Sequence) or len(s) != 2: - raise ValueError( - "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". - format(s)) - if axes is not None: - if not isinstance(axes, Sequence) or len(axes) != 2: - raise ValueError( - "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". - format(axes)) - return ifftn(x, s, axes, norm, name) - - -def rfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): - """ - The two dimensional FFT with real tensor input. - - This is really just `rfftn` with different default behavior. - For more details see `rfftn`. - - Args: - x(Tensor): Input tensor, taken to be real. - s(Sequence[int]) : Shape of the FFT. - axes(Sequence[int], optional): Axes over which to compute the FFT. - norm(str, optional) : {"backward", "ortho", "forward"}, - default is "backward". Indicates which direction of the - forward/backward pair of transforms is scaled and with what - normalization factor. - name(str, optional): The default value is None. Normally there is no - need for user to set this property. For more information, please - refer to :ref:`api_guide_Name` . - - Returns: - out(Tensor): The result of the real 2-D FFT. - - Raises: - - - Examples: - - .. code-block:: python - import paddle - import numpy as np - - x = paddle.to_tensor(np.mgrid[:5, :5][0].astype(np.float32)) - print(paddle.fft.rfft2(x)) - # Tensor(shape=[5, 3], dtype=complex64, place=CUDAPlace(0), stop_gradient=True, - # [[ (50+0j) , (1.1920928955078125e-07+0j) , 0j ], - # [(-12.5+17.204774856567383j) , (-9.644234211236835e-08+7.006946134424652e-08j) , 0j ], - # [(-12.500000953674316+4.061495304107666j) , (3.6837697336977726e-08-1.1337477445749755e-07j), 0j ], - # [(-12.500000953674316-4.061495304107666j) , (3.6837697336977726e-08+1.1337477445749755e-07j), 0j ], - # [(-12.5-17.204774856567383j) , (-9.644234211236835e-08-7.006946134424652e-08j) , 0j ]]) - """ - _check_at_least_ndim(x, 2) - if s is not None: - if not isinstance(s, Sequence) or len(s) != 2: - raise ValueError( - "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". - format(s)) - if axes is not None: - if not isinstance(axes, Sequence) or len(axes) != 2: - raise ValueError( - "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". - format(axes)) - return rfftn(x, s, axes, norm, name) - - -def irfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): - """ - Computes the inverse of `rfft2`. - - Args: - x (Tensor): The input data. It's a Tensor type. - s (sequence of ints, optional): Shape of the real output to the inverse FFT. Default is None. - axes (sequence of ints, optional): The axes over which to compute the inverse FFT. Axes - must be two-dimensional. If not specified, the last two axes are used by default. - norm (str): Indicates which direction to scale the `forward` or `backward` transform - pair and what normalization factor to use. The parameter value must be one - of "forward" or "backward" or "ortho". Default is "backward". - name (str, optional): The default value is None. Normally there is no need for user to set - this property. For more information, please refer to :ref:`api_guide_Name` . - - Returns: - Real tensor. The result of the inverse real 2-D FFT. - - Raises: - ValueError: if `s` not be a sequence of 2 integers or None. - ValueError: if `axes` not be a sequence of 2 integers or None. - ValueError: If the input dimension is smaller than 2. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle - - x = (np.array([[3,2,3],[2, 2, 3]]) + 1j * np.array([[3,2,3],[2, 2, 3]])).astype(np.complex128) - xp = paddle.to_tensor(x) - irfft2_xp = paddle.fft.irfft2(xp).numpy() - print(irfft2_xp) - # [[ 2.375 -1.125 0.375 0.875] - # [ 0.125 0.125 0.125 0.125]] - - """ - _check_at_least_ndim(x, 2) - if s is not None: - if not isinstance(s, Sequence) or len(s) != 2: - raise ValueError( - "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". - format(s)) - if axes is not None: - if not isinstance(axes, Sequence) or len(axes) != 2: - raise ValueError( - "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". - format(axes)) - return irfftn(x, s, axes, norm, name) - - -def hfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): - """ - Compute the 2-D FFT of a Hermitian complex array. - - Args: - x (Tensor): The input data. It's a Tensor type. - s (sequence of ints, optional): Shape of the real output. Default is None. - axes (sequence of ints, optional): Axes over which to compute the FFT. Axes must be - two-dimensional. If not specified, the last two axes are used by default. - norm (str): Indicates which direction to scale the `forward` or `backward` transform - pair and what normalization factor to use. The parameter value must be one - of "forward" or "backward" or "ortho". Default is "backward". - name (str, optional): The default value is None. Normally there is no need for user to set - this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: - Real tensor. The real result of the 2-D Hermitian complex real FFT. - - Raises: - ValueError: if `s` not be a sequence of 2 integers or None. - ValueError: if `axes` not be a sequence of 2 integers or None. - ValueError: If the input dimension is smaller than 2. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle - - x = (np.array([[3,2,3],[2, 2, 3]]) + 1j * np.array([[3,2,3],[2, 2, 3]])).astype(np.complex128) - xp = paddle.to_tensor(x) - hfft2_xp = paddle.fft.hfft2(xp).numpy() - print(hfft2_xp) - # [[19. 7. 3. -9.] - # [ 1. 1. 1. 1.]] - - - """ - _check_at_least_ndim(x, 2) - if s is not None: - if not isinstance(s, Sequence) or len(s) != 2: - raise ValueError( - "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". - format(s)) - if axes is not None: - if not isinstance(axes, Sequence) or len(axes) != 2: - raise ValueError( - "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". - format(axes)) - return hfftn(x, s, axes, norm, name) - - -def ihfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): - """ - Compute the two dimensional inverse FFT of a real spectrum. - - This is really `ihfftn` with different defaults. - For more details see `ihfftn`. - - Args: - x(Tensor): Input tensor - s(Sequence[int], optional): Shape of the real input to the inverse FFT. - axes(Sequance[int], optional): The axes over which to compute the - inverse fft. Default is the last two axes. - norm(str, optional): {"backward", "ortho", "forward"}. Default is - "backward". - name(str, optional): The default value is None. Normally there is no - need for user to set this property. For more information, please - refer to :ref:`api_guide_Name` . - - Returns: - out(Tensor) : The result of the inverse hermitian 2-D FFT. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle - - x = np.mgrid[:5, :5][0].astype(np.float64) - xp = paddle.to_tensor(x) - ihfft2_xp = paddle.fft.ihfft2(xp).numpy() - print(ihfft2_xp) - # [[ 2. +0.j 0. +0.j 0. +0.j ] - # [-0.5-0.68819096j 0. +0.j 0. +0.j ] - # [-0.5-0.16245985j 0. +0.j 0. +0.j ] - # [-0.5+0.16245985j 0. +0.j 0. +0.j ] - # [-0.5+0.68819096j 0. +0.j 0. +0.j ]] - """ - _check_at_least_ndim(x, 2) - if s is not None: - if not isinstance(s, Sequence) or len(s) != 2: - raise ValueError( - "Invalid FFT argument s ({}), it should be a sequence of 2 integers.". - format(s)) - if axes is not None: - if not isinstance(axes, Sequence) or len(axes) != 2: - raise ValueError( - "Invalid FFT argument axes ({}), it should be a sequence of 2 integers.". - format(axes)) - return ihfftn(x, s, axes, norm, name) - - -# public APIs utilities -def fftfreq(n, d=1.0, dtype=None, name=None): - """ - Return the Discrete Fourier Transform sample frequencies. - - The returned float array `f` contains the frequency bin centers in cycles - per unit of the sample spacing (with zero at the start). For instance, if - the sample spacing is in seconds, then the frequency unit is cycles/second. - - Given input length `n` and a sample spacing `d`:: - - f = [0, 1, ..., n/2-1, -n/2, ..., -1] / (d*n) if n is even - f = [0, 1, ..., (n-1)/2, -(n-1)/2, ..., -1] / (d*n) if n is odd - - Args: - n (int): Dimension inputed. - d (scalar, optional): Sample spacing (inverse of the sampling rate). Defaults is 1. - name (str, optional): The default value is None. Normally there is no need for user to set - this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: - Tensor. A tensor of length 'n' containing the sampling frequency. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle - - x = np.array([3, 1, 2, 2, 3], dtype=float) - scalar_temp = 0.5 - n = x.size - fftfreq_xp = paddle.fft.fftfreq(n, d=scalar_temp) - print(fftfreq_xp) - - # Tensor(shape=[5], dtype=float32, place=CUDAPlace(0), stop_gradient=True, - # [ 0. , 0.40000001, 0.80000001, -0.80000001, -0.40000001]) - """ - - dtype = paddle.framework.get_default_dtype() - val = 1.0 / (n * d) - pos_max = (n + 1) // 2 - neg_max = n // 2 - indices = paddle.arange(-neg_max, pos_max, dtype=dtype, name=name) - indices = paddle.roll(indices, -neg_max, name=name) - return indices * val - - -def rfftfreq(n, d=1.0, dtype=None, name=None): - """ - Return the Discrete Fourier Transform sample frequencies. - - The returned floating-point array "F" contains the center of the frequency unit, - and the unit is the number of cycles of the sampling interval (the starting point is zero). - - Given input length `n` and a sample spacing `d`:: - - f = [0, 1, ..., n/2-1, n/2] / (d*n) if n is even - f = [0, 1, ..., (n-1)/2-1, (n-1)/2] / (d*n) if n is odd - - the Nyquist frequency component is considered to be positive. - - Args: - n (int): Dimension inputed. - d (scalar, optional): Sample spacing (inverse of the sampling rate). Defaults is 1. - name (str, optional): The default value is None. Normally there is no need for user to set - this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: - Tensor. A tensor of length ``n//2 + 1`` containing the sample frequencies. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle - - x = np.array([3, 1, 2, 2, 3], dtype=float) - scalar_temp = 0.3 - n = x.size - rfftfreq_xp = paddle.fft.rfftfreq(n, d=scalar_temp) - print(rfftfreq_xp) - - # Tensor(shape=[3], dtype=float32, place=CUDAPlace(0), stop_gradient=True, - # [0. , 0.66666669, 1.33333337]) - - """ - - dtype = paddle.framework.get_default_dtype() - val = 1.0 / (n * d) - pos_max = 1 + n // 2 - indices = paddle.arange(0, pos_max, dtype=dtype, name=name) - return indices * val - - -def fftshift(x, axes=None, name=None): - """ - Shift the zero-frequency component to the center of the spectrum. - - This function swaps half spaces for all the axes listed (all by default). - Note that ``y[0]`` is the Nyquist component only if ``len(x)`` is even. - - Args: - n (int): Dimension inputed. - axes (int|tuple, optional): The axis on which to move. The default is none, which moves all axes. - Default is None. - name (str, optional): The default value is None. Normally there is no need for user to set - this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: - Tensor. The shifted tensor. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle - - x = np.array([3, 1, 2, 2, 3], dtype=float) - n = x.size - fftfreq_xp = paddle.fft.fftfreq(n, d=0.3) - res = paddle.fft.fftshift(fftfreq_xp).numpy() - print(res) - # [-1.3333334 -0.6666667 0. 0.6666667 1.3333334] - - """ - shape = paddle.shape(x) - if axes is None: - # shift all axes - rank = paddle.rank(x).reshape([1]) - axes = axes or paddle.arange(0, rank) - shifts = [size // 2 for size in shape] - elif isinstance(axes, int): - shifts = shape[axes] // 2 - else: - shifts = [shape[ax] // 2 for ax in axes] - return paddle.roll(x, shifts, axes, name=name) - - -def ifftshift(x, axes=None, name=None): - """ - The inverse of `fftshift`. Although the even length 'x' is the same, the function of the - odd length 'x' is different. An example. - - Args: - n (int): Dimension inputed. - axes (int|tuple, optional): The axis on which to move. The default is none, which moves all axes. - Default is None. - name (str, optional): The default value is None. Normally there is no need for user to set - this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: - Tensor. The shifted tensor. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle - - x = np.array([3, 1, 2, 2, 3], dtype=float) - n = x.size - fftfreq_xp = paddle.fft.fftfreq(n, d=0.3) - res = paddle.fft.ifftshift(fftfreq_xp).numpy() - print(res) - # [ 1.3333334 -1.3333334 -0.6666667 0. 0.6666667] - - """ - shape = paddle.shape(x) - if axes is None: - # shift all axes - rank = paddle.rank(x).reshape([1]) - axes = axes or paddle.arange(0, rank) - shifts = [-size // 2 for size in shape] - elif isinstance(axes, int): - shifts = -shape[axes] // 2 - else: - shifts = [-shape[ax] // 2 for ax in axes] - return paddle.roll(x, shifts, axes, name=name) - - -# internal functions -def fft_c2c(x, n, axis, norm, forward, name): - if is_interger(x): - x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype())) - elif is_floating_point(x): - x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) - _check_normalization(norm) - - axis = axis if axis is not None else -1 - _check_fft_axis(x, axis) - axes = [axis] - axes = _normalize_axes(x, axes) - if n is not None: - _check_fft_n(n) - s = [n] - x = _resize_fft_input(x, s, axes) - op_type = 'fft_c2c' - - check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type) - if in_dygraph_mode(): - attrs = ('axes', axes, 'normalization', norm, 'forward', forward) - out = getattr(_C_ops, op_type)(x, *attrs) - else: - inputs = {'X': [x], } - attrs = {'axes': axes, 'normalization': norm, 'forward': forward} - helper = LayerHelper(op_type, **locals()) - dtype = helper.input_dtype(input_param_name='x') - out = helper.create_variable_for_type_inference(dtype) - outputs = {"Out": [out]} - helper.append_op( - type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) - return out - - -def fft_r2c(x, n, axis, norm, forward, onesided, name): - if is_interger(x): - x = paddle.cast(x, paddle.get_default_dtype()) - _check_normalization(norm) - axis = axis if axis is not None else -1 - _check_fft_axis(x, axis) - axes = [axis] - axes = _normalize_axes(x, axes) - if n is not None: - _check_fft_n(n) - s = [n] - x = _resize_fft_input(x, s, axes) - op_type = 'fft_r2c' - check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], op_type) - - if in_dygraph_mode(): - attrs = ('axes', axes, 'normalization', norm, 'forward', forward, - 'onesided', onesided) - out = getattr(_C_ops, op_type)(x, *attrs) - else: - inputs = {'X': [x], } - attrs = { - 'axes': axes, - 'normalization': norm, - 'forward': forward, - 'onesided': onesided, - } - helper = LayerHelper(op_type, **locals()) - dtype = helper.input_dtype(input_param_name='x') - out = helper.create_variable_for_type_inference( - _real_to_complex_dtype(dtype)) - outputs = {"Out": [out]} - helper.append_op( - type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) - return out - - -def fft_c2r(x, n, axis, norm, forward, name): - if is_interger(x): - x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype())) - elif is_floating_point(x): - x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) - _check_normalization(norm) - axis = axis if axis is not None else -1 - _check_fft_axis(x, axis) - axes = [axis] - axes = _normalize_axes(x, axes) - if n is not None: - _check_fft_n(n) - s = [n // 2 + 1] - x = _resize_fft_input(x, s, axes) - op_type = 'fft_c2r' - check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type) - - if in_dygraph_mode(): - if n is not None: - attrs = ('axes', axes, 'normalization', norm, 'forward', forward, - 'last_dim_size', n) - else: - attrs = ('axes', axes, 'normalization', norm, 'forward', forward) - out = getattr(_C_ops, op_type)(x, *attrs) - else: - inputs = {'X': [x], } - attrs = {'axes': axes, 'normalization': norm, 'forward': forward} - if n is not None: - attrs['last_dim_size'] = n - helper = LayerHelper(op_type, **locals()) - dtype = helper.input_dtype(input_param_name='x') - out = helper.create_variable_for_type_inference( - _complex_to_real_dtype(dtype)) - outputs = {"Out": [out]} - helper.append_op( - type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) - return out - - -def fftn_c2c(x, s, axes, norm, forward, name): - if is_interger(x): - x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype())) - elif is_floating_point(x): - x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) - _check_normalization(norm) - if s is not None: - _check_fft_shape(x, s) - - rank = x.ndim - if axes is None: - if s is None: - axes = list(range(rank)) - else: - fft_ndims = len(s) - axes = list(range(rank - fft_ndims, rank)) - else: - _check_fft_axes(x, axes) - axes = _normalize_axes(x, axes) - axes_argsoft = np.argsort(axes).tolist() - axes = [axes[i] for i in axes_argsoft] - if s is not None: - if len(s) != len(axes): - raise ValueError( - "Length of s ({}) and length of axes ({}) does not match.". - format(len(s), len(axes))) - s = [s[i] for i in axes_argsoft] - - if s is not None: - x = _resize_fft_input(x, s, axes) - op_type = 'fft_c2c' - check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type) - - if in_dygraph_mode(): - attrs = ('axes', axes, 'normalization', norm, 'forward', forward) - out = getattr(_C_ops, op_type)(x, *attrs) - else: - inputs = {'X': [x], } - attrs = {'axes': axes, 'normalization': norm, 'forward': forward} - helper = LayerHelper(op_type, **locals()) - dtype = helper.input_dtype(input_param_name='x') - out = helper.create_variable_for_type_inference(dtype) - outputs = {"Out": [out]} - helper.append_op( - type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) - return out - - -def fftn_r2c(x, s, axes, norm, forward, onesided, name): - if is_interger(x): - x = paddle.cast(x, paddle.get_default_dtype()) - _check_normalization(norm) - if s is not None: - _check_fft_shape(x, s) - - rank = x.ndim - if axes is None: - if s is None: - axes = list(range(rank)) - else: - fft_ndims = len(s) - axes = list(range(rank - fft_ndims, rank)) - else: - _check_fft_axes(x, axes) - axes = _normalize_axes(x, axes) - axes_argsoft = np.argsort(axes[:-1]).tolist() - axes = [axes[i] for i in axes_argsoft] + [axes[-1]] - if s is not None: - if len(s) != len(axes): - raise ValueError( - "Length of s ({}) and length of axes ({}) does not match.". - format(len(s), len(axes))) - s = [s[i] for i in axes_argsoft] + [s[-1]] - - if s is not None: - x = _resize_fft_input(x, s, axes) - - op_type = 'fft_r2c' - check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], op_type) - - if in_dygraph_mode(): - attrs = ('axes', axes, 'normalization', norm, 'forward', forward, - 'onesided', onesided) - out = getattr(_C_ops, op_type)(x, *attrs) - else: - inputs = {'X': [x], } - attrs = { - 'axes': axes, - 'normalization': norm, - 'forward': forward, - 'onesided': onesided, - } - helper = LayerHelper(op_type, **locals()) - dtype = helper.input_dtype(input_param_name='x') - out = helper.create_variable_for_type_inference( - _real_to_complex_dtype(dtype)) - outputs = {"Out": [out]} - helper.append_op( - type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) - - return out - - -def fftn_c2r(x, s, axes, norm, forward, name): - if is_interger(x): - x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype())) - elif is_floating_point(x): - x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) - _check_normalization(norm) - if s is not None: - _check_fft_shape(x, s) - - rank = x.ndim - if axes is None: - if s is None: - axes = list(range(rank)) - else: - fft_ndims = len(s) - axes = list(range(rank - fft_ndims, rank)) - else: - _check_fft_axes(x, axes) - axes = _normalize_axes(x, axes) - axes_argsoft = np.argsort(axes[:-1]).tolist() - axes = [axes[i] for i in axes_argsoft] + [axes[-1]] - if s is not None: - if len(s) != len(axes): - raise ValueError( - "Length of s ({}) and length of axes ({}) does not match.". - format(len(s), len(axes))) - s = [s[i] for i in axes_argsoft] + [s[-1]] - - if s is not None: - fft_input_shape = list(s) - fft_input_shape[-1] = fft_input_shape[-1] // 2 + 1 - x = _resize_fft_input(x, fft_input_shape, axes) - - op_type = 'fft_c2r' - check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], op_type) - - if in_dygraph_mode(): - if s: - attrs = ('axes', axes, 'normalization', norm, 'forward', forward, - 'last_dim_size', s[-1]) - else: - attrs = ('axes', axes, 'normalization', norm, 'forward', forward) - out = getattr(_C_ops, op_type)(x, *attrs) - else: - inputs = {'X': [x], } - attrs = {'axes': axes, 'normalization': norm, 'forward': forward} - if s: - attrs["last_dim_size"] = s[-1] - helper = LayerHelper(op_type, **locals()) - dtype = helper.input_dtype(input_param_name='x') - out = helper.create_variable_for_type_inference( - _complex_to_real_dtype(dtype)) - outputs = {"Out": [out]} - helper.append_op( - type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) - return out diff --git a/python/paddle/tensor/signal.py b/python/paddle/tensor/signal.py deleted file mode 100644 index d5cb703dd09d7..0000000000000 --- a/python/paddle/tensor/signal.py +++ /dev/null @@ -1,571 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import paddle - -from .attribute import is_complex, is_floating_point -from .fft import fft_r2c, fft_c2r, fft_c2c -from ..fluid.data_feeder import check_variable_and_dtype -from ..fluid.framework import in_dygraph_mode -from ..fluid.layer_helper import LayerHelper -from .. import _C_ops - -__all__ = [] - - -def frame(x, frame_length, hop_length, axis=-1, name=None): - """ - Slice the N-dimensional (where N >= 1) input into (overlapping) frames. - - Args: - x (Tensor): The input data which is a N-dimensional (where N >= 1) Tensor - with shape `[..., seq_length]` or `[seq_length, ...]`. - frame_length (int): Length of the frame and `0 < frame_length <= x.shape[axis]`. - hop_length (int): Number of steps to advance between adjacent frames - and `0 < hop_length`. - axis (int, optional): Specify the axis to operate on the input Tensors. Its - value should be 0(the first dimension) or -1(the last dimension). If not - specified, the last axis is used by default. - - Returns: - The output frames tensor with shape `[..., frame_length, num_frames]` if `axis==-1`, - otherwise `[num_frames, frame_length, ...]` where - - `num_framse = 1 + (x.shape[axis] - frame_length) // hop_length` - - Examples: - - .. code-block:: python - - import paddle - from paddle.tensor.signal import frame - - # 1D - x = paddle.arange(8) - y0 = frame(x, frame_length=4, hop_length=2, axis=-1) # [4, 3] - # [[0, 2, 4], - # [1, 3, 5], - # [2, 4, 6], - # [3, 5, 7]] - - y1 = frame(x, frame_length=4, hop_length=2, axis=0) # [3, 4] - # [[0, 1, 2, 3], - # [2, 3, 4, 5], - # [4, 5, 6, 7]] - - # 2D - x0 = paddle.arange(16).reshape([2, 8]) - y0 = frame(x0, frame_length=4, hop_length=2, axis=-1) # [2, 4, 3] - # [[[0, 2, 4], - # [1, 3, 5], - # [2, 4, 6], - # [3, 5, 7]], - # - # [[8 , 10, 12], - # [9 , 11, 13], - # [10, 12, 14], - # [11, 13, 15]]] - - x1 = paddle.arange(16).reshape([8, 2]) - y1 = frame(x1, frame_length=4, hop_length=2, axis=0) # [3, 4, 2] - # [[[0 , 1 ], - # [2 , 3 ], - # [4 , 5 ], - # [6 , 7 ]], - # - # [4 , 5 ], - # [6 , 7 ], - # [8 , 9 ], - # [10, 11]], - # - # [8 , 9 ], - # [10, 11], - # [12, 13], - # [14, 15]]] - - # > 2D - x0 = paddle.arange(32).reshape([2, 2, 8]) - y0 = frame(x0, frame_length=4, hop_length=2, axis=-1) # [2, 2, 4, 3] - - x1 = paddle.arange(32).reshape([8, 2, 2]) - y1 = frame(x1, frame_length=4, hop_length=2, axis=0) # [3, 4, 2, 2] - """ - if axis not in [0, -1]: - raise ValueError(f'Unexpected axis: {axis}. It should be 0 or -1.') - - if not isinstance(frame_length, int) or frame_length <= 0: - raise ValueError( - f'Unexpected frame_length: {frame_length}. It should be an positive integer.' - ) - - if not isinstance(hop_length, int) or hop_length <= 0: - raise ValueError( - f'Unexpected hop_length: {hop_length}. It should be an positive integer.' - ) - - if frame_length > x.shape[axis]: - raise ValueError( - f'Attribute frame_length should be less equal than sequence length, ' - f'but got ({frame_length}) > ({x.shape[axis]}).') - - op_type = 'frame' - - if in_dygraph_mode(): - attrs = ('frame_length', frame_length, 'hop_length', hop_length, 'axis', - axis) - op = getattr(_C_ops, op_type) - out = op(x, *attrs) - else: - check_variable_and_dtype( - x, 'x', ['int32', 'int64', 'float16', 'float32', - 'float64'], op_type) - helper = LayerHelper(op_type, **locals()) - dtype = helper.input_dtype(input_param_name='x') - out = helper.create_variable_for_type_inference(dtype=dtype) - helper.append_op( - type=op_type, - inputs={'X': x}, - attrs={ - 'frame_length': frame_length, - 'hop_length': hop_length, - 'axis': axis - }, - outputs={'Out': out}) - return out - - -def overlap_add(x, hop_length, axis=-1, name=None): - """ - Reconstructs a tensor consisted of overlap added sequences from input frames. - - Args: - x (Tensor): The input data which is a N-dimensional (where N >= 2) Tensor - with shape `[..., frame_length, num_frames]` or - `[num_frames, frame_length ...]`. - hop_length (int): Number of steps to advance between adjacent frames and - `0 < hop_length <= frame_length`. - axis (int, optional): Specify the axis to operate on the input Tensors. Its - value should be 0(the first dimension) or -1(the last dimension). If not - specified, the last axis is used by default. - - Returns: - The output frames tensor with shape `[..., seq_length]` if `axis==-1`, - otherwise `[seq_length, ...]` where - - `seq_length = (n_frames - 1) * hop_length + frame_length` - - Examples: - - .. code-block:: python - - import paddle - from paddle.tensor.signal import overlap_add - - # 2D - x0 = paddle.arange(16).reshape([8, 2]) - # [[0 , 1 ], - # [2 , 3 ], - # [4 , 5 ], - # [6 , 7 ], - # [8 , 9 ], - # [10, 11], - # [12, 13], - # [14, 15]] - y0 = overlap_add(x0, hop_length=2, axis=-1) # [10] - # [0 , 2 , 5 , 9 , 13, 17, 21, 25, 13, 15] - - x1 = paddle.arange(16).reshape([2, 8]) - # [[0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], - # [8 , 9 , 10, 11, 12, 13, 14, 15]] - y1 = overlap_add(x1, hop_length=2, axis=0) # [10] - # [0 , 1 , 10, 12, 14, 16, 18, 20, 14, 15] - - # > 2D - x0 = paddle.arange(32).reshape([2, 1, 8, 2]) - y0 = overlap_add(x0, hop_length=2, axis=-1) # [2, 1, 10] - - x1 = paddle.arange(32).reshape([2, 8, 1, 2]) - y1 = overlap_add(x1, hop_length=2, axis=0) # [10, 1, 2] - """ - if axis not in [0, -1]: - raise ValueError(f'Unexpected axis: {axis}. It should be 0 or -1.') - - if not isinstance(hop_length, int) or hop_length <= 0: - raise ValueError( - f'Unexpected hop_length: {hop_length}. It should be an positive integer.' - ) - - op_type = 'overlap_add' - - if in_dygraph_mode(): - attrs = ('hop_length', hop_length, 'axis', axis) - op = getattr(_C_ops, op_type) - out = op(x, *attrs) - else: - check_variable_and_dtype( - x, 'x', ['int32', 'int64', 'float16', 'float32', - 'float64'], op_type) - helper = LayerHelper(op_type, **locals()) - dtype = helper.input_dtype(input_param_name='x') - out = helper.create_variable_for_type_inference(dtype=dtype) - helper.append_op( - type=op_type, - inputs={'X': x}, - attrs={'hop_length': hop_length, - 'axis': axis}, - outputs={'Out': out}) - return out - - -def stft(x, - n_fft, - hop_length=None, - win_length=None, - window=None, - center=True, - pad_mode='reflect', - normalized=False, - onesided=True, - name=None): - """ - Short-time Fourier transform (STFT). - - The STFT computes the discrete Fourier transforms (DFT) of short overlapping - windows of the input using this formula: - - .. math:: - X_t[\omega] = \sum_{n = 0}^{N-1}% - \text{window}[n]\ x[t \times H + n]\ % - e^{-{2 \pi j \omega n}/{N}} - - Where: - - :math:`t`: The :math:`t`-th input window. - - :math:`\omega`: Frequency :math:`0 \leq \omega < \text{n\_fft}` for `onesided=False`, - or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for `onesided=True`. - - :math:`N`: Value of `n_fft`. - - :math:`H`: Value of `hop_length`. - - Args: - x (Tensor): The input data which is a 1-dimensional or 2-dimensional Tensor with - shape `[..., seq_length]`. It can be a real-valued or a complex Tensor. - n_fft (int): The number of input samples to perform Fourier transform. - hop_length (int, optional): Number of steps to advance between adjacent windows - and `0 < hop_length`. Default: `None`(treated as equal to `n_fft//4`) - win_length (int, optional): The size of window. Default: `None`(treated as equal - to `n_fft`) - window (Tensor, optional): A 1-dimensional tensor of size `win_length`. It will - be center padded to length `n_fft` if `win_length < n_fft`. Default: `None`( - treated as a rectangle window with value equal to 1 of size `win_length`). - center (bool, optional): Whether to pad `x` to make that the - :math:`t \times hop\_length` at the center of :math:`t`-th frame. Default: `True`. - pad_mode (str, optional): Choose padding pattern when `center` is `True`. See - `paddle.nn.functional.pad` for all padding options. Default: `"reflect"` - normalized (bool, optional): Control whether to scale the output by `1/sqrt(n_fft)`. - Default: `False` - onesided (bool, optional): Control whether to return half of the Fourier transform - output that satisfies the conjugate symmetry condition when input is a real-valued - tensor. It can not be `True` if input is a complex tensor. Default: `True` - name (str, optional): The default value is None. Normally there is no need for user - to set this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: - The complex STFT output tensor with shape `[..., n_fft//2 + 1, num_frames]`( - real-valued input and `onesided` is `True`) or `[..., n_fft, num_frames]`( - `onesided` is `False`) - - Exampels: - .. code-block:: python - - import paddle - from paddle.signal import stft - - # real-valued input - x = paddle.randn([8, 48000], dtype=paddle.float64) - y1 = stft(x, n_fft=512) # [8, 257, 376] - y2 = stft(x, n_fft=512, onesided=False) # [8, 512, 376] - - # complex input - x = paddle.randn([8, 48000], dtype=paddle.float64) + \ - paddle.randn([8, 48000], dtype=paddle.float64)*1j # [8, 48000] complex128 - y1 = stft(x, n_fft=512, center=False, onesided=False) # [8, 512, 372] - """ - check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'complex64', 'complex128'], - 'stft') - - x_rank = len(x.shape) - assert x_rank in [1, 2], \ - f'x should be a 1D or 2D real tensor, but got rank of x is {x_rank}' - - if x_rank == 1: # (batch, seq_length) - x = x.unsqueeze(0) - - if hop_length is None: - hop_length = int(n_fft // 4) - - assert hop_length > 0, \ - f'hop_length should be > 0, but got {hop_length}.' - - if win_length is None: - win_length = n_fft - - assert 0 < n_fft <= x.shape[-1], \ - f'n_fft should be in (0, seq_length({x.shape[-1]})], but got {n_fft}.' - - assert 0 < win_length <= n_fft, \ - f'win_length should be in (0, n_fft({n_fft})], but got {win_length}.' - - if window is not None: - assert len(window.shape) == 1 and len(window) == win_length, \ - f'expected a 1D window tensor of size equal to win_length({win_length}), but got window with shape {window.shape}.' - else: - window = paddle.ones(shape=(win_length, ), dtype=x.dtype) - - if win_length < n_fft: - pad_left = (n_fft - win_length) // 2 - pad_right = n_fft - win_length - pad_left - window = paddle.nn.functional.pad(window, - pad=[pad_left, pad_right], - mode='constant') - - if center: - assert pad_mode in ['constant', 'reflect'], \ - 'pad_mode should be "reflect" or "constant", but got "{}".'.format(pad_mode) - - pad_length = n_fft // 2 - # FIXME: Input `x` can be a complex tensor but pad does not supprt complex input. - x = paddle.nn.functional.pad(x.unsqueeze(-1), - pad=[pad_length, pad_length], - mode=pad_mode, - data_format="NLC").squeeze(-1) - - x_frames = frame(x=x, frame_length=n_fft, hop_length=hop_length, axis=-1) - x_frames = x_frames.transpose( - perm=[0, 2, - 1]) # switch n_fft to last dim, egs: (batch, num_frames, n_fft) - x_frames = x_frames * window - - norm = 'ortho' if normalized else 'backward' - if is_complex(x_frames): - assert not onesided, \ - 'onesided should be False when input or window is a complex Tensor.' - - if not is_complex(x): - out = fft_r2c( - x=x_frames, - n=None, - axis=-1, - norm=norm, - forward=True, - onesided=onesided, - name=name) - else: - out = fft_c2c( - x=x_frames, n=None, axis=-1, norm=norm, forward=True, name=name) - - out = out.transpose(perm=[0, 2, 1]) # (batch, n_fft, num_frames) - - if x_rank == 1: - out.squeeze_(0) - - return out - - -def istft(x, - n_fft, - hop_length=None, - win_length=None, - window=None, - center=True, - normalized=False, - onesided=True, - length=None, - return_complex=False, - name=None): - """ - Inverse short-time Fourier transform (ISTFT). - - Reconstruct time-domain signal from the giving complex input and window tensor when - nonzero overlap-add (NOLA) condition is met: - - .. math:: - \sum_{t = -\infty}^{\infty}% - \text{window}^2[n - t \times H]\ \neq \ 0, \ \text{for } all \ n - - Where: - - :math:`t`: The :math:`t`-th input window. - - :math:`N`: Value of `n_fft`. - - :math:`H`: Value of `hop_length`. - - Result of `istft` expected to be the inverse of `paddle.tensor.signal.stft`, but it is - not guaranteed to reconstruct a exactly realizible time-domain signal from a STFT - complex tensor which has been modified (via masking or otherwise). Therefore, `istft` - gives the [Griffin-Lim optimal estimate](https://ieeexplore.ieee.org/document/1164317) - (optimal in a least-squares sense) for the corresponding signal. - - Args: - x (Tensor): The input data which is a 2-dimensional or 3-dimensional **complesx** - Tensor with shape `[..., n_fft, num_frames]`. - n_fft (int): The size of Fourier transform. - hop_length (int, optional): Number of steps to advance between adjacent windows - from time-domain signal and `0 < hop_length < win_length`. Default: `None`( - treated as equal to `n_fft//4`) - win_length (int, optional): The size of window. Default: `None`(treated as equal - to `n_fft`) - window (Tensor, optional): A 1-dimensional tensor of size `win_length`. It will - be center padded to length `n_fft` if `win_length < n_fft`. It should be a - real-valued tensor if `return_complex` is False. Default: `None`(treated as - a rectangle window with value equal to 1 of size `win_length`). - center (bool, optional): It means that whether the time-domain signal has been - center padded. Default: `True`. - normalized (bool, optional): Control whether to scale the output by `1/sqrt(n_fft)`. - Default: `False` - onesided (bool, optional): It means that whether the input STFT tensor is a half - of the conjugate symmetry STFT tensor transformed from a real-valued signal - and `istft` will return a real-valued tensor when it is set to `True`. - Default: `True`. - length (int, optional): Specify the length of time-domain signal. Default: `None`( - treated as the whole length of signal). - return_complex (bool, optional): It means that whether the time-domain signal is - real-valued. If `return_complex` is set to `True`, `onesided` should be set to - `False` cause the output is complex. - name (str, optional): The default value is None. Normally there is no need for user - to set this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: - A tensor of least squares estimation of the reconstructed signal(s) with shape - `[..., seq_length]` - - Exampels: - .. code-block:: python - - import numpy as np - import paddle - from paddle.signal import stft, istft - - paddle.seed(0) - - # STFT - x = paddle.randn([8, 48000], dtype=paddle.float64) - y = stft(x, n_fft=512) # [8, 257, 376] - - # ISTFT - x_ = istft(y, n_fft=512) # [8, 48000] - - np.allclose(x, x_) # True - """ - check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'istft') - - x_rank = len(x.shape) - assert x_rank in [2, 3], \ - 'x should be a 2D or 3D complex tensor, but got rank of x is {}'.format(x_rank) - - if x_rank == 2: # (batch, n_fft, n_frames) - x = x.unsqueeze(0) - - if hop_length is None: - hop_length = int(n_fft // 4) - - if win_length is None: - win_length = n_fft - - # Assure no gaps between frames. - assert 0 < hop_length <= win_length, \ - 'hop_length should be in (0, win_length({})], but got {}.'.format(win_length, hop_length) - - assert 0 < win_length <= n_fft, \ - 'win_length should be in (0, n_fft({})], but got {}.'.format(n_fft, win_length) - - n_frames = x.shape[-1] - fft_size = x.shape[-2] - - if onesided: - assert (fft_size == n_fft // 2 + 1), \ - 'fft_size should be equal to n_fft // 2 + 1({}) when onesided is True, but got {}.'.format(n_fft // 2 + 1, fft_size) - else: - assert (fft_size == n_fft), \ - 'fft_size should be equal to n_fft({}) when onesided is False, but got {}.'.format(n_fft, fft_size) - - if window is not None: - assert len(window.shape) == 1 and len(window) == win_length, \ - 'expected a 1D window tensor of size equal to win_length({}), but got window with shape {}.'.format(win_length, window.shape) - else: - window = paddle.ones(shape=(win_length, )) - - if win_length < n_fft: - pad_left = (n_fft - win_length) // 2 - pad_right = n_fft - win_length - pad_left - # FIXME: Input `window` can be a complex tensor but pad does not supprt complex input. - window = paddle.nn.functional.pad(window, - pad=[pad_left, pad_right], - mode='constant') - - x = x.transpose( - perm=[0, 2, - 1]) # switch n_fft to last dim, egs: (batch, num_frames, n_fft) - norm = 'ortho' if normalized else 'backward' - - if return_complex: - assert not onesided, \ - 'onesided should be False when input(output of istft) or window is a complex Tensor.' - - out = fft_c2c(x=x, n=None, axis=-1, norm=norm, forward=False, name=None) - else: - assert not is_complex(window), \ - 'Data type of window should not be complex when return_complex is False.' - - if onesided is False: - x = x[:, :, :n_fft // 2 + 1] - out = fft_c2r(x=x, n=None, axis=-1, norm=norm, forward=False, name=None) - - out = overlap_add( - x=(out * window).transpose( - perm=[0, 2, 1]), # (batch, n_fft, num_frames) - hop_length=hop_length, - axis=-1) # (batch, seq_length) - - window_envelop = overlap_add( - x=paddle.tile( - x=window * window, repeat_times=[n_frames, 1]).transpose( - perm=[1, 0]), # (n_fft, num_frames) - hop_length=hop_length, - axis=-1) # (seq_length, ) - - if length is None: - if center: - out = out[:, (n_fft // 2):-(n_fft // 2)] - window_envelop = window_envelop[(n_fft // 2):-(n_fft // 2)] - else: - if center: - start = n_fft // 2 - else: - start = 0 - - out = out[:, start:start + length] - window_envelop = window_envelop[start:start + length] - - # Check whether the Nonzero Overlap Add (NOLA) constraint is met. - if window_envelop.abs().min().item() < 1e-11: - raise ValueError( - 'Abort istft because Nonzero Overlap Add (NOLA) condition failed. For more information about NOLA constraint please see `scipy.signal.check_NOLA`(https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.check_NOLA.html).' - ) - - out = out / window_envelop - - if x_rank == 2: - out.squeeze_(0) - - return out From fe351708bebc5cbafb63998c28d7affd9677c525 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Wed, 20 Oct 2021 15:28:38 +0800 Subject: [PATCH 09/16] fix typos in signal.py (#3) * move signal apis * move fft.py and signal.py to paddle/, fix typos * fix relative imports from fft.py and signal.py * fix typos --- python/paddle/signal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/signal.py b/python/paddle/signal.py index d550fa677d929..fc80c7cbc80f3 100644 --- a/python/paddle/signal.py +++ b/python/paddle/signal.py @@ -289,7 +289,7 @@ def stft(x, real-valued input and `onesided` is `True`) or `[..., n_fft, num_frames]`( `onesided` is `False`) - Exampels: + Examples: .. code-block:: python import paddle @@ -452,7 +452,7 @@ def istft(x, A tensor of least squares estimation of the reconstructed signal(s) with shape `[..., seq_length]` - Exampels: + Examples: .. code-block:: python import numpy as np From 7b1c8aef8cf73dcf06be63e9350a1bd2f1367d92 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Tue, 26 Oct 2021 21:15:55 +0800 Subject: [PATCH 10/16] disable Cache when CUFFT_VERSION >= 10200 (#4) * move signal apis * move fft.py and signal.py to paddle/, fix typos * fix relative imports from fft.py and signal.py * fix typos * Add LRUCache for fft plans --- paddle/fluid/operators/spectral_op.cu | 76 +++++++++++++++++---------- 1 file changed, 48 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/operators/spectral_op.cu b/paddle/fluid/operators/spectral_op.cu index e8a4fac2915d7..38f001ad3daae 100644 --- a/paddle/fluid/operators/spectral_op.cu +++ b/paddle/fluid/operators/spectral_op.cu @@ -68,9 +68,8 @@ void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out, } #if defined(PADDLE_WITH_CUDA) -CuFFTConfig create_cufft_config(const framework::Tensor& input, - const framework::Tensor& output, - int signal_ndim) { +FFTConfigKey create_fft_configkey(const framework::Tensor& input, + const framework::Tensor& output, int signal_ndim) { // Create the transform plan (either from cache or locally) const auto value_type = framework::IsComplexType(input.type()) ? framework::ToRealType(input.type()) @@ -85,11 +84,10 @@ CuFFTConfig create_cufft_config(const framework::Tensor& input, auto out_size = output.dims()[i]; signal_size[i] = std::max(in_size, out_size); } - PlanKey key(framework::vectorize(input.dims()), - framework::vectorize(output.dims()), signal_size, fft_type, - value_type); - - return CuFFTConfig(key); + FFTConfigKey key(framework::vectorize(input.dims()), + framework::vectorize(output.dims()), signal_size, fft_type, + value_type); + return key; } // Execute a pre-planned transform @@ -136,9 +134,8 @@ void exec_cufft_plan(const DeviceContext& ctx, const CuFFTConfig& config, #elif defined(PADDLE_WITH_HIP) -HIPFFTConfig create_hipfft_config(const framework::Tensor& input, - const framework::Tensor& output, - int signal_ndim) { +FFTConfigKey create_fft_configkey(const framework::Tensor& input, + const framework::Tensor& output, int signal_ndim) { // Create the transform plan (either from cache or locally) const auto value_type = framework::IsComplexType(input.type()) ? framework::ToRealType(input.type()) @@ -153,11 +150,10 @@ HIPFFTConfig create_hipfft_config(const framework::Tensor& input, auto out_size = output.dims()[i]; signal_size[i] = std::max(in_size, out_size); } - PlanKey key(framework::vectorize(input.dims()), - framework::vectorize(output.dims()), signal_size, fft_type, - value_type); - - return HIPFFTConfig(key); + FFTConfigKey key(framework::vectorize(input.dims()), + framework::vectorize(output.dims()), signal_size, fft_type, + value_type); + return key; } // Execute a pre-planned transform @@ -308,34 +304,58 @@ void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out, collapsed_output.Resize(framework::make_ddim(collapsed_output_shape)); collapsed_output.mutable_data(tensor_place); +FFTConfig* config = nullptr; + #if defined(PADDLE_WITH_CUDA) + std::unique_ptr config_ = nullptr; // create plan - CuFFTConfig config = - create_cufft_config(collapsed_input, collapsed_output, signal_ndim); + FFTConfigKey key = + create_fft_configkey(collapsed_input, collapsed_output, signal_ndim); + if (CUFFT_VERSION < 10200) { + const int64_t device_id = static_cast( + reinterpret_cast(&collapsed_input.place()) + ->GetDeviceId()); + FFTConfigCache& plan_cache = get_fft_plan_cache(device_id); + std::unique_lock guard(plan_cache.mutex, std::defer_lock); + guard.lock(); + config = &(plan_cache.lookup(key)); + } else { + config_ = std::make_unique(key); + config = config_.get(); + } + // prepare cufft for execution PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cufftSetStream(config.plan(), ctx.stream())); + platform::dynload::cufftSetStream(config->plan(), ctx.stream())); framework::Tensor workspace_tensor; - workspace_tensor.mutable_data(tensor_place, config.workspace_size()); + workspace_tensor.mutable_data(tensor_place, config->workspace_size()); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftSetWorkArea( - config.plan(), workspace_tensor.data())); + config->plan(), workspace_tensor.data())); // execute transform plan - exec_cufft_plan(ctx, config, &collapsed_input, + exec_cufft_plan(ctx, *config, &collapsed_input, &collapsed_output, forward); #elif defined(PADDLE_WITH_HIP) // create plan - HIPFFTConfig config = - create_hipfft_config(collapsed_input, collapsed_output, signal_ndim); + FFTConfigKey key = + create_fft_configkey(collapsed_input, collapsed_output, signal_ndim); + const int64_t device_id = static_cast( + reinterpret_cast(&collapsed_input.place()) + ->GetDeviceId()); + FFTConfigCache& plan_cache = get_fft_plan_cache(device_id); + std::unique_lock guard(plan_cache.mutex, std::defer_lock); + guard.lock(); + config = &(plan_cache.lookup(key)); + // prepare cufft for execution PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::hipfftSetStream(config.plan(), ctx.stream())); + platform::dynload::hipfftSetStream(config->plan(), ctx.stream())); framework::Tensor workspace_tensor; - workspace_tensor.mutable_data(tensor_place, config.workspace_size()); + workspace_tensor.mutable_data(tensor_place, config->workspace_size()); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftSetWorkArea( - config.plan(), workspace_tensor.data())); + config->plan(), workspace_tensor.data())); // execute transform plan - exec_hipfft_plan(ctx, config, &collapsed_input, + exec_hipfft_plan(ctx, *config, &collapsed_input, &collapsed_output, forward); #endif From 14687c97f1177892aa1c756e8fcab243323f814f Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Tue, 26 Oct 2021 22:21:20 +0800 Subject: [PATCH 11/16] add LRUCache for cuff and hipfft (#5) * move signal apis * move fft.py and signal.py to paddle/, fix typos * fix relative imports from fft.py and signal.py * fix typos * WIP: add cache * delete move constructor and operator= for CuFFTHandle and FFTConfig * remove log from CuFFTHandle and FFTConfig * add lrucache for fft rocm backend * disable LRUCache when CUFFT_VERSION >= 10200 * disbale copy and move for hipFFTHandle; format code Co-authored-by: Xiaoxu Chen --- paddle/fluid/operators/spectral_helper.h | 245 +++++++++++++++++++++-- paddle/fluid/operators/spectral_op.cu | 29 +-- 2 files changed, 241 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/operators/spectral_helper.h b/paddle/fluid/operators/spectral_helper.h index 9c34d500eac92..7792f7618dcf0 100644 --- a/paddle/fluid/operators/spectral_helper.h +++ b/paddle/fluid/operators/spectral_helper.h @@ -27,12 +27,12 @@ namespace paddle { namespace operators { using ScalarType = framework::proto::VarType::Type; -const int64_t kMaxCUFFTNdim = 3; -const int64_t kMaxDataNdim = kMaxCUFFTNdim + 1; +const int64_t kMaxFFTNdim = 3; +const int64_t kMaxDataNdim = kMaxFFTNdim + 1; // This struct is used to easily compute hashes of the // parameters. It will be the **key** to the plan cache. -struct PlanKey { - // between 1 and kMaxCUFFTNdim, i.e., 1 <= signal_ndim <= 3 +struct FFTConfigKey { + // between 1 and kMaxFFTNdim, i.e., 1 <= signal_ndim <= 3 int64_t signal_ndim_; // These include additional batch dimension as well. int64_t sizes_[kMaxDataNdim]; @@ -41,12 +41,12 @@ struct PlanKey { FFTTransformType fft_type_; ScalarType value_type_; - PlanKey() = default; + FFTConfigKey() = default; - PlanKey(const std::vector& in_shape, - const std::vector& out_shape, - const std::vector& signal_size, FFTTransformType fft_type, - ScalarType value_type) { + FFTConfigKey(const std::vector& in_shape, + const std::vector& out_shape, + const std::vector& signal_size, + FFTTransformType fft_type, ScalarType value_type) { // Padding bits must be zeroed for hashing memset(this, 0, sizeof(*this)); signal_ndim_ = signal_size.size() - 1; @@ -67,12 +67,20 @@ class CuFFTHandle { public: CuFFTHandle() { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftCreate(&handle_)); + std::cout << "Constructed Handle " << handle_ << std::endl; } + CuFFTHandle(const CuFFTHandle& other) = delete; + CuFFTHandle& operator=(const CuFFTHandle& other) = delete; + + CuFFTHandle(CuFFTHandle&& other) = delete; + CuFFTHandle& operator=(CuFFTHandle&& other) = delete; + ::cufftHandle& get() { return handle_; } const ::cufftHandle& get() const { return handle_; } ~CuFFTHandle() { + std::cout << "Destructing Handle " << handle_ << std::endl; PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftDestroy(handle_)); } }; @@ -81,20 +89,20 @@ using plan_size_type = long long int; // NOLINT // This class contains all the information needed to execute a cuFFT plan: // 1. the plan // 2. the workspace size needed -class CuFFTConfig { +class FFTConfig { public: // Only move semantics is enought for this class. Although we already use // unique_ptr for the plan, still remove copy constructor and assignment op so // we don't accidentally copy and take perf hit. - explicit CuFFTConfig(const PlanKey& plan_key) - : CuFFTConfig( + explicit FFTConfig(const FFTConfigKey& plan_key) + : FFTConfig( std::vector(plan_key.sizes_, plan_key.sizes_ + plan_key.signal_ndim_ + 1), plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {} // sizes are full signal, including batch size and always two-sided - CuFFTConfig(const std::vector& sizes, const int64_t signal_ndim, - FFTTransformType fft_type, ScalarType dtype) + FFTConfig(const std::vector& sizes, const int64_t signal_ndim, + FFTTransformType fft_type, ScalarType dtype) : fft_type_(fft_type), value_type_(dtype) { // signal sizes (excluding batch dim) std::vector signal_sizes(sizes.begin() + 1, sizes.end()); @@ -144,6 +152,12 @@ class CuFFTConfig { ws_size = ws_size_t; } + FFTConfig(const FFTConfig& other) = delete; + FFTConfig& operator=(const FFTConfig& other) = delete; + + FFTConfig(FFTConfig&& other) = delete; + FFTConfig& operator=(FFTConfig&& other) = delete; + const cufftHandle& plan() const { return plan_ptr.get(); } FFTTransformType transform_type() const { return fft_type_; } @@ -167,6 +181,12 @@ class HIPFFTHandle { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftCreate(&handle_)); } + HIPFFTHandle(const HIPFFTHandle& other) = delete; + HIPFFTHandle& operator=(const HIPFFTHandle& other) = delete; + + HIPFFTHandle(HIPFFTHandle&& other) = delete; + HIPFFTHandle& operator=(HIPFFTHandle&& other) = delete; + ::hipfftHandle& get() { return handle_; } const ::hipfftHandle& get() const { return handle_; } @@ -178,20 +198,20 @@ using plan_size_type = int; // This class contains all the information needed to execute a cuFFT plan: // 1. the plan // 2. the workspace size needed -class HIPFFTConfig { +class FFTConfig { public: // Only move semantics is enought for this class. Although we already use // unique_ptr for the plan, still remove copy constructor and assignment op so // we don't accidentally copy and take perf hit. - explicit HIPFFTConfig(const PlanKey& plan_key) - : HIPFFTConfig( + explicit FFTConfig(const FFTConfigKey& plan_key) + : FFTConfig( std::vector(plan_key.sizes_, plan_key.sizes_ + plan_key.signal_ndim_ + 1), plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {} // sizes are full signal, including batch size and always two-sided - HIPFFTConfig(const std::vector& sizes, const int64_t signal_ndim, - FFTTransformType fft_type, ScalarType dtype) + FFTConfig(const std::vector& sizes, const int64_t signal_ndim, + FFTTransformType fft_type, ScalarType dtype) : fft_type_(fft_type), value_type_(dtype) { // signal sizes (excluding batch dim) std::vector signal_sizes(sizes.begin() + 1, sizes.end()); @@ -257,5 +277,192 @@ class HIPFFTConfig { ScalarType value_type_; }; #endif + +// Hashing machinery for Key +// Fowler–Noll–Vo hash function +// see +// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function +template +struct KeyHash { + // Key must be a POD because we read out its memory + // contenst as char* when hashing + static_assert(std::is_pod::value, "Key must be plain old data type"); + + size_t operator()(const Key& params) const { + auto ptr = reinterpret_cast(¶ms); + uint32_t value = 0x811C9DC5; + for (int i = 0; i < static_cast(sizeof(Key)); ++i) { + value ^= ptr[i]; + value *= 0x01000193; + } + return static_cast(value); + } +}; + +template +struct KeyEqual { + // Key must be a POD because we read out its memory + // contenst as char* when comparing + static_assert(std::is_pod::value, "Key must be plain old data type"); + + bool operator()(const Key& a, const Key& b) const { + auto ptr1 = reinterpret_cast(&a); + auto ptr2 = reinterpret_cast(&b); + return memcmp(ptr1, ptr2, sizeof(Key)) == 0; + } +}; + +#if CUDA_VERSION < 10000 +// Note that the max plan number for CUDA version < 10 has to be 1023 +// due to a bug that fails on the 1024th plan +constexpr size_t CUFFT_MAX_PLAN_NUM = 1023; +constexpr size_t CUFFT_DEFAULT_CACHE_SIZE = CUFFT_MAX_PLAN_NUM; +#else +constexpr size_t CUFFT_MAX_PLAN_NUM = std::numeric_limits::max(); +// The default max cache size chosen for CUDA version > 10 is arbitrary. +// This number puts a limit on how big of a plan cache should we maintain by +// default. Users can always configure it via cufft_set_plan_cache_max_size. +constexpr size_t CUFFT_DEFAULT_CACHE_SIZE = 4096; +#endif +static_assert(CUFFT_MAX_PLAN_NUM >= 0 && + CUFFT_MAX_PLAN_NUM <= std::numeric_limits::max(), + "CUFFT_MAX_PLAN_NUM not in size_t range"); +static_assert(CUFFT_DEFAULT_CACHE_SIZE >= 0 && + CUFFT_DEFAULT_CACHE_SIZE <= CUFFT_MAX_PLAN_NUM, + "CUFFT_DEFAULT_CACHE_SIZE not in [0, CUFFT_MAX_PLAN_NUM] range"); + +// This cache assumes that the mapping from key to value never changes. +// This is **NOT** thread-safe. Please use a mutex when using it **AND** the +// value returned from try_emplace_value. +// The contract of using this cache is that try_emplace_value should only be +// used when the max_size is positive. +class FFTConfigCache { + public: + using kv_t = typename std::pair; + using map_t = typename std::unordered_map< + std::reference_wrapper, typename std::list::iterator, + KeyHash, KeyEqual>; + using map_kkv_iter_t = typename map_t::iterator; + + FFTConfigCache() : FFTConfigCache(CUFFT_DEFAULT_CACHE_SIZE) {} + + explicit FFTConfigCache(int64_t max_size) { _set_max_size(max_size); } + + FFTConfigCache(const FFTConfigCache& other) = delete; + FFTConfigCache& operator=(const FFTConfigCache& other) = delete; + + FFTConfigCache(FFTConfigCache&& other) noexcept + : _usage_list(std::move(other._usage_list)), + _cache_map(std::move(other._cache_map)), + _max_size(other._max_size) {} + + FFTConfigCache& operator=(FFTConfigCache&& other) noexcept { + _usage_list = std::move(other._usage_list); + _cache_map = std::move(other._cache_map); + _max_size = other._max_size; + return *this; + } + + // If key is in this cache, return the cached config. Otherwise, emplace the + // config in this cache and return it. + FFTConfig& lookup(FFTConfigKey params) { + PADDLE_ENFORCE_GT(_max_size, 0, + platform::errors::InvalidArgument( + "The max size of FFTConfigCache must be great than 0," + "But received is [%d]", + _max_size)); + + map_kkv_iter_t map_it = _cache_map.find(params); + // Hit, put to list front + if (map_it != _cache_map.end()) { + _usage_list.splice(_usage_list.begin(), _usage_list, map_it->second); + return map_it->second->second; + } + + // Miss + // remove if needed + if (_usage_list.size() >= _max_size) { + auto last = _usage_list.end(); + last--; + _cache_map.erase(last->first); + _usage_list.pop_back(); + } + + // construct new plan at list front, then insert into _cache_map + _usage_list.emplace_front(std::piecewise_construct, + std::forward_as_tuple(params), + std::forward_as_tuple(params)); + auto kv_it = _usage_list.begin(); + _cache_map.emplace(std::piecewise_construct, + std::forward_as_tuple(kv_it->first), + std::forward_as_tuple(kv_it)); + return kv_it->second; + } + + void clear() { + _cache_map.clear(); + _usage_list.clear(); + } + + void resize(int64_t new_size) { + _set_max_size(new_size); + auto cur_size = _usage_list.size(); + if (cur_size > _max_size) { + auto delete_it = _usage_list.end(); + for (size_t i = 0; i < cur_size - _max_size; i++) { + delete_it--; + _cache_map.erase(delete_it->first); + } + _usage_list.erase(delete_it, _usage_list.end()); + } + } + + size_t size() const { return _cache_map.size(); } + + size_t max_size() const noexcept { return _max_size; } + + std::mutex mutex; + + private: + // Only sets size and does value check. Does not resize the data structures. + void _set_max_size(int64_t new_size) { + // We check that 0 <= new_size <= CUFFT_MAX_PLAN_NUM here. Since + // CUFFT_MAX_PLAN_NUM is of type size_t, we need to do non-negativity check + // first. + PADDLE_ENFORCE_GE( + new_size, 0, + platform::errors::InvalidArgument( + "cuFFT plan cache size must be non-negative, But received is [%d]", + new_size)); + PADDLE_ENFORCE_LE(new_size, CUFFT_MAX_PLAN_NUM, + platform::errors::InvalidArgument( + "cuFFT plan cache size can not be larger than [%d], " + "But received is [%d]", + CUFFT_MAX_PLAN_NUM, new_size)); + _max_size = static_cast(new_size); + } + + std::list _usage_list; + map_t _cache_map; + size_t _max_size; +}; + +static std::vector> plan_caches; +static std::mutex plan_caches_mutex; + +static inline FFTConfigCache& get_fft_plan_cache(int64_t device_index) { + std::lock_guard guard(plan_caches_mutex); + + if (device_index >= plan_caches.size()) { + plan_caches.resize(device_index + 1); + } + + if (!plan_caches[device_index]) { + plan_caches[device_index] = std::make_unique(); + } + + return *plan_caches[device_index]; +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/spectral_op.cu b/paddle/fluid/operators/spectral_op.cu index 38f001ad3daae..dee5315b67fb4 100644 --- a/paddle/fluid/operators/spectral_op.cu +++ b/paddle/fluid/operators/spectral_op.cu @@ -69,7 +69,8 @@ void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out, #if defined(PADDLE_WITH_CUDA) FFTConfigKey create_fft_configkey(const framework::Tensor& input, - const framework::Tensor& output, int signal_ndim) { + const framework::Tensor& output, + int signal_ndim) { // Create the transform plan (either from cache or locally) const auto value_type = framework::IsComplexType(input.type()) ? framework::ToRealType(input.type()) @@ -91,7 +92,7 @@ FFTConfigKey create_fft_configkey(const framework::Tensor& input, } // Execute a pre-planned transform -static void exec_cufft_plan_raw(const CuFFTConfig& config, void* in_data, +static void exec_cufft_plan_raw(const FFTConfig& config, void* in_data, void* out_data, bool forward) { auto& plan = config.plan(); @@ -100,7 +101,7 @@ static void exec_cufft_plan_raw(const CuFFTConfig& config, void* in_data, } template -void exec_cufft_plan(const DeviceContext& ctx, const CuFFTConfig& config, +void exec_cufft_plan(const DeviceContext& ctx, const FFTConfig& config, framework::Tensor* input, framework::Tensor* output, bool forward) { // execute transform plan @@ -135,7 +136,8 @@ void exec_cufft_plan(const DeviceContext& ctx, const CuFFTConfig& config, #elif defined(PADDLE_WITH_HIP) FFTConfigKey create_fft_configkey(const framework::Tensor& input, - const framework::Tensor& output, int signal_ndim) { + const framework::Tensor& output, + int signal_ndim) { // Create the transform plan (either from cache or locally) const auto value_type = framework::IsComplexType(input.type()) ? framework::ToRealType(input.type()) @@ -157,7 +159,7 @@ FFTConfigKey create_fft_configkey(const framework::Tensor& input, } // Execute a pre-planned transform -static void exec_hipfft_plan_raw(const HIPFFTConfig& config, void* in_data, +static void exec_hipfft_plan_raw(const FFTConfig& config, void* in_data, void* out_data, bool forward) { auto& plan = config.plan(); @@ -212,7 +214,7 @@ static void exec_hipfft_plan_raw(const HIPFFTConfig& config, void* in_data, } template -void exec_hipfft_plan(const DeviceContext& ctx, const HIPFFTConfig& config, +void exec_hipfft_plan(const DeviceContext& ctx, const FFTConfig& config, framework::Tensor* input, framework::Tensor* output, bool forward) { auto fft_type = config.transform_type(); @@ -304,7 +306,7 @@ void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out, collapsed_output.Resize(framework::make_ddim(collapsed_output_shape)); collapsed_output.mutable_data(tensor_place); -FFTConfig* config = nullptr; + FFTConfig* config = nullptr; #if defined(PADDLE_WITH_CUDA) std::unique_ptr config_ = nullptr; @@ -313,8 +315,8 @@ FFTConfig* config = nullptr; create_fft_configkey(collapsed_input, collapsed_output, signal_ndim); if (CUFFT_VERSION < 10200) { const int64_t device_id = static_cast( - reinterpret_cast(&collapsed_input.place()) - ->GetDeviceId()); + reinterpret_cast(&collapsed_input.place()) + ->GetDeviceId()); FFTConfigCache& plan_cache = get_fft_plan_cache(device_id); std::unique_lock guard(plan_cache.mutex, std::defer_lock); guard.lock(); @@ -323,7 +325,6 @@ FFTConfig* config = nullptr; config_ = std::make_unique(key); config = config_.get(); } - // prepare cufft for execution PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cufftSetStream(config->plan(), ctx.stream())); @@ -378,10 +379,10 @@ FFTConfig* config = nullptr; // Use the optimized path to perform single R2C or C2R if transformation dim is // supported by cuFFT -bool use_optimized_cufft_path(const std::vector& axes) { +bool use_optimized_fft_path(const std::vector& axes) { // For performance reason, when axes starts with (0, 1), do not use the // optimized path. - if (axes.size() > kMaxCUFFTNdim || + if (axes.size() > kMaxFFTNdim || (axes.size() >= 2 && axes[0] == 0 && axes[1] == 1)) { return false; } else { @@ -411,7 +412,7 @@ struct FFTC2CFunctor { while (true) { max_dims = - std::min(static_cast(kMaxCUFFTNdim), working_axes.size()); + std::min(static_cast(kMaxFFTNdim), working_axes.size()); first_dims.assign(working_axes.end() - max_dims, working_axes.end()); exec_fft(ctx, p_working_tensor, @@ -438,7 +439,7 @@ struct FFTC2RFunctor { std::vector in_dims = framework::vectorize(X->dims()); std::vector out_dims = framework::vectorize(out->dims()); - if (use_optimized_cufft_path(axes)) { + if (use_optimized_fft_path(axes)) { framework::Tensor x_copy(X->type()); x_copy.mutable_data(X->dims(), ctx.GetPlace()); framework::TensorCopy(*X, ctx.GetPlace(), &x_copy); From d30b4bef3d61817183d79c7edd68b46731aaecf9 Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Tue, 26 Oct 2021 14:58:54 +0000 Subject: [PATCH 12/16] remove debug message of cufftHandler --- paddle/fluid/operators/spectral_helper.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/fluid/operators/spectral_helper.h b/paddle/fluid/operators/spectral_helper.h index 7792f7618dcf0..924ec7cd52d50 100644 --- a/paddle/fluid/operators/spectral_helper.h +++ b/paddle/fluid/operators/spectral_helper.h @@ -67,7 +67,6 @@ class CuFFTHandle { public: CuFFTHandle() { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftCreate(&handle_)); - std::cout << "Constructed Handle " << handle_ << std::endl; } CuFFTHandle(const CuFFTHandle& other) = delete; @@ -80,7 +79,6 @@ class CuFFTHandle { const ::cufftHandle& get() const { return handle_; } ~CuFFTHandle() { - std::cout << "Destructing Handle " << handle_ << std::endl; PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftDestroy(handle_)); } }; From a2e456a689b15dd6cd2fad7a041022a1fb7a6c82 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Tue, 26 Oct 2021 17:58:35 +0800 Subject: [PATCH 13/16] roll_op: support Tensor as input for shifts (#36727) --- paddle/fluid/operators/roll_op.cc | 39 ++++++++++++------- paddle/fluid/operators/roll_op.cu | 20 ++++++++++ paddle/fluid/operators/roll_op.h | 17 ++++++++ .../fluid/tests/unittests/test_roll_op.py | 28 +++++++++++++ python/paddle/tensor/manipulation.py | 23 +++++++---- 5 files changed, 105 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/roll_op.cc b/paddle/fluid/operators/roll_op.cc index b6a8111592fb7..b74dfc984affb 100644 --- a/paddle/fluid/operators/roll_op.cc +++ b/paddle/fluid/operators/roll_op.cc @@ -40,21 +40,23 @@ class RollOp : public framework::OperatorWithKernel { auto dims = ctx->Attrs().Get>("axis"); auto shifts = ctx->Attrs().Get>("shifts"); - if (dims.size() != 0) { - PADDLE_ENFORCE_EQ(dims.size(), shifts.size(), - platform::errors::InvalidArgument( - "When dims.size() != 0, dims.size() " - "should be equal to " - "shifts.size(). But received " - "dims.size() = %d, shifts.size() = %d", - dims.size(), shifts.size())); - } else { - PADDLE_ENFORCE_EQ(shifts.size(), 1, - platform::errors::InvalidArgument( - "When dims.size() == 0, shifts.size() " - "should be equal to 1, But received " - "shifts.size() = %d", - shifts.size())); + if (!ctx->HasInput("ShiftsTensor")) { + if (dims.size() != 0) { + PADDLE_ENFORCE_EQ(dims.size(), shifts.size(), + platform::errors::InvalidArgument( + "When dims.size() != 0, dims.size() " + "should be equal to " + "shifts.size(). But received " + "dims.size() = %d, shifts.size() = %d", + dims.size(), shifts.size())); + } else { + PADDLE_ENFORCE_EQ(shifts.size(), 1, + platform::errors::InvalidArgument( + "When dims.size() == 0, shifts.size() " + "should be equal to 1, But received " + "shifts.size() = %d", + shifts.size())); + } } ctx->SetOutputDim("Out", ctx->GetInputDim("X")); @@ -105,6 +107,10 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker { "The number of places by which the elements " "of the tensor are shifted.") .SetDefault({}); + AddInput("ShiftsTensor", + "The number of places by which the elements of the tensor " + "are shifted.") + .AsDispensable(); AddAttr>( "axis", "Axis along which to roll. It must have the same size " @@ -129,6 +135,9 @@ class RollGradMaker : public framework::SingleGradOpMaker { void Apply(GradOpPtr op) const override { op->SetType("roll_grad"); op->SetInput("X", this->Input("X")); + if (this->HasInput("ShiftsTensor")) { + op->SetInput("ShiftsTensor", this->Input("ShiftsTensor")); + } op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetAttrMap(this->Attrs()); diff --git a/paddle/fluid/operators/roll_op.cu b/paddle/fluid/operators/roll_op.cu index a170ce2fb111d..d70bd58887f84 100644 --- a/paddle/fluid/operators/roll_op.cu +++ b/paddle/fluid/operators/roll_op.cu @@ -59,6 +59,16 @@ class RollKernel auto* in = context.Input("X"); auto* out = context.Output("Out"); std::vector shifts = context.Attr>("shifts"); + if (context.HasInput("ShiftsTensor")) { + const auto* shifts_tensor = + context.Input("ShiftsTensor"); + PADDLE_ENFORCE_EQ( + shifts_tensor->dims().size(), 1, + platform::errors::InvalidArgument( + "The rank of ShiftsTensor is expected to be 1, got %s", + shifts_tensor->dims().size())); + shifts = GetDataFromTensor(shifts_tensor); + } std::vector dims = context.Attr>("axis"); auto* in_data = in->data(); @@ -134,6 +144,16 @@ class RollGradKernel auto* in = context.Input(framework::GradVarName("Out")); auto* out = context.Output(framework::GradVarName("X")); std::vector shifts = context.Attr>("shifts"); + if (context.HasInput("ShiftsTensor")) { + const auto* shifts_tensor = + context.Input("ShiftsTensor"); + PADDLE_ENFORCE_EQ( + shifts_tensor->dims().size(), 1, + platform::errors::InvalidArgument( + "The rank of ShiftsTensor is expected to be 1, got %s", + shifts_tensor->dims().size())); + shifts = GetDataFromTensor(shifts_tensor); + } std::vector dims = context.Attr>("axis"); auto* in_data = in->data(); diff --git a/paddle/fluid/operators/roll_op.h b/paddle/fluid/operators/roll_op.h index e58ff521d8df7..affb5f226ed55 100644 --- a/paddle/fluid/operators/roll_op.h +++ b/paddle/fluid/operators/roll_op.h @@ -16,6 +16,8 @@ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/utils.h" +#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace operators { @@ -85,6 +87,16 @@ class RollKernel : public framework::OpKernel { auto& input = input_var->Get(); auto* output = output_var->GetMutable(); std::vector shifts = context.Attr>("shifts"); + if (context.HasInput("ShiftsTensor")) { + const auto* shifts_tensor = + context.Input("ShiftsTensor"); + PADDLE_ENFORCE_EQ( + shifts_tensor->dims().size(), 1, + platform::errors::InvalidArgument( + "The rank of ShiftsTensor is expected to be 1, got %s", + shifts_tensor->dims().size())); + shifts = GetDataFromTensor(shifts_tensor); + } std::vector dims = context.Attr>("axis"); std::vector out_vec; @@ -123,6 +135,11 @@ class RollGradKernel : public framework::OpKernel { auto& input = input_var->Get(); auto* output = output_var->GetMutable(); std::vector shifts = context.Attr>("shifts"); + if (context.HasInput("ShiftsTensor")) { + const auto* shifts_tensor = + context.Input("ShiftsTensor"); + shifts = GetDataFromTensor(shifts_tensor); + } std::vector dims = context.Attr>("axis"); std::vector out_vec; diff --git a/python/paddle/fluid/tests/unittests/test_roll_op.py b/python/paddle/fluid/tests/unittests/test_roll_op.py index 99121d2953a14..bca7665b814db 100644 --- a/python/paddle/fluid/tests/unittests/test_roll_op.py +++ b/python/paddle/fluid/tests/unittests/test_roll_op.py @@ -122,6 +122,34 @@ def test_axis_out_range(): self.assertRaises(ValueError, test_axis_out_range) + def test_shifts_as_tensor_dygraph(self): + with fluid.dygraph.guard(): + x = paddle.arange(9).reshape([3, 3]) + shape = paddle.shape(x) + shifts = shape // 2 + axes = [0, 1] + out = paddle.roll(x, shifts=shifts, axis=axes).numpy() + expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]]) + self.assertTrue(np.allclose(out, expected_out)) + + def test_shifts_as_tensor_static(self): + with program_guard(Program(), Program()): + x = paddle.arange(9).reshape([3, 3]).astype('float32') + shape = paddle.shape(x) + shifts = shape // 2 + axes = [0, 1] + out = paddle.roll(x, shifts=shifts, axis=axes) + expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]]) + + exe = fluid.Executor(fluid.CPUPlace()) + [out_np] = exe.run(fetch_list=[out]) + self.assertTrue(np.allclose(out_np, expected_out)) + + if paddle.is_compiled_with_cuda(): + exe = fluid.Executor(fluid.CPUPlace()) + [out_np] = exe.run(fetch_list=[out]) + self.assertTrue(np.allclose(out_np, expected_out)) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 4129a1060daf9..1158846706fba 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -696,15 +696,24 @@ def roll(x, shifts, axis=None, name=None): helper = LayerHelper("roll", **locals()) check_type(axis, 'axis', (list, tuple), 'roll') - check_type(shifts, 'shifts', (list, tuple), 'roll') + out = helper.create_variable_for_type_inference(x.dtype) - helper.append_op( - type='roll', - inputs={'X': x}, - outputs={'Out': out}, - attrs={'axis': axis, - 'shifts': shifts}) + if isinstance(shifts, Variable): + helper.append_op( + type='roll', + inputs={'X': x, + "ShiftsTensor": shifts}, + outputs={'Out': out}, + attrs={'axis': axis}) + else: + check_type(shifts, 'shifts', (list, tuple), 'roll') + helper.append_op( + type='roll', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'axis': axis, + 'shifts': shifts}) return out From 4e22f407fbfcb4ab17e507c601a523092458b500 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Tue, 26 Oct 2021 10:15:01 +0000 Subject: [PATCH 14/16] fix fftshift/ifftshift on static mode --- python/paddle/fft.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/paddle/fft.py b/python/paddle/fft.py index de15eba0feffa..7399ccc1ace59 100644 --- a/python/paddle/fft.py +++ b/python/paddle/fft.py @@ -1300,13 +1300,13 @@ def fftshift(x, axes=None, name=None): shape = paddle.shape(x) if axes is None: # shift all axes - rank = paddle.rank(x).reshape([1]) - axes = axes or paddle.arange(0, rank) - shifts = [size // 2 for size in shape] + rank = len(x.shape) + axes = list(range(0, rank)) + shifts = shape // 2 elif isinstance(axes, int): shifts = shape[axes] // 2 else: - shifts = [shape[ax] // 2 for ax in axes] + shifts = paddle.concat([shape[ax] // 2 for ax in axes]) return paddle.roll(x, shifts, axes, name=name) @@ -1343,13 +1343,13 @@ def ifftshift(x, axes=None, name=None): shape = paddle.shape(x) if axes is None: # shift all axes - rank = paddle.rank(x).reshape([1]) - axes = axes or paddle.arange(0, rank) - shifts = [-size // 2 for size in shape] + rank = len(x.shape) + axes = list(range(0, rank)) + shifts = shape // 2 elif isinstance(axes, int): shifts = -shape[axes] // 2 else: - shifts = [-shape[ax] // 2 for ax in axes] + shifts = paddle.concat([-shape[ax] // 2 for ax in axes]) return paddle.roll(x, shifts, axes, name=name) From 6e2fa31cbcaae28550ed30723b110bf1719fb359 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Tue, 26 Oct 2021 14:46:51 +0000 Subject: [PATCH 15/16] update roll_op version --- paddle/fluid/operators/roll_op.cc | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/roll_op.cc b/paddle/fluid/operators/roll_op.cc index b74dfc984affb..f82510556fde8 100644 --- a/paddle/fluid/operators/roll_op.cc +++ b/paddle/fluid/operators/roll_op.cc @@ -183,7 +183,12 @@ REGISTER_OP_VERSION(roll) "(std::vector) Axis along which to roll. " "It must have the same size with shifts, or size = 0.", std::vector()) - .DeleteAttr( - "dims", - "(std::vector) Dims along which to roll. " - "It must have the same size with shifts, or size = 0.")); + .DeleteAttr("dims", + "(std::vector) Dims along which to roll. " + "It must have the same size with shifts, or size = 0.")) + .AddCheckpoint( + R"ROC(Upgrade roll add a dispensable input "ShiftsTensor".)ROC", + paddle::framework::compatible::OpVersionDesc().NewInput( + "ShiftsTensor", + "The number of places by which the elements of" + "the tensor are shifted.")); From c5d92f378db09368bf7eb6f94537b06f9ab90593 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Tue, 26 Oct 2021 15:02:54 +0000 Subject: [PATCH 16/16] add more test cases for fftshift/ifftshift --- python/paddle/fluid/tests/unittests/fft/test_fft.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft.py b/python/paddle/fluid/tests/unittests/fft/test_fft.py index c83c943217d4e..604de11521b7d 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft.py @@ -1009,10 +1009,11 @@ def test_rfftfreq(self): @place(DEVICES) -@parameterize((TEST_CASE_NAME, 'x', 'axes', 'dtype'), [ - ('test_1d', np.random.randn(10), (0, ), 'float64'), - ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), -]) +@parameterize( + (TEST_CASE_NAME, 'x', 'axes', 'dtype'), + [('test_1d', np.random.randn(10), (0, ), 'float64'), + ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), + ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64')]) class TestFftShift(unittest.TestCase): def test_fftshift(self): """Test fftshift with norm condition @@ -1030,6 +1031,7 @@ def test_fftshift(self): @parameterize((TEST_CASE_NAME, 'x', 'axes'), [ ('test_1d', np.random.randn(10), (0, ), 'float64'), ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), + ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'), ]) class TestIfftShift(unittest.TestCase): def test_ifftshift(self):