fix static cache data type miss-match#34799
Conversation
|
BTW, I suppose transformers missed some static cache tests, do you have any instructions about where can I add this kind of test? Thanks! |
All the tests related to the cache are in test_utils.py file. inside |
SunMarc
left a comment
There was a problem hiding this comment.
Thanks for the bug fix ! Left a comment
| key = key.permute(0, 2, 1, 3).to(value.dtype) | ||
| query = query.permute(0, 2, 1, 3).to(value.dtype) |
There was a problem hiding this comment.
could you explain why this is needed for this particular model and why this doesn't happen for llama for example ? Many models have approximately the same modeling code.
There was a problem hiding this comment.
For llama, we can see sin and cos come from position_embeddings (bf16 tensor) which comes from here. You can see llama's rotary embedding converts the data type. But for gptj, the position embeddings come from here, it set the data type to float32, so data type miss-match happens when the input data type is bf16 or fp16.
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
SunMarc
left a comment
There was a problem hiding this comment.
Thanks for the explanation ! LGTM ! Did you run the static cache tests you added to see if there are other models that requires this fix ?
ArthurZucker
left a comment
There was a problem hiding this comment.
➕ on marc's comment. The safe way to do this is to cast key and query to the cache's dtype no? And do this in the cache_utils rather than at the modeling level!
Yes, I have applied your suggestions, thanks!
The CI already contains the tests that I changed, so currently no other models require it. Besides, I have changed it into cache_utils which should be applied for all language models with static cache. |
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
|
Hi @SunMarc , please review the new changes, thanks! |
All good from my side. Pinging @ArthurZucker |
ArthurZucker
left a comment
There was a problem hiding this comment.
Not completely sure we want to test for float32 as it's quite heavy
|
I think it was testing for float32 initally and @jiqing-feng added coverage for float16 @ArthurZucker |
|
Sounds good then, merging!˜ |
* fix gptj data type missmatch Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add low precision static cache tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix low-precision static cache tests * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * avoid config change Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * change data type convert in cache copy Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix comment Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * cast key value after k v out Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Hi @SunMarc . This PR fixed the data type mismatch when using low-precision static cache. The following code can reproduce the bug:
Output: