Skip to content

Commit

Permalink
initial code dump
Browse files Browse the repository at this point in the history
  • Loading branch information
eringrant committed Sep 14, 2023
1 parent 6f735d4 commit 2de44b0
Show file tree
Hide file tree
Showing 16 changed files with 2,703 additions and 1 deletion.
13 changes: 13 additions & 0 deletions nets/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .base import DatasetSplit
from .base import ExemplarLabeling
from .base import HoldoutClassLabeling
from .base import Dataset
from .symbolic import SymbolicDataset

__all__ = (
"DatasetSplit",
"ExemplarLabeling",
"HoldoutClassLabeling",
"Dataset",
"SymbolicDataset",
)
245 changes: 245 additions & 0 deletions nets/datasets/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
"""`Dataset`s are sequences of unique examples."""
from typing import Any
from typing import Sequence
from typing import Tuple
from typing import Union
from nptyping import NDArray
from nptyping import Bool
from nptyping import Floating
from nptyping import Int
from jax.random import KeyArray
from jaxtyping import Array

from enum import Enum
from enum import unique
from functools import cached_property
from functools import partial
import numpy as np
from pathlib import Path

import jax
import jax.numpy as jnp
import jax.nn as jnn


# Type hints.
IndexType = Union[int, Sequence[int], slice]
ExemplarType = Tuple[NDArray[Any, Floating], NDArray[Any, Int]]


@unique
class ExemplarLabeling(Enum):
# Use the original labels from the dataset.
STANDARD = 1
# Remove all but the first exemplar from each class.
SINGLE = 2
# Assign each exemplar a unique label, resulting in as many classes as there
# are exemplars.
SEPARATED = 3


@unique
class HoldoutClassLabeling(Enum):
# Use the original labels from the dataset.
STANDARD = 1
# Relabel validation and test classes with labels from the training set.
TRAIN_LABELS = 2


@unique
class DatasetSplit(Enum):
TRAIN = 1
VALID = 2
TEST = 3
ALL = 4


def wrap_labels(labels: Array, num_classes: int, modulus: Array) -> Array:
onehot_labels = jnn.one_hot(labels, num_classes)
return (onehot_labels @ modulus).argmax(axis=-1)


def get_wrapped_indices(
prop_labels: float, num_classes: int, offset=0
) -> Tuple[int, Array]:
"""Get indices to wrap `num_classes` into `prop_labels` labels."""
if prop_labels < 1.0:
num_labels = int(prop_labels * num_classes)
indices = jnp.arange(num_classes) % num_labels
else:
num_labels = num_classes
indices = jnp.arange(num_classes)
indices += offset
return num_labels, indices


class Dataset:

_exemplars: Union[Sequence[Path], NDArray]
_labels: NDArray

num_train_classes: int
prop_train_labels: float
num_test_classes: int
prop_test_labels: float
num_valid_classes: int
prop_valid_labels: float

def __init__(
self,
key: KeyArray,
split: DatasetSplit,
exemplar_labeling: ExemplarLabeling,
holdout_class_labeling: HoldoutClassLabeling,
num_train_classes: int,
prop_train_labels: float,
num_test_classes: int,
prop_test_labels: float,
num_valid_classes: int = 0,
prop_valid_labels: float = 0,
num_exemplars_per_class: int = 400,
exemplar_noise_scale: float = 1e-2,
):
"""A `Dataset` of class exemplars from which to draw sequences.
Args:
key: A key for randomness in sampling.
split: Which split of the underlying dataset to use.
exemplar_labeling: How to assign class labels to exemplars from the
underlying dataset (reproduction of
https://github.com/deepmind/emergent_in_context_learning/blob/main/datasets/data_generators.py#L60-L65).
holdout_class_labeling: How to assign class labels to holdout
(validation and testing) splits of this `Dataset`.
num_{train,test,valid}_classes: Number of {training, testing, validation}
classes in this `Dataset`.
prop_{train,test,valid}_labels: Size of the {training, testing,
validation} label set proportional to the underlying class set. If 1.0,
then labels are identical to the underlying class labels; if < 1.0,
then labels are wrapped in increasing order.
num_exemplars_per_class: Number of exemplars per class to draw from the
underlying dataset.
exemplar_noise_scale: The scale of noise to add to each additional exemplar.
"""

