From 3f69021fdf123c452ebdcbca5b21b73af388991d Mon Sep 17 00:00:00 2001 From: eplatero Date: Tue, 10 Dec 2024 22:13:47 -0600 Subject: [PATCH 01/12] adding spd inference script by apoorva Signed-off-by: eplatero --- tests/transformers/spd/test_spd_inference.py | 349 +++++++++++++++++++ 1 file changed, 349 insertions(+) create mode 100644 tests/transformers/spd/test_spd_inference.py diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py new file mode 100644 index 000000000..082f30a51 --- /dev/null +++ b/tests/transformers/spd/test_spd_inference.py @@ -0,0 +1,349 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from typing import List, Optional + +import numpy as np +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession + + +def run_prefill_on_draft_and_target( + tlm_session: QAICInferenceSession, + dlm_session: QAICInferenceSession, + prompt: dict, + prompt_len: int, + ctx_len: int, + prefill_batch_size: int, + decode_batch_size: int, + slot_idx: int +): + tlm_decode_start_input = dict() + dlm_decode_start_input = dict() + inputs = prompt + input_len = prompt.input_ids.shape[1] + num_chunks = -(input_len // -prompt_len) # ceil divide without float + input_len = num_chunks * prompt_len # Convert input_len to a multiple of prompt_len + assert input_len <= ctx_len, "input_len should be less than ctx_len" + # pad the prompt tokens to match the input_len + inputs = prompt + # TODO need to store the attention mask and position ids for each batch element so that we can access them + # at decode time + inputs["attention_mask"] = np.concatenate( + [inputs["attention_mask"].astype(bool) for j in range(decode_batch_size)], 0 + ) + inputs["position_ids"] = (np.cumsum(inputs["attention_mask"][0:1], 1) - 1) * inputs["attention_mask"][0:1] + + # FIXME "not" does not work for below line in place of the "== False" check, but code formatter recommends it + inputs["position_ids"][inputs["attention_mask"][0:1] == False] = -1 + cache_index = np.array([[0]], np.int64) + batch_index = np.array([[slot_idx]], np.int64) + inputs["batch_index"] = batch_index + + # Run chunked prefill + for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, cache_index[0, 0] : cache_index[0, 0] + prompt_len] + chunk_inputs["position_ids"] = inputs["position_ids"][:, cache_index[0, 0] : cache_index[0, 0] + prompt_len] + + chunk_inputs.pop("attention_mask") + tlm_outputs = tlm_session.run(chunk_inputs) + dlm_outputs = dlm_session.run(chunk_inputs) + cache_index += prompt_len + + tlm_logits = tlm_outputs["logits"] + dlm_logits = dlm_outputs["logits"] + + if len(tlm_logits.shape) == 2: + tlm_logits = np.expand_dims(tlm_logits, 1) + if len(dlm_logits.shape) == 2: + dlm_logits = np.expand_dims(dlm_logits, 1) + + tlm_decode_start_pos_id = inputs["attention_mask"][0:1].sum(1, keepdims=True) + tlm_decode_start_input_id = tlm_logits.argmax(2) + dlm_decode_start_input_id = dlm_logits.argmax(2) + dlm_decode_start_pos_id = inputs["attention_mask"][0:1].sum(1, keepdims=True) + + inputs.pop("attention_mask") + + tlm_decode_start_input = { + "logits": tlm_logits, + "input_ids": tlm_decode_start_input_id, + "position_ids": tlm_decode_start_pos_id, + "batch_index": batch_index, + "input_len": tlm_decode_start_pos_id[0, 0], + } + dlm_decode_start_input = { + "logits": dlm_logits, + "input_ids": dlm_decode_start_input_id, + "position_ids": dlm_decode_start_pos_id, + "batch_index": batch_index, + "input_len": tlm_decode_start_pos_id[0, 0], + } + + return tlm_decode_start_input, dlm_decode_start_input + + +def get_padded_input_len(input_len: int, prompt_len: int, ctx_len: int): + """return padded input length (must be factor of `prompt_len`) + + Args: + input_len (int): prompt length + prompt_len (int): prefill sequence length + ctx_len (int): context length + + Returns: + input_len_padded (int): padded input length + """ + num_chunks = -(input_len // -prompt_len) # ceil divide without float + input_len_padded = num_chunks * prompt_len # Convert input_len to a multiple of prompt_len + assert input_len_padded <= ctx_len, "input_len rounded to nearest prompt_len multiple should be less than ctx_len" + return input_len_padded + + +def populate_inputs(source, dest, index=None): + for k, v in dest.items(): + if k == "batch_index": + continue + if index is None: + # during decode + dest[k] = source[k] + else: + # during prefill with bs=1 + dest[k][index] = source[k] + +def split_dlm_bonus_token_inputs(dlm_decode_inputs): + bonus_token_inputs = dict() + bonus_token_inputs["input_ids"] = dlm_decode_inputs["input_ids"][:,0:1] + bonus_token_inputs["position_ids"] = dlm_decode_inputs["input_ids"][:,0:1] + dlm_decode_inputs["input_ids"] = dlm_decode_inputs["input_ids"][:,1:] + dlm_decode_inputs["position_ids"] = dlm_decode_inputs["position_ids"][:,1:] + return bonus_token_inputs, dlm_decode_inputs + +def test_spec_decode_inference( + prompt: List[str], + device_group: List[int], + num_speculative_tokens: int, + prompt_len: int, + ctx_len: int, + prefill_bsz: int, + draft_model_name: str, + target_model_name: str, + full_batch_size: Optional[int] = None, +): + # assumes dlm and tlm are compiled to the same prompt-chunk-size, context length and full_batch_size/batch-size + # get vocab size + tokenizer = AutoTokenizer.from_pretrained(target_model_name) + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + vocab_size = len(tokenizer) + + # export_and_compile tlm and dlm + target_model = AutoModelForCausalLM.from_pretrained(target_model_name, continuous_batching=True,is_tlm=True) + draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, continuous_batching=True) + + num_devices = len(device_group) + target_model_qpc_path: str = target_model.compile(num_cores=11,num_devices=num_devices,prefill_seq_len=prompt_len,ctx_len=ctx_len,mxfp6_matmul=True,aic_enable_depth_first=True, full_batch_size=full_batch_size, num_speculative_tokens=num_speculative_tokens) + + draft_model_qpc_path: str = draft_model.compile(is_dlm=False, num_cores=5,prefill_seq_len=prompt_len,ctx_len=ctx_len,mxfp6_matmul=True,aic_enable_depth_first=True, full_batch_size=full_batch_size) + + # init qaic session + target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=[2]) + draft_model_session = QAICInferenceSession(draft_model_qpc_path, device_ids=[3]) + + # skip inputs/outputs buffers + target_model_session.skip_buffers(set([x for x in target_model_session.input_names if x.startswith("past_")])) + target_model_session.skip_buffers( + set([x for x in target_model_session.output_names if x.endswith("_RetainedState")]) + ) + draft_model_session.skip_buffers(set([x for x in draft_model_session.input_names if x.startswith("past_")])) + draft_model_session.skip_buffers(set([x for x in draft_model_session.output_names if x.endswith("_RetainedState")])) + + is_cb = full_batch_size is not None + if not is_cb: + prompts = prompt * prefill_bsz + decode_batch_size = prefill_bsz + else: + prompts = prompt + decode_batch_size = full_batch_size + # tokenize the prompts + prompts_tokenized: List[dict] = [] + for p in prompts: + input_len: int = tokenizer(p, return_tensors="np", padding=True).input_ids.shape[1] + input_len_padded: int = get_padded_input_len(input_len, prompt_len, ctx_len) + p_tok: dict = tokenizer(p, return_tensors="np", padding="max_length", max_length=input_len_padded) + prompts_tokenized.append(p_tok) + # create caches to hold generated ids and input prompt lengths + generated_ids = [[] for i in range(decode_batch_size)] + input_lengths = [0] * decode_batch_size + # run prefill on both draft and target models + dlm_decode_inputs = dict() + dlm_decode_inputs["position_ids"] = np.zeros((decode_batch_size, 1), np.int64) + dlm_decode_inputs["input_ids"] = np.full((decode_batch_size, 1), tokenizer.pad_token_id) + dlm_decode_inputs["batch_index"] = np.reshape( + np.array(np.arange(decode_batch_size), np.int64), (decode_batch_size, 1) + ) + # mock input key "logits" to store the first batch of output logits + dlm_decode_inputs["logits"] = np.full((decode_batch_size, 1, vocab_size), 0) + tlm_precode_inputs = dict(dlm_decode_inputs) + is_prefill = True + generation_done = False + max_gen_len = [ctx_len] * decode_batch_size + num_logits_to_keep = num_speculative_tokens+1 + all_accept = np.full((decode_batch_size, num_speculative_tokens), False, dtype=bool) + tlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32) + dlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32) + decode_logits_ph = np.zeros((decode_batch_size, 1, vocab_size), dtype=np.float32) + precode_logits_ph = np.zeros((decode_batch_size, num_logits_to_keep, vocab_size), dtype=np.float32) + + target_model_session.set_buffers({"logits": tlm_prefill_logits_ph}) + draft_model_session.set_buffers({"logits": dlm_prefill_logits_ph}) + for bi in range(decode_batch_size): + # assumes that prefill queue will always be popped from the front + tlm_prefill_output, dlm_prefill_output = run_prefill_on_draft_and_target( + tlm_session=target_model_session, + dlm_session=draft_model_session, + prompt=prompts_tokenized[bi], + prompt_len=prompt_len, + ctx_len=ctx_len, + prefill_batch_size=prefill_bsz, + decode_batch_size=decode_batch_size, + slot_idx=bi, + ) + # this way, we will directly get the updated full batch input dict to run decode + populate_inputs(dlm_prefill_output, dlm_decode_inputs, bi) + populate_inputs(tlm_prefill_output, tlm_precode_inputs, bi) + # assumes that prefill queue will always be popped from the front + input_lengths[bi] = tlm_prefill_output["input_len"] + max_gen_len[bi] -= input_lengths[bi] + + target_model_session.set_buffers({"logits": precode_logits_ph}) + draft_model_session.set_buffers({"logits": decode_logits_ph}) + dlm_run_bonus_token = False + while not generation_done: + # compute the processed context length before each iteration to prepare the position id inputs + processed_context = [len(generated_ids[j]) + input_lengths[j] for j in range(decode_batch_size)] + # generate proposals from draft model + if is_prefill: + draft_logits = [dlm_decode_inputs.pop("logits")] + target_logits = [tlm_precode_inputs.pop("logits")] + else: + if np.any(all_accept): + input_ids = [] + position_ids = [] + dlm_run_bonus_token = True + for bi in range(decode_batch_size): + if all_accept[bi]: + # both last DLM token and bonus TLM token to be passed as input to DLM + input_ids.append([generated_ids[bi][-2], generated_ids[bi][-1]]) + position_ids.append([processed_context[bi] - 2, processed_context[bi] - 1]) + else: + # only the correct token from TLM from previous iteration and the pad_token as a dummy + input_ids.append([generated_ids[bi][-1], tokenizer.pad_token_id]) + position_ids.append([processed_context[bi] - 1, -1]) + dlm_decode_inputs["input_ids"] = np.array(input_ids) + dlm_decode_inputs["position_ids"] = np.array(position_ids) + else: + dlm_decode_inputs["input_ids"] = np.array([gid[-1] for gid in generated_ids], dtype=np.int64).reshape( + (decode_batch_size, 1) + ) + dlm_decode_inputs["position_ids"] = np.array( + [(pc - 1) for pc in processed_context], dtype=np.int64 + ).reshape((decode_batch_size, 1)) + # prepare the inputs for the dlm speculation + # TODO in case of even one of the batch having all_accept, we have to use the seqlen=2 specialization + # hence need to have dummy -1 position id for other sequences. + # dlm_decode_inputs["position_ids"] = len(generated_ids per batch) + # dlm_decode_inputs["input_ids"] = (last gen dlm token) + last true token from TLM + for k_ in range(num_speculative_tokens): + if dlm_run_bonus_token: + #running decode one extra time in the first speculative iteration + # workaround to avoid the incorrect precode with 3-specialized multi-batch DLM + bonus_token_inputs, dlm_decode_inputs = split_dlm_bonus_token_inputs(dlm_decode_inputs) + dlm_outputs = draft_model_session.run(bonus_token_inputs) + dlm_run_bonus_token = False + dlm_outputs = draft_model_session.run(dlm_decode_inputs) + draft_logits.append(dlm_outputs["logits"]) + dlm_decode_inputs["input_ids"] = dlm_outputs["logits"].argmax(-1) + dlm_decode_inputs["position_ids"] = dlm_decode_inputs["position_ids"][:, -1:] + 1 + + draft_logits = np.array(draft_logits).squeeze(2).transpose((1, 0, 2)) + # greedy sampling from draft model + draft_tokens = draft_logits.argmax(-1) + + # construct precode inputs + tlm_precode_inputs["input_ids"] = draft_tokens + if not is_prefill: + last_genid = np.array([gid[-1] for gid in generated_ids], dtype=np.int64).reshape(decode_batch_size, 1) + tlm_precode_inputs["input_ids"] = np.concatenate((last_genid, tlm_precode_inputs["input_ids"]), axis=1) + # in case of general precode, first token in input sequence is = last generated TLM token (kv cache backfill) + tlm_precode_inputs["position_ids"] = np.array( + [np.arange(start=pc - 1, stop=pc + num_speculative_tokens) for pc in processed_context], dtype=np.int64 + ) + else: + # in case of just first precode, we are feeding in all new positions + tlm_precode_inputs["position_ids"] = np.array( + [np.arange(start=pc, stop=pc + num_speculative_tokens + 1) for pc in processed_context], dtype=np.int64 + ) + + # run precode on TLM to score the proposed tokens + tlm_outputs = target_model_session.run(tlm_precode_inputs) + target_precode_logits = tlm_outputs["logits"] + if is_prefill: + target_logits = np.concatenate((target_logits[0], target_precode_logits), axis=1) + # stack the prefill output logit and precode logits into a single tensor + else: + target_logits = target_precode_logits + # greedy sampling from target model + target_tokens = target_logits.argmax(-1) + # exact matching between draft and target tokens + matching = draft_tokens == target_tokens[:, :-1] + num_tokens_selected = np.argmin(matching, axis=1) + all_accept = matching[np.arange(decode_batch_size), num_tokens_selected] + num_tokens_selected = np.where(all_accept, matching.shape[1], num_tokens_selected) + + # append selected tokens to the generated_ids + for bi in range(decode_batch_size): + if len(generated_ids[bi]) >= max_gen_len[bi]: + continue + num_tokens_to_append = min(num_tokens_selected[bi], max_gen_len[bi] - len(generated_ids[bi])) + generated_ids[bi] += list(draft_tokens[bi, :num_tokens_to_append]) + # append bonus/corrected token where applicable + for bi in range(decode_batch_size): + if len(generated_ids[bi]) >= max_gen_len[bi]: + continue + if all_accept[bi]: + # bonus token + generated_ids[bi].append(target_tokens[bi, -1]) + else: + # correct token + generated_ids[bi].append(target_tokens[bi, num_tokens_selected[bi]]) + generation_done = True + for bi in range(decode_batch_size): + if len(generated_ids[bi]) < max_gen_len[bi]: + generation_done = False + is_prefill = False + draft_logits = [] + target_logits = [] + print("max generation len = ", max_gen_len) + print("actual generation len = ", [len(gid) for gid in generated_ids]) + print(tokenizer.batch_decode(generated_ids)) + + +test_spec_decode_inference( + ["My name is", "Hello", "Hi", "My name is"], + [0], + 5, + 32, + 128, + 1, + "JackFram/llama-68m", + "JackFram/llama-68m", + 4, +) \ No newline at end of file From 92410c6d0c2d410cfa8bd998eaf2a6304e9e9241 Mon Sep 17 00:00:00 2001 From: eplatero Date: Wed, 11 Dec 2024 06:19:57 -0600 Subject: [PATCH 02/12] use pytest parametrize configs Signed-off-by: eplatero --- tests/transformers/spd/test_spd_inference.py | 102 ++++++++++++------- 1 file changed, 64 insertions(+), 38 deletions(-) diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index 082f30a51..5ec967945 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -8,17 +8,34 @@ from typing import List, Optional import numpy as np +import pytest from transformers import AutoTokenizer from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils.device_utils import get_available_device_id + +configs = [ + pytest.param( + ["My name is", "Hello", "Hi", "My name is"], # prompt + 2, # num_speculative_tokens + 32, # prefill_seq_len + 128, # ctx_len + 1, # prefill_bsz + "JackFram/llama-68m", # draft_model_name + "JackFram/llama-68m", # target_model_name + 4, # full_batch_size + id="CB llama", + ), +] + def run_prefill_on_draft_and_target( tlm_session: QAICInferenceSession, dlm_session: QAICInferenceSession, prompt: dict, - prompt_len: int, + prefill_seq_len: int, ctx_len: int, prefill_batch_size: int, decode_batch_size: int, @@ -28,8 +45,8 @@ def run_prefill_on_draft_and_target( dlm_decode_start_input = dict() inputs = prompt input_len = prompt.input_ids.shape[1] - num_chunks = -(input_len // -prompt_len) # ceil divide without float - input_len = num_chunks * prompt_len # Convert input_len to a multiple of prompt_len + num_chunks = -(input_len // -prefill_seq_len) # ceil divide without float + input_len = num_chunks * prefill_seq_len # Convert input_len to a multiple of prefill_seq_len assert input_len <= ctx_len, "input_len should be less than ctx_len" # pad the prompt tokens to match the input_len inputs = prompt @@ -49,16 +66,17 @@ def run_prefill_on_draft_and_target( # Run chunked prefill for i in range(num_chunks): chunk_inputs = inputs.copy() - chunk_inputs["input_ids"] = inputs["input_ids"][:, cache_index[0, 0] : cache_index[0, 0] + prompt_len] - chunk_inputs["position_ids"] = inputs["position_ids"][:, cache_index[0, 0] : cache_index[0, 0] + prompt_len] + chunk_inputs["input_ids"] = inputs["input_ids"][:, cache_index[0, 0] : cache_index[0, 0] + prefill_seq_len] + chunk_inputs["position_ids"] = inputs["position_ids"][:, cache_index[0, 0] : cache_index[0, 0] + prefill_seq_len] chunk_inputs.pop("attention_mask") tlm_outputs = tlm_session.run(chunk_inputs) dlm_outputs = dlm_session.run(chunk_inputs) - cache_index += prompt_len + cache_index += prefill_seq_len tlm_logits = tlm_outputs["logits"] dlm_logits = dlm_outputs["logits"] + assert (tlm_logits == dlm_logits).sum().all() if len(tlm_logits.shape) == 2: tlm_logits = np.expand_dims(tlm_logits, 1) @@ -90,20 +108,20 @@ def run_prefill_on_draft_and_target( return tlm_decode_start_input, dlm_decode_start_input -def get_padded_input_len(input_len: int, prompt_len: int, ctx_len: int): - """return padded input length (must be factor of `prompt_len`) +def get_padded_input_len(input_len: int, prefill_seq_len: int, ctx_len: int): + """return padded input length (must be factor of `prefill_seq_len`) Args: input_len (int): prompt length - prompt_len (int): prefill sequence length + prefill_seq_len (int): prefill sequence length ctx_len (int): context length Returns: input_len_padded (int): padded input length """ - num_chunks = -(input_len // -prompt_len) # ceil divide without float - input_len_padded = num_chunks * prompt_len # Convert input_len to a multiple of prompt_len - assert input_len_padded <= ctx_len, "input_len rounded to nearest prompt_len multiple should be less than ctx_len" + num_chunks = -(input_len // -prefill_seq_len) # ceil divide without float + input_len_padded = num_chunks * prefill_seq_len # Convert input_len to a multiple of prefill_seq_len + assert input_len_padded <= ctx_len, "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len" return input_len_padded @@ -126,17 +144,24 @@ def split_dlm_bonus_token_inputs(dlm_decode_inputs): dlm_decode_inputs["position_ids"] = dlm_decode_inputs["position_ids"][:,1:] return bonus_token_inputs, dlm_decode_inputs +@pytest.mark.parametrize( + "prompt, num_speculative_tokens, prefill_seq_len, ctx_len, prefill_bsz, draft_model_name, target_model_name, full_batch_size", + configs, +) def test_spec_decode_inference( prompt: List[str], - device_group: List[int], num_speculative_tokens: int, - prompt_len: int, + prefill_seq_len: int, ctx_len: int, prefill_bsz: int, draft_model_name: str, target_model_name: str, - full_batch_size: Optional[int] = None, + full_batch_size: Optional[int], ): + # get device group + device_group: List[int] = get_available_device_id() + if not device_group: + pytest.skip("No available devices to run model on Cloud AI 100") # assumes dlm and tlm are compiled to the same prompt-chunk-size, context length and full_batch_size/batch-size # get vocab size tokenizer = AutoTokenizer.from_pretrained(target_model_name) @@ -145,17 +170,26 @@ def test_spec_decode_inference( vocab_size = len(tokenizer) # export_and_compile tlm and dlm - target_model = AutoModelForCausalLM.from_pretrained(target_model_name, continuous_batching=True,is_tlm=True) - draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, continuous_batching=True) + continuous_batching = full_batch_size is not None + target_model = AutoModelForCausalLM.from_pretrained(target_model_name, continuous_batching=continuous_batching, is_tlm=True) + draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, continuous_batching=continuous_batching) num_devices = len(device_group) - target_model_qpc_path: str = target_model.compile(num_cores=11,num_devices=num_devices,prefill_seq_len=prompt_len,ctx_len=ctx_len,mxfp6_matmul=True,aic_enable_depth_first=True, full_batch_size=full_batch_size, num_speculative_tokens=num_speculative_tokens) - - draft_model_qpc_path: str = draft_model.compile(is_dlm=False, num_cores=5,prefill_seq_len=prompt_len,ctx_len=ctx_len,mxfp6_matmul=True,aic_enable_depth_first=True, full_batch_size=full_batch_size) - + target_model_qpc_path: str = target_model.compile(num_cores=11, + num_devices=num_devices, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + aic_enable_depth_first=True, + full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens) + draft_model_qpc_path: str = draft_model.compile(num_cores=5, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + aic_enable_depth_first=True, + full_batch_size=full_batch_size) # init qaic session - target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=[2]) - draft_model_session = QAICInferenceSession(draft_model_qpc_path, device_ids=[3]) + target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=device_group) + draft_model_session = QAICInferenceSession(draft_model_qpc_path, device_ids=device_group) # skip inputs/outputs buffers target_model_session.skip_buffers(set([x for x in target_model_session.input_names if x.startswith("past_")])) @@ -176,7 +210,7 @@ def test_spec_decode_inference( prompts_tokenized: List[dict] = [] for p in prompts: input_len: int = tokenizer(p, return_tensors="np", padding=True).input_ids.shape[1] - input_len_padded: int = get_padded_input_len(input_len, prompt_len, ctx_len) + input_len_padded: int = get_padded_input_len(input_len, prefill_seq_len, ctx_len) p_tok: dict = tokenizer(p, return_tensors="np", padding="max_length", max_length=input_len_padded) prompts_tokenized.append(p_tok) # create caches to hold generated ids and input prompt lengths @@ -210,7 +244,7 @@ def test_spec_decode_inference( tlm_session=target_model_session, dlm_session=draft_model_session, prompt=prompts_tokenized[bi], - prompt_len=prompt_len, + prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, prefill_batch_size=prefill_bsz, decode_batch_size=decode_batch_size, @@ -225,6 +259,7 @@ def test_spec_decode_inference( target_model_session.set_buffers({"logits": precode_logits_ph}) draft_model_session.set_buffers({"logits": decode_logits_ph}) + num_tokens_selected_per_validation = [] dlm_run_bonus_token = False while not generation_done: # compute the processed context length before each iteration to prepare the position id inputs @@ -307,6 +342,7 @@ def test_spec_decode_inference( num_tokens_selected = np.argmin(matching, axis=1) all_accept = matching[np.arange(decode_batch_size), num_tokens_selected] num_tokens_selected = np.where(all_accept, matching.shape[1], num_tokens_selected) + num_tokens_selected_per_validation.append(num_tokens_selected) # append selected tokens to the generated_ids for bi in range(decode_batch_size): @@ -331,19 +367,9 @@ def test_spec_decode_inference( is_prefill = False draft_logits = [] target_logits = [] + num_tokens_selected_per_validation = np.concatenate(num_tokens_selected_per_validation).reshape(len(num_tokens_selected_per_validation), decode_batch_size) + mean_num_accepted_tokens_per_batch = num_tokens_selected_per_validation.mean(axis=0) + print("mean number of accepted tokens per batch = ", mean_num_accepted_tokens_per_batch) print("max generation len = ", max_gen_len) print("actual generation len = ", [len(gid) for gid in generated_ids]) print(tokenizer.batch_decode(generated_ids)) - - -test_spec_decode_inference( - ["My name is", "Hello", "Hi", "My name is"], - [0], - 5, - 32, - 128, - 1, - "JackFram/llama-68m", - "JackFram/llama-68m", - 4, -) \ No newline at end of file From 23311ac3a5c2378d6027f27030bfffdf69a3431c Mon Sep 17 00:00:00 2001 From: eplatero Date: Wed, 11 Dec 2024 10:56:04 -0600 Subject: [PATCH 03/12] rm function as it was causing some corruption when populating logits Signed-off-by: eplatero --- tests/transformers/spd/test_spd_inference.py | 74 +++++++------------- 1 file changed, 27 insertions(+), 47 deletions(-) diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index 5ec967945..eb27f6ef7 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- from typing import List, Optional +from pprint import pprint import numpy as np import pytest @@ -17,14 +18,15 @@ configs = [ pytest.param( - ["My name is", "Hello", "Hi", "My name is"], # prompt - 2, # num_speculative_tokens + #["My name is", "Hello", "Hi", "My name is"], # prompt + ["My name is"], # prompt + 1, # num_speculative_tokens 32, # prefill_seq_len 128, # ctx_len 1, # prefill_bsz "JackFram/llama-68m", # draft_model_name "JackFram/llama-68m", # target_model_name - 4, # full_batch_size + 1, # full_batch_size id="CB llama", ), ] @@ -34,31 +36,14 @@ def run_prefill_on_draft_and_target( tlm_session: QAICInferenceSession, dlm_session: QAICInferenceSession, - prompt: dict, + inputs: dict, prefill_seq_len: int, - ctx_len: int, - prefill_batch_size: int, - decode_batch_size: int, slot_idx: int ): tlm_decode_start_input = dict() dlm_decode_start_input = dict() - inputs = prompt - input_len = prompt.input_ids.shape[1] - num_chunks = -(input_len // -prefill_seq_len) # ceil divide without float - input_len = num_chunks * prefill_seq_len # Convert input_len to a multiple of prefill_seq_len - assert input_len <= ctx_len, "input_len should be less than ctx_len" - # pad the prompt tokens to match the input_len - inputs = prompt - # TODO need to store the attention mask and position ids for each batch element so that we can access them - # at decode time - inputs["attention_mask"] = np.concatenate( - [inputs["attention_mask"].astype(bool) for j in range(decode_batch_size)], 0 - ) - inputs["position_ids"] = (np.cumsum(inputs["attention_mask"][0:1], 1) - 1) * inputs["attention_mask"][0:1] - - # FIXME "not" does not work for below line in place of the "== False" check, but code formatter recommends it - inputs["position_ids"][inputs["attention_mask"][0:1] == False] = -1 + input_len = inputs.input_ids.shape[1] + num_chunks = input_len // prefill_seq_len cache_index = np.array([[0]], np.int64) batch_index = np.array([[slot_idx]], np.int64) inputs["batch_index"] = batch_index @@ -69,12 +54,12 @@ def run_prefill_on_draft_and_target( chunk_inputs["input_ids"] = inputs["input_ids"][:, cache_index[0, 0] : cache_index[0, 0] + prefill_seq_len] chunk_inputs["position_ids"] = inputs["position_ids"][:, cache_index[0, 0] : cache_index[0, 0] + prefill_seq_len] - chunk_inputs.pop("attention_mask") tlm_outputs = tlm_session.run(chunk_inputs) dlm_outputs = dlm_session.run(chunk_inputs) cache_index += prefill_seq_len tlm_logits = tlm_outputs["logits"] + return tlm_logits dlm_logits = dlm_outputs["logits"] assert (tlm_logits == dlm_logits).sum().all() @@ -125,16 +110,6 @@ def get_padded_input_len(input_len: int, prefill_seq_len: int, ctx_len: int): return input_len_padded -def populate_inputs(source, dest, index=None): - for k, v in dest.items(): - if k == "batch_index": - continue - if index is None: - # during decode - dest[k] = source[k] - else: - # during prefill with bs=1 - dest[k][index] = source[k] def split_dlm_bonus_token_inputs(dlm_decode_inputs): bonus_token_inputs = dict() @@ -164,7 +139,7 @@ def test_spec_decode_inference( pytest.skip("No available devices to run model on Cloud AI 100") # assumes dlm and tlm are compiled to the same prompt-chunk-size, context length and full_batch_size/batch-size # get vocab size - tokenizer = AutoTokenizer.from_pretrained(target_model_name) + tokenizer = AutoTokenizer.from_pretrained(target_model_name, padding_side="right") if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id vocab_size = len(tokenizer) @@ -212,6 +187,8 @@ def test_spec_decode_inference( input_len: int = tokenizer(p, return_tensors="np", padding=True).input_ids.shape[1] input_len_padded: int = get_padded_input_len(input_len, prefill_seq_len, ctx_len) p_tok: dict = tokenizer(p, return_tensors="np", padding="max_length", max_length=input_len_padded) + position_ids = np.where(p_tok.pop("attention_mask"), np.arange(input_len_padded), -1) + p_tok["position_ids"] = position_ids prompts_tokenized.append(p_tok) # create caches to hold generated ids and input prompt lengths generated_ids = [[] for i in range(decode_batch_size)] @@ -224,7 +201,6 @@ def test_spec_decode_inference( np.array(np.arange(decode_batch_size), np.int64), (decode_batch_size, 1) ) # mock input key "logits" to store the first batch of output logits - dlm_decode_inputs["logits"] = np.full((decode_batch_size, 1, vocab_size), 0) tlm_precode_inputs = dict(dlm_decode_inputs) is_prefill = True generation_done = False @@ -240,21 +216,19 @@ def test_spec_decode_inference( draft_model_session.set_buffers({"logits": dlm_prefill_logits_ph}) for bi in range(decode_batch_size): # assumes that prefill queue will always be popped from the front - tlm_prefill_output, dlm_prefill_output = run_prefill_on_draft_and_target( + tlm_logits = run_prefill_on_draft_and_target( tlm_session=target_model_session, dlm_session=draft_model_session, - prompt=prompts_tokenized[bi], + inputs=prompts_tokenized[bi], prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, - prefill_batch_size=prefill_bsz, - decode_batch_size=decode_batch_size, slot_idx=bi, ) - # this way, we will directly get the updated full batch input dict to run decode - populate_inputs(dlm_prefill_output, dlm_decode_inputs, bi) - populate_inputs(tlm_prefill_output, tlm_precode_inputs, bi) + input_ids = tlm_logits.argmax(2) + dlm_decode_inputs["input_ids"] = input_ids + input_len = prompts_tokenized[bi]['position_ids'].max(1).item() + 1 + dlm_decode_inputs["position_ids"][bi, 0] = input_len # assumes that prefill queue will always be popped from the front - input_lengths[bi] = tlm_prefill_output["input_len"] + input_lengths[bi] = input_len max_gen_len[bi] -= input_lengths[bi] target_model_session.set_buffers({"logits": precode_logits_ph}) @@ -266,8 +240,8 @@ def test_spec_decode_inference( processed_context = [len(generated_ids[j]) + input_lengths[j] for j in range(decode_batch_size)] # generate proposals from draft model if is_prefill: - draft_logits = [dlm_decode_inputs.pop("logits")] - target_logits = [tlm_precode_inputs.pop("logits")] + draft_logits = [tlm_logits] + target_logits = [tlm_logits] else: if np.any(all_accept): input_ids = [] @@ -296,6 +270,8 @@ def test_spec_decode_inference( # hence need to have dummy -1 position id for other sequences. # dlm_decode_inputs["position_ids"] = len(generated_ids per batch) # dlm_decode_inputs["input_ids"] = (last gen dlm token) + last true token from TLM + from pprint import pprint + pprint(f"{dlm_decode_inputs=}") for k_ in range(num_speculative_tokens): if dlm_run_bonus_token: #running decode one extra time in the first speculative iteration @@ -303,6 +279,7 @@ def test_spec_decode_inference( bonus_token_inputs, dlm_decode_inputs = split_dlm_bonus_token_inputs(dlm_decode_inputs) dlm_outputs = draft_model_session.run(bonus_token_inputs) dlm_run_bonus_token = False + pprint(f"{dlm_decode_inputs=}") dlm_outputs = draft_model_session.run(dlm_decode_inputs) draft_logits.append(dlm_outputs["logits"]) dlm_decode_inputs["input_ids"] = dlm_outputs["logits"].argmax(-1) @@ -328,6 +305,8 @@ def test_spec_decode_inference( ) # run precode on TLM to score the proposed tokens + pprint(f"{dlm_decode_inputs=}") + pprint(f"{tlm_precode_inputs=}") tlm_outputs = target_model_session.run(tlm_precode_inputs) target_precode_logits = tlm_outputs["logits"] if is_prefill: @@ -342,6 +321,7 @@ def test_spec_decode_inference( num_tokens_selected = np.argmin(matching, axis=1) all_accept = matching[np.arange(decode_batch_size), num_tokens_selected] num_tokens_selected = np.where(all_accept, matching.shape[1], num_tokens_selected) + print(num_tokens_selected) num_tokens_selected_per_validation.append(num_tokens_selected) # append selected tokens to the generated_ids From ba08731032b0f5783bccdc438bdb4d5aa601601c Mon Sep 17 00:00:00 2001 From: eplatero Date: Fri, 13 Dec 2024 13:29:30 -0600 Subject: [PATCH 04/12] first draft of script getting 100% acceptance rate when using same TLM/DLM Signed-off-by: eplatero --- tests/transformers/spd/test_spd_inference.py | 212 +++++++------------ 1 file changed, 77 insertions(+), 135 deletions(-) diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index eb27f6ef7..a635ef3d3 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -55,42 +55,11 @@ def run_prefill_on_draft_and_target( chunk_inputs["position_ids"] = inputs["position_ids"][:, cache_index[0, 0] : cache_index[0, 0] + prefill_seq_len] tlm_outputs = tlm_session.run(chunk_inputs) - dlm_outputs = dlm_session.run(chunk_inputs) + _ = dlm_session.run(chunk_inputs) cache_index += prefill_seq_len tlm_logits = tlm_outputs["logits"] return tlm_logits - dlm_logits = dlm_outputs["logits"] - assert (tlm_logits == dlm_logits).sum().all() - - if len(tlm_logits.shape) == 2: - tlm_logits = np.expand_dims(tlm_logits, 1) - if len(dlm_logits.shape) == 2: - dlm_logits = np.expand_dims(dlm_logits, 1) - - tlm_decode_start_pos_id = inputs["attention_mask"][0:1].sum(1, keepdims=True) - tlm_decode_start_input_id = tlm_logits.argmax(2) - dlm_decode_start_input_id = dlm_logits.argmax(2) - dlm_decode_start_pos_id = inputs["attention_mask"][0:1].sum(1, keepdims=True) - - inputs.pop("attention_mask") - - tlm_decode_start_input = { - "logits": tlm_logits, - "input_ids": tlm_decode_start_input_id, - "position_ids": tlm_decode_start_pos_id, - "batch_index": batch_index, - "input_len": tlm_decode_start_pos_id[0, 0], - } - dlm_decode_start_input = { - "logits": dlm_logits, - "input_ids": dlm_decode_start_input_id, - "position_ids": dlm_decode_start_pos_id, - "batch_index": batch_index, - "input_len": tlm_decode_start_pos_id[0, 0], - } - - return tlm_decode_start_input, dlm_decode_start_input def get_padded_input_len(input_len: int, prefill_seq_len: int, ctx_len: int): @@ -113,10 +82,12 @@ def get_padded_input_len(input_len: int, prefill_seq_len: int, ctx_len: int): def split_dlm_bonus_token_inputs(dlm_decode_inputs): bonus_token_inputs = dict() - bonus_token_inputs["input_ids"] = dlm_decode_inputs["input_ids"][:,0:1] - bonus_token_inputs["position_ids"] = dlm_decode_inputs["input_ids"][:,0:1] - dlm_decode_inputs["input_ids"] = dlm_decode_inputs["input_ids"][:,1:] - dlm_decode_inputs["position_ids"] = dlm_decode_inputs["position_ids"][:,1:] + bonus, regular = np.hsplit(dlm_decode_inputs["input_ids"], 2) + bonus_token_inputs["input_ids"] = bonus + dlm_decode_inputs["input_ids"] = regular + bonus, regular = np.hsplit(dlm_decode_inputs["position_ids"], 2) + bonus_token_inputs["position_ids"] = bonus + dlm_decode_inputs["position_ids"] = regular return bonus_token_inputs, dlm_decode_inputs @pytest.mark.parametrize( @@ -134,7 +105,8 @@ def test_spec_decode_inference( full_batch_size: Optional[int], ): # get device group - device_group: List[int] = get_available_device_id() + #device_group: List[int] = get_available_device_id() + device_group: List[int] = [1] if not device_group: pytest.skip("No available devices to run model on Cloud AI 100") # assumes dlm and tlm are compiled to the same prompt-chunk-size, context length and full_batch_size/batch-size @@ -201,11 +173,15 @@ def test_spec_decode_inference( np.array(np.arange(decode_batch_size), np.int64), (decode_batch_size, 1) ) # mock input key "logits" to store the first batch of output logits - tlm_precode_inputs = dict(dlm_decode_inputs) - is_prefill = True + tlm_precode_inputs = dict( + input_ids = np.zeros((decode_batch_size, num_speculative_tokens+1), dtype=np.int64), + position_ids = np.zeros((decode_batch_size, num_speculative_tokens+1), dtype=np.int64), + batch_index = np.arange(decode_batch_size, dtype=np.int64).reshape(-1, 1) + ) generation_done = False max_gen_len = [ctx_len] * decode_batch_size num_logits_to_keep = num_speculative_tokens+1 + # setup buffers all_accept = np.full((decode_batch_size, num_speculative_tokens), False, dtype=bool) tlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32) dlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32) @@ -223,130 +199,96 @@ def test_spec_decode_inference( prefill_seq_len=prefill_seq_len, slot_idx=bi, ) - input_ids = tlm_logits.argmax(2) + input_ids = tlm_logits.argmax(2).astype(np.int64) dlm_decode_inputs["input_ids"] = input_ids + tlm_precode_inputs["input_ids"][bi, 0] = input_ids.item() input_len = prompts_tokenized[bi]['position_ids'].max(1).item() + 1 dlm_decode_inputs["position_ids"][bi, 0] = input_len + tlm_precode_inputs["position_ids"][bi] = np.arange(input_len, input_len+num_speculative_tokens+1, dtype=np.int64) # assumes that prefill queue will always be popped from the front input_lengths[bi] = input_len max_gen_len[bi] -= input_lengths[bi] + # set decode logits buffers target_model_session.set_buffers({"logits": precode_logits_ph}) draft_model_session.set_buffers({"logits": decode_logits_ph}) + # start decode phase + valid_batch_indices = list(range(decode_batch_size))[::-1] num_tokens_selected_per_validation = [] - dlm_run_bonus_token = False - while not generation_done: - # compute the processed context length before each iteration to prepare the position id inputs - processed_context = [len(generated_ids[j]) + input_lengths[j] for j in range(decode_batch_size)] + all_accept = False + it = 0 + break_idx = 1000 + while True: + print('-'*60) + print(f"{it=}") # generate proposals from draft model - if is_prefill: - draft_logits = [tlm_logits] - target_logits = [tlm_logits] - else: - if np.any(all_accept): - input_ids = [] - position_ids = [] - dlm_run_bonus_token = True - for bi in range(decode_batch_size): - if all_accept[bi]: - # both last DLM token and bonus TLM token to be passed as input to DLM - input_ids.append([generated_ids[bi][-2], generated_ids[bi][-1]]) - position_ids.append([processed_context[bi] - 2, processed_context[bi] - 1]) - else: - # only the correct token from TLM from previous iteration and the pad_token as a dummy - input_ids.append([generated_ids[bi][-1], tokenizer.pad_token_id]) - position_ids.append([processed_context[bi] - 1, -1]) - dlm_decode_inputs["input_ids"] = np.array(input_ids) - dlm_decode_inputs["position_ids"] = np.array(position_ids) - else: - dlm_decode_inputs["input_ids"] = np.array([gid[-1] for gid in generated_ids], dtype=np.int64).reshape( - (decode_batch_size, 1) - ) - dlm_decode_inputs["position_ids"] = np.array( - [(pc - 1) for pc in processed_context], dtype=np.int64 - ).reshape((decode_batch_size, 1)) - # prepare the inputs for the dlm speculation - # TODO in case of even one of the batch having all_accept, we have to use the seqlen=2 specialization - # hence need to have dummy -1 position id for other sequences. - # dlm_decode_inputs["position_ids"] = len(generated_ids per batch) - # dlm_decode_inputs["input_ids"] = (last gen dlm token) + last true token from TLM - from pprint import pprint - pprint(f"{dlm_decode_inputs=}") for k_ in range(num_speculative_tokens): - if dlm_run_bonus_token: + if all_accept: #running decode one extra time in the first speculative iteration # workaround to avoid the incorrect precode with 3-specialized multi-batch DLM bonus_token_inputs, dlm_decode_inputs = split_dlm_bonus_token_inputs(dlm_decode_inputs) - dlm_outputs = draft_model_session.run(bonus_token_inputs) - dlm_run_bonus_token = False - pprint(f"{dlm_decode_inputs=}") + pprint(f"{bonus_token_inputs=}") + _ = draft_model_session.run(bonus_token_inputs) dlm_outputs = draft_model_session.run(dlm_decode_inputs) - draft_logits.append(dlm_outputs["logits"]) - dlm_decode_inputs["input_ids"] = dlm_outputs["logits"].argmax(-1) - dlm_decode_inputs["position_ids"] = dlm_decode_inputs["position_ids"][:, -1:] + 1 - - draft_logits = np.array(draft_logits).squeeze(2).transpose((1, 0, 2)) - # greedy sampling from draft model - draft_tokens = draft_logits.argmax(-1) - - # construct precode inputs - tlm_precode_inputs["input_ids"] = draft_tokens - if not is_prefill: - last_genid = np.array([gid[-1] for gid in generated_ids], dtype=np.int64).reshape(decode_batch_size, 1) - tlm_precode_inputs["input_ids"] = np.concatenate((last_genid, tlm_precode_inputs["input_ids"]), axis=1) - # in case of general precode, first token in input sequence is = last generated TLM token (kv cache backfill) - tlm_precode_inputs["position_ids"] = np.array( - [np.arange(start=pc - 1, stop=pc + num_speculative_tokens) for pc in processed_context], dtype=np.int64 - ) - else: - # in case of just first precode, we are feeding in all new positions - tlm_precode_inputs["position_ids"] = np.array( - [np.arange(start=pc, stop=pc + num_speculative_tokens + 1) for pc in processed_context], dtype=np.int64 - ) - + pprint(f"{dlm_decode_inputs=}") + pprint(f"{tlm_precode_inputs=}") + input_ids = dlm_outputs["logits"].argmax(2) + tlm_precode_inputs["input_ids"][:, k_+1] = input_ids.flatten() + dlm_decode_inputs["input_ids"] = input_ids + dlm_decode_inputs["position_ids"][valid_batch_indices] += 1 # run precode on TLM to score the proposed tokens pprint(f"{dlm_decode_inputs=}") pprint(f"{tlm_precode_inputs=}") + if it == break_idx: breakpoint() tlm_outputs = target_model_session.run(tlm_precode_inputs) - target_precode_logits = tlm_outputs["logits"] - if is_prefill: - target_logits = np.concatenate((target_logits[0], target_precode_logits), axis=1) - # stack the prefill output logit and precode logits into a single tensor - else: - target_logits = target_precode_logits + target_logits = tlm_outputs["logits"][valid_batch_indices] # greedy sampling from target model - target_tokens = target_logits.argmax(-1) + target_tokens = target_logits.argmax(-1) # shape: [len(valid_batch_indices), num_speculative_tokens+1] # exact matching between draft and target tokens + draft_tokens = tlm_precode_inputs["input_ids"][valid_batch_indices,1:] # shape: [len(valid_batch_indices), num_speculative_tokens] matching = draft_tokens == target_tokens[:, :-1] - num_tokens_selected = np.argmin(matching, axis=1) - all_accept = matching[np.arange(decode_batch_size), num_tokens_selected] - num_tokens_selected = np.where(all_accept, matching.shape[1], num_tokens_selected) - print(num_tokens_selected) + num_tokens_selected = matching.cumprod(axis=1).sum(axis=1) # shape: [len(valid_batch_indices)] + all_accept = (num_tokens_selected == num_speculative_tokens).all() num_tokens_selected_per_validation.append(num_tokens_selected) + print(f"{target_tokens=}") + print(f"{num_tokens_selected=}") + print(f"{all_accept=}") + if it == break_idx: breakpoint() + if num_tokens_selected.item() == 0: breakpoint() # append selected tokens to the generated_ids - for bi in range(decode_batch_size): - if len(generated_ids[bi]) >= max_gen_len[bi]: - continue - num_tokens_to_append = min(num_tokens_selected[bi], max_gen_len[bi] - len(generated_ids[bi])) - generated_ids[bi] += list(draft_tokens[bi, :num_tokens_to_append]) - # append bonus/corrected token where applicable - for bi in range(decode_batch_size): + for bi in valid_batch_indices: + accepted_tokens = num_tokens_selected[bi] + if all_accept: + accepted_tokens += 1 + num_tokens_to_append = min(accepted_tokens, max_gen_len[bi] - len(generated_ids[bi])) + generated_ids[bi].extend(target_tokens[bi, :num_tokens_to_append].tolist()) if len(generated_ids[bi]) >= max_gen_len[bi]: + del valid_batch_indices[bi] continue - if all_accept[bi]: - # bonus token - generated_ids[bi].append(target_tokens[bi, -1]) - else: - # correct token - generated_ids[bi].append(target_tokens[bi, num_tokens_selected[bi]]) - generation_done = True - for bi in range(decode_batch_size): - if len(generated_ids[bi]) < max_gen_len[bi]: - generation_done = False - is_prefill = False - draft_logits = [] - target_logits = [] + # check if all generations are done + if not valid_batch_indices: break + # prepare decode inputs for next decode iteration + common_input_ids = target_tokens[:, num_tokens_selected[valid_batch_indices]].reshape(len(valid_batch_indices), 1) + common_position_ids = tlm_precode_inputs["position_ids"][valid_batch_indices, num_tokens_selected[valid_batch_indices]].reshape(len(valid_batch_indices), 1)+1 + if all_accept: + # all_accept input_ids + input_ids = np.zeros((decode_batch_size, 2), dtype=np.int64) + input_ids[valid_batch_indices] = np.concatenate([target_tokens[:, num_tokens_selected-1].reshape(-1,1), common_input_ids], axis=1) + dlm_decode_inputs["input_ids"] = input_ids + # all_accept position_ids + position_ids = np.zeros((decode_batch_size, 2), dtype=np.int64) + position_ids[valid_batch_indices] = np.concatenate([tlm_precode_inputs["position_ids"][valid_batch_indices, num_tokens_selected-1].reshape(-1,1)+1, common_position_ids], axis=1) + dlm_decode_inputs["position_ids"] = position_ids + else: + dlm_decode_inputs["input_ids"][valid_batch_indices] = common_input_ids + dlm_decode_inputs["position_ids"][valid_batch_indices] = common_position_ids + tlm_precode_inputs["input_ids"][valid_batch_indices,0] = common_input_ids + tlm_precode_inputs["position_ids"][valid_batch_indices] += num_tokens_selected.reshape(len(valid_batch_indices),1)+1 + pprint(dlm_decode_inputs) + pprint(tlm_precode_inputs) + if it == break_idx: breakpoint() + it += 1 num_tokens_selected_per_validation = np.concatenate(num_tokens_selected_per_validation).reshape(len(num_tokens_selected_per_validation), decode_batch_size) mean_num_accepted_tokens_per_batch = num_tokens_selected_per_validation.mean(axis=0) print("mean number of accepted tokens per batch = ", mean_num_accepted_tokens_per_batch) From c9abb94a8af5deb3ff08b52116ee9dc75159a8fb Mon Sep 17 00:00:00 2001 From: eplatero Date: Sat, 14 Dec 2024 11:42:05 -0600 Subject: [PATCH 05/12] validation with full_batch_size=2 is passing Signed-off-by: eplatero --- tests/transformers/spd/test_spd_inference.py | 35 +++++++++++--------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index a635ef3d3..866f63c9b 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -16,17 +16,18 @@ from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.utils.device_utils import get_available_device_id +fbs = 2 configs = [ pytest.param( #["My name is", "Hello", "Hi", "My name is"], # prompt - ["My name is"], # prompt + ["My name is"]*fbs, # prompt 1, # num_speculative_tokens 32, # prefill_seq_len 128, # ctx_len 1, # prefill_bsz "JackFram/llama-68m", # draft_model_name "JackFram/llama-68m", # target_model_name - 1, # full_batch_size + fbs, # full_batch_size id="CB llama", ), ] @@ -40,8 +41,6 @@ def run_prefill_on_draft_and_target( prefill_seq_len: int, slot_idx: int ): - tlm_decode_start_input = dict() - dlm_decode_start_input = dict() input_len = inputs.input_ids.shape[1] num_chunks = input_len // prefill_seq_len cache_index = np.array([[0]], np.int64) @@ -88,6 +87,7 @@ def split_dlm_bonus_token_inputs(dlm_decode_inputs): bonus, regular = np.hsplit(dlm_decode_inputs["position_ids"], 2) bonus_token_inputs["position_ids"] = bonus dlm_decode_inputs["position_ids"] = regular + bonus_token_inputs["batch_index"] = dlm_decode_inputs["batch_index"] return bonus_token_inputs, dlm_decode_inputs @pytest.mark.parametrize( @@ -178,7 +178,6 @@ def test_spec_decode_inference( position_ids = np.zeros((decode_batch_size, num_speculative_tokens+1), dtype=np.int64), batch_index = np.arange(decode_batch_size, dtype=np.int64).reshape(-1, 1) ) - generation_done = False max_gen_len = [ctx_len] * decode_batch_size num_logits_to_keep = num_speculative_tokens+1 # setup buffers @@ -200,7 +199,7 @@ def test_spec_decode_inference( slot_idx=bi, ) input_ids = tlm_logits.argmax(2).astype(np.int64) - dlm_decode_inputs["input_ids"] = input_ids + dlm_decode_inputs["input_ids"][bi, 0] = input_ids tlm_precode_inputs["input_ids"][bi, 0] = input_ids.item() input_len = prompts_tokenized[bi]['position_ids'].max(1).item() + 1 dlm_decode_inputs["position_ids"][bi, 0] = input_len @@ -217,10 +216,12 @@ def test_spec_decode_inference( num_tokens_selected_per_validation = [] all_accept = False it = 0 - break_idx = 1000 + #break_idx = 61 + break_idx = 10000 while True: print('-'*60) print(f"{it=}") + print(f"{valid_batch_indices=}") # generate proposals from draft model for k_ in range(num_speculative_tokens): if all_accept: @@ -254,9 +255,10 @@ def test_spec_decode_inference( print(f"{num_tokens_selected=}") print(f"{all_accept=}") if it == break_idx: breakpoint() - if num_tokens_selected.item() == 0: breakpoint() + if not all_accept: breakpoint() # append selected tokens to the generated_ids + indices_to_rm = [] for bi in valid_batch_indices: accepted_tokens = num_tokens_selected[bi] if all_accept: @@ -264,27 +266,30 @@ def test_spec_decode_inference( num_tokens_to_append = min(accepted_tokens, max_gen_len[bi] - len(generated_ids[bi])) generated_ids[bi].extend(target_tokens[bi, :num_tokens_to_append].tolist()) if len(generated_ids[bi]) >= max_gen_len[bi]: - del valid_batch_indices[bi] - continue + indices_to_rm.append(bi) + for idx in indices_to_rm: + del valid_batch_indices[idx] + print(f"{valid_batch_indices=}") # check if all generations are done if not valid_batch_indices: break # prepare decode inputs for next decode iteration - common_input_ids = target_tokens[:, num_tokens_selected[valid_batch_indices]].reshape(len(valid_batch_indices), 1) + common_input_ids = target_tokens[np.arange(len(valid_batch_indices)), num_tokens_selected[valid_batch_indices]].reshape(len(valid_batch_indices), 1) common_position_ids = tlm_precode_inputs["position_ids"][valid_batch_indices, num_tokens_selected[valid_batch_indices]].reshape(len(valid_batch_indices), 1)+1 if all_accept: # all_accept input_ids input_ids = np.zeros((decode_batch_size, 2), dtype=np.int64) - input_ids[valid_batch_indices] = np.concatenate([target_tokens[:, num_tokens_selected-1].reshape(-1,1), common_input_ids], axis=1) + #if it == 61: breakpoint() + input_ids[valid_batch_indices] = np.concatenate([target_tokens[np.arange(len(valid_batch_indices)), num_tokens_selected[valid_batch_indices]-1].reshape(-1,1), common_input_ids], axis=1) dlm_decode_inputs["input_ids"] = input_ids # all_accept position_ids position_ids = np.zeros((decode_batch_size, 2), dtype=np.int64) - position_ids[valid_batch_indices] = np.concatenate([tlm_precode_inputs["position_ids"][valid_batch_indices, num_tokens_selected-1].reshape(-1,1)+1, common_position_ids], axis=1) + position_ids[valid_batch_indices] = np.concatenate([tlm_precode_inputs["position_ids"][valid_batch_indices, num_tokens_selected[valid_batch_indices]-1].reshape(-1,1)+1, common_position_ids], axis=1) dlm_decode_inputs["position_ids"] = position_ids else: dlm_decode_inputs["input_ids"][valid_batch_indices] = common_input_ids dlm_decode_inputs["position_ids"][valid_batch_indices] = common_position_ids - tlm_precode_inputs["input_ids"][valid_batch_indices,0] = common_input_ids - tlm_precode_inputs["position_ids"][valid_batch_indices] += num_tokens_selected.reshape(len(valid_batch_indices),1)+1 + tlm_precode_inputs["input_ids"][valid_batch_indices,0] = common_input_ids.flatten() + tlm_precode_inputs["position_ids"][valid_batch_indices] += num_tokens_selected[valid_batch_indices].reshape(len(valid_batch_indices),1)+1 pprint(dlm_decode_inputs) pprint(tlm_precode_inputs) if it == break_idx: breakpoint() From 8a1c47f960b3a7415245c67db86b3b2717d4cff4 Mon Sep 17 00:00:00 2001 From: eplatero Date: Mon, 16 Dec 2024 14:59:55 -0600 Subject: [PATCH 06/12] solved bugs with batch_size>1 and num_spec_tokens>1. 1 final bug remaing Signed-off-by: eplatero --- tests/transformers/spd/test_spd_inference.py | 94 +++++++++++++++++--- 1 file changed, 80 insertions(+), 14 deletions(-) diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index 866f63c9b..16489b9c2 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -7,6 +7,7 @@ from typing import List, Optional from pprint import pprint +from time import perf_counter import numpy as np import pytest @@ -14,18 +15,23 @@ from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils.constants import Constants from QEfficient.utils.device_utils import get_available_device_id -fbs = 2 +#fbs = 4 # passed with spec_len=1 +fbs = 4 configs = [ pytest.param( #["My name is", "Hello", "Hi", "My name is"], # prompt - ["My name is"]*fbs, # prompt - 1, # num_speculative_tokens + #Constants.INPUT_STR*fbs, # prompt + ['hello', 'hi', 'hola', 'bonjour'], # prompt + 4, # num_speculative_tokens 32, # prefill_seq_len 128, # ctx_len 1, # prefill_bsz "JackFram/llama-68m", # draft_model_name +# "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # draft_model_name +# "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # target_model_name "JackFram/llama-68m", # target_model_name fbs, # full_batch_size id="CB llama", @@ -90,6 +96,33 @@ def split_dlm_bonus_token_inputs(dlm_decode_inputs): bonus_token_inputs["batch_index"] = dlm_decode_inputs["batch_index"] return bonus_token_inputs, dlm_decode_inputs + +def compare_pkvs(dlm_outputs, tlm_outputs, valid_batch_indices): + for key in dlm_outputs: + if not "past" in key: continue + dlm_output = dlm_outputs[key][valid_batch_indices] + tlm_output = tlm_outputs[key][valid_batch_indices] + array_equal = np.array_equal(dlm_output, tlm_output) + if not array_equal: + print(f"{key} do NOT match!") + +def compare_idx_pkvs(dlm_outputs, tlm_outputs, valid_batch_indices, idx): + for key in dlm_outputs: + if not "past" in key: continue + dlm_output = dlm_outputs[key][valid_batch_indices, :, idx] + tlm_output = tlm_outputs[key][valid_batch_indices, :, idx] + array_equal = np.array_equal(dlm_output, tlm_output) + if not array_equal: + a = dlm_output.flatten() + b = tlm_output.flatten() + scalar = np.dot(a, b) + a_mag = np.sqrt(np.dot(a,a)) + b_mag = np.sqrt(np.dot(b,b)) + sim = scalar / (a_mag * b_mag) + print(f"{key} do NOT match! Similary sore: {sim}") + + + @pytest.mark.parametrize( "prompt, num_speculative_tokens, prefill_seq_len, ctx_len, prefill_bsz, draft_model_name, target_model_name, full_batch_size", configs, @@ -106,7 +139,7 @@ def test_spec_decode_inference( ): # get device group #device_group: List[int] = get_available_device_id() - device_group: List[int] = [1] + device_group: List[int] = [31] if not device_group: pytest.skip("No available devices to run model on Cloud AI 100") # assumes dlm and tlm are compiled to the same prompt-chunk-size, context length and full_batch_size/batch-size @@ -140,11 +173,11 @@ def test_spec_decode_inference( # skip inputs/outputs buffers target_model_session.skip_buffers(set([x for x in target_model_session.input_names if x.startswith("past_")])) - target_model_session.skip_buffers( - set([x for x in target_model_session.output_names if x.endswith("_RetainedState")]) - ) +# target_model_session.skip_buffers( +# set([x for x in target_model_session.output_names if x.endswith("_RetainedState")]) +# ) draft_model_session.skip_buffers(set([x for x in draft_model_session.input_names if x.startswith("past_")])) - draft_model_session.skip_buffers(set([x for x in draft_model_session.output_names if x.endswith("_RetainedState")])) + #draft_model_session.skip_buffers(set([x for x in draft_model_session.output_names if x.endswith("_RetainedState")])) is_cb = full_batch_size is not None if not is_cb: @@ -189,8 +222,11 @@ def test_spec_decode_inference( target_model_session.set_buffers({"logits": tlm_prefill_logits_ph}) draft_model_session.set_buffers({"logits": dlm_prefill_logits_ph}) + e2e_start = perf_counter() + ttfts = [] for bi in range(decode_batch_size): # assumes that prefill queue will always be popped from the front + start = perf_counter() tlm_logits = run_prefill_on_draft_and_target( tlm_session=target_model_session, dlm_session=draft_model_session, @@ -198,6 +234,8 @@ def test_spec_decode_inference( prefill_seq_len=prefill_seq_len, slot_idx=bi, ) + ttft = perf_counter() - start + ttfts.append(ttft) input_ids = tlm_logits.argmax(2).astype(np.int64) dlm_decode_inputs["input_ids"][bi, 0] = input_ids tlm_precode_inputs["input_ids"][bi, 0] = input_ids.item() @@ -208,6 +246,9 @@ def test_spec_decode_inference( input_lengths[bi] = input_len max_gen_len[bi] -= input_lengths[bi] + print('# PREFILL') + print(f"{dlm_decode_inputs=}") + print(f"{tlm_precode_inputs=}") # set decode logits buffers target_model_session.set_buffers({"logits": precode_logits_ph}) draft_model_session.set_buffers({"logits": decode_logits_ph}) @@ -218,6 +259,8 @@ def test_spec_decode_inference( it = 0 #break_idx = 61 break_idx = 10000 + print('# DECODE') + decode_start = perf_counter() while True: print('-'*60) print(f"{it=}") @@ -230,6 +273,8 @@ def test_spec_decode_inference( bonus_token_inputs, dlm_decode_inputs = split_dlm_bonus_token_inputs(dlm_decode_inputs) pprint(f"{bonus_token_inputs=}") _ = draft_model_session.run(bonus_token_inputs) + all_accept = False + compare_pkvs(_, tlm_outputs, valid_batch_indices) dlm_outputs = draft_model_session.run(dlm_decode_inputs) pprint(f"{dlm_decode_inputs=}") pprint(f"{tlm_precode_inputs=}") @@ -241,6 +286,7 @@ def test_spec_decode_inference( pprint(f"{dlm_decode_inputs=}") pprint(f"{tlm_precode_inputs=}") if it == break_idx: breakpoint() + #if it == 41: breakpoint() tlm_outputs = target_model_session.run(tlm_precode_inputs) target_logits = tlm_outputs["logits"][valid_batch_indices] # greedy sampling from target model @@ -251,13 +297,17 @@ def test_spec_decode_inference( num_tokens_selected = matching.cumprod(axis=1).sum(axis=1) # shape: [len(valid_batch_indices)] all_accept = (num_tokens_selected == num_speculative_tokens).all() num_tokens_selected_per_validation.append(num_tokens_selected) + print(f"{draft_tokens=}") print(f"{target_tokens=}") print(f"{num_tokens_selected=}") print(f"{all_accept=}") + if it == 41: + compare_idx_pkvs(dlm_outputs, tlm_outputs, valid_batch_indices, 127) if it == break_idx: breakpoint() if not all_accept: breakpoint() # append selected tokens to the generated_ids + tlm_precode_position_ids = tlm_precode_inputs["position_ids"][valid_batch_indices] + num_tokens_selected[valid_batch_indices].reshape(len(valid_batch_indices),1)+1 indices_to_rm = [] for bi in valid_batch_indices: accepted_tokens = num_tokens_selected[bi] @@ -265,7 +315,9 @@ def test_spec_decode_inference( accepted_tokens += 1 num_tokens_to_append = min(accepted_tokens, max_gen_len[bi] - len(generated_ids[bi])) generated_ids[bi].extend(target_tokens[bi, :num_tokens_to_append].tolist()) - if len(generated_ids[bi]) >= max_gen_len[bi]: + # position_ids >= ctx_len-1 result in erronous output for logits at each seq_len of TLM + # (e.g., ctx_len=128 -> position_ids=[127,128,129] will give erronous output at each predicted token) + if len(generated_ids[bi]) >= max_gen_len[bi] or (tlm_precode_position_ids[bi] >= ctx_len-1).any(): indices_to_rm.append(bi) for idx in indices_to_rm: del valid_batch_indices[idx] @@ -294,9 +346,23 @@ def test_spec_decode_inference( pprint(tlm_precode_inputs) if it == break_idx: breakpoint() it += 1 + end = perf_counter() + decode_end = end - decode_start + e2e_end = end - e2e_start num_tokens_selected_per_validation = np.concatenate(num_tokens_selected_per_validation).reshape(len(num_tokens_selected_per_validation), decode_batch_size) - mean_num_accepted_tokens_per_batch = num_tokens_selected_per_validation.mean(axis=0) - print("mean number of accepted tokens per batch = ", mean_num_accepted_tokens_per_batch) - print("max generation len = ", max_gen_len) - print("actual generation len = ", [len(gid) for gid in generated_ids]) - print(tokenizer.batch_decode(generated_ids)) + mean_num_accepted_tokens_per_prompt = num_tokens_selected_per_validation.mean(axis=0) + mean_num_accepted_tokens = mean_num_accepted_tokens_per_prompt.mean() + mean_ttft = sum(ttfts) / len(ttfts) + generated_tokens_per_prompt = [len(gid)+1 for gid in generated_ids] + decode_throughput = sum(generated_tokens_per_prompt) / decode_end + e2e_throughput = (sum(generated_tokens_per_prompt)+decode_batch_size) / e2e_end + batch_decode = tokenizer.batch_decode(generated_ids) + print(f"Avg TLM+DLM TTFT = {mean_ttft}") + print(f"Decode Throughput = {decode_throughput}") + print(f"E2E Throughput = {e2e_throughput}") + print("Avg number of accepted tokens per prompt = ", mean_num_accepted_tokens_per_prompt) + print("Avg number of accepted tokens = ", mean_num_accepted_tokens) + print("Max generation len = ", max_gen_len) + print("Total Generated Tokens per Prompt: = ", generated_tokens_per_prompt) + for prompt,generation in zip(prompts, batch_decode): + print(f"{prompt=} {generation=}") From fd2c21bbb43ed959b6b2d6a98c08845aafb2c2e1 Mon Sep 17 00:00:00 2001 From: eplatero Date: Mon, 16 Dec 2024 15:51:17 -0600 Subject: [PATCH 07/12] resolved some bugs Signed-off-by: eplatero --- tests/transformers/spd/test_spd_inference.py | 39 ++++++++++---------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index 16489b9c2..4ed1ddf07 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -29,10 +29,10 @@ 32, # prefill_seq_len 128, # ctx_len 1, # prefill_bsz - "JackFram/llama-68m", # draft_model_name -# "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # draft_model_name -# "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # target_model_name - "JackFram/llama-68m", # target_model_name +# "JackFram/llama-68m", # draft_model_name + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # draft_model_name + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # target_model_name +# "JackFram/llama-68m", # target_model_name fbs, # full_batch_size id="CB llama", ), @@ -254,14 +254,15 @@ def test_spec_decode_inference( draft_model_session.set_buffers({"logits": decode_logits_ph}) # start decode phase valid_batch_indices = list(range(decode_batch_size))[::-1] - num_tokens_selected_per_validation = [] all_accept = False it = 0 #break_idx = 61 break_idx = 10000 print('# DECODE') decode_start = perf_counter() + mean_num_accepted_tokens = 0 while True: + it += 1 print('-'*60) print(f"{it=}") print(f"{valid_batch_indices=}") @@ -288,17 +289,19 @@ def test_spec_decode_inference( if it == break_idx: breakpoint() #if it == 41: breakpoint() tlm_outputs = target_model_session.run(tlm_precode_inputs) - target_logits = tlm_outputs["logits"][valid_batch_indices] + target_logits = tlm_outputs["logits"] + #target_logits = tlm_outputs["logits"][valid_batch_indices] # greedy sampling from target model - target_tokens = target_logits.argmax(-1) # shape: [len(valid_batch_indices), num_speculative_tokens+1] + target_tokens = target_logits.argmax(-1) # shape: [decode_batch_size, num_speculative_tokens+1] # exact matching between draft and target tokens draft_tokens = tlm_precode_inputs["input_ids"][valid_batch_indices,1:] # shape: [len(valid_batch_indices), num_speculative_tokens] - matching = draft_tokens == target_tokens[:, :-1] + matching = draft_tokens == target_tokens[valid_batch_indices, :-1] num_tokens_selected = matching.cumprod(axis=1).sum(axis=1) # shape: [len(valid_batch_indices)] all_accept = (num_tokens_selected == num_speculative_tokens).all() - num_tokens_selected_per_validation.append(num_tokens_selected) + mean_num_accepted_tokens += num_tokens_selected.mean().item() + #num_tokens_selected_per_validation.append(num_tokens_selected) print(f"{draft_tokens=}") - print(f"{target_tokens=}") + print(f"{target_tokens[valid_batch_indices]=}") print(f"{num_tokens_selected=}") print(f"{all_accept=}") if it == 41: @@ -315,9 +318,10 @@ def test_spec_decode_inference( accepted_tokens += 1 num_tokens_to_append = min(accepted_tokens, max_gen_len[bi] - len(generated_ids[bi])) generated_ids[bi].extend(target_tokens[bi, :num_tokens_to_append].tolist()) - # position_ids >= ctx_len-1 result in erronous output for logits at each seq_len of TLM + # position_ids > ctx_len-1 result in erronous output for logits at each seq_len of TLM # (e.g., ctx_len=128 -> position_ids=[127,128,129] will give erronous output at each predicted token) - if len(generated_ids[bi]) >= max_gen_len[bi] or (tlm_precode_position_ids[bi] >= ctx_len-1).any(): + #if len(generated_ids[bi]) >= max_gen_len[bi] or (tlm_precode_position_ids[bi] > ctx_len-1).any(): + if len(generated_ids[bi]) >= max_gen_len[bi] or (tlm_precode_position_ids[bi] > ctx_len-1).any(): indices_to_rm.append(bi) for idx in indices_to_rm: del valid_batch_indices[idx] @@ -325,13 +329,14 @@ def test_spec_decode_inference( # check if all generations are done if not valid_batch_indices: break # prepare decode inputs for next decode iteration - common_input_ids = target_tokens[np.arange(len(valid_batch_indices)), num_tokens_selected[valid_batch_indices]].reshape(len(valid_batch_indices), 1) + common_input_ids = target_tokens[np.asarray(valid_batch_indices), num_tokens_selected[valid_batch_indices]].reshape(len(valid_batch_indices), 1) + #common_input_ids = target_tokens[np.arange(len(valid_batch_indices)), num_tokens_selected[valid_batch_indices]].reshape(len(valid_batch_indices), 1) common_position_ids = tlm_precode_inputs["position_ids"][valid_batch_indices, num_tokens_selected[valid_batch_indices]].reshape(len(valid_batch_indices), 1)+1 if all_accept: # all_accept input_ids input_ids = np.zeros((decode_batch_size, 2), dtype=np.int64) #if it == 61: breakpoint() - input_ids[valid_batch_indices] = np.concatenate([target_tokens[np.arange(len(valid_batch_indices)), num_tokens_selected[valid_batch_indices]-1].reshape(-1,1), common_input_ids], axis=1) + input_ids[valid_batch_indices] = np.concatenate([target_tokens[np.asarray(valid_batch_indices), num_tokens_selected[valid_batch_indices]-1].reshape(-1,1), common_input_ids], axis=1) dlm_decode_inputs["input_ids"] = input_ids # all_accept position_ids position_ids = np.zeros((decode_batch_size, 2), dtype=np.int64) @@ -345,22 +350,18 @@ def test_spec_decode_inference( pprint(dlm_decode_inputs) pprint(tlm_precode_inputs) if it == break_idx: breakpoint() - it += 1 end = perf_counter() decode_end = end - decode_start e2e_end = end - e2e_start - num_tokens_selected_per_validation = np.concatenate(num_tokens_selected_per_validation).reshape(len(num_tokens_selected_per_validation), decode_batch_size) - mean_num_accepted_tokens_per_prompt = num_tokens_selected_per_validation.mean(axis=0) - mean_num_accepted_tokens = mean_num_accepted_tokens_per_prompt.mean() mean_ttft = sum(ttfts) / len(ttfts) generated_tokens_per_prompt = [len(gid)+1 for gid in generated_ids] decode_throughput = sum(generated_tokens_per_prompt) / decode_end e2e_throughput = (sum(generated_tokens_per_prompt)+decode_batch_size) / e2e_end batch_decode = tokenizer.batch_decode(generated_ids) + mean_num_accepted_tokens /= it print(f"Avg TLM+DLM TTFT = {mean_ttft}") print(f"Decode Throughput = {decode_throughput}") print(f"E2E Throughput = {e2e_throughput}") - print("Avg number of accepted tokens per prompt = ", mean_num_accepted_tokens_per_prompt) print("Avg number of accepted tokens = ", mean_num_accepted_tokens) print("Max generation len = ", max_gen_len) print("Total Generated Tokens per Prompt: = ", generated_tokens_per_prompt) From 86c4952709137b7738a0003d1dd2e024af4e0fbd Mon Sep 17 00:00:00 2001 From: eplatero Date: Mon, 16 Dec 2024 16:44:53 -0600 Subject: [PATCH 08/12] rm most debug logs Signed-off-by: eplatero --- tests/transformers/spd/test_spd_inference.py | 54 +++----------------- 1 file changed, 8 insertions(+), 46 deletions(-) diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index 4ed1ddf07..6d8aa8173 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -18,22 +18,16 @@ from QEfficient.utils.constants import Constants from QEfficient.utils.device_utils import get_available_device_id -#fbs = 4 # passed with spec_len=1 -fbs = 4 configs = [ pytest.param( - #["My name is", "Hello", "Hi", "My name is"], # prompt - #Constants.INPUT_STR*fbs, # prompt ['hello', 'hi', 'hola', 'bonjour'], # prompt 4, # num_speculative_tokens 32, # prefill_seq_len 128, # ctx_len 1, # prefill_bsz -# "JackFram/llama-68m", # draft_model_name "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # draft_model_name "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # target_model_name -# "JackFram/llama-68m", # target_model_name - fbs, # full_batch_size + 4, # full_batch_size id="CB llama", ), ] @@ -139,7 +133,8 @@ def test_spec_decode_inference( ): # get device group #device_group: List[int] = get_available_device_id() - device_group: List[int] = [31] + device_group: List[int] = [0] + #device_group: List[int] = [31] if not device_group: pytest.skip("No available devices to run model on Cloud AI 100") # assumes dlm and tlm are compiled to the same prompt-chunk-size, context length and full_batch_size/batch-size @@ -173,11 +168,11 @@ def test_spec_decode_inference( # skip inputs/outputs buffers target_model_session.skip_buffers(set([x for x in target_model_session.input_names if x.startswith("past_")])) -# target_model_session.skip_buffers( -# set([x for x in target_model_session.output_names if x.endswith("_RetainedState")]) -# ) + target_model_session.skip_buffers( + set([x for x in target_model_session.output_names if x.endswith("_RetainedState")]) + ) draft_model_session.skip_buffers(set([x for x in draft_model_session.input_names if x.startswith("past_")])) - #draft_model_session.skip_buffers(set([x for x in draft_model_session.output_names if x.endswith("_RetainedState")])) + draft_model_session.skip_buffers(set([x for x in draft_model_session.output_names if x.endswith("_RetainedState")])) is_cb = full_batch_size is not None if not is_cb: @@ -246,9 +241,6 @@ def test_spec_decode_inference( input_lengths[bi] = input_len max_gen_len[bi] -= input_lengths[bi] - print('# PREFILL') - print(f"{dlm_decode_inputs=}") - print(f"{tlm_precode_inputs=}") # set decode logits buffers target_model_session.set_buffers({"logits": precode_logits_ph}) draft_model_session.set_buffers({"logits": decode_logits_ph}) @@ -256,41 +248,27 @@ def test_spec_decode_inference( valid_batch_indices = list(range(decode_batch_size))[::-1] all_accept = False it = 0 - #break_idx = 61 - break_idx = 10000 - print('# DECODE') decode_start = perf_counter() mean_num_accepted_tokens = 0 while True: it += 1 - print('-'*60) - print(f"{it=}") - print(f"{valid_batch_indices=}") # generate proposals from draft model for k_ in range(num_speculative_tokens): if all_accept: #running decode one extra time in the first speculative iteration # workaround to avoid the incorrect precode with 3-specialized multi-batch DLM bonus_token_inputs, dlm_decode_inputs = split_dlm_bonus_token_inputs(dlm_decode_inputs) - pprint(f"{bonus_token_inputs=}") _ = draft_model_session.run(bonus_token_inputs) all_accept = False compare_pkvs(_, tlm_outputs, valid_batch_indices) dlm_outputs = draft_model_session.run(dlm_decode_inputs) - pprint(f"{dlm_decode_inputs=}") - pprint(f"{tlm_precode_inputs=}") input_ids = dlm_outputs["logits"].argmax(2) tlm_precode_inputs["input_ids"][:, k_+1] = input_ids.flatten() dlm_decode_inputs["input_ids"] = input_ids dlm_decode_inputs["position_ids"][valid_batch_indices] += 1 # run precode on TLM to score the proposed tokens - pprint(f"{dlm_decode_inputs=}") - pprint(f"{tlm_precode_inputs=}") - if it == break_idx: breakpoint() - #if it == 41: breakpoint() tlm_outputs = target_model_session.run(tlm_precode_inputs) target_logits = tlm_outputs["logits"] - #target_logits = tlm_outputs["logits"][valid_batch_indices] # greedy sampling from target model target_tokens = target_logits.argmax(-1) # shape: [decode_batch_size, num_speculative_tokens+1] # exact matching between draft and target tokens @@ -299,16 +277,7 @@ def test_spec_decode_inference( num_tokens_selected = matching.cumprod(axis=1).sum(axis=1) # shape: [len(valid_batch_indices)] all_accept = (num_tokens_selected == num_speculative_tokens).all() mean_num_accepted_tokens += num_tokens_selected.mean().item() - #num_tokens_selected_per_validation.append(num_tokens_selected) - print(f"{draft_tokens=}") - print(f"{target_tokens[valid_batch_indices]=}") - print(f"{num_tokens_selected=}") - print(f"{all_accept=}") - if it == 41: - compare_idx_pkvs(dlm_outputs, tlm_outputs, valid_batch_indices, 127) - if it == break_idx: breakpoint() - if not all_accept: breakpoint() - + #if not all_accept: breakpoint() # append selected tokens to the generated_ids tlm_precode_position_ids = tlm_precode_inputs["position_ids"][valid_batch_indices] + num_tokens_selected[valid_batch_indices].reshape(len(valid_batch_indices),1)+1 indices_to_rm = [] @@ -320,22 +289,18 @@ def test_spec_decode_inference( generated_ids[bi].extend(target_tokens[bi, :num_tokens_to_append].tolist()) # position_ids > ctx_len-1 result in erronous output for logits at each seq_len of TLM # (e.g., ctx_len=128 -> position_ids=[127,128,129] will give erronous output at each predicted token) - #if len(generated_ids[bi]) >= max_gen_len[bi] or (tlm_precode_position_ids[bi] > ctx_len-1).any(): if len(generated_ids[bi]) >= max_gen_len[bi] or (tlm_precode_position_ids[bi] > ctx_len-1).any(): indices_to_rm.append(bi) for idx in indices_to_rm: del valid_batch_indices[idx] - print(f"{valid_batch_indices=}") # check if all generations are done if not valid_batch_indices: break # prepare decode inputs for next decode iteration common_input_ids = target_tokens[np.asarray(valid_batch_indices), num_tokens_selected[valid_batch_indices]].reshape(len(valid_batch_indices), 1) - #common_input_ids = target_tokens[np.arange(len(valid_batch_indices)), num_tokens_selected[valid_batch_indices]].reshape(len(valid_batch_indices), 1) common_position_ids = tlm_precode_inputs["position_ids"][valid_batch_indices, num_tokens_selected[valid_batch_indices]].reshape(len(valid_batch_indices), 1)+1 if all_accept: # all_accept input_ids input_ids = np.zeros((decode_batch_size, 2), dtype=np.int64) - #if it == 61: breakpoint() input_ids[valid_batch_indices] = np.concatenate([target_tokens[np.asarray(valid_batch_indices), num_tokens_selected[valid_batch_indices]-1].reshape(-1,1), common_input_ids], axis=1) dlm_decode_inputs["input_ids"] = input_ids # all_accept position_ids @@ -347,9 +312,6 @@ def test_spec_decode_inference( dlm_decode_inputs["position_ids"][valid_batch_indices] = common_position_ids tlm_precode_inputs["input_ids"][valid_batch_indices,0] = common_input_ids.flatten() tlm_precode_inputs["position_ids"][valid_batch_indices] += num_tokens_selected[valid_batch_indices].reshape(len(valid_batch_indices),1)+1 - pprint(dlm_decode_inputs) - pprint(tlm_precode_inputs) - if it == break_idx: breakpoint() end = perf_counter() decode_end = end - decode_start e2e_end = end - e2e_start From 7e1efa31b6a8bd2fdc67f1a67701c5eb06a9f5fb Mon Sep 17 00:00:00 2001 From: eplatero Date: Wed, 18 Dec 2024 06:45:38 -0600 Subject: [PATCH 09/12] fix bug when some samples get all accepted and others do not Signed-off-by: eplatero --- tests/transformers/spd/test_spd_inference.py | 97 ++++++++------------ 1 file changed, 39 insertions(+), 58 deletions(-) diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index 6d8aa8173..74816ec75 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -25,8 +25,10 @@ 32, # prefill_seq_len 128, # ctx_len 1, # prefill_bsz - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # draft_model_name - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # target_model_name + "JackFram/llama-68m", # draft_model_name + "JackFram/llama-68m", # draft_model_name + #"TinyLlama/TinyLlama-1.1B-Chat-v1.0", # draft_model_name + #"TinyLlama/TinyLlama-1.1B-Chat-v1.0", # target_model_name 4, # full_batch_size id="CB llama", ), @@ -91,32 +93,6 @@ def split_dlm_bonus_token_inputs(dlm_decode_inputs): return bonus_token_inputs, dlm_decode_inputs -def compare_pkvs(dlm_outputs, tlm_outputs, valid_batch_indices): - for key in dlm_outputs: - if not "past" in key: continue - dlm_output = dlm_outputs[key][valid_batch_indices] - tlm_output = tlm_outputs[key][valid_batch_indices] - array_equal = np.array_equal(dlm_output, tlm_output) - if not array_equal: - print(f"{key} do NOT match!") - -def compare_idx_pkvs(dlm_outputs, tlm_outputs, valid_batch_indices, idx): - for key in dlm_outputs: - if not "past" in key: continue - dlm_output = dlm_outputs[key][valid_batch_indices, :, idx] - tlm_output = tlm_outputs[key][valid_batch_indices, :, idx] - array_equal = np.array_equal(dlm_output, tlm_output) - if not array_equal: - a = dlm_output.flatten() - b = tlm_output.flatten() - scalar = np.dot(a, b) - a_mag = np.sqrt(np.dot(a,a)) - b_mag = np.sqrt(np.dot(b,b)) - sim = scalar / (a_mag * b_mag) - print(f"{key} do NOT match! Similary sore: {sim}") - - - @pytest.mark.parametrize( "prompt, num_speculative_tokens, prefill_seq_len, ctx_len, prefill_bsz, draft_model_name, target_model_name, full_batch_size", configs, @@ -133,8 +109,8 @@ def test_spec_decode_inference( ): # get device group #device_group: List[int] = get_available_device_id() - device_group: List[int] = [0] - #device_group: List[int] = [31] + #device_group: List[int] = [0] + device_group: List[int] = [31] if not device_group: pytest.skip("No available devices to run model on Cloud AI 100") # assumes dlm and tlm are compiled to the same prompt-chunk-size, context length and full_batch_size/batch-size @@ -209,7 +185,6 @@ def test_spec_decode_inference( max_gen_len = [ctx_len] * decode_batch_size num_logits_to_keep = num_speculative_tokens+1 # setup buffers - all_accept = np.full((decode_batch_size, num_speculative_tokens), False, dtype=bool) tlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32) dlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32) decode_logits_ph = np.zeros((decode_batch_size, 1, vocab_size), dtype=np.float32) @@ -240,27 +215,28 @@ def test_spec_decode_inference( # assumes that prefill queue will always be popped from the front input_lengths[bi] = input_len max_gen_len[bi] -= input_lengths[bi] + batch_ttft = perf_counter() - e2e_start # set decode logits buffers target_model_session.set_buffers({"logits": precode_logits_ph}) draft_model_session.set_buffers({"logits": decode_logits_ph}) # start decode phase - valid_batch_indices = list(range(decode_batch_size))[::-1] + valid_batch_indices = np.full(decode_batch_size, True, dtype=bool) all_accept = False it = 0 decode_start = perf_counter() mean_num_accepted_tokens = 0 + all_accept = np.full(decode_batch_size, False, dtype=bool) while True: it += 1 # generate proposals from draft model for k_ in range(num_speculative_tokens): - if all_accept: + if all_accept.any(): #running decode one extra time in the first speculative iteration # workaround to avoid the incorrect precode with 3-specialized multi-batch DLM bonus_token_inputs, dlm_decode_inputs = split_dlm_bonus_token_inputs(dlm_decode_inputs) _ = draft_model_session.run(bonus_token_inputs) - all_accept = False - compare_pkvs(_, tlm_outputs, valid_batch_indices) + all_accept[:] = False dlm_outputs = draft_model_session.run(dlm_decode_inputs) input_ids = dlm_outputs["logits"].argmax(2) tlm_precode_inputs["input_ids"][:, k_+1] = input_ids.flatten() @@ -270,48 +246,47 @@ def test_spec_decode_inference( tlm_outputs = target_model_session.run(tlm_precode_inputs) target_logits = tlm_outputs["logits"] # greedy sampling from target model - target_tokens = target_logits.argmax(-1) # shape: [decode_batch_size, num_speculative_tokens+1] + target_tokens = target_logits.argmax(-1) # exact matching between draft and target tokens - draft_tokens = tlm_precode_inputs["input_ids"][valid_batch_indices,1:] # shape: [len(valid_batch_indices), num_speculative_tokens] - matching = draft_tokens == target_tokens[valid_batch_indices, :-1] - num_tokens_selected = matching.cumprod(axis=1).sum(axis=1) # shape: [len(valid_batch_indices)] - all_accept = (num_tokens_selected == num_speculative_tokens).all() - mean_num_accepted_tokens += num_tokens_selected.mean().item() - #if not all_accept: breakpoint() + draft_tokens = tlm_precode_inputs["input_ids"][:,1:] + matching = draft_tokens == target_tokens[:, :-1] # shape: [decode_batch_size, num_speculative_tokens] + num_tokens_selected = matching.cumprod(axis=1).sum(axis=1)+1 # shape: [decode_batch_size] + all_accept[valid_batch_indices] = (num_tokens_selected[valid_batch_indices] == num_speculative_tokens+1) + mean_num_accepted_tokens += num_tokens_selected[valid_batch_indices].mean().item() # append selected tokens to the generated_ids - tlm_precode_position_ids = tlm_precode_inputs["position_ids"][valid_batch_indices] + num_tokens_selected[valid_batch_indices].reshape(len(valid_batch_indices),1)+1 - indices_to_rm = [] - for bi in valid_batch_indices: + tlm_precode_position_ids = tlm_precode_inputs["position_ids"] + num_tokens_selected.reshape(decode_batch_size,1) + #tlm_precode_position_ids = tlm_precode_inputs["position_ids"] + num_tokens_selected.reshape(decode_batch_size,1)+1 + for bi,valid in enumerate(valid_batch_indices): + if not valid: continue accepted_tokens = num_tokens_selected[bi] - if all_accept: - accepted_tokens += 1 num_tokens_to_append = min(accepted_tokens, max_gen_len[bi] - len(generated_ids[bi])) generated_ids[bi].extend(target_tokens[bi, :num_tokens_to_append].tolist()) # position_ids > ctx_len-1 result in erronous output for logits at each seq_len of TLM # (e.g., ctx_len=128 -> position_ids=[127,128,129] will give erronous output at each predicted token) if len(generated_ids[bi]) >= max_gen_len[bi] or (tlm_precode_position_ids[bi] > ctx_len-1).any(): - indices_to_rm.append(bi) - for idx in indices_to_rm: - del valid_batch_indices[idx] + valid_batch_indices[bi] = False # check if all generations are done - if not valid_batch_indices: break + if not valid_batch_indices.any(): break # prepare decode inputs for next decode iteration - common_input_ids = target_tokens[np.asarray(valid_batch_indices), num_tokens_selected[valid_batch_indices]].reshape(len(valid_batch_indices), 1) - common_position_ids = tlm_precode_inputs["position_ids"][valid_batch_indices, num_tokens_selected[valid_batch_indices]].reshape(len(valid_batch_indices), 1)+1 - if all_accept: + num_valid_batch_indices = (valid_batch_indices == True).sum().item() + common_input_ids = target_tokens[valid_batch_indices, num_tokens_selected[valid_batch_indices]-1].reshape(num_valid_batch_indices, 1) + common_position_ids = tlm_precode_inputs["position_ids"][valid_batch_indices, num_tokens_selected[valid_batch_indices]-1].reshape(num_valid_batch_indices, 1)+1 + if all_accept.any(): # all_accept input_ids input_ids = np.zeros((decode_batch_size, 2), dtype=np.int64) - input_ids[valid_batch_indices] = np.concatenate([target_tokens[np.asarray(valid_batch_indices), num_tokens_selected[valid_batch_indices]-1].reshape(-1,1), common_input_ids], axis=1) + last_spec_token_id = target_tokens[valid_batch_indices, -2].reshape(-1,1) + input_ids[valid_batch_indices] = np.concatenate([last_spec_token_id, common_input_ids], axis=1) dlm_decode_inputs["input_ids"] = input_ids # all_accept position_ids - position_ids = np.zeros((decode_batch_size, 2), dtype=np.int64) - position_ids[valid_batch_indices] = np.concatenate([tlm_precode_inputs["position_ids"][valid_batch_indices, num_tokens_selected[valid_batch_indices]-1].reshape(-1,1)+1, common_position_ids], axis=1) + position_ids = np.full((decode_batch_size,2), -1, dtype=np.int64) + last_spec_position_id = tlm_precode_inputs["position_ids"][valid_batch_indices, -1].reshape(-1,1) + position_ids[valid_batch_indices] = np.concatenate([last_spec_position_id, common_position_ids], axis=1) dlm_decode_inputs["position_ids"] = position_ids else: dlm_decode_inputs["input_ids"][valid_batch_indices] = common_input_ids dlm_decode_inputs["position_ids"][valid_batch_indices] = common_position_ids tlm_precode_inputs["input_ids"][valid_batch_indices,0] = common_input_ids.flatten() - tlm_precode_inputs["position_ids"][valid_batch_indices] += num_tokens_selected[valid_batch_indices].reshape(len(valid_batch_indices),1)+1 + tlm_precode_inputs["position_ids"][valid_batch_indices] += num_tokens_selected[valid_batch_indices].reshape(num_valid_batch_indices,1) end = perf_counter() decode_end = end - decode_start e2e_end = end - e2e_start @@ -322,6 +297,7 @@ def test_spec_decode_inference( batch_decode = tokenizer.batch_decode(generated_ids) mean_num_accepted_tokens /= it print(f"Avg TLM+DLM TTFT = {mean_ttft}") + print(f"Total TLM+DLM Batch TTFT = {batch_ttft}") print(f"Decode Throughput = {decode_throughput}") print(f"E2E Throughput = {e2e_throughput}") print("Avg number of accepted tokens = ", mean_num_accepted_tokens) @@ -329,3 +305,8 @@ def test_spec_decode_inference( print("Total Generated Tokens per Prompt: = ", generated_tokens_per_prompt) for prompt,generation in zip(prompts, batch_decode): print(f"{prompt=} {generation=}") + # validation check + del target_model_session + del draft_model_session + outputs = draft_model.generate(tokenizer, prompts, device_group) + breakpoint() From 9ec726a39f75dcc1fa1271ad210474969704ff7e Mon Sep 17 00:00:00 2001 From: eplatero Date: Wed, 18 Dec 2024 07:08:15 -0600 Subject: [PATCH 10/12] assert spd output matches vanilla dlm output Signed-off-by: eplatero --- tests/transformers/spd/test_spd_inference.py | 24 +++++++++----------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index 74816ec75..b71937f65 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -6,7 +6,6 @@ # ----------------------------------------------------------------------------- from typing import List, Optional -from pprint import pprint from time import perf_counter import numpy as np @@ -20,16 +19,14 @@ configs = [ pytest.param( - ['hello', 'hi', 'hola', 'bonjour'], # prompt - 4, # num_speculative_tokens + Constants.INPUT_STR, # prompt + 1, # num_speculative_tokens 32, # prefill_seq_len 128, # ctx_len 1, # prefill_bsz - "JackFram/llama-68m", # draft_model_name - "JackFram/llama-68m", # draft_model_name - #"TinyLlama/TinyLlama-1.1B-Chat-v1.0", # draft_model_name - #"TinyLlama/TinyLlama-1.1B-Chat-v1.0", # target_model_name - 4, # full_batch_size + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # draft_model_name + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # target_model_name + 1, # full_batch_size id="CB llama", ), ] @@ -108,9 +105,7 @@ def test_spec_decode_inference( full_batch_size: Optional[int], ): # get device group - #device_group: List[int] = get_available_device_id() - #device_group: List[int] = [0] - device_group: List[int] = [31] + device_group: List[int] = get_available_device_id() if not device_group: pytest.skip("No available devices to run model on Cloud AI 100") # assumes dlm and tlm are compiled to the same prompt-chunk-size, context length and full_batch_size/batch-size @@ -207,6 +202,7 @@ def test_spec_decode_inference( ttft = perf_counter() - start ttfts.append(ttft) input_ids = tlm_logits.argmax(2).astype(np.int64) + generated_ids[bi].append(input_ids.item()) dlm_decode_inputs["input_ids"][bi, 0] = input_ids tlm_precode_inputs["input_ids"][bi, 0] = input_ids.item() input_len = prompts_tokenized[bi]['position_ids'].max(1).item() + 1 @@ -306,7 +302,9 @@ def test_spec_decode_inference( for prompt,generation in zip(prompts, batch_decode): print(f"{prompt=} {generation=}") # validation check + assert mean_num_accepted_tokens == float(num_speculative_tokens+1), f"mean number of accepted tokens is {mean_num_accepted_tokens} but should be {num_speculative_tokens+1}" del target_model_session del draft_model_session - outputs = draft_model.generate(tokenizer, prompts, device_group) - breakpoint() + exec_info = draft_model.generate(tokenizer, Constants.INPUT_STR, device_group) + cloud_ai_100_tokens = exec_info.generated_ids[0][:max_gen_len[0]] # Because we always run for single input and single batch size + assert (cloud_ai_100_tokens == np.asarray(generated_ids)[0]).all(), "Tokens don't match for SpD output and vanilla DLM output." From a1a99f599c8d41298508eed88f80c74e7ae4505b Mon Sep 17 00:00:00 2001 From: eplatero Date: Wed, 18 Dec 2024 07:18:55 -0600 Subject: [PATCH 11/12] linting Signed-off-by: eplatero --- tests/transformers/spd/test_spd_inference.py | 161 +++++++++++-------- 1 file changed, 95 insertions(+), 66 deletions(-) diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index b71937f65..9a17d2a63 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -5,8 +5,8 @@ # # ----------------------------------------------------------------------------- -from typing import List, Optional from time import perf_counter +from typing import List, Optional import numpy as np import pytest @@ -19,26 +19,25 @@ configs = [ pytest.param( - Constants.INPUT_STR, # prompt - 1, # num_speculative_tokens - 32, # prefill_seq_len - 128, # ctx_len - 1, # prefill_bsz - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # draft_model_name - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # target_model_name - 1, # full_batch_size + Constants.INPUT_STR, # prompt + 1, # num_speculative_tokens + 32, # prefill_seq_len + 128, # ctx_len + 1, # prefill_bsz + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # draft_model_name + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # target_model_name + 1, # full_batch_size id="CB llama", ), ] - def run_prefill_on_draft_and_target( - tlm_session: QAICInferenceSession, - dlm_session: QAICInferenceSession, - inputs: dict, - prefill_seq_len: int, - slot_idx: int + tlm_session: QAICInferenceSession, + dlm_session: QAICInferenceSession, + inputs: dict, + prefill_seq_len: int, + slot_idx: int, ): input_len = inputs.input_ids.shape[1] num_chunks = input_len // prefill_seq_len @@ -50,7 +49,9 @@ def run_prefill_on_draft_and_target( for i in range(num_chunks): chunk_inputs = inputs.copy() chunk_inputs["input_ids"] = inputs["input_ids"][:, cache_index[0, 0] : cache_index[0, 0] + prefill_seq_len] - chunk_inputs["position_ids"] = inputs["position_ids"][:, cache_index[0, 0] : cache_index[0, 0] + prefill_seq_len] + chunk_inputs["position_ids"] = inputs["position_ids"][ + :, cache_index[0, 0] : cache_index[0, 0] + prefill_seq_len + ] tlm_outputs = tlm_session.run(chunk_inputs) _ = dlm_session.run(chunk_inputs) @@ -65,7 +66,7 @@ def get_padded_input_len(input_len: int, prefill_seq_len: int, ctx_len: int): Args: input_len (int): prompt length - prefill_seq_len (int): prefill sequence length + prefill_seq_len (int): prefill sequence length ctx_len (int): context length Returns: @@ -73,11 +74,12 @@ def get_padded_input_len(input_len: int, prefill_seq_len: int, ctx_len: int): """ num_chunks = -(input_len // -prefill_seq_len) # ceil divide without float input_len_padded = num_chunks * prefill_seq_len # Convert input_len to a multiple of prefill_seq_len - assert input_len_padded <= ctx_len, "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len" + assert ( + input_len_padded <= ctx_len + ), "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len" return input_len_padded - def split_dlm_bonus_token_inputs(dlm_decode_inputs): bonus_token_inputs = dict() bonus, regular = np.hsplit(dlm_decode_inputs["input_ids"], 2) @@ -117,22 +119,28 @@ def test_spec_decode_inference( # export_and_compile tlm and dlm continuous_batching = full_batch_size is not None - target_model = AutoModelForCausalLM.from_pretrained(target_model_name, continuous_batching=continuous_batching, is_tlm=True) + target_model = AutoModelForCausalLM.from_pretrained( + target_model_name, continuous_batching=continuous_batching, is_tlm=True + ) draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, continuous_batching=continuous_batching) num_devices = len(device_group) - target_model_qpc_path: str = target_model.compile(num_cores=11, - num_devices=num_devices, - prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, - aic_enable_depth_first=True, - full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens) - draft_model_qpc_path: str = draft_model.compile(num_cores=5, - prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, - aic_enable_depth_first=True, - full_batch_size=full_batch_size) + target_model_qpc_path: str = target_model.compile( + num_cores=11, + num_devices=num_devices, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + aic_enable_depth_first=True, + full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, + ) + draft_model_qpc_path: str = draft_model.compile( + num_cores=5, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + aic_enable_depth_first=True, + full_batch_size=full_batch_size, + ) # init qaic session target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=device_group) draft_model_session = QAICInferenceSession(draft_model_qpc_path, device_ids=device_group) @@ -173,12 +181,12 @@ def test_spec_decode_inference( ) # mock input key "logits" to store the first batch of output logits tlm_precode_inputs = dict( - input_ids = np.zeros((decode_batch_size, num_speculative_tokens+1), dtype=np.int64), - position_ids = np.zeros((decode_batch_size, num_speculative_tokens+1), dtype=np.int64), - batch_index = np.arange(decode_batch_size, dtype=np.int64).reshape(-1, 1) + input_ids=np.zeros((decode_batch_size, num_speculative_tokens + 1), dtype=np.int64), + position_ids=np.zeros((decode_batch_size, num_speculative_tokens + 1), dtype=np.int64), + batch_index=np.arange(decode_batch_size, dtype=np.int64).reshape(-1, 1), ) max_gen_len = [ctx_len] * decode_batch_size - num_logits_to_keep = num_speculative_tokens+1 + num_logits_to_keep = num_speculative_tokens + 1 # setup buffers tlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32) dlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32) @@ -205,9 +213,11 @@ def test_spec_decode_inference( generated_ids[bi].append(input_ids.item()) dlm_decode_inputs["input_ids"][bi, 0] = input_ids tlm_precode_inputs["input_ids"][bi, 0] = input_ids.item() - input_len = prompts_tokenized[bi]['position_ids'].max(1).item() + 1 + input_len = prompts_tokenized[bi]["position_ids"].max(1).item() + 1 dlm_decode_inputs["position_ids"][bi, 0] = input_len - tlm_precode_inputs["position_ids"][bi] = np.arange(input_len, input_len+num_speculative_tokens+1, dtype=np.int64) + tlm_precode_inputs["position_ids"][bi] = np.arange( + input_len, input_len + num_speculative_tokens + 1, dtype=np.int64 + ) # assumes that prefill queue will always be popped from the front input_lengths[bi] = input_len max_gen_len[bi] -= input_lengths[bi] @@ -228,68 +238,81 @@ def test_spec_decode_inference( # generate proposals from draft model for k_ in range(num_speculative_tokens): if all_accept.any(): - #running decode one extra time in the first speculative iteration + # running decode one extra time in the first speculative iteration # workaround to avoid the incorrect precode with 3-specialized multi-batch DLM bonus_token_inputs, dlm_decode_inputs = split_dlm_bonus_token_inputs(dlm_decode_inputs) _ = draft_model_session.run(bonus_token_inputs) all_accept[:] = False dlm_outputs = draft_model_session.run(dlm_decode_inputs) input_ids = dlm_outputs["logits"].argmax(2) - tlm_precode_inputs["input_ids"][:, k_+1] = input_ids.flatten() + tlm_precode_inputs["input_ids"][:, k_ + 1] = input_ids.flatten() dlm_decode_inputs["input_ids"] = input_ids dlm_decode_inputs["position_ids"][valid_batch_indices] += 1 # run precode on TLM to score the proposed tokens tlm_outputs = target_model_session.run(tlm_precode_inputs) target_logits = tlm_outputs["logits"] # greedy sampling from target model - target_tokens = target_logits.argmax(-1) + target_tokens = target_logits.argmax(-1) # exact matching between draft and target tokens - draft_tokens = tlm_precode_inputs["input_ids"][:,1:] - matching = draft_tokens == target_tokens[:, :-1] # shape: [decode_batch_size, num_speculative_tokens] - num_tokens_selected = matching.cumprod(axis=1).sum(axis=1)+1 # shape: [decode_batch_size] - all_accept[valid_batch_indices] = (num_tokens_selected[valid_batch_indices] == num_speculative_tokens+1) + draft_tokens = tlm_precode_inputs["input_ids"][:, 1:] + matching = draft_tokens == target_tokens[:, :-1] # shape: [decode_batch_size, num_speculative_tokens] + num_tokens_selected = matching.cumprod(axis=1).sum(axis=1) + 1 # shape: [decode_batch_size] + all_accept[valid_batch_indices] = num_tokens_selected[valid_batch_indices] == num_speculative_tokens + 1 mean_num_accepted_tokens += num_tokens_selected[valid_batch_indices].mean().item() # append selected tokens to the generated_ids - tlm_precode_position_ids = tlm_precode_inputs["position_ids"] + num_tokens_selected.reshape(decode_batch_size,1) - #tlm_precode_position_ids = tlm_precode_inputs["position_ids"] + num_tokens_selected.reshape(decode_batch_size,1)+1 - for bi,valid in enumerate(valid_batch_indices): - if not valid: continue + tlm_precode_position_ids = tlm_precode_inputs["position_ids"] + num_tokens_selected.reshape( + decode_batch_size, 1 + ) + # tlm_precode_position_ids = tlm_precode_inputs["position_ids"] + num_tokens_selected.reshape(decode_batch_size,1)+1 + for bi, valid in enumerate(valid_batch_indices): + if not valid: + continue accepted_tokens = num_tokens_selected[bi] num_tokens_to_append = min(accepted_tokens, max_gen_len[bi] - len(generated_ids[bi])) generated_ids[bi].extend(target_tokens[bi, :num_tokens_to_append].tolist()) - # position_ids > ctx_len-1 result in erronous output for logits at each seq_len of TLM + # position_ids > ctx_len-1 result in erronous output for logits at each seq_len of TLM # (e.g., ctx_len=128 -> position_ids=[127,128,129] will give erronous output at each predicted token) - if len(generated_ids[bi]) >= max_gen_len[bi] or (tlm_precode_position_ids[bi] > ctx_len-1).any(): + if len(generated_ids[bi]) >= max_gen_len[bi] or (tlm_precode_position_ids[bi] > ctx_len - 1).any(): valid_batch_indices[bi] = False # check if all generations are done - if not valid_batch_indices.any(): break + if not valid_batch_indices.any(): + break # prepare decode inputs for next decode iteration - num_valid_batch_indices = (valid_batch_indices == True).sum().item() - common_input_ids = target_tokens[valid_batch_indices, num_tokens_selected[valid_batch_indices]-1].reshape(num_valid_batch_indices, 1) - common_position_ids = tlm_precode_inputs["position_ids"][valid_batch_indices, num_tokens_selected[valid_batch_indices]-1].reshape(num_valid_batch_indices, 1)+1 + num_valid_batch_indices = valid_batch_indices.sum().item() + common_input_ids = target_tokens[valid_batch_indices, num_tokens_selected[valid_batch_indices] - 1].reshape( + num_valid_batch_indices, 1 + ) + common_position_ids = ( + tlm_precode_inputs["position_ids"][ + valid_batch_indices, num_tokens_selected[valid_batch_indices] - 1 + ].reshape(num_valid_batch_indices, 1) + + 1 + ) if all_accept.any(): # all_accept input_ids input_ids = np.zeros((decode_batch_size, 2), dtype=np.int64) - last_spec_token_id = target_tokens[valid_batch_indices, -2].reshape(-1,1) + last_spec_token_id = target_tokens[valid_batch_indices, -2].reshape(-1, 1) input_ids[valid_batch_indices] = np.concatenate([last_spec_token_id, common_input_ids], axis=1) dlm_decode_inputs["input_ids"] = input_ids # all_accept position_ids - position_ids = np.full((decode_batch_size,2), -1, dtype=np.int64) - last_spec_position_id = tlm_precode_inputs["position_ids"][valid_batch_indices, -1].reshape(-1,1) + position_ids = np.full((decode_batch_size, 2), -1, dtype=np.int64) + last_spec_position_id = tlm_precode_inputs["position_ids"][valid_batch_indices, -1].reshape(-1, 1) position_ids[valid_batch_indices] = np.concatenate([last_spec_position_id, common_position_ids], axis=1) dlm_decode_inputs["position_ids"] = position_ids else: dlm_decode_inputs["input_ids"][valid_batch_indices] = common_input_ids dlm_decode_inputs["position_ids"][valid_batch_indices] = common_position_ids - tlm_precode_inputs["input_ids"][valid_batch_indices,0] = common_input_ids.flatten() - tlm_precode_inputs["position_ids"][valid_batch_indices] += num_tokens_selected[valid_batch_indices].reshape(num_valid_batch_indices,1) + tlm_precode_inputs["input_ids"][valid_batch_indices, 0] = common_input_ids.flatten() + tlm_precode_inputs["position_ids"][valid_batch_indices] += num_tokens_selected[valid_batch_indices].reshape( + num_valid_batch_indices, 1 + ) end = perf_counter() decode_end = end - decode_start e2e_end = end - e2e_start mean_ttft = sum(ttfts) / len(ttfts) - generated_tokens_per_prompt = [len(gid)+1 for gid in generated_ids] + generated_tokens_per_prompt = [len(gid) + 1 for gid in generated_ids] decode_throughput = sum(generated_tokens_per_prompt) / decode_end - e2e_throughput = (sum(generated_tokens_per_prompt)+decode_batch_size) / e2e_end + e2e_throughput = (sum(generated_tokens_per_prompt) + decode_batch_size) / e2e_end batch_decode = tokenizer.batch_decode(generated_ids) mean_num_accepted_tokens /= it print(f"Avg TLM+DLM TTFT = {mean_ttft}") @@ -299,12 +322,18 @@ def test_spec_decode_inference( print("Avg number of accepted tokens = ", mean_num_accepted_tokens) print("Max generation len = ", max_gen_len) print("Total Generated Tokens per Prompt: = ", generated_tokens_per_prompt) - for prompt,generation in zip(prompts, batch_decode): + for prompt, generation in zip(prompts, batch_decode): print(f"{prompt=} {generation=}") # validation check - assert mean_num_accepted_tokens == float(num_speculative_tokens+1), f"mean number of accepted tokens is {mean_num_accepted_tokens} but should be {num_speculative_tokens+1}" + assert mean_num_accepted_tokens == float( + num_speculative_tokens + 1 + ), f"mean number of accepted tokens is {mean_num_accepted_tokens} but should be {num_speculative_tokens+1}" del target_model_session del draft_model_session exec_info = draft_model.generate(tokenizer, Constants.INPUT_STR, device_group) - cloud_ai_100_tokens = exec_info.generated_ids[0][:max_gen_len[0]] # Because we always run for single input and single batch size - assert (cloud_ai_100_tokens == np.asarray(generated_ids)[0]).all(), "Tokens don't match for SpD output and vanilla DLM output." + cloud_ai_100_tokens = exec_info.generated_ids[0][ + : max_gen_len[0] + ] # Because we always run for single input and single batch size + assert ( + cloud_ai_100_tokens == np.asarray(generated_ids)[0] + ).all(), "Tokens don't match for SpD output and vanilla DLM output." From 4c1662a881905cd0d9658d262ec7e9f39e8599c7 Mon Sep 17 00:00:00 2001 From: eplatero Date: Wed, 18 Dec 2024 07:45:20 -0600 Subject: [PATCH 12/12] added higher spec_len Signed-off-by: eplatero --- tests/transformers/spd/test_spd_inference.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index 9a17d2a63..6e1b70f79 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -20,7 +20,7 @@ configs = [ pytest.param( Constants.INPUT_STR, # prompt - 1, # num_speculative_tokens + 4, # num_speculative_tokens 32, # prefill_seq_len 128, # ctx_len 1, # prefill_bsz @@ -330,10 +330,11 @@ def test_spec_decode_inference( ), f"mean number of accepted tokens is {mean_num_accepted_tokens} but should be {num_speculative_tokens+1}" del target_model_session del draft_model_session + generated_ids = np.asarray(generated_ids).flatten() + gen_len = generated_ids.shape[0] exec_info = draft_model.generate(tokenizer, Constants.INPUT_STR, device_group) cloud_ai_100_tokens = exec_info.generated_ids[0][ - : max_gen_len[0] + :gen_len ] # Because we always run for single input and single batch size - assert ( - cloud_ai_100_tokens == np.asarray(generated_ids)[0] - ).all(), "Tokens don't match for SpD output and vanilla DLM output." + all_matching = np.array_equal(cloud_ai_100_tokens, generated_ids) + assert all_matching, "Tokens don't match for SpD output and vanilla DLM output."