Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 46 additions & 16 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,13 +1140,13 @@ def __init__(
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if max_batch_size is not None:
if batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)

self.batch_size = batch_size or max_batch_size
self.max_batch_size = batch_size or max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len

# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
Expand Down Expand Up @@ -1254,6 +1254,14 @@ def reset(self):
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

@property
def batch_size(self):
logger.warning_once(
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
)
return self.max_batch_size
Comment on lines +1257 to +1263
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, let's make sure all our codes uses max_batch_size as logger warning is incompatible with compile

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, I check the generation files and the repo, seems like all changes are done already. Plus the compile tests on llama still pass



class SlidingWindowCache(StaticCache):
"""
Expand Down Expand Up @@ -1626,10 +1634,10 @@ def __init__(
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if max_batch_size is not None:
if batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
Expand All @@ -1638,7 +1646,7 @@ def __init__(
"config and it's not set to None."
)
self.max_cache_len = max_cache_len
self.batch_size = batch_size or max_batch_size
self.max_batch_size = batch_size or max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
Expand Down Expand Up @@ -1758,6 +1766,14 @@ def reset(self):
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

@property
def batch_size(self):
logger.warning_once(
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
)
return self.max_batch_size


class MambaCache:
"""
Expand Down Expand Up @@ -1815,28 +1831,28 @@ def __init__(
device: Optional[Union[torch.device, str]] = None,
max_batch_size: Optional[int] = None,
):
if max_batch_size is not None:
if batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)
self.dtype = dtype
self.batch_size = batch_size or max_batch_size
self.max_batch_size = batch_size or max_batch_size
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel

self.conv_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.batch_size,
self.max_batch_size,
self.intermediate_size,
self.conv_kernel_size,
device=device,
dtype=dtype,
)
self.ssm_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.batch_size,
self.max_batch_size,
self.intermediate_size,
self.ssm_state_size,
device=device,
Expand Down Expand Up @@ -1866,6 +1882,14 @@ def reset(self):
self.conv_states.zero_()
self.ssm_states.zero_()

@property
def batch_size(self):
logger.warning_once(
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
)
return self.max_batch_size


class OffloadedStaticCache(StaticCache):
"""
Expand All @@ -1887,6 +1911,9 @@ class OffloadedStaticCache(StaticCache):
The default `dtype` to use when initializing the cache.
offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`):
The device to offload to. Defaults to CPU.
layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*):
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.

Attributes:
key_cache (`List[torch.Tensor]`):
Expand Down Expand Up @@ -1933,18 +1960,21 @@ def __init__(
device: Union[str, torch.device],
dtype: Optional[torch.dtype] = None,
offload_device: Union[str, torch.device] = torch.device("cpu"),
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.device = torch.device(device)
self.device = torch.device(device) if layer_device_map is None else layer_device_map[0]
self.offload_device = torch.device(offload_device)
self.dtype = dtype if dtype is not None else torch.float32

# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads

num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)

cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim)
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@
"mamba": MambaCache,
}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
ALL_CACHE_IMPLEMENTATIONS = list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys())
ALL_CACHE_IMPLEMENTATIONS = (
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys()) + ["offloaded"]
)


class GenerationMode(ExplicitEnum):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1610,7 +1610,7 @@ def _get_cache(
need_new_cache = (
not hasattr(self, "_cache")
or (not isinstance(cache_to_check, cache_cls))
or cache_to_check.batch_size != batch_size
or cache_to_check.max_batch_size != batch_size
)
if cache_implementation != "mamba":
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
Expand Down Expand Up @@ -1666,7 +1666,7 @@ def get_layer_device_map(execution_device_map: Optional[dict] = None):

cache_kwargs = {
"config": self.config.get_text_config(),
"batch_size": batch_size,
"max_batch_size": batch_size,
"max_cache_len": max_cache_len,
"device": device,
"dtype": cache_dtype,
Expand Down
26 changes: 26 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1880,6 +1880,32 @@ def test_new_cache_format(self, num_beams, do_sample):
)
)

@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
@require_torch_gpu
@pytest.mark.generate
def test_offloaded_cache_implementation(self, cache_implementation):
"""Tests we can generate by indicating `cache_implementation` for each possible cache class"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
self.skipTest(reason="This model does not support the new cache format")

config, inputs_dict = self.prepare_config_and_inputs_for_generate()

model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_new_tokens": 5,
"use_cache": True,
"cache_implementation": cache_implementation,
}

legacy_results = model.generate(**generation_kwargs, **inputs_dict)

# Most cache classes have their own tests except for some that are tested here
# The ones here do not need special treatment when passing `cache_implementation`
# and are not bound to specific models only
new_results = model.generate(**generation_kwargs, **inputs_dict)
self.assertListEqual(legacy_results.tolist(), new_results.tolist())

@pytest.mark.generate
def test_generate_with_static_cache(self):
"""
Expand Down
8 changes: 8 additions & 0 deletions tests/models/mllama/test_modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

import unittest

import pytest
import requests
from parameterized import parameterized

from transformers import (
AutoProcessor,
Expand Down Expand Up @@ -365,6 +367,12 @@ def test_sdpa_can_compile_dynamic(self):
def test_model_parallelism(self):
pass

@parameterized.expand([("offloaded",)])
@pytest.mark.generate
@unittest.skip(reason="Offloaded cache seems to not work with mllama's kv cache type")
def test_offloaded_cache_implementation(self, cache_implementation):
pass

def test_generate_text_only_with_cache(self):
"""
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature
Expand Down
6 changes: 6 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,12 @@ def test_training_gradient_checkpointing_use_reentrant_false(self):
def test_generate_with_head_masking(self):
pass

@parameterized.expand([("offloaded",)])
@pytest.mark.generate
@unittest.skip(reason="Whisper doesnt work with offloaded cache implementation yet")
def test_offloaded_cache_implementation(self, cache_implementation):
pass

@require_torch_fp16
def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
Expand Down