Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove to restriction for 4-bit model #33122

Merged
merged 5 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 40 additions & 22 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2861,38 +2861,56 @@ def get_memory_footprint(self, return_buffers=True):
def cuda(self, *args, **kwargs):
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
raise ValueError("`.cuda` is not supported for HQQ-quantized models.")
# Checks if the model has been loaded in 8-bit
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
raise ValueError(
"Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`."
)
if getattr(self, "is_loaded_in_8bit", False):
raise ValueError(
"Calling `cuda()` is not supported for `8-bit` quantized models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`."
)
elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
raise ValueError(
"Calling `cuda()` is not supported for `4-bit` quantized models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`. "
"However, if you still want to move the model, you need to install bitsandbytes >= 0.43.2 "
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warning isn't super clear to me in terms of what the user should or should not do; should they install the new version or should they just let the model there? I'd try to clarify this a bit

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good feedback, thanks! Updated. I think in most cases the user would be using .cuda() without realizing it is already on a GPU so I put the current model.device in the message. That should help inform on whether they really meant to move it somewhere else and need to upgrade.

else:
return super().cuda(*args, **kwargs)

@wraps(torch.nn.Module.to)
def to(self, *args, **kwargs):
# For BNB/GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours.
# the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
dtype_present_in_args = "dtype" in kwargs

if not dtype_present_in_args:
for arg in args:
if isinstance(arg, torch.dtype):
dtype_present_in_args = True
break

if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
raise ValueError("`.to` is not supported for HQQ-quantized models.")
# Checks if the model has been loaded in 8-bit
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
raise ValueError(
"`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`."
)
elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
# For GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours.
# the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
dtype_present_in_args = False

if "dtype" not in kwargs:
for arg in args:
if isinstance(arg, torch.dtype):
dtype_present_in_args = True
break
if getattr(self, "is_loaded_in_4bit", False):
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc I've bumped this to 0.43.2 since that's when bitsandbytes-foundation/bitsandbytes#1279 was landed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for updating the PR !

raise ValueError(
"`.to` is not supported for `4-bit`. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`. "
"However, if you still want to move the model, you need to install bitsandbytes >= 0.43.2 "
)
elif dtype_present_in_args:
raise ValueError(
"You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
" desired `dtype` by passing the correct `torch_dtype` argument."
)
else:
dtype_present_in_args = True

raise ValueError(
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`."
)
elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
if dtype_present_in_args:
raise ValueError(
"You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired"
Expand Down
48 changes: 39 additions & 9 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,29 +256,56 @@ def test_generate_quality_dequantize(self):

self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

def test_device_assignment(self):
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
self.skipTest(reason="This test requires bitsandbytes >= 0.43.2")

mem_before = self.model_4bit.get_memory_footprint()

# Move to CPU
self.model_4bit.to("cpu")
self.assertEqual(self.model_4bit.device.type, "cpu")
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)

# Move back to CUDA device
self.model_4bit.to(0)
self.assertEqual(self.model_4bit.device, torch.device(0))
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)

def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
Test whether trying to cast (or assigning a device to) a model after converting it in 4-bit will throw an error.
Checks also if other models are casted correctly.
"""
with self.assertRaises(ValueError):
# Tries with `str`
self.model_4bit.to("cpu")

# Moving with `to` or `cuda` is not supported with versions < 0.43.2.
if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
with self.assertRaises(ValueError):
# Tries with `str`
self.model_4bit.to("cpu")

with self.assertRaises(ValueError):
# Tries with a `device`
self.model_4bit.to(torch.device("cuda:0"))

with self.assertRaises(ValueError):
# Tries with `cuda`
self.model_4bit.cuda()

with self.assertRaises(ValueError):
# Tries with a `dtype``
# Tries with a `dtype`
self.model_4bit.to(torch.float16)

with self.assertRaises(ValueError):
# Tries with a `device`
self.model_4bit.to(torch.device("cuda:0"))
# Tries with a `dtype` and `device`
self.model_4bit.to(device="cuda:0", dtype=torch.float16)

with self.assertRaises(ValueError):
# Tries with a `device`
# Tries with a cast
self.model_4bit.float()

with self.assertRaises(ValueError):
# Tries with a `device`
# Tries with a cast
self.model_4bit.half()

# Test if we did not break anything
Expand All @@ -287,6 +314,9 @@ def test_device_and_dtype_assignment(self):
self.model_fp16 = self.model_fp16.to(torch.float32)
_ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)

# Check that this does not throw an error
_ = self.model_fp16.cuda()

# Check this does not throw an error
_ = self.model_fp16.to("cpu")

Expand Down
Loading