Skip to content

Commit

Permalink
Add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Jul 7, 2024
1 parent ff396bb commit b20e18a
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 24 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ Flax), PyTorch, and/or TensorFlow.
| [DialoGPT](model_doc/dialogpt) ||||
| [DiNAT](model_doc/dinat) ||||
| [DINOv2](model_doc/dinov2) ||||
| [Dinov2WithRegistersWithRegisters](model_doc/dinov2-with-registers) ||||
| [Dinov2WithRegsiters](model_doc/dinov2-with-registers) ||||
| [DistilBERT](model_doc/distilbert) ||||
| [DiT](model_doc/dit) ||||
| [DonutSwin](model_doc/donut) ||||
Expand Down
15 changes: 7 additions & 8 deletions docs/source/en/model_doc/dinov2_with_registers.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,22 @@ specific language governing permissions and limitations under the License.

# Dinov2WithRegisters

# Dinov2WithRegisters

## Overview

The Dinov2WithRegisters model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
<INSERT SHORT SUMMARY HERE>
The Dinov2 With Registers model was proposed in [Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588) by Timothée Darcet, Maxime Oquab, Julien Mairal, Piotr Bojanowski.

This paper shows that by adding more tokens to the input sequence of a Vision Transformer useful for internal computations, one can enhance the performance.

The abstract from the paper is the following:

*<INSERT PAPER ABSTRACT HERE>*
*Transformers have recently emerged as a powerful tool for learning visual representations. In this paper, we identify and characterize artifacts in feature maps of both supervised and self-supervised ViT networks. The artifacts correspond to high-norm tokens appearing during inference primarily in low-informative background areas of images, that are repurposed for internal computations. We propose a simple yet effective solution based on providing additional tokens to the input sequence of the Vision Transformer to fill that role. We show that this solution fixes that problem entirely for both supervised and self-supervised models, sets a new state of the art for self-supervised visual models on dense visual prediction tasks, enables object discovery methods with larger models, and most importantly leads to smoother feature maps and attention maps for downstream visual processing.*

Tips:

<INSERT TIPS ABOUT MODEL HERE>
- Usage of Dinov2 with registers is identical to Dinov2 without, you'll just get better performance.

