diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index 1aab51e1bc9b..11bde4d279b6 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -93,6 +93,7 @@ jax.scipy.signal correlate correlate2d csd + detrend istft stft welch diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index 284509be29eb..cdac2ba6e098 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -34,15 +34,66 @@ from jax._src.lax.lax import PrecisionLike from jax._src.numpy import linalg from jax._src.numpy.util import ( - check_arraylike, implements, promote_dtypes_inexact, promote_dtypes_complex) + check_arraylike, promote_dtypes_inexact, promote_dtypes_complex) from jax._src.third_party.scipy import signal_helper from jax._src.typing import Array, ArrayLike from jax._src.util import canonicalize_axis, tuple_delete, tuple_insert -@implements(osp_signal.fftconvolve) def fftconvolve(in1: ArrayLike, in2: ArrayLike, mode: str = "full", axes: Sequence[int] | None = None) -> Array: + """ + Convolve two N-dimensional arrays using Fast Fourier Transform (FFT). + + JAX implementation of :func:`scipy.signal.fftconvolve`. + + Args: + in1: left-hand input to the convolution. + in2: right-hand input to the convolution. Must have ``in1.ndim == in2.ndim``. + mode: controls the size of the output. Available operations are: + + * ``"full"``: (default) output the full convolution of the inputs. + * ``"same"``: return a centered portion of the ``"full"`` output which + is the same size as ``in1``. + * ``"valid"``: return the portion of the ``"full"`` output which do not + depend on padding at the array edges. + + axes: optional sequence of axes along which to apply the convolution. + + Returns: + Array containing the convolved result. + + See Also: + - :func:`jax.numpy.convolve`: 1D convolution + - :func:`jax.scipy.signal.convolve`: direct convolution + + Examples: + A few 1D convolution examples. Because FFT-based convolution is approximate, + We use :func:`jax.numpy.printoptions` below to adjust the printing precision: + + >>> x = jnp.array([1, 2, 3, 2, 1]) + >>> y = jnp.array([1, 1, 1]) + + Full convolution uses implicit zero-padding at the edges: + + >>> with jax.numpy.printoptions(precision=3): + ... print(jax.scipy.signal.fftconvolve(x, y, mode='full')) + [1. 3. 6. 7. 6. 3. 1.] + + Specifying ``mode = 'same'`` returns a centered convolution the same size + as the first input: + + >>> with jax.numpy.printoptions(precision=3): + ... print(jax.scipy.signal.fftconvolve(x, y, mode='same')) + [3. 6. 7. 6. 3.] + + Specifying ``mode = 'valid'`` returns only the portion where the two arrays + fully overlap: + + >>> with jax.numpy.printoptions(precision=3): + ... print(jax.scipy.signal.fftconvolve(x, y, mode='valid')) + [6. 7. 6.] + """ check_arraylike('fftconvolve', in1, in2) in1, in2 = promote_dtypes_inexact(in1, in2) if in1.ndim != in2.ndim: @@ -133,9 +184,63 @@ def _convolve_nd(in1: Array, in2: Array, mode: str, *, precision: PrecisionLike) return result[0, 0] -@implements(osp_signal.convolve) def convolve(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto', precision: PrecisionLike = None) -> Array: + """Convolution of two N-dimensional arrays. + + JAX implementation of :func:`jax.scipy.signal.convolve`. + + Args: + in1: left-hand input to the convolution. + in2: right-hand input to the convolution. Must have ``in1.ndim == in2.ndim``. + mode: controls the size of the output. Available operations are: + + * ``"full"``: (default) output the full convolution of the inputs. + * ``"same"``: return a centered portion of the ``"full"`` output which + is the same size as ``in1``. + * ``"valid"``: return the portion of the ``"full"`` output which do not + depend on padding at the array edges. + + method: controls the computation method. Options are + + * ``"auto"``: (default) always uses the ``"direct"`` method. + * ``"direct"``: lower to :func:`jax.lax.conv_general_dilated`. + * ``"fft"``: compute the result via a fast Fourier transform. + + precision: Specify the precision of the computation. Refer to + :class:`jax.lax.Precision` for a description of available values. + + Returns: + Array containing the convolved result. + + See Also: + - :func:`jax.numpy.convolve`: 1D convolution + - :func:`jax.scipy.signal.convolve2d`: 2D convolution + - :func:`jax.scipy.signal.correlate`: ND correlation + + Examples: + A few 1D convolution examples: + + >>> x = jnp.array([1, 2, 3, 2, 1]) + >>> y = jnp.array([1, 1, 1]) + + Full convolution uses implicit zero-padding at the edges: + + >>> jax.scipy.signal.convolve(x, y, mode='full') + Array([1., 3., 6., 7., 6., 3., 1.], dtype=float32) + + Specifying ``mode = 'same'`` returns a centered convolution the same size + as the first input: + + >>> jax.scipy.signal.convolve(x, y, mode='same') + Array([3., 6., 7., 6., 3.], dtype=float32) + + Specifying ``mode = 'valid'`` returns only the portion where the two arrays + fully overlap: + + >>> jax.scipy.signal.convolve(x, y, mode='valid') + Array([6., 7., 6.], dtype=float32) + """ if method == 'fft': return fftconvolve(in1, in2, mode=mode) elif method in ['direct', 'auto']: @@ -144,9 +249,42 @@ def convolve(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto', raise ValueError(f"Got {method=}; expected 'auto', 'fft', or 'direct'.") -@implements(osp_signal.convolve2d) def convolve2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill', fillvalue: float = 0, precision: PrecisionLike = None) -> Array: + """Convolution of two 2-dimensional arrays. + + JAX implementation of :func:`jax.scipy.signal.convolve2d`. + + Args: + in1: left-hand input to the convolution. Must have ``in1.ndim == 2``. + in2: right-hand input to the convolution. Must have ``in2.ndim == 2``. + mode: controls the size of the output. Available operations are: + + * ``"full"``: (default) output the full convolution of the inputs. + * ``"same"``: return a centered portion of the ``"full"`` output which + is the same size as ``in1``. + * ``"valid"``: return the portion of the ``"full"`` output which do not + depend on padding at the array edges. + + boundary: only ``"fill"`` is supported. + fillvalue: only ``0`` is supported. + method: controls the computation method. Options are + + * ``"auto"``: (default) always uses the ``"direct"`` method. + * ``"direct"``: lower to :func:`jax.lax.conv_general_dilated`. + * ``"fft"``: compute the result via a fast Fourier transform. + + precision: Specify the precision of the computation. Refer to + :class:`jax.lax.Precision` for a description of available values. + + Returns: + Array containing the convolved result. + + See Also: + - :func:`jax.numpy.convolve`: 1D convolution + - :func:`jax.scipy.signal.convolve`: ND convolution + - :func:`jax.scipy.signal.correlate`: ND correlation + """ if boundary != 'fill' or fillvalue != 0: raise NotImplementedError("convolve2d() only supports boundary='fill', fillvalue=0") if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2: @@ -154,15 +292,79 @@ def convolve2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill return _convolve_nd(in1, in2, mode, precision=precision) -@implements(osp_signal.correlate) def correlate(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto', precision: PrecisionLike = None) -> Array: + """Cross-correlation of two N-dimensional arrays. + + JAX implementation of :func:`jax.scipy.signal.correlate`. + + Args: + in1: left-hand input to the cross-correlation. + in2: right-hand input to the cross-correlation. Must have ``in1.ndim == in2.ndim``. + mode: controls the size of the output. Available operations are: + + * ``"full"``: (default) output the full cross-correlation of the inputs. + * ``"same"``: return a centered portion of the ``"full"`` output which + is the same size as ``in1``. + * ``"valid"``: return the portion of the ``"full"`` output which do not + depend on padding at the array edges. + + method: controls the computation method. Options are + + * ``"auto"``: (default) always uses the ``"direct"`` method. + * ``"direct"``: lower to :func:`jax.lax.conv_general_dilated`. + * ``"fft"``: compute the result via a fast Fourier transform. + + precision: Specify the precision of the computation. Refer to + :class:`jax.lax.Precision` for a description of available values. + + Returns: + Array containing the cross-correlation result. + + See Also: + - :func:`jax.numpy.correlate`: 1D cross-correlation + - :func:`jax.scipy.signal.correlate2d`: 2D cross-correlation + - :func:`jax.scipy.signal.convolve`: ND convolution + """ return convolve(in1, jnp.flip(in2.conj()), mode, precision=precision, method=method) -@implements(osp_signal.correlate2d) def correlate2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill', fillvalue: float = 0, precision: PrecisionLike = None) -> Array: + """Cross-correlation of two 2-dimensional arrays. + + JAX implementation of :func:`jax.scipy.signal.correlate2d`. + + Args: + in1: left-hand input to the cross-correlation. Must have ``in1.ndim == 2``. + in2: right-hand input to the cross-correlation. Must have ``in2.ndim == 2``. + mode: controls the size of the output. Available operations are: + + * ``"full"``: (default) output the full cross-correlation of the inputs. + * ``"same"``: return a centered portion of the ``"full"`` output which + is the same size as ``in1``. + * ``"valid"``: return the portion of the ``"full"`` output which do not + depend on padding at the array edges. + + boundary: only ``"fill"`` is supported. + fillvalue: only ``0`` is supported. + method: controls the computation method. Options are + + * ``"auto"``: (default) always uses the ``"direct"`` method. + * ``"direct"``: lower to :func:`jax.lax.conv_general_dilated`. + * ``"fft"``: compute the result via a fast Fourier transform. + + precision: Specify the precision of the computation. Refer to + :class:`jax.lax.Precision` for a description of available values. + + Returns: + Array containing the cross-correlation result. + + See Also: + - :func:`jax.numpy.correlate`: 1D cross-correlation + - :func:`jax.scipy.signal.correlate`: ND cross-correlation + - :func:`jax.scipy.signal.convolve`: ND convolution + """ if boundary != 'fill' or fillvalue != 0: raise NotImplementedError("correlate2d() only supports boundary='fill', fillvalue=0") if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2: @@ -191,9 +393,51 @@ def correlate2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fil return result -@implements(osp_signal.detrend) def detrend(data: ArrayLike, axis: int = -1, type: str = 'linear', bp: int = 0, overwrite_data: None = None) -> Array: + """ + Remove linear or piecewise linear trends from data. + + JAX implementation of :func:`scipy.signal.detrend`. + + Args: + data: The input array containing the data to detrend. + axis: The axis along which to detrend. Default is -1 (the last axis). + type: The type of detrending. Can be: + + * ``'linear'``: Fit a single linear trend for the entire data. + * ``'constant'``: Remove the mean value of the data. + + bp: A sequence of breakpoints. If given, piecewise linear trends + are fit between these breakpoints. + overwrite_data: This argument is not supported by JAX's implementation. + + Returns: + The detrended data array. + + Example: + A simple detrend operation in one dimension: + + >>> data = jnp.array([1., 4., 8., 8., 9.]) + + Removing a linear trend from the data: + + >>> detrended = jax.scipy.signal.detrend(data) + >>> with jnp.printoptions(precision=3, suppress=True): # suppress float error + ... print("Detrended:", detrended) + ... print("Underlying trend:", data - detrended) + Detrended: [-1. -0. 2. -0. -1.] + Underlying trend: [ 2. 4. 6. 8. 10.] + + Removing a constant trend from the data: + + >>> detrended = jax.scipy.signal.detrend(data, type='constant') + >>> with jnp.printoptions(precision=3): # suppress float error + ... print("Detrended:", detrended) + ... print("Underlying trend:", data - detrended) + Detrended: [-5. -2. 2. 2. 3.] + Underlying trend: [6. 6. 6. 6. 6.] + """ if overwrite_data is not None: raise NotImplementedError("overwrite_data argument not implemented.") if type not in ['constant', 'linear']: @@ -499,11 +743,44 @@ def detrend_func(d): return freqs, time, result -@implements(osp_signal.stft) def stft(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int = 256, noverlap: int | None = None, nfft: int | None = None, detrend: bool = False, return_onesided: bool = True, boundary: str | None = 'zeros', padded: bool = True, axis: int = -1) -> tuple[Array, Array, Array]: + """ + Compute the short-time Fourier transform (STFT). + + JAX implementation of :func:`scipy.signal.stft`. + + Args: + x: Array representing a time series of input values. + fs: Sampling frequency of the time series (default: 1.0). + window: Data tapering window to apply to each segment. Can be a window function name, + a tuple specifying a window length and function, or an array (default: ``'hann'``). + nperseg: Length of each segment (default: 256). + noverlap: Number of points to overlap between segments (default: ``nperseg // 2``). + nfft: Length of the FFT used, if a zero-padded FFT is desired. If ``None`` (default), + the FFT length is ``nperseg``. + detrend: Specifies how to detrend each segment. Can be ``False`` (default: no detrending), + ``'constant'`` (remove mean), ``'linear'`` (remove linear trend), or a callable + accepting a segment and returning a detrended segment. + return_onesided: If True (default), return a one-sided spectrum for real inputs. + If False, return a two-sided spectrum. + boundary: Specifies whether the input signal is extended at both ends, and how. + Options are ``None`` (no extension), ``'zeros'`` (default), ``'even'``, ``'odd'``, + or ``'constant'``. + padded: Specifies whether the input signal is zero-padded at the end to make its + length a multiple of `nperseg`. If True (default), the padded signal length is + the next multiple of ``nperseg``. + axis: Axis along which the STFT is computed; the default is over the last axis (-1). + + Returns: + A length-3 tuple of arrays ``(f, t, Zxx)``. ``f`` is the Array of sample frequencies. + ``t`` is the Array of segment times, and ``Zxx`` is the STFT of ``x``. + + See Also: + :func:`jax.scipy.signal.istft`: inverse short-time Fourier transform. + """ return _spectral_helper(x, None, fs, window, nperseg, noverlap, nfft, detrend, return_onesided, scaling='spectrum', axis=axis, @@ -511,19 +788,56 @@ def stft(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int = 256 padded=padded) -_csd_description = """ -The original SciPy function exhibits slightly different behavior between -``csd(x, x)``` and ```csd(x, x.copy())```. The LAX-backend version is designed -to follow the latter behavior. For using the former behavior, call this -function as `csd(x, None)`.""" - - -@implements(osp_signal.csd, lax_description=_csd_description) def csd(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int | None = None, noverlap: int | None = None, nfft: int | None = None, detrend: str = 'constant', return_onesided: bool = True, scaling: str = 'density', axis: int = -1, average: str = 'mean') -> tuple[Array, Array]: + """ + Estimate cross power spectral density (CSD) using Welch's method. + + This is a JAX implementation of :func:`scipy.signal.csd`. It is similar to + :func:`jax.scipy.signal.welch`, but it operates on two input signals and + estimates their cross-spectral density instead of the power spectral density + (PSD). + + Args: + x: Array representing a time series of input values. + y: Array representing the second time series of input values, the same length as ``x`` + along the specified ``axis``. If not specified, then assume ``y = x`` and compute + the PSD ``Pxx`` of ``x`` via Welch's method. + fs: Sampling frequency of the inputs (default: 1.0). + window: Data tapering window to apply to each segment. Can be a window function name, + a tuple specifying a window length and function, or an array (default: ``'hann'``). + nperseg: Length of each segment (default: 256). + noverlap: Number of points to overlap between segments (default: ``nperseg // 2``). + nfft: Length of the FFT used, if a zero-padded FFT is desired. If ``None`` (default), + the FFT length is ``nperseg``. + detrend: Specifies how to detrend each segment. Can be ``False`` (default: no detrending), + ``'constant'`` (remove mean), ``'linear'`` (remove linear trend), or a callable + accepting a segment and returning a detrended segment. + return_onesided: If True (default), return a one-sided spectrum for real inputs. + If False, return a two-sided spectrum. + scaling: Selects between computing the power spectral density (``'density'``, default) + or the power spectrum (``'spectrum'``) + axis: Axis along which the CSD is computed (default: -1). + average: The type of averaging to use on the periodograms; one of ``'mean'`` (default) + or ``'median'``. + + Returns: + A length-2 tuple of arrays ``(f, Pxy)``. ``f`` is the array of sample frequencies, + and ``Pxy`` is the cross spectral density of `x` and `y` + + Notes: + The original SciPy function exhibits slightly different behavior between + ``csd(x, x)`` and ``csd(x, x.copy())``. The LAX-backend version is designed + to follow the latter behavior. To replicate the former, call this function + function as ``csd(x, None)``. + + See Also: + - :func:`jax.scipy.signal.welch`: Power spectral density. + - :func:`jax.scipy.signal.stft`: Short-time Fourier transform. + """ freqs, _, Pxy = _spectral_helper(x, y, fs, window, nperseg, noverlap, nfft, detrend, return_onesided, scaling, axis, mode='psd') @@ -551,12 +865,46 @@ def csd(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, window: str = 'hann' return freqs, Pxy -@implements(osp_signal.welch) def welch(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int | None = None, noverlap: int | None = None, nfft: int | None = None, detrend: str = 'constant', return_onesided: bool = True, scaling: str = 'density', axis: int = -1, average: str = 'mean') -> tuple[Array, Array]: + """ + Estimate power spectral density (PSD) using Welch's method. + + This is a JAX implementation of :func:`scipy.signal.welch`. It divides the + input signal into overlapping segments, computes the modified periodogram for + each segment, and averages the results to obtain a smoother estimate of the PSD. + + Args: + x: Array representing a time series of input values. + fs: Sampling frequency of the inputs (default: 1.0). + window: Data tapering window to apply to each segment. Can be a window function name, + a tuple specifying a window length and function, or an array (default: ``'hann'``). + nperseg: Length of each segment (default: 256). + noverlap: Number of points to overlap between segments (default: ``nperseg // 2``). + nfft: Length of the FFT used, if a zero-padded FFT is desired. If ``None`` (default), + the FFT length is ``nperseg``. + detrend: Specifies how to detrend each segment. Can be ``False`` (default: no detrending), + ``'constant'`` (remove mean), ``'linear'`` (remove linear trend), or a callable + accepting a segment and returning a detrended segment. + return_onesided: If True (default), return a one-sided spectrum for real inputs. + If False, return a two-sided spectrum. + scaling: Selects between computing the power spectral density (``'density'``, default) + or the power spectrum (``'spectrum'``) + axis: Axis along which the PSD is computed (default: -1). + average: The type of averaging to use on the periodograms; one of ``'mean'`` (default) + or ``'median'``. + + Returns: + A length-2 tuple of arrays ``(f, Pxx)``. ``f`` is the array of sample frequencies, + and ``Pxx`` is the power spectral density of ``x``. + + See Also: + - :func:`jax.scipy.signal.csd`: Cross power spectral density. + - :func:`jax.scipy.signal.stft`: Short-time Fourier transform. + """ freqs, Pxx = csd(x, None, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, @@ -613,12 +961,54 @@ def _overlap_and_add(x: Array, step_size: int) -> Array: return x.reshape(tuple(batch_shape) + (-1,)) -@implements(osp_signal.istft) def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int | None = None, noverlap: int | None = None, nfft: int | None = None, input_onesided: bool = True, boundary: bool = True, time_axis: int = -1, freq_axis: int = -2) -> tuple[Array, Array]: + """ + Perform the inverse short-time Fourier transform (ISTFT). + + JAX implementation of :func:`scipy.signal.istft`; computes the inverse of + :func:`jax.scipy.signal.stft`. + + Args: + Zxx: STFT of the signal to be reconstructed. + fs: Sampling frequency of the time series (default: 1.0) + window: Data tapering window to apply to each segment. Can be a window function name, + a tuple specifying a window length and function, or an array (default: ``'hann'``). + nperseg: Number of data points per segment in the STFT. If ``None`` (default), the + value is determined from the size of ``Zxx``. + noverlap: Number of points to overlap between segments (default: ``nperseg // 2``). + nfft: Number of FFT points used in the STFT. If ``None`` (default), the + value is determined from the size of ``Zxx``. + input_onesided: If Tru` (default), interpret the input as a one-sided STFT + (positive frequencies only). If False, interpret the input as a two-sided STFT. + boundary: If True (default), it is assumed that the input signal was extended at + its boundaries by ``stft``. If `False`, the input signal is assumed to have been truncated at the boundaries by `stft`. + time_axis: Axis in `Zxx` corresponding to time segments (default: -1). + freq_axis: Axis in `Zxx` corresponding to frequency bins (default: -2). + + Returns: + A length-2 tuple of arrays ``(t, x)``. ``t`` is the Array of signal times, and ``x`` + is the reconstructed time series. + + See Also: + :func:`jax.scipy.signal.stft`: short-time Fourier transform. + + Example: + Demonstrate that this gives the inverse of :func:`~jax.scipy.signal.stft`: + + >>> x = jnp.array([1., 2., 3., 2., 1., 0., 1., 2.]) + >>> f, t, Zxx = jax.scipy.signal.stft(x, nperseg=4) + >>> print(Zxx) + [[ 1. +0.j 2.5+0.j 1. +0.j 1. +0.j 0.5+0.j ] + [-0.5+0.5j -1.5+0.j -0.5-0.5j -0.5+0.5j 0. -0.5j] + [ 0. +0.j 0.5+0.j 0. +0.j 0. +0.j -0.5+0.j ]] + >>> t, x_reconstructed = jax.scipy.signal.istft(Zxx) + >>> print(x_reconstructed) + [1. 2. 3. 2. 1. 0. 1. 2.] + """ # Input validation check_arraylike("istft", Zxx) if Zxx.ndim < 2: