Skip to content

Commit

Permalink
Make StaticCache configurable at model construct time
Browse files Browse the repository at this point in the history
  • Loading branch information
Guang Yang committed Sep 6, 2024
1 parent 0b066be commit 5199230
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 51 deletions.
31 changes: 31 additions & 0 deletions docs/source/en/main_classes/executorch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# ExecuTorch

[`ExecuTorch`](https://github.com/pytorch/executorch) is an end-to-end solution for enabling on-device inference capabilities across mobile and edge devices including wearables, embedded devices and microcontrollers. It is part of the PyTorch ecosystem and supports the deployment of PyTorch models with a focus on portability, productivity, and performance.

ExecuTorch introduces well defined entry points to perform model, device, and/or use-case specific optimizations such as backend delegation, user-defined compiler transformations, memory planning, and more. The first step in preparing a PyTorch model for execution on an edge device using ExecuTorch is to export the model. This is achieved through the use of a PyTorch API called [`torch.export`](https://pytorch.org/docs/stable/export.html).


## ExecuTorch Integration

An integration point is being developed to ensure that 🤗 Transformers can be exported using `torch.export`. The goal of this integration is not only to enable export but also to ensure that the exported artifact can be further lowered and optimized to run efficiently in `ExecuTorch`, particularly for mobile and edge use cases.

[[autodoc]] integrations.executorch.TorchExportableModuleWithStaticCache
- forward

[[autodoc]] integrations.executorch.convert_and_export_with_cache
10 changes: 10 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,6 +1309,12 @@
"WhisperTimeStampLogitsProcessor",
]
)

# PyTorch domain libraries integration
_import_structure["integrations.executorch"] = []
_import_structure["integrations.executorch"].append("TorchExportableModuleWithStaticCache")
_import_structure["integrations.executorch"].append("convert_and_export_with_cache")

_import_structure["modeling_flash_attention_utils"] = []
_import_structure["modeling_outputs"] = []
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"]
Expand Down Expand Up @@ -6074,6 +6080,10 @@
WatermarkLogitsProcessor,
WhisperTimeStampLogitsProcessor,
)
from .integrations.executorch import (
TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
)
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
from .modeling_utils import PreTrainedModel
from .models.albert import (
Expand Down
40 changes: 40 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,46 @@ def validate(self):
)


@dataclass
class StaticCacheConfig(CacheConfig):
"""
Configuration class for static cache settings.
"""

cache_implementation = "static"

def __init__(self, batch_size: int, max_cache_len: int, device="cpu"):
self.batch_size = batch_size
self.max_cache_len = max_cache_len
self.device = device

def validate(self):
"""Validates if the arguments passed are correct"""

incorrect_arg_msg = (
"Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
"but found {found_value}"
)

if self.batch_size <= 0:
raise ValueError(
incorrect_arg_msg.format(
key="batch_size",
correct_value="> 0",
found_value=self.batch_size,
),
)

if self.max_cache_len <= 0:
raise ValueError(
incorrect_arg_msg.format(
key="max_cache_len",
correct_value="> 0",
found_value=self.max_cache_len,
),
)


