🐛 Describe the bug
Got the error while applying liger kernel on Qwen2-VL with transformers version 4.47.0+.
Traceback (most recent call last):
File "/workspaces/test/t.py", line 51, in <module>
generated_ids = model.generate(**inputs, max_new_tokens=128)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/transformers/generation/utils.py", line 2255, in generate
result = self._sample(
^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/transformers/generation/utils.py", line 3254, in _sample
outputs = self(**model_inputs, return_dict=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: lce_forward() got an unexpected keyword argument 'cache_position'
The error does not occur with Transformers version 4.46.3.
This was mentioned in another post:
cache_position is added to qwen2_vl in transformers v4.47.0
see: huggingface/transformers#34274
We have to update liger's flce_forward() in qwen2_vl.py to match the new implementation.
Originally posted by @Tcc0403 in #515
Reproduce
Getting error TypeError: lce_forward() got an unexpected keyword argument 'cache_position' by running this simple script
The code works fine without apply_liger_kernel_to_qwen2_vl()
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
from transformers import BitsAndBytesConfig
from qwen_vl_utils import process_vision_info
apply_liger_kernel_to_qwen2_vl()
model_id = "Qwen/Qwen2-VL-2B-Instruct"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_use_double_quant=False, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=bnb_config,
attn_implementation="flash_attention_2",
)
processor = AutoProcessor.from_pretrained(model_id)
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
},
{"type": "text", "text": "Describe this image."},
],
}
]
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
Versions
Operating System: Linux-6.11.0-13-generic-x86_64-with-glibc2.35
Python version: 3.11.10
Liger Kernel version: 0.5.2
PyTorch version: 2.5.1+cu124
CUDA version: 12.4
HIP(ROCm) version: Not available
Triton version: 3.1.0
Transformers version: 4.48.0
XPU version: XPU Not Available
🐛 Describe the bug
Got the error while applying liger kernel on Qwen2-VL with transformers version 4.47.0+.
The error does not occur with Transformers version 4.46.3.
This was mentioned in another post:
Originally posted by @Tcc0403 in #515
Reproduce
Getting error
TypeError: lce_forward() got an unexpected keyword argument 'cache_position'by running this simple scriptThe code works fine without
apply_liger_kernel_to_qwen2_vl()Versions
Operating System: Linux-6.11.0-13-generic-x86_64-with-glibc2.35
Python version: 3.11.10
Liger Kernel version: 0.5.2
PyTorch version: 2.5.1+cu124
CUDA version: 12.4
HIP(ROCm) version: Not available
Triton version: 3.1.0
Transformers version: 4.48.0
XPU version: XPU Not Available