From 1be092c358e6c135fa5503bd62ad07bc4d11f4e1 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sun, 10 Dec 2023 21:23:15 +0800 Subject: [PATCH] auto scaling freq --- msamp/megatron/optimizer/distrib_optimizer.py | 75 ++++++++++--------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/msamp/megatron/optimizer/distrib_optimizer.py b/msamp/megatron/optimizer/distrib_optimizer.py index d3c5bb98..a0c5599f 100644 --- a/msamp/megatron/optimizer/distrib_optimizer.py +++ b/msamp/megatron/optimizer/distrib_optimizer.py @@ -540,43 +540,44 @@ def reduce_model_grads(self, args, timers): # noqa: C901 if args.wgrad_auto_scaling: # Weight Gradient Auto Scaling - timers('wgrad-auto-scaling', log_level=1).start(barrier=args.barrier_with_L1_time) - - # update pre_scale in this partition - for model_group in self.model_fp8_groups: - for p in model_group: - g = p.main_grad - if g is not None and not torch.is_tensor(g): - if g.qtype != Dtypes.kfloat8_e4m3: - raise TypeError('g.qtype != Dtypes.kfloat8_e4m3: {}'.format(g.qtype)) - # stat overflow ratio - num_infs = torch.count_nonzero((g.value & 0x7f) == 126) - overflow_ratio = num_infs / g.numel() - if overflow_ratio > args.wgrad_auto_scaling_ratio: - g.meta.pre_scale.div_(2.0) - else: - g.meta.pre_scale.mul_(2.0**(1.0 / args.wgrad_auto_scaling_window)) - - # synchonize pre_scale in all partitions - for model_id, model in enumerate(self.models): - # all fp8 gradients - partitions = self.model_gbuf_ranges[model_id][torch.uint8]['partitions'] - fp8_grads = [[p.main_grad for p in part.keys()] for part in partitions] - # pre_scales in the partition `data_parallel_rank` - pre_scales = [g.meta.pre_scale for g in fp8_grads[data_parallel_rank]] - max_elems_per_rank = max(model._grad_buffer_num_params) - pre_scales = torch.cat(pre_scales) - # padding to max_elems_per_rank - pad = max_elems_per_rank - pre_scales.numel() - pre_scales = F.pad(pre_scales, (0, pad)) - output_pre_scales = pre_scales.new_empty((data_parallel_world_size, max_elems_per_rank)) - torch.distributed._all_gather_base(output_pre_scales, pre_scales, group=data_parallel_group) - # assign pre_scale to all fp8 gradients - for grads, pre_scales in zip(fp8_grads, output_pre_scales): - for g, pre_scale in zip(grads, pre_scales): - g.meta.pre_scale.copy_(pre_scale) - - timers('wgrad-auto-scaling').stop() + if args.curr_iteration % args.wgrad_auto_scaling_freq == 0: + timers('wgrad-auto-scaling', log_level=1).start(barrier=args.barrier_with_L1_time) + + # update pre_scale in this partition + for model_group in self.model_fp8_groups: + for p in model_group: + g = p.main_grad + if g is not None and not torch.is_tensor(g): + if g.qtype != Dtypes.kfloat8_e4m3: + raise TypeError('g.qtype != Dtypes.kfloat8_e4m3: {}'.format(g.qtype)) + # stat overflow ratio + num_infs = torch.count_nonzero((g.value & 0x7f) == 126) + overflow_ratio = num_infs / g.numel() + if overflow_ratio > args.wgrad_auto_scaling_ratio: + g.meta.pre_scale.div_(2.0) + else: + g.meta.pre_scale.mul_(2.0**(1.0 / args.wgrad_auto_scaling_window)) + + # synchonize pre_scale in all partitions + for model_id, model in enumerate(self.models): + # all fp8 gradients + partitions = self.model_gbuf_ranges[model_id][torch.uint8]['partitions'] + fp8_grads = [[p.main_grad for p in part.keys()] for part in partitions] + # pre_scales in the partition `data_parallel_rank` + pre_scales = [g.meta.pre_scale for g in fp8_grads[data_parallel_rank]] + max_elems_per_rank = max(model._grad_buffer_num_params) + pre_scales = torch.cat(pre_scales) + # padding to max_elems_per_rank + pad = max_elems_per_rank - pre_scales.numel() + pre_scales = F.pad(pre_scales, (0, pad)) + output_pre_scales = pre_scales.new_empty((data_parallel_world_size, max_elems_per_rank)) + torch.distributed._all_gather_base(output_pre_scales, pre_scales, group=data_parallel_group) + # assign pre_scale to all fp8 gradients + for grads, pre_scales in zip(fp8_grads, output_pre_scales): + for g, pre_scale in zip(grads, pre_scales): + g.meta.pre_scale.copy_(pre_scale) + + timers('wgrad-auto-scaling').stop() def gather_model_params(self, args, timers): # noqa: C901 """All-gather updated model params.