Skip to content

Commit

Permalink
enable fp8 before reducing gradient and sync parameters in optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
tocean committed Jan 4, 2024
1 parent 8a89b01 commit 27f32b2
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 20 deletions.
22 changes: 14 additions & 8 deletions examples/fsdp_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler
ddp_loss[0] += loss.item()
ddp_loss[1] += len(data)

#break
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if rank == 0:
print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1]))
Expand Down Expand Up @@ -145,11 +146,12 @@ def fsdp_main(rank, world_size, args):

model = Net().to(rank)

from msamp.nn import LinearReplacer
from msamp.common.dtype import Dtypes
from msamp.optim import FSDPAdam
if args.msamp:
from msamp.nn import LinearReplacer
from msamp.common.dtype import Dtypes
from msamp.optim import FSDPAdam

model = LinearReplacer.replace(model, weight_qtype=Dtypes.kfloat8_e4m3)
model = LinearReplacer.replace(model, weight_qtype=Dtypes.kfloat8_e4m3)

if rank == 0:
print(f'model:')
Expand All @@ -160,9 +162,11 @@ def fsdp_main(rank, world_size, args):
if rank == 0:
print(f'FSDP model:')
print(f'{model}')

# optimizer = LBAdam(model.parameters(), lr=args.lr)
optimizer = FSDPAdam(model.parameters(), lr=args.lr)

if args.msamp:
optimizer = FSDPAdam(model.parameters(), lr=args.lr)
else:
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
init_start_event.record()
Expand Down Expand Up @@ -190,7 +194,7 @@ def fsdp_main(rank, world_size, args):
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=1, metavar='N',
help='number of epochs to train (default: 14)')
help='number of epochs to train (default: 2)')
# parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
# help='learning rate (default: 1.0)')
parser.add_argument('--lr', type=float, default=3e-4, metavar='LR',
Expand All @@ -203,6 +207,8 @@ def fsdp_main(rank, world_size, args):
help='random seed (default: 1)')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
parser.add_argument('--msamp', action='store_true', default=False,
help='whether use MS-AMP')
args = parser.parse_args()

torch.manual_seed(args.seed)
Expand Down
40 changes: 35 additions & 5 deletions msamp/fsdp/_runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,12 +738,42 @@ def _post_backward_hook(
if numel_to_pad > 0
else unsharded_grad
)

new_sharded_grad = torch.empty_like(chunks[0]) # padded
state._communication_hook(
state._communication_hook_state,
padded_unsharded_grad,
new_sharded_grad,
)

start = 0
end = 0
has_meta = False
for meta in flat_param._metas:
if meta is not None:
has_meta = True
break
if has_meta:
for i, meta in enumerate(flat_param._metas):
start += flat_param._numels[i - 1] if i >= 1 else 0
end += flat_param._numels[i]
if meta is not None:
from msamp.common.dtype import Dtypes
from msamp.operators.dist_op import DistOp
dtype = Dtypes.get_dtype_from_qtype(meta.qtype)
DistOp.enable_fp8(meta.qtype)
torch.distributed.all_reduce(padded_unsharded_grad[start:end].view(dtype), group=state.process_group)
DistOp.disable_fp8()
else:
default_hooks.allreduce_hook(
state=state._communication_hook_state,
grad=padded_unsharded_grad[start:end],
)
start = state.rank * new_sharded_grad.numel()
end = (state.rank + 1) * new_sharded_grad.numel()
new_sharded_grad.copy_(padded_unsharded_grad[start:end])
else:
state._communication_hook(
state._communication_hook_state,
padded_unsharded_grad,
new_sharded_grad,
)

if handle._sharding_strategy in (
HandleShardingStrategy.HYBRID_SHARD,
HandleShardingStrategy._HYBRID_SHARD_ZERO2,
Expand Down
12 changes: 12 additions & 0 deletions msamp/fsdp/flat_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,18 @@ def _init_shard_metadata(
self.flat_param._shard_indices, # type: ignore[attr-defined]
) = self._get_shard_metadata(start, end)
self.flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined]

# update meta.group here.
start_offset = 0
end_offset = 0
for i, meta in enumerate(self.flat_param._metas):
start_offset += self.flat_param._numels[i-1] if i >=1 else 0
end_offset += self.flat_param._numels[i]
if meta is not None:
start_rank = start_offset // sharded_flat_param_numel
end_rank = (end_offset-1) // sharded_flat_param_numel
ranks = list(range(start_rank, end_rank + 1))
meta.group = dist.new_group(ranks=ranks)

def _get_shard_metadata(
self,
Expand Down
14 changes: 7 additions & 7 deletions msamp/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
self.use_adam = not adam_w_mode
self.set_grad_none = set_grad_none

class FSDPAdam(LBAdamW):
class FSDPAdam(LBAdam):
def __init__(
self,
params,
Expand Down Expand Up @@ -130,6 +130,7 @@ def zero_grad(self, set_to_none=False):
param.grad.zero_()

def step(self):
torch.set_printoptions(profile="full")
# cast gradient to ScalingTensor
for i, param in enumerate(self.original_params):
if param.grad is None:
Expand All @@ -141,12 +142,11 @@ def step(self):
self.master_weights[i].grad = ScalingTensor(param.grad.view(dtype), grad_meta)
param.grad = None

# call self.optimizer.step() to update master weight
# call step() to update master weight
super().step()

# copy master weight to weight
# sync params and copy master weight to weight
for i, param in enumerate(self.original_params):
if param.grad is None:
continue
if hasattr(param, '_meta') and param._meta is not None:
param.copy_(self.master_weight[i].cast(param._meta.qtype).view(torch.float32))
if hasattr(param, '_meta') and param._meta is not None and param.numel() > 0:
data = self.master_weights[i].float().cast(param._meta.qtype, param._meta, True).value.view(torch.float32)
param.data.copy_(data)

0 comments on commit 27f32b2

Please sign in to comment.