diff --git a/mmengine/_strategy/fsdp.py b/mmengine/_strategy/fsdp.py index 124dfd7c57..0788fafdab 100644 --- a/mmengine/_strategy/fsdp.py +++ b/mmengine/_strategy/fsdp.py @@ -408,7 +408,7 @@ def load_optim_state_dict(self, state_dict: dict) -> None: ``optimizer.state_dict()`` """ optim_state_dict = FSDP.optim_state_dict_to_load( - self.model, self.optim_wrapper.optimizer, state_dict) + state_dict, self.model, self.optim_wrapper.optimizer) self.optim_wrapper.load_state_dict(optim_state_dict) def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None: diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py index 4113aebf9e..b88bc7c2b0 100644 --- a/mmengine/model/wrappers/distributed.py +++ b/mmengine/model/wrappers/distributed.py @@ -95,6 +95,7 @@ def __init__(self, def train_step(self, data: Union[dict, tuple, list], optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: + return self.module.train_step(data, optim_wrapper) """Interface for model forward, backward and parameters updating during training process. @@ -126,6 +127,7 @@ def train_step(self, data: Union[dict, tuple, list], return log_vars def val_step(self, data: Union[dict, tuple, list]) -> list: + return self.module.val_step(data) """Gets the prediction of module during validation process. Args: @@ -137,6 +139,7 @@ def val_step(self, data: Union[dict, tuple, list]) -> list: return self.module.val_step(data) def test_step(self, data: Union[dict, tuple, list]) -> list: + return self.module.test_step(data) """Gets the predictions of module during testing process. Args: diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 8557f4d34c..b57ebc315a 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -207,5 +207,6 @@ def build_optim_wrapper(model: nn.Module, type=constructor_type, optim_wrapper_cfg=optim_wrapper_cfg, paramwise_cfg=paramwise_cfg)) + optim_wrapper = optim_wrapper_constructor(model) return optim_wrapper diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 5a678db7b9..25ff690f0b 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -12,6 +12,7 @@ from mmengine.registry import LOOPS from mmengine.structures import BaseDataElement from mmengine.utils import is_list_of +from mmengine.dataset.sampler import InfiniteSampler from .amp import autocast from .base_loop import BaseLoop from .utils import calc_dynamic_intervals @@ -274,13 +275,14 @@ def run(self) -> None: # In iteration-based training loop, we treat the whole training process # as a big epoch and execute the corresponding hook. self.runner.call_hook('before_train_epoch') - if self._iter > 0: + if self._iter > 0 and not isinstance(self.dataloader.sampler, InfiniteSampler): print_log( f'Advance dataloader {self._iter} steps to skip data ' 'that has already been trained', logger='current', level=logging.WARNING) for _ in range(self._iter): + break next(self.dataloader_iterator) while self._iter < self._max_iters and not self.stop_training: self.runner.model.train() diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 7160ac84d7..435bd55ac0 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -903,9 +903,11 @@ def wrap_model( find_unused_parameters=find_unused_parameters) else: model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel') + model_wrapper_type = model_wrapper_cfg.get('type') if isinstance(model_wrapper_type, str): - model_wrapper_type = MODEL_WRAPPERS.get(model_wrapper_type) # type: ignore + model_wrapper_type = MODEL_WRAPPERS.get( + model_wrapper_type) # type: ignore elif inspect.isclass(model_wrapper_type): pass else: