Skip to content

Commit

Permalink
Cleaning up CodeFactor issues
Browse files Browse the repository at this point in the history
  • Loading branch information
saforem2 committed Apr 14, 2022
1 parent def991a commit e5b52da
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
16 changes: 8 additions & 8 deletions src/l2hmc/group/pytorch/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,11 @@ def SU3Gradient(
"""
x.requires_grad_(True)
y = f(x)
id = torch.ones(x.shape[0], device=x.device)
identity = torch.ones(x.shape[0], device=x.device)
dydx, = torch.autograd.grad(y, x,
create_graph=create_graph,
retain_graph=True,
grad_outputs=id)
grad_outputs=identity)
return y, dydx


Expand All @@ -292,9 +292,9 @@ def mul(
) -> Tensor:
if adjoint_a and adjoint_b:
return torch.matmul(a.adjoint(), b.adjoint())
elif adjoint_a:
if adjoint_a:
return torch.matmul(a.adjoint(), b)
elif adjoint_b:
if adjoint_b:
return torch.matmul(a, b.adjoint())
return torch.matmul(a, b)

Expand All @@ -309,9 +309,9 @@ def mul(
) -> Tensor:
if adjoint_a and adjoint_b:
return -a - b
elif adjoint_a:
if adjoint_a:
return -a + b
elif adjoint_b:
if adjoint_b:
return a - b
return a + b

Expand Down Expand Up @@ -354,9 +354,9 @@ def mul(
) -> Tensor:
if adjoint_a and adjoint_b:
return torch.matmul(a.adjoint(), b.adjoint())
elif adjoint_a:
if adjoint_a:
return torch.matmul(a.adjoint(), b)
elif adjoint_b:
if adjoint_b:
return torch.matmul(a, b.adjoint())
return torch.matmul(a, b)

Expand Down
6 changes: 3 additions & 3 deletions src/l2hmc/lattice/u1/pytorch/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ def grad_action(
"""Compute the gradient of the potential function."""
x.requires_grad_(True)
s = self.action(x, beta)
id = torch.ones(x.shape[0], device=x.device)
identity = torch.ones(x.shape[0], device=x.device)
dsdx, = torch.autograd.grad(s, x,
create_graph=create_graph,
retain_graph=True,
grad_outputs=id)
create_graph=create_graph,
grad_outputs=identity)
return dsdx

def plaqs_diff(
Expand Down

0 comments on commit e5b52da

Please sign in to comment.