From 1c36bd358b8b3419baa7b14e7e010179749964d9 Mon Sep 17 00:00:00 2001 From: Erin Grant Date: Tue, 19 Dec 2023 15:26:24 -0800 Subject: [PATCH] committee machine --- nets/models/__init__.py | 3 ++- nets/models/teacher.py | 46 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/nets/models/__init__.py b/nets/models/__init__.py index 148c1e6..acb4b5a 100644 --- a/nets/models/__init__.py +++ b/nets/models/__init__.py @@ -1,10 +1,11 @@ """Neural network models.""" from .feedforward import MLP -from .teacher import CanonicalTeacher +from .teacher import CanonicalTeacher, CommitteeTeacher from .transformers import SequenceClassifier __all__ = ( "MLP", "SequenceClassifier", "CanonicalTeacher", + "CommitteeTeacher", ) diff --git a/nets/models/teacher.py b/nets/models/teacher.py index d3ba199..c6e26f5 100644 --- a/nets/models/teacher.py +++ b/nets/models/teacher.py @@ -4,13 +4,15 @@ import equinox as eqx import jax +import jax.numpy as jnp from jax import Array from nets import models +from nets.models.feedforward import StopGradient class CanonicalTeacher(eqx.Module): - """Multi-layer perceptron over standard Normal input.""" + """A canonical teacher model for the teacher-student setup.""" input_sampler: Callable net: eqx.Module @@ -33,7 +35,7 @@ def __init__( @jax.jit def gaussian_sampler(key: Array) -> Array: - return jax.random.normal(key, shape=(in_features,)) + return jax.random.normal(key, shape=(in_features,)) / jnp.sqrt(in_features) self.input_sampler = gaussian_sampler self.net = models.MLP( @@ -53,3 +55,43 @@ def __call__(self: Self, key: Array) -> tuple[Array, Array]: y = self.net(x, key=net_key) return x, y + + +class CommitteeTeacher(CanonicalTeacher): + """A teacher model that is a committee machine.""" + + def __init__( + self: Self, + in_features: int, + hidden_features: int, + activation: Callable = jax.nn.relu, + dropout_probs: tuple[float, ...] | None = None, + init_scale: float = 1.0, + *, + key: Array, + ) -> None: + """Initialize a CommitteeTeacher.""" + super().__init__( + in_features=in_features, + hidden_features=(hidden_features,), + out_features=1, + activation=activation, + dropout_probs=dropout_probs, + init_scale=init_scale, + key=key, + ) + + # Fix last-layer weights to compute the mean of the hidden-layer activations. + self.net = eqx.tree_at( + lambda net: net.layers[-1].weight, + self.net, + StopGradient( + jnp.ones_like(self.net.layers[-1].weight) + / self.net.layers[-1].weight.shape[-1], + ), + ) + self.net = eqx.tree_at( + lambda net: net.layers[-1].bias, + self.net, + StopGradient(jnp.zeros_like(self.net.layers[-1].bias)), + )