-
Notifications
You must be signed in to change notification settings - Fork 169
Open
Description
Problem Description
The skinny gemm in the tuned_gemm should also be able to override the output dtype to otype.
Operating System
Ubuntu 22.04.4 LTS (Jammy Jellyfish)
CPU
AMD EPYC 9654 96-Core Processor
GPU
AMD Instinct MI300X
ROCm Version
ROCm 6.4.1
ROCm Component
No response
Steps to Reproduce
The following is the patch to make the unit tests more holistic to also test for the output datatype correctness.
diff --git a/op_tests/test_gemm.py b/op_tests/test_gemm.py
index 8dd6b20..e6d4f16 100644
--- a/op_tests/test_gemm.py
+++ b/op_tests/test_gemm.py
@@ -65,6 +65,11 @@ def test_gemm(dtype, m, n, k, bias=False, otype=None, scaleA=None, scaleB=None):
scaleB = torch.tensor(scaleB, dtype=dtypes.fp32, device="cuda")
(a, *_), avg_a = run_torch(x, weight, bias, otype, scaleA, scaleB)
(b, *_), avg_b = run_gemm_b(x, weight, bias, otype, scaleA, scaleB)
+
+ assert a.dtype == b.dtype, f"Expected a.dtype == b.dtype, but a={a.dtype}, b={b.dtype}, input dtype={dtype}"
+ if otype is not None:
+ assert a.dtype == otype, f"a={a.dtype}, expected output dtype={otype}, input dtype={dtype}"
+ assert b.dtype == otype, f"b={b.dtype}, expected output dtype={otype}, input dtype={dtype}"
msg = f"[perf] dim: {str(dim):<20} dtype: {dtype}, torch avg: {avg_a:<8.2f} us, B avg: {avg_b:<8.2f} us, uplift: {avg_a/avg_b-1:<5.1%}"
checkAllclose(a, b, msg=msg)
@@ -274,7 +279,8 @@ def test_normal_gemm():
)
test_gemm(dtypes.bf16, 128, 32, 8192)
for dtype in [dtypes.fp16, dtypes.bf16]:
- test_gemm(dtype, 128, 32, 8192)
+ for otype in [None, dtypes.fp16, dtypes.bf16, dtypes.fp32]:
+ test_gemm(dtype, 128, 32, 8192, otype=otype)
# # qkv_proj
# for (m, n, k) in [(4096, 1280, 8192),
# (128, 1280, 8192),
@@ -337,7 +343,8 @@ def test_skinny_gemm():
for mnk in test_mnk_list:
m, n, k = mnk
for dtype in [dtypes.fp16, dtypes.bf16]:
- test_gemm(dtype, m, n, k)
+ for otype in [None, dtypes.fp16, dtypes.bf16, dtypes.fp32]:
+ test_gemm(dtype, m, n, k, otype=otype)
# test_normal_gemm()Run the test python3 op_tests/test_gemm.py.
Proposed fix:
diff --git a/aiter/tuned_gemm.py b/aiter/tuned_gemm.py
index a06dea4..fd8656d 100644
--- a/aiter/tuned_gemm.py
+++ b/aiter/tuned_gemm.py
@@ -300,6 +300,8 @@ class TunedGemm:
)
if batched:
out = out.view(*inp.shape[:-1], weights.shape[0])
+ if otype is not None:
+ out = out.to(otype)
return out(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response
Metadata
Metadata
Assignees
Labels
No labels