Skip to content

Commit

Permalink
finite dataset; train by epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
eringrant committed Nov 14, 2023
1 parent 16c2cac commit 0352419
Showing 1 changed file with 122 additions and 68 deletions.
190 changes: 122 additions & 68 deletions nets/simulators/teacher_student.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,21 @@ def evaluate(
student: eqx.Module,
batch_size: int,
num_examples: int,
epoch: int,
iteration: int,
) -> pd.DataFrame:
"""Evaluate the model on `num_examples` example-target pairs."""
metrics = {}

# Metrics metadata.
metrics["training iteration"] = iteration
metrics["training epoch"] = np.full(
shape=(num_examples,),
fill_value=epoch,
)
metrics["training iteration"] = np.full(
shape=(num_examples,),
fill_value=iteration,
)

# Probing metric shapes.
incremental_metrics = {
Expand Down Expand Up @@ -189,13 +197,15 @@ def simulate(
student_num_hiddens: tuple[int, ...],
student_activation_fn: Callable,
student_init_scale: float,
# Training and evaluation params.
# Training params.
optimizer_fn: Callable,
learning_rate: float,
train_num_epochs: int,
train_batch_size: int,
eval_batch_size: int,
num_training_iterations: int,
train_num_examples: int,
# Evaluation params.
eval_interval: int,
eval_batch_size: int,
eval_num_examples: int,
) -> tuple[pd.DataFrame, ...]:
"""Simulate teacher-student learning."""
Expand Down Expand Up @@ -250,46 +260,61 @@ def simulate(
student=eqx.nn.inference_mode(student, value=True),
batch_size=eval_batch_size,
num_examples=eval_num_examples,
epoch=0,
iteration=0,
),
)

last_batch_size = train_num_examples % train_batch_size
ragged_last_batch = bool(last_batch_size)
num_training_batches = train_num_examples // train_batch_size + ragged_last_batch

logging.info("\nStarting training...")
training_start_time = time.time()

for train_step_num in range(1, num_training_iterations):
# Mutate key.
(train_key,) = jax.random.split(train_key, 1)

train_loss, student, opt_state = batch_train_step(
train_key,
train_batch_size,
teacher,
student,
optimizer,
opt_state,
)

