Skip to content

Commit

Permalink
use duck-typing to avoid per-optimizer patches
Browse files Browse the repository at this point in the history
  • Loading branch information
tmm1 authored and winglian committed Sep 6, 2024
1 parent 1145ede commit 367ac00
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,6 @@ def _get_fsdp_ckpt_kwargs():
return {}


def _is_schedule_free_optimizer(optimizer):
return "ScheduleFree" in optimizer.__class__.__name__


if TYPE_CHECKING:
import optuna

Expand Down Expand Up @@ -3443,7 +3439,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
if _is_schedule_free_optimizer(self.optimizer):
if hasattr(self.optimizer, 'train') and callable(getattr(self.optimizer, 'train')):
self.optimizer.train()

inputs = self._prepare_inputs(inputs)
Expand Down Expand Up @@ -3996,7 +3992,7 @@ def evaluation_loop(
logger.info(f" Batch size = {batch_size}")

model.eval()
if _is_schedule_free_optimizer(self.optimizer):
if hasattr(self.optimizer, 'eval') and callable(getattr(self.optimizer, 'eval')):
self.optimizer.eval()

self.callback_handler.eval_dataloader = dataloader
Expand Down Expand Up @@ -4611,7 +4607,7 @@ def prediction_loop(
inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)

model.eval()
if _is_schedule_free_optimizer(self.optimizer):
if hasattr(self.optimizer, 'eval') and callable(getattr(self.optimizer, 'eval')):
self.optimizer.eval()

if args.past_index >= 0:
Expand Down

0 comments on commit 367ac00

Please sign in to comment.