Skip to content

Commit

Permalink
fix typing errors
Browse files Browse the repository at this point in the history
  • Loading branch information
eringrant committed Sep 16, 2023
1 parent 202384a commit cf1fd7e
Show file tree
Hide file tree
Showing 15 changed files with 379 additions and 239 deletions.
1 change: 1 addition & 0 deletions nets/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""`Dataset`s to accompany models."""
from .base import DatasetSplit
from .base import ExemplarLabeling
from .base import HoldoutClassLabeling
Expand Down
51 changes: 34 additions & 17 deletions nets/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""`Dataset`s are sequences of unique examples."""
from typing import Any
from collections.abc import Sequence
from typing import Union
from nptyping import NDArray
from nptyping import Bool
from nptyping import Floating
Expand All @@ -22,12 +21,14 @@


# Type hints.
IndexType = Union[int, Sequence[int], slice]
IndexType = int | Sequence[int] | slice
ExemplarType = tuple[NDArray[Any, Floating], NDArray[Any, Int]]


@unique
class ExemplarLabeling(Enum):
"""How to assign class labels to exemplars from the underlying dataset."""

# Use the original labels from the dataset.
STANDARD = 1
# Remove all but the first exemplar from each class.
Expand All @@ -39,6 +40,8 @@ class ExemplarLabeling(Enum):

@unique
class HoldoutClassLabeling(Enum):
"""How to assign class labels to holdout (validation and testing) splits."""

# Use the original labels from the dataset.
STANDARD = 1
# Relabel validation and test classes with labels from the training set.
Expand All @@ -47,13 +50,16 @@ class HoldoutClassLabeling(Enum):

@unique
class DatasetSplit(Enum):
"""Which split of the underlying dataset to use."""

TRAIN = 1
VALID = 2
TEST = 3
ALL = 4


def wrap_labels(labels: Array, num_classes: int, modulus: Array) -> Array:
"""Wrap `num_classes` into `labels`."""
onehot_labels = jnn.one_hot(labels, num_classes)
return (onehot_labels @ modulus).argmax(axis=-1)

