Skip to content

Handle cache_position for transformers 4.47.0 and later (#528)#529

Merged
lancerts merged 2 commits intolinkedin:mainfrom
BenasdTW:fix_qwen2vl_528
Jan 21, 2025
Merged

Handle cache_position for transformers 4.47.0 and later (#528)#529
lancerts merged 2 commits intolinkedin:mainfrom
BenasdTW:fix_qwen2vl_528

Conversation

@BenasdTW
Copy link
Copy Markdown
Contributor

@BenasdTW BenasdTW commented Jan 18, 2025

Summary

Fix issue #528 by copying the new way to handle RoPE from transformers 4.48.0

        if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
            # calculate RoPE index once per generation in the pre-fill stage only
            if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
                position_ids, rope_deltas = self.get_rope_index(
                    input_ids, image_grid_thw, video_grid_thw, attention_mask
                )
                self.rope_deltas = rope_deltas
            # then use the prev pre-calculated rope-deltas to get the correct position ids
            else:
                batch_size, seq_length, _ = inputs_embeds.shape
                delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
                position_ids = torch.arange(seq_length, device=inputs_embeds.device)
                position_ids = position_ids.view(1, -1).expand(batch_size, -1)
                if cache_position is not None:  # otherwise `deltas` is an int `0`
                    delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
                position_ids = position_ids.add(delta)
                position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)

Testing Done

Tested on (all worked with this PR, 4.48.0 didn't work with this PR):
pip install transformers==4.46.2
pip install transformers==4.46.3
pip install transformers==4.48.0

Before applying this PR, using training Qwen2-VL using liger-kernel with transformers>=4.47.0 would result in this error (issue #528):

Traceback (most recent call last):
  File "/workspaces/test/t.py", line 51, in <module>
    generated_ids = model.generate(**inputs, max_new_tokens=128)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/generation/utils.py", line 2255, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/generation/utils.py", line 3254, in _sample
    outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: lce_forward() got an unexpected keyword argument 'cache_position'

Inference test script:

import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
from transformers import BitsAndBytesConfig
from qwen_vl_utils import process_vision_info

apply_liger_kernel_to_qwen2_vl()

model_id = "Qwen/Qwen2-VL-2B-Instruct"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=False, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
    attn_implementation="flash_attention_2",
)
processor = AutoProcessor.from_pretrained(model_id)


messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
            },
            {"type": "text", "text": "Describe this image."},
        ],
    }
]

# Preparation for inference
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("cuda")

# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)

Training test script:

import torch
from datasets import load_dataset
from qwen_vl_utils import process_vision_info
from transformers import BitsAndBytesConfig
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
from peft import LoraConfig, get_peft_model
from peft.optimizers import create_loraplus_optimizer
import bitsandbytes as bnb
from trl import SFTTrainer, SFTConfig
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor
from configs_and_helpers import clear_memory, vl_format_data, generate_text_from_sample

apply_liger_kernel_to_qwen2_vl()

model_id = "Qwen/Qwen2-VL-2B-Instruct"
dataset_id = "HuggingFaceM4/ChartQA"

system_message = """You are a Vision Language Model specialized in interpreting visual data from chart images.
Your task is to analyze the provided chart image and respond to queries with concise answers, usually a single word, number, or short phrase.
The charts include a variety of types (e.g., line charts, bar charts) and contain colors, labels, and text.
Focus on delivering accurate, succinct answers based on the visual information. Avoid additional explanation unless absolutely necessary."""

train_dataset, eval_dataset, test_dataset = load_dataset(dataset_id, split=["train[:20%]", "val[:2%]", "test[:1%]"])

train_dataset = [vl_format_data(sample, system_message) for sample in train_dataset]
eval_dataset = [vl_format_data(sample, system_message) for sample in eval_dataset]
test_dataset = [vl_format_data(sample, system_message) for sample in test_dataset]


model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
)

processor = Qwen2VLProcessor.from_pretrained(model_id)
print(f"{train_dataset[0]=}")
print(f"{train_dataset[0][1:2]=}")


output = generate_text_from_sample(model, processor, train_dataset[0])
print(f"{output=}")

clear_memory(globals())

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=False, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
    attn_implementation="flash_attention_2",
    use_cache=False
)

processor = Qwen2VLProcessor.from_pretrained(model_id)
processor.padding_side = "right"  # Ensure padding is added to the right side
processor.tokenizer.padding_side = "right"  # Ensure padding is added to the right side

# Configure LoRA
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=8,
    bias="none",
    target_modules=["q_proj", "v_proj"],
    task_type="CAUSAL_LM",
)

