Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,23 +301,28 @@

# Write mdp_config.json file
if mdp_ts_num_devices > 1:
num_cores = compiler_options.get("aic_num_cores", 16)
mdp_ts_json = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json"
with open(mdp_ts_json, "w") as fp:
json.dump(
{
"connections": [{"devices": list(range(mdp_ts_num_devices)), "type": "p2p"}],
"partitions": [
{
"name": "Partition0",
"devices": [{"deviceId": d, "numCores": num_cores} for d in range(mdp_ts_num_devices)],
}
],
},
fp,
indent=4,
)
command.append(f"-mdp-load-partition-config={mdp_ts_json}")
if compiler_options.get("mdp_ts_json", None):
command.append(f"-mdp-load-partition-config={mdp_ts_json}")

Check failure on line 305 in QEfficient/base/modeling_qeff.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F821)

QEfficient/base/modeling_qeff.py:305:62: F821 Undefined name `mdp_ts_json`
else:
num_cores = compiler_options.get("aic_num_cores", 16)
mdp_ts_json = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json"
with open(mdp_ts_json, "w") as fp:
json.dump(
{
"connections": [{"devices": list(range(mdp_ts_num_devices)), "type": "p2p"}],
"partitions": [
{
"name": "Partition0",
"devices": [
{"deviceId": d, "numCores": num_cores} for d in range(mdp_ts_num_devices)
],
}
],
},
fp,
indent=4,
)
command.append(f"-mdp-load-partition-config={mdp_ts_json}")

command.append(f"-aic-binary-dir={qpc_path}")
logger.info(f"Running compiler: {' '.join(command)}")
Expand Down
92 changes: 62 additions & 30 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1522,10 +1522,27 @@ def compile(
:aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
:prefill_only (bool): To compile model only for ``prefill`` part. ``Dafaults to False``.
:decode_only (bool): To compile model only decoder part. ``Dafaults to False``.
:compiler_options (dict, optional): Any other options that the `qaic-exec` takes. ``Defaults to None``.

Returns:
:str: Path of the compiled ``qpc`` package.
"""
self.onnx_path = onnx_path
self.compile_dir = compile_dir
self.num_cores = num_cores
self.num_devices = num_devices
self.batch_size = batch_size
self.prefill_seq_len = prefill_seq_len
self.ctx_len = ctx_len
self.full_batch_size = full_batch_size
self.mxfp6_matmul = mxfp6_matmul
self.mxint8_kv_cache = mxint8_kv_cache
self.num_speculative_tokens = num_speculative_tokens
self.enable_qnn = enable_qnn
self.qnn_config = qnn_config

if self.is_tlm:
# assert num_speculative_tokens cfg is acceptable if defined
if num_speculative_tokens is None:
Expand All @@ -1548,9 +1565,10 @@ def compile(
"Prefix caching is enabled only for continuous batching as of now. Please pass `full_batch_size` argument and make sure you pass `continuous_batching=True` in the `from_pretrained` call"
)

kv_cache_batch_size = (
self.kv_cache_batch_size = (
kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size)
)

# Define prefill specialization
prefill_specialization = {
# Prefill is always run with single BS for continuous batching.
Expand All @@ -1560,69 +1578,83 @@ def compile(
# TODO: should be renamed to kv_cache_batch_size in specialization too
}
prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ...
if self.continuous_batching:
prefill_specialization.update({"full_batch_size": kv_cache_batch_size})
else:
prefill_specialization.update({"batch_size": kv_cache_batch_size})

prefill_specialization.update(
{"full_batch_size" if self.continuous_batching else "batch_size": self.kv_cache_batch_size}
)

prefill_specialization.update({"full_batch_exec_size": full_batch_size}) if full_batch_size else ...
specializations = [
prefill_specialization,
]

specializations = [prefill_specialization]

# Compile for prefill_only
if compiler_options.pop("prefill_only", False):
prefill_qpc_path = self.compile_model(compiler_options, [prefill_specialization])
return prefill_qpc_path

# Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization
# Define decode specialization
if prefill_seq_len != 1 or self.continuous_batching:
decode_specialization = {
"batch_size": full_batch_size if self.continuous_batching else batch_size,
"seq_len": num_speculative_tokens + 1 if self.is_tlm else 1,
"ctx_len": ctx_len,
}
if self.continuous_batching:
decode_specialization.update({"full_batch_size": kv_cache_batch_size})
decode_specialization.update({"full_batch_size": self.kv_cache_batch_size})
else:
decode_specialization.update({"batch_size": kv_cache_batch_size})
decode_specialization.update({"batch_size": self.kv_cache_batch_size})
decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ...
specializations.append(decode_specialization)

if enable_qnn:
if compiler_options.pop("decode_only", False):
decode_qpc_path = self.compile_model(compiler_options, [decode_specialization])
return decode_qpc_path

qpc_path = self.compile_model(compiler_options, specializations)
return qpc_path

def compile_model(self, compiler_options, specializations):
if self.enable_qnn:
if compiler_options:
logger.warning("Extra arguments to QNN compilation are supported via qnn_config.json only")

qpc_path = self._qnn_compile(
onnx_path,
compile_dir,
self.onnx_path,
self.compile_dir,
specializations=specializations,
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
batch_size=batch_size,
full_batch_size=full_batch_size,
mdp_ts_num_devices=num_devices,
num_cores=num_cores,
mxfp6_matmul=mxfp6_matmul,
mxint8_kv_cache=mxint8_kv_cache,
qnn_config=qnn_config,
kv_cache_batch_size=kv_cache_batch_size,
prefill_seq_len=self.prefill_seq_len,
ctx_len=self.ctx_len,
batch_size=self.batch_size,
full_batch_size=self.full_batch_size,
mdp_ts_num_devices=self.num_devices,
num_cores=self.num_cores,
mxfp6_matmul=self.mxfp6_matmul,
mxint8_kv_cache=self.mxint8_kv_cache,
qnn_config=self.qnn_config,
kv_cache_batch_size=self.kv_cache_batch_size,
)
else:
# Custom IO
custom_io = {}
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
kv_cache_dtype = "mxint8" if self.mxint8_kv_cache else "float16"
for suffix in ["", "_RetainedState"]:
for i in range(self.num_layers):
for kv in ["key", "value"]:
custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype

qpc_path = self._compile(
onnx_path,
compile_dir,
self.onnx_path,
self.compile_dir,
compile_only=True,
retained_state=True,
specializations=specializations,
convert_to_fp16=True,
mxfp6_matmul=mxfp6_matmul,
mxfp6_matmul=self.mxfp6_matmul,
custom_io=custom_io,
mdp_ts_num_devices=num_devices,
num_speculative_tokens=num_speculative_tokens,
aic_num_cores=num_cores,
mdp_ts_num_devices=self.num_devices,
num_speculative_tokens=self.num_speculative_tokens,
aic_num_cores=self.num_cores,
**compiler_options,
)
return qpc_path
Expand Down
Loading