Skip to content

Commit

Permalink
remove to restriction for 4-bit model (#33122)
Browse files Browse the repository at this point in the history
* remove to restiction for 4-bit model

* Update src/transformers/modeling_utils.py

Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>

* bitsandbytes: prevent dtype casting while allowing device movement with .to or .cuda

* quality fix

* Improve warning message for .to() and .cuda() on bnb quantized models

---------

Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
  • Loading branch information
SunMarc and matthewdouglas committed Sep 2, 2024
1 parent 97c0f45 commit 9ea1eac
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 31 deletions.
60 changes: 38 additions & 22 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2861,38 +2861,54 @@ def get_memory_footprint(self, return_buffers=True):
def cuda(self, *args, **kwargs):
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
raise ValueError("`.cuda` is not supported for HQQ-quantized models.")
# Checks if the model has been loaded in 8-bit
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
raise ValueError(
"Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`."
)
if getattr(self, "is_loaded_in_8bit", False):
raise ValueError(
"Calling `cuda()` is not supported for `8-bit` quantized models. "
" Please use the model as it is, since the model has already been set to the correct devices."
)
elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
raise ValueError(
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
)
else:
return super().cuda(*args, **kwargs)

@wraps(torch.nn.Module.to)
def to(self, *args, **kwargs):
# For BNB/GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours.
# the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
dtype_present_in_args = "dtype" in kwargs

if not dtype_present_in_args:
for arg in args:
if isinstance(arg, torch.dtype):
dtype_present_in_args = True
break

if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
raise ValueError("`.to` is not supported for HQQ-quantized models.")
# Checks if the model has been loaded in 8-bit
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
raise ValueError(
"`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`."
)
elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
# For GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours.
# the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
dtype_present_in_args = False

if "dtype" not in kwargs:
for arg in args:
if isinstance(arg, torch.dtype):
dtype_present_in_args = True
break
else:
dtype_present_in_args = True
if dtype_present_in_args:
raise ValueError(
"You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
" desired `dtype` by passing the correct `torch_dtype` argument."
)

if getattr(self, "is_loaded_in_8bit", False):
raise ValueError(
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`."
)
elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
raise ValueError(
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
)
elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
if dtype_present_in_args:
raise ValueError(
"You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired"
Expand Down
48 changes: 39 additions & 9 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,29 +256,56 @@ def test_generate_quality_dequantize(self):

self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

def test_device_assignment(self):
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
self.skipTest(reason="This test requires bitsandbytes >= 0.43.2")

mem_before = self.model_4bit.get_memory_footprint()

# Move to CPU
self.model_4bit.to("cpu")
self.assertEqual(self.model_4bit.device.type, "cpu")
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)

# Move back to CUDA device
self.model_4bit.to(0)
self.assertEqual(self.model_4bit.device, torch.device(0))
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)

def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
Test whether trying to cast (or assigning a device to) a model after converting it in 4-bit will throw an error.
Checks also if other models are casted correctly.
"""
with self.assertRaises(ValueError):
# Tries with `str`
self.model_4bit.to("cpu")

# Moving with `to` or `cuda` is not supported with versions < 0.43.2.
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
with self.assertRaises(ValueError):
# Tries with `str`
self.model_4bit.to("cpu")

with self.assertRaises(ValueError):
# Tries with a `device`
self.model_4bit.to(torch.device("cuda:0"))

with self.assertRaises(ValueError):
# Tries with `cuda`
self.model_4bit.cuda()

with self.assertRaises(ValueError):
# Tries with a `dtype``
# Tries with a `dtype`
self.model_4bit.to(torch.float16)

with self.assertRaises(ValueError):
# Tries with a `device`
self.model_4bit.to(torch.device("cuda:0"))
# Tries with a `dtype` and `device`
self.model_4bit.to(device="cuda:0", dtype=torch.float16)

with self.assertRaises(ValueError):
# Tries with a `device`
# Tries with a cast
self.model_4bit.float()

with self.assertRaises(ValueError):
# Tries with a `device`
# Tries with a cast
self.model_4bit.half()

# Test if we did not break anything
Expand All @@ -287,6 +314,9 @@ def test_device_and_dtype_assignment(self):
self.model_fp16 = self.model_fp16.to(torch.float32)
_ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)

# Check that this does not throw an error
_ = self.model_fp16.cuda()

# Check this does not throw an error
_ = self.model_fp16.to("cpu")

Expand Down

0 comments on commit 9ea1eac

Please sign in to comment.