diff --git a/msamp/optim/adamw.py b/msamp/optim/adamw.py index 19e41277..487034d8 100644 --- a/msamp/optim/adamw.py +++ b/msamp/optim/adamw.py @@ -185,9 +185,9 @@ def adamw_fn( # noqa: C901 for i, param in enumerate(params): grad = grads[i].float() if not maximize else -grads[i].float() - exp_avgs[i].meta.scale = _new_exp_avg_factors[i] if self.tensor_scale else 1.0 + exp_avgs[i].meta.scale = _new_exp_avg_factors[i] if self.tensor_scale else torch.ones((), device='cuda') exp_avgs[i].meta.scale_inv.fill_(1.0 / exp_avgs[i].meta.scale) - exp_avg_sqs[i].meta.scale = _new_exp_avg_sq_factors[i] if self.tensor_scale else 1.0 + exp_avg_sqs[i].meta.scale = _new_exp_avg_sq_factors[i] if self.tensor_scale else torch.ones((), device='cuda') exp_avg_sqs[i].meta.scale_inv.fill_(1.0 / exp_avg_sqs[i].meta.scale) # update state msamp_adamw.adamw_fp8_stage2_compute( diff --git a/msamp/te/extension.py b/msamp/te/extension.py index d7642be8..70b01bd7 100644 --- a/msamp/te/extension.py +++ b/msamp/te/extension.py @@ -121,7 +121,7 @@ def cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out=None): return TeExtensionOverrider.original_cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype) return TeExtensionOverrider.original_cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out) - staticmethod + @staticmethod def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: """Cast tensor to dtype""" with torch.enable_grad():