diff --git a/aiter/ops/gemm_op_a4w4.py b/aiter/ops/gemm_op_a4w4.py index cedde7bae0..1967fc0a30 100644 --- a/aiter/ops/gemm_op_a4w4.py +++ b/aiter/ops/gemm_op_a4w4.py @@ -69,7 +69,7 @@ def gemm_a4w4_asm( A_scale: Tensor, # A_scale:[M, K/32] e8m0 paded B_scale: Tensor, # B_scale:[N, K/32] e8m0 paded out: Tensor, # Out:[M, N] bf16 - bias: Tensor, # bias:[1, N] f32 + bias: Optional[Tensor] = None, # bias:[1, N] f32 alpha: Optional[float] = 1.0, beta: Optional[float] = 0.0, bpreshuffle: Optional[bool] = True, diff --git a/csrc/include/asm_gemm_a4w4.h b/csrc/include/asm_gemm_a4w4.h index 5b63af5541..61b370b33a 100644 --- a/csrc/include/asm_gemm_a4w4.h +++ b/csrc/include/asm_gemm_a4w4.h @@ -8,7 +8,7 @@ torch::Tensor gemm_a4w4_asm(torch::Tensor& A, // A:[M, K/2] f4x2 torch::Tensor& A_scale, // A_scale:[M, K/32] e8m0 paded torch::Tensor& B_scale, // B_scale:[N, K/32] e8m0 paded torch::Tensor& out, // Out:[M, N] bf16 - torch::Tensor& bias, // bias:[M, N] f32 - std::optional alpha = 1.0, - std::optional beta = 0.0, - std::optional bpreshuffle = true); + std::optional bias = std::nullopt, // bias:[M, N] f32 + std::optional alpha = 1.0, + std::optional beta = 0.0, + std::optional bpreshuffle = true); diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 2201c32e6f..edca6449ec 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -302,18 +302,18 @@ py::arg("pad_c") = 0, \ py::arg("splitK") = 0); -#define GEMM_A4W4_ASM_PYBIND \ - m.def("gemm_a4w4_asm", \ - &gemm_a4w4_asm, \ - "Asm gemm a4w4", \ - py::arg("A"), \ - py::arg("B"), \ - py::arg("A_scale"), \ - py::arg("B_scale"), \ - py::arg("out"), \ - py::arg("bias"), \ - py::arg("alpha") = 1.0, \ - py::arg("beta") = 0.0, \ +#define GEMM_A4W4_ASM_PYBIND \ + m.def("gemm_a4w4_asm", \ + &gemm_a4w4_asm, \ + "Asm gemm a4w4", \ + py::arg("A"), \ + py::arg("B"), \ + py::arg("A_scale"), \ + py::arg("B_scale"), \ + py::arg("out"), \ + py::arg("bias") = std::nullopt, \ + py::arg("alpha") = 1.0, \ + py::arg("beta") = 0.0, \ py::arg("bpreshuffle") = true); #define GEMM_A4W4_BLOCKSCALE_PYBIND \ @@ -760,7 +760,7 @@ py::arg("unit_size"), \ py::arg("local_expert_mask") = std::nullopt, \ py::arg("num_local_tokens") = std::nullopt, \ - py::arg("dispatch_policy") = 0); + py::arg("dispatch_policy") = 0); #define NORM_PYBIND \ m.def("layernorm2d_fwd", \ diff --git a/csrc/py_itfs_cu/asm_gemm_a4w4.cu b/csrc/py_itfs_cu/asm_gemm_a4w4.cu index 6f6b793ab3..6419c5bea5 100644 --- a/csrc/py_itfs_cu/asm_gemm_a4w4.cu +++ b/csrc/py_itfs_cu/asm_gemm_a4w4.cu @@ -64,10 +64,10 @@ torch::Tensor gemm_a4w4_asm(torch::Tensor& A, // A:[M, K/2] f4x2 torch::Tensor& A_scale, // A_scale:[M, K/32] e8m0 paded torch::Tensor& B_scale, // B_scale:[N, K/32] e8m0 paded torch::Tensor& out, // Out:[M, N] bf16 - torch::Tensor& bias, // bias:[M, N] f32 - std::optional alpha = 1.0, - std::optional beta = 0.0, - std::optional bpreshuffle = true) + std::optional bias = std::nullopt, // bias:[M, N] f32 + std::optional alpha = 1.0, + std::optional beta = 0.0, + std::optional bpreshuffle = true) { TORCH_CHECK( out.dtype() == torch::ScalarType::BFloat16, __func__, " only support BFloat16 output now!"); @@ -77,7 +77,7 @@ torch::Tensor gemm_a4w4_asm(torch::Tensor& A, // A:[M, K/2] f4x2 KernelArgs args; size_t arg_size = sizeof(args); args.ptr_D = (void*)out.data_ptr(); - args.ptr_C = (void*)bias.data_ptr(); + args.ptr_C = bias.has_value() ? (void*)bias.value().data_ptr() : nullptr; args.ptr_A = (void*)A.data_ptr(); args.ptr_B = (void*)B.data_ptr(); diff --git a/op_tests/test_gemm_a4w4.py b/op_tests/test_gemm_a4w4.py index 3c0294eef5..670727f9ff 100644 --- a/op_tests/test_gemm_a4w4.py +++ b/op_tests/test_gemm_a4w4.py @@ -81,7 +81,7 @@ def test_gemm(dtype, M, N, K): out1 = torch.empty(M, N, dtype=dtype) out2 = torch.empty((M + 255) // 256 * 256, N, dtype=dtype) out3 = torch.empty((M + 255) // 256 * 256, N, dtype=dtype) - bias_f32 = torch.zeros(M, N, dtype=dtype) + bias_f32 = None x_scales = x_scales.view(torch.uint8) w_scales = w_scales.view(torch.uint8)