From f2e1d8483a0da33a877e1d7d6585e3cda14e331b Mon Sep 17 00:00:00 2001 From: "BREUER Axel (ENGIE Global Markets SAS)" Date: Fri, 1 Sep 2023 10:12:19 +0200 Subject: [PATCH 1/4] add support of scipy.solve_banded --- autograd/scipy/linalg.py | 57 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/autograd/scipy/linalg.py b/autograd/scipy/linalg.py index 1defe019..d093acf9 100644 --- a/autograd/scipy/linalg.py +++ b/autograd/scipy/linalg.py @@ -1,4 +1,5 @@ from __future__ import division +from functools import partial import scipy.linalg import autograd.numpy as anp @@ -35,6 +36,62 @@ def vjp(g): lambda ans, a, b, trans=0, lower=False, **kwargs: lambda g: solve_triangular(a, g, trans=_flip(a, trans), lower=lower)) +def grad_solve_banded(argnum, ans, l_and_u, a, b): + + updim = lambda x: x if x.ndim == a.ndim else x[...,None] + + def transpose_banded(l_and_u, a): + + # Compute the transpose of a banded matrix. + # The transpose is itself a banded matrix. + + num_rows = a.shape[0] + + shifts = anp.arange(-l_and_u[1], l_and_u[0]+1) + + T_a = anp.roll(a[:1, :], shifts[0]) + for rr in range(1, num_rows): + T_a = anp.vstack([T_a, anp.flipud(anp.roll(a[rr:rr+1, :], shifts[rr]))]) + T_a = anp.flipud(T_a) + + T_l_and_u = anp.flip(l_and_u) + + return T_l_and_u, T_a + + def banded_dot(l_and_u, uu, vv): + + # Compute tensor product of vectors uu and vv. + # Tensor product elements are resticted to the bands specified by l_and_u. + + # TODO: replace the brute-force ravel() by smarter dimension handeling of uu and vv + + # main diagonal + banded_uv = anp.ravel(uu)*anp.ravel(vv) + + # stack below the sub-diagonals + for rr in range(1, l_and_u[0]+1): + banded_uv_rr = anp.hstack([anp.ravel(uu)[rr:]*anp.ravel(vv)[:-rr], anp.zeros(rr)]) + banded_uv = anp.vstack([banded_uv, banded_uv_rr]) + + # stack above the sup-diagonals + for rr in range(1, l_and_u[1]+1): + banded_uv_rr = anp.hstack([anp.zeros(rr), anp.ravel(uu)[:-rr]*anp.ravel(vv)[rr:]]) + banded_uv = anp.vstack([banded_uv_rr, banded_uv]) + + return(banded_uv) + + T_l_and_u, T_a = transpose_banded(l_and_u, a) + + if argnum == 1: + return lambda g: -banded_dot(l_and_u, updim(solve_banded(T_l_and_u, T_a, g)), anp.transpose(updim(ans))) + elif argnum == 2: + return lambda g: solve_banded(T_l_and_u, T_a, g) + +defvjp(solve_banded, + partial(grad_solve_banded, 1), + partial(grad_solve_banded, 2), + argnums=[1, 2]) + def _jvp_sqrtm(dA, ans, A, disp=True, blocksize=64): assert disp, "sqrtm jvp not implemented for disp=False" return solve_sylvester(ans, ans, dA) From 7f3b9294dd635d5fb70871b06a80410d736d7b76 Mon Sep 17 00:00:00 2001 From: "BREUER Axel (ENGIE Global Markets SAS)" Date: Sun, 8 Oct 2023 02:29:54 +0200 Subject: [PATCH 2/4] test for scipy.linalg.test_banded --- tests/test_scipy.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_scipy.py b/tests/test_scipy.py index 343bd701..da5dc95d 100644 --- a/tests/test_scipy.py +++ b/tests/test_scipy.py @@ -10,6 +10,7 @@ warn('Skipping scipy tests.') else: import autograd.numpy as np + import autograd.numpy.linalg as npla import autograd.numpy.random as npr import autograd.scipy.signal import autograd.scipy.stats as stats @@ -189,12 +190,12 @@ def test_polygamma(): combo_check(special.polygamma, [1])([0], R(4)**2 + 1.3) def test_jn(): combo_check(special.jn, [1])([2], R(4)**2 + 1.3) def test_yn(): combo_check(special.yn, [1])([2], R(4)**2 + 1.3) - def test_psi(): unary_ufunc_check(special.psi, lims=[0.3, 2.0], test_complex=False) - def test_digamma(): unary_ufunc_check(special.digamma, lims=[0.3, 2.0], test_complex=False) - def test_gamma(): unary_ufunc_check(special.gamma, lims=[0.3, 2.0], test_complex=False) - def test_gammaln(): unary_ufunc_check(special.gammaln, lims=[0.3, 2.0], test_complex=False) - def test_gammasgn(): unary_ufunc_check(special.gammasgn,lims=[0.3, 2.0], test_complex=False) - def test_rgamma() : unary_ufunc_check(special.rgamma, lims=[0.3, 2.0], test_complex=False) + def test_psi(): unary_ufunc_check(special.psi, lims=[0.3, 2.0], test_complex=False) + def test_digamma(): unary_ufunc_check(special.digamma, lims=[0.3, 2.0], test_complex=False) + def test_gamma(): unary_ufunc_check(special.gamma, lims=[0.3, 2.0], test_complex=False) + def test_gammaln(): unary_ufunc_check(special.gammaln, lims=[0.3, 2.0], test_complex=False) + def test_gammasgn(): unary_ufunc_check(special.gammasgn, lims=[0.3, 2.0], test_complex=False) + def test_rgamma(): unary_ufunc_check(special.rgamma, lims=[0.3, 2.0], test_complex=False) def test_multigammaln(): combo_check(special.multigammaln, [0])([U(4., 5.), U(4., 5., (2,3))], [1, 2, 3]) @@ -228,3 +229,4 @@ def test_odeint(): def test_sqrtm(): combo_check(spla.sqrtm, modes=['fwd'], order=2)([R(3, 3)]) def test_sqrtm(): combo_check(symmetrize_matrix_arg(spla.sqrtm, 0), modes=['fwd', 'rev'], order=2)([R(3, 3)]) def test_solve_sylvester(): combo_check(spla.solve_sylvester, [0, 1, 2], modes=['rev', 'fwd'], order=2)([R(3, 3)], [R(3, 3)], [R(3, 3)]) + def test_solve_banded(): combo_check(spla.solve_banded, [1, 2], modes=['rev'], order=1)([(1, 1)], [R(3,5)], [R(5)]) From 9b5d154e1558ba181cf7472cd8ce61ba9eb72556 Mon Sep 17 00:00:00 2001 From: "BREUER Axel (ENGIE Global Markets SAS)" Date: Sun, 8 Oct 2023 02:31:57 +0200 Subject: [PATCH 3/4] fix transpose_banded in scipy.linalg --- autograd/scipy/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autograd/scipy/linalg.py b/autograd/scipy/linalg.py index d093acf9..9a8c06d7 100644 --- a/autograd/scipy/linalg.py +++ b/autograd/scipy/linalg.py @@ -52,7 +52,7 @@ def transpose_banded(l_and_u, a): T_a = anp.roll(a[:1, :], shifts[0]) for rr in range(1, num_rows): T_a = anp.vstack([T_a, anp.flipud(anp.roll(a[rr:rr+1, :], shifts[rr]))]) - T_a = anp.flipud(T_a) + T_a = anp.flipud(T_a) T_l_and_u = anp.flip(l_and_u) From 526051c99e901a360a08bd215b7f0deb1181ed56 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Thu, 16 Nov 2023 09:38:13 +0100 Subject: [PATCH 4/4] Revert unnecessary changes to test_scipy.py --- tests/test_scipy.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/test_scipy.py b/tests/test_scipy.py index da5dc95d..11acba7f 100644 --- a/tests/test_scipy.py +++ b/tests/test_scipy.py @@ -10,7 +10,6 @@ warn('Skipping scipy tests.') else: import autograd.numpy as np - import autograd.numpy.linalg as npla import autograd.numpy.random as npr import autograd.scipy.signal import autograd.scipy.stats as stats @@ -190,12 +189,12 @@ def test_polygamma(): combo_check(special.polygamma, [1])([0], R(4)**2 + 1.3) def test_jn(): combo_check(special.jn, [1])([2], R(4)**2 + 1.3) def test_yn(): combo_check(special.yn, [1])([2], R(4)**2 + 1.3) - def test_psi(): unary_ufunc_check(special.psi, lims=[0.3, 2.0], test_complex=False) - def test_digamma(): unary_ufunc_check(special.digamma, lims=[0.3, 2.0], test_complex=False) - def test_gamma(): unary_ufunc_check(special.gamma, lims=[0.3, 2.0], test_complex=False) - def test_gammaln(): unary_ufunc_check(special.gammaln, lims=[0.3, 2.0], test_complex=False) - def test_gammasgn(): unary_ufunc_check(special.gammasgn, lims=[0.3, 2.0], test_complex=False) - def test_rgamma(): unary_ufunc_check(special.rgamma, lims=[0.3, 2.0], test_complex=False) + def test_psi(): unary_ufunc_check(special.psi, lims=[0.3, 2.0], test_complex=False) + def test_digamma(): unary_ufunc_check(special.digamma, lims=[0.3, 2.0], test_complex=False) + def test_gamma(): unary_ufunc_check(special.gamma, lims=[0.3, 2.0], test_complex=False) + def test_gammaln(): unary_ufunc_check(special.gammaln, lims=[0.3, 2.0], test_complex=False) + def test_gammasgn(): unary_ufunc_check(special.gammasgn,lims=[0.3, 2.0], test_complex=False) + def test_rgamma() : unary_ufunc_check(special.rgamma, lims=[0.3, 2.0], test_complex=False) def test_multigammaln(): combo_check(special.multigammaln, [0])([U(4., 5.), U(4., 5., (2,3))], [1, 2, 3])