Skip to content

Commit

Permalink
batched student training by initialization scale
Browse files Browse the repository at this point in the history
  • Loading branch information
eringrant committed Nov 13, 2023
1 parent f3d0482 commit 16c2cac
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 49 deletions.
45 changes: 10 additions & 35 deletions nets/models/feedforward.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,14 @@
"""Simple feedforward neural networks."""
from collections.abc import Callable
from math import sqrt
from typing import Self

import equinox as eqx
import equinox.nn as enn
import jax
import jax.numpy as jnp
import numpy as np
from jax import Array


def trunc_normal_init(weight: Array, key: Array, stddev: float | None = None) -> Array:
"""Truncated normal distribution initialization."""
_, in_ = weight.shape
stddev = stddev or sqrt(1.0 / max(1.0, in_))
return stddev * jax.random.truncated_normal(
key=key,
shape=weight.shape,
lower=-2,
upper=2,
)


# Adapted from https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/initializers.py.
def lecun_normal_init(
weight: Array,
key: Array,
scale: float = 1.0,
) -> Array:
"""LeCun (variance-scaling) normal distribution initialization."""
_, in_ = weight.shape
scale /= max(1.0, in_)

stddev = np.sqrt(scale)
# Adjust stddev for truncation.
# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
distribution_stddev = jnp.asarray(0.87962566103423978, dtype=float)
stddev = stddev / distribution_stddev

return trunc_normal_init(weight, key, stddev=stddev)


class StopGradient(eqx.Module):
"""Stop gradient wrapper."""

Expand Down Expand Up @@ -73,8 +40,16 @@ def __init__(
key=key,
)

# Reinitialize weight from variance scaling distribution, reusing `key`.
self.weight: Array = lecun_normal_init(self.weight, key=key, scale=init_scale)
# Reinitialize weight to force a specific initializer, reusing `key`.
self.weight: Array = jax.nn.initializers.variance_scaling(
scale=init_scale,
mode="fan_in",
distribution="truncated_normal",
)(
key=key,
shape=self.weight.shape,
)

if not trainable:
self.weight = StopGradient(self.weight)

Expand Down
17 changes: 13 additions & 4 deletions nets/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
from jax import Array

from nets.models.feedforward import MLP, Linear, StopGradient, trunc_normal_init
from nets.models.feedforward import MLP, Linear, StopGradient


class TokenEmbed(eqx.Module):
Expand All @@ -31,14 +31,15 @@ def __init__(
self: Self,
input_shape: int | Sequence[int],
embedding_size: int,
init_stddev: float | None = 1.0,
init_scale: float | None = 1.0,
*,
trainable: bool = True,
key: Array,
) -> None:
"""Initialize a linear token embedding layer."""
if isinstance(input_shape, int):
input_shape = (input_shape,)

super().__init__(
in_features=prod(input_shape),
out_features=embedding_size,
Expand All @@ -47,7 +48,15 @@ def __init__(
)

# Reinitialize weight from truncated normal distribution, reusing `key`.
self.weight: Array = trunc_normal_init(self.weight, key=key, stddev=init_stddev)
self.weight: Array = jax.nn.initializers.variance_scaling(
scale=init_scale,
mode="fan_avg",
distribution="truncated_normal",
)(
key=key,
shape=self.weight.shape,
)

if not trainable:
self.weight = StopGradient(self.weight)

Expand Down Expand Up @@ -518,7 +527,7 @@ def __init__(
embed_dim,
trainable=train_embed,
key=keys[1],
init_stddev=0.02,
init_scale=0.02**2,
)
self.label_embed_drop = enn.Dropout(p=label_embed_drop_rate)

Expand Down
Loading

0 comments on commit 16c2cac

Please sign in to comment.