generated from eringrant/coding-project-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
2,703 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from .base import DatasetSplit | ||
from .base import ExemplarLabeling | ||
from .base import HoldoutClassLabeling | ||
from .base import Dataset | ||
from .symbolic import SymbolicDataset | ||
|
||
__all__ = ( | ||
"DatasetSplit", | ||
"ExemplarLabeling", | ||
"HoldoutClassLabeling", | ||
"Dataset", | ||
"SymbolicDataset", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
"""`Dataset`s are sequences of unique examples.""" | ||
from typing import Any | ||
from typing import Sequence | ||
from typing import Tuple | ||
from typing import Union | ||
from nptyping import NDArray | ||
from nptyping import Bool | ||
from nptyping import Floating | ||
from nptyping import Int | ||
from jax.random import KeyArray | ||
from jaxtyping import Array | ||
|
||
from enum import Enum | ||
from enum import unique | ||
from functools import cached_property | ||
from functools import partial | ||
import numpy as np | ||
from pathlib import Path | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import jax.nn as jnn | ||
|
||
|
||
# Type hints. | ||
IndexType = Union[int, Sequence[int], slice] | ||
ExemplarType = Tuple[NDArray[Any, Floating], NDArray[Any, Int]] | ||
|
||
|
||
@unique | ||
class ExemplarLabeling(Enum): | ||
# Use the original labels from the dataset. | ||
STANDARD = 1 | ||
# Remove all but the first exemplar from each class. | ||
SINGLE = 2 | ||
# Assign each exemplar a unique label, resulting in as many classes as there | ||
# are exemplars. | ||
SEPARATED = 3 | ||
|
||
|
||
@unique | ||
class HoldoutClassLabeling(Enum): | ||
# Use the original labels from the dataset. | ||
STANDARD = 1 | ||
# Relabel validation and test classes with labels from the training set. | ||
TRAIN_LABELS = 2 | ||
|
||
|
||
@unique | ||
class DatasetSplit(Enum): | ||
TRAIN = 1 | ||
VALID = 2 | ||
TEST = 3 | ||
ALL = 4 | ||
|
||
|
||
def wrap_labels(labels: Array, num_classes: int, modulus: Array) -> Array: | ||
onehot_labels = jnn.one_hot(labels, num_classes) | ||
return (onehot_labels @ modulus).argmax(axis=-1) | ||
|
||
|
||
def get_wrapped_indices( | ||
prop_labels: float, num_classes: int, offset=0 | ||
) -> Tuple[int, Array]: | ||
"""Get indices to wrap `num_classes` into `prop_labels` labels.""" | ||
if prop_labels < 1.0: | ||
num_labels = int(prop_labels * num_classes) | ||
indices = jnp.arange(num_classes) % num_labels | ||
else: | ||
num_labels = num_classes | ||
indices = jnp.arange(num_classes) | ||
indices += offset | ||
return num_labels, indices | ||
|
||
|
||
class Dataset: | ||
|
||
_exemplars: Union[Sequence[Path], NDArray] | ||
_labels: NDArray | ||
|
||
num_train_classes: int | ||
prop_train_labels: float | ||
num_test_classes: int | ||
prop_test_labels: float | ||
num_valid_classes: int | ||
prop_valid_labels: float | ||
|
||
def __init__( | ||
self, | ||
key: KeyArray, | ||
split: DatasetSplit, | ||
exemplar_labeling: ExemplarLabeling, | ||
holdout_class_labeling: HoldoutClassLabeling, | ||
num_train_classes: int, | ||
prop_train_labels: float, | ||
num_test_classes: int, | ||
prop_test_labels: float, | ||
num_valid_classes: int = 0, | ||
prop_valid_labels: float = 0, | ||
num_exemplars_per_class: int = 400, | ||
exemplar_noise_scale: float = 1e-2, | ||
): | ||
"""A `Dataset` of class exemplars from which to draw sequences. | ||
Args: | ||
key: A key for randomness in sampling. | ||
split: Which split of the underlying dataset to use. | ||
exemplar_labeling: How to assign class labels to exemplars from the | ||
underlying dataset (reproduction of | ||
https://github.com/deepmind/emergent_in_context_learning/blob/main/datasets/data_generators.py#L60-L65). | ||
holdout_class_labeling: How to assign class labels to holdout | ||
(validation and testing) splits of this `Dataset`. | ||
num_{train,test,valid}_classes: Number of {training, testing, validation} | ||
classes in this `Dataset`. | ||
prop_{train,test,valid}_labels: Size of the {training, testing, | ||
validation} label set proportional to the underlying class set. If 1.0, | ||
then labels are identical to the underlying class labels; if < 1.0, | ||
then labels are wrapped in increasing order. | ||
num_exemplars_per_class: Number of exemplars per class to draw from the | ||
underlying dataset. | ||
exemplar_noise_scale: The scale of noise to add to each additional exemplar. | ||
""" | ||
|
||
self.num_train_classes = num_train_classes | ||
self.num_valid_classes = num_valid_classes | ||
self.num_test_classes = num_test_classes | ||
self.num_exemplars_per_class = num_exemplars_per_class | ||
self.exemplar_noise_scale = exemplar_noise_scale | ||
|
||
if holdout_class_labeling == HoldoutClassLabeling.TRAIN_LABELS: | ||
if ( | ||
prop_train_labels * num_train_classes < prop_valid_labels * num_valid_classes | ||
or prop_train_labels * num_train_classes < prop_test_labels * num_test_classes | ||
): | ||
raise ValueError( | ||
"Relabeling of validation and test sets with train " | ||
"labels usually assumes more train classes than " | ||
"validation and test classes, but " | ||
f"{prop_train_labels * num_train_classes} < " | ||
f"{prop_valid_labels * num_valid_classes} or " | ||
f"{prop_train_labels * num_train_classes} < " | ||
f"{prop_test_labels * num_test_classes}." | ||
) | ||
|
||
self.num_observed_classes = int(prop_train_labels * self.num_train_classes) | ||
|
||
else: | ||
self.num_observed_classes = ( | ||
int(prop_train_labels * self.num_train_classes) | ||
+ int(prop_valid_labels * self.num_valid_classes) | ||
+ int(prop_test_labels * self.num_test_classes) | ||
) | ||
|
||
if not all( | ||
0.0 < p <= 1.0 | ||
for p in ( | ||
prop_train_labels, | ||
prop_valid_labels, | ||
prop_test_labels, | ||
) | ||
): | ||
raise ValueError( | ||
"One of `prop_{train,valid,test}_labels` was invalid: " | ||
f"{prop_train_labels}, {prop_valid_labels}, {prop_test_labels}." | ||
) | ||
|
||
num_train_labels, train_indices = get_wrapped_indices( | ||
prop_train_labels, num_train_classes | ||
) | ||
num_valid_labels, valid_indices = get_wrapped_indices( | ||
prop_valid_labels, | ||
num_valid_classes, | ||
offset=0 | ||
if holdout_class_labeling == HoldoutClassLabeling.TRAIN_LABELS | ||
else num_train_labels, | ||
) | ||
num_test_labels, test_indices = get_wrapped_indices( | ||
prop_test_labels, | ||
num_test_classes, | ||
offset=0 | ||
if holdout_class_labeling == HoldoutClassLabeling.TRAIN_LABELS | ||
else num_train_labels + num_valid_labels, | ||
) | ||
|
||
indices = jnp.concatenate((train_indices, valid_indices, test_indices)) | ||
modulus = jnp.eye(self.num_classes, dtype=int)[indices, :] | ||
|
||
self.wrap_labels = jax.jit( | ||
partial( | ||
wrap_labels, | ||
num_classes=self.num_classes, | ||
modulus=modulus, | ||
) | ||
) | ||
|
||
def __len__(self) -> int: | ||
return len(self._exemplars) | ||
|
||
@property | ||
def num_classes(self) -> int: | ||
return self.num_train_classes + self.num_valid_classes + self.num_test_classes | ||
|
||
@property | ||
def exemplar_shape(self) -> Tuple[int]: | ||
raise NotImplementedError("To be implemented by the subclass.") | ||
|
||
def __getitem__(self, index: Union[int, slice]) -> ExemplarType: | ||
raise NotImplementedError("To be implemented by the subclass.") | ||
|
||
@cached_property | ||
def unique_classes(self) -> Sequence[int]: | ||
"""Deterministic ordering of dataset class labels.""" | ||
return np.unique(self._labels).tolist() | ||
|
||
@cached_property | ||
def train_classes(self) -> Sequence[int]: | ||
i = self.num_train_classes | ||
return self.unique_classes[:i] | ||
|
||
@cached_property | ||
def valid_classes(self) -> Sequence[int]: | ||
i = self.num_train_classes | ||
j = self.num_train_classes + self.num_valid_classes | ||
return self.unique_classes[i:j] | ||
|
||
@cached_property | ||
def test_classes(self) -> Sequence[int]: | ||
j = self.num_train_classes + self.num_valid_classes | ||
k = self.num_train_classes + self.num_valid_classes + self.num_test_classes | ||
return self.unique_classes[j:k] | ||
|
||
@cached_property | ||
def _train_idx(self) -> NDArray[Any, Bool]: | ||
"""Mask for the train split.""" | ||
return np.in1d(self._labels, self.train_classes) | ||
|
||
@cached_property | ||
def _valid_idx(self) -> NDArray[Any, Bool]: | ||
"""Mask for the validation split.""" | ||
return np.in1d(self._labels, self.valid_classes) | ||
|
||
@cached_property | ||
def _test_idx(self) -> NDArray[Any, Bool]: | ||
"""Mask for the test split.""" | ||
return np.in1d(self._labels, self.test_classes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from typing import Tuple | ||
from typing import Union | ||
from jax.random import KeyArray | ||
|
||
from functools import partial | ||
import numpy as np | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import jax.nn as jnn | ||
|
||
from nets.datasets.base import Dataset | ||
from nets.datasets.base import DatasetSplit | ||
from nets.datasets.base import ExemplarLabeling | ||
from nets.datasets.base import ExemplarType | ||
from nets.datasets.base import HoldoutClassLabeling | ||
|
||
|
||
class SymbolicDataset(Dataset): | ||
|
||
_exemplars: np.ndarray | ||
_labels: np.ndarray | ||
num_train_classes: int | ||
num_valid_classes: int | ||
num_test_classes: int | ||
num_train_labels: int | ||
num_valid_labels: int | ||
num_test_labels: int | ||
|
||
def __init__( | ||
self, | ||
key: KeyArray, | ||
split: DatasetSplit, | ||
exemplar_labeling: ExemplarLabeling, | ||
holdout_class_labeling: HoldoutClassLabeling, | ||
num_train_classes: int, | ||
prop_train_labels: float, | ||
num_test_classes: int, | ||
prop_test_labels: float, | ||
num_valid_classes: int = 0, | ||
prop_valid_labels: float = 0, | ||
num_exemplars_per_class: int = 400, | ||
exemplar_noise_scale: float = 1e-2, | ||
): | ||
"""A `SymbolicDataset` of class exemplars from which to draw sequences. | ||
Args: | ||
...`Dataset` args... | ||
""" | ||
super().__init__( | ||
key=key, | ||
split=split, | ||
exemplar_labeling=exemplar_labeling, | ||
holdout_class_labeling=holdout_class_labeling, | ||
num_train_classes=num_train_classes, | ||
prop_train_labels=prop_train_labels, | ||
num_test_classes=num_test_classes, | ||
prop_test_labels=prop_test_labels, | ||
num_valid_classes=num_valid_classes, | ||
prop_valid_labels=prop_valid_labels, | ||
num_exemplars_per_class=num_exemplars_per_class, | ||
exemplar_noise_scale=exemplar_noise_scale, | ||
) | ||
|
||
# Exemplar generation for `SymbolicDataset`. | ||
labels = np.arange(self.num_classes) | ||
|
||
if num_exemplars_per_class > 1: | ||
labels = np.repeat( | ||
labels[:, np.newaxis], num_exemplars_per_class, axis=-1 | ||
).reshape(-1) | ||
|
||
# TODO(eringrant): Deal with this params. | ||
del exemplar_labeling | ||
|
||
self._labels = labels | ||
|
||
if self.num_exemplars_per_class > 1: | ||
self._exemplar_keys = jax.random.split( | ||
key, self.num_classes * num_exemplars_per_class | ||
) | ||
|
||
# Compile functions for sampling at `Dataset.__init__`. | ||
self.generate_exemplar = jax.jit( | ||
jax.vmap( | ||
jax.vmap( | ||
partial( | ||
jax.random.multivariate_normal, | ||
# Isotropic with scale a/C to keep noise level in embeddings constant. | ||
cov=exemplar_noise_scale / self.num_classes * jnp.eye(self.num_classes), | ||
) | ||
) | ||
) | ||
) | ||
|
||
@property | ||
def exemplar_shape(self) -> Tuple[int]: | ||
return (self.num_classes,) | ||
|
||
def __getitem__(self, index: Union[int, slice]) -> ExemplarType: | ||
labels = self._labels[index] | ||
onehot_labels = jnn.one_hot(labels, self.num_classes) | ||
|
||
if self.num_exemplars_per_class == 1: | ||
exemplars = onehot_labels | ||
|
||
else: | ||
exemplar_key = self._exemplar_keys[index] | ||
|
||
# TODO(eringrant): Deal with other `index` shapes. | ||
if isinstance(index, int): | ||
exemplar_key = jnp.expand_dims(exemplar_key, 0) # type: ignore[arg-type] | ||
|
||
exemplars = self.generate_exemplar( | ||
key=exemplar_key, | ||
mean=onehot_labels, | ||
) | ||
|
||
labels = self.wrap_labels(labels) | ||
|
||
if isinstance(index, int): | ||
assert len(exemplars) == 1 and len(labels) == 1 | ||
exemplars = exemplars[0] | ||
labels = labels[0] | ||
|
||
return exemplars, labels |
Empty file.
Oops, something went wrong.