-
Notifications
You must be signed in to change notification settings - Fork 33.1k
Cleaner Cache dtype and device extraction for CUDA graph generation for quantizers compatibility
#29079
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cleaner Cache dtype and device extraction for CUDA graph generation for quantizers compatibility
#29079
Changes from all commits
f93e730
a956ec8
042c2d5
edebc72
39d7603
1c0adb8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,10 +14,13 @@ | |
| # limitations under the License. | ||
|
|
||
| import gc | ||
| import importlib | ||
| import tempfile | ||
| import unittest | ||
|
|
||
| from transformers import AqlmConfig, AutoConfig, AutoModelForCausalLM, AutoTokenizer, OPTForCausalLM | ||
| from packaging import version | ||
|
|
||
| from transformers import AqlmConfig, AutoConfig, AutoModelForCausalLM, AutoTokenizer, OPTForCausalLM, StaticCache | ||
| from transformers.testing_utils import ( | ||
| require_accelerate, | ||
| require_aqlm, | ||
|
|
@@ -26,7 +29,7 @@ | |
| slow, | ||
| torch_device, | ||
| ) | ||
| from transformers.utils import is_accelerate_available, is_torch_available | ||
| from transformers.utils import is_accelerate_available, is_aqlm_available, is_torch_available | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
|
|
@@ -71,11 +74,12 @@ def test_from_dict(self): | |
| @require_aqlm | ||
| @require_accelerate | ||
| class AqlmTest(unittest.TestCase): | ||
| model_name = "BlackSamorez/Mixtral-8x7b-AQLM-2Bit-1x16-hf-test-dispatch" | ||
| model_name = "BlackSamorez/Llama-2-7b-AQLM-2Bit-1x16-hf" | ||
|
|
||
| input_text = "Hello my name is" | ||
| max_new_tokens = 32 | ||
|
|
||
| EXPECTED_OUTPUT = "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am currently a sophomore and am majoring in Psychology. I am" | ||
| EXPECTED_OUTPUT = "Hello my name is Katie. I am a 20 year old college student. I am a very outgoing person. I love to have fun and be active. I" | ||
|
|
||
| device_map = "cuda" | ||
|
|
||
|
|
@@ -144,7 +148,7 @@ def test_quantized_model(self): | |
| """ | ||
| input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) | ||
|
|
||
| output = self.quantized_model.generate(**input_ids, max_new_tokens=40) | ||
| output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) | ||
| self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) | ||
|
|
||
| def test_raise_if_non_quantized(self): | ||
|
|
@@ -164,7 +168,7 @@ def test_save_pretrained(self): | |
|
|
||
| input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) | ||
|
|
||
| output = model.generate(**input_ids, max_new_tokens=40) | ||
| output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) | ||
| self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) | ||
|
|
||
| @require_torch_multi_gpu | ||
|
|
@@ -178,6 +182,56 @@ def test_quantized_model_multi_gpu(self): | |
|
|
||
| self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) | ||
|
|
||
| output = quantized_model.generate(**input_ids, max_new_tokens=40) | ||
| output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) | ||
|
|
||
| self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) | ||
|
|
||
| @unittest.skipUnless( | ||
| is_aqlm_available() and version.parse(importlib.metadata.version("aqlm")) >= version.parse("1.0.3"), | ||
| "test requires `aqlm>=1.0.3`", | ||
| ) | ||
| def test_quantized_model_compile(self): | ||
|
Comment on lines
+189
to
+193
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Loving this test ❤️
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ArthurZucker @BlackSamorez The problem with it that it's failing :) . See this. So, advice needed on what to do here
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it is super important that the outputs match for quantized models no? Distributions are the same, but kernels / ops are not run in the same order. It's small but could explain this?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't really know how to automatically check if text makes sense.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fine with me 😉 |
||
| """ | ||
| Simple test that checks if the quantized model is working properly | ||
| """ | ||
|
|
||
| # Sample tokens greedily | ||
| def decode_one_tokens(model, cur_token, input_pos, cache_position): | ||
| logits = model( | ||
| cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True | ||
| )[0] | ||
| new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int) | ||
|
|
||
| return new_token | ||
|
|
||
| # Tokenize the test input | ||
| input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)["input_ids"] | ||
| seq_length = input_ids.shape[1] | ||
|
|
||
| # Setup static KV cache for generation | ||
| self.quantized_model._setup_cache(StaticCache, 1, max_cache_len=seq_length + self.max_new_tokens + 1) | ||
|
|
||
| # Allocate token ids to be generated and copy prefix ids | ||
| cache_position = torch.arange(seq_length, device=torch_device) | ||
| generated_ids = torch.zeros(1, seq_length + self.max_new_tokens, dtype=torch.int, device=torch_device) | ||
| generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int) | ||
|
|
||
| # Do a forward pass to fill the prefix cache and compile the kernels if necessary | ||
| logits = self.quantized_model(input_ids, cache_position=cache_position, return_dict=False, use_cache=True)[0] | ||
| next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int) | ||
| generated_ids[:, [seq_length]] = next_token | ||
|
|
||
| with torch.no_grad(): | ||
| # Compile the CUDA graph | ||
| decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) | ||
|
|
||
| # Generate tokens one by one | ||
| cache_position = torch.tensor([seq_length + 1], device=torch_device) | ||
| for _ in range(1, self.max_new_tokens): | ||
| with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): | ||
| next_token = decode_one_tokens(self.quantized_model, next_token.clone(), None, cache_position) | ||
| generated_ids.index_copy_(1, cache_position, next_token) | ||
| cache_position += 1 | ||
|
|
||
| # Check generated text | ||
| self.assertEqual(self.tokenizer.decode(generated_ids[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) | ||
Uh oh!
There was an error while loading. Please reload this page.