diff --git a/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py b/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py index 25d65457..6ae3a548 100644 --- a/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py +++ b/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py @@ -34,6 +34,15 @@ def __init__(self, init_optimizer, *args, **kwargs): # noqa: C901 **kwargs: Arbitrary keyword arguments. """ self.fp8_param_groups = [] + dtype = torch.float16 + for pg in init_optimizer.param_groups: + for p in pg['params']: + if p.requires_grad and not isinstance(p, ScalingTensor): + dtype = p.dtype + break + + fake_param = torch.nn.parameter.Parameter(torch.zeros((), dtype=dtype)) + fake_index = 0 for pg in init_optimizer.param_groups: fp8_params = [] hp_params = [] @@ -45,6 +54,13 @@ def __init__(self, init_optimizer, *args, **kwargs): # noqa: C901 else: hp_params.append(p) self.fp8_param_groups.append(fp8_params) + # DeepSpeedZeroOptimizer will crash if there is no parameters in any parameter group, + # so add a fake parameter. + if len(hp_params) == 0: + param_names = args[0] + param_names[fake_param] = 'fake_' + str(fake_index) + fake_index += 1 + hp_params.append(fake_param) pg['params'] = hp_params assert len(self.fp8_param_groups) == len(init_optimizer.param_groups) @@ -139,9 +155,14 @@ def _pad_and_flat(self, values_partitions, group_fp8_mems, group_id): torch.Tensor: flat fp8 groups. """ partition_size = dist.get_world_size(group=self.dp_process_group) - ref_value = values_partitions[0][0] - dtype = ref_value.dtype - assert all(v.dtype == dtype for v in chain(*values_partitions)) + ref_value = None + for partition in values_partitions: + if len(partition) > 0: + ref_value = partition[0] + break + if ref_value is not None: + dtype = ref_value.dtype + assert all(v.dtype == dtype for v in chain(*values_partitions)) align = self.fp8_nccl_start_alignment_factor max_flat_numels = max(group_fp8_mems) @@ -777,12 +798,12 @@ def all_gather_fp8_metas(self): continue partition_size = len(params_partitions) scale_invs_partitions = [[p.meta.scale_inv for p in ps] for ps in params_partitions] - ref_scale = scale_invs_partitions[0][0] align = self.fp8_nccl_start_alignment_factor max_flat_numels = (max_flat_numels + align - 1) // align * align for pi in range(partition_size): pad = max_flat_numels - numels[pi] - scale_invs_partitions[pi].append(ref_scale.new_empty((pad, ))) + scale_invs_partitions[pi].append(torch.empty((pad, ), dtype=torch.float32, device='cuda')) + scales = list(chain(*scale_invs_partitions)) scale_invs_groups.append(scales) flat = _flatten_dense_tensors(scales)