Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,18 +982,38 @@ def update_recurrent_state(self, recurrent_states: torch.Tensor, layer_idx: int,
return recurrent_states

def early_initialization(
self, batch_size: int, num_heads: int, head_dim: int, dtype: torch.dtype, device: torch.device
self,
batch_size: int,
num_heads: int | list[int],
head_dim: int | list[int],
dtype: torch.dtype,
device: torch.device,
):
"""
Initialize all the layers in advance (it's otherwise lazily initialized on the first `update` call).
This is useful for our `export` recipes, as `export` needs everything in advance.
"""
# Note that the initialization needs all dimensions (except -2), as well as device and dtype, so we use
# this fake tensor approach. It has size 0 on the -2 dimension, so it does not allocate any data (it only
# creates an empty tensor with correct shape, dtype and device), which is very efficient and practical
fake_kv_tensor = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=dtype, device=device)
# Init all layers
for layer in self.layers:
# To allow different num_heads and head_dim depending on layers, we accept lists
if isinstance(num_heads, int):
num_heads = [num_heads] * len(self)
if isinstance(head_dim, int):
head_dim = [head_dim] * len(self)

if len(num_heads) != len(self.layers):
raise ValueError(
f"`num_head` was provided as a list of length {len(num_heads)}, but the Cache currently has {len(self.layers)} layers"
)
if len(head_dim) != len(self.layers):
raise ValueError(
f"`head_dim` was provided as a list of length {len(num_heads)}, but the Cache currently has {len(self.layers)} layers"
)

for layer, layer_num_heads, layer_head_dim in zip(self.layers, num_heads, head_dim):
# Note that the initialization needs all dimensions (except -2), as well as device and dtype, so we use
# this fake tensor approach. It has size 0 on the -2 dimension, so it does not allocate any data (it only
# creates an empty tensor with correct shape, dtype and device), which is very efficient and practical
fake_kv_tensor = torch.zeros((batch_size, layer_num_heads, 0, layer_head_dim), dtype=dtype, device=device)
# Init the layer
layer.lazy_initialization(fake_kv_tensor, fake_kv_tensor)

def get_seq_length(self, layer_idx: int = 0) -> int:
Expand Down
37 changes: 28 additions & 9 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,28 @@ def generate(
return tokenizer.decode(generated_ids[0], skip_special_tokens=True)


def get_head_shapes(config) -> tuple[int | list[int], int | list[int]]:
"""Returns a tuple `(num_heads, head_dim)` containing either 2 ints, or a list of int with the value for each
layer."""
# Gemma4 has different head_dim and num_heads depending on layer type
if hasattr(config, "global_head_dim"):
head_dim = [
config.global_head_dim if layer == "full_attention" else config.head_dim
for layer in config.layer_types[: -config.num_kv_shared_layers]
]
num_heads = [
config.num_global_key_value_heads
if layer == "full_attention" and config.attention_k_eq_v
else config.num_key_value_heads
for layer in config.layer_types[: -config.num_kv_shared_layers]
]
else:
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)

return num_heads, head_dim


class TorchExportableModuleWithStaticCache(torch.nn.Module):
"""
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
Expand Down Expand Up @@ -523,9 +545,8 @@ def __init__(
# simple StaticLayer... It means that any generation beyond the window is unfortunately unsupported
for i, layer in enumerate(self.static_cache.layers):
if isinstance(layer, StaticSlidingWindowLayer):
self.static_cache.layers[i] = StaticLayer(layer.max_cache_len)
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
self.static_cache.layers[i] = StaticLayer(max_cache_len)
num_heads, head_dim = get_head_shapes(config)
dtype = self.model.dtype
# We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable)
self.static_cache.early_initialization(batch_size, num_heads, head_dim, dtype, device)
Expand Down Expand Up @@ -702,9 +723,8 @@ def __init__(
# simple StaticLayer... It means that any generation beyond the window is unfortunately unsupported
for i, layer in enumerate(self.cache.layers):
if isinstance(layer, StaticSlidingWindowLayer):
self.cache.layers[i] = StaticLayer(layer.max_cache_len)
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
self.cache.layers[i] = StaticLayer(max_cache_len)
num_heads, head_dim = get_head_shapes(config)
dtype = self.model.dtype
# We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable)
self.cache.early_initialization(batch_size, num_heads, head_dim, dtype, device)
Expand Down Expand Up @@ -856,9 +876,8 @@ def __init__(self, model, max_static_cache_length, batch_size):
# simple StaticLayer... It means that any generation beyond the window is unfortunately unsupported
for i, layer in enumerate(self.static_cache.layers):
if isinstance(layer, StaticSlidingWindowLayer):
self.static_cache.layers[i] = StaticLayer(layer.max_cache_len)
head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads)
num_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads)
self.static_cache.layers[i] = StaticLayer(max_static_cache_length)
num_heads, head_dim = get_head_shapes(self.config)
self.static_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, model_device)
self.cache = EncoderDecoderCache(self.static_cache, DynamicCache(config=self.config))

Expand Down
Loading
Loading