diff --git a/kvpress/presses/kvzip_press.py b/kvpress/presses/kvzip_press.py index 0f13e68a..f0ba82c9 100644 --- a/kvpress/presses/kvzip_press.py +++ b/kvpress/presses/kvzip_press.py @@ -116,10 +116,7 @@ def __call__(self, model: PreTrainedModel) -> Generator: def wrapped_forward(model_self, *args, **kwargs): self._context_ids = kwargs["input_ids"] - assert ( - "past_key_value" in kwargs or "past_key_values" in kwargs - ), f"KVzipPress requires 'past_key_value' or 'past_key_values' during prefilling. Got {kwargs.keys()}" - self._cache = kwargs.get("past_key_values", None) or kwargs.get("past_key_value", None) + self._cache = kwargs["past_key_values"] return original_forward(*args, **kwargs) model.model.forward = MethodType(wrapped_forward, model.model)