diff --git a/nets/models/feedforward.py b/nets/models/feedforward.py index c25ae91..9b5b881 100644 --- a/nets/models/feedforward.py +++ b/nets/models/feedforward.py @@ -1,47 +1,14 @@ """Simple feedforward neural networks.""" from collections.abc import Callable -from math import sqrt from typing import Self import equinox as eqx import equinox.nn as enn import jax import jax.numpy as jnp -import numpy as np from jax import Array -def trunc_normal_init(weight: Array, key: Array, stddev: float | None = None) -> Array: - """Truncated normal distribution initialization.""" - _, in_ = weight.shape - stddev = stddev or sqrt(1.0 / max(1.0, in_)) - return stddev * jax.random.truncated_normal( - key=key, - shape=weight.shape, - lower=-2, - upper=2, - ) - - -# Adapted from https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/initializers.py. -def lecun_normal_init( - weight: Array, - key: Array, - scale: float = 1.0, -) -> Array: - """LeCun (variance-scaling) normal distribution initialization.""" - _, in_ = weight.shape - scale /= max(1.0, in_) - - stddev = np.sqrt(scale) - # Adjust stddev for truncation. - # Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) - distribution_stddev = jnp.asarray(0.87962566103423978, dtype=float) - stddev = stddev / distribution_stddev - - return trunc_normal_init(weight, key, stddev=stddev) - - class StopGradient(eqx.Module): """Stop gradient wrapper.""" @@ -73,8 +40,16 @@ def __init__( key=key, ) - # Reinitialize weight from variance scaling distribution, reusing `key`. - self.weight: Array = lecun_normal_init(self.weight, key=key, scale=init_scale) + # Reinitialize weight to force a specific initializer, reusing `key`. + self.weight: Array = jax.nn.initializers.variance_scaling( + scale=init_scale, + mode="fan_in", + distribution="truncated_normal", + )( + key=key, + shape=self.weight.shape, + ) + if not trainable: self.weight = StopGradient(self.weight) diff --git a/nets/models/transformers.py b/nets/models/transformers.py index f4884dd..5e9385c 100644 --- a/nets/models/transformers.py +++ b/nets/models/transformers.py @@ -17,7 +17,7 @@ import numpy as np from jax import Array -from nets.models.feedforward import MLP, Linear, StopGradient, trunc_normal_init +from nets.models.feedforward import MLP, Linear, StopGradient class TokenEmbed(eqx.Module): @@ -31,7 +31,7 @@ def __init__( self: Self, input_shape: int | Sequence[int], embedding_size: int, - init_stddev: float | None = 1.0, + init_scale: float | None = 1.0, *, trainable: bool = True, key: Array, @@ -39,6 +39,7 @@ def __init__( """Initialize a linear token embedding layer.""" if isinstance(input_shape, int): input_shape = (input_shape,) + super().__init__( in_features=prod(input_shape), out_features=embedding_size, @@ -47,7 +48,15 @@ def __init__( ) # Reinitialize weight from truncated normal distribution, reusing `key`. - self.weight: Array = trunc_normal_init(self.weight, key=key, stddev=init_stddev) + self.weight: Array = jax.nn.initializers.variance_scaling( + scale=init_scale, + mode="fan_avg", + distribution="truncated_normal", + )( + key=key, + shape=self.weight.shape, + ) + if not trainable: self.weight = StopGradient(self.weight) @@ -518,7 +527,7 @@ def __init__( embed_dim, trainable=train_embed, key=keys[1], - init_stddev=0.02, + init_scale=0.02**2, ) self.label_embed_drop = enn.Dropout(p=label_embed_drop_rate) diff --git a/nets/simulators/teacher_student.py b/nets/simulators/teacher_student.py index 5a59652..e78642a 100644 --- a/nets/simulators/teacher_student.py +++ b/nets/simulators/teacher_student.py @@ -50,7 +50,7 @@ def batch_train_step( optimizer: optax.GradientTransformation, opt_state: Array, ) -> tuple[Array, eqx.Module, Array]: - """Update the model on a batch of example-target pair.""" + """Update the student model on a batch of example-target pairs.""" student_key, teacher_key = jax.random.split(key) x, y = jax.vmap(teacher)(jax.random.split(teacher_key, batch_size)) @@ -74,7 +74,7 @@ def eval_step( teacher: eqx.Module, student: eqx.Module, ) -> Mapping[str, Array]: - """Evaluate the model on a single example-target pairs.""" + """Evaluate the student model on a single example-target pair.""" teacher_key, student_key = jax.random.split(key) x, y = teacher(key=teacher_key) @@ -100,16 +100,16 @@ def metrics_to_df(metrics: Mapping[str, Array]) -> pd.DataFrame: # Probe to get shape. num_iters = len(metrics_df) - num_examples = metrics_df["loss"][0].size + num_examples = metrics_df["ground truth target"][0].size # Determine metric structures. - def has_shape(col: str, shape: tuple[int]) -> bool: + def has_shape(col: str, shape: tuple | tuple[int]) -> bool: a = metrics_df[col][0] return hasattr(a, "shape") and a.shape == shape def has_no_shape(col: str) -> bool: a = metrics_df[col][0] - return not hasattr(a, "shape") or has_shape(col, (1,)) + return not hasattr(a, "shape") or has_shape(col, ()) or has_shape(col, (1,)) iterationwise_metrics = tuple(filter(has_no_shape, metrics_df.columns)) elementwise_metrics = tuple(set(metrics_df.columns) - set(iterationwise_metrics)) @@ -293,10 +293,232 @@ def simulate( return student, metrics_df +# TODO(eringrant): Allow varying multiple dimensions. +def batched_metrics_to_df( + metrics: Mapping[str, Array], + batched_key: str, + batched_values: tuple[float, ...], +) -> pd.DataFrame: + """Wrapper function calling `metrics_to_df` with a leading batch dimension.""" + singleton_dfs = [] + + for i, singleton_value in enumerate(batched_values): + singleton_metrics = { + key: value[i] if value is not None else None for key, value in metrics.items() + } + + singleton_df = metrics_to_df(singleton_metrics) + singleton_df[batched_key] = singleton_value + singleton_dfs.append(singleton_df) + + return pd.concat(singleton_dfs, ignore_index=True) + + +batched_student_eval_step = eqx.filter_vmap( + fun=eval_step, + in_axes=(None, None, eqx.if_array(0)), +) + + +batched_student_batch_train_step = eqx.filter_vmap( + fun=batch_train_step, + in_axes=(None, None, None, eqx.if_array(0), None, None), +) + + +def batched_student_evaluate( + key: Array, + teacher: eqx.Module, + students: eqx.Module, + batch_size: int, + num_examples: int, + iteration: int, + batched_key: str, + batched_values: tuple[float, ...], +) -> pd.DataFrame: + """Evaluate the batched student models on `num_examples` example-target pairs.""" + metrics = {} + + # Probing metric shapes. + incremental_metrics = { + metric_name: np.repeat(np.empty_like(metric_value), repeats=num_examples, axis=1) + for metric_name, metric_value in batched_student_eval_step( + key[jnp.newaxis], + teacher, + students, + ).items() + } + + # Incremental evaluation. + for i, eval_keys in enumerate( + batcher(jax.random.split(key, num_examples), batch_size), + ): + batch_metrics = batched_student_eval_step( + eval_keys, + teacher, + students, + ) + + for metric_name in incremental_metrics: + incremental_metrics[metric_name][ + :, + i * batch_size : min((i + 1) * batch_size, num_examples), + ..., + ] = batch_metrics[metric_name] + + metrics.update(incremental_metrics) + num_models = incremental_metrics["loss"].shape[0] + + # Metrics metadata. + metrics["training iteration"] = np.full( + shape=(num_models, num_examples), + fill_value=iteration, + ) + + return batched_metrics_to_df( + metrics, + batched_key=batched_key, + batched_values=batched_values, + ) + + +def batch_student_init_scale_simulate( + seed: int, + # Data params. + input_num_dimensions: int, + output_num_dimensions: int, + input_noise_scale: float, + # Model params. + teacher_num_hiddens: tuple[int, ...], + teacher_activation_fn: Callable, + teacher_init_scale: float, + student_num_hiddens: tuple[int, ...], + student_activation_fn: Callable, + student_init_scale: tuple[float, ...], + # Training and evaluation params. + optimizer_fn: Callable, + learning_rate: float, + train_batch_size: int, + eval_batch_size: int, + num_training_iterations: int, + eval_interval: int, + eval_num_examples: int, +) -> tuple[pd.DataFrame, ...]: + """Simulate teacher-student learning.""" + logging.info(f"Using JAX backend: {jax.default_backend()}\n") + logging.info(f"Using configuration: {pprint.pformat(locals())}") + + # Single source of randomness. + teacher_key, student_key, train_key, eval_key = jax.random.split( + jax.random.PRNGKey(seed), + 4, + ) + + ######### + # Data model setup. + teacher = models.CanonicalTeacher( + in_features=input_num_dimensions, + hidden_features=teacher_num_hiddens, + out_features=output_num_dimensions, + activation=teacher_activation_fn, + init_scale=teacher_init_scale, + key=teacher_key, + ) + teacher = eqx.nn.inference_mode(teacher, value=True) + + ######### + # Learner model setup. + student_keys = jax.random.split(student_key, len(student_init_scale)) + student_init_scales = jnp.asarray(student_init_scale) + + @eqx.filter_vmap + def make_students(init_scale: float, key: Array) -> eqx.Module: + return models.MLP( + in_features=input_num_dimensions, + hidden_features=student_num_hiddens, + out_features=output_num_dimensions, + activation=student_activation_fn, + init_scale=init_scale, + key=key, + ) + + students = make_students(student_init_scales, student_keys) + + logging.info(f"Teacher:\n{teacher}\n") + logging.info(f"Student:\n{students}\n") + + ######### + # Training loop. + optimizer = optimizer_fn(learning_rate=learning_rate) + opt_state = optimizer.init(eqx.filter(students, eqx.is_array)) + + # Bookkeeping. + metrics = [] + + # Evaluate before starting training. + metrics.append( + batched_student_evaluate( + key=eval_key, + teacher=teacher, + students=eqx.nn.inference_mode(students, value=True), + batch_size=eval_batch_size, + num_examples=eval_num_examples, + iteration=0, + batched_key="init scale", + batched_values=student_init_scale, + ), + ) + + logging.info("\nStarting training...") + training_start_time = time.time() + + for train_step_num in range(1, num_training_iterations): + # Mutate key. + (train_key,) = jax.random.split(train_key, 1) + + train_loss, students, opt_state = batched_student_batch_train_step( + train_key, + train_batch_size, + teacher, + students, + optimizer, + opt_state, + ) + + if train_step_num % eval_interval == 0: + metrics.append( + batched_student_evaluate( + key=eval_key, + teacher=teacher, + students=eqx.nn.inference_mode(students, value=True), + batch_size=eval_batch_size, + num_examples=eval_num_examples, + iteration=train_step_num, + batched_key="init scale", + batched_values=student_init_scale, + ), + ) + + logging.info(f"\titeration:\t{train_step_num}") + for init_scale, loss in zip( + student_init_scales.tolist(), + metrics[-1].groupby("init scale").loss.mean(), + strict=True, + ): + logging.info(f"\t\t\t\tinit scale:\t{init_scale:.4f}\tloss:\t{loss:.4f}") + + training_time = time.time() - training_start_time + logging.info(f"Finished training in {training_time:0.2f} seconds.") + + metrics_df = pd.concat(metrics) + + return students, metrics_df + + if __name__ == "__main__": logging.basicConfig(level=logging.INFO) - _, metrics_df = simulate( + _, metrics_df = batch_student_init_scale_simulate( seed=0, input_noise_scale=1e-1, input_num_dimensions=5, @@ -306,12 +528,12 @@ def simulate( teacher_init_scale=1e-1, student_num_hiddens=(100, 100), student_activation_fn=jax.nn.tanh, - student_init_scale=1e-1, + student_init_scale=np.logspace(-4, 1, num=40).tolist(), optimizer_fn=optax.sgd, learning_rate=1e-2, train_batch_size=256, eval_batch_size=256, - num_training_iterations=int(5e2), + num_training_iterations=int(1e3), eval_interval=int(1e1), eval_num_examples=int(1e4), ) @@ -320,10 +542,14 @@ def simulate( import seaborn as sns sns.lineplot( - data=metrics_df, + data=metrics_df.round(4), x="training iteration", y="accuracy @ 0.001", - errorbar=("ci", 95), + hue="init scale", + markers=True, + palette=sns.cubehelix_palette(n_colors=metrics_df["init scale"].nunique()), ) + plt.legend(title="student init scale", bbox_to_anchor=(1.05, 1), loc="upper left") + plt.tight_layout() plt.show()