Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Fixed] scale=INF when casting a tensor to scaling FP32/BF16 tensors #131

Merged
merged 4 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions msamp/common/tensor/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import copy
import torch

from msamp.common.dtype import Floating
from msamp.common.dtype import Floating, Dtypes


class ScalingMeta:
Expand Down Expand Up @@ -104,9 +104,12 @@ def reset_scaling_factor(self, qtype=None):
if qtype is None:
qtype = self.qtype

fp_max = Floating.qfp_max[qtype]
sf = ScalingMeta.compute_scaling_factor(self.amax[0], self.scale, fp_max, 0)
self.scale.copy_(sf)
if qtype in [Dtypes.kfloat32, Dtypes.kbfloat16]:
self.scale.fill_(1)
else:
fp_max = Floating.qfp_max[qtype]
sf = ScalingMeta.compute_scaling_factor(self.amax[0], self.scale, fp_max, 0)
self.scale.copy_(sf)

def copy_(self, src):
"""Copies the members from src into self and returns self.
Expand Down
9 changes: 9 additions & 0 deletions tests/common/tensor/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,15 @@ def _allclose(input, other):
# check if tensor is not changed
self.assertTrue(torch.equal(tensor, tensor_bak))

@decorator.cuda_test
def test_tensor_cast_to_scaling_fp32(self):
"""Test cast function to ScalingFP32 or ScalingBF16 in ScalingTensor."""
for dtype in [Dtypes.kfloat32, Dtypes.kbfloat16]:
with self.subTest(dtype=dtype):
x = torch.tensor([1.0 / 512], dtype=torch.float32, device=self.device)
y = x.cast(dtype)
self.assertTrue(x == y.float())

@decorator.cuda_test
def test_tensor_cast_with_exception_value(self):
"""Test cast function in ScalingTensor with exception value."""
Expand Down
Loading