Skip to content

Commit

Permalink
Adds l2hmc/utils/pytorch/utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
saforem2 committed Apr 14, 2022
1 parent db3b726 commit 333c34a
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 21 deletions.
30 changes: 16 additions & 14 deletions src/l2hmc/scripts/pytorch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from l2hmc.configs import get_jobdir
from l2hmc.experiment import Experiment
from l2hmc.trainers.pytorch.trainer import Trainer
# from l2hmc.utils.pytorch.utils import get_summary_writer
from l2hmc.utils.pytorch.utils import get_summary_writer


log = logging.getLogger(__name__)
Expand All @@ -35,18 +35,20 @@ def evaluate(
trainer: Trainer,
job_type: str,
run: Optional[Any] = None,
writer: Optional[Any] = None,
# writer: Optional[Any] = None,
nchains: Optional[int] = None,
eps: Optional[Tensor] = None,
) -> dict:
"""Evaluate model (nested as `trainer.model`)"""
nchains = -1 if nchains is None else nchains
therm_frac = cfg.get('therm_frac', 0.2)

writer = None
# # writer = None
jobdir = get_jobdir(cfg, job_type=job_type)
# if trainer.accelerator.is_local_main_process:
# writer = get_summary_writer(cfg, job_type=job_type)
if trainer.accelerator.is_local_main_process:
writer = get_summary_writer(cfg, job_type=job_type)
else:
writer = None