self.num_train_classes = num_train_classes
self.num_valid_classes = num_valid_classes
self.num_test_classes = num_test_classes
self.num_exemplars_per_class = num_exemplars_per_class
self.exemplar_noise_scale = exemplar_noise_scale

if holdout_class_labeling == HoldoutClassLabeling.TRAIN_LABELS:
if (
prop_train_labels * num_train_classes < prop_valid_labels * num_valid_classes
or prop_train_labels * num_train_classes < prop_test_labels * num_test_classes
):
raise ValueError(
"Relabeling of validation and test sets with train "
"labels usually assumes more train classes than "
"validation and test classes, but "
f"{prop_train_labels * num_train_classes} < "
f"{prop_valid_labels * num_valid_classes} or "
f"{prop_train_labels * num_train_classes} < "
f"{prop_test_labels * num_test_classes}."
)

self.num_observed_classes = int(prop_train_labels * self.num_train_classes)

else:
self.num_observed_classes = (
int(prop_train_labels * self.num_train_classes)
+ int(prop_valid_labels * self.num_valid_classes)
+ int(prop_test_labels * self.num_test_classes)
)

if not all(
0.0 < p <= 1.0
for p in (
prop_train_labels,
prop_valid_labels,
prop_test_labels,
)
):
raise ValueError(
"One of `prop_{train,valid,test}_labels` was invalid: "
f"{prop_train_labels}, {prop_valid_labels}, {prop_test_labels}."
)

num_train_labels, train_indices = get_wrapped_indices(
prop_train_labels, num_train_classes
)
num_valid_labels, valid_indices = get_wrapped_indices(
prop_valid_labels,
num_valid_classes,
offset=0
if holdout_class_labeling == HoldoutClassLabeling.TRAIN_LABELS
else num_train_labels,
)
num_test_labels, test_indices = get_wrapped_indices(
prop_test_labels,
num_test_classes,
offset=0
if holdout_class_labeling == HoldoutClassLabeling.TRAIN_LABELS
else num_train_labels + num_valid_labels,
)

indices = jnp.concatenate((train_indices, valid_indices, test_indices))
modulus = jnp.eye(self.num_classes, dtype=int)[indices, :]

self.wrap_labels = jax.jit(
partial(
wrap_labels,
num_classes=self.num_classes,
modulus=modulus,
)
)

def __len__(self) -> int:
return len(self._exemplars)

@property
def num_classes(self) -> int:
return self.num_train_classes + self.num_valid_classes + self.num_test_classes

@property
def exemplar_shape(self) -> Tuple[int]:
raise NotImplementedError("To be implemented by the subclass.")

def __getitem__(self, index: Union[int, slice]) -> ExemplarType:
raise NotImplementedError("To be implemented by the subclass.")

@cached_property
def unique_classes(self) -> Sequence[int]:
"""Deterministic ordering of dataset class labels."""
return np.unique(self._labels).tolist()

@cached_property
def train_classes(self) -> Sequence[int]:
i = self.num_train_classes
return self.unique_classes[:i]

@cached_property
def valid_classes(self) -> Sequence[int]:
i = self.num_train_classes
j = self.num_train_classes + self.num_valid_classes
return self.unique_classes[i:j]

@cached_property
def test_classes(self) -> Sequence[int]:
j = self.num_train_classes + self.num_valid_classes
k = self.num_train_classes + self.num_valid_classes + self.num_test_classes
return self.unique_classes[j:k]

@cached_property
def _train_idx(self) -> NDArray[Any, Bool]:
"""Mask for the train split."""
return np.in1d(self._labels, self.train_classes)

