Skip to content

Commit

Permalink
Cleaning up CodeFactor issues
Browse files Browse the repository at this point in the history
  • Loading branch information
saforem2 committed Apr 14, 2022
1 parent 8e21a7d commit def991a
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/l2hmc/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,11 @@ def build_loss(self):
from l2hmc.loss.pytorch.loss import LatticeLoss
return LatticeLoss(lattice=self.lattice, # type: ignore
loss_config=self.config.loss)
elif self.config.framework == 'tensorflow':
if self.config.framework == 'tensorflow':
from l2hmc.loss.tensorflow.loss import LatticeLoss
return LatticeLoss(lattice=self.lattice, # type: ignore
loss_config=self.config.loss)
else:
raise ValueError('Unexpected value for `config.framework`')
raise ValueError('Unexpected value for `config.framework`')

def build_optimizer(self, dynamics: Optional[Any] = None):
lr = self.config.learning_rate.lr_init
Expand Down Expand Up @@ -288,15 +287,18 @@ def get_summary_writer(
if self.config.framework == 'tensorflow':
import tensorflow as tf
return tf.summary.create_file_writer(sdir) # type:ignore
elif self.config.framework == 'pytorch':
if self.config.framework == 'pytorch':
from torch.utils.tensorboard.writer import SummaryWriter
return SummaryWriter(sdir)

raise ValueError('Unable to get summary writer')

def build(self):
loss_fn = self.build_loss()
dynamics = self.build_dynamics()
optimizer = self.build_optimizer(dynamics)
if self.config.framework == 'pytorch':
assert self.config.framework in ['torch', 'pytorch', 'tensorflow']
if self.config.framework in ['torch', 'pytorch']:
accelerator = self.build_accelerator()
IS_CHIEF = accelerator.is_local_main_process
trainer = self.build_trainer(dynamics=dynamics,
Expand Down

0 comments on commit def991a

Please sign in to comment.