Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove to restriction for 4-bit model #33122

Merged
merged 5 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 38 additions & 22 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2861,38 +2861,54 @@ def get_memory_footprint(self, return_buffers=True):
def cuda(self, *args, **kwargs):
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
raise ValueError("`.cuda` is not supported for HQQ-quantized models.")
# Checks if the model has been loaded in 8-bit
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
raise ValueError(
"Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`."
)
if getattr(self, "is_loaded_in_8bit", False):
raise ValueError(
"Calling `cuda()` is not supported for `8-bit` quantized models. "
" Please use the model as it is, since the model has already been set to the correct devices."
)
elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
raise ValueError(
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
)
else:
return super().cuda(*args, **kwargs)

@wraps(torch.nn.Module.to)
def to(self, *args, **kwargs):
# For BNB/GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours.
# the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
dtype_present_in_args = "dtype" in kwargs

if not dtype_present_in_args:
for arg in args:
if isinstance(arg, torch.dtype):
dtype_present_in_args = True
break

if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
raise ValueError("`.to` is not supported for HQQ-quantized models.")
# Checks if the model has been loaded in 8-bit
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
raise ValueError(
"`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`."
)
elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
# For GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours.
# the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
dtype_present_in_args = False

if "dtype" not in kwargs:
for arg in args:
if isinstance(arg, torch.dtype):
dtype_present_in_args = True
break
else:
dtype_present_in_args = True
if dtype_present_in_args:
raise ValueError(
"You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
" desired `dtype` by passing the correct `torch_dtype` argument."
)

if getattr(self, "is_loaded_in_8bit", False):
raise ValueError(
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`."
)
elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
raise ValueError(
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
)
elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
if dtype_present_in_args:
raise ValueError(
"You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired"
Expand Down
48 changes: 39 additions & 9 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,29 +256,56 @@ def test_generate_quality_dequantize(self):

self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

def test_device_assignment(self):
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
self.skipTest(reason="This test requires bitsandbytes >= 0.43.2")

mem_before = self.model_4bit.get_memory_footprint()

# Move to CPU
self.model_4bit.to("cpu")
self.assertEqual(self.model_4bit.device.type, "cpu")
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)

# Move back to CUDA device
self.model_4bit.to(0)
self.assertEqual(self.model_4bit.device, torch.device(0))
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)

def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
Test whether trying to cast (or assigning a device to) a model after converting it in 4-bit will throw an error.
Checks also if other models are casted correctly.
"""
with self.assertRaises(ValueError):
# Tries with `str`
self.model_4bit.to("cpu")

# Moving with `to` or `cuda` is not supported with versions < 0.43.2.
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
with self.assertRaises(ValueError):
# Tries with `str`
self.model_4bit.to("cpu")

with self.assertRaises(ValueError):
# Tries with a `device`
self.model_4bit.to(torch.device("cuda:0"))

with self.assertRaises(ValueError):
# Tries with `cuda`
self.model_4bit.cuda()

with self.assertRaises(ValueError):
# Tries with a `dtype``
# Tries with a `dtype`
self.model_4bit.to(torch.float16)

with self.assertRaises(ValueError):
# Tries with a `device`
self.model_4bit.to(torch.device("cuda:0"))
# Tries with a `dtype` and `device`
self.model_4bit.to(device="cuda:0", dtype=torch.float16)

with self.assertRaises(ValueError):
# Tries with a `device`
# Tries with a cast
self.model_4bit.float()

with self.assertRaises(ValueError):
# Tries with a `device`
# Tries with a cast
self.model_4bit.half()

# Test if we did not break anything
Expand All @@ -287,6 +314,9 @@ def test_device_and_dtype_assignment(self):
self.model_fp16 = self.model_fp16.to(torch.float32)
_ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)

# Check that this does not throw an error
_ = self.model_fp16.cuda()

# Check this does not throw an error
_ = self.model_fp16.to("cpu")

Expand Down
Loading