# Apply PEFT model adaptation
print(model)
model = get_peft_model(model, peft_config)
print(model)

# Print trainable parameters
model.print_trainable_parameters()

optimizer = create_loraplus_optimizer(
    model=model,
    optimizer_cls=bnb.optim.PagedAdamW8bit,
    # optimizer_cls=torch.optim.AdamW,
    lr=2e-4,
    eps=1e-6,
    # eps=1e-8,
    betas=(0.9, 0.999),
    weight_decay=0.0,
    loraplus_lr_ratio=16,
)
scheduler = None


# Configure training arguments
training_args = SFTConfig(
    output_dir="qwen2-2b-instruct-trl-sft-ChartQA",
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,
    eval_accumulation_steps=4,
    # Logging and evaluation
    logging_steps=1,
    eval_steps=10,
    torch_empty_cache_steps=1,
    eval_strategy="steps",
    save_strategy="epoch",
    bf16=True,
    # Gradient checkpointing settings
    gradient_checkpointing_kwargs={"use_reentrant": False},
    gradient_checkpointing=True,
    # Dataset configuration
    dataset_kwargs={"skip_prepare_dataset": True},
    # max_seq_length=1024  # Maximum sequence length for input
    remove_unused_columns = False  # Keep unused columns in dataset
)

# Create a data collator to encode text and image pairs
def collator_fn(examples):
    # Get the texts and images, and apply the chat template
    texts = [processor.apply_chat_template(example, tokenize=False) for example in examples]
    image_inputs = [process_vision_info(example)[0] for example in examples]  # Process the images to extract inputs

    # Tokenize the texts and process the images
    batch = processor(
        text=texts, images=image_inputs, return_tensors="pt", padding=True
    )

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()
    # Mask padding tokens in labels
    labels[labels == processor.tokenizer.pad_token_id] = -100

    # Ignore the image token index in the loss computation (model specific)
    if isinstance(processor, Qwen2VLProcessor):
        image_tokens = [151652, 151653, 151655]  # Specific image token IDs for Qwen2VLProcessor
    else:
        image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]  # Convert image token to ID

    # Mask image token IDs in the labels
    for image_token_id in image_tokens:
        labels[labels == image_token_id] = -100

    batch["labels"] = labels

    return batch

print(f"Processed data:\n{collator_fn(train_dataset[:2])}")

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=collator_fn,
    peft_config=peft_config,
    processing_class=processor.tokenizer,
    optimizers=(optimizer, scheduler)
)
trainer.train()
trainer.save_model(training_args.output_dir)

# model.save_model(output_name)

My hardware is fairly weak, OOM running make test. Might need further testing.

  • Hardware Type: Nvidia RTX 4070 Laptop
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@Tcc0403
Copy link
Copy Markdown
Collaborator

Tcc0403 commented Jan 18, 2025

LGTM! All test passed on my machine, including trasnformers 4.48.0 and 4.47.1.

Tested on (all worked with this PR, 4.48.0 didn't work with this PR):

btw, could you clarify what you mean by that?

@BenasdTW
Copy link
Copy Markdown
Contributor Author

LGTM! All test passed on my machine, including trasnformers 4.48.0 and 4.47.1.

Tested on (all worked with this PR, 4.48.0 didn't work with this PR):

btw, could you clarify what you mean by that?

@Tcc0403 Thank you for taking the time to look into this! I have updated my previous comment.

Before applying this PR, using training Qwen2-VL using liger-kernel with transformers>=4.47.0 would result in this error (issue #528):

Traceback (most recent call last):
  File "/workspaces/test/t.py", line 51, in <module>
    generated_ids = model.generate(**inputs, max_new_tokens=128)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/generation/utils.py", line 2255, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/generation/utils.py", line 3254, in _sample
    outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: lce_forward() got an unexpected keyword argument 'cache_position'

@Tcc0403
Copy link
Copy Markdown
Collaborator

Tcc0403 commented Jan 18, 2025

Oh I see what you mean. You meant 4.48.0 couldn't pass the test without this PR, and applying this PR can fix the issue.

Before merging it, let me ask liger team whether we should modify liger's backward compatibility ci. If so, we can update the ci to test this PR.

@lancerts lancerts merged commit 2ea3cfb into linkedin:main Jan 21, 2025
@BenasdTW BenasdTW deleted the fix_qwen2vl_528 branch January 23, 2025 08:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Qwen2-VL breaks with transformers version 4.47.0+: TypeError: lce_forward() got an unexpected keyword argument 'cache_position'

3 participants