Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backends/xnnpack/partition/xnnpack_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,9 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
partitions = [
Partition(
id=next(partition_id),
nodes=set(match),
nodes=set(
filter(lambda x: x.target != torch.ops.aten.sym_size.int, match)
),
)
for match in self.get_module_partitions(exported_program)
]
Expand Down
11 changes: 10 additions & 1 deletion examples/models/llama2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def load_llama_model(
use_kv_cache: bool = False,
use_sdpa_with_kv_cache: bool = False,
weight_type: WeightType = WeightType.LLAMA,
enable_dynamic_shape: bool = False,
verbose: bool = False,
max_seq_len: int = 128,
) -> "LlamaEdgeManager":
Expand All @@ -101,6 +102,7 @@ def load_llama_model(
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
fairseq2=weight_type == WeightType.FAIRSEQ2,
max_seq_len=max_seq_len,
enable_dynamic_shape=enable_dynamic_shape,
)
state_dict = model.state_dict()
dtype = state_dict[next(iter(state_dict))].dtype
Expand Down Expand Up @@ -128,6 +130,7 @@ def load_llama_model(
use_kv_cache=use_kv_cache,
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
example_inputs=example_inputs,
enable_dynamic_shape=enable_dynamic_shape,
verbose=verbose,
)

Expand All @@ -146,6 +149,7 @@ def __init__(
use_kv_cache,
use_sdpa_with_kv_cache,
example_inputs,
enable_dynamic_shape: bool = False,
verbose: bool = False,
):
self.model = model
Expand All @@ -156,6 +160,7 @@ def __init__(
self.dtype = dtype
self.example_inputs = example_inputs
self.use_kv_cache = use_kv_cache
self.enable_dynamic_shape = enable_dynamic_shape
self.use_sdpa_with_kv_cache = use_sdpa_with_kv_cache
self.metadata = None
self.verbose = verbose
Expand Down Expand Up @@ -220,7 +225,10 @@ def source_transform(
def _get_dynamic_shape(self) -> Any:
dim = torch.export.Dim("token_dim", max=self.model.params.max_seq_len - 1)
if self.use_kv_cache:
return None
if self.enable_dynamic_shape:
return ({1: dim}, {0: dim})
else:
None
else:
return ({1: dim},)

Expand Down Expand Up @@ -250,6 +258,7 @@ def _get_metadata(self):
"get_vocab_size": params.vocab_size,
"use_kv_cache": self.use_kv_cache,
"use_sdpa_with_kv_cache": self.use_sdpa_with_kv_cache,
"enable_dynamic_shape": self.enable_dynamic_shape,
}
if self.metadata:
try:
Expand Down
2 changes: 0 additions & 2 deletions examples/models/llama2/custom_ops/sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ def _validate_params(
# 1
# ), f"Start position + length = {start_pos + seq_len} must be less than sequence length {key_cache.size(2)}"

assert seq_len == 1, "Only support seq_len = 1 for now."

if attn_mask is not None:
assert (
attn_mask.dim() == 2
Expand Down
20 changes: 20 additions & 0 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,13 @@ def build_args_parser() -> argparse.ArgumentParser:
action="store_true",
help="Whether to use sdpa_with_kv_cache update op when using kv cache",
)
parser.add_argument(
"--disable_dynamic_shape",
dest="enable_dynamic_shape",
default=True, # Enable this by default
action="store_false",
help="Enable dynamic shape along seq dim. Used for faster prefill",
)
parser.add_argument(
"-p",
"--params",
Expand Down Expand Up @@ -369,6 +376,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
use_kv_cache=args.use_kv_cache,
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
weight_type=weight_type,
enable_dynamic_shape=args.enable_dynamic_shape,
verbose=args.verbose,
max_seq_len=args.max_seq_length,
)
Expand All @@ -391,7 +399,19 @@ def get_quantizer_and_quant_params(args):
return pt2e_quant_params, quantizers, quant_dtype


def _validate_args(args):
"""
TODO: Combine all the backends under --backend args
"""
if args.enable_dynamic_shape and (args.coreml or args.mps or args.qnn):
raise ValueError(
"Dynamic shape is not supported with coreml, MPS or qnn backends."
" Please us --disble_dynamic_shape."
)


def _export_llama(modelname, args) -> LlamaEdgeManager: # noqa: C901
_validate_args(args)
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)

# export_to_edge
Expand Down
48 changes: 39 additions & 9 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def __init__(
dtype=torch.float32,
):
super().__init__()
self.max_seq_length = max_seq_length
if transpose_cache:
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
else:
Expand All @@ -208,12 +209,23 @@ def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val

return k_out, v_out
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_seq_length)
seq_length = k_val.size(2)
# Replace the entry in the cache for this token
# The following lines are equivalent to:
# cache_k[:bsz, start_pos : start_pos + seqlen] = xk
# cache_v[:bsz, start_pos : start_pos + seqlen] = xv
# We use .narrow() here to make the compiler happy
# pyre-ignore: Incompatible parameter type [6]
narrowed_k = self.k_cache.narrow(2, start_pos, seq_length)
# pyre-ignore: Incompatible parameter type [6]
narrowed_v = self.v_cache.narrow(2, start_pos, seq_length)

