Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 69 additions & 25 deletions QEfficient/utils/check_ccl_specializations.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def automatic_ccl_generation(
max_elements=constants.CCL_MAX_ELEMENTS_LISTS,
last_value=prefill_last,
)
# Set the last element in prefill_list to maximum possible input prompt to support all input lengths
prefill_list[-1] = mapped_cl

return prefill_list, decode_list, mapped_cl

Expand All @@ -126,36 +128,78 @@ def automatic_ccl_generation(
logger.warning("prefill_seq_len cannot be less than 1!")


def validate_ccl_lists(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len):
# Check CCL values are not negative and more than the CCL minimum context length = constants.CCL_MIN_CTX_LEN
if ccl_prefill:
ccl_prefill = [x if x >= constants.CCL_MIN_CTX_LEN else constants.CCL_MIN_CTX_LEN for x in ccl_prefill]
if ccl_decode:
ccl_decode = [x if x >= constants.CCL_MIN_CTX_LEN else constants.CCL_MIN_CTX_LEN for x in ccl_decode]

# Check the last element of ccl_prefill and ccl_decode to make sure it's not less than ctx_len
if ccl_prefill[-1] < ctx_len - 1:
ccl_prefill.append(ctx_len)
if ccl_decode[-1] < ctx_len:
ccl_decode.append(ctx_len)

if prefill_seq_len == 1:
# both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them.
ccl_union_all = sorted(set([min(x, ctx_len) for x in ccl_prefill + ccl_decode]))
ccl_prefill = ccl_union_all
ccl_decode = ccl_union_all
else:
# Sort ccl_prefill and ccl_decode lists and make sure they don't have repeated elements and also are less than ctx_len
if ccl_prefill:
ccl_prefill = sorted({min(x, ctx_len) for x in (ccl_prefill)})
if ccl_decode:
ccl_decode = sorted({min(x, ctx_len) for x in (ccl_decode)})

# Handling the common values between ccl_prefill and ccl_decode. The elements of these two lists should be unique (COMPILER)
tmp_prefill = ccl_prefill
ccl_prefill = []
for val in tmp_prefill:
while val in ccl_decode or val in ccl_prefill:
val -= 1
if val < 0:
break # Prevent negative values
if val >= 0:
ccl_prefill.append(val)
ccl_prefill.sort()

return ccl_prefill, ccl_decode


def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len):
"""
This function evaluates the values of CCL lists based on three inputs:
- ccl_prefill: optional [list]
- ccl_decode: optional [list]
- ccl_enabled: optional [bool]

Conditions to handle:
1) ccl_prefill AND ccl_decode AND ccl_enabled == True
2) ccl_prefill AND ccl_decode (ccl_enabled not provided)
3) ccl_prefill ONLY AND ccl_enabled == True and ccl_decode not provided
4) ccl_decode ONLY AND ccl_enabled == True and ccl_prefill not provided
5) ccl_prefill ONLY (ccl_enabled and ccl_decode are not provided)
6) ccl_decode ONLY (ccl_enabled and ccl_prefill are not provided)
7) ccl_enabled == True (no ccl_prefill, no ccl_decode) -> Automatic CCL lists generation
"""
# Automatic CCL generation: If both ccl_prefill and ccl_decode are None
if ccl_prefill is None and ccl_decode is None:
# Condition #7
if not ccl_prefill and not ccl_decode:
# Generate optimized context length lists for prefill and decode based on ctx_len
# Due to compiler limitations, ccl_prefill and ccl_decode must have distinct values
ccl_prefill, ccl_decode, ctx_len = automatic_ccl_generation(ctx_len, prefill_seq_len)
else:
if prefill_seq_len == 1:
if ccl_prefill is not None and ccl_decode is not None:
# both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them.
ccl_union_all = sorted(set([min(x, ctx_len) for x in ccl_prefill + ccl_decode]))
ccl_prefill = ccl_union_all
ccl_decode = ccl_union_all
else:
if ccl_prefill:
ccl_prefill = sorted({min(x, ctx_len) for x in (ccl_prefill)})
if ccl_decode:
ccl_decode = sorted({min(x, ctx_len) for x in (ccl_decode)})

if ccl_prefill is not None and ccl_decode is not None:
tmp_prefill = ccl_prefill
ccl_prefill = []
for val in tmp_prefill:
while val in ccl_decode or val in ccl_prefill:
val -= 1
if val < 0:
break # Prevent negative values
if val >= 0:
ccl_prefill.append(val)
ccl_prefill.sort()

# One of ccl lists is [] or None -> replace it with [ctx_len] -> CCL lists have to have a value when CCL is enabled
# Condition #3, #4, #5, and #6
elif not ccl_prefill or not ccl_decode:
# Initial setting and will be checked with edge cases later
ccl_prefill = ccl_prefill if ccl_prefill else [ctx_len]
ccl_decode = ccl_decode if ccl_decode else [ctx_len]

# Verifying ccl_prefill and ccl_decode values for all conditions
ccl_prefill, ccl_decode = validate_ccl_lists(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len)

logger.info("CCL Configuration:")
logger.info(f" - Prefill context lengths: {ccl_prefill}")
Expand Down
1 change: 1 addition & 0 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def get_models_dir():
# Limitation in the maximum number of elements in comp_ctx_lengths_decode and comp_ctx_lengths_prefill lists during automatic lists generation process.
CCL_MAX_ELEMENTS_LISTS = 5
CCL_START_CTX_LEN = 4096
CCL_MIN_CTX_LEN = 1024

# used for gpt-oss prefill-only model Q-blocking
GPT_OSS_PREFILL_Q_BLOCK_SIZE = 256
Expand Down
Loading