Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions examples/models/export_hf_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import argparse
import os

import torch
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig
from executorch.extension.export_util.utils import export_to_edge
from torch.nn.attention import SDPBackend
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"-hfm",
"--hf_model_repo",
required=False,
default=None,
help="a valid huggingface model repo name",
)
parser.add_argument(
"--compile",
required=False,
action="store_true",
help="run HF model in eager with torch.compile",
)
parser.add_argument(
"--export",
required=False,
action="store_true",
help="export HF model to ExecuTorch",
)

args = parser.parse_args()

# Configs to HF model
device = "cpu"
dtype = torch.float32
max_batch_size = 1
max_seq_len = 32
cache_implementation = "static"
attn_implementation = "sdpa"
use_sdpa_with_kv_cache = False
prompt = "" # Use empty prompt as a hack to avoid parallel prefill in order to verify the correctness in eager

# Load a HF model in eager
tokenizer = AutoTokenizer.from_pretrained(args.hf_model_repo)
config = AutoConfig.from_pretrained(
args.hf_model_repo,
torch_dtype=dtype,
use_cache=True,
max_length=max_seq_len,
)
model = AutoModelForCausalLM.from_pretrained(
args.hf_model_repo,
config=config,
attn_implementation=attn_implementation,
device_map=device,
)
# Make sure generation config is consistent with the model config
# TODO: In HF cache impl is a generation time config. To make the HF models work
# properly with ExecuTorch, this needs to be a config at model construction time
# and should not change at generation runtime.
model.generation_config.cache_implementation = cache_implementation
model.generation_config.max_length = max_seq_len # HF is setting this independently from config.max_length, and use this one to construct static kv cache
print(f"DEBUG model config = {model.config}")
print(f"DEBUG model generation_config = {model.generation_config}")

if args.compile:
# torch.compile
compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
input_tokens = tokenizer(prompt, return_tensors="pt")
outputs = compiled_model.generate(**input_tokens)
output_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(f"DEBUG output_texts: {output_texts}")

if args.export:
# torch.export
input_tokens = tokenizer(prompt, return_tensors="pt")
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
prog = (
export_to_edge(
model,
(
torch.tensor(
[[1]], dtype=torch.long
), # tokens, with kv cache our input token length is always just 1 token.
torch.tensor(
[0], dtype=torch.long
), # input_pos, what token of output are we on.)
),
edge_compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
),
edge_constant_methods={
"get_dtype": 5 if config.torch_dtype == torch.float16 else 6,
"get_bos_id": config.bos_token_id,
"get_eos_id": config.eos_token_id,
"get_head_dim": config.hidden_size / config.num_attention_heads,
"get_max_batch_size": max_batch_size,
"get_max_seq_len": max_seq_len,
"get_n_bos": 1,
"get_n_eos": 1,
"get_n_kv_heads": config.num_key_value_heads,
"get_n_layers": config.num_hidden_layers,
"get_vocab_size": config.vocab_size,
"use_kv_cache": config.use_cache,
"use_sdpa_with_kv_cache": use_sdpa_with_kv_cache,
},
verbose=False,
)
.to_backend(
XnnpackPartitioner(_lower_recomposed_sdpa=use_sdpa_with_kv_cache)
)
.to_executorch(
ExecutorchBackendConfig(
extract_constant_segment=True, extract_delegate_segments=True
)
)
)
filename = os.path.join("./", f"{config.model_type}.pte")
with open(filename, "wb") as f:
prog.write_to_file(f)
print(f"Saved exported program to {filename}")


if __name__ == "__main__":
main()