Skip to content

Commit

Permalink
running OSGD
Browse files Browse the repository at this point in the history
  • Loading branch information
eringrant committed Sep 17, 2023
1 parent 46bd732 commit 5f1b84f
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 259 deletions.
2 changes: 2 additions & 0 deletions nets/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from .base import HoldoutClassLabeling
from .base import Dataset
from .symbolic import SymbolicDataset
from .parity import ParityDataset

__all__ = (
"DatasetSplit",
"ExemplarLabeling",
"HoldoutClassLabeling",
"Dataset",
"SymbolicDataset",
"ParityDataset",
)
72 changes: 13 additions & 59 deletions nets/experiments/analyzable_online_sgd/launcher_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,8 @@ class SearchConfig(configs.Config):
seed: Param = field(default_factory=lambda: EnumParam(range(0, 3)))

# Model params.
embed_dim: Param = field(init=False)
num_heads: Param = field(init=False)
depth: Param = field(init=False)
mlp_ratio: Param = field(init=False)
causal: Param = field(init=False)
num_hiddens: Param = field(init=False)
init_scale: Param = field(init=False)

# Training and evaluation params.
optimizer_fn: Param = field(default_factory=lambda: FixedParam(optax.adam))
Expand All @@ -43,27 +40,13 @@ class SearchConfig(configs.Config):
evaluate_on_test_split: Param = field(default_factory=lambda: FixedParam(False))

# Dataset params.
num_train_classes: Param = field(init=False) # `init=False` toavoid init
num_valid_classes: Param = field(init=False) # of to-be-overridden value.
num_test_classes: Param = field(init=False)
prop_train_labels: Param = field(init=False)
prop_valid_labels: Param = field(init=False)
prop_test_labels: Param = field(init=False)
dataset_cls: Param = field(init=False)
exemplar_labeling: Param = field(init=False)
holdout_class_labeling: Param = field(init=False)
num_dimensions: Param = field(init=False)
num_exemplars_per_class: Param = field(init=False)
exemplar_noise_scale: Param = field(init=False)

# Sampler params.
num_train_seqs: Param = field(init=False)
num_eval_seqs: Param = field(init=False)
train_sampler_cls: Param = field(init=False)
eval_sampler_cls: Param = field(init=False)
train_query_type: Param = field(init=False)
train_context_len: Param = field(init=False)
train_zipf_exponent: Param = field(init=False)
train_relabeling: Param = field(init=False)
sampler_cls: Param = field(init=False)


@dataclass(frozen=True, kw_only=True)
Expand All @@ -77,44 +60,14 @@ class DebugSearchConfig(SearchConfig):
evaluations_per_epoch: Param = field(default_factory=lambda: FixedParam(1))

# Teeny tiny model.
embed_dim: Param = field(default_factory=lambda: FixedParam(8))
num_heads: Param = field(default_factory=lambda: FixedParam(8))
depth: Param = field(default_factory=lambda: FixedParam(2))
mlp_ratio: Param = field(default_factory=lambda: FixedParam(4.0))
causal: Param = field(default_factory=lambda: FixedParam(True))

num_train_classes: Param = field(default_factory=lambda: FixedParam(80))
num_valid_classes: Param = field(default_factory=lambda: FixedParam(20))
num_test_classes: Param = field(default_factory=lambda: FixedParam(16))
prop_train_labels: Param = field(default_factory=lambda: FixedParam(0.8))
prop_valid_labels: Param = field(default_factory=lambda: FixedParam(0.7))
prop_test_labels: Param = field(default_factory=lambda: FixedParam(0.3))
dataset_cls: Param = field(
default_factory=lambda: FixedParam(datasets.SymbolicDataset)
)
exemplar_labeling: Param = field(
default_factory=lambda: FixedParam(datasets.ExemplarLabeling.STANDARD)
)
holdout_class_labeling: Param = field(
default_factory=lambda: FixedParam(datasets.HoldoutClassLabeling.STANDARD)
)
num_exemplars_per_class: Param = field(default_factory=lambda: FixedParam(20))
exemplar_noise_scale: Param = field(default_factory=lambda: FixedParam(1.0))
num_hiddens: Param = field(default_factory=lambda: FixedParam(8))
init_scale: Param = field(default_factory=lambda: FixedParam(1.0))

num_train_seqs: Param = field(default_factory=lambda: FixedParam(int(1e3)))
num_eval_seqs: Param = field(default_factory=lambda: FixedParam(int(1e2)))
train_sampler_cls: Param = field(
default_factory=lambda: FixedParam(samplers.DirichletMultinomialSampler)
)
eval_sampler_cls: Param = field(
default_factory=lambda: FixedParam(samplers.DirichletMultinomialSampler)
)
train_query_type: Param = field(
default_factory=lambda: FixedParam(samplers.QueryType.SUPPORTED)
)
train_context_len: Param = field(default_factory=lambda: FixedParam(2))
train_zipf_exponent: Param = field(default_factory=lambda: FixedParam(1.0))
train_relabeling: Param = field(default_factory=lambda: FixedParam(False))
dataset_cls: Param = field(default_factory=lambda: FixedParam(datasets.ParityDataset))
num_dimensions: Param = field(default_factory=lambda: FixedParam(2))
num_exemplars_per_class: Param = field(default_factory=lambda: FixedParam(16))
exemplar_noise_scale: Param = field(default_factory=lambda: FixedParam(0.1))
sampler_cls: Param = field(default_factory=lambda: FixedParam(samplers.EpochSampler))


