Fix: Jamba batched generation#32914
Conversation
There was a problem hiding this comment.
Passed locally without the higher rtol/atol. Will see if the CI agrees.
There was a problem hiding this comment.
Seems like it does :D
There was a problem hiding this comment.
Keeping it open for visibility: Left padding works fine now, it was an issue of how padding has been handled in general (for mamba-related models).
|
CI failure seems unrelated to the PR, some import issues from another model. |
ArthurZucker
left a comment
There was a problem hiding this comment.
What I would find weird is if this does not improve / change the results. Especially for batched generation! The model is tiny random, would be nice if we can run this with the big one 👀
|
Yea, I think it should definitely improve the batched generation. Especially since the Too GPU poor to run the Jamba models, iirc they require at least an 80GB Vram GPU 😢 Maybe we could notify the guys behind Jamba? I doubt they are aware of this issue. |
… batch gen (with todo on logits comp)
cfc73d9 to
e2c2341
Compare
| with torch.no_grad(): | ||
| logits = self.model(input_ids=inputs["input_ids"]).logits | ||
|
|
||
| # TODO fix logits |
There was a problem hiding this comment.
For more visibility so that I don't forget about it.
| # No need for zeroing states when | ||
| # 1. Cached forward | ||
| # 2. Attending to all inputs | ||
| if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): |
There was a problem hiding this comment.
I suspect this line will fail at compilation time (data-dependent conditional branch). Can you confirm, i.e. try running a compiled forward pass?
If it fails, we can add a compile guard, i.e. start the if with not is_torchdynamo_compiling()
There was a problem hiding this comment.
Tested via the following mini script:
import torch
from transformers import JambaForCausalLM, AutoTokenizer
model_id = "ai21labs/Jamba-tiny-random"
model = JambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to("cuda")
model = torch.compile(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# tested on both, batched or non-batched input
#input = tokenizer(["Hey how are you doing on this lovely evening?", "What is the purpose of life?"], padding=True, return_tensors="pt").to("cuda")
input = tokenizer(["What is the purpose of life?"], padding=True, return_tensors="pt").to("cuda")
# tested on both, forward call or generate
out = model(**input)
#out = model.generate(**input, do_sample=False, max_new_tokens=10)Haven't encountered any compilation errors locally, so seems to be fine. Is this what you had in mind to test compilation?
There was a problem hiding this comment.
Yes, that's it!
Perfect, thank you for confirming :)
ArthurZucker
left a comment
There was a problem hiding this comment.
LGTM thanks again @vasqu for your great contributions!
* init fix * fix mask during cached forward, move mask related stuff to own function * adjust tests as left padding does not change logits as much anymore + batch gen (with todo on logits comp) * revert overwriting new integration tests * move some comments to docstring
What does this PR do?
Basically a continuation of #32677 which implements the fixes for Jamba this time. Batched generation tests might need to be changed, especially the logits, but not sure how to proceed there as the logits are HW dependent.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@molbap @ArthurZucker