Skip to content

Fix: Jamba batched generation#32914

Merged
ArthurZucker merged 5 commits intohuggingface:mainfrom
vasqu:jamba-batched-gen-fix
Aug 28, 2024
Merged

Fix: Jamba batched generation#32914
ArthurZucker merged 5 commits intohuggingface:mainfrom
vasqu:jamba-batched-gen-fix

Conversation

@vasqu
Copy link
Copy Markdown
Contributor

@vasqu vasqu commented Aug 21, 2024

What does this PR do?

Basically a continuation of #32677 which implements the fixes for Jamba this time. Batched generation tests might need to be changed, especially the logits, but not sure how to proceed there as the logits are HW dependent.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@molbap @ArthurZucker

Comment thread src/transformers/models/jamba/modeling_jamba.py Outdated
Comment on lines 461 to 505
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passed locally without the higher rtol/atol. Will see if the CI agrees.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it does :D

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping it open for visibility: Left padding works fine now, it was an issue of how padding has been handled in general (for mamba-related models).

Comment thread tests/models/jamba/test_modeling_jamba.py Outdated
@vasqu
Copy link
Copy Markdown
Contributor Author

vasqu commented Aug 21, 2024

CI failure seems unrelated to the PR, some import issues from another model.

Comment thread tests/models/jamba/test_modeling_jamba.py Outdated
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I would find weird is if this does not improve / change the results. Especially for batched generation! The model is tiny random, would be nice if we can run this with the big one 👀

Comment thread src/transformers/models/jamba/modeling_jamba.py Outdated
Comment thread tests/models/jamba/test_modeling_jamba.py Outdated
@vasqu
Copy link
Copy Markdown
Contributor Author

vasqu commented Aug 23, 2024

Yea, I think it should definitely improve the batched generation. Especially since the test_left_padding_compatibility doesn't need higher atols anymore, padding is not as big of a problem as before (I think).

Too GPU poor to run the Jamba models, iirc they require at least an 80GB Vram GPU 😢 Maybe we could notify the guys behind Jamba? I doubt they are aware of this issue.

@vasqu vasqu force-pushed the jamba-batched-gen-fix branch from cfc73d9 to e2c2341 Compare August 23, 2024 10:08
@vasqu
Copy link
Copy Markdown
Contributor Author

vasqu commented Aug 23, 2024

#32250 seems to have changed the integration tests cc @gante

Guess we have to redo them again 👀

with torch.no_grad():
logits = self.model(input_ids=inputs["input_ids"]).logits

# TODO fix logits
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For more visibility so that I don't forget about it.

Copy link
Copy Markdown
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for fixing! 🙌

Added a nit to confirm. Pre-approving assuming the logits tests will be addressed (sorry about that :) )

# No need for zeroing states when
# 1. Cached forward
# 2. Attending to all inputs
if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this line will fail at compilation time (data-dependent conditional branch). Can you confirm, i.e. try running a compiled forward pass?

If it fails, we can add a compile guard, i.e. start the if with not is_torchdynamo_compiling()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested via the following mini script:

import torch
from transformers import JambaForCausalLM, AutoTokenizer

model_id = "ai21labs/Jamba-tiny-random"
model = JambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to("cuda")
model = torch.compile(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# tested on both, batched or non-batched input
#input = tokenizer(["Hey how are you doing on this lovely evening?", "What is the purpose of life?"], padding=True, return_tensors="pt").to("cuda")
input = tokenizer(["What is the purpose of life?"], padding=True, return_tensors="pt").to("cuda")

# tested on both, forward call or generate
out = model(**input)
#out = model.generate(**input, do_sample=False, max_new_tokens=10)

Haven't encountered any compilation errors locally, so seems to be fine. Is this what you had in mind to test compilation?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's it!

Perfect, thank you for confirming :)

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks again @vasqu for your great contributions!

@ArthurZucker ArthurZucker merged commit 3bfd3e4 into huggingface:main Aug 28, 2024
@vasqu vasqu deleted the jamba-batched-gen-fix branch August 28, 2024 11:15
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* init fix

* fix mask during cached forward, move mask related stuff to own function

* adjust tests as left padding does not change logits as much anymore + batch gen (with todo on logits comp)

* revert overwriting new integration tests

* move some comments to docstring
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants