Skip to content

Commit

Permalink
Merge branch 'main' into yuxaing/ds_te_example
Browse files Browse the repository at this point in the history
  • Loading branch information
tocean committed Dec 14, 2023
2 parents b9329cb + bf6f01a commit 97ace2e
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ def __init__(self, init_optimizer, *args, **kwargs): # noqa: C901
**kwargs: Arbitrary keyword arguments.
"""
self.fp8_param_groups = []
dtype = torch.float16
for pg in init_optimizer.param_groups:
for p in pg['params']:
if p.requires_grad and not isinstance(p, ScalingTensor):
dtype = p.dtype
break

fake_param = torch.nn.parameter.Parameter(torch.zeros((), dtype=dtype))
fake_index = 0
for pg in init_optimizer.param_groups:
fp8_params = []
hp_params = []
Expand All @@ -45,6 +54,13 @@ def __init__(self, init_optimizer, *args, **kwargs): # noqa: C901
else:
hp_params.append(p)
self.fp8_param_groups.append(fp8_params)
# DeepSpeedZeroOptimizer will crash if there is no parameters in any parameter group,
# so add a fake parameter.
if len(hp_params) == 0:
param_names = args[0]
param_names[fake_param] = 'fake_' + str(fake_index)
fake_index += 1
hp_params.append(fake_param)
pg['params'] = hp_params

assert len(self.fp8_param_groups) == len(init_optimizer.param_groups)
Expand Down Expand Up @@ -139,9 +155,14 @@ def _pad_and_flat(self, values_partitions, group_fp8_mems, group_id):
torch.Tensor: flat fp8 groups.
"""
partition_size = dist.get_world_size(group=self.dp_process_group)
ref_value = values_partitions[0][0]
dtype = ref_value.dtype
assert all(v.dtype == dtype for v in chain(*values_partitions))
ref_value = None
for partition in values_partitions:
if len(partition) > 0:
ref_value = partition[0]
break
if ref_value is not None:
dtype = ref_value.dtype
assert all(v.dtype == dtype for v in chain(*values_partitions))

align = self.fp8_nccl_start_alignment_factor
max_flat_numels = max(group_fp8_mems)
Expand Down Expand Up @@ -777,12 +798,12 @@ def all_gather_fp8_metas(self):
continue
partition_size = len(params_partitions)
scale_invs_partitions = [[p.meta.scale_inv for p in ps] for ps in params_partitions]
ref_scale = scale_invs_partitions[0][0]
align = self.fp8_nccl_start_alignment_factor
max_flat_numels = (max_flat_numels + align - 1) // align * align
for pi in range(partition_size):
pad = max_flat_numels - numels[pi]
scale_invs_partitions[pi].append(ref_scale.new_empty((pad, )))
scale_invs_partitions[pi].append(torch.empty((pad, ), dtype=torch.float32, device='cuda'))

scales = list(chain(*scale_invs_partitions))
scale_invs_groups.append(scales)
flat = _flatten_dense_tensors(scales)
Expand Down

0 comments on commit 97ace2e

Please sign in to comment.