if __name__ == "__main__":
Expand All @@ -124,7 +77,7 @@ class DebugSearchConfig(SearchConfig):
cluster="debug",
log_dir=Path(
nets.SCRATCH_DIR,
"in-ctx",
"osgd",
submit.get_timestamp(),
),
gpus_per_node=0,
Expand All @@ -134,6 +87,7 @@ class DebugSearchConfig(SearchConfig):
executor=executor,
fn=simulate,
cfg=DebugSearchConfig(
num_epochs=FixedParam(10),
key=jax.random.PRNGKey(0),
num_configs=1,
),
Expand Down
6 changes: 5 additions & 1 deletion nets/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
"""Neural network models."""
from .transformers import SequenceClassifier
from .feedforward import MLP

__all__ = ("SequenceClassifier",)
__all__ = (
"MLP",
"SequenceClassifier",
)
18 changes: 14 additions & 4 deletions nets/models/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def trunc_normal_init(

# Adapted from https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/initializers.py.
def lecun_normal_init(
weight: Array, key: KeyArray, scale: float | None = None
weight: Array,
key: KeyArray,
scale: float = 1.0,
) -> Array:
"""LeCun (variance-scaling) normal distribution initialization."""
_, in_ = weight.shape
Expand Down Expand Up @@ -66,7 +68,7 @@ def __init__(
trainable: bool = True,
*,
key: KeyArray,
init_scale: float | None = 1.0,
init_scale: float = 1.0,
):
"""Initialize a linear layer."""
super().__init__(
Expand Down Expand Up @@ -107,6 +109,7 @@ def __init__(
drop: float | tuple[float] = 0.0,
*,
key: KeyArray = None,
init_scale: float = 1.0,
):
"""Initialize an MLP.
Expand All @@ -118,6 +121,7 @@ def __init__(
drop: The probability associated with `Dropout`.
key: A `jax.random.PRNGKey` used to provide randomness for parameter
initialisation.
init_scale: The scale of the variance of the initial weights.
"""
super().__init__()
out_features = out_features or in_features
Expand All @@ -126,12 +130,18 @@ def __init__(
keys = jrandom.split(key, 2)

self.fc1 = Linear(
in_features=in_features, out_features=hidden_features, key=keys[0]
in_features=in_features,
out_features=hidden_features,
key=keys[0],
init_scale=init_scale,
)
self.act = act
self.drop1 = enn.Dropout(drop_probs[0])
self.fc2 = Linear(
in_features=hidden_features, out_features=out_features, key=keys[1]
in_features=hidden_features,
out_features=out_features,
key=keys[1],
init_scale=init_scale,
)
self.drop2 = enn.Dropout(drop_probs[1])

Expand Down
2 changes: 2 additions & 0 deletions nets/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .base import Sampler
from .base import SequenceSampler
from .base import SingletonSampler
from .base import EpochSampler
from .base import ClassificationSequenceSampler
from .dirichlet_multinomial import DirichletMultinomialSampler

Expand All @@ -11,6 +12,7 @@
"Sampler",
"SequenceSampler",
"SingletonSampler",
"EpochSampler",
"ClassificationSequenceSampler",
"DirichletMultinomialSampler",
)
63 changes: 63 additions & 0 deletions nets/samplers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,69 @@ class SingletonSampler(Sampler):
"""Sampler of a sequence of examples."""


# TODO(eringrant): Is this faster than a recursive call to `__getitem__`?
def slice_to_array(s: slice, array_length: int) -> Array:
"""Convert a `slice` object to an array of indices."""
start = s.start if s.start is not None else 0
stop = s.stop if s.stop is not None else array_length
step = s.step if s.step is not None else 1

return jnp.array(range(start, stop, step))


class EpochSampler(SingletonSampler):
"""Sampler of example-label pairs over multiple epochs."""

def __init__(
self,
key: KeyArray,
dataset: Dataset,
num_epochs: int | None = None,
):
"""Sampler of example-label pairs over multiple epochs."""
self.key = key
self.dataset = dataset
self.num_epochs = num_epochs
self.epoch_count = 0
self.index_in_epoch = 0

self.dataset_size = len(self.dataset)

def __len__(self) -> int:
"""Return the number of example-label pairs in `Sampler`."""
if self.num_epochs is None:
return int(float("inf")) # Infinite length if num_epochs is not set
return self.num_epochs * self.dataset_size

def __getitem__(self, index: int | slice) -> ExemplarType:
"""Return exemplar-class pairs at index `index` of `Sampler`."""
# TODO(eringrant): Simplify this while maintaining type-validity.
if isinstance(index, slice):
transformed_index = slice_to_array(index, len(self))
else:
transformed_index = index

epoch_idx = transformed_index // self.dataset_size
if not isinstance(epoch_idx, int):
unique_vals = jnp.unique(epoch_idx)
if unique_vals.size != 1:
# TODO(eringrant): Implement this case.
raise ValueError("Array should contain only one unique value.")
epoch_idx = unique_vals[0]
index_in_epoch = transformed_index % self.dataset_size

if self.num_epochs is not None and epoch_idx >= self.num_epochs:
raise StopIteration("Reached the end of data generation.")

epoch_key = jax.random.fold_in(self.key, epoch_idx)
permuted_index = jax.random.permutation(
epoch_key,
jnp.arange(self.dataset_size),
)[index_in_epoch]

return self.dataset[permuted_index]


class SequenceSampler(Sampler):
"""Sampler of context + query sequences for in-context learning."""

Expand Down
Loading

0 comments on commit 5f1b84f

Please sign in to comment.