From 5d1eccab8f171f9912c728a511b15062a169a1e8 Mon Sep 17 00:00:00 2001 From: fanqiNO1 <1848839264@qq.com> Date: Thu, 22 Feb 2024 19:05:32 +0800 Subject: [PATCH] [Fix] Fix test --- mmengine/model/base_model/base_model.py | 4 +- mmengine/runner/loops.py | 49 +++++++++++++------------ 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py index 299cd67557..a20ce553d4 100644 --- a/mmengine/model/base_model/base_model.py +++ b/mmengine/model/base_model/base_model.py @@ -116,7 +116,7 @@ def train_step(self, data: Union[dict, tuple, list], optim_wrapper.update_params(parsed_losses) return log_vars - def val_step(self, data: Union[tuple, dict, list]) -> list: + def val_step(self, data: Union[tuple, dict, list]) -> Union[tuple, list]: """Gets the predictions of given data. Calls ``self.data_preprocessor(data, False)`` and @@ -132,7 +132,7 @@ def val_step(self, data: Union[tuple, dict, list]) -> list: data = self.data_preprocessor(data, False) return self._run_forward(data, mode='predict') # type: ignore - def test_step(self, data: Union[dict, tuple, list]) -> list: + def test_step(self, data: Union[dict, tuple, list]) -> Union[tuple, list]: """``BaseModel`` implements ``test_step`` the same as ``val_step``. Args: diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index ec7cc55dd9..71e97f8e83 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -399,19 +399,19 @@ def run_iter(self, idx, data_batch: Sequence[dict]): 'before_val_iter', batch_idx=idx, data_batch=data_batch) # outputs should be sequence of BaseDataElement with autocast(enabled=self.fp16): - outputs = self.runner.model.val_step(data_batch) + results = self.runner.model.test_step(data_batch) + if isinstance(results, tuple): + outputs, loss = results + elif isinstance(results, list): + outputs, loss = results, dict() # get val loss and avoid breaking change - if len(outputs) > len(data_batch): - loss = outputs[-1].loss - for loss_name, loss_value in loss.items(): - if loss_name not in self.val_loss: - self.val_loss[loss_name] = [] - if isinstance(loss_value, torch.Tensor): - self.val_loss[loss_name].append(loss_value.item()) - elif is_list_of(loss_value, torch.Tensor): - self.val_loss[loss_name].extend( - [v.item() for v in loss_value]) - outputs = outputs[:-1] + for loss_name, loss_value in loss.items(): + if loss_name not in self.val_loss: + self.val_loss[loss_name] = [] + if isinstance(loss_value, torch.Tensor): + self.val_loss[loss_name].append(loss_value.item()) + elif is_list_of(loss_value, torch.Tensor): + self.val_loss[loss_name].extend([v.item() for v in loss_value]) self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook( @@ -493,19 +493,20 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None: 'before_test_iter', batch_idx=idx, data_batch=data_batch) # predictions should be sequence of BaseDataElement with autocast(enabled=self.fp16): - outputs = self.runner.model.test_step(data_batch) + results = self.runner.model.test_step(data_batch) + if isinstance(results, tuple): + outputs, loss = results + elif isinstance(results, list): + outputs, loss = results, dict() # get val loss and avoid breaking change - if len(outputs) > len(data_batch): - loss = outputs[-1].loss - for loss_name, loss_value in loss.items(): - if loss_name not in self.test_loss: - self.test_loss[loss_name] = [] - if isinstance(loss_value, torch.Tensor): - self.test_loss[loss_name].append(loss_value.item()) - elif is_list_of(loss_value, torch.Tensor): - self.test_loss[loss_name].extend( - [v.item() for v in loss_value]) - outputs = outputs[:-1] + for loss_name, loss_value in loss.items(): + if loss_name not in self.test_loss: + self.test_loss[loss_name] = [] + if isinstance(loss_value, torch.Tensor): + self.test_loss[loss_name].append(loss_value.item()) + elif is_list_of(loss_value, torch.Tensor): + self.test_loss[loss_name].extend( + [v.item() for v in loss_value]) self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook(