Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Pixtral #33449

Merged
merged 62 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
2424d20
initial commit
ArthurZucker Sep 12, 2024
4780f37
gloups
ArthurZucker Sep 12, 2024
1b897b3
updates
ArthurZucker Sep 12, 2024
1e97527
work
ArthurZucker Sep 12, 2024
fb0e78c
weights match
ArthurZucker Sep 12, 2024
eb76b0c
nits
ArthurZucker Sep 12, 2024
334d7a9
nits
ArthurZucker Sep 12, 2024
30439a1
updates to support the tokenizer :)
ArthurZucker Sep 12, 2024
6544127
updates
ArthurZucker Sep 12, 2024
a45122b
Pixtral processor (#33454)
amyeroberts Sep 13, 2024
b6db4ee
Fix token expansion
amyeroberts Sep 13, 2024
b8df95d
nit in conversion script
ArthurZucker Sep 13, 2024
a6443bd
Merge branch 'add-pixtral' of github.com:huggingface/transformers int…
ArthurZucker Sep 13, 2024
185c435
Fix image token list creation
amyeroberts Sep 13, 2024
92c2735
done
ArthurZucker Sep 13, 2024
ea2d9fb
add expected results
ArthurZucker Sep 13, 2024
6ee62a7
Process list of list of images (#33465)
amyeroberts Sep 13, 2024
5f33680
updates
ArthurZucker Sep 13, 2024
cc18d88
working image and processor
ArthurZucker Sep 13, 2024
f04075e
this is the expected format
ArthurZucker Sep 13, 2024
732071b
some fixes
ArthurZucker Sep 13, 2024
3a15b4e
push current updated
ArthurZucker Sep 13, 2024
b773bde
working mult images!
ArthurZucker Sep 13, 2024
6c58167
add a small integration test
ArthurZucker Sep 13, 2024
c4c32fb
Uodate configuration docstring
amyeroberts Sep 13, 2024
9c621af
Formatting
amyeroberts Sep 13, 2024
172b5bc
Config docstring fix
amyeroberts Sep 13, 2024
d34bac0
simplify model test
ArthurZucker Sep 13, 2024
ed74135
Merge branch 'add-pixtral' of github.com:huggingface/transformers int…
ArthurZucker Sep 13, 2024
e090756
fixup modeling and etests
ArthurZucker Sep 13, 2024
3725e23
Return BatchMixFeature in image processor
amyeroberts Sep 13, 2024
26adfec
fix some copies
ArthurZucker Sep 13, 2024
4a65050
Merge branch 'add-pixtral' of github.com:huggingface/transformers int…
ArthurZucker Sep 13, 2024
4ee5cfb
update
ArthurZucker Sep 13, 2024
66de967
nits
ArthurZucker Sep 13, 2024
07c7600
Update model docstring
amyeroberts Sep 13, 2024
ff04e9f
Apply suggestions from code review
amyeroberts Sep 13, 2024
e7fae23
Fix up
amyeroberts Sep 13, 2024
324ba36
updates
ArthurZucker Sep 13, 2024
c4ad4e5
revert modeling changes
ArthurZucker Sep 13, 2024
9f2d98b
update
ArthurZucker Sep 13, 2024
4562fa4
Merge branch 'add-pixtral' of github.com:huggingface/transformers int…
ArthurZucker Sep 13, 2024
97b4d93
update
ArthurZucker Sep 13, 2024
ce23dc3
fix load safe
ArthurZucker Sep 13, 2024
bbf516c
addd liscence
ArthurZucker Sep 13, 2024
b783e7a
update
ArthurZucker Sep 13, 2024
443917f
use pixel_values as required by the model
ArthurZucker Sep 13, 2024
f9291ea
skip some tests and refactor
ArthurZucker Sep 13, 2024
db84a7d
Add pixtral image processing tests (#33476)
amyeroberts Sep 13, 2024
908233f
fixup post merge
ArthurZucker Sep 13, 2024
8c0f8f6
images -> pixel values
ArthurZucker Sep 13, 2024
b7d7760
oups sorry Mr docbuilder
ArthurZucker Sep 13, 2024
be154e1
isort
ArthurZucker Sep 13, 2024
bd721e2
fix
ArthurZucker Sep 13, 2024
2eda353
fix processor tests
ArthurZucker Sep 13, 2024
36e5525
small fixes
ArthurZucker Sep 13, 2024
7d4bb19
nit
ArthurZucker Sep 13, 2024
df84fe7
update
ArthurZucker Sep 13, 2024
0da8b22
last nits
ArthurZucker Sep 13, 2024
8cfff1a
oups this was really breaking!
ArthurZucker Sep 13, 2024
1effab2
nits
ArthurZucker Sep 13, 2024
f33fe19
is composition needs to be true
ArthurZucker Sep 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,8 @@
title: Perceiver
- local: model_doc/pix2struct
title: Pix2Struct
- local: model_doc/pixtral
title: Pixtral
- local: model_doc/sam
title: Segment Anything
- local: model_doc/siglip
Expand Down
45 changes: 45 additions & 0 deletions docs/source/en/model_doc/pixtral.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# Pixtral

# Pixtral
amyeroberts marked this conversation as resolved.
Show resolved Hide resolved

## Overview

The Pixtral model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
<INSERT SHORT SUMMARY HERE>

The abstract from the paper is the following:

*<INSERT PAPER ABSTRACT HERE>*

Tips:

<INSERT TIPS ABOUT MODEL HERE>
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved

This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
amyeroberts marked this conversation as resolved.
Show resolved Hide resolved
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved


## PixtralConfig

[[autodoc]] PixtralConfig

## PixtralModel

[[autodoc]] PixtralModel
- forward
20 changes: 20 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,9 @@
"LlavaConfig",
"LlavaProcessor",
],
"models.pixtral": [
"PixtralConfig",
],
"models.llava_next": [
"LlavaNextConfig",
"LlavaNextProcessor",
Expand Down Expand Up @@ -643,6 +646,7 @@
"models.phi": ["PhiConfig"],
"models.phi3": ["Phi3Config"],
"models.phobert": ["PhobertTokenizer"],
"models.pixtral": ["PixtralConfig", "PixtralProcessor"],
"models.pix2struct": [
"Pix2StructConfig",
"Pix2StructProcessor",
Expand Down Expand Up @@ -1198,6 +1202,7 @@
_import_structure["models.owlv2"].append("Owlv2ImageProcessor")
_import_structure["models.owlvit"].extend(["OwlViTFeatureExtractor", "OwlViTImageProcessor"])
_import_structure["models.perceiver"].extend(["PerceiverFeatureExtractor", "PerceiverImageProcessor"])
_import_structure["models.pixtral"].append("PixtralImageProcessor")
_import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"])
_import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"])
_import_structure["models.pvt"].extend(["PvtImageProcessor"])
Expand Down Expand Up @@ -2524,6 +2529,12 @@
"LlavaPreTrainedModel",
]
)
_import_structure["models.pixtral"].extend(
[
"PixtralModel",
"PixtralPreTrainedModel",
]
)
_import_structure["models.llava_next"].extend(
[
"LlavaNextForConditionalGeneration",
Expand Down Expand Up @@ -5434,6 +5445,10 @@
Pix2StructTextConfig,
Pix2StructVisionConfig,
)
from .models.pixtral import (
PixtralConfig,
PixtralProcessor,
)
from .models.plbart import PLBartConfig
from .models.poolformer import (
PoolFormerConfig,
Expand Down Expand Up @@ -6009,6 +6024,7 @@
from .models.owlvit import OwlViTFeatureExtractor, OwlViTImageProcessor
from .models.perceiver import PerceiverFeatureExtractor, PerceiverImageProcessor
from .models.pix2struct import Pix2StructImageProcessor
from .models.pixtral import PixtralImageProcessor
from .models.poolformer import (
PoolFormerFeatureExtractor,
PoolFormerImageProcessor,
Expand Down Expand Up @@ -7448,6 +7464,10 @@
Pix2StructTextModel,
Pix2StructVisionModel,
)
from .models.pixtral import (
PixtralModel,
PixtralPreTrainedModel,
)
from .models.plbart import (
PLBartForCausalLM,
PLBartForConditionalGeneration,
Expand Down
42 changes: 22 additions & 20 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,11 +548,12 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], is_quantized: bool
# Check format of the archive
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)
if metadata is not None:
if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)
amyeroberts marked this conversation as resolved.
Show resolved Hide resolved
return safe_load_file(checkpoint_file)
try:
if (
Expand Down Expand Up @@ -3751,21 +3752,22 @@ def from_pretrained(
with safe_open(resolved_archive_file, framework="pt") as f:
metadata = f.metadata()

if metadata.get("format") == "pt":
pass
elif metadata.get("format") == "tf":
from_tf = True
logger.info("A TensorFlow safetensors file is being loaded in a PyTorch model.")
elif metadata.get("format") == "flax":
from_flax = True
logger.info("A Flax safetensors file is being loaded in a PyTorch model.")
elif metadata.get("format") == "mlx":
# This is a mlx file, we assume weights are compatible with pt
pass
else:
raise ValueError(
f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax', 'mlx'] but {metadata.get('format')}"
)
if metadata is not None:
if metadata.get("format") == "pt":
pass
elif metadata.get("format") == "tf":
from_tf = True
logger.info("A TensorFlow safetensors file is being loaded in a PyTorch model.")
elif metadata.get("format") == "flax":
from_flax = True
logger.info("A Flax safetensors file is being loaded in a PyTorch model.")
elif metadata.get("format") == "mlx":
# This is a mlx file, we assume weights are compatible with pt
pass
else:
raise ValueError(
f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax', 'mlx'] but {metadata.get('format')}"
)

from_pt = not (from_tf | from_flax)

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@
phi3,
phobert,
pix2struct,
pixtral,
plbart,
poolformer,
pop2piano,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@
("phi", "PhiConfig"),
("phi3", "Phi3Config"),
("pix2struct", "Pix2StructConfig"),
("pixtral", "PixtralConfig"),
("plbart", "PLBartConfig"),
("poolformer", "PoolFormerConfig"),
("pop2piano", "Pop2PianoConfig"),
Expand Down Expand Up @@ -509,6 +510,7 @@
("phi3", "Phi3"),
("phobert", "PhoBERT"),
("pix2struct", "Pix2Struct"),
("pixtral", "Pixtral"),
("plbart", "PLBart"),
("poolformer", "PoolFormer"),
("pop2piano", "Pop2Piano"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
("owlvit", ("OwlViTImageProcessor",)),
("perceiver", ("PerceiverImageProcessor",)),
("pix2struct", ("Pix2StructImageProcessor",)),
("pixtral", ("PixtralImageProcessor",)),
("poolformer", ("PoolFormerImageProcessor",)),
("pvt", ("PvtImageProcessor",)),
("pvt_v2", ("PvtImageProcessor",)),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@
("persimmon", "PersimmonModel"),
("phi", "PhiModel"),
("phi3", "Phi3Model"),
("pixtral", "PixtralModel"),
("plbart", "PLBartModel"),
("poolformer", "PoolFormerModel"),
("prophetnet", "ProphetNetModel"),
Expand Down Expand Up @@ -733,6 +734,7 @@
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
("paligemma", "PaliGemmaForConditionalGeneration"),
("pix2struct", "Pix2StructForConditionalGeneration"),
("pixtral", "PixtralModel"),
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
("video_llava", "VideoLlavaForConditionalGeneration"),
("vipllava", "VipLlavaForConditionalGeneration"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
("owlvit", "OwlViTProcessor"),
("paligemma", "PaliGemmaProcessor"),
("pix2struct", "Pix2StructProcessor"),
("pixtral", "PixtralProcessor"),
("pop2piano", "Pop2PianoProcessor"),
("qwen2_audio", "Qwen2AudioProcessor"),
("qwen2_vl", "Qwen2VLProcessor"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@
("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("phobert", ("PhobertTokenizer", None)),
("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
("pixtral", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
("prophetnet", ("ProphetNetTokenizer", None)),
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
Expand Down
70 changes: 70 additions & 0 deletions src/transformers/models/pixtral/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available


_import_structure = {
"configuration_pixtral": ["PixtralConfig"],
"processing_pixtral": ["PixtralProcessor"],
}


try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_pixtral"] = [
"PixtralModel",
"PixtralPreTrainedModel",
]

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_pixtral"] = ["PixtralImageProcessor"]


if TYPE_CHECKING:
from .configuration_pixtral import PixtralConfig, PixtralProcessor

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_pixtral import (
PixtralModel,
PixtralPreTrainedModel,
)

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_pixtral import PixtralImageProcessor

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
Loading
Loading