diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 6fe3539257..0cd0762ee5 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -96,8 +96,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla const int arch = cuda::sm_arch(); // Transpose mode with column-major ordering - bool transa_bool = transA == CUBLAS_OP_T; - bool transb_bool = transB == CUBLAS_OP_T; + bool is_A_transposed = transA == CUBLAS_OP_T; + bool is_B_transposed = transB == CUBLAS_OP_T; // Configure A matrix if (is_tensor_scaling(A.scaling_mode)) { @@ -106,8 +106,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.transA = transA; ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; - ret.lda = transa_bool ? k : m; - if (arch < 100 && !transa_bool) { + ret.lda = is_A_transposed ? k : m; + if (arch < 100 && !is_A_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { ret.A = A.columnwise_data.dptr; @@ -123,28 +123,28 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // MXFP8 // Note: Row-wise and column-wise data are scaled along different // dimensions (with matrix interpreted in row-major order). - if (transa_bool) { + if (is_A_transposed) { NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); } else { - NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage"); + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage"); } - ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; + ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr; ret.transA = transA; - ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; - ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; - ret.lda = m; + ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = is_A_transposed ? k : m; } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { // FP8 block scaling // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. - if (transa_bool) { + if (is_A_transposed) { NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); } else { - NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage"); + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage"); } - ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; + ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr; ret.transA = CUBLAS_OP_T; - ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; - ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; ret.lda = k; // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage @@ -165,8 +165,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.transB = transB; ret.Btype = B.data.dtype; ret.B_scale_inv = B.scale_inv.dptr; - ret.ldb = transb_bool ? n : k; - if (arch < 100 && transb_bool) { + ret.ldb = is_B_transposed ? n : k; + if (arch < 100 && is_B_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { ret.B = B.columnwise_data.dptr; @@ -182,28 +182,28 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // MXFP8 // Note: Row-wise and column-wise data are scaled along different // dimensions (with matrix interpreted in row-major order). - if (transb_bool) { + if (is_B_transposed) { NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); } else { NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); } - ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; + ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr; ret.transB = transB; - ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; - ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; - ret.ldb = k; + ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = is_B_transposed ? n : k; } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) { // FP8 block scaling // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. - if (transb_bool) { + if (is_B_transposed) { NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); } else { NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); } - ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; + ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr; ret.transB = CUBLAS_OP_N; - ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; - ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; ret.ldb = k; // Requirements from @@ -392,7 +392,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, &B_scale_inverse, sizeof(B_scale_inverse))); NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D && inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)), - "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported got 2D by 2D"); + "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D"); scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;