Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
const int arch = cuda::sm_arch();

// Transpose mode with column-major ordering
bool transa_bool = transA == CUBLAS_OP_T;
bool transb_bool = transB == CUBLAS_OP_T;
bool is_A_transposed = transA == CUBLAS_OP_T;
bool is_B_transposed = transB == CUBLAS_OP_T;

// Configure A matrix
if (is_tensor_scaling(A.scaling_mode)) {
Expand All @@ -106,8 +106,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.transA = transA;
ret.Atype = A.data.dtype;
ret.A_scale_inv = A.scale_inv.dptr;
ret.lda = transa_bool ? k : m;
if (arch < 100 && !transa_bool) {
ret.lda = is_A_transposed ? k : m;
if (arch < 100 && !is_A_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
ret.A = A.columnwise_data.dptr;
Expand All @@ -123,28 +123,28 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
// MXFP8
// Note: Row-wise and column-wise data are scaled along different
// dimensions (with matrix interpreted in row-major order).
if (transa_bool) {
if (is_A_transposed) {
NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
} else {
NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage");
NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
}
ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr;
ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
ret.transA = transA;
ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype;
ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
ret.lda = m;
ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
ret.lda = is_A_transposed ? k : m;
} else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) {
// FP8 block scaling
// Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (transa_bool) {
if (is_A_transposed) {
NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage");
} else {
NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage");
NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage");
}
ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr;
ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr;
ret.transA = CUBLAS_OP_T;
ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype;
ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype;
ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
ret.lda = k;

// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
Expand All @@ -165,8 +165,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.transB = transB;
ret.Btype = B.data.dtype;
ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = transb_bool ? n : k;
if (arch < 100 && transb_bool) {
ret.ldb = is_B_transposed ? n : k;
if (arch < 100 && is_B_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
ret.B = B.columnwise_data.dptr;
Expand All @@ -182,28 +182,28 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
// MXFP8
// Note: Row-wise and column-wise data are scaled along different
// dimensions (with matrix interpreted in row-major order).
if (transb_bool) {
if (is_B_transposed) {
NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
} else {
NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
}
ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr;
ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
ret.transB = transB;
ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype;
ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
ret.ldb = k;
ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
} else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) {
// FP8 block scaling
// Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (transb_bool) {
if (is_B_transposed) {
NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage");
} else {
NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage");
}
ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr;
ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr;
ret.transB = CUBLAS_OP_N;
ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype;
ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype;
ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
ret.ldb = k;

// Requirements from
Expand Down Expand Up @@ -392,7 +392,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&B_scale_inverse, sizeof(B_scale_inverse)));
NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D &&
inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)),
"Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported got 2D by 2D");
"Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D");
scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D
? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
: CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
Expand Down