Skip to content

Commit

Permalink
[UserWarning] Fixing the warnings appearing in `test_nn.py::test_grou…
Browse files Browse the repository at this point in the history
…p_rev_res`. (#7486)
  • Loading branch information
drivanov committed Jul 1, 2024
1 parent 489671c commit 65b949f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/dgl/nn/pytorch/conv/grouprevres.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def forward(ctx, fn, fn_inverse, num_inputs, *inputs_and_weights):
outputs = ctx.fn(*x).detach_()

# clear memory of input node features
inputs[1].storage().resize_(0)
inputs[1].untyped_storage().resize_(0)

# store for backward pass
ctx.inputs = [inputs]
Expand Down Expand Up @@ -63,10 +63,10 @@ def backward(ctx, *grad_outputs):
*((inputs[0], outputs) + inputs[2:])
)
# clear memory of outputs
outputs.storage().resize_(0)
outputs.untyped_storage().resize_(0)

x = inputs[1]
x.storage().resize_(int(np.prod(x.size())))
x.untyped_storage().resize_(int(np.prod(x.size())))
x.set_(inputs_inverted)

# compute gradients
Expand Down

0 comments on commit 65b949f

Please sign in to comment.