Skip to content

Commit

Permalink
move experiment runs to a separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
eringrant committed Sep 16, 2023
1 parent af8ea07 commit 91e1889
Show file tree
Hide file tree
Showing 4 changed files with 325 additions and 180 deletions.
144 changes: 144 additions & 0 deletions experiments/in_context_learning/launcher_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""Launcher for local runs of in-context learning simulations."""
import logging
import os
from pathlib import Path

from dataclasses import dataclass
from dataclasses import field

import jax
import optax

from nets.launch import configs
from nets.launch import submit
from nets.launch.hparams import Param
from nets.launch.hparams import EnumParam
from nets.launch.hparams import FixedParam

from nets.simulators import in_context_learning

from nets import datasets
from nets import samplers


@dataclass(frozen=True, kw_only=True)
class SearchConfig(configs.Config):
"""Generic config for a hyperparameter search."""

seed: Param = field(default_factory=lambda: EnumParam(range(0, 3)))

# Model params.
embed_dim: Param = field(init=False)
num_heads: Param = field(init=False)
depth: Param = field(init=False)
mlp_ratio: Param = field(init=False)
causal: Param = field(init=False)

# Training and evaluation params.
optimizer_fn: Param = field(default_factory=lambda: FixedParam(optax.adam))
learning_rate: Param = field(default_factory=lambda: FixedParam(1e-3))
train_batch_size: Param = field(default_factory=lambda: FixedParam(32))
eval_batch_size: Param = field(default_factory=lambda: FixedParam(32))
num_epochs: Param = field(default_factory=lambda: FixedParam(1))
evaluations_per_epoch: Param = field(default_factory=lambda: FixedParam(100))
evaluate_on_test_split: Param = field(default_factory=lambda: FixedParam(False))

# Dataset params.
num_train_classes: Param = field(init=False) # `init=False` toavoid init
num_valid_classes: Param = field(init=False) # of to-be-overridden value.
num_test_classes: Param = field(init=False)
prop_train_labels: Param = field(init=False)
prop_valid_labels: Param = field(init=False)
prop_test_labels: Param = field(init=False)
dataset_cls: Param = field(init=False)
exemplar_labeling: Param = field(init=False)
holdout_class_labeling: Param = field(init=False)
num_exemplars_per_class: Param = field(init=False)
exemplar_noise_scale: Param = field(init=False)

# Sampler params.
num_train_seqs: Param = field(init=False)
num_eval_seqs: Param = field(init=False)
train_sampler_cls: Param = field(init=False)
eval_sampler_cls: Param = field(init=False)
train_query_type: Param = field(init=False)
train_context_len: Param = field(init=False)
train_zipf_exponent: Param = field(init=False)
train_relabeling: Param = field(init=False)


@dataclass(frozen=True, kw_only=True)
class DebugSearchConfig(SearchConfig):
"""Singleton config for debugging."""

seed: Param = field(default_factory=lambda: FixedParam(0))

# No training.
num_epochs: Param = field(default_factory=lambda: FixedParam(0))
evaluations_per_epoch: Param = field(default_factory=lambda: FixedParam(1))

# Teeny tiny model.
embed_dim: Param = field(default_factory=lambda: FixedParam(8))
num_heads: Param = field(default_factory=lambda: FixedParam(8))
depth: Param = field(default_factory=lambda: FixedParam(2))
mlp_ratio: Param = field(default_factory=lambda: FixedParam(4.0))
causal: Param = field(default_factory=lambda: FixedParam(True))

num_train_classes: Param = field(default_factory=lambda: FixedParam(80))
num_valid_classes: Param = field(default_factory=lambda: FixedParam(20))
num_test_classes: Param = field(default_factory=lambda: FixedParam(16))
prop_train_labels: Param = field(default_factory=lambda: FixedParam(0.8))
prop_valid_labels: Param = field(default_factory=lambda: FixedParam(0.7))
prop_test_labels: Param = field(default_factory=lambda: FixedParam(0.3))
dataset_cls: Param = field(
default_factory=lambda: FixedParam(datasets.SymbolicDataset)
)
exemplar_labeling: Param = field(
default_factory=lambda: FixedParam(datasets.ExemplarLabeling.STANDARD)
)
holdout_class_labeling: Param = field(
default_factory=lambda: FixedParam(datasets.HoldoutClassLabeling.STANDARD)
)
num_exemplars_per_class: Param = field(default_factory=lambda: FixedParam(20))
exemplar_noise_scale: Param = field(default_factory=lambda: FixedParam(1.0))

num_train_seqs: Param = field(default_factory=lambda: FixedParam(int(1e3)))
num_eval_seqs: Param = field(default_factory=lambda: FixedParam(int(1e2)))
train_sampler_cls: Param = field(
default_factory=lambda: FixedParam(samplers.DirichletMultinomialSampler)
)
eval_sampler_cls: Param = field(
default_factory=lambda: FixedParam(samplers.DirichletMultinomialSampler)
)
train_query_type: Param = field(
default_factory=lambda: FixedParam(samplers.QueryType.SUPPORTED)
)
train_context_len: Param = field(default_factory=lambda: FixedParam(2))
train_zipf_exponent: Param = field(default_factory=lambda: FixedParam(1.0))
train_relabeling: Param = field(default_factory=lambda: FixedParam(False))


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

executor = submit.get_submitit_executor(
cluster="local",
log_dir=Path(
"/tmp",
os.environ["USER"],
"in-ctx",
submit.get_timestamp(),
),
)

jobs = executor.map_array(
lambda kwargs: in_context_learning.simulate(
**kwargs,
),
DebugSearchConfig(
key=jax.random.PRNGKey(0),
num_configs=1,
),
)
result = jobs[0].result()
print(result)
154 changes: 154 additions & 0 deletions experiments/in_context_learning/launcher_slurm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""Launcher for SLURM runs of in-context learning simulations."""
import logging
import os
from pathlib import Path

