Skip to content

Commit

Permalink
Make StaticCache configurable at model construct time
Browse files Browse the repository at this point in the history
  • Loading branch information
Guang Yang committed Aug 15, 2024
1 parent 9d2ab88 commit 57aecb3
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 36 deletions.
41 changes: 41 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,47 @@ def validate(self):
)


@dataclass
class StaticCacheConfig(CacheConfig):
"""
Configuration class for static cache settings.
"""

cache_implementation = "static"

def __init__(self, max_batch_size: int, max_cache_len: int, device="cpu"):
self.max_batch_size = max_batch_size
self.max_cache_len = max_cache_len
self.device = device

def validate(self):
"""Validates if the arguments passed are correct"""

incorrect_arg_msg = (
"Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
"but found {found_value}"
)
# Check that the values are reasonable in general (nbits, axis)
# Later in QuantizedCache init we check if they are supported for that particular backend
if self.max_batch_size <= 0:
raise ValueError(
incorrect_arg_msg.format(
key="max_batch_size",
correct_value="> 0",
found_value=self.max_batch_size,
),
)

if self.max_cache_len <= 0:
raise ValueError(
incorrect_arg_msg.format(
key="max_cache_len",
correct_value="> 0",
found_value=self.max_cache_len,
),
)


class DynamicCache(Cache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@
NEEDS_CACHE_CONFIG = {}

if is_torch_available():
from ..cache_utils import QuantizedCacheConfig
from ..cache_utils import QuantizedCacheConfig, StaticCacheConfig

NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig
NEEDS_CACHE_CONFIG["static"] = StaticCacheConfig


class GenerationMode(ExplicitEnum):
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/gemma/configuration_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def __init__(
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.cache_implementation = None
self.cache_config = None

super().__init__(
pad_token_id=pad_token_id,
Expand Down
34 changes: 34 additions & 0 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,28 @@ def __init__(self, config):
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

if self.config.cache_implementation == "static":
max_batch_size = self.config.cache_config.get("max_batch_size", 1)
max_cache_len = self.config.cache_config.get("max_cache_len", 32)
device = self.config.cache_config.get("device", "cpu")
dtype = self.config.cache_config.get("dtype", torch.float32)
self.static_cache = StaticCache(
config=config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=dtype,
)
causal_mask = torch.tril(
torch.ones(
max_cache_len,
max_cache_len,
dtype=torch.bool,
device=device,
)
)
self.register_buffer("mask", causal_mask, persistent=False)

# Initialize weights and apply final processing
self.post_init()

Expand Down Expand Up @@ -1070,6 +1092,14 @@ def forward(
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if self.config.cache_implementation == "static" and self.static_cache is not None:
# Static cache config is passed explicitly via `from_pretrained` at model construction time.
# `cache_position` can't be None when using static cache.
assert cache_position is not None
past_key_values = self.static_cache
attention_mask = self.mask[cache_position, : input_ids.shape[1]]
position_ids = cache_position.unsqueeze(0)

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
Expand All @@ -1084,6 +1114,10 @@ def forward(
cache_position=cache_position,
)

if self.config.cache_implementation == "static":
# Outputs should not include `Cache`` object as `torch.export`` does not support `Cache` type as an output.
outputs.past_key_values = None

hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
Expand Down
74 changes: 39 additions & 35 deletions tests/utils/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import torch

from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
DynamicCache,
Expand Down Expand Up @@ -181,54 +180,59 @@ def test_static_cache_exportability(self):

device = "cpu"
dtype = torch.float32
cache_implementation = "static"
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
max_batch_size = 1

config = AutoConfig.from_pretrained(
max_cache_len = 1234
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
device_map=device,
torch_dtype=dtype,
attn_implementation=attn_implementation,
use_cache=True,
cache_implementation=cache_implementation,
cache_config={
"max_batch_size": max_batch_size,
"max_cache_len": max_cache_len,
},
)
m = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
config=config,
torch_dtype=dtype,
attn_implementation="sdpa", # Export and ExecuTorch only works for SdpaAttention
).to(device)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
inputs = tokenizer(["The best color is"], return_tensors="pt").to(device)["input_ids"]

class ExportatibleModelWithStaticCache(torch.nn.Module):
def __init__(self, config, model):
super().__init__()
self.config = config
self.model = model
self.static_cache = StaticCache(
config=config, max_batch_size=max_batch_size, max_cache_len=config.max_length, device=device
)

def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor):
outs = self.model(
input_ids=tokens,
attention_mask=None,
position_ids=input_pos.unsqueeze(0),
cache_position=input_pos,
past_key_values=self.static_cache,
use_cache=True,
)
return outs.logits
# Check if cache config is passed through correctly
self.assertEqual(model.config.use_cache, True)
self.assertEqual(model.config.cache_implementation, cache_implementation)
self.assertTrue(model.config.cache_config is not None)
self.assertEqual(model.config.cache_config.get("max_batch_size"), max_batch_size)
self.assertEqual(model.config.cache_config.get("max_cache_len"), max_cache_len)

set_seed(0)
with torch.no_grad():
import torch.export._trace
from torch.export import ExportedProgram

model = ExportatibleModelWithStaticCache(config, m)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
input_ids = tokenizer([""], return_tensors="pt").to(device)["input_ids"]
cache_position = torch.tensor([0], dtype=torch.long)
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.4.1+ release.
exported_program = torch.export._trace._export(
model, args=(inputs,), kwargs={"input_pos": torch.arange(1)}, pre_dispatch=False, strict=True
model,
args=(input_ids,),
kwargs={"cache_position": cache_position},
pre_dispatch=False,
strict=True,
)
self.assertTrue(isinstance(exported_program, ExportedProgram))

# Check if the exported model is configured with the `StaticCache` correctly
n_static_key_caches = n_static_value_caches = 0
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("static_cache.key_cache"):
self.assertTrue(buffer.shape[0] == max_batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_key_caches = n_static_key_caches + 1
if buffer_name.startswith("static_cache.value_cache"):
self.assertTrue(buffer.shape[0] == max_batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_value_caches = n_static_value_caches + 1
self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)


@require_torch_gpu
Expand Down

0 comments on commit 57aecb3

Please sign in to comment.