From 27f32b220f91c301f60b783ec12c93ada9ea3c68 Mon Sep 17 00:00:00 2001 From: tocean Date: Thu, 4 Jan 2024 07:49:12 +0000 Subject: [PATCH] enable fp8 before reducing gradient and sync parameters in optimizer --- examples/fsdp_mnist.py | 22 ++++++++++++-------- msamp/fsdp/_runtime_utils.py | 40 +++++++++++++++++++++++++++++++----- msamp/fsdp/flat_param.py | 12 +++++++++++ msamp/optim/adam.py | 14 ++++++------- 4 files changed, 68 insertions(+), 20 deletions(-) diff --git a/examples/fsdp_mnist.py b/examples/fsdp_mnist.py index 7d799fcb..37f45e7f 100644 --- a/examples/fsdp_mnist.py +++ b/examples/fsdp_mnist.py @@ -77,6 +77,7 @@ def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler ddp_loss[0] += loss.item() ddp_loss[1] += len(data) + #break dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) if rank == 0: print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1])) @@ -145,11 +146,12 @@ def fsdp_main(rank, world_size, args): model = Net().to(rank) - from msamp.nn import LinearReplacer - from msamp.common.dtype import Dtypes - from msamp.optim import FSDPAdam + if args.msamp: + from msamp.nn import LinearReplacer + from msamp.common.dtype import Dtypes + from msamp.optim import FSDPAdam - model = LinearReplacer.replace(model, weight_qtype=Dtypes.kfloat8_e4m3) + model = LinearReplacer.replace(model, weight_qtype=Dtypes.kfloat8_e4m3) if rank == 0: print(f'model:') @@ -160,9 +162,11 @@ def fsdp_main(rank, world_size, args): if rank == 0: print(f'FSDP model:') print(f'{model}') - - # optimizer = LBAdam(model.parameters(), lr=args.lr) - optimizer = FSDPAdam(model.parameters(), lr=args.lr) + + if args.msamp: + optimizer = FSDPAdam(model.parameters(), lr=args.lr) + else: + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) init_start_event.record() @@ -190,7 +194,7 @@ def fsdp_main(rank, world_size, args): parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=1, metavar='N', - help='number of epochs to train (default: 14)') + help='number of epochs to train (default: 2)') # parser.add_argument('--lr', type=float, default=1.0, metavar='LR', # help='learning rate (default: 1.0)') parser.add_argument('--lr', type=float, default=3e-4, metavar='LR', @@ -203,6 +207,8 @@ def fsdp_main(rank, world_size, args): help='random seed (default: 1)') parser.add_argument('--save-model', action='store_true', default=False, help='For Saving the current Model') + parser.add_argument('--msamp', action='store_true', default=False, + help='whether use MS-AMP') args = parser.parse_args() torch.manual_seed(args.seed) diff --git a/msamp/fsdp/_runtime_utils.py b/msamp/fsdp/_runtime_utils.py index 8d6726dd..529ab7c8 100644 --- a/msamp/fsdp/_runtime_utils.py +++ b/msamp/fsdp/_runtime_utils.py @@ -738,12 +738,42 @@ def _post_backward_hook( if numel_to_pad > 0 else unsharded_grad ) + new_sharded_grad = torch.empty_like(chunks[0]) # padded - state._communication_hook( - state._communication_hook_state, - padded_unsharded_grad, - new_sharded_grad, - ) + + start = 0 + end = 0 + has_meta = False + for meta in flat_param._metas: + if meta is not None: + has_meta = True + break + if has_meta: + for i, meta in enumerate(flat_param._metas): + start += flat_param._numels[i - 1] if i >= 1 else 0 + end += flat_param._numels[i] + if meta is not None: + from msamp.common.dtype import Dtypes + from msamp.operators.dist_op import DistOp + dtype = Dtypes.get_dtype_from_qtype(meta.qtype) + DistOp.enable_fp8(meta.qtype) + torch.distributed.all_reduce(padded_unsharded_grad[start:end].view(dtype), group=state.process_group) + DistOp.disable_fp8() + else: + default_hooks.allreduce_hook( + state=state._communication_hook_state, + grad=padded_unsharded_grad[start:end], + ) + start = state.rank * new_sharded_grad.numel() + end = (state.rank + 1) * new_sharded_grad.numel() + new_sharded_grad.copy_(padded_unsharded_grad[start:end]) + else: + state._communication_hook( + state._communication_hook_state, + padded_unsharded_grad, + new_sharded_grad, + ) + if handle._sharding_strategy in ( HandleShardingStrategy.HYBRID_SHARD, HandleShardingStrategy._HYBRID_SHARD_ZERO2, diff --git a/msamp/fsdp/flat_param.py b/msamp/fsdp/flat_param.py index 8303c875..6b65e206 100644 --- a/msamp/fsdp/flat_param.py +++ b/msamp/fsdp/flat_param.py @@ -643,6 +643,18 @@ def _init_shard_metadata( self.flat_param._shard_indices, # type: ignore[attr-defined] ) = self._get_shard_metadata(start, end) self.flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined] + + # update meta.group here. + start_offset = 0 + end_offset = 0 + for i, meta in enumerate(self.flat_param._metas): + start_offset += self.flat_param._numels[i-1] if i >=1 else 0 + end_offset += self.flat_param._numels[i] + if meta is not None: + start_rank = start_offset // sharded_flat_param_numel + end_rank = (end_offset-1) // sharded_flat_param_numel + ranks = list(range(start_rank, end_rank + 1)) + meta.group = dist.new_group(ranks=ranks) def _get_shard_metadata( self, diff --git a/msamp/optim/adam.py b/msamp/optim/adam.py index 0793d632..8ad7e29f 100644 --- a/msamp/optim/adam.py +++ b/msamp/optim/adam.py @@ -68,7 +68,7 @@ def __init__( self.use_adam = not adam_w_mode self.set_grad_none = set_grad_none -class FSDPAdam(LBAdamW): +class FSDPAdam(LBAdam): def __init__( self, params, @@ -130,6 +130,7 @@ def zero_grad(self, set_to_none=False): param.grad.zero_() def step(self): + torch.set_printoptions(profile="full") # cast gradient to ScalingTensor for i, param in enumerate(self.original_params): if param.grad is None: @@ -141,12 +142,11 @@ def step(self): self.master_weights[i].grad = ScalingTensor(param.grad.view(dtype), grad_meta) param.grad = None - # call self.optimizer.step() to update master weight + # call step() to update master weight super().step() - # copy master weight to weight + # sync params and copy master weight to weight for i, param in enumerate(self.original_params): - if param.grad is None: - continue - if hasattr(param, '_meta') and param._meta is not None: - param.copy_(self.master_weight[i].cast(param._meta.qtype).view(torch.float32)) \ No newline at end of file + if hasattr(param, '_meta') and param._meta is not None and param.numel() > 0: + data = self.master_weights[i].float().cast(param._meta.qtype, param._meta, True).value.view(torch.float32) + param.data.copy_(data) \ No newline at end of file