Skip to content

Commit

Permalink
epoch-based evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
eringrant committed Nov 15, 2023
1 parent 0352419 commit 1cd8cc4
Showing 1 changed file with 64 additions and 33 deletions.
97 changes: 64 additions & 33 deletions nets/simulators/teacher_student.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,30 +252,31 @@ def simulate(
# Bookkeeping.
metrics = []

# Evaluate before starting training.
metrics.append(
evaluate(
key=eval_key,
teacher=teacher,
student=eqx.nn.inference_mode(student, value=True),
batch_size=eval_batch_size,
num_examples=eval_num_examples,
epoch=0,
iteration=0,
),
)

last_batch_size = train_num_examples % train_batch_size
ragged_last_batch = bool(last_batch_size)
num_training_batches = train_num_examples // train_batch_size + ragged_last_batch

logging.info(f"{num_training_batches} batches per epoch...")
logging.info("\nStarting training...")
training_start_time = time.time()

for epoch_num in range(1, train_num_epochs + 1):
# Reset data generation.
iter_train_key = train_key
for step_num in range(1, num_training_batches + 1):
# Evaluate at the start of each epoch.
metrics.append(
evaluate(
key=eval_key,
teacher=teacher,
student=eqx.nn.inference_mode(student, value=True),
batch_size=eval_batch_size,
num_examples=eval_num_examples,
epoch=0,
iteration=0,
),
)

# Mutate key.
(iter_train_key,) = jax.random.split(iter_train_key, 1)

Expand Down Expand Up @@ -310,6 +311,20 @@ def simulate(
training_time = time.time() - training_start_time
logging.info(f"Finished training in {training_time:0.2f} seconds.")

# Evaluate at the end of training.
if step_num % eval_interval != 0:
metrics.append(
evaluate(
key=eval_key,
teacher=teacher,
student=eqx.nn.inference_mode(student, value=True),
batch_size=eval_batch_size,
num_examples=eval_num_examples,
epoch=epoch_num,
iteration=step_num,
),
)

metrics_df = pd.concat(metrics)

# TODO(eringrant): Robust handling of hyperparameters.
Expand Down Expand Up @@ -491,28 +506,29 @@ def make_students(init_scale: float, key: Array) -> eqx.Module:
# Bookkeeping.
metrics = []

# Evaluate before starting training.
metrics.append(
batched_student_evaluate(
key=eval_key,
teacher=teacher,
students=eqx.nn.inference_mode(students, value=True),
batch_size=eval_batch_size,
num_examples=eval_num_examples,
epoch=0,
iteration=0,
batched_key="init scale",
batched_values=student_init_scale,
),
)

last_batch_size = train_num_examples % train_batch_size
ragged_last_batch = bool(last_batch_size)
num_training_batches = train_num_examples // train_batch_size + ragged_last_batch

logging.info(f"{num_training_batches} batches per epoch...")
logging.info("\nStarting training...")
training_start_time = time.time()
for epoch_num in range(1, train_num_epochs + 1):
# Evaluate at the start of each epoch.
metrics.append(
batched_student_evaluate(
key=eval_key,
teacher=teacher,
students=eqx.nn.inference_mode(students, value=True),
batch_size=eval_batch_size,
num_examples=eval_num_examples,
epoch=epoch_num,
iteration=0,
batched_key="init scale",
batched_values=student_init_scale,
),
)

# Reset data generation.
iter_train_key = train_key
for step_num in range(1, num_training_batches + 1):
Expand Down Expand Up @@ -562,8 +578,23 @@ def make_students(init_scale: float, key: Array) -> eqx.Module:
training_time = time.time() - training_start_time
logging.info(f"Finished training in {training_time:0.2f} seconds.")

metrics_df = pd.concat(metrics)
# Evaluate at the end of training.
if step_num % eval_interval != 0:
metrics.append(
batched_student_evaluate(
key=eval_key,
teacher=teacher,
students=eqx.nn.inference_mode(students, value=True),
batch_size=eval_batch_size,
num_examples=eval_num_examples,
epoch=epoch_num,
iteration=step_num,
batched_key="init scale",
batched_values=student_init_scale,
),
)

metrics_df = pd.concat(metrics)
return students, metrics_df


Expand All @@ -584,10 +615,10 @@ def make_students(init_scale: float, key: Array) -> eqx.Module:
optimizer_fn=optax.sgd,
learning_rate=1e-2,
train_num_epochs=int(1e1),
train_batch_size=256,
train_num_examples=int(1e5),
train_batch_size=2 * int(1e2),
train_num_examples=int(1e4),
eval_interval=int(1e1),
eval_batch_size=256,
eval_batch_size=4 * int(1e2),
eval_num_examples=int(1e4),
)

Expand All @@ -596,7 +627,7 @@ def make_students(init_scale: float, key: Array) -> eqx.Module:

sns.lineplot(
data=metrics_df.round(4),
x="training iteration",
x="training epoch",
y="accuracy @ 0.001",
errorbar=None,
hue="init scale",
Expand Down

0 comments on commit 1cd8cc4

Please sign in to comment.