diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index a01b05daa17..087e4d1efdc 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -551,6 +551,13 @@ def build_args_parser() -> argparse.ArgumentParser: help="path to the input pruning token mapping file (token_map.json)", ) + parser.add_argument( + "--nncf_compression", + default=False, + action="store_true", + help="If true, stops right after torch.export() and saves the exported model.", + ) + parser.add_argument( "--export_only", default=False, @@ -1207,6 +1214,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager": use_legacy_export=llm_config.backend.qnn.enabled, save_exported_program=llm_config.export.export_only, verbose=llm_config.debug.verbose, + nncf_compression=llm_config.nncf_compression, metadata=_load_llama_model_metadata( WeightType.FAIRSEQ2 if llm_config.base.fairseq2 else WeightType.LLAMA, llm_config.model.use_kv_cache, diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 4128bfd8198..f185d9b346d 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -16,6 +16,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple from unittest.mock import patch +import nncf import torch from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( DuplicateDynamicQuantChainPass, @@ -40,6 +41,7 @@ from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer from torchao.utils import unwrap_tensor_subclass +from functools import partial FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -98,6 +100,7 @@ def __init__( dynamic_shapes: Optional[Any] = None, use_legacy_export: bool = False, save_exported_program: bool = False, + nncf_compression: bool = False ): # Store necessary constructor arguments. self.model = model @@ -119,6 +122,7 @@ def __init__( self.dynamic_shapes = dynamic_shapes self.use_legacy_export = use_legacy_export self.save_exported_program = save_exported_program + self.nncf_compression = nncf_compression # Note: treat this as the source of truth for the result of # torch.export'ing a model. If the overall ExportedProgram is needed, @@ -428,6 +432,34 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage DuplicateDynamicQuantChainPass()(m) self.pre_autograd_graph_module = m return self + elif (self.nncf_compression): + tokenizer = get_tokenizer(self.tokenizer_path) + + def transform_fn( + prompts: str, tokenizer + ): + tokenized_text = tokenizer.encode(prompts, bos=False, eos=False) + logging.error(tokenized_text) + + inputs = () + inputs = ( + torch.tensor(tokenized_text).unsqueeze(0), + {"input_pos": torch.tensor([0])}, + ) + + return inputs + + self.calibration_data = [self.calibration_data] if isinstance(self.calibration_data, str) else self.calibration_data + self.calibration_data = [word for prompt in self.calibration_data for word in prompt.split()] if not self.dynamic_shapes else self.calibration_data + + self.pre_autograd_graph_module = nncf.compress_weights( + self.pre_autograd_graph_module, + dataset=nncf.Dataset(self.calibration_data, transform_func=partial(transform_fn, tokenizer=tokenizer)), + mode=nncf.CompressWeightsMode.INT4_SYM, + ratio=0.8, + sensitivity_metric=nncf.SensitivityMetric.HESSIAN_INPUT_ACTIVATION, + ) + return self else: logging.info("No quantizer provided, passing...") return self