From 172d7c1d87110c186106d2916c093d88ba0854e4 Mon Sep 17 00:00:00 2001 From: Annanya Date: Tue, 1 Apr 2025 05:07:34 -0400 Subject: [PATCH 1/3] [Cublas] Added support for bfloat16 while dispatching to cublas kernels --- python/tvm/relax/backend/cuda/cublas.py | 1 + src/runtime/contrib/cublas/cublas.cc | 4 ++++ src/runtime/contrib/cublas/cublas_utils.h | 5 +++++ 3 files changed, 10 insertions(+) diff --git a/python/tvm/relax/backend/cuda/cublas.py b/python/tvm/relax/backend/cuda/cublas.py index 6828381e68e1..f8621d9b5621 100644 --- a/python/tvm/relax/backend/cuda/cublas.py +++ b/python/tvm/relax/backend/cuda/cublas.py @@ -43,6 +43,7 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype): (lhs_dtype == "float16" and rhs_dtype == "float16") or (lhs_dtype == "float32" and rhs_dtype == "float32") or (lhs_dtype == "int8" and rhs_dtype == "int8") + or (lhs_dtype == "bfloat16" and rhs_dtype == "bfloat16") ) diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index ba01f791d98a..79d94a818564 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -162,6 +162,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, if (TypeMatch(A->dtype, kDLFloat, 16)) { ab_type = CUDA_R_16F; + } else if(TypeMatch(A->dtype, kDLBfloat, 16)){ + ab_type = CUDA_R_16BF; } else if (TypeMatch(A->dtype, kDLInt, 8)) { ab_type = CUDA_R_8I; } else if (TypeMatch(A->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)) { @@ -171,6 +173,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, if (TypeMatch(C->dtype, kDLFloat, 16)) { c_type = CUDA_R_16F; + } else if(TypeMatch(C->dtype, kDLBfloat, 16)){ + c_type = CUDA_R_16BF; } else if (TypeMatch(C->dtype, kDLInt, 32)) { c_type = CUDA_R_32I; compute_type = CUBLAS_COMPUTE_32I; diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 387065093eaa..316241915557 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -116,6 +116,11 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) { case 64: return CUDA_R_64F; } + } else if (type.code == kDLBfloat){ + switch (type.bits) { + case 16: + return CUDA_R_16BF; + } } LOG(FATAL) << "Unsupported cuda type"; } From c24854b87d781dbde80fe030c08fab1b67d6a078 Mon Sep 17 00:00:00 2001 From: Annanya Date: Mon, 7 Apr 2025 03:25:15 -0400 Subject: [PATCH 2/3] Added tests and fixed the code --- src/runtime/contrib/cublas/cublas.cc | 2 +- tests/python/relax/test_codegen_cublas.py | 41 +++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 79d94a818564..7ed190c3d1ea 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -125,7 +125,7 @@ bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_s if (int_support && TypeMatch(out_dtype, kDLInt, 32)) { return TypeMatch(in_dtype, kDLInt, 8); } else if (TypeMatch(out_dtype, kDLFloat, 32)) { - return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16); + return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16) || TypeMatch(in_dtype, kDLBfloat, 16); } else { return false; } diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index dbcb25b69d52..152f04fc3ce7 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -393,6 +393,47 @@ def test_matmul_fp8_multiply_offload(): tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) +@pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed") +@pytest.mark.parametrize( + "x_shape, y_shape, transpose_y, out_dtype", + [ + ((10, 32), (64, 32), True, "float32"), + ((32, 16), (32, 16), True, "float32"), + ((2, 10, 32), (2, 64, 32), True, "float32"), + ], +) +def test_matmul_bfloat16_offload( + x_shape, + y_shape, + transpose_y, + out_dtype, +): + in_dtype = "bfloat16" + mod = get_relax_matmul_module( + x_shape, + y_shape, + in_dtype, + out_dtype, + bias_shape=None, + transposed_y=transpose_y, + activation=None, + ) + # Generate input data in float32 and then convert to bfloat16 using ml_dtypes. + x_float32 = np.random.uniform(low=0, high=5, size=x_shape).astype("float32") + y_float32 = np.random.uniform(low=0, high=5, size=y_shape).astype("float32") + x_bf16 = ml_dtypes.bfloat16(x_float32) + y_bf16 = ml_dtypes.bfloat16(y_float32) + + # For the reference result, adjust y (if needed) in float32. + z = np.swapaxes(y_float32, -2, -1) if transpose_y else y_float32 + args = (x_bf16, y_bf16) + + out = get_result_with_relax_cublas_offload(mod, args) + ref_out = np.matmul(x_float32, z).astype(out_dtype) + + tvm.testing.assert_allclose(out, ref_out, rtol=1e-2, atol=1e-2) + + @pytest.mark.parametrize( "M, N, K, out_dtype, transposed_y, partition_done", [ From 78b863a992360bbb939e7b62c4822b7b8b9b7443 Mon Sep 17 00:00:00 2001 From: Annanya Date: Mon, 7 Apr 2025 03:30:25 -0400 Subject: [PATCH 3/3] Some formatting changes --- src/runtime/contrib/cublas/cublas.cc | 7 ++++--- src/runtime/contrib/cublas/cublas_utils.h | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 7ed190c3d1ea..3fbda3ac945d 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -125,7 +125,8 @@ bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_s if (int_support && TypeMatch(out_dtype, kDLInt, 32)) { return TypeMatch(in_dtype, kDLInt, 8); } else if (TypeMatch(out_dtype, kDLFloat, 32)) { - return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16) || TypeMatch(in_dtype, kDLBfloat, 16); + return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16) || + TypeMatch(in_dtype, kDLBfloat, 16); } else { return false; } @@ -162,7 +163,7 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, if (TypeMatch(A->dtype, kDLFloat, 16)) { ab_type = CUDA_R_16F; - } else if(TypeMatch(A->dtype, kDLBfloat, 16)){ + } else if (TypeMatch(A->dtype, kDLBfloat, 16)) { ab_type = CUDA_R_16BF; } else if (TypeMatch(A->dtype, kDLInt, 8)) { ab_type = CUDA_R_8I; @@ -173,7 +174,7 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, if (TypeMatch(C->dtype, kDLFloat, 16)) { c_type = CUDA_R_16F; - } else if(TypeMatch(C->dtype, kDLBfloat, 16)){ + } else if (TypeMatch(C->dtype, kDLBfloat, 16)) { c_type = CUDA_R_16BF; } else if (TypeMatch(C->dtype, kDLInt, 32)) { c_type = CUDA_R_32I; diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 316241915557..3e9ded08deb1 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -116,7 +116,7 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) { case 64: return CUDA_R_64F; } - } else if (type.code == kDLBfloat){ + } else if (type.code == kDLBfloat) { switch (type.bits) { case 16: return CUDA_R_16BF;