narrowed_k.copy_(k_val)
narrowed_v.copy_(v_val)
return self.k_cache, self.v_cache


class SDPA(nn.Module):
Expand All @@ -223,12 +235,14 @@ def __init__(
dim: int,
head_dim: int,
n_rep: int,
max_seq_len: int,
):
super().__init__()
self.kv_cache = kv_cache
self.dim = dim
self.head_dim = head_dim
self.n_rep = n_rep
self.max_seq_len = max_seq_len

def forward(
self,
Expand All @@ -245,7 +259,12 @@ def forward(
v = v.transpose(1, 2)

k, v = self.kv_cache.update(input_pos, k, v)
attn_mask = mask[None, None, input_pos]
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_seq_len)
seq_length = q.size(2)
# pyre-ignore: Incompatible parameter type [6]
attn_mask = mask.narrow(0, start_pos, seq_length)

k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
Expand Down Expand Up @@ -299,6 +318,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
dim=self.dim,
head_dim=self.head_dim,
n_rep=self.n_rep,
max_seq_len=self.max_seq_len,
)

def forward(
Expand Down Expand Up @@ -447,6 +467,7 @@ def __init__(self, params: ModelArgs):
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.use_kv_cache = params.use_kv_cache
self.max_seq_len = params.max_seq_len

freqs_cos, freqs_sin = precompute_freqs_cis(
params.dim // params.n_heads,
Expand Down Expand Up @@ -476,8 +497,17 @@ def forward(
), "input_pos must be provided when use_kv_cache is True"

# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
freqs_cos = self.freqs_cos[input_pos]
freqs_sin = self.freqs_sin[input_pos]
input_pos_item = input_pos[-1].item()
torch._check_is_size(input_pos_item)
# Setting this value to 32 for no particular reason.
# It is mainly to make export happy as the resulting
# asserts are ignored anyway.
# We really need unbounded start_pos
torch._check(input_pos_item < self.params.max_seq_len)
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen)
# pyre-ignore: Incompatible parameter type [6]
freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen)
else:
assert input_pos is None, "input_pos is unused when use_kv_cache is False"
freqs_cos = self.freqs_cos[:seqlen]
Expand Down
27 changes: 19 additions & 8 deletions examples/models/llama2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def __init__(self, **kwargs):
if "use_sdpa_with_kv_cache" in kwargs
else False
)
self.enable_dynamic_shape = (
kwargs["enable_dynamic_shape"]
if "enable_dynamic_shape" in kwargs
else False
)

self.max_seq_len = kwargs["max_seq_len"] if "max_seq_len" in kwargs else 128
# The example is using a dummy small model with random weights for demo purpose only.
Expand Down Expand Up @@ -220,11 +225,17 @@ def get_example_inputs(self):

# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
def get_example_inputs_kvcache_sdpa(self):
return (
torch.tensor(
[[1]], dtype=torch.long
), # tokens, with kv cache our input token length is always just 1 token.
torch.tensor(
[0], dtype=torch.long
), # start_pos, what token of output are we on.)
)
if self.enable_dynamic_shape:
return (
torch.tensor([[2, 3, 4]], dtype=torch.long),
torch.tensor([0, 1, 2], dtype=torch.long),
)
else:
return (
torch.tensor(
[[1]], dtype=torch.long
), # tokens, with kv cache our input token length is always just 1 token.
torch.tensor(
[0], dtype=torch.long
), # start_pos, what token of output are we on.
)
Loading