Skip to content

Commit

Permalink
committee machine
Browse files Browse the repository at this point in the history
  • Loading branch information
eringrant committed Dec 19, 2023
1 parent 2d8622f commit 1c36bd3
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
3 changes: 2 additions & 1 deletion nets/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Neural network models."""
from .feedforward import MLP
from .teacher import CanonicalTeacher
from .teacher import CanonicalTeacher, CommitteeTeacher
from .transformers import SequenceClassifier

__all__ = (
"MLP",
"SequenceClassifier",
"CanonicalTeacher",
"CommitteeTeacher",
)
46 changes: 44 additions & 2 deletions nets/models/teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@

import equinox as eqx
import jax
import jax.numpy as jnp
from jax import Array

from nets import models
from nets.models.feedforward import StopGradient


class CanonicalTeacher(eqx.Module):
"""Multi-layer perceptron over standard Normal input."""
"""A canonical teacher model for the teacher-student setup."""

input_sampler: Callable
net: eqx.Module
Expand All @@ -33,7 +35,7 @@ def __init__(

@jax.jit
def gaussian_sampler(key: Array) -> Array:
return jax.random.normal(key, shape=(in_features,))
return jax.random.normal(key, shape=(in_features,)) / jnp.sqrt(in_features)

self.input_sampler = gaussian_sampler
self.net = models.MLP(
Expand All @@ -53,3 +55,43 @@ def __call__(self: Self, key: Array) -> tuple[Array, Array]:
y = self.net(x, key=net_key)

return x, y


class CommitteeTeacher(CanonicalTeacher):
"""A teacher model that is a committee machine."""

def __init__(
self: Self,
in_features: int,
hidden_features: int,
activation: Callable = jax.nn.relu,
dropout_probs: tuple[float, ...] | None = None,
init_scale: float = 1.0,
*,
key: Array,
) -> None:
"""Initialize a CommitteeTeacher."""
super().__init__(
in_features=in_features,
hidden_features=(hidden_features,),
out_features=1,
activation=activation,
dropout_probs=dropout_probs,
init_scale=init_scale,
key=key,
)

# Fix last-layer weights to compute the mean of the hidden-layer activations.
self.net = eqx.tree_at(
lambda net: net.layers[-1].weight,
self.net,
StopGradient(
jnp.ones_like(self.net.layers[-1].weight)
/ self.net.layers[-1].weight.shape[-1],
),
)
self.net = eqx.tree_at(
lambda net: net.layers[-1].bias,
self.net,
StopGradient(jnp.zeros_like(self.net.layers[-1].bias)),
)

0 comments on commit 1c36bd3

Please sign in to comment.