output = trainer.eval(run=run,
writer=writer,
Expand Down Expand Up @@ -76,14 +78,14 @@ def train(
cfg: DictConfig,
trainer: Trainer,
run: Optional[Any] = None,
writer: Optional[Any] = None,
# writer: Optional[Any] = None,
nchains: Optional[int] = None,
) -> dict:
writer = None
nchains = 16 if nchains is None else nchains
jobdir = get_jobdir(cfg, job_type='train')
# if trainer.accelerator.is_local_main_process:
# writer = get_summary_writer(cfg, job_type='train')
if trainer.accelerator.is_local_main_process:
writer = get_summary_writer(cfg, job_type='train')

# ------------------------------------------
# NOTE: cfg.profile will be False by default
Expand Down Expand Up @@ -138,8 +140,8 @@ def main(cfg: DictConfig) -> dict:
# ----------------------------------------------------------
should_train = (cfg.steps.nera > 0 and cfg.steps.nepoch > 0)
if should_train:
tw = experiment.get_summary_writer('train')
outputs['train'] = train(cfg, trainer, run=run, writer=tw) # [1.]
# tw = experiment.get_summary_writer('train')
outputs['train'] = train(cfg, trainer, run=run) # , writer=tw) # [1.]

if run is not None:
run.unwatch(objs['dynamics'])
Expand All @@ -149,20 +151,20 @@ def main(cfg: DictConfig) -> dict:
nchains = max((4, cfg.dynamics.nchains // 8))
if should_train and cfg.steps.test > 0: # [2.]
log.warning('Evaluating trained model')
ew = experiment.get_summary_writer('eval')
# ew = experiment.get_summary_writer('eval')
outputs['eval'] = evaluate(cfg,
run=run,
writer=ew,
# writer=ew,
job_type='eval',
nchains=nchains,
trainer=trainer)
if cfg.steps.test > 0: # [3.]
log.warning('Running generic HMC')
eps_hmc = torch.tensor(cfg.get('eps_hmc', 0.118))
hw = experiment.get_summary_writer('hmc')
# hw = experiment.get_summary_writer('hmc')
outputs['hmc'] = evaluate(cfg=cfg,
run=run,
writer=hw,
# writer=hw,
eps=eps_hmc,
job_type='hmc',
nchains=nchains,
Expand Down
12 changes: 7 additions & 5 deletions src/l2hmc/trainers/pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def record_metrics(
avgs = {k: v.mean() for k, v in record.items()}

summary = summarize_dict(avgs)
# if step is not None:
if writer is not None:
assert step is not None
update_summaries(step=step,
Expand All @@ -353,7 +354,8 @@ def record_metrics(
metrics=record,
prefix=job_type,
optimizer=optimizer)
writer.flush()
if writer is not None:
writer.flush()

if run is not None:
dQint = record.get('dQint', None)
Expand Down Expand Up @@ -539,10 +541,10 @@ def train(
if layout is not None and self.rank == 0:
layout['root']['main'].update(table)
# console.width = min(int(main_panel.get), WIDTH)
# console.rule(', '.join([
# f'BETA: {beta}',
# f'ERA: {era} / {self.steps.nera}',
# ]))
console.rule(', '.join([
f'BETA: {beta}',
f'ERA: {era} / {self.steps.nera}',
]))

# if WIDTH is not None and WIDTH > 0 and console is not None:
# console.width = WIDTH
Expand Down
6 changes: 4 additions & 2 deletions src/l2hmc/trainers/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,14 +286,16 @@ def record_metrics(
avgs = {k: v.mean() for k, v in record.items()}

summary = summarize_dict(avgs)
if writer is not None:
# if writer is not None:
if step is not None:
assert step is not None
update_summaries(step=step,
prefix=job_type,
model=model,
metrics=record,
optimizer=optimizer)
writer.flush()
if writer is not None:
writer.flush()

if run is not None:
dQint = record.get('dQint', None)
Expand Down
89 changes: 89 additions & 0 deletions src/l2hmc/utils/pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
utils/pytorch.py
Contains utilities for use in PyTorch
"""
from __future__ import absolute_import, division, print_function, annotations
from omegaconf import DictConfig
from pathlib import Path
import os
import logging
from torch.utils.tensorboard.writer import SummaryWriter

from l2hmc.dynamics.pytorch.dynamics import Dynamics
import torch


log = logging.getLogger(__name__)


def get_summary_writer(
cfg: DictConfig,
job_type: str
):
"""Returns SummaryWriter object for tracking summaries."""
outdir = Path(cfg.get('outdir', os.getcwd()))
jobdir = outdir.joinpath(job_type)
summary_dir = jobdir.joinpath('summaries')
summary_dir.mkdir(exist_ok=True, parents=True)

writer = SummaryWriter(summary_dir.as_posix())

return writer


def load_from_ckpt(
dynamics: Dynamics,
optimizer: torch.optim.Optimizer,
cfg: DictConfig,
) -> tuple[torch.nn.Module, torch.optim.Optimizer, dict]:
outdir = Path(cfg.get('outdir', os.getcwd()))
ckpts = list(outdir.joinpath('train', 'checkpoints').rglob('*.tar'))
if len(ckpts) > 0:
latest = max(ckpts, key=lambda p: p.stat().st_ctime)
if latest.is_file():
log.info(f'Loading from checkpoint: {latest}')
ckpt = torch.load(latest)
else:
raise FileNotFoundError(f'No checkpoints found in {outdir}')
else:
raise FileNotFoundError(f'No checkpoints found in {outdir}')

dynamics.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
dynamics.assign_eps({
'xeps': ckpt['xeps'],
'veps': ckpt['veps'],
})

return dynamics, optimizer, ckpt


# def update_wandb_config(
# cfg: DictConfig,
# tag: Optional[str] = None,
# debug: Optional[bool] = None,
# ) -> DictConfig:
# group = [
# 'pytorch',
# 'gpu' if torch.cuda.is_available() else 'cpu',
# 'DDP' if torch.cuda.device_count() > 1 else 'local'
# ]
# if debug:
# group.append('debug')

# cfg.wandb.setup.update({'group': '/'.join(group)})
# if tag is not None:
# cfg.wandb.setup.update({'id': tag})

# cfg.wandb.setup.update({
# 'tags': [
# f'{cfg.framework}',
# f'nlf-{cfg.dynamics.nleapfrog}',
# f'beta_final-{cfg.annealing_schedule.beta_final}',
# f'{cfg.dynamics.latvolume[0]}x{cfg.dynamics.latvolume[1]}',
# f'{cfg.dynamics.group}',
# ]
# })

# return cfg

0 comments on commit 333c34a

Please sign in to comment.