Skip to content

Fix Gemma2 synced multi-GPU generation#35232

Merged
ArthurZucker merged 4 commits intohuggingface:mainfrom
ManukyanD:fix_gemma2_multi_gpu_generation
Feb 5, 2025
Merged

Fix Gemma2 synced multi-GPU generation#35232
ArthurZucker merged 4 commits intohuggingface:mainfrom
ManukyanD:fix_gemma2_multi_gpu_generation

Conversation

@ManukyanD
Copy link
Copy Markdown
Contributor

What does this PR do?

Generation with Gemma2ForCausalLM in synced multi-GPU settings crashes because the cache_position goes out of bounds. This PR addresses the issue.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker
@gante

@ManukyanD ManukyanD force-pushed the fix_gemma2_multi_gpu_generation branch from fcbc37b to 57c52db Compare December 12, 2024 13:58
@ManukyanD
Copy link
Copy Markdown
Contributor Author

@gante Couldn't you, please, review this PR?

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, could you additionally provide a reproducer ? 🤗

@ManukyanD
Copy link
Copy Markdown
Contributor Author

ManukyanD commented Dec 23, 2024

@ArthurZucker Sure, here is an example script

from transformers import (
    GenerationConfig,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    AutoTokenizer,
    Gemma2ForCausalLM,
)

from datasets import load_dataset


model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = Gemma2ForCausalLM.from_pretrained(model_id)

dataset = load_dataset("google/boolq")

prompt = """
Answer the question according to the provided passage.

Passage:
{passage}

Question:
{question}
"""


def collator(examples):

    inputs = [
        [
            {
                "role": "user",
                "content": prompt.format(
                    passage=example["passage"], question=example["question"]
                ),
            }
        ]
        for example in examples
    ]
    inputs = tokenizer.apply_chat_template(
        inputs, add_generation_prompt=True, tokenize=False
    )
    return tokenizer(
        inputs, padding=True, add_special_tokens=False, return_tensors="pt"
    )


def compute_metrics(pred):
    # compute some metrics
    return {"score": 0.5}


training_args = Seq2SeqTrainingArguments(
    output_dir="./output",
    generation_config=GenerationConfig(
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        max_new_tokens=512,
        use_cache=True,
    ),
    remove_unused_columns=False,
    predict_with_generate=True,
    per_device_eval_batch_size=8,
    deepspeed="./deepspeed_config.json",
)
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    data_collator=collator,
    compute_metrics=compute_metrics,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
)
trainer.evaluate()

I am running this with deepspeed. The following is the deepspeed config

{
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    },
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "gradient_accumulation_steps": "auto"
}

@zucchini-nlp
Copy link
Copy Markdown
Member

@ArthurZucker let's merge this and a related PR (#35893), seems like a real issue and has been reported by another user recently

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I was a bit waiting to get more failures but let's go

@ArthurZucker ArthurZucker merged commit 4831a94 into huggingface:main Feb 5, 2025
@ManukyanD ManukyanD deleted the fix_gemma2_multi_gpu_generation branch February 5, 2025 09:26
@gante
Copy link
Copy Markdown
Contributor

gante commented Feb 5, 2025

@ManukyanD thank you for the fix 💛

cc @SunMarc: this PR (and #35893) copies the fix in #34095 into functions that are overwritten in specific models. There, you mentioned we had no tests for it [multigpu + generate] -- any chance you had a look at it, or would you be able to share pointers to add an appropriate lightweight test? 🙏

elvircrn pushed a commit to elvircrn/transformers that referenced this pull request Feb 13, 2025
* Fix Gemma2 synced multi-GPU generation

* Fix import ordering in modular_gemma2.py
sbucaille pushed a commit to sbucaille/transformers that referenced this pull request Feb 16, 2025
* Fix Gemma2 synced multi-GPU generation

* Fix import ordering in modular_gemma2.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants