From 502f86dc1e5b7399e330b3f702d787bd0c7f8600 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 25 Nov 2024 16:17:39 +0100 Subject: [PATCH 1/4] fix cache impl --- src/transformers/cache_utils.py | 51 ++++++++++++++----- .../generation/configuration_utils.py | 4 +- src/transformers/generation/utils.py | 4 +- tests/generation/test_utils.py | 25 +++++++++ tests/models/mllama/test_modeling_mllama.py | 8 +++ tests/models/whisper/test_modeling_whisper.py | 6 +++ 6 files changed, 81 insertions(+), 17 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 0f696cc3ac6a..5feddad20eb7 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1131,13 +1131,13 @@ def __init__( layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: super().__init__() - if max_batch_size is not None: + if batch_size is not None: logger.warning_once( - f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " - "v4.46. Use the more precisely named 'batch_size' argument instead." + f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " + "v4.49. Use the more precisely named 'max_batch_size' argument instead." ) - self.batch_size = batch_size or max_batch_size + self.max_batch_size = batch_size or max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads @@ -1243,6 +1243,14 @@ def reset(self): self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_() + @property + def batch_size(self): + logger.warning_once( + f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in " + "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead." + ) + return self.max_batch_size + class SlidingWindowCache(StaticCache): """ @@ -1609,10 +1617,10 @@ def __init__( layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: super().__init__() - if max_batch_size is not None: + if batch_size is not None: logger.warning_once( - f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " - "v4.46. Use the more precisely named 'batch_size' argument instead." + f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " + "v4.49. Use the more precisely named 'max_batch_size' argument instead." ) if not hasattr(config, "sliding_window") or config.sliding_window is None: raise ValueError( @@ -1621,7 +1629,7 @@ def __init__( "config and it's not set to None." ) self.max_cache_len = max_cache_len - self.batch_size = batch_size or max_batch_size + self.max_batch_size = batch_size or max_batch_size # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads self.head_dim = ( config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads @@ -1741,6 +1749,14 @@ def reset(self): self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_() + @property + def batch_size(self): + logger.warning_once( + f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in " + "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead." + ) + return self.max_batch_size + class MambaCache: """ @@ -1798,13 +1814,13 @@ def __init__( device: Optional[Union[torch.device, str]] = None, max_batch_size: Optional[int] = None, ): - if max_batch_size is not None: + if batch_size is not None: logger.warning_once( - f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " - "v4.46. Use the more precisely named 'batch_size' argument instead." + f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " + "v4.49. Use the more precisely named 'max_batch_size' argument instead." ) self.dtype = dtype - self.batch_size = batch_size or max_batch_size + self.max_batch_size = batch_size or max_batch_size self.intermediate_size = config.intermediate_size self.ssm_state_size = config.state_size self.conv_kernel_size = config.conv_kernel @@ -1888,6 +1904,10 @@ class OffloadedStaticCache(StaticCache): The device used to offload to. dtype (`torch.dtype`): The `dtype` used to initializing the cache. + layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): + Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus. + You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. + Example: @@ -1916,10 +1936,11 @@ def __init__( device: Union[str, torch.device], dtype: Optional[torch.dtype] = None, offload_device: Union[str, torch.device] = torch.device("cpu"), + layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len - self.device = torch.device(device) + self.device = torch.device(device) if layer_device_map is None else layer_device_map[0] self.offload_device = torch.device(offload_device) self.dtype = dtype if dtype is not None else torch.float32 @@ -1927,7 +1948,9 @@ def __init__( head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads num_key_value_heads = ( - config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads ) cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 9b543f6c3571..3403807dcac1 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -72,7 +72,9 @@ "mamba": MambaCache, } QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache} - ALL_CACHE_IMPLEMENTATIONS = list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys()) + ALL_CACHE_IMPLEMENTATIONS = ( + list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys()) + ["offloaded"] + ) class GenerationMode(ExplicitEnum): diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6e6d5b8bdce7..ddfb8038353a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1591,7 +1591,7 @@ def _get_cache( need_new_cache = ( not hasattr(self, "_cache") or (not isinstance(cache_to_check, cache_cls)) - or cache_to_check.batch_size != batch_size + or cache_to_check.max_batch_size != batch_size ) if cache_implementation != "mamba": need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len @@ -1644,7 +1644,7 @@ def get_layer_device_map(execution_device_map: Optional[dict] = None): cache_kwargs = { "config": self.config.get_text_config(), - "batch_size": batch_size, + "max_batch_size": batch_size, "max_cache_len": max_cache_len, "device": device, "dtype": cache_dtype, diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index cbe851e97e9a..7dbd05d58a19 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1869,6 +1869,31 @@ def test_new_cache_format(self, num_beams, do_sample): ) ) + @parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5) + @pytest.mark.generate + def test_offloaded_cache_implementation(self, cache_implementation): + """Tests we can generate by indicating `cache_implementation` for each possible cache class""" + for model_class in self.all_generative_model_classes: + if not model_class._supports_cache_class: + self.skipTest(reason="This model does not support the new cache format") + + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + + model = model_class(config).to(torch_device).eval() + generation_kwargs = { + "max_new_tokens": 5, + "use_cache": True, + "cache_implementation": cache_implementation, + } + + legacy_results = model.generate(**generation_kwargs, **inputs_dict) + + # Most cache classes have their own tests except for some that are tested here + # The ones here do not need special treatment when passing `cache_implementation` + # and are not bound to specific models only + new_results = model.generate(**generation_kwargs, **inputs_dict) + self.assertListEqual(legacy_results.tolist(), new_results.tolist()) + @pytest.mark.generate def test_generate_with_static_cache(self): """ diff --git a/tests/models/mllama/test_modeling_mllama.py b/tests/models/mllama/test_modeling_mllama.py index 8da927f815db..cfd64aee5368 100644 --- a/tests/models/mllama/test_modeling_mllama.py +++ b/tests/models/mllama/test_modeling_mllama.py @@ -16,7 +16,9 @@ import unittest +import pytest import requests +from parameterized import parameterized from transformers import ( AutoProcessor, @@ -365,6 +367,12 @@ def test_sdpa_can_compile_dynamic(self): def test_model_parallelism(self): pass + @parameterized.expand([("offloaded",)]) + @pytest.mark.generate + @unittest.skip(reason="Offloaded cache seems to not work with mllama's kv cache type") + def test_offloaded_cache_implementation(self, cache_implementation): + pass + def test_generate_text_only_with_cache(self): """ Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 12aedaca8cf9..1123bec459d7 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -567,6 +567,12 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_generate_with_head_masking(self): pass + @parameterized.expand([("offloaded",)]) + @pytest.mark.generate + @unittest.skip(reason="Whisper doesnt work with offloaded cache implementation yet") + def test_offloaded_cache_implementation(self, cache_implementation): + pass + @require_torch_fp16 def test_generate_fp16(self): config, input_dict = self.model_tester.prepare_config_and_inputs() From f6fc92d209d81625500bb9c0468ae431883af25e Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 25 Nov 2024 16:26:11 +0100 Subject: [PATCH 2/4] require_torch_gpu --- src/transformers/cache_utils.py | 3 +++ tests/generation/test_utils.py | 1 + 2 files changed, 4 insertions(+) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 5feddad20eb7..7296479e39e4 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1886,6 +1886,9 @@ class OffloadedStaticCache(StaticCache): The default `dtype` to use when initializing the cache. offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`): The device to offload to. Defaults to CPU. + layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): + Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus. + You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. Attributes: key_cache (`List[torch.Tensor]`): diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 7dbd05d58a19..97a43b89d415 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1870,6 +1870,7 @@ def test_new_cache_format(self, num_beams, do_sample): ) @parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5) + @require_torch_gpu @pytest.mark.generate def test_offloaded_cache_implementation(self, cache_implementation): """Tests we can generate by indicating `cache_implementation` for each possible cache class""" From 1e142d178ffbbc844366f344383a28bf6b2cc6ae Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 25 Nov 2024 16:30:58 +0100 Subject: [PATCH 3/4] fix mamba --- src/transformers/cache_utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 7296479e39e4..0f4e0b7464e1 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1827,7 +1827,7 @@ def __init__( self.conv_states: torch.Tensor = torch.zeros( config.num_hidden_layers, - self.batch_size, + self.max_batch_size, self.intermediate_size, self.conv_kernel_size, device=device, @@ -1835,7 +1835,7 @@ def __init__( ) self.ssm_states: torch.Tensor = torch.zeros( config.num_hidden_layers, - self.batch_size, + self.max_batch_size, self.intermediate_size, self.ssm_state_size, device=device, @@ -1865,6 +1865,14 @@ def reset(self): self.conv_states.zero_() self.ssm_states.zero_() + @property + def batch_size(self): + logger.warning_once( + f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in " + "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead." + ) + return self.max_batch_size + class OffloadedStaticCache(StaticCache): """ From c773dcc162a77fb6d25b7832c1ebd9a0ff052a3f Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 25 Nov 2024 16:45:02 +0100 Subject: [PATCH 4/4] fix copies --- src/transformers/cache_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 0f4e0b7464e1..bcf28993fe83 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1894,7 +1894,7 @@ class OffloadedStaticCache(StaticCache): The default `dtype` to use when initializing the cache. offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`): The device to offload to. Defaults to CPU. - layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): + layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*): Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus. You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. @@ -1915,10 +1915,6 @@ class OffloadedStaticCache(StaticCache): The device used to offload to. dtype (`torch.dtype`): The `dtype` used to initializing the cache. - layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): - Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus. - You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. - Example: