Llama: fix batched generation#29109
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
These changed test results were checked against 4b236aed7618d90546cd2e8797dab5b4a24c5dce (the commit before the static caches were introduced).
These tests do batched generation, hence the need to change.
👉 the fact that this PR matches the commit before the static caches in this test means that we can now do left-padded batched generation with the same results!
|
I'll have to run the benchmark on the A100 to make sure everything is alright but otherwise should be good |
ArthurZucker
left a comment
There was a problem hiding this comment.
Great work, nice catch! I'll approve but let me run the benchmark on my side!
| cos = cos.unsqueeze(unsqueeze_dim) | ||
| sin = sin.unsqueeze(unsqueeze_dim) |
There was a problem hiding this comment.
let's unsqueeze in the rotary embedding no? or that changes the shape we previously had?
There was a problem hiding this comment.
Same shapes/no shape problems, but unsqueezing here is preferable by some users (see #27117)
| ) | ||
| freqs = freqs.transpose(1, 2) | ||
| emb = torch.cat((freqs, freqs), dim=-1) | ||
| return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) |
There was a problem hiding this comment.
BTW for BC we could / should still cache the rope no?
With a property _sin_cache: logger.warning_once(will be removed in 4.39) WDYT?
| causal_mask = torch.triu(mask, diagonal=1).to(dtype) | ||
| causal_mask = torch.triu(mask, diagonal=1) | ||
|
|
||
| causal_mask = causal_mask.to(dtype=dtype, device=device) |
| EXPECTED_GENERATION = [ | ||
| "The best color is\n\n\n\n\n\n\n\n\n\n", | ||
| "We should not undermind the issues at hand, but address them head on.\nI think", | ||
| "The best color isЋ the one that complements the skin tone of", |
There was a problem hiding this comment.
-isЋ t
+is tseems strange 😅 but alright
There was a problem hiding this comment.
hehe this weird one is a copy/paste
(it has right-padding, so we should expect weird things at generation time)
|
Alright, no significant slow downs so 🟢 but I can't do naive Dynamic generation with the same script as before: File "/home/arthur/transformers/../static-kv-cache/clean_bench.py", line 147, in <module>
outputs = model(input_ids, past_key_values=past_key_values,position_ids=position_ids,cache_position=cache_position, return_dict=False, use_cache = True)
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1545, in _call_impl
return forward_call(*args, **kwargs)
File "/home/arthur/transformers/src/transformers/models/llama/modeling_llama.py", line 1155, in forward
outputs = self.model(
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1545, in _call_impl
return forward_call(*args, **kwargs)
File "/home/arthur/transformers/src/transformers/models/llama/modeling_llama.py", line 995, in forward
layer_outputs = decoder_layer(
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1545, in _call_impl
return forward_call(*args, **kwargs)
File "/home/arthur/transformers/src/transformers/models/llama/modeling_llama.py", line 721, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1545, in _call_impl
return forward_call(*args, **kwargs)
File "/home/arthur/transformers/src/transformers/models/llama/modeling_llama.py", line 628, in forward
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1545, in _call_impl
return forward_call(*args, **kwargs)
File "/home/arthur/transformers/src/transformers/models/llama/modeling_llama.py", line 107, in forward
position_ids[:, None, :].float()
IndexError: too many indices for tensor of dimension 1 |
|
@ArthurZucker regarding the benchmark error: position ids should be a 2D tensor, just like the input ids :D I also had to adapt it on my end |
|
Alright if passing a 1d before was erroring out! |
|
@gante thanks a lot for this |
| self._cos_cached = cos | ||
| self._sin_cached = sin |
There was a problem hiding this comment.
we should. not always overwrite them. We need them accessible but not to be overwritten at the forward
What does this PR do?
Fixes batched inference on llama, after the static cache changes were added. For instance,
RUN_SLOW=1 py.test tests/test_cache_utils.py::CacheIntegrationTest::test_dynamic_cache_beam_searchnow passes.What was wrong?
position_idshas shape[bsz, seq_len]. The line computingfreqswas correct for batch size = 1, but incorrect for larger batch sizes: it was summing the values for the different batch members. Therefore, we need to create another dimension to prevent this sum from happening, which is what this PR does.Throughput impact of changes
None 🙌 [Measured on my end, RTX3090 +
TinyLlama/TinyLlama-1.1B-Chat-v1.0]Before this PR

After this PR
