diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c3b3ab1a9d4296..80634bb6f06b28 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3439,7 +3439,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, `torch.Tensor`: The tensor with training loss on this batch. """ model.train() - if hasattr(self.optimizer, 'train') and callable(getattr(self.optimizer, 'train')): + if hasattr(self.optimizer, "train") and callable(getattr(self.optimizer, "train")): self.optimizer.train() inputs = self._prepare_inputs(inputs) @@ -3992,7 +3992,7 @@ def evaluation_loop( logger.info(f" Batch size = {batch_size}") model.eval() - if hasattr(self.optimizer, 'eval') and callable(getattr(self.optimizer, 'eval')): + if hasattr(self.optimizer, "eval") and callable(getattr(self.optimizer, "eval")): self.optimizer.eval() self.callback_handler.eval_dataloader = dataloader @@ -4607,7 +4607,7 @@ def prediction_loop( inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) model.eval() - if hasattr(self.optimizer, 'eval') and callable(getattr(self.optimizer, 'eval')): + if hasattr(self.optimizer, "eval") and callable(getattr(self.optimizer, "eval")): self.optimizer.eval() if args.past_index >= 0: