Skip to content

Commit

Permalink
factor out simulate function
Browse files Browse the repository at this point in the history
  • Loading branch information
eringrant committed Sep 17, 2023
1 parent f55ba72 commit b8d02fb
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 14 deletions.
7 changes: 4 additions & 3 deletions nets/experiments/in_context_learning/launcher_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
import optax

import nets
from nets import datasets
from nets import samplers
from nets.launch import configs
from nets.launch import submit
from nets.launch.hparams import Param
from nets.launch.hparams import EnumParam
from nets.launch.hparams import FixedParam


from nets import datasets
from nets import samplers
from nets.simulators.in_context_learning import simulate


@dataclass(frozen=True, kw_only=True)
Expand Down Expand Up @@ -132,6 +132,7 @@ class DebugSearchConfig(SearchConfig):

jobs = submit.submit_jobs(
executor=executor,
fn=simulate,
cfg=DebugSearchConfig(
key=jax.random.PRNGKey(0),
num_configs=1,
Expand Down
7 changes: 4 additions & 3 deletions nets/experiments/in_context_learning/launcher_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import optax

import nets
from nets import datasets
from nets import samplers
from nets.launch import configs
from nets.launch import submit
from nets.launch.hparams import Param
Expand All @@ -17,8 +19,7 @@
from nets.launch.hparams import LogUniformParam
from nets.launch.hparams import UniformParam

from nets import datasets
from nets import samplers
from nets.simulators.in_context_learning import simulate


@dataclass(frozen=True, kw_only=True)
Expand Down Expand Up @@ -150,4 +151,4 @@ class SymbolicSearchConfig(SearchConfig):
exemplar_noise_scale=LogUniformParam(1e-1, 1e3, base=10),
)

jobs = submit.submit_jobs(executor, cfg)
jobs = submit.submit_jobs(executor, simulate, cfg)
12 changes: 4 additions & 8 deletions nets/launch/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@

import submitit

from nets.simulators.in_context_learning import simulate
from nets.launch import analyze
from nets.launch import configs


# Ignore warnings about invalid column names for PyTables.
import warnings
from tables import NaturalNameWarning
Expand All @@ -29,7 +27,7 @@


def augment_df_with_kwargs(func):
"""Return a function augments a `pd.DataFrame` with keyword arguments."""
"""Return a function that augments a `pd.DataFrame` with keyword arguments."""

def wrapped(**kwargs):
results_df = func(**kwargs)
Expand All @@ -42,9 +40,7 @@ def wrapped(**kwargs):
class Executor(submitit.AutoExecutor):
"""A `submitit.AutoExecutor` with a custom `starmap_array` method."""

def starmap_array(
self, fn: Callable, iterable: Iterable | configs.Config
) -> list[Any]:
def starmap_array(self, fn: Callable, iterable: Iterable) -> list[Any]:
"""A distributed equivalent of the `itertools.starmap` function."""
submissions = [
submitit.core.utils.DelayedSubmission(fn, **kwargs) for kwargs in iterable
Expand Down Expand Up @@ -117,14 +113,14 @@ def get_submitit_executor(
return executor


def submit_jobs(executor: Executor, cfg: configs.Config):
def submit_jobs(executor: Executor, fn: Callable, cfg: configs.Config):
"""Submit jobs to the cluster."""
logging.info(f"Using config {pprint.pformat(cfg)}.")

# Launch jobs.
logging.info("Launching jobs...")
jobs = executor.starmap_array(
augment_df_with_kwargs(simulate),
augment_df_with_kwargs(fn),
cfg,
)
logging.info(f"Waiting for {len(jobs)} jobs to terminate...")
Expand Down

0 comments on commit b8d02fb

Please sign in to comment.