Cleaner Cache dtype and device extraction for CUDA graph generation for quantizers compatibility#29079
Conversation
There was a problem hiding this comment.
Thanks a lot for fixing !
For retrieving the correct device, the fix sounds correct.
However for the dtype, I am afraid this might lead to some bugs / unexpected behaviours 😭 As many users call perform text generation after calling some utility methods such as prepare_model_for_kbit_training (using PEFT), we do sometimes cast the layer norms in FP32. This is quite a niche usecase though. I propose to be on the safe zone and retrieve the dtype similarly as what we do here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L451-L457 - can you let me know if applying that logic here would fix CUDA graph generation for quantized models?
Also, can you elaborate a bit on the original issue, i.e. what you are trying to achieve and what bug do you get
Thanks !
|
@younesbelkada The error I get on main is quite simple: That's why, I believe, it would make sense to source |
|
And I don't really understand what you're proposing with the code you referenced. |
|
@BlackSamorez thanks! you can do something like: if hasattr(self.config, "_pre_quantization_dtype")
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = layer.self_attn.o_proj.weightDoes that fixes the issue? |
dtype and device extraction for CUDA graph generation for quantizers compatibility
|
What you proposed seems to work fine with both FP16 and AQLM with a notebook test based off @ArthurZucker's test script. |
younesbelkada
left a comment
There was a problem hiding this comment.
Amazing work !
Could you add a simple test in the aqlm testing file to test that usecase 🙏
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@younesbelkada CUDA graph generation diverges at some point: A stupid solution would be to generate shorter texts but I'm not sure if it's a good idea to have unstable tests. P.S. As you might have guessed, I added a CUDA graph generation test for AQLM. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Looks very good to me!
Do you have any numbers to share regarding benchmakr?
| @unittest.skipUnless( | ||
| is_aqlm_available() and version.parse(importlib.metadata.version("aqlm")) >= version.parse("1.0.3"), | ||
| "test requires `aqlm>=1.0.3`", | ||
| ) | ||
| def test_quantized_model_compile(self): |
There was a problem hiding this comment.
@ArthurZucker @BlackSamorez The problem with it that it's failing :) . See this. So, advice needed on what to do here
There was a problem hiding this comment.
I don't think it is super important that the outputs match for quantized models no? Distributions are the same, but kernels / ops are not run in the same order. It's small but could explain this?
Would just add a long generation and make sure it still makes sense!
There was a problem hiding this comment.
I don't really know how to automatically check if text makes sense.
Alternatively, I've shortened the generation length from 40 tokens to 32 and it matches perfectly on RTX 3090, RTX 2080ti and a6000. Maybe we could just leave it as is since the tests above are exact match anyway.
(Current iteration tests pass)
younesbelkada
left a comment
There was a problem hiding this comment.
Amazing work @BlackSamorez !
| @unittest.skipUnless( | ||
| is_aqlm_available() and version.parse(importlib.metadata.version("aqlm")) >= version.parse("1.0.3"), | ||
| "test requires `aqlm>=1.0.3`", | ||
| ) | ||
| def test_quantized_model_compile(self): |
What does this PR do?
As of now, this PR fixes a small problem preventing one from using CUDA graph generation from #28937 with quantized models.
In the long run, It would be great to have compiled generation actually working for GPTQ/AQLM/other quantization methods.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.