Skip to content
2 changes: 1 addition & 1 deletion docs/source/en/kv_cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ In case you are using Sink Cache, you have to crop your inputs to that maximum l
>>> user_prompts = ["Hello, what's your name?", "Btw, yesterday I was on a rock concert."]

>>> past_key_values = DynamicCache()
>>> max_cache_length = past_key_values.get_max_length()
>>> max_cache_length = past_key_values.get_max_cache_shape()

>>> messages = []
>>> for prompt in user_prompts:
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class Cache(torch.nn.Module):
Base, abstract class for all caches. The actual data structure is specific to each subclass.
"""

is_compileable = False

def __init__(self):
super().__init__()

Expand Down Expand Up @@ -1098,6 +1100,8 @@ class StaticCache(Cache):
```
"""

is_compileable = True

# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
@deprecate_kwarg("layer_device_map", version="4.52.0")
def __init__(
Expand Down Expand Up @@ -1297,6 +1301,7 @@ class SlidingWindowCache(StaticCache):
"""

is_sliding = True
is_compileable = True

# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__(
Expand Down Expand Up @@ -1421,6 +1426,7 @@ def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
super().__init__()
self.self_attention_cache = self_attention_cache
self.cross_attention_cache = cross_attention_cache
self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False)

self.is_updated = {}
for layer_idx in range(len(cross_attention_cache.key_cache)):
Expand Down Expand Up @@ -1612,6 +1618,8 @@ class HybridCache(Cache):
```
"""

is_compileable = True

# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
@deprecate_kwarg("layer_device_map", version="4.52.0")
def __init__(
Expand Down Expand Up @@ -1832,6 +1840,8 @@ class MambaCache:
```
"""

is_compileable = True

# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__(
self,
Expand Down Expand Up @@ -1975,6 +1985,8 @@ class OffloadedStaticCache(StaticCache):
```
"""

is_compileable = True

@deprecate_kwarg("layer_device_map", version="4.52.0")
def __init__(
self,
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,7 +1579,7 @@ def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProces


@dataclass
class CompileConfig(object):
class CompileConfig:
"""
Class that holds arguments relative to `torch.compile` behavior, when using automatic compilation in `generate`.
See [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) for more details on the arguments.
Expand Down Expand Up @@ -1620,7 +1620,9 @@ class CompileConfig(object):
backend: Union[str, Callable] = "inductor"
mode: str = "reduce-overhead"
options: Optional[dict] = None
# Used to flag our `generate` call to compile on e.g. CPU. Often not optimal, but useful for testing purposes.
_compile_all_devices = None

def to_dict(self) -> Dict[str, Any]:
"""Serializes this instance to a Python dictionary."""
return copy.deepcopy(self.__dict__)
return copy.deepcopy({key: value for key, value in self.__dict__.items() if key != "_compile_all_devices"})
8 changes: 5 additions & 3 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3177,9 +3177,11 @@ def _sample(
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

model_forward = self.__call__
if isinstance(model_kwargs.get("past_key_values"), StaticCache):
if self.device.type == "cuda":
logger.warning_once("Using `torch.compile`.")
if isinstance(model_kwargs.get("past_key_values"), Cache):
is_compileable = model_kwargs["past_key_values"].is_compileable
if is_compileable and (
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
):
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config)

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ class AriaPreTrainedModel(PreTrainedModel):
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
_supports_attention_backend = False

def _init_weights(self, module):
Expand Down Expand Up @@ -1561,6 +1561,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
logits_to_keep=logits_to_keep,
cache_position=cache_position,
)

logits = outputs[0]
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,7 @@ def _init_weights(self, module):


class AriaPreTrainedModel(LlamaPreTrainedModel):
_supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
_supports_attention_backend = False

def _init_weights(self, module):
Expand Down Expand Up @@ -1535,6 +1536,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
logits_to_keep=logits_to_keep,
cache_position=cache_position,
)

logits = outputs[0]
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,7 @@ class DbrxPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)

