diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index f2b3714fa..c36a8d1e3 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -301,23 +301,28 @@ def _compile( # Write mdp_config.json file if mdp_ts_num_devices > 1: - num_cores = compiler_options.get("aic_num_cores", 16) - mdp_ts_json = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json" - with open(mdp_ts_json, "w") as fp: - json.dump( - { - "connections": [{"devices": list(range(mdp_ts_num_devices)), "type": "p2p"}], - "partitions": [ - { - "name": "Partition0", - "devices": [{"deviceId": d, "numCores": num_cores} for d in range(mdp_ts_num_devices)], - } - ], - }, - fp, - indent=4, - ) - command.append(f"-mdp-load-partition-config={mdp_ts_json}") + if compiler_options.get("mdp_ts_json", None): + command.append(f"-mdp-load-partition-config={mdp_ts_json}") + else: + num_cores = compiler_options.get("aic_num_cores", 16) + mdp_ts_json = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json" + with open(mdp_ts_json, "w") as fp: + json.dump( + { + "connections": [{"devices": list(range(mdp_ts_num_devices)), "type": "p2p"}], + "partitions": [ + { + "name": "Partition0", + "devices": [ + {"deviceId": d, "numCores": num_cores} for d in range(mdp_ts_num_devices) + ], + } + ], + }, + fp, + indent=4, + ) + command.append(f"-mdp-load-partition-config={mdp_ts_json}") command.append(f"-aic-binary-dir={qpc_path}") logger.info(f"Running compiler: {' '.join(command)}") diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index d88c11d46..7e51dc842 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1522,10 +1522,27 @@ def compile( :aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``. :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.`` :qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.`` + :prefill_only (bool): To compile model only for ``prefill`` part. ``Dafaults to False``. + :decode_only (bool): To compile model only decoder part. ``Dafaults to False``. + :compiler_options (dict, optional): Any other options that the `qaic-exec` takes. ``Defaults to None``. Returns: :str: Path of the compiled ``qpc`` package. """ + self.onnx_path = onnx_path + self.compile_dir = compile_dir + self.num_cores = num_cores + self.num_devices = num_devices + self.batch_size = batch_size + self.prefill_seq_len = prefill_seq_len + self.ctx_len = ctx_len + self.full_batch_size = full_batch_size + self.mxfp6_matmul = mxfp6_matmul + self.mxint8_kv_cache = mxint8_kv_cache + self.num_speculative_tokens = num_speculative_tokens + self.enable_qnn = enable_qnn + self.qnn_config = qnn_config + if self.is_tlm: # assert num_speculative_tokens cfg is acceptable if defined if num_speculative_tokens is None: @@ -1548,9 +1565,10 @@ def compile( "Prefix caching is enabled only for continuous batching as of now. Please pass `full_batch_size` argument and make sure you pass `continuous_batching=True` in the `from_pretrained` call" ) - kv_cache_batch_size = ( + self.kv_cache_batch_size = ( kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size) ) + # Define prefill specialization prefill_specialization = { # Prefill is always run with single BS for continuous batching. @@ -1560,16 +1578,22 @@ def compile( # TODO: should be renamed to kv_cache_batch_size in specialization too } prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ... - if self.continuous_batching: - prefill_specialization.update({"full_batch_size": kv_cache_batch_size}) - else: - prefill_specialization.update({"batch_size": kv_cache_batch_size}) + + prefill_specialization.update( + {"full_batch_size" if self.continuous_batching else "batch_size": self.kv_cache_batch_size} + ) + prefill_specialization.update({"full_batch_exec_size": full_batch_size}) if full_batch_size else ... - specializations = [ - prefill_specialization, - ] + + specializations = [prefill_specialization] + + # Compile for prefill_only + if compiler_options.pop("prefill_only", False): + prefill_qpc_path = self.compile_model(compiler_options, [prefill_specialization]) + return prefill_qpc_path # Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization + # Define decode specialization if prefill_seq_len != 1 or self.continuous_batching: decode_specialization = { "batch_size": full_batch_size if self.continuous_batching else batch_size, @@ -1577,52 +1601,60 @@ def compile( "ctx_len": ctx_len, } if self.continuous_batching: - decode_specialization.update({"full_batch_size": kv_cache_batch_size}) + decode_specialization.update({"full_batch_size": self.kv_cache_batch_size}) else: - decode_specialization.update({"batch_size": kv_cache_batch_size}) + decode_specialization.update({"batch_size": self.kv_cache_batch_size}) decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ... specializations.append(decode_specialization) - if enable_qnn: + if compiler_options.pop("decode_only", False): + decode_qpc_path = self.compile_model(compiler_options, [decode_specialization]) + return decode_qpc_path + + qpc_path = self.compile_model(compiler_options, specializations) + return qpc_path + + def compile_model(self, compiler_options, specializations): + if self.enable_qnn: if compiler_options: logger.warning("Extra arguments to QNN compilation are supported via qnn_config.json only") qpc_path = self._qnn_compile( - onnx_path, - compile_dir, + self.onnx_path, + self.compile_dir, specializations=specializations, - prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, - batch_size=batch_size, - full_batch_size=full_batch_size, - mdp_ts_num_devices=num_devices, - num_cores=num_cores, - mxfp6_matmul=mxfp6_matmul, - mxint8_kv_cache=mxint8_kv_cache, - qnn_config=qnn_config, - kv_cache_batch_size=kv_cache_batch_size, + prefill_seq_len=self.prefill_seq_len, + ctx_len=self.ctx_len, + batch_size=self.batch_size, + full_batch_size=self.full_batch_size, + mdp_ts_num_devices=self.num_devices, + num_cores=self.num_cores, + mxfp6_matmul=self.mxfp6_matmul, + mxint8_kv_cache=self.mxint8_kv_cache, + qnn_config=self.qnn_config, + kv_cache_batch_size=self.kv_cache_batch_size, ) else: # Custom IO custom_io = {} - kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" + kv_cache_dtype = "mxint8" if self.mxint8_kv_cache else "float16" for suffix in ["", "_RetainedState"]: for i in range(self.num_layers): for kv in ["key", "value"]: custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype qpc_path = self._compile( - onnx_path, - compile_dir, + self.onnx_path, + self.compile_dir, compile_only=True, retained_state=True, specializations=specializations, convert_to_fp16=True, - mxfp6_matmul=mxfp6_matmul, + mxfp6_matmul=self.mxfp6_matmul, custom_io=custom_io, - mdp_ts_num_devices=num_devices, - num_speculative_tokens=num_speculative_tokens, - aic_num_cores=num_cores, + mdp_ts_num_devices=self.num_devices, + num_speculative_tokens=self.num_speculative_tokens, + aic_num_cores=self.num_cores, **compiler_options, ) return qpc_path