Skip to content

Commit

Permalink
empirical teacher-student
Browse files Browse the repository at this point in the history
  • Loading branch information
eringrant committed Oct 31, 2023
1 parent 8252685 commit 234d4d7
Show file tree
Hide file tree
Showing 4 changed files with 387 additions and 2 deletions.
4 changes: 3 additions & 1 deletion nets/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Neural network models."""
from .transformers import SequenceClassifier
from .feedforward import MLP
from .teacher import CanonicalTeacher
from .transformers import SequenceClassifier

__all__ = (
"MLP",
"SequenceClassifier",
"CanonicalTeacher",
)
55 changes: 55 additions & 0 deletions nets/models/teacher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Teacher models for the canonical teacher-student setup."""
from collections.abc import Callable
from typing import Self

import equinox as eqx
import jax
from jax import Array

from nets import models


class CanonicalTeacher(eqx.Module):
"""Multi-layer perceptron over standard Normal input."""

input_sampler: Callable
net: eqx.Module

def __init__(
self: Self,
in_features: int,
hidden_features: tuple[int, ...],
out_features: int,
activation: Callable = jax.nn.relu,
dropout_probs: tuple[float, ...] | None = None,
init_scale: float = 1.0,
*,
key: Array,
) -> None:
"""Initialize a CanonicalTeacher."""
super().__init__()

del dropout_probs # TODO(eringrant): Unused.

@jax.jit
def gaussian_sampler(key: Array) -> Array:
return jax.random.normal(key, shape=(in_features,))

self.input_sampler = gaussian_sampler
self.net = models.MLP(
in_features=in_features,
hidden_features=hidden_features,
out_features=out_features,
activation=activation,
init_scale=init_scale,
key=key,
)

def __call__(self: Self, key: Array) -> tuple[Array, Array]:
"""Generate the input and output to this teacher."""
input_key, net_key = jax.random.split(key, 2)

x = self.input_sampler(input_key)
y = self.net(x, key=net_key)

return x, y
1 change: 0 additions & 1 deletion nets/simulators/in_context_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ def simulate(
) -> tuple[pd.DataFrame, ...]:
"""Simulate in-context learning of classification tasks."""
logging.info(f"Using JAX backend: {jax.default_backend()}\n")

logging.info(f"Using configuration: {pprint.pformat(locals())}")

# Single source of randomness.
Expand Down
Loading

0 comments on commit 234d4d7

Please sign in to comment.