diff --git a/autograd/scipy/linalg.py b/autograd/scipy/linalg.py index 1defe019..9a8c06d7 100644 --- a/autograd/scipy/linalg.py +++ b/autograd/scipy/linalg.py @@ -1,4 +1,5 @@ from __future__ import division +from functools import partial import scipy.linalg import autograd.numpy as anp @@ -35,6 +36,62 @@ def vjp(g): lambda ans, a, b, trans=0, lower=False, **kwargs: lambda g: solve_triangular(a, g, trans=_flip(a, trans), lower=lower)) +def grad_solve_banded(argnum, ans, l_and_u, a, b): + + updim = lambda x: x if x.ndim == a.ndim else x[...,None] + + def transpose_banded(l_and_u, a): + + # Compute the transpose of a banded matrix. + # The transpose is itself a banded matrix. + + num_rows = a.shape[0] + + shifts = anp.arange(-l_and_u[1], l_and_u[0]+1) + + T_a = anp.roll(a[:1, :], shifts[0]) + for rr in range(1, num_rows): + T_a = anp.vstack([T_a, anp.flipud(anp.roll(a[rr:rr+1, :], shifts[rr]))]) + T_a = anp.flipud(T_a) + + T_l_and_u = anp.flip(l_and_u) + + return T_l_and_u, T_a + + def banded_dot(l_and_u, uu, vv): + + # Compute tensor product of vectors uu and vv. + # Tensor product elements are resticted to the bands specified by l_and_u. + + # TODO: replace the brute-force ravel() by smarter dimension handeling of uu and vv + + # main diagonal + banded_uv = anp.ravel(uu)*anp.ravel(vv) + + # stack below the sub-diagonals + for rr in range(1, l_and_u[0]+1): + banded_uv_rr = anp.hstack([anp.ravel(uu)[rr:]*anp.ravel(vv)[:-rr], anp.zeros(rr)]) + banded_uv = anp.vstack([banded_uv, banded_uv_rr]) + + # stack above the sup-diagonals + for rr in range(1, l_and_u[1]+1): + banded_uv_rr = anp.hstack([anp.zeros(rr), anp.ravel(uu)[:-rr]*anp.ravel(vv)[rr:]]) + banded_uv = anp.vstack([banded_uv_rr, banded_uv]) + + return(banded_uv) + + T_l_and_u, T_a = transpose_banded(l_and_u, a) + + if argnum == 1: + return lambda g: -banded_dot(l_and_u, updim(solve_banded(T_l_and_u, T_a, g)), anp.transpose(updim(ans))) + elif argnum == 2: + return lambda g: solve_banded(T_l_and_u, T_a, g) + +defvjp(solve_banded, + partial(grad_solve_banded, 1), + partial(grad_solve_banded, 2), + argnums=[1, 2]) + def _jvp_sqrtm(dA, ans, A, disp=True, blocksize=64): assert disp, "sqrtm jvp not implemented for disp=False" return solve_sylvester(ans, ans, dA) diff --git a/tests/test_scipy.py b/tests/test_scipy.py index 343bd701..11acba7f 100644 --- a/tests/test_scipy.py +++ b/tests/test_scipy.py @@ -228,3 +228,4 @@ def test_odeint(): def test_sqrtm(): combo_check(spla.sqrtm, modes=['fwd'], order=2)([R(3, 3)]) def test_sqrtm(): combo_check(symmetrize_matrix_arg(spla.sqrtm, 0), modes=['fwd', 'rev'], order=2)([R(3, 3)]) def test_solve_sylvester(): combo_check(spla.solve_sylvester, [0, 1, 2], modes=['rev', 'fwd'], order=2)([R(3, 3)], [R(3, 3)], [R(3, 3)]) + def test_solve_banded(): combo_check(spla.solve_banded, [1, 2], modes=['rev'], order=1)([(1, 1)], [R(3,5)], [R(5)])