diff --git a/backends/xnnpack/partition/xnnpack_partitioner.py b/backends/xnnpack/partition/xnnpack_partitioner.py index e066f2a72ea..a6c7d9576df 100644 --- a/backends/xnnpack/partition/xnnpack_partitioner.py +++ b/backends/xnnpack/partition/xnnpack_partitioner.py @@ -1185,7 +1185,9 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: partitions = [ Partition( id=next(partition_id), - nodes=set(match), + nodes=set( + filter(lambda x: x.target != torch.ops.aten.sym_size.int, match) + ), ) for match in self.get_module_partitions(exported_program) ] diff --git a/examples/models/llama2/builder.py b/examples/models/llama2/builder.py index c8d949eb6f2..ebb5cee9d13 100644 --- a/examples/models/llama2/builder.py +++ b/examples/models/llama2/builder.py @@ -76,6 +76,7 @@ def load_llama_model( use_kv_cache: bool = False, use_sdpa_with_kv_cache: bool = False, weight_type: WeightType = WeightType.LLAMA, + enable_dynamic_shape: bool = False, verbose: bool = False, max_seq_len: int = 128, ) -> "LlamaEdgeManager": @@ -101,6 +102,7 @@ def load_llama_model( use_sdpa_with_kv_cache=use_sdpa_with_kv_cache, fairseq2=weight_type == WeightType.FAIRSEQ2, max_seq_len=max_seq_len, + enable_dynamic_shape=enable_dynamic_shape, ) state_dict = model.state_dict() dtype = state_dict[next(iter(state_dict))].dtype @@ -128,6 +130,7 @@ def load_llama_model( use_kv_cache=use_kv_cache, use_sdpa_with_kv_cache=use_sdpa_with_kv_cache, example_inputs=example_inputs, + enable_dynamic_shape=enable_dynamic_shape, verbose=verbose, ) @@ -146,6 +149,7 @@ def __init__( use_kv_cache, use_sdpa_with_kv_cache, example_inputs, + enable_dynamic_shape: bool = False, verbose: bool = False, ): self.model = model @@ -156,6 +160,7 @@ def __init__( self.dtype = dtype self.example_inputs = example_inputs self.use_kv_cache = use_kv_cache + self.enable_dynamic_shape = enable_dynamic_shape self.use_sdpa_with_kv_cache = use_sdpa_with_kv_cache self.metadata = None self.verbose = verbose @@ -220,7 +225,10 @@ def source_transform( def _get_dynamic_shape(self) -> Any: dim = torch.export.Dim("token_dim", max=self.model.params.max_seq_len - 1) if self.use_kv_cache: - return None + if self.enable_dynamic_shape: + return ({1: dim}, {0: dim}) + else: + None else: return ({1: dim},) @@ -250,6 +258,7 @@ def _get_metadata(self): "get_vocab_size": params.vocab_size, "use_kv_cache": self.use_kv_cache, "use_sdpa_with_kv_cache": self.use_sdpa_with_kv_cache, + "enable_dynamic_shape": self.enable_dynamic_shape, } if self.metadata: try: diff --git a/examples/models/llama2/custom_ops/sdpa_with_kv_cache.py b/examples/models/llama2/custom_ops/sdpa_with_kv_cache.py index bada40220bc..1b89dddce3a 100644 --- a/examples/models/llama2/custom_ops/sdpa_with_kv_cache.py +++ b/examples/models/llama2/custom_ops/sdpa_with_kv_cache.py @@ -86,8 +86,6 @@ def _validate_params( # 1 # ), f"Start position + length = {start_pos + seq_len} must be less than sequence length {key_cache.size(2)}" - assert seq_len == 1, "Only support seq_len = 1 for now." - if attn_mask is not None: assert ( attn_mask.dim() == 2 diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index c1fae0eb77b..c62813643bf 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -174,6 +174,13 @@ def build_args_parser() -> argparse.ArgumentParser: action="store_true", help="Whether to use sdpa_with_kv_cache update op when using kv cache", ) + parser.add_argument( + "--disable_dynamic_shape", + dest="enable_dynamic_shape", + default=True, # Enable this by default + action="store_false", + help="Enable dynamic shape along seq dim. Used for faster prefill", + ) parser.add_argument( "-p", "--params", @@ -369,6 +376,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: use_kv_cache=args.use_kv_cache, use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, weight_type=weight_type, + enable_dynamic_shape=args.enable_dynamic_shape, verbose=args.verbose, max_seq_len=args.max_seq_length, ) @@ -391,7 +399,19 @@ def get_quantizer_and_quant_params(args): return pt2e_quant_params, quantizers, quant_dtype +def _validate_args(args): + """ + TODO: Combine all the backends under --backend args + """ + if args.enable_dynamic_shape and (args.coreml or args.mps or args.qnn): + raise ValueError( + "Dynamic shape is not supported with coreml, MPS or qnn backends." + " Please us --disble_dynamic_shape." + ) + + def _export_llama(modelname, args) -> LlamaEdgeManager: # noqa: C901 + _validate_args(args) pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) # export_to_edge diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 298e0463c07..8bebfb8e9b8 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -191,6 +191,7 @@ def __init__( dtype=torch.float32, ): super().__init__() + self.max_seq_length = max_seq_length if transpose_cache: cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) else: @@ -208,12 +209,23 @@ def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache - k_out = self.k_cache - v_out = self.v_cache - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - - return k_out, v_out + start_pos = input_pos[-1].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_seq_length) + seq_length = k_val.size(2) + # Replace the entry in the cache for this token + # The following lines are equivalent to: + # cache_k[:bsz, start_pos : start_pos + seqlen] = xk + # cache_v[:bsz, start_pos : start_pos + seqlen] = xv + # We use .narrow() here to make the compiler happy + # pyre-ignore: Incompatible parameter type [6] + narrowed_k = self.k_cache.narrow(2, start_pos, seq_length) + # pyre-ignore: Incompatible parameter type [6] + narrowed_v = self.v_cache.narrow(2, start_pos, seq_length) + + narrowed_k.copy_(k_val) + narrowed_v.copy_(v_val) + return self.k_cache, self.v_cache class SDPA(nn.Module): @@ -223,12 +235,14 @@ def __init__( dim: int, head_dim: int, n_rep: int, + max_seq_len: int, ): super().__init__() self.kv_cache = kv_cache self.dim = dim self.head_dim = head_dim self.n_rep = n_rep + self.max_seq_len = max_seq_len def forward( self, @@ -245,7 +259,12 @@ def forward( v = v.transpose(1, 2) k, v = self.kv_cache.update(input_pos, k, v) - attn_mask = mask[None, None, input_pos] + start_pos = input_pos[-1].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_seq_len) + seq_length = q.size(2) + # pyre-ignore: Incompatible parameter type [6] + attn_mask = mask.narrow(0, start_pos, seq_length) k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) @@ -299,6 +318,7 @@ def __init__(self, args: ModelArgs, layer_id: int): dim=self.dim, head_dim=self.head_dim, n_rep=self.n_rep, + max_seq_len=self.max_seq_len, ) def forward( @@ -447,6 +467,7 @@ def __init__(self, params: ModelArgs): self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.use_kv_cache = params.use_kv_cache + self.max_seq_len = params.max_seq_len freqs_cos, freqs_sin = precompute_freqs_cis( params.dim // params.n_heads, @@ -476,8 +497,17 @@ def forward( ), "input_pos must be provided when use_kv_cache is True" # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. - freqs_cos = self.freqs_cos[input_pos] - freqs_sin = self.freqs_sin[input_pos] + input_pos_item = input_pos[-1].item() + torch._check_is_size(input_pos_item) + # Setting this value to 32 for no particular reason. + # It is mainly to make export happy as the resulting + # asserts are ignored anyway. + # We really need unbounded start_pos + torch._check(input_pos_item < self.params.max_seq_len) + # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor + freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen) + # pyre-ignore: Incompatible parameter type [6] + freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen) else: assert input_pos is None, "input_pos is unused when use_kv_cache is False" freqs_cos = self.freqs_cos[:seqlen] diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index aa997aa56ea..197de2289b0 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -74,6 +74,11 @@ def __init__(self, **kwargs): if "use_sdpa_with_kv_cache" in kwargs else False ) + self.enable_dynamic_shape = ( + kwargs["enable_dynamic_shape"] + if "enable_dynamic_shape" in kwargs + else False + ) self.max_seq_len = kwargs["max_seq_len"] if "max_seq_len" in kwargs else 128 # The example is using a dummy small model with random weights for demo purpose only. @@ -220,11 +225,17 @@ def get_example_inputs(self): # assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working def get_example_inputs_kvcache_sdpa(self): - return ( - torch.tensor( - [[1]], dtype=torch.long - ), # tokens, with kv cache our input token length is always just 1 token. - torch.tensor( - [0], dtype=torch.long - ), # start_pos, what token of output are we on.) - ) + if self.enable_dynamic_shape: + return ( + torch.tensor([[2, 3, 4]], dtype=torch.long), + torch.tensor([0, 1, 2], dtype=torch.long), + ) + else: + return ( + torch.tensor( + [[1]], dtype=torch.long + ), # tokens, with kv cache our input token length is always just 1 token. + torch.tensor( + [0], dtype=torch.long + ), # start_pos, what token of output are we on. + ) diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index 559c71fd81a..fd560bca3f7 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -76,6 +76,7 @@ Error Runner::load() { use_kv_cache_ = getMetadataHelper("use_kv_cache", true); use_sdpa_with_kv_cache_ = getMetadataHelper("use_sdpa_with_kv_cache", false); append_eos_ = getMetadataHelper("append_eos_to_prompt", false); + enable_parallel_prefill_ = getMetadataHelper("enable_dynamic_shape", false); // Load tokenizer #if ET_USE_TIKTOKEN @@ -122,20 +123,141 @@ T Runner::getMetadataHelper(const std::string& method_name, T default_val) { return res; } -template -int32_t Runner::logitsToToken( - const exec_aten::Tensor& logits_tensor, - int64_t pos, - T _) { - (void)_; - T* logits = logits_tensor.mutable_data_ptr(); - - // Since the logits are for all tokens, get the last token probabilities - T* logits_last = logits; - if (!use_kv_cache_) { - logits_last += pos * tokenizer_->vocab_size(); +int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) { + ET_CHECK_MSG(logits_tensor.dim() == 3, "Logits tensor must be 3D"); + auto num_tokens = logits_tensor.size(1); + + switch (logits_tensor.scalar_type()) { + case ScalarType::Float: { + float* logits = logits_tensor.mutable_data_ptr(); + float* logits_last = logits; + logits_last += (num_tokens - 1) * tokenizer_->vocab_size(); + return sampler_->sample(logits_last); + } + case ScalarType::Half: { + exec_aten::Half* logits = + logits_tensor.mutable_data_ptr(); + exec_aten::Half* logits_last = logits; + logits_last += (num_tokens - 1) * tokenizer_->vocab_size(); + return sampler_->sample(logits_last); + } + default: + ET_CHECK_MSG( + false, + "Unsupported dtype output %hhd", + static_cast(logits_tensor.scalar_type())); + } +} + +Result Runner::prefill( + const std::vector& tokens, + ManagedTensor& managed_tokens, + ManagedTensor& managed_start_pos, + std::function token_callback) { + // enable_parallel_prefill_ maybe set even when not using kv cache + // When kv cache is not used, start pos is ignored + int32_t num_tokens = tokens.size(); + if (enable_parallel_prefill_) { + managed_tokens.resize({1, num_tokens}); + int64_t* tokens_ptr = + managed_tokens.get_aliasing_tensor().mutable_data_ptr(); + for (int i = 0; i < num_tokens; i++) { + // The following assumes batch size = 1 + tokens_ptr[i] = tokens[i]; + } + std::vector inputs; + auto tokens_tensor = managed_tokens.get_aliasing_tensor(); + auto start_pos = managed_start_pos.get_aliasing_tensor(); + + // inputs:[tokens, start_pos] + inputs.push_back(tokens_tensor); + inputs.push_back(start_pos); + + Result> outputs_res = module_->forward(inputs); + ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); + ET_CHECK_MSG( + outputs_res.get()[0].isTensor(), + "Non Tensor Output returned from executing LLM"); + ET_CHECK_MSG( + outputs_res.get()[0].toTensor().size(1) == num_tokens, + "Expected number of output tokens %d does not match returned value %zu.", + num_tokens, + outputs_res.get()[0].toTensor().size(1)); + + start_pos.mutable_data_ptr()[0] = num_tokens; + + uint64_t prev = tokens[0]; + uint64_t cur; + for (int i = 1; i < num_tokens; i++) { + cur = tokens[i]; + auto piece_res = tokenizer_->decode(prev, cur); + ET_CHECK_OK_OR_RETURN_ERROR(piece_res.error()); + util::safe_printf(piece_res.get().c_str()); + fflush(stdout); + prev = cur; + if (token_callback) { + token_callback(piece_res.get().c_str()); + } + } + cur = logitsToToken(outputs_res.get()[0].toTensor()); + auto piece_res = tokenizer_->decode(prev, cur); + ET_CHECK(piece_res.ok()); + const char* piece = piece_res.get().c_str(); + util::safe_printf(piece); + fflush(stdout); + if (token_callback) { + token_callback(piece_res.get().c_str()); + } + + // Return the logits tensor + stats_.first_token_ms = util::time_in_ms(); + stats_.prompt_eval_end_ms = util::time_in_ms(); + return outputs_res.get()[0].toTensor(); + } else { // sequential prefill + int64_t pos = 0; // position in the sequence + int64_t cur_token = tokens[0]; + int64_t prev_token; + // This is a hack to enable returning a logits tensor from prefill + auto logits_tensor = managed_tokens.get_aliasing_tensor(); + while (pos < num_tokens) { + // Run the model + Result logits_res = run_model_step( + cur_token, managed_tokens, managed_start_pos, num_tokens); + + ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); + logits_tensor = logits_res.get(); + // Hack to enable returning a logits tensor from prefill + + prev_token = cur_token; + + long sample_start_time_ms = util::time_in_ms(); + cur_token = logitsToToken(logits_tensor); + stats_.aggregate_sampling_time_ms += + util::time_in_ms() - sample_start_time_ms; + + // advance the state machine + if (pos < num_tokens - 1) { + // prefill, force the next token to be the next prompt token + cur_token = tokens[pos + 1]; + } + pos++; + + // print the token as string, decode it with the Tokenizer object + auto piece_res = tokenizer_->decode(prev_token, cur_token); + ET_CHECK(piece_res.ok()); + const char* piece = piece_res.get().c_str(); + util::safe_printf(piece); + fflush(stdout); + if (token_callback) { + token_callback(piece_res.get().c_str()); + } + } + auto start_pos = managed_start_pos.get_aliasing_tensor(); + start_pos.mutable_data_ptr()[0] = num_tokens; + stats_.first_token_ms = util::time_in_ms(); + stats_.prompt_eval_end_ms = util::time_in_ms(); + return logits_tensor; } - return sampler_->sample(logits_last); } // Given an input token. Set up the inputs for the model and execute a single @@ -197,6 +319,9 @@ Result Runner::run_model_step( if (tokens.size(1) < max_seq_len) { // Resize the tokens tensor to be 1 larger for next step. + // Note that this relies on the fact that underlying memory is the same + // such that previous tokens stored there will still exist. + // Not a good thing to rely upon. managed_tokens.resize({1, static_cast(tokens.size(1) + 1)}); } @@ -256,16 +381,11 @@ Error Runner::generate( std::vector start_pos_data; // allocate space for the tokens std::vector start_pos_shape = {1}; + token_data.resize(seq_len); if (use_kv_cache_) { // hard code these to size 1 as kv cache is locked to static size right now. - token_data.resize(1); - token_shape[1] = 1; start_pos_data.resize(1); start_pos_data.push_back(0); - } else { - // reserve data for tokens, notice the size is still 0 but the capacity is - // seq_len. - token_data.resize(seq_len); } // initialize tensor wrappers @@ -281,31 +401,35 @@ Error Runner::generate( int64_t prev_token; int64_t cur_token = prompt_tokens[0]; - // If we arent using the kv cache then we can batch prefill the prompt - if (!use_kv_cache_) { - tokens_managed.resize({1, num_prompt_tokens}); - for (int i = 0; i < num_prompt_tokens - 1; i++) { - tokens_managed.get_aliasing_tensor().mutable_data_ptr()[i] = - prompt_tokens[i]; - } - // prefill tokens up to the last prompt token and then enter the loop with - // the last promp token as the current token. - cur_token = prompt_tokens[num_prompt_tokens - 1]; - pos = num_prompt_tokens - 1; - - // Print the prompt for consistent output between single token prefill and - // batch prefill. - uint64_t prev = prompt_tokens[0]; - uint64_t cur; - for (int i = 1; i < num_prompt_tokens; i++) { - cur = prompt_tokens[i]; - auto piece_res = tokenizer_->decode(prev, cur); - ET_CHECK_OK_OR_RETURN_ERROR(piece_res.error()); - util::safe_printf(piece_res.get().c_str()); - fflush(stdout); - prev = cur; - } + // Prefill first + // Here feed all tokens to the model and get the next predicted token + // after the prompt. After that we will enter generate loop. + auto prefill_res = + prefill(prompt_tokens, tokens_managed, start_pos_managed, token_callback); + ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); + exec_aten::Tensor& prefill_res_tensor = prefill_res.get(); + cur_token = logitsToToken(prefill_res_tensor); + if (use_kv_cache_) { + // Prefill could be parallel or sequential. + // Parallel: + // kv cache: + // - tokens_managed should resized to 1 as inference expects one token at + // a time. + // no kv cache: + // - tokens_managed should be resized to prompt length + 1, as inference + // expects all tokens at once. + // Sequential prefill: + // kv cache: + // - tokens_managed should be resized to 1, as inference expects one + // token at a time. + // no kv cache: + // - tokens_managed should be resized to prompt length + 1, as inference + // expects all tokens at once. + tokens_managed.resize({1, 1}); + } else { + tokens_managed.resize({1, num_prompt_tokens + 1}); } + pos = num_prompt_tokens; // Generate our tokens while (pos < seq_len - 1) { @@ -313,41 +437,16 @@ Error Runner::generate( Result logits_res = run_model_step(cur_token, tokens_managed, start_pos_managed, seq_len); - if (pos == num_prompt_tokens) { - stats_.first_token_ms = util::time_in_ms(); - } else if (pos == num_prompt_tokens - 1) { - stats_.prompt_eval_end_ms = util::time_in_ms(); - } - ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); exec_aten::Tensor& logits_tensor = logits_res.get(); prev_token = cur_token; long sample_start_time_ms = util::time_in_ms(); - switch (logits_tensor.scalar_type()) { - case ScalarType::Float: { - cur_token = logitsToToken(logits_tensor, pos, 0); - break; - } - case ScalarType::Half: { - cur_token = logitsToToken(logits_tensor, pos, 0); - break; - } - default: - ET_CHECK_MSG( - false, - "Unsupported dtype output %hhd", - static_cast(logits_tensor.scalar_type())); - } + cur_token = logitsToToken(logits_tensor); stats_.aggregate_sampling_time_ms += util::time_in_ms() - sample_start_time_ms; - // advance the state machine - if (pos < num_prompt_tokens - 1) { - // prefill, force the next token to be the next prompt token - cur_token = prompt_tokens[pos + 1]; - } pos++; // print the token as string, decode it with the Tokenizer object diff --git a/examples/models/llama2/runner/runner.h b/examples/models/llama2/runner/runner.h index 4e200d5e6ca..ff76d205d53 100644 --- a/examples/models/llama2/runner/runner.h +++ b/examples/models/llama2/runner/runner.h @@ -70,9 +70,12 @@ class Runner { // metadata template T getMetadataHelper(const std::string& method_name, T default_val); - template - int32_t - logitsToToken(const exec_aten::Tensor& logits_tensor, int64_t pos, T _); + int32_t logitsToToken(const exec_aten::Tensor& logits_tensor); + Result prefill( + const std::vector& tokens, + ManagedTensor& managed_tokens, + ManagedTensor& managed_start_pos, + std::function token_callback); Result run_model_step( int64_t input_token, ManagedTensor& tokens, @@ -96,6 +99,7 @@ class Runner { std::unique_ptr sampler_; bool shouldStop_{false}; Stats stats_; + bool enable_parallel_prefill_; }; } // namespace torch::executor diff --git a/examples/models/llama2/source_transformation/sdpa.py b/examples/models/llama2/source_transformation/sdpa.py index 8a8a0cac7c2..cc5890a1753 100644 --- a/examples/models/llama2/source_transformation/sdpa.py +++ b/examples/models/llama2/source_transformation/sdpa.py @@ -41,6 +41,9 @@ def forward( self.kv_cache.v_cache, input_pos[-1].item(), seqlen, + None, # Attention mask + 0, # dropout probability. Ignored by the code + True, # is_causal ) return output.view(bsz, seqlen, self.dim) diff --git a/examples/models/llama2/tests/test_simple_sdpa.py b/examples/models/llama2/tests/test_simple_sdpa.py index 9113059fd5d..61f14e58dc5 100644 --- a/examples/models/llama2/tests/test_simple_sdpa.py +++ b/examples/models/llama2/tests/test_simple_sdpa.py @@ -32,7 +32,11 @@ def test_simple_sdpa(self): transpose_cache=True, ) sdpa = SDPA( - kv_cache=copy.deepcopy(kv_cache), dim=dim, head_dim=head_dim, n_rep=n_rep + kv_cache=copy.deepcopy(kv_cache), + dim=dim, + head_dim=head_dim, + n_rep=n_rep, + max_seq_len=max_seq_length, ) input_pos = torch.tensor([0]) query = torch.randn(1, 1, n_local_heads, head_dim) @@ -40,7 +44,13 @@ def test_simple_sdpa(self): value = torch.randn(1, 1, n_local_heads, head_dim) mask = torch.randn(max_seq_length, max_seq_length) sdpa_output = sdpa( - input_pos, query, key, value, bsz=bsz, seqlen=seqlen, mask=mask + input_pos, + query, + key, + value, + bsz=bsz, + seqlen=seqlen, + mask=mask, ) simple_sdpa = SDPASimple(