from dataclasses import dataclass
from dataclasses import field

import jax
import optax

from nets.launch import configs
from nets.launch import submit
from nets.launch.hparams import Param
from nets.launch.hparams import EnumParam
from nets.launch.hparams import FixedParam
from nets.launch.hparams import LogUniformParam
from nets.launch.hparams import UniformParam

from nets import datasets
from nets import samplers


@dataclass(frozen=True, kw_only=True)
class SearchConfig(configs.Config):
"""Generic config for a hyperparameter search."""

seed: Param = field(default_factory=lambda: EnumParam(range(0, 3)))

# Model params.
embed_dim: Param = field(init=False)
num_heads: Param = field(init=False)
depth: Param = field(init=False)
mlp_ratio: Param = field(init=False)
causal: Param = field(init=False)

# Training and evaluation params.
optimizer_fn: Param = field(default_factory=lambda: FixedParam(optax.adam))
learning_rate: Param = field(default_factory=lambda: FixedParam(1e-3))
train_batch_size: Param = field(default_factory=lambda: FixedParam(32))
eval_batch_size: Param = field(default_factory=lambda: FixedParam(32))
num_epochs: Param = field(default_factory=lambda: FixedParam(1))
evaluations_per_epoch: Param = field(default_factory=lambda: FixedParam(100))
evaluate_on_test_split: Param = field(default_factory=lambda: FixedParam(False))

# Dataset params.
num_train_classes: Param = field(init=False) # `init=False` toavoid init
num_valid_classes: Param = field(init=False) # of to-be-overridden value.
num_test_classes: Param = field(init=False)
prop_train_labels: Param = field(init=False)
prop_valid_labels: Param = field(init=False)
prop_test_labels: Param = field(init=False)
dataset_cls: Param = field(init=False)
exemplar_labeling: Param = field(init=False)
holdout_class_labeling: Param = field(init=False)
num_exemplars_per_class: Param = field(init=False)
exemplar_noise_scale: Param = field(init=False)

