Skip to content

Commit

Permalink
reconstruct
Browse files Browse the repository at this point in the history
  • Loading branch information
MGAMZ committed Sep 23, 2024
1 parent 0934d75 commit 7103c3e
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 3 deletions.
2 changes: 1 addition & 1 deletion mmengine/_strategy/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def load_optim_state_dict(self, state_dict: dict) -> None:
``optimizer.state_dict()``
"""
optim_state_dict = FSDP.optim_state_dict_to_load(
self.model, self.optim_wrapper.optimizer, state_dict)
state_dict, self.model, self.optim_wrapper.optimizer)
self.optim_wrapper.load_state_dict(optim_state_dict)

def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None:
Expand Down
3 changes: 3 additions & 0 deletions mmengine/model/wrappers/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(self,

def train_step(self, data: Union[dict, tuple, list],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
return self.module.train_step(data, optim_wrapper)
"""Interface for model forward, backward and parameters updating during
training process.
Expand Down Expand Up @@ -126,6 +127,7 @@ def train_step(self, data: Union[dict, tuple, list],
return log_vars

def val_step(self, data: Union[dict, tuple, list]) -> list:
return self.module.val_step(data)
"""Gets the prediction of module during validation process.
Args:
Expand All @@ -137,6 +139,7 @@ def val_step(self, data: Union[dict, tuple, list]) -> list:
return self.module.val_step(data)

def test_step(self, data: Union[dict, tuple, list]) -> list:
return self.module.test_step(data)
"""Gets the predictions of module during testing process.
Args:
Expand Down
1 change: 1 addition & 0 deletions mmengine/optim/optimizer/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,5 +207,6 @@ def build_optim_wrapper(model: nn.Module,
type=constructor_type,
optim_wrapper_cfg=optim_wrapper_cfg,
paramwise_cfg=paramwise_cfg))

optim_wrapper = optim_wrapper_constructor(model)
return optim_wrapper
4 changes: 3 additions & 1 deletion mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mmengine.registry import LOOPS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
from mmengine.dataset.sampler import InfiniteSampler
from .amp import autocast
from .base_loop import BaseLoop
from .utils import calc_dynamic_intervals
Expand Down Expand Up @@ -274,13 +275,14 @@ def run(self) -> None:
# In iteration-based training loop, we treat the whole training process
# as a big epoch and execute the corresponding hook.
self.runner.call_hook('before_train_epoch')
if self._iter > 0:
if self._iter > 0 and not isinstance(self.dataloader.sampler, InfiniteSampler):
print_log(
f'Advance dataloader {self._iter} steps to skip data '
'that has already been trained',
logger='current',
level=logging.WARNING)
for _ in range(self._iter):
break
next(self.dataloader_iterator)
while self._iter < self._max_iters and not self.stop_training:
self.runner.model.train()
Expand Down
4 changes: 3 additions & 1 deletion mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,9 +903,11 @@ def wrap_model(
find_unused_parameters=find_unused_parameters)
else:
model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel')

model_wrapper_type = model_wrapper_cfg.get('type')
if isinstance(model_wrapper_type, str):
model_wrapper_type = MODEL_WRAPPERS.get(model_wrapper_type) # type: ignore
model_wrapper_type = MODEL_WRAPPERS.get(
model_wrapper_type) # type: ignore
elif inspect.isclass(model_wrapper_type):
pass
else:
Expand Down

0 comments on commit 7103c3e

Please sign in to comment.