Skip to content

Commit

Permalink
make auto_wrap_policy work
Browse files Browse the repository at this point in the history
  • Loading branch information
tocean committed Jan 8, 2024
1 parent 27f32b2 commit 75cc279
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 46 deletions.
15 changes: 9 additions & 6 deletions examples/fsdp_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,17 @@ def fsdp_main(rank, world_size, args):
model = Net().to(rank)

if args.msamp:
from msamp.nn import LinearReplacer
from msamp.common.dtype import Dtypes
from msamp.fsdp.replacer import FsdpReplacer
from msamp.optim import FSDPAdam

model = LinearReplacer.replace(model, weight_qtype=Dtypes.kfloat8_e4m3)

model = FsdpReplacer.replace(model)

if rank == 0:
print(f'model:')
print(f'{model}')
for name, parameter in model.named_parameters():
print(f'name:{name}, numel:{parameter.numel()}')

model = FSDP(model, use_orig_params=True)
model = FSDP(model, use_orig_params=True, auto_wrap_policy=my_auto_wrap_policy)

if rank == 0:
print(f'FSDP model:')
Expand All @@ -167,6 +167,9 @@ def fsdp_main(rank, world_size, args):
optimizer = FSDPAdam(model.parameters(), lr=args.lr)
else:
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

if rank == 0:
print(f'optimizer initialized')

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
init_start_event.record()
Expand Down
12 changes: 12 additions & 0 deletions msamp/fsdp/flat_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ def _init_shard_metadata(
ranks = list(range(start_rank, end_rank + 1))
meta.group = dist.new_group(ranks=ranks)


def _get_shard_metadata(
self,
start: int,
Expand Down Expand Up @@ -1469,12 +1470,19 @@ def _use_unsharded_views(self, as_params: bool) -> None:
tensor.data = view # type: ignore[union-attr]
assert tensor is not None # mypy
param_var = tensor

setattr(module, param_name, param_var)
if (
self._use_orig_params
and self._training_state == HandleTrainingState.FORWARD
):
module._parameters[param_name] = param_var # type: ignore[assignment]

param_var._fp8 = True
param_var._scaling_metas = self.flat_param._scaling_metas[i]
param_var._meta = self.flat_param._metas[i]
param_var._padded = self.flat_param._paddeds[i]
param_var._original_shape = self.flat_param._original_shapes[i]
for i, (
param_name,
module,
Expand Down Expand Up @@ -1615,6 +1623,10 @@ def _use_sharded_views(self) -> None:
zip(self.flat_param._params, self.flat_param._param_infos)
):
setattr(module, param_name, param)
if self.flat_param._metas[i] is not None:
param._meta = self.flat_param._metas[i]
param._grad_meta = self.flat_param._scaling_metas[i]['wgrad']

in_sharded_flat_param = (
i >= start
and i <= end
Expand Down
37 changes: 1 addition & 36 deletions msamp/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,26 +407,6 @@ def __init__(
_init_prefetching_state(self, backward_prefetch, forward_prefetch)
_init_buffer_state(self, module)

for name, submodule in module.named_modules():
params_to_process = list(submodule.named_parameters(recurse=False))
for param_name, param in params_to_process:
if not isinstance(param, torch.Tensor):
data = param.value.view(-1)
padded = 0
if data.numel() % 4 != 0:
padded = 4 - data.numel() % 4
data = torch.nn.functional.pad(data, (0, padded))

data = data.view(dtype=torch.float32)
new_param = torch.nn.Parameter(data)
new_param._fp8 = True
new_param._original_shape = param.shape
new_param._padded = 0
new_param._meta = param.meta
new_param._scaling_metas = param._scaling_metas

setattr(submodule, param_name, new_param)

_init_param_handle_from_module(
self,
module,
Expand Down Expand Up @@ -770,17 +750,6 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
f"{self.compute_device} but got {handle.flat_param.device}",
)

i = 0
for _, submodule in self._fsdp_wrapped_module.named_modules():
for param_name, param in submodule.named_parameters(recurse=False):
if self._flat_param._metas[i] is not None:
param._fp8 = True
param._scaling_metas = self._flat_param._scaling_metas[i]
param._meta = self._flat_param._metas[i]
param._padded = self._flat_param._paddeds[i]
param._original_shape = self._flat_param._original_shapes[i]
i += 1

output = self._fsdp_wrapped_module(*args, **kwargs)
return _post_forward(self, self._handles, reshard_fn, self, unused, output)

Expand Down Expand Up @@ -928,12 +897,8 @@ def named_parameters(
when inside the :meth:`summon_full_params` context manager.
"""
should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
i = 0

for param_name, param in super().named_parameters(*args, **kwargs):
if self._flat_param._metas[i] is not None:
param._meta = self._flat_param._metas[i]
param._grad_meta = self._flat_param._scaling_metas[i]['wgrad']
i += 1
if should_clean_name:
# Remove any instances of the FSDP-specific prefix; there can
# be multiple in the case of nested FSDP modules
Expand Down
39 changes: 39 additions & 0 deletions msamp/fsdp/replacer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""MS-AMP te.replacer module."""

import torch

from msamp.common.dtype import Dtypes
from msamp.nn import LinearReplacer


class FsdpReplacer:
"""A replacer to replace the FP8 weights with FP32 nn.Parameter and attributes."""

@classmethod
def replace(cls, model):
"""Replace the weights with ScalingParameter in transformer engine modules."""

model = LinearReplacer.replace(model, weight_qtype=Dtypes.kfloat8_e4m3)
for _, submodule in model.named_modules():
params_to_process = list(submodule.named_parameters(recurse=False))
for param_name, param in params_to_process:
if not isinstance(param, torch.Tensor):
data = param.value.view(-1)
padded = 0
if data.numel() % 4 != 0:
padded = 4 - data.numel() % 4
data = torch.nn.functional.pad(data, (0, padded))

data = data.view(dtype=torch.float32)
new_param = torch.nn.Parameter(data)
new_param._fp8 = True
new_param._original_shape = param.shape
new_param._padded = 0
new_param._meta = param.meta
new_param._scaling_metas = param._scaling_metas

setattr(submodule, param_name, new_param)
return model
6 changes: 3 additions & 3 deletions msamp/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def forward(ctx, input, weight, metas, dtype_holder):
dtype_holder (torch.Tensor): A tensor to hold the output dtype. The required_grad of this tensor
should be if input.required_grad is False.
"""
if hasattr(weight, '_fp8'):
if isinstance(weight, torch.Tensor):
padded = weight._padded
original_shape = weight._original_shape
meta = weight._meta
Expand All @@ -36,7 +36,7 @@ def forward(ctx, input, weight, metas, dtype_holder):
weight = weight[0: weight.numel() - padded]
weight = weight.view(original_shape)
weight = ScalingParameter(ScalingTensor(weight, meta))
ctx._fp8 = True
ctx._returnWgrad = True

ctx.metas = metas
model_state.check_metas_in_flat(metas)
Expand Down Expand Up @@ -109,7 +109,7 @@ def backward(ctx, output_grad):
)
del old_wgrad

if ctx._fp8:
if ctx._returnWgrad:
wgrad = wgrad.cast(Dtypes.kfloat8_e4m3, meta=wgrad_meta, sync=True)
wgrad = wgrad.value.view(-1).view(dtype=torch.float32)
wgrad.meta = wgrad_meta
Expand Down
1 change: 0 additions & 1 deletion msamp/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def zero_grad(self, set_to_none=False):
param.grad.zero_()

def step(self):
torch.set_printoptions(profile="full")
# cast gradient to ScalingTensor
for i, param in enumerate(self.original_params):
if param.grad is None:
Expand Down

0 comments on commit 75cc279

Please sign in to comment.