class DynamicCache(Cache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@
NEEDS_CACHE_CONFIG = {}

if is_torch_available():
from ..cache_utils import QuantizedCacheConfig
from ..cache_utils import QuantizedCacheConfig, StaticCacheConfig

NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig
NEEDS_CACHE_CONFIG["static"] = StaticCacheConfig


class GenerationMode(ExplicitEnum):
Expand Down
147 changes: 147 additions & 0 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import torch

from transformers import (
PreTrainedModel,
StaticCache,
)
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3


class TorchExportableModuleWithStaticCache(torch.nn.Module):
"""
A wrapper module designed to make a `PreTrainedModel` exportable with `torch.export`,
specifically for use with static caching. This module ensures that the exported model
is compatible with further lowering and execution in `ExecuTorch`.
Note:
This class is specifically designed to support export process using `torch.export`
in a way that ensures the model can be further lowered and run efficiently in `ExecuTorch`.
"""

def __init__(self, model: PreTrainedModel):
"""
Initializes the wrapper module with the pretrained model.
Args:
model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching
enabled and use a 'static' caching implementation.
Raises:
AssertionError: If the pretrained model does not have caching enabled or if it does
not use a 'static' caching implementation in `model.generation_config`.
"""
super().__init__()

# Sanity checks
if model.generation_config is None:
raise AssertionError(
"The model must have a generation config to be exported with static caching. "
"Please set `generation_config`."
)

if not model.generation_config.use_cache:
raise AssertionError(
"The model must have caching enabled to be exported with static caching. "
"Please set `generation_config.use_cache=True`."
)

if model.generation_config.cache_implementation != "static":
raise AssertionError(
"The model must use a 'static' caching implementation to be exported with static caching. "
"Please set `generation_config.cache_implementation='static'`."
)

self.model = model
self.static_cache = StaticCache(
config=self.model.config,
batch_size=self.model.generation_config.cache_config.batch_size,
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
dtype=self.model.config.torch_dtype,
)
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
if self.is_causal:
causal_mask = torch.tril(
torch.ones(
self.static_cache.max_cache_len,
self.static_cache.max_cache_len,
dtype=torch.bool,
)
)
self.register_buffer("mask", causal_mask, persistent=False)

def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
"""
Forward pass of the module, which is compatible with the ExecuTorch runtime.
Args:
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
Returns:
torch.Tensor: Logits output from the model.
This forward adapter serves two primary purposes:
1. **Making the Model `torch.export`-Compatible**:
The adapter hides unsupported objects, such as the `Cache`, from the graph inputs and outputs,
enabling the model to be exportable using `torch.export` without encountering issues.
2. **Ensuring Compatibility with `ExecuTorch` runtime**:
The adapter matches the model's forward signature with that in `executorch/extension/llm/runner`,
ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
"""
_, seqlen = input_ids.shape
attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
outs = self.model(
input_ids=input_ids,
attention_mask=attn_mask,
position_ids=cache_position.unsqueeze(0),
cache_position=cache_position,
past_key_values=self.static_cache,
use_cache=True,
)
return outs.logits


def convert_and_export_with_cache(
model: PreTrainedModel,
example_input_ids: torch.Tensor = None,
example_cache_position: torch.Tensor = None,
):
"""
Convert a `PreTrainedModel` into an exportable module and export it using `torch.export`,
ensuring the exported model is compatible with `ExecuTorch`.
Args:
model (`PreTrainedModel`): The pretrained model to be exported.
example_input_ids (`torch.Tensor`): Example input token id used by `torch.export`.
example_cache_position (`torch.Tensor`): Example current cache position used by `torch.export`.
Returns:
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
"""

if not is_torch_greater_or_equal_than_2_3:
raise ImportError("torch >= 2.3 is required.")

import torch.export._trace

with torch.no_grad():
# TODO: The default inputs only work for text models. We need to add support for vision/audio models.
example_input_ids = (
example_input_ids if example_input_ids is not None else torch.tensor([[1]], dtype=torch.long)
)
example_cache_position = (
example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long)
)

# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
exported_program = torch.export._trace._export(
TorchExportableModuleWithStaticCache(model),
args=(example_input_ids,),
kwargs={"cache_position": example_cache_position},
pre_dispatch=False,
strict=True,
)
return exported_program
6 changes: 5 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3184,6 +3184,7 @@ def from_pretrained(
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
adapter_name = kwargs.pop("adapter_name", "default")
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
generation_config = kwargs.pop("generation_config", None)

gguf_file = kwargs.pop("gguf_file", None)
# Cache path to the GGUF file
Expand Down Expand Up @@ -3959,7 +3960,10 @@ def from_pretrained(
model.eval()

# If it is a model with generation capabilities, attempt to load the generation config
if model.can_generate() and pretrained_model_name_or_path is not None:
if model.can_generate() and generation_config is not None:
logger.info("The user-defined `generation_config` will be used to override the default generation config.")
model.generation_config = model.generation_config.from_dict(generation_config.to_dict())
elif model.can_generate() and pretrained_model_name_or_path is not None:
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,17 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class TorchExportableModuleWithStaticCache(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


def convert_and_export_with_cache(*args, **kwargs):
requires_backends(convert_and_export_with_cache, ["torch"])


ROPE_INIT_FUNCTIONS = None


Expand Down
Loading

0 comments on commit 5199230

Please sign in to comment.