Skip to content

Commit

Permalink
Updates {common.py,main.py,utils/*}
Browse files Browse the repository at this point in the history
  • Loading branch information
saforem2 committed Aug 9, 2022
1 parent 8caf530 commit 208bf32
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 87 deletions.
24 changes: 13 additions & 11 deletions src/l2hmc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from l2hmc.configs import AnnealingSchedule, Steps
from l2hmc.configs import OUTPUTS_DIR
from l2hmc.configs import State
from l2hmc.utils.plot_helpers import make_ridgeplots, plot_dataArray, set_plot_style
from l2hmc.utils.plot_helpers import (
make_ridgeplots, plot_dataArray, set_plot_style
)
from l2hmc.utils.rich import get_console, is_interactive

os.environ['AUTOGRAPH_VERBOSITY'] = '0'
Expand Down Expand Up @@ -55,7 +57,6 @@ def grab_tensor(x: Any) -> np.ndarray | ScalarLike:
return x



def clear_cuda_cache():
import gc
gc.collect()
Expand All @@ -76,14 +77,14 @@ def check_diff(x, y, name: Optional[str] = None):
if isinstance(x, State):
xd = {'x': x.x, 'v': x.v, 'beta': x.beta}
yd = {'x': y.x, 'v': y.v, 'beta': y.beta}
check_diff(xd, yd, name=f'State')
check_diff(xd, yd, name='State')

elif isinstance(x, dict) and isinstance(y, dict):
for (kx, vx), (ky, vy) in zip(x.items(), y.items()):
if kx == ky:
check_diff(vx, vy, name=kx)
else:
log.warning(f'Mismatch encountered!')
log.warning('Mismatch encountered!')
log.warning(f'kx: {kx}')
log.warning(f'ky: {ky}')
vy_ = y.get(kx, None)
Expand All @@ -92,12 +93,12 @@ def check_diff(x, y, name: Optional[str] = None):
else:
log.warning(f'{kx} not in y, skipping!')
continue

elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
assert len(x) == len(y)
for idx in range(len(x)):
check_diff(x[idx], y[idx], name=f'{name}, {idx}')

else:
x = grab_tensor(x)
y = grab_tensor(y)
Expand Down Expand Up @@ -608,22 +609,22 @@ def plot_dataset(
outdir: Optional[os.PathLike] = None,
title: Optional[str] = None,
job_type: Optional[str] = None,
# run: Optional[Any] = None,
# arun: Optional[Any] = None,
# run: Any = None,
) -> None:
outdir = Path(outdir) if outdir is not None else Path(os.getcwd())
outdir.mkdir(exist_ok=True, parents=True)
# outdir = outdir.joinpath('plots')
job_type = job_type if job_type is not None else f'job-{get_timestamp()}'
names = ['rainbow', 'viridis_r', 'magma', 'mako', 'turbo', 'spectral']
cmap = np.random.choice(names, replace=True)

set_plot_style()
_ = make_ridgeplots(
dataset,
outdir=outdir,
drop_nans=True,
drop_zeros=False,
num_chains=nchains
num_chains=nchains,
cmap=cmap,
)
for key, val in dataset.data_vars.items():
if key == 'x':
Expand Down Expand Up @@ -654,7 +655,7 @@ def analyze_dataset(
"""Save plot and analyze resultant `xarray.Dataset`."""
job_type = job_type if job_type is not None else f'job-{get_timestamp()}'
dirs = make_subdirs(outdir)
if nchains is not None and nchains > 1024:
if nchains is not None and nchains > 1000:
nchains_ = nchains // 4
log.warning(
f'Reducing `nchains` from: {nchains} -> {nchains_} for plotting'
Expand Down Expand Up @@ -746,6 +747,7 @@ def save_and_analyze_data(
nchains=nchains,
job_type=job_type,
title=title)

if not is_interactive():
edir = Path(outdir).joinpath('logs')
edir.mkdir(exist_ok=True, parents=True)
Expand Down
5 changes: 5 additions & 0 deletions src/l2hmc/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ def main(cfg: DictConfig) -> None:
hstart = time.time()
_ = ex.evaluate(job_type='hmc')
log.info(f'HMC took: {time.time() - hstart:.5f}s')
from l2hmc.utils.plot_helpers import measure_improvement
measure_improvement(
experiment=ex,
title=f'{ex.config.framework}',
)


if __name__ == '__main__':
Expand Down
99 changes: 32 additions & 67 deletions src/l2hmc/utils/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import seaborn as sns
import tensorflow as tf
from tensorflow.python.framework.ops import EagerTensor
import torch
import xarray as xr

Expand Down Expand Up @@ -74,7 +75,6 @@ def era_summary(self, era) -> str:
raise ValueError

def _update(self, key: str, val: TensorLike) -> float:
from l2hmc.common import grab_tensor
if val is None:
raise ValueError(f'None encountered: {key}: {val}')

Expand All @@ -88,14 +88,18 @@ def _update(self, key: str, val: TensorLike) -> float:
else:
val = np.array(val)

arr = grab_tensor(val)
if isinstance(val, (EagerTensor, tf.Tensor)):
val = val.numpy() # type:ignore

try:
self.history[key].append(arr)
self.history[key].append(val)
except KeyError:
self.history[key] = [arr]
self.history[key] = [val]

return np.array(arr).mean()
# if isinstance(val, Scalar):
# return np.array(val).mean()

return np.array(val).mean()

def update(self, metrics: dict) -> dict:
avgs = {}
Expand Down Expand Up @@ -387,73 +391,34 @@ def to_DataArray(
x: Union[list, np.ndarray],
therm_frac: Optional[float] = 0.0,
) -> xr.DataArray:
"""Convert `x` to an `xarray.DataArray` with consistent named dims
Explicitly, since we accumulate data in an interative fashion
(i.e. by sequentially appending new values to the end of a list)
the input data `x` can have shape:
- if len(x.shape) == 1:
x.shape = [ndraws]
- if len(x.shape) == 2:
x.shape = [ndraws, nchains]
- if len(x.shape) == 3:
x.shape = [ndraws, nleapfrog, nchains]
For consistency, we reshape these cases as:
- [ndraws] --> [ndraws]
- [ndraws, nchains] --> [nchains, ndraws]
- [ndraws, nleapfrog, nchains] --> [nchains, nleapfrog, ndraws]
This allows us to aggregate multiple different `xr.DataArray`s into a
single `xr.Dataset`
The resultant `xr.DataArray` will have coordinates corresponding
to these named dims.
"""
arr = np.array(x)
assert len(arr.shape) in [1, 2, 3]
xargs = {
'dims': ['draw'],
'coords': [np.arange(len(arr))],
}
if len(arr.shape) == 3:
arr = arr.T
nchains, nlf, ndraws = arr.shape
xargs = {
'dims': ('chain', 'leapfrog', 'draw'),
'coords': [
np.arange(nchains),
np.arange(nlf),
np.arange(ndraws),
]
}
elif len(arr.shape) == 2:
arr = arr.T
nchains, ndraws = arr.shape
xargs = {
'dims': ('chain', 'draw'),
'coords': [
np.arange(nchains),
np.arange(ndraws),
]
}
else:
assert len(arr.shape) == 1
if therm_frac is not None and therm_frac > 0:
drop = int(therm_frac * arr.shape[0])
arr = arr[drop:]
# steps = np.arange(len(arr))
if len(arr.shape) == 1: # [ndraws]
ndraws = arr.shape[0]
xargs = {
'dims': ('draw'),
'coords': [np.arange(len(arr))],
}
dims = ['draw']
coords = [np.arange(len(arr))]
return xr.DataArray(arr, dims=dims, coords=coords) # type:ignore

darr = xr.DataArray(arr, **xargs)
if len(arr.shape) == 2: # [nchains, ndraws]
arr = arr.T
nchains, ndraws = arr.shape
dims = ('chain', 'draw')
coords = [np.arange(nchains), np.arange(ndraws)]
return xr.DataArray(arr, dims=dims, coords=coords) # type:ignore

# Drop first `therm_frac` percent of `draws` to account for warmup
if therm_frac is not None and therm_frac > 0.:
darr = darr.drop_sel(
draw=np.arange(int(therm_frac * len(darr.draw)))
)
if len(arr.shape) == 3: # [nchains, nlf, ndraws]
arr = arr.T
nchains, nlf, ndraws = arr.shape
dims = ('chain', 'leapfrog', 'draw')
coords = [np.arange(nchains), np.arange(nlf), np.arange(ndraws)]
return xr.DataArray(arr, dims=dims, coords=coords) # type:ignore

return darr
else:
print(f'arr.shape: {arr.shape}')
raise ValueError('Invalid shape encountered')

def get_dataset(
self,
Expand Down
87 changes: 78 additions & 9 deletions src/l2hmc/utils/plot_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
"""
from __future__ import absolute_import, annotations, division, print_function
import datetime
import logging
import os
from pathlib import Path
from typing import Any, Tuple, Optional
from typing import Any, Optional, Tuple
import warnings

import matplotlib.pyplot as plt
Expand All @@ -16,7 +17,9 @@
import pandas as pd
import seaborn as sns
import xarray as xr
import logging

# from l2hmc.experiment.pytorch.experiment import Experiment as ptExperiment
# from l2hmc.experiment.tensorflow.experiment import Experiment as tfExperiment

warnings.filterwarnings('ignore')

Expand All @@ -26,7 +29,7 @@

LW = plt.rcParams.get('axes.linewidth', 1.75)

colors = {
COLORS = {
'blue': '#007DFF',
'red': '#FF5252',
'green': '#63FF5B',
Expand Down Expand Up @@ -101,14 +104,76 @@ def get_timestamp(fstr=None):
FigAxes = Tuple[plt.Figure, plt.Axes]


def save_figure(fig: plt.Figure, fname: str, outdir: os.PathLike):
pngdir = Path(outdir).joinpath('pngs')
pngdir.mkdir(exist_ok=True, parents=True)
pngfile = pngdir.joinpath(f'{fname}.png')
svgfile = Path(outdir).joinpath(f'{fname}.svg')
fig.savefig(pngfile, dpi=400, bbox_inches='tight')
fig.savefig(svgfile, dpi=400, bbox_inches='tight')


def savefig(fig: plt.Figure, outfile: os.PathLike):
fout = Path(outfile)
parent = fout.parent
parent.mkdir(exist_ok=True, parents=True)
print(f'Saving figure to: {fout.as_posix()}')
log.info(f'Saving figure to: {fout.as_posix()}')
fig.savefig(fout.as_posix(), dpi=400, bbox_inches='tight')


def measure_improvement(
experiment: Any,
title: Optional[str] = None,
) -> None:
ehist = experiment.trainer.histories.get('eval', None)
hhist = experiment.trainer.histories.get('hmc', None)
if ehist is not None and hhist is not None:
edset = ehist.get_dataset()
hdset = hhist.get_dataset()
dQint_eval = edset.dQint.mean('chain')[1:]
dQint_hmc = hdset.dQint.mean('chain')[1:]
fig, ax = plt.subplots()
_ = ax.plot(
dQint_eval,
label='Trained',
lw=2.,
color=COLORS['blue'],
)
_ = ax.plot(
dQint_hmc,
label='HMC',
ls=':',
lw=1.5,
color=COLORS['blue'],
)
_ = ax.grid(True, alpha=0.2)
xticks = ax.get_xticks()
# xticklabels = ax.get_xticklabels()
_ = ax.set_xticklabels([
f'{experiment.config.steps.log * int(i)}' for i in xticks
])
_ = ax.set_xlabel('MD Step')
_ = ax.set_ylabel('dQint')
_ = ax.legend(
loc='best',
framealpha=0.1,
ncol=2,
labelcolor='#FFF',
shadow=True
)
if title is not None:
_ = ax.set_title(title)

outdir = experiment._outdir
improvement = np.mean(dQint_eval.values / dQint_hmc.values)
txtfile = Path(outdir).joinpath('model_improvement.txt').as_posix()
log.warning(f'Writing model improvement to: {txtfile}')
with open(txtfile, 'w') as f:
f.write(f'{improvement:.8f}')

save_figure(fig, fname='model_improvement', outdir=outdir)


def plot_scalar(
y: np.ndarray,
x: Optional[np.ndarray] = None,
Expand Down Expand Up @@ -699,14 +764,15 @@ def make_ridgeplots(
outdir: Optional[os.PathLike] = None,
drop_zeros: Optional[bool] = False,
drop_nans: Optional[bool] = True,
cmap: Optional[str] = 'viridis_r',
# default_style: dict = None,
cmap: Optional[str] = 'rainbow',
):
"""Make ridgeplots."""
data = {}
# with sns.axes_style('white', rc={'axes.facecolor': (0, 0, 0, 0)}):
# sns.set(style='white', palette='bright', context='paper')
# with sns.set_style(style='white'):
outdir = Path(os.getcwd()) if outdir is None else Path(outdir)
outdir = outdir.joinpath('ridgeplots')
with sns.plotting_context(
context='paper',
):
Expand Down Expand Up @@ -747,8 +813,11 @@ def make_ridgeplots(
# Initialize the FacetGrid object
ncolors = len(val.leapfrog.values)
pal = sns.color_palette(cmap, n_colors=ncolors)
g = sns.FacetGrid(lfdf, row='lf', hue='lf',
aspect=15, height=0.25, palette=pal)
g = sns.FacetGrid(
lfdf,
row='lf', hue='lf',
aspect=15, height=0.25, palette=pal # type:ignore
)

# Draw the densities in a few steps
_ = g.map(sns.kdeplot, key, cut=1,
Expand Down Expand Up @@ -783,7 +852,7 @@ def label(_, color, label): # type:ignore #noqa
outdir.mkdir(exist_ok=True, parents=True)
pngdir.mkdir(exist_ok=True, parents=True)

log.info(f'Saving figure to: {fsvg.as_posix()}')
log.warning(f'Saving figure to: {fsvg.as_posix()}')
plt.savefig(fsvg.as_posix(), dpi=500, bbox_inches='tight')
plt.savefig(fpng.as_posix(), dpi=500, bbox_inches='tight')

Expand Down

0 comments on commit 208bf32

Please sign in to comment.