@cached_property
def _valid_idx(self) -> NDArray[Any, Bool]:
"""Mask for the validation split."""
return np.in1d(self._labels, self.valid_classes)

@cached_property
def _test_idx(self) -> NDArray[Any, Bool]:
"""Mask for the test split."""
return np.in1d(self._labels, self.test_classes)
126 changes: 126 additions & 0 deletions nets/datasets/symbolic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import Tuple
from typing import Union
from jax.random import KeyArray

from functools import partial
import numpy as np

import jax
import jax.numpy as jnp
import jax.nn as jnn

from nets.datasets.base import Dataset
from nets.datasets.base import DatasetSplit
from nets.datasets.base import ExemplarLabeling
from nets.datasets.base import ExemplarType
from nets.datasets.base import HoldoutClassLabeling


class SymbolicDataset(Dataset):

_exemplars: np.ndarray
_labels: np.ndarray
num_train_classes: int
num_valid_classes: int
num_test_classes: int
num_train_labels: int
num_valid_labels: int
num_test_labels: int

def __init__(
self,
key: KeyArray,
split: DatasetSplit,
exemplar_labeling: ExemplarLabeling,
holdout_class_labeling: HoldoutClassLabeling,
num_train_classes: int,
prop_train_labels: float,
num_test_classes: int,
prop_test_labels: float,
num_valid_classes: int = 0,
prop_valid_labels: float = 0,
num_exemplars_per_class: int = 400,
exemplar_noise_scale: float = 1e-2,
):
"""A `SymbolicDataset` of class exemplars from which to draw sequences.
Args:
...`Dataset` args...
"""
super().__init__(
key=key,
split=split,
exemplar_labeling=exemplar_labeling,
holdout_class_labeling=holdout_class_labeling,
num_train_classes=num_train_classes,
prop_train_labels=prop_train_labels,
num_test_classes=num_test_classes,
prop_test_labels=prop_test_labels,
num_valid_classes=num_valid_classes,
prop_valid_labels=prop_valid_labels,
num_exemplars_per_class=num_exemplars_per_class,
exemplar_noise_scale=exemplar_noise_scale,
)

# Exemplar generation for `SymbolicDataset`.
labels = np.arange(self.num_classes)

if num_exemplars_per_class > 1:
labels = np.repeat(
labels[:, np.newaxis], num_exemplars_per_class, axis=-1
).reshape(-1)

# TODO(eringrant): Deal with this params.
del exemplar_labeling

self._labels = labels

if self.num_exemplars_per_class > 1:
self._exemplar_keys = jax.random.split(
key, self.num_classes * num_exemplars_per_class
)

# Compile functions for sampling at `Dataset.__init__`.
self.generate_exemplar = jax.jit(
jax.vmap(
jax.vmap(
partial(
jax.random.multivariate_normal,
# Isotropic with scale a/C to keep noise level in embeddings constant.
cov=exemplar_noise_scale / self.num_classes * jnp.eye(self.num_classes),
)
)
)
)

@property
def exemplar_shape(self) -> Tuple[int]:
return (self.num_classes,)

def __getitem__(self, index: Union[int, slice]) -> ExemplarType:
labels = self._labels[index]
onehot_labels = jnn.one_hot(labels, self.num_classes)

if self.num_exemplars_per_class == 1:
exemplars = onehot_labels

else:
exemplar_key = self._exemplar_keys[index]

# TODO(eringrant): Deal with other `index` shapes.
if isinstance(index, int):
exemplar_key = jnp.expand_dims(exemplar_key, 0) # type: ignore[arg-type]

exemplars = self.generate_exemplar(
key=exemplar_key,
mean=onehot_labels,
)

labels = self.wrap_labels(labels)

if isinstance(index, int):
assert len(exemplars) == 1 and len(labels) == 1
exemplars = exemplars[0]
labels = labels[0]

return exemplars, labels
Empty file added nets/launch/__init__.py
Empty file.
Loading

0 comments on commit 2de44b0

Please sign in to comment.