Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 15 additions & 23 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
super().__init__()
self.model = model
self.hash_params = create_model_params(self, **kwargs)
self.prefill_onnx_path: Optional[str] = None
self.onnx_path: Optional[str] = None
self.qpc_path: Optional[str] = None
self.qpc_session: Optional[QAICInferenceSession] = None
Expand Down Expand Up @@ -240,10 +239,7 @@ def _export(

# Return early if ONNX already exists
if onnx_path.is_file():
if prefill_only:
self.prefill_onnx_path = onnx_path
else:
self.onnx_path = onnx_path
self.onnx_path = onnx_path
return onnx_path

# check if the model is in meta state or weights are offloaded
Expand Down Expand Up @@ -322,10 +318,7 @@ def _export(
finally:
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)

if prefill_only:
self.prefill_onnx_path = onnx_path
else:
self.onnx_path = onnx_path
self.onnx_path = onnx_path
return onnx_path

def get_onnx_path(
Expand All @@ -342,21 +335,18 @@ def get_onnx_path(
"use_onnx_subfunctions": use_onnx_subfunctions,
"retain_full_kv": retain_full_kv,
}

if prefill_only:
if self.prefill_onnx_path is None:
kwargs.update(
{
"prefill_only": prefill_only,
"prefill_seq_len": specializations[0].get("seq_len"),
"enable_chunking": enable_chunking,
}
)
self.export(**kwargs)
return self.prefill_onnx_path
else:
if self.onnx_path is None:
self.export(**kwargs)
return self.onnx_path
kwargs.update(
{
"prefill_only": prefill_only,
"prefill_seq_len": specializations[0].get("seq_len"),
"enable_chunking": enable_chunking,
}
)

self.export(**kwargs)
return self.onnx_path

@dump_qconfig
def _compile(
Expand Down Expand Up @@ -404,6 +394,8 @@ def _compile(
onnx_path = Path(
onnx_path
if onnx_path
else self.onnx_path
if self.onnx_path
else self.get_onnx_path(
prefill_only,
enable_chunking,
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"}

# This is for supporting different modelling classes specially written for prefill-only model
SPECIALIZED_PREFILL_ONLY_MODEL_ARCH = {"gpt_oss"}
SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss"}

# Define a transformers layers to QEff layers dictionary
# While onboarding new models make sure to add the new layer maps to this dictionary.
Expand Down
57 changes: 28 additions & 29 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from QEfficient.generation.vlm_generation import VisionLanguageGeneration
from QEfficient.transformers.modeling_utils import (
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH,
SPECIALIZED_PREFILL_ONLY_MODEL_ARCH,
SPECIALIZED_DISAGG_SERVING_MODEL_ARCH,
)
from QEfficient.transformers.models.pytorch_transforms import (
BlockedKVAttentionTransform,
Expand Down Expand Up @@ -2522,15 +2522,18 @@ def get_seq_len_and_handle_specialized_prefill_model(

num_q_blocks = os.environ.get("NUM_Q_BLOCKS", None)
if num_q_blocks is None:
block_size = 256
if prefill_seq_len is None or prefill_seq_len % block_size != 0 or prefill_seq_len < 128:
if (
prefill_seq_len is None
or prefill_seq_len % constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE != 0
or prefill_seq_len < constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE
):
raise ValueError(
f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. "
f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE}. "
f"Or set `NUM_Q_BLOCKS` ENV variable"
f"Received: prefill_seq_len={prefill_seq_len}"
)

num_q_blocks = prefill_seq_len // block_size
num_q_blocks = prefill_seq_len // constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE
logger.warning(
f"Setting NUM_Q_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_Q_BLOCKS` to override"
)
Expand Down Expand Up @@ -2588,31 +2591,28 @@ def export(
self.model.config, fbs if self.continuous_batching else bs, seq_len
)
enable_chunking = kwargs.get("enable_chunking", False)
if prefill_only:
if not enable_chunking and self.continuous_batching:
raise NotImplementedError(
"Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!"
)
self.prefill(enable=True, enable_chunking=enable_chunking)
self.hash_params.pop("retain_full_kv", None)
seq_len = (
self.get_seq_len_and_handle_specialized_prefill_model(

# TODO: move this to a DA Serving utility class
if self.model.config.model_type in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH:
if prefill_only:
if self.continuous_batching and not enable_chunking:
raise NotImplementedError("Can't enable prefix-caching without chunking")
self.prefill(enable=True, enable_chunking=enable_chunking)
self.hash_params.pop("retain_full_kv", None)
seq_len = self.get_seq_len_and_handle_specialized_prefill_model(
prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking
)
if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH
else seq_len
)
kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len
else:
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
self.hash_params.pop("prefill_only", None)
self.hash_params.pop("NUM_Q_BLOCKS", None)
self.hash_params.pop("NUM_FFN_BLOCKS", None)
self.hash_params.pop("ENABLE_OPT_SWA", None)
self.hash_params.pop("chunking", None)
if kwargs.get("retain_full_kv", False):
kv_cache_shape[2] = seq_len + self.model.config.sliding_window
self.hash_params["retain_full_kv"] = True
kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len
else:
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
self.hash_params.pop("prefill_only", None)
self.hash_params.pop("NUM_Q_BLOCKS", None)
self.hash_params.pop("NUM_FFN_BLOCKS", None)
self.hash_params.pop("ENABLE_OPT_SWA", None)
self.hash_params.pop("chunking", None)
if kwargs.get("retain_full_kv", False):
kv_cache_shape[2] = seq_len + self.model.config.sliding_window
self.hash_params["retain_full_kv"] = True

example_inputs = {
"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
Expand Down Expand Up @@ -2942,7 +2942,6 @@ def compile(
if prefill_only is None or not prefill_only:
if self.continuous_batching and full_batch_size is None:
raise TypeError("`full_batch_size` is required when `continuous_batching=True`.")

else:
if self.continuous_batching and kv_cache_batch_size is None and full_batch_size is None:
raise ValueError(
Expand Down
3 changes: 3 additions & 0 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def get_models_dir():
CCL_MAX_ELEMENTS_LISTS = 5
CCL_START_CTX_LEN = 4096

# used for gpt-oss prefill-only model Q-blocking
GPT_OSS_PREFILL_Q_BLOCK_SIZE = 256


class Constants:
# Export Constants.
Expand Down
20 changes: 16 additions & 4 deletions tests/transformers/test_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,17 @@ def test_causal_lm_export_and_hash(config, cb, tmp_path):


@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
@pytest.mark.parametrize("subfunc", [False, True], ids=["False", "True"])
@pytest.mark.parametrize("subfunc", [False, True], ids=["non-subfunc", "subfunc"])
@pytest.mark.parametrize("prefill_only", [False, True], ids=["pref+decode", "prefill-only"])
@pytest.mark.parametrize("config", configs, ids=config_ids)
def test_causal_lm_hash_creation(config, cb, subfunc, tmp_path):
def test_causal_lm_hash_creation(config, cb, subfunc, prefill_only, tmp_path):
if config.model_type == "gpt_oss" and prefill_only:
pytest.skip(
"gpt_oss prefill_only mode has different logic to create hash as we have two different ONNX for prefill/decode for this model for disagg serving"
)
model = AutoModelForCausalLM.from_config(config, **model_kwargs)
qeff_model = QEFFAutoModelForCausalLM(model, cb)
qeff_model.export(tmp_path, use_onnx_subfunctions=subfunc)
qeff_model.export(tmp_path, use_onnx_subfunctions=subfunc, prefill_only=prefill_only)
hash_params = {}
hash_params["config"] = qeff_model.model.config.to_diff_dict()
hash_params["peft_config"] = None
Expand Down Expand Up @@ -251,12 +256,19 @@ def tmp_cache(tmp_path, monkeypatch):
yield tmp_path


@pytest.mark.parametrize("prefill_only", [False, True], ids=["pref+decode", "prefill_only"])
@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
@pytest.mark.parametrize("config", configs, ids=config_ids)
def test_causal_lm_compile(config, cb, tmp_cache):
def test_causal_lm_compile(config, cb, prefill_only, tmp_cache):
if config.model_type == "gpt_oss":
pytest.skip(
"gpt_oss prefill_only mode has different logic to create hash as we have two different ONNX for prefill/decode for this model for disagg serving"
)
model = AutoModelForCausalLM.from_config(config, **model_kwargs)
qeff_model = QEFFAutoModelForCausalLM(model, cb)
compile_params = {"prefill_seq_len": 8, "ctx_len": 16}
if prefill_only:
compile_params["prefill_only"] = True
if cb:
compile_params["full_batch_size"] = 32
compile_params["batch_size"] = 8
Expand Down