From 3933cb67d0be7db0eac67ce92c4679e5622a00b6 Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Tue, 7 Jan 2025 17:09:04 +0800 Subject: [PATCH 1/4] Apply calibration patch and deduplicate delegate cache patch --- examples/qualcomm/oss_scripts/llama/llama.py | 53 +++++++++++++++++--- exir/emit/test/test_emit.py | 11 ++++ 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 0af0f55b88f..0c96c275695 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -82,7 +82,6 @@ def _kv_calibrate( _, atten_mask, _, k_caches, v_caches = example_inputs # TODO: change criteria & support batch inputs if necessary - pos = torch.tensor(0, dtype=torch.int32) max_cache_len = max_seq_len - 1 token_list = [] @@ -114,10 +113,42 @@ def _kv_calibrate( for i, v_cache in enumerate(v_caches) ] - pos += 1 - atten_mask[0][-pos - 1] = 0 - if pos >= len(token_list): - token_list.append(torch.argmax(logits[:, -1], dim=-1).item()) + # token_list = sp_model.encode(user_prompts, bos=True, eos=False) + + user_token_list = [ + # what is the capital of the united states + [128000, 128006, 882, 128007, 271, 12840, 374, 279, 6864, 315, 279, 29292, 5415, 128009, 128006, 78191, 128007, 271], + # what is 1 + 1 + [128000, 128006, 882, 128007, 271, 12840, 374, 220, 16, 489, 220, 16, 128009, 128006, 78191, 128007, 271], + # what is the meaning of life + [128000, 128006, 882, 128007, 271, 12840, 374, 279, 7438, 315, 2324, 128009, 128006, 78191, 128007, 271], + ] + + for token_list in user_token_list: + _, atten_mask, _, k_caches, v_caches = copy.deepcopy(example_inputs) + pos = torch.tensor(0, dtype=torch.int32) + with torch.no_grad(): + while token_list[-1] != sp_model.eos_id and pos < max_cache_len: + logits, new_k_caches, new_v_caches = module( + torch.full((1, 1), token_list[pos], dtype=torch.int32), + atten_mask, + torch.full((1, 1), pos), + *k_caches, + *v_caches, + ) + k_caches = [ + torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1) + for i, k_cache in enumerate(k_caches) + ] + v_caches = [ + torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1) + for i, v_cache in enumerate(v_caches) + ] + + pos += 1 + atten_mask[0][-pos - 1] = 0 + if pos >= len(token_list): + token_list.append(torch.argmax(logits[:, -1], dim=-1).item()) print(f"kv calibration data:\n{tokenizer.decode(token_list)}") @@ -328,7 +359,17 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): max_seq_len=self.llama_meta["get_max_seq_len"], ) - self.llama_model = convert_pt2e(fx_graph_module) + fx_graph_module = convert_pt2e(fx_graph_module) + + logging.info("Evaluating the converted model...") + calibrate( + self.get_example_inputs(self.llama_meta["get_use_kv_cache"]), + args.prompt, + fx_graph_module, + tokenizer_model_path=args.tokenizer_model, + max_seq_len=self.llama_meta["get_max_seq_len"], + ) + self.llama_model = fx_graph_module def lowering_modules( self, diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 3fca3958feb..bc3b7975d71 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -1682,7 +1682,10 @@ def forward(self, x): ] self.assertEqual(external_map["linear.weight"], 0) self.assertEqual(external_map["linear.bias"], 1) +<<<<<<< HEAD +======= +>>>>>>> c766f0dc0 (Apply calibration patch and deduplicate delegate cache patch) def test_delegate_deduplicate(self) -> None: class SharedModule(torch.nn.Module): def __init__(self): @@ -1692,6 +1695,10 @@ def __init__(self): def forward(self, x): return self.linear(x) +<<<<<<< HEAD +======= + +>>>>>>> c766f0dc0 (Apply calibration patch and deduplicate delegate cache patch) class Module1(torch.nn.Module): def __init__(self, shared_module): super().__init__() @@ -1700,6 +1707,10 @@ def __init__(self, shared_module): def forward(self, x): return self.shared_module(x) +<<<<<<< HEAD +======= + +>>>>>>> c766f0dc0 (Apply calibration patch and deduplicate delegate cache patch) class Module2(torch.nn.Module): def __init__(self, shared_module): super().__init__() From b4a9e53f010f9009fbf4b0ff57b13fceab36fa3c Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Thu, 9 Jan 2025 15:10:17 +0800 Subject: [PATCH 2/4] Comment out eot condition to generate all tokens --- backends/qualcomm/utils/utils.py | 1 + examples/qualcomm/oss_scripts/llama/llama.py | 53 +++---------------- .../oss_scripts/llama/runner/runner.cpp | 8 +-- exir/emit/test/test_emit.py | 11 ---- 4 files changed, 11 insertions(+), 62 deletions(-) diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index e13705b3a8f..a53f8e39071 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -823,6 +823,7 @@ def generate_multi_graph_program( ) assert qnn_mgr.Init().value == 0, "failed to load processed bytes" binary_info = bytes(qnn_mgr.Compile()) + print("Checking the size of QNN binary info: ", len(binary_info)) assert len(binary_info) != 0, "failed to generate QNN context binary" graph_names = qnn_mgr.GetGraphNames() for graph_name in graph_names: diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 0c96c275695..0af0f55b88f 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -82,6 +82,7 @@ def _kv_calibrate( _, atten_mask, _, k_caches, v_caches = example_inputs # TODO: change criteria & support batch inputs if necessary + pos = torch.tensor(0, dtype=torch.int32) max_cache_len = max_seq_len - 1 token_list = [] @@ -113,42 +114,10 @@ def _kv_calibrate( for i, v_cache in enumerate(v_caches) ] - # token_list = sp_model.encode(user_prompts, bos=True, eos=False) - - user_token_list = [ - # what is the capital of the united states - [128000, 128006, 882, 128007, 271, 12840, 374, 279, 6864, 315, 279, 29292, 5415, 128009, 128006, 78191, 128007, 271], - # what is 1 + 1 - [128000, 128006, 882, 128007, 271, 12840, 374, 220, 16, 489, 220, 16, 128009, 128006, 78191, 128007, 271], - # what is the meaning of life - [128000, 128006, 882, 128007, 271, 12840, 374, 279, 7438, 315, 2324, 128009, 128006, 78191, 128007, 271], - ] - - for token_list in user_token_list: - _, atten_mask, _, k_caches, v_caches = copy.deepcopy(example_inputs) - pos = torch.tensor(0, dtype=torch.int32) - with torch.no_grad(): - while token_list[-1] != sp_model.eos_id and pos < max_cache_len: - logits, new_k_caches, new_v_caches = module( - torch.full((1, 1), token_list[pos], dtype=torch.int32), - atten_mask, - torch.full((1, 1), pos), - *k_caches, - *v_caches, - ) - k_caches = [ - torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1) - for i, k_cache in enumerate(k_caches) - ] - v_caches = [ - torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1) - for i, v_cache in enumerate(v_caches) - ] - - pos += 1 - atten_mask[0][-pos - 1] = 0 - if pos >= len(token_list): - token_list.append(torch.argmax(logits[:, -1], dim=-1).item()) + pos += 1 + atten_mask[0][-pos - 1] = 0 + if pos >= len(token_list): + token_list.append(torch.argmax(logits[:, -1], dim=-1).item()) print(f"kv calibration data:\n{tokenizer.decode(token_list)}") @@ -359,17 +328,7 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): max_seq_len=self.llama_meta["get_max_seq_len"], ) - fx_graph_module = convert_pt2e(fx_graph_module) - - logging.info("Evaluating the converted model...") - calibrate( - self.get_example_inputs(self.llama_meta["get_use_kv_cache"]), - args.prompt, - fx_graph_module, - tokenizer_model_path=args.tokenizer_model, - max_seq_len=self.llama_meta["get_max_seq_len"], - ) - self.llama_model = fx_graph_module + self.llama_model = convert_pt2e(fx_graph_module) def lowering_modules( self, diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index e06d52fbb37..99d6d715db7 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -404,10 +404,10 @@ Error Runner::generate( token_callback(piece_res.get().c_str()); } - if (pos >= num_prompt_tokens && eos_id_.count(cur_token) > 0) { - ET_LOG(Info, "\nReached to the end of generation"); - break; - } + // if (pos >= num_prompt_tokens && eos_id_.count(cur_token) > 0) { + // ET_LOG(Info, "\nReached to the end of generation"); + // break; + // } } }; diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index bc3b7975d71..3fca3958feb 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -1682,10 +1682,7 @@ def forward(self, x): ] self.assertEqual(external_map["linear.weight"], 0) self.assertEqual(external_map["linear.bias"], 1) -<<<<<<< HEAD -======= ->>>>>>> c766f0dc0 (Apply calibration patch and deduplicate delegate cache patch) def test_delegate_deduplicate(self) -> None: class SharedModule(torch.nn.Module): def __init__(self): @@ -1695,10 +1692,6 @@ def __init__(self): def forward(self, x): return self.linear(x) -<<<<<<< HEAD -======= - ->>>>>>> c766f0dc0 (Apply calibration patch and deduplicate delegate cache patch) class Module1(torch.nn.Module): def __init__(self, shared_module): super().__init__() @@ -1707,10 +1700,6 @@ def __init__(self, shared_module): def forward(self, x): return self.shared_module(x) -<<<<<<< HEAD -======= - ->>>>>>> c766f0dc0 (Apply calibration patch and deduplicate delegate cache patch) class Module2(torch.nn.Module): def __init__(self, shared_module): super().__init__() From a6aee94dcb507e66f66c5353fc39ee4e413cf68d Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Tue, 21 Jan 2025 13:53:06 +0800 Subject: [PATCH 3/4] md5 to check quantized weights align --- examples/qualcomm/oss_scripts/llama/llama.py | 28 +++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 0af0f55b88f..cbcb3b0c04d 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -494,7 +494,8 @@ def compile(args, pte_filename, tokenizer): annotate_linear_16a8w_in_affine_layer, ) if args.ptq != None: - kv_quant_attrs = {} + import hashlib + kv_quant_attrs, parameter_hash = {}, [] for i, llama_instance in enumerate(llama_instance_list): llama_instance.quantize( quant_dtype=quant_dtype, @@ -517,6 +518,31 @@ def compile(args, pte_filename, tokenizer): kv_quant_attrs=kv_quant_attrs, ), ) + + tensor_to_md5 = {} + for name, buffer in llama_instance.llama_model.named_buffers(): + md5_buffer = hashlib.md5(buffer.numpy().tobytes()).hexdigest() + if md5_buffer in tensor_to_md5: + tensor_to_md5[md5_buffer].append(name) + else: + tensor_to_md5[md5_buffer] = [name] + parameter_hash.append(tensor_to_md5) + + # check tensors in prefill & decode are exactly the same + assert len(parameter_hash[0]) == len(parameter_hash[1]) + num_keys = len(parameter_hash[0]) + # Remove common keys from both dictionaries + for key in set(parameter_hash[0]).intersection(set(parameter_hash[1])): + del parameter_hash[0][key] + del parameter_hash[1][key] + print(f"{num_keys - len(parameter_hash[0])} / {num_keys} tensors are matched") + + for buf, name in parameter_hash[0].items(): # kv + print(f"KV buffers: {name} cannot find a match") + for buf, name in parameter_hash[1].items(): # prefill + print(f"Prefill buffers: {name} cannot find a match") + + end_quantize_ts = time.time() logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}") From f228d74668179a4baf03f9e570c0ef089fc61a2b Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Tue, 21 Jan 2025 16:58:22 +0800 Subject: [PATCH 4/4] add print --- examples/qualcomm/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index c2d2f002aa8..a1890022be4 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -167,6 +167,8 @@ def execute(self, custom_runner_cmd=None, method_index=0): ) else: qnn_executor_runner_cmds = custom_runner_cmd + + print("Execution command is: ", qnn_executor_runner_cmds) self._adb(["shell", f"{qnn_executor_runner_cmds}"])