Skip to content

Commit

Permalink
[Fix] Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
fanqiNO1 committed Feb 22, 2024
1 parent 6f8c6c7 commit 5d1ecca
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 26 deletions.
4 changes: 2 additions & 2 deletions mmengine/model/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def train_step(self, data: Union[dict, tuple, list],
optim_wrapper.update_params(parsed_losses)
return log_vars

def val_step(self, data: Union[tuple, dict, list]) -> list:
def val_step(self, data: Union[tuple, dict, list]) -> Union[tuple, list]:
"""Gets the predictions of given data.
Calls ``self.data_preprocessor(data, False)`` and
Expand All @@ -132,7 +132,7 @@ def val_step(self, data: Union[tuple, dict, list]) -> list:
data = self.data_preprocessor(data, False)
return self._run_forward(data, mode='predict') # type: ignore

def test_step(self, data: Union[dict, tuple, list]) -> list:
def test_step(self, data: Union[dict, tuple, list]) -> Union[tuple, list]:
"""``BaseModel`` implements ``test_step`` the same as ``val_step``.
Args:
Expand Down
49 changes: 25 additions & 24 deletions mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,19 +399,19 @@ def run_iter(self, idx, data_batch: Sequence[dict]):
'before_val_iter', batch_idx=idx, data_batch=data_batch)
# outputs should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.val_step(data_batch)
results = self.runner.model.test_step(data_batch)
if isinstance(results, tuple):
outputs, loss = results
elif isinstance(results, list):
outputs, loss = results, dict()
# get val loss and avoid breaking change
if len(outputs) > len(data_batch):
loss = outputs[-1].loss
for loss_name, loss_value in loss.items():
if loss_name not in self.val_loss:
self.val_loss[loss_name] = []
if isinstance(loss_value, torch.Tensor):
self.val_loss[loss_name].append(loss_value.item())
elif is_list_of(loss_value, torch.Tensor):
self.val_loss[loss_name].extend(
[v.item() for v in loss_value])
outputs = outputs[:-1]
for loss_name, loss_value in loss.items():
if loss_name not in self.val_loss:
self.val_loss[loss_name] = []
if isinstance(loss_value, torch.Tensor):
self.val_loss[loss_name].append(loss_value.item())
elif is_list_of(loss_value, torch.Tensor):
self.val_loss[loss_name].extend([v.item() for v in loss_value])

self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
Expand Down Expand Up @@ -493,19 +493,20 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
'before_test_iter', batch_idx=idx, data_batch=data_batch)
# predictions should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.test_step(data_batch)
results = self.runner.model.test_step(data_batch)
if isinstance(results, tuple):
outputs, loss = results
elif isinstance(results, list):
outputs, loss = results, dict()
# get val loss and avoid breaking change
if len(outputs) > len(data_batch):
loss = outputs[-1].loss
for loss_name, loss_value in loss.items():
if loss_name not in self.test_loss:
self.test_loss[loss_name] = []
if isinstance(loss_value, torch.Tensor):
self.test_loss[loss_name].append(loss_value.item())
elif is_list_of(loss_value, torch.Tensor):
self.test_loss[loss_name].extend(
[v.item() for v in loss_value])
outputs = outputs[:-1]
for loss_name, loss_value in loss.items():
if loss_name not in self.test_loss:
self.test_loss[loss_name] = []
if isinstance(loss_value, torch.Tensor):
self.test_loss[loss_name].append(loss_value.item())
elif is_list_of(loss_value, torch.Tensor):
self.test_loss[loss_name].extend(
[v.item() for v in loss_value])

self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
Expand Down

0 comments on commit 5d1ecca

Please sign in to comment.