Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def update_config_files(file_path: str, merge_name: str):

df_list.append(df)
else:
print(f"path {i+1}: {path} (not exist)")
logger.info(f"path {i+1}: {path} (not exist)")
merge_df = pd.concat(df_list, ignore_index=True) if df_list else pd.DataFrame()
## get keys from untuned file to drop_duplicates
untuned_name = (
Expand Down Expand Up @@ -135,7 +135,9 @@ def get_config_file(env_name, default_file, tuned_file_name):
else:
tuned_files = ":".join(str(p) for p in op_tuned_file_list)
tuned_files = default_file + ":" + tuned_files
print(f"merge tuned file under model_configs/ and configs/ ", tuned_files)
logger.info(
f"merge tuned file under model_configs/ and configs/ {tuned_files}"
)
config_file = update_config_files(tuned_files, tuned_file_name)
else:
config_file = update_config_files(config_env_file, tuned_file_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant(
if config is None:
config = _get_config(M, N, K)
config["BLOCK_SIZE_K"] = group_size
config["kpack"] = 1

grid = lambda META: ( # noqa: E731
B,
Expand Down
9 changes: 5 additions & 4 deletions aiter/ops/triton/pa_mqa_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
_deepgemm_fp8_paged_mqa_logits,
_deepgemm_fp8_paged_mqa_logits_ragged_k,
)
from aiter import dtypes


def deepgemm_fp8_paged_mqa_logits_ragged_k(
Expand All @@ -29,7 +30,7 @@ def deepgemm_fp8_paged_mqa_logits_ragged_k(
)
# Since triton doesn't have have the reinterpret_cast, we slice the scale out and view it as float
kv_cache_scale = kv_cache_scale.view(torch.float32)
kv_cache_fp8 = kv_cache_fp8.view(torch.float8_e4m3fnuz)
kv_cache_fp8 = kv_cache_fp8.view(dtypes.fp8)

config = {
"ChunkQ": heads,
Expand Down Expand Up @@ -78,7 +79,7 @@ def deepgemm_fp8_paged_mqa_logits_stage1_ragged_k(
)
# Since triton doesn't have the reinterpret_cast, we slice the scale out and view it as float
kv_cache_scale = kv_cache_scale.view(torch.float32)
kv_cache_fp8 = kv_cache_fp8.view(torch.float8_e4m3fnuz)
kv_cache_fp8 = kv_cache_fp8.view(dtypes.fp8)

config = {
"ChunkQ": 32,
Expand Down Expand Up @@ -130,7 +131,7 @@ def deepgemm_fp8_paged_mqa_logits_stage1(
)
# Since triton doesn't have the reinterpret_cast, we slice the scale out and view it as float
kv_cache_scale = kv_cache_scale.view(torch.float32)
kv_cache_fp8 = kv_cache_fp8.view(torch.float8_e4m3fnuz)
kv_cache_fp8 = kv_cache_fp8.view(dtypes.fp8)

config = {
"ChunkQ": 32,
Expand Down Expand Up @@ -185,7 +186,7 @@ def deepgemm_fp8_paged_mqa_logits(
)
# Since triton doesn't have the reinterpret_cast, we slice the scale out and view it as float
kv_cache_scale = kv_cache_scale.view(torch.float32)
kv_cache_fp8 = kv_cache_fp8.view(torch.float8_e4m3fnuz)
kv_cache_fp8 = kv_cache_fp8.view(dtypes.fp8)

config = {
"ChunkQ": heads,
Expand Down
Loading