From d845ba61243f275e652ef7cba0aea0791d41e040 Mon Sep 17 00:00:00 2001 From: Ivan Sidorenko Date: Tue, 16 Apr 2024 15:17:21 +0000 Subject: [PATCH] [CUBLAS] Set fp32 compute and scale dtypes in fp16 matmul This commit replaces fp16 compute dtype and scale dtype by fp32 in cublas matmul. --- src/runtime/contrib/cublas/cublas.cc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 7a867f4bae18..77948649e0d3 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -150,8 +150,6 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, cudaDataType_t c_type = CUDA_R_32F; float one_fp32 = 1.0; float zero_fp32 = 0.0; - auto one_fp16 = __truncXfYf2__(1.0); - auto zero_fp16 = __truncXfYf2__(0.0); int32_t one_i32 = 1; int32_t zero_i32 = 0; void* alpha = &one_fp32; @@ -165,10 +163,6 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, if (TypeMatch(C->dtype, kDLFloat, 16)) { c_type = CUDA_R_16F; - compute_type = CUBLAS_COMPUTE_16F; - scale_type = CUDA_R_16F; - alpha = &one_fp16; - beta = &zero_fp16; } else if (TypeMatch(C->dtype, kDLInt, 32)) { c_type = CUDA_R_32I; compute_type = CUBLAS_COMPUTE_32I;