Skip to content

Commit

Permalink
large-scale cleanups; stricter typing
Browse files Browse the repository at this point in the history
  • Loading branch information
eringrant committed Oct 31, 2023
1 parent 6a24022 commit 6b8406a
Show file tree
Hide file tree
Showing 22 changed files with 869 additions and 765 deletions.
13 changes: 6 additions & 7 deletions nets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module-wide constants for `nets`."""
import os
from pathlib import Path

import nets

__version__ = "0.0.1"
Expand All @@ -9,15 +10,13 @@
# https://stackoverflow.com/a/59597733.
__package_path = os.path.split(nets.__path__[0])[0]
DATA_DIR = Path(__package_path, "data")
TMP_DIR = Path("/tmp", "nets")
os.makedirs(TMP_DIR, exist_ok=True)
# TODO(eringrant): Rethink tmpdir.
TMP_DIR = Path("/tmp", "nets") # noqa: S108
Path.mkdir(TMP_DIR, parents=True, exist_ok=True)

scratch_home = os.environ.get("SCRATCH_HOME")
if scratch_home is not None:
SCRATCH_DIR = Path(scratch_home, "nets")
else:
SCRATCH_DIR = TMP_DIR
os.makedirs(SCRATCH_DIR, exist_ok=True)
SCRATCH_DIR = Path(scratch_home, "nets") if scratch_home is not None else TMP_DIR
Path.mkdir(SCRATCH_DIR, parents=True, exist_ok=True)

del nets
del os
Expand Down
7 changes: 2 additions & 5 deletions nets/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
"""`Dataset`s to accompany models."""
from .base import DatasetSplit
from .base import ExemplarLabeling
from .base import HoldoutClassLabeling
from .base import Dataset
from .symbolic import SymbolicDataset
from .base import Dataset, DatasetSplit, ExemplarLabeling, HoldoutClassLabeling
from .parity import ParityDataset
from .symbolic import SymbolicDataset

__all__ = (
"DatasetSplit",
Expand Down
91 changes: 49 additions & 42 deletions nets/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
"""`Dataset`s are sequences of unique examples."""
from typing import Any
from collections.abc import Sequence
from nptyping import NDArray
from nptyping import Bool
from nptyping import Floating
from nptyping import Int
from jax.random import KeyArray
from jaxtyping import Array

from enum import Enum
from enum import unique
from functools import cached_property
from functools import partial
import numpy as np
from enum import Enum, unique
from functools import cached_property, partial
from pathlib import Path
from typing import Any, Self

import jax
import jax.numpy as jnp
import jax.nn as jnn

import jax.numpy as jnp
import numpy as np
from jax import Array
from nptyping import Bool, Floating, Int, NDArray

# Type hints.
IndexType = int | Sequence[int] | slice
Expand Down Expand Up @@ -65,7 +57,9 @@ def wrap_labels(labels: Array, num_classes: int, modulus: Array) -> Array:


def get_wrapped_indices(
prop_labels: float, num_classes: int, offset=0
prop_labels: float,
num_classes: int,
offset: int = 0,
) -> tuple[int, Array]:
"""Get indices to wrap `num_classes` into `prop_labels` labels."""
if prop_labels < 1.0:
Expand All @@ -92,8 +86,8 @@ class Dataset:
prop_valid_labels: float

def __init__(
self,
key: KeyArray,
self: Self,
key: Array,
split: DatasetSplit,
exemplar_labeling: ExemplarLabeling,
holdout_class_labeling: HoldoutClassLabeling,
Expand All @@ -104,7 +98,7 @@ def __init__(
num_valid_classes: int = 0,
prop_valid_labels: float = 0,
num_exemplars_per_class: int = 400,
):
) -> None:
"""A `Dataset` of class exemplars from which to draw sequences.
Args:
Expand All @@ -129,6 +123,11 @@ class set. If 1.0, then labels are identical to the underlying class labels;
num_exemplars_per_class: Number of exemplars per class to draw from the
underlying dataset.
"""
# TODO(eringrant): Remove these arguments.
del key
del split
del exemplar_labeling

self.num_train_classes = num_train_classes
self.num_valid_classes = num_valid_classes
self.num_test_classes = num_test_classes
Expand All @@ -139,15 +138,17 @@ class set. If 1.0, then labels are identical to the underlying class labels;
prop_train_labels * num_train_classes < prop_valid_labels * num_valid_classes
or prop_train_labels * num_train_classes < prop_test_labels * num_test_classes
):
raise ValueError(
"Relabeling of validation and test sets with train "
"labels usually assumes more train classes than "
"validation and test classes, but "
f"{prop_train_labels * num_train_classes} < "
f"{prop_valid_labels * num_valid_classes} or "
msg = (
"Relabeling of validation and test sets with train labels usually assumes "
"more train classes than validation and test classes, but "
f"{prop_train_labels * num_train_classes} < "
f"{prop_valid_labels * num_valid_classes} "
f"or {prop_train_labels * num_train_classes} < "
f"{prop_test_labels * num_test_classes}."
)
raise ValueError(
msg,
)

self.num_observed_classes = int(prop_train_labels * self.num_train_classes)

Expand All @@ -167,13 +168,17 @@ class set. If 1.0, then labels are identical to the underlying class labels;
prop_test_labels,
)
):
raise ValueError(
"One of `prop_{train,valid,test}_labels` was invalid: "
msg = (
f"One of `prop_{{train,valid,test}}_labels` was invalid: "
f"{prop_train_labels}, {prop_valid_labels}, {prop_test_labels}."
)
raise ValueError(
msg,
)

num_train_labels, train_indices = get_wrapped_indices(
prop_train_labels, num_train_classes
prop_train_labels,
num_train_classes,
)
num_valid_labels, valid_indices = get_wrapped_indices(
prop_valid_labels,
Expand All @@ -198,63 +203,65 @@ class set. If 1.0, then labels are identical to the underlying class labels;
wrap_labels,
num_classes=self.num_classes,
modulus=modulus,
)
),
)

def __len__(self) -> int:
def __len__(self: Self) -> int:
"""Number of exemplars in this `Dataset`."""
return len(self._exemplars)

@property
def num_classes(self) -> int:
def num_classes(self: Self) -> int:
"""Number of classes in this `Dataset`."""
return self.num_train_classes + self.num_valid_classes + self.num_test_classes

@property
def exemplar_shape(self) -> tuple[int]:
def exemplar_shape(self: Self) -> tuple[int]:
"""Shape of an exemplar."""
raise NotImplementedError("To be implemented by the subclass.")
msg = "To be implemented by the subclass."
raise NotImplementedError(msg)

def __getitem__(self, index: int | slice) -> ExemplarType:
def __getitem__(self: Self, index: int | slice) -> ExemplarType:
"""Get the exemplar(s) and the corresponding label(s) at `index`."""
raise NotImplementedError("To be implemented by the subclass.")
msg = "To be implemented by the subclass."
raise NotImplementedError(msg)

@cached_property
def unique_classes(self) -> Sequence[int]:
def unique_classes(self: Self) -> Sequence[int]:
"""Deterministic ordering of dataset class labels."""
return np.unique(self._labels).tolist()

@cached_property
def train_classes(self) -> Sequence[int]:
def train_classes(self: Self) -> Sequence[int]:
"""Deterministic ordering of training class labels."""
i = self.num_train_classes
return self.unique_classes[:i]

@cached_property
def valid_classes(self) -> Sequence[int]:
def valid_classes(self: Self) -> Sequence[int]:
"""Deterministic ordering of validation class labels."""
i = self.num_train_classes
j = self.num_train_classes + self.num_valid_classes
return self.unique_classes[i:j]

@cached_property
def test_classes(self) -> Sequence[int]:
def test_classes(self: Self) -> Sequence[int]:
"""Deterministic ordering of testing class labels."""
j = self.num_train_classes + self.num_valid_classes
k = self.num_train_classes + self.num_valid_classes + self.num_test_classes
return self.unique_classes[j:k]

@cached_property
def _train_idx(self) -> NDArray[Any, Bool]:
def _train_idx(self: Self) -> NDArray[Any, Bool]:
"""Mask for the train split."""
return np.in1d(self._labels, self.train_classes)

@cached_property
def _valid_idx(self) -> NDArray[Any, Bool]:
def _valid_idx(self: Self) -> NDArray[Any, Bool]:
"""Mask for the validation split."""
return np.in1d(self._labels, self.valid_classes)

@cached_property
def _test_idx(self) -> NDArray[Any, Bool]:
def _test_idx(self: Self) -> NDArray[Any, Bool]:
"""Mask for the test split."""
return np.in1d(self._labels, self.test_classes)
88 changes: 38 additions & 50 deletions nets/datasets/parity.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,34 @@
"""A `ParityDataset` that generates parity-labelled examples in `D` dimensions."""
from jax.random import KeyArray
from functools import partial
from typing import Self

import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from jax import Array

from nets.datasets.base import Dataset
from nets.datasets.base import DatasetSplit
from nets.datasets.base import ExemplarType
from nets.datasets.base import ExemplarLabeling
from nets.datasets.base import HoldoutClassLabeling
from nets.datasets.base import (
Dataset,
DatasetSplit,
ExemplarLabeling,
ExemplarType,
HoldoutClassLabeling,
)


class ParityDataset(Dataset):
"""Parity-labelled exmaples in many dimensions."""

def __init__(
self,
key: KeyArray,
self: Self,
key: Array,
num_dimensions: int = 2, # num classes is c == 2**d
num_exemplars_per_class: int = 400,
exemplar_noise_scale: float = 1e-1,
# TODO(eringrant): Decide whether to use these arguments.
split: DatasetSplit = DatasetSplit.TRAIN,
exemplar_labeling: ExemplarLabeling = ExemplarLabeling.STANDARD,
holdout_class_labeling: HoldoutClassLabeling = HoldoutClassLabeling.STANDARD,
):
) -> None:
"""Initializes a `ParityDataset` instance."""
super().__init__(
key=key, # TODO(eringrant): Use a separate key.
Expand All @@ -49,27 +51,38 @@ def __init__(
labels = jnp.arange(2**num_dimensions)
# TODO(eringrant): Assert labels are 32-bit integers for this conversion.
bit_labels = jnp.unpackbits(labels.view("uint8"), bitorder="little").reshape(
labels.size, 32
labels.size,
32,
)[:, :num_dimensions]
parity_labels = jax.lax.reduce(
bit_labels, init_values=jnp.uint8(0), computation=jnp.bitwise_xor, dimensions=(1,)
bit_labels,
init_values=jnp.uint8(0),
computation=jnp.bitwise_xor,
dimensions=(1,),
)

self._exemplars = bit_labels.astype(jnp.int32)
self._labels = parity_labels.astype(jnp.int32)

if num_exemplars_per_class > 1:
# Class exemplars are noised one-hots.

# Repeat each exemplar and label `num_exemplars_per_class` times.
self._exemplars = jnp.repeat(
self._exemplars[:, jnp.newaxis, :], num_exemplars_per_class, axis=1
).reshape(num_exemplars_per_class * self.num_classes, num_dimensions)
self._exemplars[:, jnp.newaxis, :],
num_exemplars_per_class,
axis=1,
).reshape(num_exemplars_per_class * (2**num_dimensions), num_dimensions)
self._labels = jnp.repeat(
self._labels[:, jnp.newaxis], num_exemplars_per_class, axis=-1
).reshape(num_exemplars_per_class * self.num_classes)
self._labels[:, jnp.newaxis],
num_exemplars_per_class,
axis=-1,
).reshape(num_exemplars_per_class * (2**num_dimensions))

# Produce unique keys for each exemplar.
self._exemplar_keys = jax.random.split(
key, self.num_classes * num_exemplars_per_class
key,
num_exemplars_per_class * (2**num_dimensions),
)

# Compile a function for sampling exemplars at `Dataset.__init__`.
Expand All @@ -79,16 +92,19 @@ def __init__(
jax.random.multivariate_normal,
# Isotropic with scale a/C to keep noise scale constant.
cov=exemplar_noise_scale / self.num_classes * jnp.eye(num_dimensions),
)
)
),
),
)

# TODO(eringrant): Implement this case
# Class exemplars are a single one-hot.

@property
def exemplar_shape(self) -> tuple[int]:
def exemplar_shape(self: Self) -> tuple[int]:
"""Returns the shape of an exemplar."""
return (self.num_dimensions,)

def __getitem__(self, index: int | slice) -> ExemplarType:
def __getitem__(self: Self, index: int | slice) -> ExemplarType:
"""Get the exemplar(s) and the corresponding label(s) at `index`."""
exemplars = self._exemplars[index]
labels = self._labels[index]
Expand All @@ -106,31 +122,3 @@ def __getitem__(self, index: int | slice) -> ExemplarType:
)

return exemplars, labels


if __name__ == "__main__":
# Test the class
key = random.PRNGKey(0)
dataset = ParityDataset(key)
# Only do the below for small datasets...
exemplars, labels = dataset[:]

import matplotlib.pyplot as plt

plt.figure()
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.title("2D XOR Dataset")
plt.grid(True)

# Use the labels to distinguish classes and plot
plt.scatter(
exemplars[labels == 0, 0], exemplars[labels == 0, 1], c="red", label="Class 0"
)
plt.scatter(
exemplars[labels == 1, 0], exemplars[labels == 1, 1], c="blue", label="Class 1"
)

# Add legend and show plot
plt.legend()
plt.show()
Loading

0 comments on commit 6b8406a

Please sign in to comment.