Skip to content

Commit

Permalink
auto scaling freq
Browse files Browse the repository at this point in the history
  • Loading branch information
wkcn committed Dec 10, 2023
1 parent 617ee0b commit 1be092c
Showing 1 changed file with 38 additions and 37 deletions.
75 changes: 38 additions & 37 deletions msamp/megatron/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,43 +540,44 @@ def reduce_model_grads(self, args, timers): # noqa: C901

if args.wgrad_auto_scaling:
# Weight Gradient Auto Scaling
timers('wgrad-auto-scaling', log_level=1).start(barrier=args.barrier_with_L1_time)

# update pre_scale in this partition
for model_group in self.model_fp8_groups:
for p in model_group:
g = p.main_grad
if g is not None and not torch.is_tensor(g):
if g.qtype != Dtypes.kfloat8_e4m3:
raise TypeError('g.qtype != Dtypes.kfloat8_e4m3: {}'.format(g.qtype))
# stat overflow ratio
num_infs = torch.count_nonzero((g.value & 0x7f) == 126)
overflow_ratio = num_infs / g.numel()
if overflow_ratio > args.wgrad_auto_scaling_ratio:
g.meta.pre_scale.div_(2.0)
else:
g.meta.pre_scale.mul_(2.0**(1.0 / args.wgrad_auto_scaling_window))

# synchonize pre_scale in all partitions
for model_id, model in enumerate(self.models):
# all fp8 gradients
partitions = self.model_gbuf_ranges[model_id][torch.uint8]['partitions']
fp8_grads = [[p.main_grad for p in part.keys()] for part in partitions]
# pre_scales in the partition `data_parallel_rank`
pre_scales = [g.meta.pre_scale for g in fp8_grads[data_parallel_rank]]
max_elems_per_rank = max(model._grad_buffer_num_params)
pre_scales = torch.cat(pre_scales)
# padding to max_elems_per_rank
pad = max_elems_per_rank - pre_scales.numel()
pre_scales = F.pad(pre_scales, (0, pad))
output_pre_scales = pre_scales.new_empty((data_parallel_world_size, max_elems_per_rank))
torch.distributed._all_gather_base(output_pre_scales, pre_scales, group=data_parallel_group)
# assign pre_scale to all fp8 gradients
for grads, pre_scales in zip(fp8_grads, output_pre_scales):
for g, pre_scale in zip(grads, pre_scales):
g.meta.pre_scale.copy_(pre_scale)

timers('wgrad-auto-scaling').stop()
if args.curr_iteration % args.wgrad_auto_scaling_freq == 0:
timers('wgrad-auto-scaling', log_level=1).start(barrier=args.barrier_with_L1_time)

# update pre_scale in this partition
for model_group in self.model_fp8_groups:
for p in model_group:
g = p.main_grad
if g is not None and not torch.is_tensor(g):
if g.qtype != Dtypes.kfloat8_e4m3:
raise TypeError('g.qtype != Dtypes.kfloat8_e4m3: {}'.format(g.qtype))
# stat overflow ratio
num_infs = torch.count_nonzero((g.value & 0x7f) == 126)
overflow_ratio = num_infs / g.numel()
if overflow_ratio > args.wgrad_auto_scaling_ratio:
g.meta.pre_scale.div_(2.0)
else:
g.meta.pre_scale.mul_(2.0**(1.0 / args.wgrad_auto_scaling_window))

# synchonize pre_scale in all partitions
for model_id, model in enumerate(self.models):
# all fp8 gradients
partitions = self.model_gbuf_ranges[model_id][torch.uint8]['partitions']
fp8_grads = [[p.main_grad for p in part.keys()] for part in partitions]
# pre_scales in the partition `data_parallel_rank`
pre_scales = [g.meta.pre_scale for g in fp8_grads[data_parallel_rank]]
max_elems_per_rank = max(model._grad_buffer_num_params)
pre_scales = torch.cat(pre_scales)
# padding to max_elems_per_rank
pad = max_elems_per_rank - pre_scales.numel()
pre_scales = F.pad(pre_scales, (0, pad))
output_pre_scales = pre_scales.new_empty((data_parallel_world_size, max_elems_per_rank))
torch.distributed._all_gather_base(output_pre_scales, pre_scales, group=data_parallel_group)
# assign pre_scale to all fp8 gradients
for grads, pre_scales in zip(fp8_grads, output_pre_scales):
for g, pre_scale in zip(grads, pre_scales):
g.meta.pre_scale.copy_(pre_scale)

timers('wgrad-auto-scaling').stop()

def gather_model_params(self, args, timers): # noqa: C901
"""All-gather updated model params.
Expand Down

0 comments on commit 1be092c

Please sign in to comment.