diff --git a/jax/experimental/sparse/ad.py b/jax/experimental/sparse/ad.py index 2c235c9320d5..018047e3d5e1 100644 --- a/jax/experimental/sparse/ad.py +++ b/jax/experimental/sparse/ad.py @@ -67,8 +67,12 @@ def reconstruct(i, grad_out): return f_recons(grad_out) def postprocess_gradients(grads_out): - out = [reconstruct(*args) for args in safe_zip(argnums_flat1, grads_out)] - return out[0] if isinstance(argnums, int) else out + leaf_grads = [None] * tree1.num_leaves + for i, grad in safe_zip(argnums_flat1, grads_out): + leaf_grads[i] = reconstruct(i, grad) + grad_tree = tree_util.tree_unflatten(tree1, leaf_grads) + grad_tree = tuple(filter(lambda x: jax.tree.leaves(x), grad_tree)) + return grad_tree[0] if len(grad_tree) == 1 else grad_tree return fun_flat, argnums_flat, args_flat, postprocess_gradients diff --git a/tests/sparse_test.py b/tests/sparse_test.py index df5bc647faa9..616396222ec6 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -712,6 +712,97 @@ def f(X, y): self.assertAllClose(jac_dense(f, argnums=1, has_aux=has_aux)(X, y), jac_sparse(f, argnums=1, has_aux=has_aux)(Xsp, y), rtol=rtol) + @jtu.sample_product(has_aux=[True, False], + deep=[True,False], + arg0=[True,False], + bias=[True,False]) + def test_sparse_pytree_grad(self, has_aux, deep, arg0, bias): + rng_sparse = sptu.rand_sparse(self.rng()) + rng = jtu.rand_default(self.rng()) + + y = rng(5, "float32") + X = rng_sparse((10, 5), "float32") + b = rng(10, "float32") + Xsp = sparse.BCOO.fromdense(X) + Xtree_sp = {'deep':{'X':Xsp}, + 'X':Xsp, + 'list':[None,(b,None)]} + Xtree_de = {'deep':{'X':X}, + 'X':X, + 'list':[None,(b,None)]} + + def f(Xtree, y): + if deep: + out = Xtree['deep']['X'] @ y + else: + out = Xtree['X'] @ y + # Other grad variables + if bias: + out += Xtree['list'][1][0] + out = jnp.sum(out) + if has_aux: + return out, {'y': y.shape} + else: + return out + + def g(y, Xtree): + if deep: + out = Xtree['deep']['X'] @ y + else: + out = Xtree['X'] @ y + # Other grad variables + if bias: + out += Xtree['list'][1][0] + out = jnp.sum(out) + if has_aux: + return out, {'y': y.shape} + return out + + with self.subTest("wrt sparse"): + # Argument ordering + if arg0: + grad_de = jax.grad(f, argnums=0, has_aux=has_aux)(Xtree_de, y) + grad_sp = sparse.grad(f, argnums=0, has_aux=has_aux)(Xtree_sp, y) + else: + grad_de = jax.grad(g, argnums=1, has_aux=has_aux)(y, Xtree_de) + grad_sp = sparse.grad(g, argnums=1, has_aux=has_aux)(y, Xtree_sp) + + if has_aux: + grad_de, aux_de = grad_de + grad_sp, aux_sp = grad_sp + self.assertAllClose(aux_de, aux_sp) + + # Pytree structure + is_bcoo = lambda x: isinstance(x, sparse.bcoo.BCOO) + grad_densified = jax.tree_util.tree_map(sparse.todense, grad_sp, + is_leaf=is_bcoo) + self.assertEqual(jax.tree_util.tree_structure(grad_de), + jax.tree_util.tree_structure(grad_densified)) + + # Depth in tree + if deep: + grad_sp_arr = grad_sp['deep']['X'] + grad_de_arr = grad_de['deep']['X'] + else: + grad_sp_arr = grad_sp['X'] + grad_de_arr = grad_de['X'] + self.assertIsInstance(grad_sp_arr, sparse.BCOO) + self.assertAllClose(grad_sp_arr.data, + sparse_bcoo._bcoo_extract(grad_sp_arr.indices, + grad_de_arr)) + # Other grad variables + if bias: + self.assertAllClose(grad_sp['list'][1][0], + grad_de['list'][1][0]) + + with self.subTest("wrt dense"): + # Argument ordering + if arg0: + self.assertAllClose(jax.grad(f, argnums=1, has_aux=has_aux)(Xtree_de, y), + sparse.grad(f, argnums=1, has_aux=has_aux)(Xtree_sp, y)) + else: + self.assertAllClose(jax.grad(g, argnums=0, has_aux=has_aux)(y, Xtree_de), + sparse.grad(g, argnums=0, has_aux=has_aux)(y, Xtree_sp)) class SparseObjectTest(sptu.SparseTestCase): @parameterized.named_parameters(