diff --git a/aiter/jit/core.py b/aiter/jit/core.py index 3ceb4b823a..560727795a 100644 --- a/aiter/jit/core.py +++ b/aiter/jit/core.py @@ -90,7 +90,7 @@ def update_config_files(file_path: str, merge_name: str): df_list.append(df) else: - print(f"path {i+1}: {path} (not exist)") + logger.info(f"path {i+1}: {path} (not exist)") merge_df = pd.concat(df_list, ignore_index=True) if df_list else pd.DataFrame() ## get keys from untuned file to drop_duplicates untuned_name = ( @@ -135,7 +135,9 @@ def get_config_file(env_name, default_file, tuned_file_name): else: tuned_files = ":".join(str(p) for p in op_tuned_file_list) tuned_files = default_file + ":" + tuned_files - print(f"merge tuned file under model_configs/ and configs/ ", tuned_files) + logger.info( + f"merge tuned file under model_configs/ and configs/ {tuned_files}" + ) config_file = update_config_files(tuned_files, tuned_file_name) else: config_file = update_config_files(config_env_file, tuned_file_name) diff --git a/aiter/ops/triton/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py b/aiter/ops/triton/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py index 9fd8e505b8..7701fb249e 100644 --- a/aiter/ops/triton/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py +++ b/aiter/ops/triton/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py @@ -88,6 +88,7 @@ def batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( if config is None: config = _get_config(M, N, K) config["BLOCK_SIZE_K"] = group_size + config["kpack"] = 1 grid = lambda META: ( # noqa: E731 B, diff --git a/aiter/ops/triton/pa_mqa_logits.py b/aiter/ops/triton/pa_mqa_logits.py index 3bcda1b8f6..d343eeed6e 100644 --- a/aiter/ops/triton/pa_mqa_logits.py +++ b/aiter/ops/triton/pa_mqa_logits.py @@ -9,6 +9,7 @@ _deepgemm_fp8_paged_mqa_logits, _deepgemm_fp8_paged_mqa_logits_ragged_k, ) +from aiter import dtypes def deepgemm_fp8_paged_mqa_logits_ragged_k( @@ -29,7 +30,7 @@ def deepgemm_fp8_paged_mqa_logits_ragged_k( ) # Since triton doesn't have have the reinterpret_cast, we slice the scale out and view it as float kv_cache_scale = kv_cache_scale.view(torch.float32) - kv_cache_fp8 = kv_cache_fp8.view(torch.float8_e4m3fnuz) + kv_cache_fp8 = kv_cache_fp8.view(dtypes.fp8) config = { "ChunkQ": heads, @@ -78,7 +79,7 @@ def deepgemm_fp8_paged_mqa_logits_stage1_ragged_k( ) # Since triton doesn't have the reinterpret_cast, we slice the scale out and view it as float kv_cache_scale = kv_cache_scale.view(torch.float32) - kv_cache_fp8 = kv_cache_fp8.view(torch.float8_e4m3fnuz) + kv_cache_fp8 = kv_cache_fp8.view(dtypes.fp8) config = { "ChunkQ": 32, @@ -130,7 +131,7 @@ def deepgemm_fp8_paged_mqa_logits_stage1( ) # Since triton doesn't have the reinterpret_cast, we slice the scale out and view it as float kv_cache_scale = kv_cache_scale.view(torch.float32) - kv_cache_fp8 = kv_cache_fp8.view(torch.float8_e4m3fnuz) + kv_cache_fp8 = kv_cache_fp8.view(dtypes.fp8) config = { "ChunkQ": 32, @@ -185,7 +186,7 @@ def deepgemm_fp8_paged_mqa_logits( ) # Since triton doesn't have the reinterpret_cast, we slice the scale out and view it as float kv_cache_scale = kv_cache_scale.view(torch.float32) - kv_cache_fp8 = kv_cache_fp8.view(torch.float8_e4m3fnuz) + kv_cache_fp8 = kv_cache_fp8.view(dtypes.fp8) config = { "ChunkQ": heads,