Expand All @@ -73,6 +79,8 @@ def get_wrapped_indices(


class Dataset:
"""A `Dataset` of class exemplars from which to draw sequences."""

_exemplars: Sequence[Path] | NDArray
_labels: NDArray

Expand All @@ -96,33 +104,35 @@ def __init__(
num_valid_classes: int = 0,
prop_valid_labels: float = 0,
num_exemplars_per_class: int = 400,
exemplar_noise_scale: float = 1e-2,
):
"""A `Dataset` of class exemplars from which to draw sequences.
Args:
key: A key for randomness in sampling.
split: Which split of the underlying dataset to use.
exemplar_labeling: How to assign class labels to exemplars from the
underlying dataset (reproduction of
https://github.com/deepmind/emergent_in_context_learning/blob/main/datasets/data_generators.py#L60-L65).
holdout_class_labeling: How to assign class labels to holdout
(validation and testing) splits of this `Dataset`.
num_{train,test,valid}_classes: Number of {training, testing, validation}
classes in this `Dataset`.
prop_{train,test,valid}_labels: Size of the {training, testing,
validation} label set proportional to the underlying class set. If 1.0,
then labels are identical to the underlying class labels; if < 1.0,
then labels are wrapped in increasing order.
exemplar_labeling: How to assign class labels to exemplars from the underlying
dataset.
holdout_class_labeling: How to assign class labels to holdout (validation and
testing) splits of this `Dataset`.
num_train_classes: Number of training classes in this `Dataset`.
prop_train_labels: Size of the training label set proportional to the underlying
class set. If 1.0, then labels are identical to the underlying class labels;
if < 1.0, then labels are wrapped in increasing order.
num_valid_classes: Number of validation classes in this `Dataset`.
prop_valid_labels: Size of the validation label set proportional to the
underlying class set. If 1.0, then labels are identical to the underlying
class labels; if < 1.0, then labels are wrapped in increasing order.
num_test_classes: Number of testing classes in this `Dataset`.
prop_test_labels: Size of the testing label set proportional to the underlying
class set. If 1.0, then labels are identical to the underlying class labels;
if < 1.0, then labels are wrapped in increasing order.
num_exemplars_per_class: Number of exemplars per class to draw from the
underlying dataset.
exemplar_noise_scale: The scale of noise to add to each additional exemplar.
underlying dataset.
"""
self.num_train_classes = num_train_classes
self.num_valid_classes = num_valid_classes
self.num_test_classes = num_test_classes
self.num_exemplars_per_class = num_exemplars_per_class
self.exemplar_noise_scale = exemplar_noise_scale

if holdout_class_labeling == HoldoutClassLabeling.TRAIN_LABELS:
if (
Expand Down Expand Up @@ -191,17 +201,21 @@ def __init__(
)

def __len__(self) -> int:
"""Number of exemplars in this `Dataset`."""
return len(self._exemplars)

@property
def num_classes(self) -> int:
"""Number of classes in this `Dataset`."""
return self.num_train_classes + self.num_valid_classes + self.num_test_classes

@property
def exemplar_shape(self) -> tuple[int]:
"""Shape of an exemplar."""
raise NotImplementedError("To be implemented by the subclass.")

def __getitem__(self, index: int | slice) -> ExemplarType:
"""Get the exemplar(s) and the corresponding label(s) at `index`."""
raise NotImplementedError("To be implemented by the subclass.")

@cached_property
Expand All @@ -211,17 +225,20 @@ def unique_classes(self) -> Sequence[int]:

@cached_property
def train_classes(self) -> Sequence[int]:
"""Deterministic ordering of training class labels."""
i = self.num_train_classes
return self.unique_classes[:i]

@cached_property
def valid_classes(self) -> Sequence[int]:
"""Deterministic ordering of validation class labels."""
i = self.num_train_classes
j = self.num_train_classes + self.num_valid_classes
return self.unique_classes[i:j]

@cached_property
def test_classes(self) -> Sequence[int]:
"""Deterministic ordering of testing class labels."""
j = self.num_train_classes + self.num_valid_classes
k = self.num_train_classes + self.num_valid_classes + self.num_test_classes
return self.unique_classes[j:k]
Expand Down
42 changes: 34 additions & 8 deletions nets/datasets/symbolic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""A `SymbolicDataset` of class exemplars from which to draw sequences."""
from jax.random import KeyArray

from functools import partial
Expand All @@ -15,14 +16,17 @@


class SymbolicDataset(Dataset):
"""A `SymbolicDataset` of class exemplars from which to draw sequences."""

_exemplars: np.ndarray
_labels: np.ndarray

num_train_classes: int
num_valid_classes: int
prop_train_labels: int
prop_test_labels: int
num_test_classes: int
num_train_labels: int
num_valid_labels: int
num_test_labels: int
num_valid_classes: int
prop_valid_labels: int

def __init__(
self,
Expand All @@ -42,7 +46,27 @@ def __init__(
"""A `SymbolicDataset` of class exemplars from which to draw sequences.
Args:
...`Dataset` args...
key: A key for randomness in sampling.
split: Which split of the underlying dataset to use.
exemplar_labeling: How to assign class labels to exemplars from the underlying
dataset.
holdout_class_labeling: How to assign class labels to holdout (validation and
testing) splits of this `Dataset`.
num_train_classes: Number of training classes in this `Dataset`.
prop_train_labels: Size of the training label set proportional to the underlying
class set. If 1.0, then labels are identical to the underlying class labels;
if < 1.0, then labels are wrapped in increasing order.
num_valid_classes: Number of validation classes in this `Dataset`.
prop_valid_labels: Size of the validation label set proportional to the
underlying class set. If 1.0, then labels are identical to the underlying
class labels; if < 1.0, then labels are wrapped in increasing order.
num_test_classes: Number of testing classes in this `Dataset`.
prop_test_labels: Size of the testing label set proportional to the underlying
class set. If 1.0, then labels are identical to the underlying class labels;
if < 1.0, then labels are wrapped in increasing order.
num_exemplars_per_class: Number of exemplars per class to draw from the
underlying dataset.
exemplar_noise_scale: Scale of the noise to add to exemplars.
"""
super().__init__(
key=key,
Expand All @@ -56,9 +80,10 @@ def __init__(
num_valid_classes=num_valid_classes,
prop_valid_labels=prop_valid_labels,
num_exemplars_per_class=num_exemplars_per_class,
exemplar_noise_scale=exemplar_noise_scale,
)

self.exemplar_noise_scale = exemplar_noise_scale

# Exemplar generation for `SymbolicDataset`.
labels = np.arange(self.num_classes)

Expand Down Expand Up @@ -92,9 +117,11 @@ def __init__(

@property
def exemplar_shape(self) -> tuple[int]:
"""Shape of an exemplar."""
return (self.num_classes,)

def __getitem__(self, index: int | slice) -> ExemplarType:
"""Get the exemplar(s) and the corresponding label(s) at `index`."""
labels = self._labels[index]
onehot_labels = jnn.one_hot(labels, self.num_classes)

Expand All @@ -104,9 +131,8 @@ def __getitem__(self, index: int | slice) -> ExemplarType:
else:
exemplar_key = self._exemplar_keys[index]

# TODO(eringrant): Deal with other `index` shapes.
if isinstance(index, int):
exemplar_key = jnp.expand_dims(exemplar_key, 0) # type: ignore[arg-type]
exemplar_key = jnp.expand_dims(exemplar_key, 0)

exemplars = self.generate_exemplar(
key=exemplar_key,
Expand Down
1 change: 1 addition & 0 deletions nets/launch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Utilities to launch jobs."""
8 changes: 6 additions & 2 deletions nets/launch/analyze.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Utilities for analyzing results."""
import dill as pickle
import logging
import os
Expand All @@ -22,6 +23,7 @@


def compress_df(df: pd.DataFrame) -> pd.DataFrame:
"""Compress `df` by optimizing data types."""
for col in df.select_dtypes("category"):
if df[col].dtype.categories.dtype is np.dtype("object"):
df[col] = df[col].map(str)
Expand Down Expand Up @@ -66,8 +68,8 @@ def postprocess_result(result: pd.DataFrame, cfg: configs.Config) -> pd.DataFram
categories[field.name] = CategoricalDtype(categories=param, ordered=True)

# Optimize data types.
for field in categories:
result[field] = result[field].astype(categories[field])
for field_name in categories:
result[field_name] = result[field_name].astype(categories[field_name])
result = compress_df(result)
if result.isnull().values.any():
raise ValueError("Failed to cast.")
Expand All @@ -86,6 +88,7 @@ def truncate(df, col: str, n: int = int(1e2 * 32)):


def load_result_from_pkl(filepath):
"""Load result from `filepath`."""
if filepath is None:
return None
else:
Expand Down Expand Up @@ -151,6 +154,7 @@ def pd_categorical_concat(df1, df2):

# TODO(eringrant): Generalize to arbitrary #s.
def read_concat_hdf(f1, f2):
"""Read and concatenate two HDFs."""
ignore_columns = (
"optimizer_fn",
"learning_rate",
Expand Down
Loading

0 comments on commit cf1fd7e

Please sign in to comment.