From e91febe0c021a3450e697c6dffcd413ad09ca5c6 Mon Sep 17 00:00:00 2001 From: Shiyan Deng Date: Mon, 4 Aug 2025 23:48:24 -0700 Subject: [PATCH 1/2] fix arg parser --- csrc/cpp_itfs/pa/pa_v1.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/cpp_itfs/pa/pa_v1.py b/csrc/cpp_itfs/pa/pa_v1.py index 80d0cba47b..e0faec6b98 100644 --- a/csrc/cpp_itfs/pa/pa_v1.py +++ b/csrc/cpp_itfs/pa/pa_v1.py @@ -247,7 +247,8 @@ def paged_attention_v1( parser.add_argument("--fp8_kv_dtype", type=str, required=True) parser.add_argument("--out_dtype", type=str, required=True) parser.add_argument("--block_size", type=int, required=True) - parser.add_argument("--alibi_enabled", type=str, required=True) + parser.add_argument("--alibi_enabled", type=bool, required=True) + parser.add_argument("--logits_soft_cap_enabled", type=bool, required=True) parser.add_argument("--mtp", type=int, default=1) parser.add_argument("--folder", type=str, default=None) args = parser.parse_args() From c12943db995d415a38792f6f768875ac9c1ecf53 Mon Sep 17 00:00:00 2001 From: Shiyan Deng Date: Tue, 5 Aug 2025 14:41:08 -0700 Subject: [PATCH 2/2] use str_to_bool --- csrc/cpp_itfs/pa/pa_v1.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/cpp_itfs/pa/pa_v1.py b/csrc/cpp_itfs/pa/pa_v1.py index e0faec6b98..746c5cc9ea 100644 --- a/csrc/cpp_itfs/pa/pa_v1.py +++ b/csrc/cpp_itfs/pa/pa_v1.py @@ -1,5 +1,5 @@ from jinja2 import Template -from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR +from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR, str_to_bool import ctypes import math @@ -247,8 +247,8 @@ def paged_attention_v1( parser.add_argument("--fp8_kv_dtype", type=str, required=True) parser.add_argument("--out_dtype", type=str, required=True) parser.add_argument("--block_size", type=int, required=True) - parser.add_argument("--alibi_enabled", type=bool, required=True) - parser.add_argument("--logits_soft_cap_enabled", type=bool, required=True) + parser.add_argument("--alibi_enabled", type=str_to_bool, required=True) + parser.add_argument("--logits_soft_cap_enabled", type=str_to_bool, required=True) parser.add_argument("--mtp", type=int, default=1) parser.add_argument("--folder", type=str, default=None) args = parser.parse_args()