This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
This model was contributed by [nielsr](https://huggingface.co/nielsr).
The original code can be found [here](https://github.com/facebookresearch/dinov2).


## Dinov2WithRegistersConfig
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Dinov2WithRegisters checkpoints from the original repository.
"""Convert Dinov2 With Registers checkpoints from the original repository.
URL: https://github.com/facebookresearch/dinov2_with_registers/tree/main
URL: https://github.com/facebookresearch/dinov2/tree/main
"""

import argparse
Expand Down Expand Up @@ -181,8 +181,12 @@ def convert_dinov2_with_registers_checkpoint(model_name, pytorch_dump_folder_pat
if image_classifier:
model = Dinov2WithRegistersForImageClassification(config).eval()
model.dinov2_with_registers.load_state_dict(state_dict)
raise NotImplementedError("To do")
model_name_to_classifier_dict_url = {}
model_name_to_classifier_dict_url = {
"dinov2_vits14_reg_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear_head.pth",
"dinov2_vitb14_reg_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear_head.pth",
"dinov2_vitl14_reg_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear_head.pth",
"dinov2_vitg14_reg_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_linear_head.pth",
}
url = model_name_to_classifier_dict_url[model_name]
classifier_state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
model.classifier.weight = nn.Parameter(classifier_state_dict["weight"])
Expand Down Expand Up @@ -246,7 +250,10 @@ def convert_dinov2_with_registers_checkpoint(model_name, pytorch_dump_folder_pat
"dinov2_vitb14_reg": "dinov2-with-registers-base",
"dinov2_vitl14_reg": "dinov2-with-registers-large",
"dinov2_vitg14_reg": "dinov2-with-registers-giant",
# TODO 1-layer image classifiers
"dinov2_vits14_reg_1layer": "dinov2-with-registers-small-imagenet1k-1-layer",
"dinov2_vitb14_reg_1layer": "dinov2-with-registers-base-imagenet1k-1-layer",
"dinov2_vitl14_reg_1layer": "dinov2-with-registers-large-imagenet1k-1-layer",
"dinov2_vitg14_reg_1layer": "dinov2-with-registers-giant-imagenet1k-1-layer",
}

name = model_name_to_hf_name[model_name]
Expand All @@ -266,6 +273,10 @@ def convert_dinov2_with_registers_checkpoint(model_name, pytorch_dump_folder_pat
"dinov2_vitb14_reg",
"dinov2_vitl14_reg",
"dinov2_vitg14_reg",
"dinov2_vits14_reg_1layer",
"dinov2_vitb14_reg_1layer",
"dinov2_vitl14_reg_1layer",
"dinov2_vitg14_reg_1layer",
],
help="Name of the model you'd like to convert.",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,19 @@
_CONFIG_FOR_DOC = "Dinov2WithRegistersConfig"

# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/dinov2_with_registers"
_CHECKPOINT_FOR_DOC = "facebook/dinov2-with-registers-base"
_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]

# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2_with_registers-small-imagenet1k-1-layer"
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-with-registers-small-imagenet1k-1-layer"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"


# Copied from transformers.models.dinov2.modeling_dinov2.Dinov2Embeddings with Dinov2->Dinov2WithRegisters
class Dinov2WithRegistersEmbeddings(nn.Module):
"""
Construct the CLS token, mask token, position and patch embeddings.
Construct the CLS token, mask token, register tokens, position and patch embeddings.
"""

# Ignore copy
def __init__(self, config: Dinov2WithRegistersConfig) -> None:
super().__init__()

Expand Down Expand Up @@ -116,7 +114,6 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width:
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

# Ignore copy
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
target_dtype = self.patch_embeddings.projection.weight.dtype
Expand Down Expand Up @@ -791,12 +788,12 @@ def forward(
""",
DINOV2_WITH_REGISTERS_START_DOCSTRING,
)
# Copied from transformers.models.dinov2.modeling_dinov2.Dinov2Backbone with DINOV2->DINOV2_WITH_REGISTERS,Dinov2->Dinov2WithRegisters,facebook/dinov2-base->facebook/dinov2_with_registers
class Dinov2WithRegistersBackbone(Dinov2WithRegistersPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
super()._init_backbone(config)

self.num_register_tokens = config.num_register_tokens
self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
self.embeddings = Dinov2WithRegistersEmbeddings(config)
self.encoder = Dinov2WithRegistersEncoder(config)
Expand Down Expand Up @@ -864,7 +861,7 @@ def forward(
if self.config.apply_layernorm:
hidden_state = self.layernorm(hidden_state)
if self.config.reshape_hidden_states:
hidden_state = hidden_state[:, 1:]
hidden_state = hidden_state[:, self.num_register_tokens + 1 :]
# this was actually a bug in the original implementation that we copied here,
# cause normally the order is height, width
batch_size, _, height, width = pixel_values.shape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
attention_probs_dropout_prob=0.1,
type_sequence_label_size=10,
initializer_range=0.02,
num_register_tokens=2,
scope=None,
):
self.parent = parent
Expand All @@ -86,11 +87,12 @@ def __init__(
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_register_tokens = num_register_tokens
self.scope = scope

# in Dinov2WithRegisters, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
# in Dinov2 With Registers, the seq length equals the number of patches + 1 + num_register_tokens (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
self.seq_length = num_patches + 1 + self.num_register_tokens

def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
Expand All @@ -117,6 +119,7 @@ def get_config(self):
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
num_register_tokens=self.num_register_tokens,
)

def create_and_check_model(self, config, pixel_values, labels):
Expand Down Expand Up @@ -218,6 +221,14 @@ class Dinov2WithRegistersModelTest(ModelTesterMixin, PipelineTesterMixin, unitte
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{
"image-feature-extraction": Dinov2WithRegistersModel,
"image-classification": Dinov2WithRegistersForImageClassification,
}
if is_torch_available()
else {}
)
fx_compatible = False

test_pruning = False
Expand Down

0 comments on commit b20e18a

Please sign in to comment.