if train_step_num % eval_interval == 0:
metrics.append(
evaluate(
key=eval_key,
teacher=teacher,
student=eqx.nn.inference_mode(student, value=True),
batch_size=eval_batch_size,
num_examples=eval_num_examples,
iteration=train_step_num,
),
for epoch_num in range(1, train_num_epochs + 1):
# Reset data generation.
iter_train_key = train_key
for step_num in range(1, num_training_batches + 1):
# Mutate key.
(iter_train_key,) = jax.random.split(iter_train_key, 1)

train_loss, student, opt_state = batch_train_step(
iter_train_key,
train_batch_size,
teacher,
student,
optimizer,
opt_state,
)

loss = metrics[-1]["loss"].mean()
logging.info(f"\titeration:\t{train_step_num}\tloss:\t{loss:.4f}")
if step_num % eval_interval == 0:
metrics.append(
evaluate(
key=eval_key,
teacher=teacher,
student=eqx.nn.inference_mode(student, value=True),
batch_size=eval_batch_size,
num_examples=eval_num_examples,
epoch=epoch_num,
iteration=step_num,
),
)

loss = metrics[-1]["loss"].mean()
logging.info(
f"\titeration:\t{(epoch_num - 1) * num_training_batches + step_num}"
f"\tloss:\t{loss:.4f}",
)

training_time = time.time() - training_start_time
logging.info(f"Finished training in {training_time:0.2f} seconds.")

metrics_df = pd.concat(metrics)

# TODO(eringrant): Robust handling of hyperparameters.
metrics_df["init scale"] = student_init_scale

return student, metrics_df


Expand Down Expand Up @@ -332,6 +357,7 @@ def batched_student_evaluate(
students: eqx.Module,
batch_size: int,
num_examples: int,
epoch: int,
iteration: int,
batched_key: str,
batched_values: tuple[float, ...],
Expand Down Expand Up @@ -370,6 +396,10 @@ def batched_student_evaluate(
num_models = incremental_metrics["loss"].shape[0]

# Metrics metadata.
metrics["training epoch"] = np.full(
shape=(num_models, num_examples),
fill_value=epoch,
)
metrics["training iteration"] = np.full(
shape=(num_models, num_examples),
fill_value=iteration,
Expand All @@ -395,16 +425,22 @@ def batch_student_init_scale_simulate(
student_num_hiddens: tuple[int, ...],
student_activation_fn: Callable,
student_init_scale: tuple[float, ...],
# Training and evaluation params.
# Training params.
optimizer_fn: Callable,
learning_rate: float,
train_num_epochs: int,
train_batch_size: int,
eval_batch_size: int,
num_training_iterations: int,
train_num_examples: int,
# Evaluation params.
eval_interval: int,
eval_batch_size: int,
eval_num_examples: int,
) -> tuple[pd.DataFrame, ...]:
"""Simulate teacher-student learning."""
if eval_interval > train_num_examples // train_batch_size:
msg = "Evaluation interval must be no more than the number of training batches."
raise ValueError(msg)

logging.info(f"Using JAX backend: {jax.default_backend()}\n")
logging.info(f"Using configuration: {pprint.pformat(locals())}")

Expand Down Expand Up @@ -463,49 +499,65 @@ def make_students(init_scale: float, key: Array) -> eqx.Module:
students=eqx.nn.inference_mode(students, value=True),
batch_size=eval_batch_size,
num_examples=eval_num_examples,
epoch=0,
iteration=0,
batched_key="init scale",
batched_values=student_init_scale,
),
)

last_batch_size = train_num_examples % train_batch_size
ragged_last_batch = bool(last_batch_size)
num_training_batches = train_num_examples // train_batch_size + ragged_last_batch

logging.info("\nStarting training...")
training_start_time = time.time()

for train_step_num in range(1, num_training_iterations):
# Mutate key.
(train_key,) = jax.random.split(train_key, 1)

train_loss, students, opt_state = batched_student_batch_train_step(
train_key,
train_batch_size,
teacher,
students,
optimizer,
opt_state,
)

if train_step_num % eval_interval == 0:
metrics.append(
batched_student_evaluate(
key=eval_key,
teacher=teacher,
students=eqx.nn.inference_mode(students, value=True),
batch_size=eval_batch_size,
num_examples=eval_num_examples,
iteration=train_step_num,
batched_key="init scale",
batched_values=student_init_scale,
),
for epoch_num in range(1, train_num_epochs + 1):
# Reset data generation.
iter_train_key = train_key
for step_num in range(1, num_training_batches + 1):
# Mutate key.
(iter_train_key,) = jax.random.split(iter_train_key, 1)

# Last batch may be smaller.
if ragged_last_batch and step_num == num_training_batches:
iter_batch_size = last_batch_size
else:
iter_batch_size = train_batch_size

train_loss, students, opt_state = batched_student_batch_train_step(
iter_train_key,
iter_batch_size,
teacher,
students,
optimizer,
opt_state,
)

logging.info(f"\titeration:\t{train_step_num}")
for init_scale, loss in zip(
student_init_scales.tolist(),
metrics[-1].groupby("init scale").loss.mean(),
strict=True,
):
logging.info(f"\t\t\t\tinit scale:\t{init_scale:.4f}\tloss:\t{loss:.4f}")
if step_num % eval_interval == 0:
metrics.append(
batched_student_evaluate(
key=eval_key,
teacher=teacher,
students=eqx.nn.inference_mode(students, value=True),
batch_size=eval_batch_size,
num_examples=eval_num_examples,
epoch=epoch_num,
iteration=step_num,
batched_key="init scale",
batched_values=student_init_scale,
),
)

logging.info(
f"\titeration:\t{(epoch_num - 1) * num_training_batches + step_num}",
)
for init_scale, loss in zip(
student_init_scales.tolist(),
metrics[-1].groupby("init scale").loss.mean(),
strict=True,
):
logging.info(f"\t\t\t\tinit scale:\t{init_scale:.4f}\tloss:\t{loss:.4f}")

training_time = time.time() - training_start_time
logging.info(f"Finished training in {training_time:0.2f} seconds.")
Expand All @@ -528,13 +580,14 @@ def make_students(init_scale: float, key: Array) -> eqx.Module:
teacher_init_scale=1e-1,
student_num_hiddens=(100, 100),
student_activation_fn=jax.nn.tanh,
student_init_scale=np.logspace(-4, 1, num=40).tolist(),
student_init_scale=np.logspace(-4, 1, num=20).tolist(),
optimizer_fn=optax.sgd,
learning_rate=1e-2,
train_num_epochs=int(1e1),
train_batch_size=256,
eval_batch_size=256,
num_training_iterations=int(1e3),
train_num_examples=int(1e5),
eval_interval=int(1e1),
eval_batch_size=256,
eval_num_examples=int(1e4),
)

Expand All @@ -545,6 +598,7 @@ def make_students(init_scale: float, key: Array) -> eqx.Module:
data=metrics_df.round(4),
x="training iteration",
y="accuracy @ 0.001",
errorbar=None,
hue="init scale",
markers=True,
palette=sns.cubehelix_palette(n_colors=metrics_df["init scale"].nunique()),
Expand Down

0 comments on commit 0352419

Please sign in to comment.