diff --git a/src/l2hmc/group/pytorch/group.py b/src/l2hmc/group/pytorch/group.py index 15f048d6..2aa341bd 100644 --- a/src/l2hmc/group/pytorch/group.py +++ b/src/l2hmc/group/pytorch/group.py @@ -273,11 +273,11 @@ def SU3Gradient( """ x.requires_grad_(True) y = f(x) - id = torch.ones(x.shape[0], device=x.device) + identity = torch.ones(x.shape[0], device=x.device) dydx, = torch.autograd.grad(y, x, create_graph=create_graph, retain_graph=True, - grad_outputs=id) + grad_outputs=identity) return y, dydx @@ -292,9 +292,9 @@ def mul( ) -> Tensor: if adjoint_a and adjoint_b: return torch.matmul(a.adjoint(), b.adjoint()) - elif adjoint_a: + if adjoint_a: return torch.matmul(a.adjoint(), b) - elif adjoint_b: + if adjoint_b: return torch.matmul(a, b.adjoint()) return torch.matmul(a, b) @@ -309,9 +309,9 @@ def mul( ) -> Tensor: if adjoint_a and adjoint_b: return -a - b - elif adjoint_a: + if adjoint_a: return -a + b - elif adjoint_b: + if adjoint_b: return a - b return a + b @@ -354,9 +354,9 @@ def mul( ) -> Tensor: if adjoint_a and adjoint_b: return torch.matmul(a.adjoint(), b.adjoint()) - elif adjoint_a: + if adjoint_a: return torch.matmul(a.adjoint(), b) - elif adjoint_b: + if adjoint_b: return torch.matmul(a, b.adjoint()) return torch.matmul(a, b) diff --git a/src/l2hmc/lattice/u1/pytorch/lattice.py b/src/l2hmc/lattice/u1/pytorch/lattice.py index d5b74ee8..68328534 100644 --- a/src/l2hmc/lattice/u1/pytorch/lattice.py +++ b/src/l2hmc/lattice/u1/pytorch/lattice.py @@ -72,11 +72,11 @@ def grad_action( """Compute the gradient of the potential function.""" x.requires_grad_(True) s = self.action(x, beta) - id = torch.ones(x.shape[0], device=x.device) + identity = torch.ones(x.shape[0], device=x.device) dsdx, = torch.autograd.grad(s, x, - create_graph=create_graph, retain_graph=True, - grad_outputs=id) + create_graph=create_graph, + grad_outputs=identity) return dsdx def plaqs_diff(