# Sampler params.
num_train_seqs: Param = field(init=False)
num_eval_seqs: Param = field(init=False)
train_sampler_cls: Param = field(init=False)
eval_sampler_cls: Param = field(init=False)
train_query_type: Param = field(init=False)
train_context_len: Param = field(init=False)
train_zipf_exponent: Param = field(init=False)
train_relabeling: Param = field(init=False)


@dataclass(frozen=True, kw_only=True)
class SymbolicSearchConfig(SearchConfig):
"""Singleton hyperparameter search for the symbolic dataset."""

evaluations_per_epoch: Param = field(default_factory=lambda: FixedParam(100))

embed_dim: Param = field(default_factory=lambda: FixedParam(64))
num_heads: Param = field(default_factory=lambda: FixedParam(8))
depth: Param = field(default_factory=lambda: FixedParam(2))
mlp_ratio: Param = field(default_factory=lambda: FixedParam(4.0))
causal: Param = field(default_factory=lambda: FixedParam(True))

num_train_classes: Param = field(default_factory=lambda: FixedParam(1600))
num_valid_classes: Param = field(default_factory=lambda: FixedParam(2))
num_test_classes: Param = field(default_factory=lambda: FixedParam(2))
prop_train_labels: Param = field(default_factory=lambda: FixedParam(1.0))
prop_valid_labels: Param = field(default_factory=lambda: FixedParam(1.0))
prop_test_labels: Param = field(default_factory=lambda: FixedParam(1.0))
dataset_cls: Param = field(
default_factory=lambda: FixedParam(datasets.SymbolicDataset)
)
exemplar_labeling: Param = field(
default_factory=lambda: FixedParam(datasets.ExemplarLabeling.STANDARD)
)
holdout_class_labeling: Param = field(
default_factory=lambda: FixedParam(datasets.HoldoutClassLabeling.TRAIN_LABELS)
)
num_exemplars_per_class: Param = field(default_factory=lambda: FixedParam(20))
exemplar_noise_scale: Param = field(default_factory=lambda: FixedParam(0.1))

num_train_seqs: Param = field(default_factory=lambda: FixedParam(int(1e5 * 32)))
num_eval_seqs: Param = field(default_factory=lambda: FixedParam(int(1e2 * 32)))
train_sampler_cls: Param = field(
default_factory=lambda: FixedParam(samplers.DirichletMultinomialSampler)
)
eval_sampler_cls: Param = field(
default_factory=lambda: FixedParam(samplers.DirichletMultinomialSampler)
)
train_query_type: Param = field(
default_factory=lambda: FixedParam(samplers.QueryType.SUPPORTED)
)
train_context_len: Param = field(default_factory=lambda: FixedParam(2))
train_zipf_exponent: Param = field(default_factory=lambda: FixedParam(1.0))
train_relabeling: Param = field(default_factory=lambda: FixedParam(True))


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

job_folder = Path(
"/tmp",
os.environ["USER"],
"in-ctx",
submit.get_timestamp(),
)

executor = submit.get_submitit_executor(
log_dir=job_folder,
cluster="slurm",
#
### GPU mode. ###
slurm_partition="gpu",
slurm_parallelism=50,
#
### CPU mode. ###
# slurm_partition="cpu",
# gpus_per_node=0,
#
# 24-hour time limit per job.
timeout_min=60 * 24,
)

# Change config here.
cfg = SymbolicSearchConfig(
key=jax.random.PRNGKey(0),
num_configs=500,
seed=UniformParam(0, (1 << 15) - 1),
embed_dim=EnumParam((16, 32, 64)),
num_train_classes=LogUniformParam(20, 2000, base=10),
prop_train_labels=UniformParam(0.25, 1.0),
num_exemplars_per_class=LogUniformParam(1, 1000, base=10),
exemplar_noise_scale=LogUniformParam(1e-1, 1e3, base=10),
)

jobs = submit.submit_jobs(executor, cfg)
Loading

0 comments on commit 91e1889

Please sign in to comment.