Skip to content

Commit

Permalink
Re-pack gradients of jax.experimental.sparse.grad() to match original…
Browse files Browse the repository at this point in the history
… pytrees & test cases
  • Loading branch information
Blair-Johnson committed Jul 29, 2024
1 parent 85e83b5 commit 802a14c
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 2 deletions.
8 changes: 6 additions & 2 deletions jax/experimental/sparse/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,12 @@ def reconstruct(i, grad_out):
return f_recons(grad_out)

def postprocess_gradients(grads_out):
out = [reconstruct(*args) for args in safe_zip(argnums_flat1, grads_out)]
return out[0] if isinstance(argnums, int) else out
leaf_grads = [None] * tree1.num_leaves
for i, grad in safe_zip(argnums_flat1, grads_out):
leaf_grads[i] = reconstruct(i, grad)
grad_tree = tree_util.tree_unflatten(tree1, leaf_grads)
grad_tree = tuple(filter(lambda x: jax.tree.leaves(x), grad_tree))
return grad_tree[0] if len(grad_tree) == 1 else grad_tree

return fun_flat, argnums_flat, args_flat, postprocess_gradients

Expand Down
91 changes: 91 additions & 0 deletions tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,97 @@ def f(X, y):
self.assertAllClose(jac_dense(f, argnums=1, has_aux=has_aux)(X, y),
jac_sparse(f, argnums=1, has_aux=has_aux)(Xsp, y), rtol=rtol)

@jtu.sample_product(has_aux=[True, False],
deep=[True,False],
arg0=[True,False],
bias=[True,False])
def test_sparse_pytree_grad(self, has_aux, deep, arg0, bias):
rng_sparse = sptu.rand_sparse(self.rng())
rng = jtu.rand_default(self.rng())

y = rng(5, "float32")
X = rng_sparse((10, 5), "float32")
b = rng(10, "float32")
Xsp = sparse.BCOO.fromdense(X)
Xtree_sp = {'deep':{'X':Xsp},
'X':Xsp,
'list':[None,(b,None)]}
Xtree_de = {'deep':{'X':X},
'X':X,
'list':[None,(b,None)]}

def f(Xtree, y):
if deep:
out = Xtree['deep']['X'] @ y
else:
out = Xtree['X'] @ y
# Other grad variables
if bias:
out += Xtree['list'][1][0]
out = jnp.sum(out)
if has_aux:
return out, {'y': y.shape}
else:
return out

def g(y, Xtree):
if deep:
out = Xtree['deep']['X'] @ y
else:
out = Xtree['X'] @ y
# Other grad variables
if bias:
out += Xtree['list'][1][0]
out = jnp.sum(out)
if has_aux:
return out, {'y': y.shape}
return out

with self.subTest("wrt sparse"):
# Argument ordering
if arg0:
grad_de = jax.grad(f, argnums=0, has_aux=has_aux)(Xtree_de, y)
grad_sp = sparse.grad(f, argnums=0, has_aux=has_aux)(Xtree_sp, y)
else:
grad_de = jax.grad(g, argnums=1, has_aux=has_aux)(y, Xtree_de)
grad_sp = sparse.grad(g, argnums=1, has_aux=has_aux)(y, Xtree_sp)

if has_aux:
grad_de, aux_de = grad_de
grad_sp, aux_sp = grad_sp
self.assertAllClose(aux_de, aux_sp)

# Pytree structure
is_bcoo = lambda x: isinstance(x, sparse.bcoo.BCOO)
grad_densified = jax.tree_util.tree_map(sparse.todense, grad_sp,
is_leaf=is_bcoo)
self.assertEqual(jax.tree_util.tree_structure(grad_de),
jax.tree_util.tree_structure(grad_densified))

# Depth in tree
if deep:
grad_sp_arr = grad_sp['deep']['X']
grad_de_arr = grad_de['deep']['X']
else:
grad_sp_arr = grad_sp['X']
grad_de_arr = grad_de['X']
self.assertIsInstance(grad_sp_arr, sparse.BCOO)
self.assertAllClose(grad_sp_arr.data,
sparse_bcoo._bcoo_extract(grad_sp_arr.indices,
grad_de_arr))
# Other grad variables
if bias:
self.assertAllClose(grad_sp['list'][1][0],
grad_de['list'][1][0])

with self.subTest("wrt dense"):
# Argument ordering
if arg0:
self.assertAllClose(jax.grad(f, argnums=1, has_aux=has_aux)(Xtree_de, y),
sparse.grad(f, argnums=1, has_aux=has_aux)(Xtree_sp, y))
else:
self.assertAllClose(jax.grad(g, argnums=0, has_aux=has_aux)(y, Xtree_de),
sparse.grad(g, argnums=0, has_aux=has_aux)(y, Xtree_sp))

class SparseObjectTest(sptu.SparseTestCase):
@parameterized.named_parameters(
Expand Down

0 comments on commit 802a14c

Please sign in to comment.