Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
tocean committed Feb 28, 2024
1 parent 359927d commit d6d2012
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions msamp/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ def adamw_fn( # noqa: C901

for i, param in enumerate(params):
grad = grads[i].float() if not maximize else -grads[i].float()
exp_avgs[i].meta.scale = _new_exp_avg_factors[i] if self.tensor_scale else 1.0
exp_avgs[i].meta.scale = _new_exp_avg_factors[i] if self.tensor_scale else torch.ones((), device='cuda')
exp_avgs[i].meta.scale_inv.fill_(1.0 / exp_avgs[i].meta.scale)
exp_avg_sqs[i].meta.scale = _new_exp_avg_sq_factors[i] if self.tensor_scale else 1.0
exp_avg_sqs[i].meta.scale = _new_exp_avg_sq_factors[i] if self.tensor_scale else torch.ones((), device='cuda')
exp_avg_sqs[i].meta.scale_inv.fill_(1.0 / exp_avg_sqs[i].meta.scale)
# update state
msamp_adamw.adamw_fp8_stage2_compute(
Expand Down
2 changes: 1 addition & 1 deletion msamp/te/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out=None):
return TeExtensionOverrider.original_cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype)
return TeExtensionOverrider.original_cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out)

staticmethod
@staticmethod
def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""Cast tensor to dtype"""
with torch.enable_grad():
Expand Down

0 comments on commit d6d2012

Please sign in to comment.