diff --git a/nets/simulators/teacher_student.py b/nets/simulators/teacher_student.py index db83f18..13c9d1a 100644 --- a/nets/simulators/teacher_student.py +++ b/nets/simulators/teacher_student.py @@ -252,23 +252,11 @@ def simulate( # Bookkeeping. metrics = [] - # Evaluate before starting training. - metrics.append( - evaluate( - key=eval_key, - teacher=teacher, - student=eqx.nn.inference_mode(student, value=True), - batch_size=eval_batch_size, - num_examples=eval_num_examples, - epoch=0, - iteration=0, - ), - ) - last_batch_size = train_num_examples % train_batch_size ragged_last_batch = bool(last_batch_size) num_training_batches = train_num_examples // train_batch_size + ragged_last_batch + logging.info(f"{num_training_batches} batches per epoch...") logging.info("\nStarting training...") training_start_time = time.time() @@ -276,6 +264,19 @@ def simulate( # Reset data generation. iter_train_key = train_key for step_num in range(1, num_training_batches + 1): + # Evaluate at the start of each epoch. + metrics.append( + evaluate( + key=eval_key, + teacher=teacher, + student=eqx.nn.inference_mode(student, value=True), + batch_size=eval_batch_size, + num_examples=eval_num_examples, + epoch=0, + iteration=0, + ), + ) + # Mutate key. (iter_train_key,) = jax.random.split(iter_train_key, 1) @@ -310,6 +311,20 @@ def simulate( training_time = time.time() - training_start_time logging.info(f"Finished training in {training_time:0.2f} seconds.") + # Evaluate at the end of training. + if step_num % eval_interval != 0: + metrics.append( + evaluate( + key=eval_key, + teacher=teacher, + student=eqx.nn.inference_mode(student, value=True), + batch_size=eval_batch_size, + num_examples=eval_num_examples, + epoch=epoch_num, + iteration=step_num, + ), + ) + metrics_df = pd.concat(metrics) # TODO(eringrant): Robust handling of hyperparameters. @@ -491,28 +506,29 @@ def make_students(init_scale: float, key: Array) -> eqx.Module: # Bookkeeping. metrics = [] - # Evaluate before starting training. - metrics.append( - batched_student_evaluate( - key=eval_key, - teacher=teacher, - students=eqx.nn.inference_mode(students, value=True), - batch_size=eval_batch_size, - num_examples=eval_num_examples, - epoch=0, - iteration=0, - batched_key="init scale", - batched_values=student_init_scale, - ), - ) - last_batch_size = train_num_examples % train_batch_size ragged_last_batch = bool(last_batch_size) num_training_batches = train_num_examples // train_batch_size + ragged_last_batch + logging.info(f"{num_training_batches} batches per epoch...") logging.info("\nStarting training...") training_start_time = time.time() for epoch_num in range(1, train_num_epochs + 1): + # Evaluate at the start of each epoch. + metrics.append( + batched_student_evaluate( + key=eval_key, + teacher=teacher, + students=eqx.nn.inference_mode(students, value=True), + batch_size=eval_batch_size, + num_examples=eval_num_examples, + epoch=epoch_num, + iteration=0, + batched_key="init scale", + batched_values=student_init_scale, + ), + ) + # Reset data generation. iter_train_key = train_key for step_num in range(1, num_training_batches + 1): @@ -562,8 +578,23 @@ def make_students(init_scale: float, key: Array) -> eqx.Module: training_time = time.time() - training_start_time logging.info(f"Finished training in {training_time:0.2f} seconds.") - metrics_df = pd.concat(metrics) + # Evaluate at the end of training. + if step_num % eval_interval != 0: + metrics.append( + batched_student_evaluate( + key=eval_key, + teacher=teacher, + students=eqx.nn.inference_mode(students, value=True), + batch_size=eval_batch_size, + num_examples=eval_num_examples, + epoch=epoch_num, + iteration=step_num, + batched_key="init scale", + batched_values=student_init_scale, + ), + ) + metrics_df = pd.concat(metrics) return students, metrics_df @@ -584,10 +615,10 @@ def make_students(init_scale: float, key: Array) -> eqx.Module: optimizer_fn=optax.sgd, learning_rate=1e-2, train_num_epochs=int(1e1), - train_batch_size=256, - train_num_examples=int(1e5), + train_batch_size=2 * int(1e2), + train_num_examples=int(1e4), eval_interval=int(1e1), - eval_batch_size=256, + eval_batch_size=4 * int(1e2), eval_num_examples=int(1e4), ) @@ -596,7 +627,7 @@ def make_students(init_scale: float, key: Array) -> eqx.Module: sns.lineplot( data=metrics_df.round(4), - x="training iteration", + x="training epoch", y="accuracy @ 0.001", errorbar=None, hue="init scale",