Skip to content

[Issue]: The skinny gemm in the tuned_gemm does not adhere to the specified output dtype otype. #663

@vllmellm

Description

@vllmellm

Problem Description

The skinny gemm in the tuned_gemm should also be able to override the output dtype to otype.

Operating System

Ubuntu 22.04.4 LTS (Jammy Jellyfish)

CPU

AMD EPYC 9654 96-Core Processor

GPU

AMD Instinct MI300X

ROCm Version

ROCm 6.4.1

ROCm Component

No response

Steps to Reproduce

The following is the patch to make the unit tests more holistic to also test for the output datatype correctness.

diff --git a/op_tests/test_gemm.py b/op_tests/test_gemm.py
index 8dd6b20..e6d4f16 100644
--- a/op_tests/test_gemm.py
+++ b/op_tests/test_gemm.py
@@ -65,6 +65,11 @@ def test_gemm(dtype, m, n, k, bias=False, otype=None, scaleA=None, scaleB=None):
         scaleB = torch.tensor(scaleB, dtype=dtypes.fp32, device="cuda")
     (a, *_), avg_a = run_torch(x, weight, bias, otype, scaleA, scaleB)
     (b, *_), avg_b = run_gemm_b(x, weight, bias, otype, scaleA, scaleB)
+    
+    assert a.dtype == b.dtype, f"Expected a.dtype == b.dtype, but a={a.dtype}, b={b.dtype}, input dtype={dtype}"
+    if otype is not None:
+        assert a.dtype == otype, f"a={a.dtype}, expected output dtype={otype}, input dtype={dtype}"
+        assert b.dtype == otype, f"b={b.dtype}, expected output dtype={otype}, input dtype={dtype}"
 
     msg = f"[perf] dim: {str(dim):<20} dtype: {dtype}, torch avg: {avg_a:<8.2f} us, B avg: {avg_b:<8.2f} us, uplift: {avg_a/avg_b-1:<5.1%}"
     checkAllclose(a, b, msg=msg)
@@ -274,7 +279,8 @@ def test_normal_gemm():
     )
     test_gemm(dtypes.bf16, 128, 32, 8192)
     for dtype in [dtypes.fp16, dtypes.bf16]:
-        test_gemm(dtype, 128, 32, 8192)
+        for otype in [None, dtypes.fp16, dtypes.bf16, dtypes.fp32]:
+            test_gemm(dtype, 128, 32, 8192, otype=otype)
         # # qkv_proj
         # for (m, n, k) in [(4096, 1280, 8192),
         #                   (128, 1280, 8192),
@@ -337,7 +343,8 @@ def test_skinny_gemm():
         for mnk in test_mnk_list:
             m, n, k = mnk
             for dtype in [dtypes.fp16, dtypes.bf16]:
-                test_gemm(dtype, m, n, k)
+                for otype in [None, dtypes.fp16, dtypes.bf16, dtypes.fp32]:
+                    test_gemm(dtype, m, n, k, otype=otype)
 
 
 # test_normal_gemm()

Run the test python3 op_tests/test_gemm.py.

Proposed fix:

diff --git a/aiter/tuned_gemm.py b/aiter/tuned_gemm.py
index a06dea4..fd8656d 100644
--- a/aiter/tuned_gemm.py
+++ b/aiter/tuned_gemm.py
@@ -300,6 +300,8 @@ class TunedGemm:
         )
         if batched:
             out = out.view(*inp.shape[:-1], weights.shape[0])
+        if otype is not None:
+            out = out.to(otype)
         return out

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions