Skip to content

Commit

Permalink
working launchers
Browse files Browse the repository at this point in the history
  • Loading branch information
eringrant committed Sep 17, 2023
1 parent 491c734 commit f55ba72
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 26 deletions.
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,24 @@ To install a Conda environment with the requisite packages on CPU:
conda env create --file environment-cpu.yml
```

To test the code by running a quick visualization:
To test the code a quick debug run:

```sh
TODO
python -m nets.experiments.in_context_learning.launcher_local
```

## Installation

Optionally, define a few environment variables
by adding the following to a shell configuration file such as
`~/.bashrc`, `~/.bash_profile`, `~/.bash_login`, or `~/.profile`:

```sh
export SCRATCH_HOME="..."
```

Then, follow one of two methods below to install `nets`.

### Method #1: via Conda

Use Conda to develop `nets` code directly.
Expand Down
1 change: 1 addition & 0 deletions environment-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies:
- pandas>=2.0
- pytables
- submitit
- tqdm

# JAX:
- conda-forge::jax
Expand Down
1 change: 1 addition & 0 deletions environment-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies:
- pandas>=2.0
- pytables
- submitit
- tqdm

# NOTE: The following requires a GPU node available
# at the time of creating the environment.
Expand Down
14 changes: 11 additions & 3 deletions nets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
"""Module-wide constants for `nets`."""
import nets
import os
from pathlib import Path
import nets

__version__ = "0.0.1"

# Module-wide path constants as in
# https://stackoverflow.com/a/59597733.
__package_path = os.path.split(nets.__path__[0])[0]
DATA_DIR = os.path.join(__package_path, "data")
TMP_DIR = os.path.join("/tmp", "nets")
DATA_DIR = Path(__package_path, "data")
TMP_DIR = Path("/tmp", "nets")
os.makedirs(TMP_DIR, exist_ok=True)

scratch_home = os.environ.get("SCRATCH_HOME")
if scratch_home is not None:
SCRATCH_DIR = Path(scratch_home, "nets")
else:
SCRATCH_DIR = TMP_DIR
os.makedirs(SCRATCH_DIR, exist_ok=True)

del nets
del os
del __package_path
1 change: 1 addition & 0 deletions nets/experiments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Experiments using `nets`."""
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Launcher for local runs of in-context learning simulations."""
"""Launcher for local runs of analyzable online SGD."""
import logging
import os
from pathlib import Path

from dataclasses import dataclass
Expand All @@ -9,6 +8,7 @@
import jax
import optax

import nets
from nets.launch import configs
from nets.launch import submit
from nets.launch.hparams import Param
Expand Down Expand Up @@ -124,9 +124,8 @@ class DebugSearchConfig(SearchConfig):
executor = submit.get_submitit_executor(
cluster="local",
log_dir=Path(
"/tmp",
os.environ["USER"],
"in-ctx",
nets.SCRATCH_DIR,
"osgd",
submit.get_timestamp(),
),
)
Expand Down
142 changes: 142 additions & 0 deletions nets/experiments/in_context_learning/launcher_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Launcher for local runs of in-context learning simulations."""
import logging
from pathlib import Path

from dataclasses import dataclass
from dataclasses import field

import jax
import optax

import nets
from nets.launch import configs
from nets.launch import submit
from nets.launch.hparams import Param
from nets.launch.hparams import EnumParam
from nets.launch.hparams import FixedParam


from nets import datasets
from nets import samplers


@dataclass(frozen=True, kw_only=True)
class SearchConfig(configs.Config):
"""Generic config for a hyperparameter search."""

seed: Param = field(default_factory=lambda: EnumParam(range(0, 3)))

# Model params.
embed_dim: Param = field(init=False)
num_heads: Param = field(init=False)
depth: Param = field(init=False)
mlp_ratio: Param = field(init=False)
causal: Param = field(init=False)

# Training and evaluation params.
optimizer_fn: Param = field(default_factory=lambda: FixedParam(optax.adam))
learning_rate: Param = field(default_factory=lambda: FixedParam(1e-3))
train_batch_size: Param = field(default_factory=lambda: FixedParam(32))
eval_batch_size: Param = field(default_factory=lambda: FixedParam(32))
num_epochs: Param = field(default_factory=lambda: FixedParam(1))
evaluations_per_epoch: Param = field(default_factory=lambda: FixedParam(100))
evaluate_on_test_split: Param = field(default_factory=lambda: FixedParam(False))

# Dataset params.
num_train_classes: Param = field(init=False) # `init=False` toavoid init
num_valid_classes: Param = field(init=False) # of to-be-overridden value.
num_test_classes: Param = field(init=False)
prop_train_labels: Param = field(init=False)
prop_valid_labels: Param = field(init=False)
prop_test_labels: Param = field(init=False)
dataset_cls: Param = field(init=False)
exemplar_labeling: Param = field(init=False)
holdout_class_labeling: Param = field(init=False)
num_exemplars_per_class: Param = field(init=False)
exemplar_noise_scale: Param = field(init=False)

# Sampler params.
num_train_seqs: Param = field(init=False)
num_eval_seqs: Param = field(init=False)
train_sampler_cls: Param = field(init=False)
eval_sampler_cls: Param = field(init=False)
train_query_type: Param = field(init=False)
train_context_len: Param = field(init=False)
train_zipf_exponent: Param = field(init=False)
train_relabeling: Param = field(init=False)


@dataclass(frozen=True, kw_only=True)
class DebugSearchConfig(SearchConfig):
"""Singleton config for debugging."""

seed: Param = field(default_factory=lambda: FixedParam(0))

# No training.
num_epochs: Param = field(default_factory=lambda: FixedParam(0))
evaluations_per_epoch: Param = field(default_factory=lambda: FixedParam(1))

# Teeny tiny model.
embed_dim: Param = field(default_factory=lambda: FixedParam(8))
num_heads: Param = field(default_factory=lambda: FixedParam(8))
depth: Param = field(default_factory=lambda: FixedParam(2))
mlp_ratio: Param = field(default_factory=lambda: FixedParam(4.0))
causal: Param = field(default_factory=lambda: FixedParam(True))

num_train_classes: Param = field(default_factory=lambda: FixedParam(80))
num_valid_classes: Param = field(default_factory=lambda: FixedParam(20))
num_test_classes: Param = field(default_factory=lambda: FixedParam(16))
prop_train_labels: Param = field(default_factory=lambda: FixedParam(0.8))
prop_valid_labels: Param = field(default_factory=lambda: FixedParam(0.7))
prop_test_labels: Param = field(default_factory=lambda: FixedParam(0.3))
dataset_cls: Param = field(
default_factory=lambda: FixedParam(datasets.SymbolicDataset)
)
exemplar_labeling: Param = field(
default_factory=lambda: FixedParam(datasets.ExemplarLabeling.STANDARD)
)
holdout_class_labeling: Param = field(
default_factory=lambda: FixedParam(datasets.HoldoutClassLabeling.STANDARD)
)
num_exemplars_per_class: Param = field(default_factory=lambda: FixedParam(20))
exemplar_noise_scale: Param = field(default_factory=lambda: FixedParam(1.0))

num_train_seqs: Param = field(default_factory=lambda: FixedParam(int(1e3)))
num_eval_seqs: Param = field(default_factory=lambda: FixedParam(int(1e2)))
train_sampler_cls: Param = field(
default_factory=lambda: FixedParam(samplers.DirichletMultinomialSampler)
)
eval_sampler_cls: Param = field(
default_factory=lambda: FixedParam(samplers.DirichletMultinomialSampler)
)
train_query_type: Param = field(
default_factory=lambda: FixedParam(samplers.QueryType.SUPPORTED)
)
train_context_len: Param = field(default_factory=lambda: FixedParam(2))
train_zipf_exponent: Param = field(default_factory=lambda: FixedParam(1.0))
train_relabeling: Param = field(default_factory=lambda: FixedParam(False))


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

executor = submit.get_submitit_executor(
cluster="debug",
log_dir=Path(
nets.SCRATCH_DIR,
"in-ctx",
submit.get_timestamp(),
),
gpus_per_node=0,
)

jobs = submit.submit_jobs(
executor=executor,
cfg=DebugSearchConfig(
key=jax.random.PRNGKey(0),
num_configs=1,
),
)

result = jobs[0].result()
print(result)
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Launcher for SLURM runs of in-context learning simulations."""
import logging
import os
from pathlib import Path

from dataclasses import dataclass
Expand All @@ -9,6 +8,7 @@
import jax
import optax

import nets
from nets.launch import configs
from nets.launch import submit
from nets.launch.hparams import Param
Expand Down Expand Up @@ -116,17 +116,16 @@ class SymbolicSearchConfig(SearchConfig):
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

job_folder = Path(
"/tmp",
os.environ["USER"],
"in-ctx",
submit.get_timestamp(),
)

executor = submit.get_submitit_executor(
log_dir=job_folder,
cluster="slurm",
#
### Output directory. ###
log_dir=Path(
nets.SCRATCH_DIR,
"in-ctx",
submit.get_timestamp(),
),
#
### GPU mode. ###
slurm_partition="gpu",
slurm_parallelism=50,
Expand Down
9 changes: 7 additions & 2 deletions nets/launch/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,9 @@ def submit_jobs(executor: Executor, cfg: configs.Config):

# Dump the config at root.
job_root = executor.folder.parent
with open(os.path.join(job_root, "config.pkl"), "wb") as f:
pickle.dump(cfg, f)
if executor.cluster != "debug":
with open(os.path.join(job_root, "config.pkl"), "wb") as f:
pickle.dump(cfg, f)

async def async_annotate():
# Annotate results as they become available.
Expand All @@ -150,6 +151,10 @@ async def async_annotate():
results_paths = asyncio.run(async_annotate())
logging.info("All jobs terminated.")

if executor.cluster == "debug":
with open(os.path.join(job_root, "config.pkl"), "wb") as f:
pickle.dump(cfg, f)

# Last step: Try to concatenate all results into a single HDF file.
# This might error out depending on the joint size of results and the
# available memory on the current machine.
Expand Down
11 changes: 6 additions & 5 deletions nets/samplers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,15 @@ def relabel_sequence(key, labels):
relabeling = jnp.eye(n)[perm]
return (onehot_labels @ relabeling).argmax(axis=-1)

def do_not_relabel_sequence(key, labels):
del key
return labels

# TODO(eringrant): Satisfy type-checker but avoid branching.
if relabel_sequences:
self.relabel_sequences = jax.jit(jax.vmap(relabel_sequence))
else:

def identity(x):
return x

self.relabel_sequences = jax.jit(jax.vmap(identity))
self.relabel_sequences = jax.jit(jax.vmap(do_not_relabel_sequence))

# PRNG depends on `MAX_NUM_SEQS` parameter in the infinite `Sampler` case.
self._seq_keys = jax.random.split(key, num_seqs or MAX_NUM_SEQS)
Expand Down

0 comments on commit f55ba72

Please sign in to comment.