Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the initialization of the cache when we have multi gpu #33303

Merged
merged 12 commits into from
Sep 13, 2024
13 changes: 8 additions & 5 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,7 @@ def __init__(
device: torch.device = None,
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_mapping: Optional[dict] = None,
) -> None:
super().__init__()
if max_batch_size is not None:
Expand Down Expand Up @@ -1047,6 +1048,8 @@ def __init__(
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for idx in range(config.num_hidden_layers):
if layer_device_mapping is not None:
device = layer_device_mapping[idx]
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
# Notes:
Expand Down Expand Up @@ -1090,8 +1093,6 @@ def update(
A tuple containing the updated key and value states.
"""
cache_position = cache_kwargs.get("cache_position")
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]

Expand Down Expand Up @@ -1190,6 +1191,7 @@ def __init__(
device: torch.device = None,
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_mapping: Optional[dict] = None,
) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
Expand All @@ -1206,6 +1208,7 @@ def __init__(
device=device,
dtype=dtype,
max_batch_size=max_batch_size,
layer_device_mapping=layer_device_mapping,
)

def update(
Expand Down Expand Up @@ -1239,7 +1242,6 @@ def update(
v_out = v_out[:, :, indices]

try:
cache_position.to(device=k_out.device)
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
Expand Down Expand Up @@ -1484,6 +1486,7 @@ def __init__(
device: Union[torch.device, str] = "cpu",
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_mapping: Optional[dict] = None,
) -> None:
super().__init__()
if max_batch_size is not None:
Expand Down Expand Up @@ -1521,6 +1524,8 @@ def __init__(
self.head_dim,
)
for i in range(config.num_hidden_layers):
if layer_device_mapping is not None:
device = layer_device_mapping[i]
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
Expand Down Expand Up @@ -1576,8 +1581,6 @@ def update(
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
sliding_window = cache_kwargs.get("sliding_window")
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
if sliding_window:
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,12 +1449,39 @@ def _get_cache(
# models. May cause trobles with non-text modalities.
cache_dtype = self.get_output_embeddings().weight.dtype

def get_layer_device_mapping(execution_device: Optional[dict] = None):
if execution_device is None or len(execution_device) <= 1:
return None
layer_device_mapping = {}
for layer in execution_device:
for idx in range(self.config.num_hidden_layers):
if f".{idx}." in f"{layer}.":
layer_device_mapping[idx] = execution_device[layer]
break
for idx in range(self.config.num_hidden_layers):
if idx not in layer_device_mapping:
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
return layer_device_mapping

execution_device = None
# Taken from dispatch_model from accelerate.
# This is needed here if we don't want to make changes in accelerate in order to save execution_device
# For offloaded case, we need to get the execution device, not just the device where it is offloaded
if hasattr(self, "hf_device_map"):
main_device = main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
execution_device = {
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
name: main_device if device in ["cpu", "disk"] else device
for name, device in self.hf_device_map.items()
}
layer_device_mapping = get_layer_device_mapping(execution_device)

cache_kwargs = {
"config": self.config,
"batch_size": batch_size,
"max_cache_len": max_cache_len,
"device": device,
"dtype": cache_dtype,
"layer_device_mapping": layer_device_mapping,
}
self._cache = cache_cls(**cache_kwargs)
if requires_cross_attention_cache:
Expand Down
Loading