def _init_weights(self, module: nn.Module):
std = self.config.initializer_range
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/emu3/modeling_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,7 @@ def forward(

class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["text_model.lm_head.weight"]
_supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable

def __init__(self, config):
super().__init__(config)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/emu3/modular_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,7 @@ def forward(**super_kwargs):

class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["text_model.lm_head.weight"]
_supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable

def __init__(self, config):
super().__init__(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_static_cache = False # TODO (fix me): compilation fails due to a stide error?

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/granitemoe/modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,7 @@ class IdeficsPreTrainedModel(PreTrainedModel):
_no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = False # IDEFICS cannot compile due to dynamic control flow when checking inputs

def _init_weights(self, module):
# important: this ported version of Idefics isn't meant for training from scratch - only
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ class MixtralPreTrainedModel(PreTrainedModel):
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
_supports_attention_backend = True

def _init_weights(self, module):
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/models/mixtral/modular_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
MistralForSequenceClassification,
MistralForTokenClassification,
MistralModel,
MistralPreTrainedModel,
MistralRMSNorm,
MistralRotaryEmbedding,
)
from .configuration_mixtral import MixtralConfig

Expand Down Expand Up @@ -313,6 +315,14 @@ def forward(
return outputs


class MixtralRotaryEmbedding(MistralRotaryEmbedding):
pass


class MixtralPreTrainedModel(MistralPreTrainedModel):
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)


class MixtralModel(MistralModel):
def __init__(self, config: MixtralConfig):
super().__init__(config)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/olmoe/modeling_olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ class OlmoePreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/phimoe/modeling_phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ class PhimoePreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
74 changes: 52 additions & 22 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1978,52 +1978,82 @@ def test_generate_with_quant_cache(self):
model.generate(**generation_kwargs, **inputs_dict)

@pytest.mark.generate
@require_torch_accelerator
@slow
def test_generate_compile_model_forward(self):
"""
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests
end-to-end compilation and forward pass compilation only.
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest("This model doesn't support static cache")
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=4)

model = model_class(config).to(torch_device)
model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time

input_ids = inputs_dict["input_ids"].to(torch_device)
main_input = inputs_dict[model.main_input_name].to(torch_device)
# creates two sets of *different* inputs with the same shape
half_batch_size = input_ids.shape[0] // 2
input_ids_sets = [input_ids[:half_batch_size, :], input_ids[half_batch_size : half_batch_size * 2, :]]
self.assertTrue(input_ids_sets[0].shape == input_ids_sets[1].shape)
half_batch_size = main_input.shape[0] // 2
input_1 = {}
input_2 = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
input_1[key] = value[:half_batch_size, :].to(torch_device)
input_2[key] = value[half_batch_size : half_batch_size * 2, :].to(torch_device)
else:
input_1[key] = value
input_2[key] = value
model_input_sets = [input_1, input_2]
self.assertTrue(
model_input_sets[0][model.main_input_name].shape == model_input_sets[1][model.main_input_name].shape
)

# compilation-specific setup
torch.compiler.reset() # prevent cached compilation from being used in the test
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)

generation_kwargs = {
"do_sample": False,
"max_new_tokens": 10,
"max_new_tokens": 5,
"return_dict_in_generate": True,
"output_scores": True,
"cache_implementation": "static",
}

# get eager + dynamic cache results for future comparison
dynamic_outputs = []
for model_inputs in input_ids_sets:
dynamic_outputs.append(model.generate(model_inputs, **generation_kwargs))

# get compiled results
generation_config = copy.deepcopy(model.generation_config)
generation_config.update(**generation_kwargs)
torch.compiler.reset()
for model_inputs in model_input_sets:
gen_out = model.generate(**model_inputs, **generation_kwargs)
dynamic_outputs.append(gen_out)
# sanity checks for the default cache implementation
if not has_defined_cache_implementation:
decoder_cache = (
gen_out.past_key_values.self_attention_cache
if config.is_encoder_decoder
else gen_out.past_key_values
)
self.assertTrue(isinstance(decoder_cache, DynamicCache))
self.assertFalse(decoder_cache.is_compileable)
self.assertFalse(hasattr(model, "_compiled_call")) # our auto compile should NOT have been called

model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
# get compiled results -- relies on the automatic compilation triggered by specific "cache_implementation"
if not has_defined_cache_implementation:
generation_kwargs["cache_implementation"] = "static"

compiled_outputs = []
for model_inputs in input_ids_sets:
compiled_outputs.append(model.generate(model_inputs, generation_config=generation_config))
for model_inputs in model_input_sets:
gen_out = model.generate(**model_inputs, **generation_kwargs)
compiled_outputs.append(gen_out)
# sanity checks
decoder_cache = (
gen_out.past_key_values.self_attention_cache
if config.is_encoder_decoder
else gen_out.past_key_values
)
self.assertFalse(isinstance(decoder_cache, DynamicCache))
self.assertTrue(decoder_cache.is_compileable)
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called

for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
self._check_similar_generate_outputs(dynamic_result, compiled_result)
Expand Down
5 changes: 0 additions & 5 deletions tests/models/chameleon/test_modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,6 @@ def test_model_rope_scaling(self, scaling_type):
def test_batching_equivalence(self):
pass

# TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow
@unittest.skip("Chameleon is not compatible with end-to-end generation compilation")
def test_generate_compile_model_forward(self):
pass


@require_torch
class ChameleonIntegrationTest(unittest.TestCase):
Expand Down
4 changes: 0 additions & 4 deletions tests/models/dbrx/test_modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,6 @@ def test_disk_offload_safetensors(self):
def test_disk_offload_bin(self):
pass

@unittest.skip("Dbrx does not support `torch.compile` with `fullgraph=True`.")
def test_generate_compile_model_forward(self):
pass


@require_torch
class DbrxModelIntegrationTest(unittest.TestCase):
Expand Down
4 changes: 0 additions & 4 deletions tests/models/idefics/test_modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,10 +780,6 @@ def test_contrastive_generate_low_memory(self):
def test_custom_4d_attention_mask(self):
pass

@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
def test_generate_compile_model_forward(self):
pass

@unittest.skip(reason="We only test the model that takes in multiple images")
def test_model(self):
pass
Expand Down
Loading