Skip to content

Commit

Permalink
fixup style
Browse files Browse the repository at this point in the history
  • Loading branch information
tmm1 authored and winglian committed Sep 6, 2024
1 parent 2e857a8 commit 179232c
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3439,7 +3439,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
if hasattr(self.optimizer, "train") and callable(getattr(self.optimizer, "train")):
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
self.optimizer.train()

inputs = self._prepare_inputs(inputs)
Expand Down Expand Up @@ -3992,7 +3992,7 @@ def evaluation_loop(
logger.info(f" Batch size = {batch_size}")

model.eval()
if hasattr(self.optimizer, "eval") and callable(getattr(self.optimizer, "eval")):
if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
self.optimizer.eval()

self.callback_handler.eval_dataloader = dataloader
Expand Down Expand Up @@ -4607,7 +4607,7 @@ def prediction_loop(
inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)

model.eval()
if hasattr(self.optimizer, "eval") and callable(getattr(self.optimizer, "eval")):
if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
self.optimizer.eval()

if args.past_index >= 0:
Expand Down

0 comments on commit 179232c

Please sign in to comment.