From 2f7342ad4a9bf180844ae224ef5d66eb247c7726 Mon Sep 17 00:00:00 2001 From: Hua Huang Date: Wed, 2 Apr 2025 15:10:41 -0700 Subject: [PATCH 01/13] grouped_add example to test variadic args Signed-off-by: Hua Huang --- tests/jax/mytest.py | 29 +++++++++++ transformer_engine/jax/cpp_extensions/gemm.py | 37 ++++++++++++- transformer_engine/jax/csrc/extensions.h | 2 + .../jax/csrc/extensions/gemm.cpp | 52 +++++++++++++++++++ .../jax/csrc/extensions/pybind.cpp | 2 + 5 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 tests/jax/mytest.py diff --git a/tests/jax/mytest.py b/tests/jax/mytest.py new file mode 100644 index 0000000000..3e77181d8b --- /dev/null +++ b/tests/jax/mytest.py @@ -0,0 +1,29 @@ +import jax +import jax.numpy as jnp +from transformer_engine.jax import cpp_extensions as tex +from utils import assert_allclose +import pdb + +out_dtype = jnp.float32 + +shape_list = [[128, 256], [256, 256], [512, 128]] +A_list = [] +B_list = [] +ref_C_list = [] + +key = jax.random.PRNGKey(0) +subkeys = jax.random.split(key, len(shape_list * 2)) +for i in range(len(shape_list)): + shape = shape_list[i] + A_i = jax.random.normal(subkeys[2 * i], shape, dtype=out_dtype) + B_i = jax.random.normal(subkeys[2 * i + 1], shape, dtype=out_dtype) + ref_C_i = A_i + B_i + A_list.append(A_i) + B_list.append(B_i) + ref_C_list.append(ref_C_i) + +#pdb.set_trace() +C_list = tex.grouped_add(A_list, B_list, out_dtype) +for i in range(len(shape_list)): + assert_allclose(C_list[i], ref_C_list[i]) +print("Grouped add test passed.") diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 0327542c2f..d67a37caf5 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -21,7 +21,7 @@ ) -__all__ = ["gemm", "grouped_gemm"] +__all__ = ["gemm", "grouped_gemm", "grouped_add"] num_cublas_streams = 4 @@ -518,3 +518,38 @@ def grouped_gemm( out_tensors.append(out_flat.reshape(*lhs_remain_shape, *rhs_remain_shape)) return out_tensors + +class GroupedAddPrimitive(BasePrimitive): + multiple_results = True + impl_static_args = () + inner_primitive = None + outer_primitive = None + name = "te_grouped_add_ffi" + + @staticmethod + def abstract(*args, out_dtype): + num_pairs = len(args) // 2 + A_list = args[:num_pairs] + return tuple(jax.core.ShapedArray(A.shape, dtype=out_dtype) for A in A_list) + + @staticmethod + def outer_abstract(*args, **kwargs): + return GroupedAddPrimitive.abstract(*args, **kwargs) + + @staticmethod + def lowering(ctx, *args, out_dtype): + del out_dtype + num_pairs = len(args) // 2 + return jax.ffi.ffi_lowering(GroupedAddPrimitive.name)( + ctx, *args, num_pairs=num_pairs + ) + + @staticmethod + def impl(*args, out_dtype): + assert GroupedAddPrimitive.inner_primitive is not None + return GroupedAddPrimitive.inner_primitive.bind(*args, out_dtype=out_dtype) + +register_primitive(GroupedAddPrimitive) + +def grouped_add(A_list, B_list, out_dtype): + return GroupedAddPrimitive.outer_primitive.bind(*A_list, *B_list, out_dtype=out_dtype) \ No newline at end of file diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index aaaf57fab7..2fbd95f1cd 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -119,6 +119,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( // Grouped GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedAddHandler); + // Cudnn helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index d4b9bf720e..28d820218e 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -209,5 +209,57 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("scaling_mode"), FFI_CudaGraph_Traits); +Error_Type GroupedAddFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, + Variadic_Result_Type output_list, int64_t num_pairs) { + for (size_t i = 0; i < static_cast(num_pairs); i++) { + auto A_i_get = input_list.get(i); + auto B_i_get = input_list.get(num_pairs + i); + auto C_i_get = output_list.get(i); + Buffer_Type A_i = A_i_get.value(); + Buffer_Type B_i = B_i_get.value(); + Result_Type C_i = C_i_get.value(); + auto A_ptr = reinterpret_cast(A_i.untyped_data()); + auto B_ptr = reinterpret_cast(B_i.untyped_data()); + auto C_ptr = reinterpret_cast(C_i->untyped_data()); + auto A_shape = A_i.dimensions(); + auto B_shape = B_i.dimensions(); + auto C_shape = C_i->dimensions(); + printf("Pair %ld: A shape ", i); + for (size_t j = 0; j < A_shape.size(); j++) printf("%ld ", A_shape[j]); + printf("; B shape "); + for (size_t j = 0; j < B_shape.size(); j++) printf("%ld ", B_shape[j]); + printf("; C shape "); + for (size_t j = 0; j < C_shape.size(); j++) printf("%ld ", C_shape[j]); + printf("\n"); + size_t A_size = product(A_shape); + size_t B_size = product(B_shape); + size_t C_size = product(C_shape); + float *A_ptr_host = (float *) malloc(A_size * sizeof(float)); + float *B_ptr_host = (float *) malloc(B_size * sizeof(float)); + float *C_ptr_host = (float *) malloc(C_size * sizeof(float)); + cudaMemcpyAsync(A_ptr_host, A_ptr, A_size * sizeof(float), cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(B_ptr_host, B_ptr, B_size * sizeof(float), cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(C_ptr_host, C_ptr, C_size * sizeof(float), cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); + for (size_t j = 0; j < A_size; j++) + C_ptr_host[j] = A_ptr_host[j] + B_ptr_host[j]; + cudaMemcpyAsync(C_ptr, C_ptr_host, C_size * sizeof(float), cudaMemcpyHostToDevice, stream); + cudaStreamSynchronize(stream); + free(A_ptr_host); + free(B_ptr_host); + free(C_ptr_host); + } + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedAddHandler, GroupedAddFFI, + FFI::Bind() + .Ctx() // stream + .RemainingArgs() // input list + .RemainingRets() // output list + .Attr("num_pairs"), + FFI_CudaGraph_Traits); + + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 5c165cccb6..5e1f64e40f 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -59,6 +59,8 @@ pybind11::dict Registrations() { pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); + dict["te_grouped_add_ffi"] = EncapsulateFFI(GroupedAddHandler); + return dict; } From 59b17fd903bdd23a12403aed7e475dc8236248a9 Mon Sep 17 00:00:00 2001 From: Hua Huang Date: Mon, 7 Apr 2025 11:16:22 -0700 Subject: [PATCH 02/13] New GroupedGemmPrimitive using variadic args Signed-off-by: Hua Huang --- transformer_engine/jax/cpp_extensions/gemm.py | 190 +++++++++++++++++- transformer_engine/jax/csrc/extensions.h | 1 + .../jax/csrc/extensions/gemm.cpp | 188 +++++++++++++++++ .../jax/csrc/extensions/pybind.cpp | 5 + 4 files changed, 381 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d67a37caf5..eaddb139dd 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -132,6 +132,59 @@ def impl( register_primitive(GroupedGemmPrimitive) +class GroupedGemmPrimitiveNew(BasePrimitive): + """ + Primitive for grouped GEMM + """ + + name = "te_grouped_gemm_new_ffi" + multiple_results = True + impl_static_args = () + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias): + del scaling_mode, has_bias + A_list = args[0 : num_gemms] + B_list = args[num_gemms : 2 * num_gemms] + out_list_aval = tuple(jax.core.ShapedArray((A.shape[0], B.shape[0]), dtype=out_dtype) for A, B in zip(A_list, B_list)) + workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams + workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) + return (*out_list_aval, workspace_aval) + + @staticmethod + def outer_abstract(*args, **kwargs): + (out_aval, _) = GroupedGemmPrimitiveNew.abstract(*args, **kwargs) + return out_aval + + @staticmethod + def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias): + del out_dtype + workspace_size = get_cublas_workspace_size_bytes() + return jax.ffi.ffi_lowering(GroupedGemmPrimitiveNew.name)( + ctx, + *args, + num_gemms=num_gemms, + scaling_mode=int(scaling_mode), + has_bias=has_bias, + workspace_size=workspace_size + ) + + @staticmethod + def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias): + assert GroupedGemmPrimitiveNew.inner_primitive is not None + out = GroupedGemmPrimitiveNew.inner_primitive.bind( + *args, + num_gemms=num_gemms, + scaling_mode=scaling_mode.value, + out_dtype=out_dtype, + has_bias=has_bias, + ) + return out[:-1] # out is [out_list, wkspace], only return out_list + + +register_primitive(GroupedGemmPrimitiveNew) def _shape_normalization(x, dimension_numbers, already_transposed: bool = False): orig_order = list(range(x.ndim)) @@ -369,7 +422,7 @@ def swizzled_scale(scales): return scales -def grouped_gemm( +def grouped_gemm_old( lhs_list: List[Union[jnp.ndarray, ScaledTensor]], rhs_list: List[Union[jnp.ndarray, ScaledTensor]], contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]], @@ -519,6 +572,134 @@ def grouped_gemm( return out_tensors + +def grouped_gemm_new( + lhs_list: List[Union[jnp.ndarray, ScaledTensor]], + rhs_list: List[Union[jnp.ndarray, ScaledTensor]], + contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]], + bias_list: List[jnp.ndarray] = None, +) -> List[jnp.ndarray]: + """Grouped GEMM for multiple pairs of tensors.""" + assert ( + len(lhs_list) == len(rhs_list) == len(contracting_dims_list) + ), "lhs_list, rhs_list, contracting_dims_list must have the same length" + + num_gemms = len(lhs_list) + lhs_list_ = [] + rhs_list_ = [] + lhs_sinv_list_ = [] + rhs_sinv_list_ = [] + bias_list_ = [] + num_gemms = len(lhs_list) + for i in range(num_gemms): + lhs = lhs_list[i] + rhs = rhs_list[i] + contracting_dims = contracting_dims_list[i] + dim_nums = (contracting_dims, ((), ())) + if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): + scaling_mode = lhs.scaling_mode + lhs_shape = lhs.data.shape + rhs_shape = rhs.data.shape + out_dtype = lhs.dq_dtype + # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout + if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + assert not ( + lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 + ), "FP8 GEMM does not support E5M2 * E5M2" + ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims + if lhs.data_layout == "T": + lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim + if rhs.data_layout == "T": + rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim + dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) + else: + # For jnp.ndarray, only consider contracting_dims, data_layout is always NN + scaling_mode = ScalingMode.NVTE_NO_SCALING + lhs_shape = lhs.shape + rhs_shape = rhs.shape + out_dtype = lhs.dtype + + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums + lhs_dn = (lhs_contract, lhs_batch) + rhs_dn = (rhs_contract, rhs_batch) + + lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) + rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) + + if scaling_mode == ScalingMode.NVTE_NO_SCALING: + lhs_3d = _shape_normalization(lhs, lhs_dn) + rhs_3d = _shape_normalization(rhs, rhs_dn) + elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") + rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") + elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + lhs_3d = _shape_normalization(lhs.data, lhs_dn) + rhs_3d = _shape_normalization(rhs.data, rhs_dn) + lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) + rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn) + lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze()) + rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze()) + else: + raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}") + + # Note: if _shape_normalization() is updated to support non-TN, need to update here + # already_transposed doesn't matter for the output shape + # x.shape = [B, D1, D2] + # contracting_dims = (2, ) --> output.shape = [1, B * D1, D2] + # contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1] + # x.shape = [D1, D2] + # contracting_dims = (1, ) --> output.shape = [1, D1, D2] + # contracting_dims = (0, ) --> output.shape = [1, D2, D1] + bm = lhs_remain_shape[0] + bn = rhs_remain_shape[0] + kl = lhs_3d.shape[-1] + kr = rhs_3d.shape[-1] + assert kl == kr, f"lhs_3d.shape[-1] ({kl}) != rhs_3d.shape[-1] ({kr})" + k = kl + if (bm % 16 != 0) or (bn % 16 != 0) or (k % 16 != 0): + print(f"grouped_gemm input pair {i} has invalid problem shape for lowering: ") + print( + f"m = {bm}, n = {bn}, k = {k}; cuBLAS requires the problem shapes being multiples" + " of 16" + ) + assert bm % 16 == 0 and bn % 16 == 0 and k % 16 == 0 + + lhs_list_.append(jnp.squeeze(lhs_3d, axis=0)) + rhs_list_.append(jnp.squeeze(rhs_3d, axis=0)) + if scaling_mode == ScalingMode.NVTE_NO_SCALING: + lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) + rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + lhs_sinv_list_.append(lhs.scale_inv) + rhs_sinv_list_.append(rhs.scale_inv) + if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + lhs_sinv_list_.append(lhs_scale_inv) + rhs_sinv_list_.append(rhs_scale_inv) + if bias_list is not None: + bias_list_.append(bias_list[i]) + + out_list = GroupedGemmPrimitiveNew.outer_primitive.bind( + *lhs_list_, + *rhs_list_, + *lhs_sinv_list_, + *rhs_sinv_list_, + *bias_list_, + num_gemms=num_gemms, + scaling_mode=scaling_mode, + out_dtype=out_dtype, + has_bias=1 if bias_list is not None else 0, + ) + + return out_list + +def grouped_gemm( + lhs_list: List[Union[jnp.ndarray, ScaledTensor]], + rhs_list: List[Union[jnp.ndarray, ScaledTensor]], + contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]], + bias_list: List[jnp.ndarray] = None, +) -> List[jnp.ndarray]: + return grouped_gemm_new(lhs_list, rhs_list, contracting_dims_list, bias_list) + class GroupedAddPrimitive(BasePrimitive): multiple_results = True impl_static_args = () @@ -530,7 +711,9 @@ class GroupedAddPrimitive(BasePrimitive): def abstract(*args, out_dtype): num_pairs = len(args) // 2 A_list = args[:num_pairs] - return tuple(jax.core.ShapedArray(A.shape, dtype=out_dtype) for A in A_list) + workspace_aval = jax.core.ShapedArray(shape=(1024,), dtype=jnp.uint8) + out_aval = tuple(jax.core.ShapedArray(A.shape, dtype=out_dtype) for A in A_list) + return (*out_aval, workspace_aval) @staticmethod def outer_abstract(*args, **kwargs): @@ -547,7 +730,8 @@ def lowering(ctx, *args, out_dtype): @staticmethod def impl(*args, out_dtype): assert GroupedAddPrimitive.inner_primitive is not None - return GroupedAddPrimitive.inner_primitive.bind(*args, out_dtype=out_dtype) + out = GroupedAddPrimitive.inner_primitive.bind(*args, out_dtype=out_dtype) + return out[:-1] register_primitive(GroupedAddPrimitive) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 2fbd95f1cd..3715f4bfc4 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -118,6 +118,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( // Grouped GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmNewHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedAddHandler); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 28d820218e..dc39a97be2 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -209,6 +209,194 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("scaling_mode"), FFI_CudaGraph_Traits); + +Error_Type GroupedGemmNewFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, + Variadic_Result_Type output_list, int64_t num_gemms, + int64_t scaling_mode, int64_t has_bias, int64_t workspace_size) { + // Notes on matrix layouts and transpose: + // Jax uses row-major data_layout, on entering this function, each input matrix pair: + // A: row-major with size [m, k], + // B: row-major with size [n, k], needs transpose, + // on exiting this function, JAX expect: + // C: row-major with size [m, n]. + // cuBLAS uses column-major data_layout, in this view, each input matrix pair: + // A: column-major with size [k, m], needs transpose, + // B: column-major with size [k, n]. + // If we call cuBLAS GEMM for A * B, the output will be: + // C: column-major with size [m, n] --> row-major with size [n, m]. + // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. + + bool trans_lhs = true; + bool trans_rhs = false; + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); + bool grad = false; + bool accumulate = false; + bool use_split_accumulator = false; + + // These lists are to keep the TensorWrapper objects alive + std::vector lhs_wrapper_list; + std::vector rhs_wrapper_list; + std::vector bias_wrapper_list; + std::vector pre_gelu_wrapper_list; + std::vector out_wrapper_list; + std::vector workspace_wrapper_list; + + // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM + std::vector lhs_list; + std::vector rhs_list; + std::vector bias_list; + std::vector pre_gelu_list; + std::vector out_list; + std::vector workspace_list; + + //printf("[DEBUG] num_gemms, scaling_mode, has_bias, workspace_size: %ld %ld %ld %ld\n", + // num_gemms, scaling_mode, has_bias, workspace_size); + //fflush(stdout); + + int lhs_list_offset = 0; + int rhs_list_offset = num_gemms; + int lhs_sinv_list_offset = 2 * num_gemms; + int rhs_sinv_list_offset = 3 * num_gemms; + int bias_list_offset = 4 * num_gemms; + int out_list_offset = 0; + for (int i = 0; i < num_gemms; i++) { + auto lhs_i_get = input_list.get(lhs_list_offset + i); + auto rhs_i_get = input_list.get(rhs_list_offset + i); + auto lhs_sinv_i_get = input_list.get(lhs_sinv_list_offset + i); + auto rhs_sinv_i_get = input_list.get(rhs_sinv_list_offset + i); + auto out_i_get = output_list.get(out_list_offset + i); + Buffer_Type lhs_i = lhs_i_get.value(); + Buffer_Type rhs_i = rhs_i_get.value(); + Buffer_Type lhs_sinv_i = lhs_sinv_i_get.value(); + Buffer_Type rhs_sinv_i = rhs_sinv_i_get.value(); + Result_Type out_i = out_i_get.value(); + + DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type()); + DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type()); + DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type()); + + void *lhs_ptr = lhs_i.untyped_data(); + void *rhs_ptr = rhs_i.untyped_data(); + void *lhs_sinv_ptr = lhs_sinv_i.untyped_data(); + void *rhs_sinv_ptr = rhs_sinv_i.untyped_data(); + void *out_ptr = out_i->untyped_data(); + + // Placeholder for bias since it can be empty + DType bias_dtype = DType::kFloat32; + void *bias_ptr = nullptr; + + auto lhs_shape_ = lhs_i.dimensions(); + auto rhs_shape_ = rhs_i.dimensions(); + + size_t m = lhs_shape_[0]; + size_t n = rhs_shape_[0]; + size_t k = lhs_shape_[1]; + + auto lhs_shape = std::vector{m, k}; + auto rhs_shape = std::vector{n, k}; + auto out_shape = std::vector{n, m}; + auto lhs_sinv_shape = std::vector{1, 1}; + auto rhs_sinv_shape = std::vector{1, 1}; + + if (scaling_mode == NVTE_NO_SCALING || scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, nullptr, nullptr, + reinterpret_cast(lhs_sinv_ptr)); + auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, nullptr, nullptr, + reinterpret_cast(rhs_sinv_ptr)); + lhs_wrapper_list.push_back(std::move(lhs_i_)); + rhs_wrapper_list.push_back(std::move(rhs_i_)); + } else if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", + MXFP8_BLOCK_SIZE, k); + size_t sinv_k = k / MXFP8_BLOCK_SIZE; + lhs_sinv_shape[0] = m; + lhs_sinv_shape[1] = sinv_k; + rhs_sinv_shape[0] = n; + rhs_sinv_shape[1] = sinv_k; + + // Note: the scale_inv array should have been swizzled in Python before lowering + TensorWrapper lhs_i_(NVTE_MXFP8_1D_SCALING); + TensorWrapper rhs_i_(NVTE_MXFP8_1D_SCALING); + lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape); + rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape); + lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape); + rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape); + + lhs_wrapper_list.push_back(std::move(lhs_i_)); + rhs_wrapper_list.push_back(std::move(rhs_i_)); + } else { + NVTE_ERROR("Unsupported scaling mode: ", scaling_mode); + } + + /* + printf("[DEBUG] i: %d, lhs_dtype, rhs_dtype, out_dtype = %d, %d, %d\n", i, lhs_dtype, rhs_dtype, out_dtype); + printf("[DEBUG] lhs_shape: %d %d, ptr = %p\n", lhs_shape_[0], lhs_shape_[1], lhs_ptr); + printf("[DEBUG] rhs_shape: %d %d, ptr = %p\n", rhs_shape_[0], rhs_shape_[1], rhs_ptr); + printf("[DEBUG] out_shape: %d %d, ptr = %p\n", out_i->dimensions()[0], out_i->dimensions()[1], out_ptr); + printf("[DEBUG] lhs_sinv_shape: %d %d, ptr = %p\n", lhs_sinv_shape[0], lhs_sinv_shape[1], lhs_sinv_ptr); + printf("[DEBUG] rhs_sinv_shape: %d %d, ptr = %p\n", rhs_sinv_shape[0], rhs_sinv_shape[1], rhs_sinv_ptr); + fflush(stdout); + */ + + auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype); + void *pre_gelu_ptr = nullptr; + auto bias_shape = std::vector{0}; + auto pre_gelu_shape = std::vector{0}; + if (has_bias) + { + auto bias_i_get = input_list.get(bias_list_offset + i); + Buffer_Type bias_i = bias_i_get.value(); + bias_ptr = bias_i.untyped_data(); + bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type()); + bias_shape[0] = n; + } + auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype); + + out_wrapper_list.push_back(std::move(out_i_)); + bias_wrapper_list.push_back(std::move(bias_i)); + pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); + + lhs_list.push_back(lhs_wrapper_list.back().data()); + rhs_list.push_back(rhs_wrapper_list.back().data()); + bias_list.push_back(bias_wrapper_list.back().data()); + pre_gelu_list.push_back(pre_gelu_wrapper_list.back().data()); + out_list.push_back(out_wrapper_list.back().data()); + } + + auto workspace_get = output_list.get(num_gemms); + Result_Type workspace = workspace_get.value(); + uint8_t *workspace_ptr = reinterpret_cast(workspace->untyped_data()); + auto workspace_shape = std::vector{workspace_size}; + for (int i = 0; i < num_streams; i++) { + auto workspace_i = + TensorWrapper(static_cast(workspace_ptr), workspace_shape, DType::kByte); + workspace_wrapper_list.push_back(std::move(workspace_i)); + workspace_list.push_back(workspace_wrapper_list.back().data()); + workspace_ptr += workspace_size; + } + //printf("[DEBUG] workspace packing done\n"); + //fflush(stdout); + + nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), + pre_gelu_list.data(), num_gemms, trans_lhs, trans_rhs, grad, + workspace_list.data(), accumulate, use_split_accumulator, + num_math_sm, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmNewHandler, GroupedGemmNewFFI, + FFI::Bind() + .Ctx() // stream + .RemainingArgs() // input list + .RemainingRets() // output list + .Attr("num_gemms") + .Attr("scaling_mode") + .Attr("has_bias") + .Attr("workspace_size"), + FFI_CudaGraph_Traits); + Error_Type GroupedAddFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, Variadic_Result_Type output_list, int64_t num_pairs) { for (size_t i = 0; i < static_cast(num_pairs); i++) { diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 5e1f64e40f..dcc218dced 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -59,6 +59,11 @@ pybind11::dict Registrations() { pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); + dict["te_grouped_gemm_new_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(GroupedGemmNewHandler)); + + dict["te_grouped_add_ffi"] = EncapsulateFFI(GroupedAddHandler); return dict; From 937f98d2ee86aba85be4dff5791827dacfed0d0e Mon Sep 17 00:00:00 2001 From: Hua Huang Date: Tue, 8 Apr 2025 10:39:08 -0700 Subject: [PATCH 03/13] Remove NVTE_NO_SCALING in new grouped_gemm Signed-off-by: Hua Huang --- transformer_engine/jax/cpp_extensions/gemm.py | 10 +++++++++- transformer_engine/jax/csrc/extensions/gemm.cpp | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index eaddb139dd..e5ab28750e 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -678,6 +678,11 @@ def grouped_gemm_new( if bias_list is not None: bias_list_.append(bias_list[i]) + # TE/common does not support NVTE_NO_SCALING yet + # It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16 + if scaling_mode == ScalingMode.NVTE_NO_SCALING: + scaling_mode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING + out_list = GroupedGemmPrimitiveNew.outer_primitive.bind( *lhs_list_, *rhs_list_, @@ -698,7 +703,10 @@ def grouped_gemm( contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]], bias_list: List[jnp.ndarray] = None, ) -> List[jnp.ndarray]: - return grouped_gemm_new(lhs_list, rhs_list, contracting_dims_list, bias_list) + grouped_gemm_ = grouped_gemm_new + #grouped_gemm_ = grouped_gemm_old + #print(f"Using {grouped_gemm_.__name__} for grouped_gemm") + return grouped_gemm_(lhs_list, rhs_list, contracting_dims_list, bias_list) class GroupedAddPrimitive(BasePrimitive): multiple_results = True diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index dc39a97be2..a590755300 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -298,7 +298,7 @@ Error_Type GroupedGemmNewFFI(cudaStream_t stream, Variadic_Buffer_Type input_lis auto lhs_sinv_shape = std::vector{1, 1}; auto rhs_sinv_shape = std::vector{1, 1}; - if (scaling_mode == NVTE_NO_SCALING || scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, nullptr, nullptr, reinterpret_cast(lhs_sinv_ptr)); auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, nullptr, nullptr, From 963fd4c4d73fc3f72464cb3f92b5b6daf5322ad7 Mon Sep 17 00:00:00 2001 From: Hua Huang Date: Tue, 8 Apr 2025 12:50:11 -0700 Subject: [PATCH 04/13] Remove squeeze() to reduce D2D memcpy Signed-off-by: Hua Huang --- transformer_engine/jax/cpp_extensions/gemm.py | 120 +++++++++++------- .../jax/csrc/extensions/gemm.cpp | 28 ++-- 2 files changed, 87 insertions(+), 61 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index e5ab28750e..27e2941f1e 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -148,7 +148,8 @@ def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias): del scaling_mode, has_bias A_list = args[0 : num_gemms] B_list = args[num_gemms : 2 * num_gemms] - out_list_aval = tuple(jax.core.ShapedArray((A.shape[0], B.shape[0]), dtype=out_dtype) for A, B in zip(A_list, B_list)) + # A and B have shapes [1, m, k] and [1, n, k] + out_list_aval = tuple(jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype) for A, B in zip(A_list, B_list)) workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) return (*out_list_aval, workspace_aval) @@ -572,7 +573,14 @@ def grouped_gemm_old( return out_tensors +def _get_shape_normalization_contracting_dim_size(x, dimension_numbers): + orig_order = list(range(x.ndim)) + contracting_dims, batch_dims = dimension_numbers + contracting_order = [d for d in orig_order if d in contracting_dims] + cols_shape = [x.shape[d] for d in contracting_order] + return reduce(operator.mul, cols_shape, 1) +import nvtx def grouped_gemm_new( lhs_list: List[Union[jnp.ndarray, ScaledTensor]], rhs_list: List[Union[jnp.ndarray, ScaledTensor]], @@ -584,13 +592,12 @@ def grouped_gemm_new( len(lhs_list) == len(rhs_list) == len(contracting_dims_list) ), "lhs_list, rhs_list, contracting_dims_list must have the same length" + rng_gg = nvtx.start_range("grouped_gemm_new", color="blue") + + rng_shape_check = nvtx.start_range("shape_check", color="yellow") num_gemms = len(lhs_list) - lhs_list_ = [] - rhs_list_ = [] - lhs_sinv_list_ = [] - rhs_sinv_list_ = [] - bias_list_ = [] - num_gemms = len(lhs_list) + lhs_dn_list = [] + rhs_dn_list = [] for i in range(num_gemms): lhs = lhs_list[i] rhs = rhs_list[i] @@ -622,28 +629,13 @@ def grouped_gemm_new( (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums lhs_dn = (lhs_contract, lhs_batch) rhs_dn = (rhs_contract, rhs_batch) + lhs_dn_list.append(lhs_dn) + rhs_dn_list.append(rhs_dn) lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) - if scaling_mode == ScalingMode.NVTE_NO_SCALING: - lhs_3d = _shape_normalization(lhs, lhs_dn) - rhs_3d = _shape_normalization(rhs, rhs_dn) - elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: - lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") - rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: - lhs_3d = _shape_normalization(lhs.data, lhs_dn) - rhs_3d = _shape_normalization(rhs.data, rhs_dn) - lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) - rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn) - lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze()) - rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze()) - else: - raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}") - - # Note: if _shape_normalization() is updated to support non-TN, need to update here - # already_transposed doesn't matter for the output shape + # Note: already_transposed doesn't matter for the output shape # x.shape = [B, D1, D2] # contracting_dims = (2, ) --> output.shape = [1, B * D1, D2] # contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1] @@ -652,37 +644,72 @@ def grouped_gemm_new( # contracting_dims = (0, ) --> output.shape = [1, D2, D1] bm = lhs_remain_shape[0] bn = rhs_remain_shape[0] - kl = lhs_3d.shape[-1] - kr = rhs_3d.shape[-1] - assert kl == kr, f"lhs_3d.shape[-1] ({kl}) != rhs_3d.shape[-1] ({kr})" - k = kl - if (bm % 16 != 0) or (bn % 16 != 0) or (k % 16 != 0): + _lhs_data = lhs if scaling_mode == ScalingMode.NVTE_NO_SCALING else lhs.data + _rhs_data = rhs if scaling_mode == ScalingMode.NVTE_NO_SCALING else rhs.data + kl = _get_shape_normalization_contracting_dim_size(_lhs_data, lhs_dn) + kr = _get_shape_normalization_contracting_dim_size(_rhs_data, rhs_dn) + assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}" + if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0): print(f"grouped_gemm input pair {i} has invalid problem shape for lowering: ") print( - f"m = {bm}, n = {bn}, k = {k}; cuBLAS requires the problem shapes being multiples" + f"m = {bm}, n = {bn}, k = {kl}; cuBLAS requires the problem shapes being multiples" " of 16" ) - assert bm % 16 == 0 and bn % 16 == 0 and k % 16 == 0 + assert bm % 16 == 0 and bn % 16 == 0 and kl % 16 == 0 + nvtx.end_range(rng_shape_check) - lhs_list_.append(jnp.squeeze(lhs_3d, axis=0)) - rhs_list_.append(jnp.squeeze(rhs_3d, axis=0)) - if scaling_mode == ScalingMode.NVTE_NO_SCALING: - lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) - rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: - lhs_sinv_list_.append(lhs.scale_inv) - rhs_sinv_list_.append(rhs.scale_inv) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: - lhs_sinv_list_.append(lhs_scale_inv) - rhs_sinv_list_.append(rhs_scale_inv) - if bias_list is not None: - bias_list_.append(bias_list[i]) + rng_make_tuple = nvtx.start_range("make_tuple", color="purple") + # Note: do not .squeeze() for the output of _shape_normalization, it will trigger a D2D memcpy + if scaling_mode == ScalingMode.NVTE_NO_SCALING: + lhs_list_ = tuple( + _shape_normalization(lhs, lhs_dn) + for lhs, lhs_dn in zip(lhs_list, lhs_dn_list) + ) + rhs_list_ = tuple( + _shape_normalization(rhs, rhs_dn) + for rhs, rhs_dn in zip(rhs_list, rhs_dn_list) + ) + lhs_sinv_list_ = tuple(jnp.ones(1, dtype=jnp.float32) for _ in range(num_gemms)) + rhs_sinv_list_ = tuple(jnp.ones(1, dtype=jnp.float32) for _ in range(num_gemms)) + elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + lhs_list_ = tuple( + _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") + for lhs, lhs_dn in zip(lhs_list, lhs_dn_list) + ) + rhs_list_ = tuple( + _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") + for rhs, rhs_dn in zip(rhs_list, rhs_dn_list) + ) + lhs_sinv_list_ = tuple(lhs.scale_inv for lhs in lhs_list) + rhs_sinv_list_ = tuple(rhs.scale_inv for rhs in rhs_list) + elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + lhs_list_ = tuple( + _shape_normalization(lhs.data, lhs_dn) + for lhs, lhs_dn in zip(lhs_list, lhs_dn_list) + ) + rhs_list_ = tuple( + _shape_normalization(rhs.data, rhs_dn) + for rhs, rhs_dn in zip(rhs_list, rhs_dn_list) + ) + lhs_sinv_list_ = tuple( + swizzled_scale(_shape_normalization(lhs.scale_inv, lhs_dn)) + for lhs, lhs_dn in zip(lhs_list, lhs_dn_list) + ) + rhs_sinv_list_ = tuple( + swizzled_scale(_shape_normalization(rhs.scale_inv, rhs_dn)) + for rhs, rhs_dn in zip(rhs_list, rhs_dn_list) + ) + else: + raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}") + bias_list_ = [] if bias_list is None else tuple(bias_list) + nvtx.end_range(rng_make_tuple) # TE/common does not support NVTE_NO_SCALING yet # It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16 if scaling_mode == ScalingMode.NVTE_NO_SCALING: scaling_mode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING + rng_prim = nvtx.start_range("grouped_gemm_new_ffi", color="green") out_list = GroupedGemmPrimitiveNew.outer_primitive.bind( *lhs_list_, *rhs_list_, @@ -694,6 +721,9 @@ def grouped_gemm_new( out_dtype=out_dtype, has_bias=1 if bias_list is not None else 0, ) + nvtx.end_range(rng_prim) + + nvtx.end_range(rng_gg) return out_list diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index a590755300..34d1ba4009 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -260,16 +260,11 @@ Error_Type GroupedGemmNewFFI(cudaStream_t stream, Variadic_Buffer_Type input_lis int bias_list_offset = 4 * num_gemms; int out_list_offset = 0; for (int i = 0; i < num_gemms; i++) { - auto lhs_i_get = input_list.get(lhs_list_offset + i); - auto rhs_i_get = input_list.get(rhs_list_offset + i); - auto lhs_sinv_i_get = input_list.get(lhs_sinv_list_offset + i); - auto rhs_sinv_i_get = input_list.get(rhs_sinv_list_offset + i); - auto out_i_get = output_list.get(out_list_offset + i); - Buffer_Type lhs_i = lhs_i_get.value(); - Buffer_Type rhs_i = rhs_i_get.value(); - Buffer_Type lhs_sinv_i = lhs_sinv_i_get.value(); - Buffer_Type rhs_sinv_i = rhs_sinv_i_get.value(); - Result_Type out_i = out_i_get.value(); + Buffer_Type lhs_i = input_list.get(lhs_list_offset + i).value(); + Buffer_Type rhs_i = input_list.get(rhs_list_offset + i).value(); + Buffer_Type lhs_sinv_i = input_list.get(lhs_sinv_list_offset + i).value(); + Buffer_Type rhs_sinv_i = input_list.get(rhs_sinv_list_offset + i).value(); + Result_Type out_i = output_list.get(out_list_offset + i).value(); DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type()); DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type()); @@ -288,9 +283,10 @@ Error_Type GroupedGemmNewFFI(cudaStream_t stream, Variadic_Buffer_Type input_lis auto lhs_shape_ = lhs_i.dimensions(); auto rhs_shape_ = rhs_i.dimensions(); - size_t m = lhs_shape_[0]; - size_t n = rhs_shape_[0]; - size_t k = lhs_shape_[1]; + // lhs and rhs has shape [1, m, k] and [1, n, k] + size_t m = lhs_shape_[1]; + size_t n = rhs_shape_[1]; + size_t k = lhs_shape_[2]; auto lhs_shape = std::vector{m, k}; auto rhs_shape = std::vector{n, k}; @@ -330,9 +326,9 @@ Error_Type GroupedGemmNewFFI(cudaStream_t stream, Variadic_Buffer_Type input_lis /* printf("[DEBUG] i: %d, lhs_dtype, rhs_dtype, out_dtype = %d, %d, %d\n", i, lhs_dtype, rhs_dtype, out_dtype); - printf("[DEBUG] lhs_shape: %d %d, ptr = %p\n", lhs_shape_[0], lhs_shape_[1], lhs_ptr); - printf("[DEBUG] rhs_shape: %d %d, ptr = %p\n", rhs_shape_[0], rhs_shape_[1], rhs_ptr); - printf("[DEBUG] out_shape: %d %d, ptr = %p\n", out_i->dimensions()[0], out_i->dimensions()[1], out_ptr); + printf("[DEBUG] lhs_shape: %d %d, ptr = %p\n", lhs_shape_[1], lhs_shape_[2], lhs_ptr); + printf("[DEBUG] rhs_shape: %d %d, ptr = %p\n", rhs_shape_[1], rhs_shape_[2], rhs_ptr); + printf("[DEBUG] out_shape: %d %d, ptr = %p\n", out_i->dimensions()[1], out_i->dimensions()[2], out_ptr); printf("[DEBUG] lhs_sinv_shape: %d %d, ptr = %p\n", lhs_sinv_shape[0], lhs_sinv_shape[1], lhs_sinv_ptr); printf("[DEBUG] rhs_sinv_shape: %d %d, ptr = %p\n", rhs_sinv_shape[0], rhs_sinv_shape[1], rhs_sinv_ptr); fflush(stdout); From 408acde0a155f6dea66c3c6409d56ca28d012bc5 Mon Sep 17 00:00:00 2001 From: Hua Huang Date: Tue, 8 Apr 2025 12:57:58 -0700 Subject: [PATCH 05/13] Clean up code Signed-off-by: Hua Huang --- transformer_engine/jax/cpp_extensions/gemm.py | 336 +----------------- transformer_engine/jax/csrc/extensions.h | 3 - .../jax/csrc/extensions/gemm.cpp | 265 +------------- .../jax/csrc/extensions/pybind.cpp | 7 - 4 files changed, 18 insertions(+), 593 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 27e2941f1e..e3a8c094fb 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -21,7 +21,7 @@ ) -__all__ = ["gemm", "grouped_gemm", "grouped_add"] +__all__ = ["gemm", "grouped_gemm"] num_cublas_streams = 4 @@ -41,104 +41,6 @@ class GroupedGemmPrimitive(BasePrimitive): name = "te_grouped_gemm_ffi" multiple_results = True - impl_static_args = (6, 7, 8, 9) - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract( - lhs_contig_aval, - lhs_scale_contig_aval, - rhs_contig_aval, - rhs_scale_contig_aval, - bias_contig_aval, - dim_list_aval, - *, - num_gemms, - scaling_mode, - out_dtype, - out_flat_size, - ): - del lhs_contig_aval, lhs_scale_contig_aval - del rhs_contig_aval, rhs_scale_contig_aval - del bias_contig_aval, dim_list_aval - del num_gemms, scaling_mode - out_flat_aval = jax.core.ShapedArray(shape=(out_flat_size,), dtype=out_dtype) - wkspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams - wkspace_aval = jax.core.ShapedArray(shape=(wkspace_size,), dtype=jnp.uint8) - return (out_flat_aval, wkspace_aval) - - @staticmethod - def outer_abstract(*args, **kwargs): - (out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs) - return out_aval - - @staticmethod - def lowering( - ctx, - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, - *, - num_gemms, - scaling_mode, - out_dtype, - out_flat_size, - ) -> jnp.ndarray: - del out_dtype, out_flat_size - return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( - ctx, - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, - num_gemms=num_gemms, - scaling_mode=scaling_mode.value, - ) - - @staticmethod - def impl( - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, - num_gemms, - scaling_mode, - out_dtype, - out_flat_size, - ) -> jnp.ndarray: - assert GroupedGemmPrimitive.inner_primitive is not None - out = GroupedGemmPrimitive.inner_primitive.bind( - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, - num_gemms=num_gemms, - scaling_mode=scaling_mode, - out_dtype=out_dtype, - out_flat_size=out_flat_size, - ) - return out[0] # out is [out_flat, wkspace], only return out_flat - - -register_primitive(GroupedGemmPrimitive) - -class GroupedGemmPrimitiveNew(BasePrimitive): - """ - Primitive for grouped GEMM - """ - - name = "te_grouped_gemm_new_ffi" - multiple_results = True impl_static_args = () inner_primitive = None outer_primitive = None @@ -149,21 +51,24 @@ def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias): A_list = args[0 : num_gemms] B_list = args[num_gemms : 2 * num_gemms] # A and B have shapes [1, m, k] and [1, n, k] - out_list_aval = tuple(jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype) for A, B in zip(A_list, B_list)) + out_list_aval = tuple( + jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype) + for A, B in zip(A_list, B_list) + ) workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) return (*out_list_aval, workspace_aval) @staticmethod def outer_abstract(*args, **kwargs): - (out_aval, _) = GroupedGemmPrimitiveNew.abstract(*args, **kwargs) + (out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs) return out_aval @staticmethod def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias): del out_dtype workspace_size = get_cublas_workspace_size_bytes() - return jax.ffi.ffi_lowering(GroupedGemmPrimitiveNew.name)( + return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( ctx, *args, num_gemms=num_gemms, @@ -174,8 +79,8 @@ def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias): @staticmethod def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias): - assert GroupedGemmPrimitiveNew.inner_primitive is not None - out = GroupedGemmPrimitiveNew.inner_primitive.bind( + assert GroupedGemmPrimitive.inner_primitive is not None + out = GroupedGemmPrimitive.inner_primitive.bind( *args, num_gemms=num_gemms, scaling_mode=scaling_mode.value, @@ -185,7 +90,7 @@ def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias): return out[:-1] # out is [out_list, wkspace], only return out_list -register_primitive(GroupedGemmPrimitiveNew) +register_primitive(GroupedGemmPrimitive) def _shape_normalization(x, dimension_numbers, already_transposed: bool = False): orig_order = list(range(x.ndim)) @@ -423,165 +328,15 @@ def swizzled_scale(scales): return scales -def grouped_gemm_old( - lhs_list: List[Union[jnp.ndarray, ScaledTensor]], - rhs_list: List[Union[jnp.ndarray, ScaledTensor]], - contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]], - bias_list: List[jnp.ndarray] = None, -) -> List[jnp.ndarray]: - """Grouped GEMM for multiple pairs of tensors.""" - assert ( - len(lhs_list) == len(rhs_list) == len(contracting_dims_list) - ), "lhs_list, rhs_list, contracting_dims_list must have the same length" - - # Flatten inputs and save their shapes - num_gemms = len(lhs_list) - out_flat_size = 0 - dims = [] - lhs_contig_ = [] - rhs_contig_ = [] - lhs_scale_inv_contig_ = [] - rhs_scale_inv_contig_ = [] - bias_contig_ = [] - out_offsets = [] - remain_shape_list = [] - num_gemms = len(lhs_list) - for i in range(num_gemms): - lhs = lhs_list[i] - rhs = rhs_list[i] - contracting_dims = contracting_dims_list[i] - dim_nums = (contracting_dims, ((), ())) - if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): - scaling_mode = lhs.scaling_mode - lhs_shape = lhs.data.shape - rhs_shape = rhs.data.shape - out_dtype = lhs.dq_dtype - # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout - if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: - assert not ( - lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 - ), "FP8 GEMM does not support E5M2 * E5M2" - ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims - if lhs.data_layout == "T": - lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim - if rhs.data_layout == "T": - rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim - dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) - else: - # For jnp.ndarray, only consider contracting_dims, data_layout is always NN - scaling_mode = ScalingMode.NO_SCALING - lhs_shape = lhs.shape - rhs_shape = rhs.shape - out_dtype = lhs.dtype - - (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums - lhs_dn = (lhs_contract, lhs_batch) - rhs_dn = (rhs_contract, rhs_batch) - - lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) - rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) - - if scaling_mode == ScalingMode.NO_SCALING: - lhs_3d = _shape_normalization(lhs, lhs_dn) - rhs_3d = _shape_normalization(rhs, rhs_dn) - elif scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: - lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") - rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") - elif scaling_mode == ScalingMode.MXFP8_1D_SCALING: - lhs_3d = _shape_normalization(lhs.data, lhs_dn) - rhs_3d = _shape_normalization(rhs.data, rhs_dn) - lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) - rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn) - lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze()) - rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze()) - else: - raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}") - - # Note: if _shape_normalization() is updated to support non-TN, need to update here - # already_transposed doesn't matter for the output shape - # x.shape = [B, D1, D2] - # contracting_dims = (2, ) --> output.shape = [1, B * D1, D2] - # contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1] - # x.shape = [D1, D2] - # contracting_dims = (1, ) --> output.shape = [1, D1, D2] - # contracting_dims = (0, ) --> output.shape = [1, D2, D1] - bm = lhs_remain_shape[0] - bn = rhs_remain_shape[0] - kl = lhs_3d.shape[-1] - kr = rhs_3d.shape[-1] - remain_shape_list.append(((bm,), (bn,))) - assert kl == kr, f"lhs_3d.shape[-1] ({kl}) != rhs_3d.shape[-1] ({kr})" - k = kl - - if (bm % 16 != 0) or (bn % 16 != 0) or (k % 16 != 0): - print(f"grouped_gemm input pair {i} has invalid problem shape for lowering: ") - print( - f"m = {bm}, n = {bn}, k = {k}; cuBLAS requires the problem shapes being multiples" - " of 16" - ) - assert bm % 16 == 0 and bn % 16 == 0 and k % 16 == 0 - - dims.append((bm, bn, k)) - lhs_contig_.append(lhs_3d.reshape(-1)) - rhs_contig_.append(rhs_3d.reshape(-1)) - if scaling_mode == ScalingMode.NO_SCALING: - lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) - rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: - lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1)) - rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1)) - if scaling_mode == ScalingMode.MXFP8_1D_SCALING: - lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1)) - rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1)) - if bias_list is not None: - bias_contig_.append(bias_list[i].reshape(-1)) - out_flat_size += bm * bn - out_offsets.append(out_flat_size) - - lhs_contig = jnp.concatenate(lhs_contig_) - rhs_contig = jnp.concatenate(rhs_contig_) - lhs_scale_inv_contig = jnp.concatenate(lhs_scale_inv_contig_) - rhs_scale_inv_contig = jnp.concatenate(rhs_scale_inv_contig_) - bias_contig = jnp.empty(0) if bias_list is None else jnp.concatenate(bias_contig_) - dim_list = jnp.array(dims, dtype=jnp.int32) - - # TE/common does not support NVTE_NO_SCALING yet - # It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16 - if scaling_mode == ScalingMode.NO_SCALING: - scaling_mode = ScalingMode.DELAYED_TENSOR_SCALING - - # Perform batched GEMM on flattened inputs - out_contig = GroupedGemmPrimitive.outer_primitive.bind( - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, - num_gemms=num_gemms, - scaling_mode=scaling_mode.value, - out_dtype=out_dtype, - out_flat_size=out_flat_size, - ) - - # Split the output back into tensors - out_offsets = jnp.array(out_offsets) - out_flat_list = jnp.split(out_contig, out_offsets[:-1]) - out_tensors = [] - for out_flat, (lhs_remain_shape, rhs_remain_shape) in zip(out_flat_list, remain_shape_list): - out_tensors.append(out_flat.reshape(*lhs_remain_shape, *rhs_remain_shape)) - - return out_tensors - -def _get_shape_normalization_contracting_dim_size(x, dimension_numbers): +def get_shape_normalization_contracting_dim_size(x, dimension_numbers): orig_order = list(range(x.ndim)) contracting_dims, batch_dims = dimension_numbers contracting_order = [d for d in orig_order if d in contracting_dims] cols_shape = [x.shape[d] for d in contracting_order] return reduce(operator.mul, cols_shape, 1) -import nvtx -def grouped_gemm_new( + +def grouped_gemm( lhs_list: List[Union[jnp.ndarray, ScaledTensor]], rhs_list: List[Union[jnp.ndarray, ScaledTensor]], contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]], @@ -592,9 +347,6 @@ def grouped_gemm_new( len(lhs_list) == len(rhs_list) == len(contracting_dims_list) ), "lhs_list, rhs_list, contracting_dims_list must have the same length" - rng_gg = nvtx.start_range("grouped_gemm_new", color="blue") - - rng_shape_check = nvtx.start_range("shape_check", color="yellow") num_gemms = len(lhs_list) lhs_dn_list = [] rhs_dn_list = [] @@ -646,8 +398,8 @@ def grouped_gemm_new( bn = rhs_remain_shape[0] _lhs_data = lhs if scaling_mode == ScalingMode.NVTE_NO_SCALING else lhs.data _rhs_data = rhs if scaling_mode == ScalingMode.NVTE_NO_SCALING else rhs.data - kl = _get_shape_normalization_contracting_dim_size(_lhs_data, lhs_dn) - kr = _get_shape_normalization_contracting_dim_size(_rhs_data, rhs_dn) + kl = get_shape_normalization_contracting_dim_size(_lhs_data, lhs_dn) + kr = get_shape_normalization_contracting_dim_size(_rhs_data, rhs_dn) assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}" if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0): print(f"grouped_gemm input pair {i} has invalid problem shape for lowering: ") @@ -656,9 +408,7 @@ def grouped_gemm_new( " of 16" ) assert bm % 16 == 0 and bn % 16 == 0 and kl % 16 == 0 - nvtx.end_range(rng_shape_check) - rng_make_tuple = nvtx.start_range("make_tuple", color="purple") # Note: do not .squeeze() for the output of _shape_normalization, it will trigger a D2D memcpy if scaling_mode == ScalingMode.NVTE_NO_SCALING: lhs_list_ = tuple( @@ -702,15 +452,13 @@ def grouped_gemm_new( else: raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}") bias_list_ = [] if bias_list is None else tuple(bias_list) - nvtx.end_range(rng_make_tuple) # TE/common does not support NVTE_NO_SCALING yet # It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16 if scaling_mode == ScalingMode.NVTE_NO_SCALING: scaling_mode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING - rng_prim = nvtx.start_range("grouped_gemm_new_ffi", color="green") - out_list = GroupedGemmPrimitiveNew.outer_primitive.bind( + out_list = GroupedGemmPrimitive.outer_primitive.bind( *lhs_list_, *rhs_list_, *lhs_sinv_list_, @@ -721,57 +469,5 @@ def grouped_gemm_new( out_dtype=out_dtype, has_bias=1 if bias_list is not None else 0, ) - nvtx.end_range(rng_prim) - - nvtx.end_range(rng_gg) return out_list - -def grouped_gemm( - lhs_list: List[Union[jnp.ndarray, ScaledTensor]], - rhs_list: List[Union[jnp.ndarray, ScaledTensor]], - contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]], - bias_list: List[jnp.ndarray] = None, -) -> List[jnp.ndarray]: - grouped_gemm_ = grouped_gemm_new - #grouped_gemm_ = grouped_gemm_old - #print(f"Using {grouped_gemm_.__name__} for grouped_gemm") - return grouped_gemm_(lhs_list, rhs_list, contracting_dims_list, bias_list) - -class GroupedAddPrimitive(BasePrimitive): - multiple_results = True - impl_static_args = () - inner_primitive = None - outer_primitive = None - name = "te_grouped_add_ffi" - - @staticmethod - def abstract(*args, out_dtype): - num_pairs = len(args) // 2 - A_list = args[:num_pairs] - workspace_aval = jax.core.ShapedArray(shape=(1024,), dtype=jnp.uint8) - out_aval = tuple(jax.core.ShapedArray(A.shape, dtype=out_dtype) for A in A_list) - return (*out_aval, workspace_aval) - - @staticmethod - def outer_abstract(*args, **kwargs): - return GroupedAddPrimitive.abstract(*args, **kwargs) - - @staticmethod - def lowering(ctx, *args, out_dtype): - del out_dtype - num_pairs = len(args) // 2 - return jax.ffi.ffi_lowering(GroupedAddPrimitive.name)( - ctx, *args, num_pairs=num_pairs - ) - - @staticmethod - def impl(*args, out_dtype): - assert GroupedAddPrimitive.inner_primitive is not None - out = GroupedAddPrimitive.inner_primitive.bind(*args, out_dtype=out_dtype) - return out[:-1] - -register_primitive(GroupedAddPrimitive) - -def grouped_add(A_list, B_list, out_dtype): - return GroupedAddPrimitive.outer_primitive.bind(*A_list, *B_list, out_dtype=out_dtype) \ No newline at end of file diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 3715f4bfc4..aaaf57fab7 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -118,9 +118,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( // Grouped GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); -XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmNewHandler); - -XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedAddHandler); // Cudnn helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 34d1ba4009..c64fc87a79 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -17,200 +17,7 @@ namespace jax { constexpr static size_t MXFP8_BLOCK_SIZE = 32; -// Note: we only support TN-GEMM for now (TN in cuBLASLt == NT in JAX) -Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lhs_sinv_ptr, - const DType &lhs_sinv_dtype, uint8_t *rhs_ptr, const DType &rhs_dtype, - uint8_t *rhs_sinv_ptr, const DType &rhs_sinv_dtype, uint8_t *bias_ptr, - const DType &bias_dtype, uint8_t *out_ptr, const DType &out_dtype, - uint8_t *workspace_ptr, const size_t workspace_size, size_t num_gemms, - int32_t *dim_list_ptr, const JAXX_Scaling_Mode scaling_mode, - cudaStream_t stream) { - size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); - size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); - size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); - size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); - size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); - size_t out_dtype_bytes = te_dtype_bytes(out_dtype); - NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); - NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, - "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); - - size_t dim_list_bytes = sizeof(int32_t) * 3 * num_gemms; - std::unique_ptr dim_list_host = std::make_unique(3 * num_gemms); - - cudaMemcpyAsync(dim_list_host.get(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); - - // Notes on matrix layouts and transpose: - // Jax uses row-major data_layout, on entering this function, each input matrix pair: - // A: row-major with size [m, k], - // B: row-major with size [n, k], needs transpose, - // on exiting this function, JAX expect: - // C: row-major with size [m, n]. - // cuBLAS uses column-major data_layout, in this view, each input matrix pair: - // A: column-major with size [k, m], needs transpose, - // B: column-major with size [k, n]. - // If we call cuBLAS GEMM for A * B, the output will be: - // C: column-major with size [m, n] --> row-major with size [n, m]. - // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. - - bool trans_lhs = true; - bool trans_rhs = false; - auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); - bool grad = false; - bool accumulate = false; - bool use_split_accumulator = false; - - // These lists are to keep the TensorWrapper objects alive - std::vector lhs_wrapper_list; - std::vector rhs_wrapper_list; - std::vector bias_wrapper_list; - std::vector pre_gelu_wrapper_list; - std::vector out_wrapper_list; - std::vector workspace_wrapper_list; - - // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM - std::vector lhs_list; - std::vector rhs_list; - std::vector bias_list; - std::vector pre_gelu_list; - std::vector out_list; - std::vector workspace_list; - - for (int i = 0; i < num_gemms; i++) { - size_t m = dim_list_host[i * 3]; - size_t n = dim_list_host[i * 3 + 1]; - size_t k = dim_list_host[i * 3 + 2]; - - auto lhs_shape = std::vector{m, k}; - auto rhs_shape = std::vector{n, k}; - auto out_shape = std::vector{n, m}; - auto lhs_sinv_shape = std::vector{1, 1}; - auto rhs_sinv_shape = std::vector{1, 1}; - - auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - lhs_i.set_rowwise_data(static_cast(lhs_ptr), lhs_dtype, lhs_shape); - rhs_i.set_rowwise_data(static_cast(rhs_ptr), rhs_dtype, rhs_shape); - - if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { - lhs_i.set_rowwise_scale_inv(static_cast(lhs_sinv_ptr), DType::kFloat32, - std::vector{1}); - rhs_i.set_rowwise_scale_inv(static_cast(rhs_sinv_ptr), DType::kFloat32, - std::vector{1}); - } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { - NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", - MXFP8_BLOCK_SIZE, k); - size_t sinv_k = k / MXFP8_BLOCK_SIZE; - lhs_sinv_shape[0] = m; - lhs_sinv_shape[1] = sinv_k; - rhs_sinv_shape[0] = n; - rhs_sinv_shape[1] = sinv_k; - - // Note: the scale_inv array should have been swizzled in Python before lowering - lhs_i.set_rowwise_scale_inv(static_cast(lhs_sinv_ptr), DType::kFloat8E8M0, - lhs_sinv_shape); - rhs_i.set_rowwise_scale_inv(static_cast(rhs_sinv_ptr), DType::kFloat8E8M0, - rhs_sinv_shape); - } else { - NVTE_ERROR("Unsupported scaling mode: ", static_cast(scaling_mode)); - } - lhs_wrapper_list.push_back(std::move(lhs_i)); - rhs_wrapper_list.push_back(std::move(rhs_i)); - - auto out_i = TensorWrapper(static_cast(out_ptr), out_shape, out_dtype); - lhs_ptr += m * k * lhs_dtype_bytes; - rhs_ptr += n * k * rhs_dtype_bytes; - out_ptr += m * n * out_dtype_bytes; - lhs_sinv_ptr += lhs_sinv_shape[0] * lhs_sinv_shape[1] * lhs_sinv_dtype_bytes; - rhs_sinv_ptr += rhs_sinv_shape[0] * rhs_sinv_shape[1] * rhs_sinv_dtype_bytes; - - void *pre_gelu_ptr = nullptr; - auto bias_shape = std::vector{0}; - auto pre_gelu_shape = std::vector{0}; - if (bias_ptr != nullptr) bias_shape[0] = n; - auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); - if (bias_ptr != nullptr) bias_ptr += n * bias_dtype_bytes; - auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype); - - out_wrapper_list.push_back(std::move(out_i)); - bias_wrapper_list.push_back(std::move(bias_i)); - pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); - - lhs_list.push_back(lhs_wrapper_list.back().data()); - rhs_list.push_back(rhs_wrapper_list.back().data()); - bias_list.push_back(bias_wrapper_list.back().data()); - pre_gelu_list.push_back(pre_gelu_wrapper_list.back().data()); - out_list.push_back(out_wrapper_list.back().data()); - } - - auto workspace_shape = std::vector{workspace_size}; - for (int i = 0; i < num_streams; i++) { - auto workspace_i = - TensorWrapper(static_cast(workspace_ptr), workspace_shape, DType::kByte); - workspace_wrapper_list.push_back(std::move(workspace_i)); - workspace_list.push_back(workspace_wrapper_list.back().data()); - workspace_ptr += workspace_size; - } - - nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), - pre_gelu_list.data(), num_gemms, trans_lhs, trans_rhs, grad, - workspace_list.data(), accumulate, use_split_accumulator, - num_math_sm, stream); - - return ffi_with_cuda_error_check(); -} - -Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten, - Buffer_Type lhs_sinv_flatten, Buffer_Type rhs_flatten, - Buffer_Type rhs_sinv_flatten, Buffer_Type bias_flatten, - Buffer_Type dim_list, Result_Type out_flatten, - Result_Type workspace_flatten, int64_t num_gemms, - JAXX_Scaling_Mode scaling_mode) { - // Inputs - auto lhs_ptr = reinterpret_cast(lhs_flatten.untyped_data()); - auto rhs_ptr = reinterpret_cast(rhs_flatten.untyped_data()); - auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv_flatten.untyped_data()); - auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv_flatten.untyped_data()); - auto bias_ptr = reinterpret_cast(bias_flatten.untyped_data()); - auto dim_list_ptr = reinterpret_cast(dim_list.untyped_data()); - auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_flatten.element_type()); - auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_flatten.element_type()); - auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv_flatten.element_type()); - auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv_flatten.element_type()); - auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias_flatten.element_type()); - - // Outputs - auto out_ptr = reinterpret_cast(out_flatten->untyped_data()); - auto out_dtype = convert_ffi_datatype_to_te_dtype(out_flatten->element_type()); - auto workspace_ptr = reinterpret_cast(workspace_flatten->untyped_data()); - auto workspace_size = workspace_flatten->dimensions().back() / num_streams; - - return GroupedGemmImpl(lhs_ptr, lhs_dtype, lhs_sinv_ptr, lhs_sinv_dtype, rhs_ptr, rhs_dtype, - rhs_sinv_ptr, rhs_sinv_dtype, bias_ptr, bias_dtype, out_ptr, out_dtype, - workspace_ptr, workspace_size, num_gemms, dim_list_ptr, scaling_mode, - stream); -} - -XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, - FFI::Bind() - .Ctx() // stream - .Arg() // lhs_flatten - .Arg() // lhs_sinv_flatten - .Arg() // rhs_flatten - .Arg() // rhs_sinv_flatten - .Arg() // bias_flatten - .Arg() // dim_list - .Ret() // out_flatten - .Ret() // workspace_flatten - .Attr("num_gemms") - .Attr("scaling_mode"), - FFI_CudaGraph_Traits); - - -Error_Type GroupedGemmNewFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, +Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, Variadic_Result_Type output_list, int64_t num_gemms, int64_t scaling_mode, int64_t has_bias, int64_t workspace_size) { // Notes on matrix layouts and transpose: @@ -249,10 +56,6 @@ Error_Type GroupedGemmNewFFI(cudaStream_t stream, Variadic_Buffer_Type input_lis std::vector out_list; std::vector workspace_list; - //printf("[DEBUG] num_gemms, scaling_mode, has_bias, workspace_size: %ld %ld %ld %ld\n", - // num_gemms, scaling_mode, has_bias, workspace_size); - //fflush(stdout); - int lhs_list_offset = 0; int rhs_list_offset = num_gemms; int lhs_sinv_list_offset = 2 * num_gemms; @@ -324,16 +127,6 @@ Error_Type GroupedGemmNewFFI(cudaStream_t stream, Variadic_Buffer_Type input_lis NVTE_ERROR("Unsupported scaling mode: ", scaling_mode); } - /* - printf("[DEBUG] i: %d, lhs_dtype, rhs_dtype, out_dtype = %d, %d, %d\n", i, lhs_dtype, rhs_dtype, out_dtype); - printf("[DEBUG] lhs_shape: %d %d, ptr = %p\n", lhs_shape_[1], lhs_shape_[2], lhs_ptr); - printf("[DEBUG] rhs_shape: %d %d, ptr = %p\n", rhs_shape_[1], rhs_shape_[2], rhs_ptr); - printf("[DEBUG] out_shape: %d %d, ptr = %p\n", out_i->dimensions()[1], out_i->dimensions()[2], out_ptr); - printf("[DEBUG] lhs_sinv_shape: %d %d, ptr = %p\n", lhs_sinv_shape[0], lhs_sinv_shape[1], lhs_sinv_ptr); - printf("[DEBUG] rhs_sinv_shape: %d %d, ptr = %p\n", rhs_sinv_shape[0], rhs_sinv_shape[1], rhs_sinv_ptr); - fflush(stdout); - */ - auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype); void *pre_gelu_ptr = nullptr; auto bias_shape = std::vector{0}; @@ -371,8 +164,6 @@ Error_Type GroupedGemmNewFFI(cudaStream_t stream, Variadic_Buffer_Type input_lis workspace_list.push_back(workspace_wrapper_list.back().data()); workspace_ptr += workspace_size; } - //printf("[DEBUG] workspace packing done\n"); - //fflush(stdout); nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), pre_gelu_list.data(), num_gemms, trans_lhs, trans_rhs, grad, @@ -382,7 +173,7 @@ Error_Type GroupedGemmNewFFI(cudaStream_t stream, Variadic_Buffer_Type input_lis return ffi_with_cuda_error_check(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmNewHandler, GroupedGemmNewFFI, +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, FFI::Bind() .Ctx() // stream .RemainingArgs() // input list @@ -393,57 +184,5 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmNewHandler, GroupedGemmNewFFI, .Attr("workspace_size"), FFI_CudaGraph_Traits); -Error_Type GroupedAddFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, - Variadic_Result_Type output_list, int64_t num_pairs) { - for (size_t i = 0; i < static_cast(num_pairs); i++) { - auto A_i_get = input_list.get(i); - auto B_i_get = input_list.get(num_pairs + i); - auto C_i_get = output_list.get(i); - Buffer_Type A_i = A_i_get.value(); - Buffer_Type B_i = B_i_get.value(); - Result_Type C_i = C_i_get.value(); - auto A_ptr = reinterpret_cast(A_i.untyped_data()); - auto B_ptr = reinterpret_cast(B_i.untyped_data()); - auto C_ptr = reinterpret_cast(C_i->untyped_data()); - auto A_shape = A_i.dimensions(); - auto B_shape = B_i.dimensions(); - auto C_shape = C_i->dimensions(); - printf("Pair %ld: A shape ", i); - for (size_t j = 0; j < A_shape.size(); j++) printf("%ld ", A_shape[j]); - printf("; B shape "); - for (size_t j = 0; j < B_shape.size(); j++) printf("%ld ", B_shape[j]); - printf("; C shape "); - for (size_t j = 0; j < C_shape.size(); j++) printf("%ld ", C_shape[j]); - printf("\n"); - size_t A_size = product(A_shape); - size_t B_size = product(B_shape); - size_t C_size = product(C_shape); - float *A_ptr_host = (float *) malloc(A_size * sizeof(float)); - float *B_ptr_host = (float *) malloc(B_size * sizeof(float)); - float *C_ptr_host = (float *) malloc(C_size * sizeof(float)); - cudaMemcpyAsync(A_ptr_host, A_ptr, A_size * sizeof(float), cudaMemcpyDeviceToHost, stream); - cudaMemcpyAsync(B_ptr_host, B_ptr, B_size * sizeof(float), cudaMemcpyDeviceToHost, stream); - cudaMemcpyAsync(C_ptr_host, C_ptr, C_size * sizeof(float), cudaMemcpyDeviceToHost, stream); - cudaStreamSynchronize(stream); - for (size_t j = 0; j < A_size; j++) - C_ptr_host[j] = A_ptr_host[j] + B_ptr_host[j]; - cudaMemcpyAsync(C_ptr, C_ptr_host, C_size * sizeof(float), cudaMemcpyHostToDevice, stream); - cudaStreamSynchronize(stream); - free(A_ptr_host); - free(B_ptr_host); - free(C_ptr_host); - } - return ffi_with_cuda_error_check(); -} - -XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedAddHandler, GroupedAddFFI, - FFI::Bind() - .Ctx() // stream - .RemainingArgs() // input list - .RemainingRets() // output list - .Attr("num_pairs"), - FFI_CudaGraph_Traits); - - } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index dcc218dced..5c165cccb6 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -59,13 +59,6 @@ pybind11::dict Registrations() { pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); - dict["te_grouped_gemm_new_ffi"] = - pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), - pybind11::arg("execute") = EncapsulateFFI(GroupedGemmNewHandler)); - - - dict["te_grouped_add_ffi"] = EncapsulateFFI(GroupedAddHandler); - return dict; } From 167ef32a14a7513c3e0ff7c3c0d73124244905ee Mon Sep 17 00:00:00 2001 From: Hua Huang Date: Tue, 8 Apr 2025 13:07:01 -0700 Subject: [PATCH 06/13] . Signed-off-by: Hua Huang --- transformer_engine/jax/csrc/extensions/gemm.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index c64fc87a79..b4481673b3 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -18,8 +18,8 @@ namespace jax { constexpr static size_t MXFP8_BLOCK_SIZE = 32; Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, - Variadic_Result_Type output_list, int64_t num_gemms, - int64_t scaling_mode, int64_t has_bias, int64_t workspace_size) { + Variadic_Result_Type output_list, int64_t num_gemms, + int64_t scaling_mode, int64_t has_bias, int64_t workspace_size) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major with size [m, k], From eceff9014ccb4da5a05edc5d475d9dd3d37baa73 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Apr 2025 20:09:11 +0000 Subject: [PATCH 07/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/mytest.py | 2 +- transformer_engine/jax/cpp_extensions/gemm.py | 19 +++++------- .../jax/csrc/extensions/gemm.cpp | 29 +++++++++---------- 3 files changed, 23 insertions(+), 27 deletions(-) diff --git a/tests/jax/mytest.py b/tests/jax/mytest.py index 3e77181d8b..bcdf69ce7d 100644 --- a/tests/jax/mytest.py +++ b/tests/jax/mytest.py @@ -22,7 +22,7 @@ B_list.append(B_i) ref_C_list.append(ref_C_i) -#pdb.set_trace() +# pdb.set_trace() C_list = tex.grouped_add(A_list, B_list, out_dtype) for i in range(len(shape_list)): assert_allclose(C_list[i], ref_C_list[i]) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index e3a8c094fb..d528437a1d 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -48,11 +48,11 @@ class GroupedGemmPrimitive(BasePrimitive): @staticmethod def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias): del scaling_mode, has_bias - A_list = args[0 : num_gemms] + A_list = args[0:num_gemms] B_list = args[num_gemms : 2 * num_gemms] # A and B have shapes [1, m, k] and [1, n, k] out_list_aval = tuple( - jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype) + jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype) for A, B in zip(A_list, B_list) ) workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams @@ -74,7 +74,7 @@ def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias): num_gemms=num_gemms, scaling_mode=int(scaling_mode), has_bias=has_bias, - workspace_size=workspace_size + workspace_size=workspace_size, ) @staticmethod @@ -92,6 +92,7 @@ def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias): register_primitive(GroupedGemmPrimitive) + def _shape_normalization(x, dimension_numbers, already_transposed: bool = False): orig_order = list(range(x.ndim)) contracting_dims, batch_dims = dimension_numbers @@ -412,12 +413,10 @@ def grouped_gemm( # Note: do not .squeeze() for the output of _shape_normalization, it will trigger a D2D memcpy if scaling_mode == ScalingMode.NVTE_NO_SCALING: lhs_list_ = tuple( - _shape_normalization(lhs, lhs_dn) - for lhs, lhs_dn in zip(lhs_list, lhs_dn_list) + _shape_normalization(lhs, lhs_dn) for lhs, lhs_dn in zip(lhs_list, lhs_dn_list) ) rhs_list_ = tuple( - _shape_normalization(rhs, rhs_dn) - for rhs, rhs_dn in zip(rhs_list, rhs_dn_list) + _shape_normalization(rhs, rhs_dn) for rhs, rhs_dn in zip(rhs_list, rhs_dn_list) ) lhs_sinv_list_ = tuple(jnp.ones(1, dtype=jnp.float32) for _ in range(num_gemms)) rhs_sinv_list_ = tuple(jnp.ones(1, dtype=jnp.float32) for _ in range(num_gemms)) @@ -434,12 +433,10 @@ def grouped_gemm( rhs_sinv_list_ = tuple(rhs.scale_inv for rhs in rhs_list) elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: lhs_list_ = tuple( - _shape_normalization(lhs.data, lhs_dn) - for lhs, lhs_dn in zip(lhs_list, lhs_dn_list) + _shape_normalization(lhs.data, lhs_dn) for lhs, lhs_dn in zip(lhs_list, lhs_dn_list) ) rhs_list_ = tuple( - _shape_normalization(rhs.data, rhs_dn) - for rhs, rhs_dn in zip(rhs_list, rhs_dn_list) + _shape_normalization(rhs.data, rhs_dn) for rhs, rhs_dn in zip(rhs_list, rhs_dn_list) ) lhs_sinv_list_ = tuple( swizzled_scale(_shape_normalization(lhs.scale_inv, lhs_dn)) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index b4481673b3..a6159fae5d 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -18,8 +18,8 @@ namespace jax { constexpr static size_t MXFP8_BLOCK_SIZE = 32; Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, - Variadic_Result_Type output_list, int64_t num_gemms, - int64_t scaling_mode, int64_t has_bias, int64_t workspace_size) { + Variadic_Result_Type output_list, int64_t num_gemms, int64_t scaling_mode, + int64_t has_bias, int64_t workspace_size) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major with size [m, k], @@ -56,32 +56,32 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, std::vector out_list; std::vector workspace_list; - int lhs_list_offset = 0; - int rhs_list_offset = num_gemms; + int lhs_list_offset = 0; + int rhs_list_offset = num_gemms; int lhs_sinv_list_offset = 2 * num_gemms; int rhs_sinv_list_offset = 3 * num_gemms; - int bias_list_offset = 4 * num_gemms; - int out_list_offset = 0; + int bias_list_offset = 4 * num_gemms; + int out_list_offset = 0; for (int i = 0; i < num_gemms; i++) { - Buffer_Type lhs_i = input_list.get(lhs_list_offset + i).value(); - Buffer_Type rhs_i = input_list.get(rhs_list_offset + i).value(); + Buffer_Type lhs_i = input_list.get(lhs_list_offset + i).value(); + Buffer_Type rhs_i = input_list.get(rhs_list_offset + i).value(); Buffer_Type lhs_sinv_i = input_list.get(lhs_sinv_list_offset + i).value(); Buffer_Type rhs_sinv_i = input_list.get(rhs_sinv_list_offset + i).value(); - Result_Type out_i = output_list.get(out_list_offset + i).value(); + Result_Type out_i = output_list.get(out_list_offset + i).value(); DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type()); DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type()); DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type()); - void *lhs_ptr = lhs_i.untyped_data(); - void *rhs_ptr = rhs_i.untyped_data(); + void *lhs_ptr = lhs_i.untyped_data(); + void *rhs_ptr = rhs_i.untyped_data(); void *lhs_sinv_ptr = lhs_sinv_i.untyped_data(); void *rhs_sinv_ptr = rhs_sinv_i.untyped_data(); - void *out_ptr = out_i->untyped_data(); + void *out_ptr = out_i->untyped_data(); // Placeholder for bias since it can be empty DType bias_dtype = DType::kFloat32; - void *bias_ptr = nullptr; + void *bias_ptr = nullptr; auto lhs_shape_ = lhs_i.dimensions(); auto rhs_shape_ = rhs_i.dimensions(); @@ -131,8 +131,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, void *pre_gelu_ptr = nullptr; auto bias_shape = std::vector{0}; auto pre_gelu_shape = std::vector{0}; - if (has_bias) - { + if (has_bias) { auto bias_i_get = input_list.get(bias_list_offset + i); Buffer_Type bias_i = bias_i_get.value(); bias_ptr = bias_i.untyped_data(); From c1e33bda68c2036e53a9a6caf47c65a8a81deeae Mon Sep 17 00:00:00 2001 From: Hua Huang Date: Tue, 8 Apr 2025 13:10:29 -0700 Subject: [PATCH 08/13] Remove unused file Signed-off-by: Hua Huang --- tests/jax/mytest.py | 29 ----------------------------- 1 file changed, 29 deletions(-) delete mode 100644 tests/jax/mytest.py diff --git a/tests/jax/mytest.py b/tests/jax/mytest.py deleted file mode 100644 index bcdf69ce7d..0000000000 --- a/tests/jax/mytest.py +++ /dev/null @@ -1,29 +0,0 @@ -import jax -import jax.numpy as jnp -from transformer_engine.jax import cpp_extensions as tex -from utils import assert_allclose -import pdb - -out_dtype = jnp.float32 - -shape_list = [[128, 256], [256, 256], [512, 128]] -A_list = [] -B_list = [] -ref_C_list = [] - -key = jax.random.PRNGKey(0) -subkeys = jax.random.split(key, len(shape_list * 2)) -for i in range(len(shape_list)): - shape = shape_list[i] - A_i = jax.random.normal(subkeys[2 * i], shape, dtype=out_dtype) - B_i = jax.random.normal(subkeys[2 * i + 1], shape, dtype=out_dtype) - ref_C_i = A_i + B_i - A_list.append(A_i) - B_list.append(B_i) - ref_C_list.append(ref_C_i) - -# pdb.set_trace() -C_list = tex.grouped_add(A_list, B_list, out_dtype) -for i in range(len(shape_list)): - assert_allclose(C_list[i], ref_C_list[i]) -print("Grouped add test passed.") From 538451d2edc906fca49ae14106599015ab162352 Mon Sep 17 00:00:00 2001 From: Hua Huang Date: Tue, 8 Apr 2025 13:47:10 -0700 Subject: [PATCH 09/13] Revert to the list append fashion to simplify code Signed-off-by: Hua Huang --- transformer_engine/jax/cpp_extensions/gemm.py | 94 ++++++++----------- 1 file changed, 38 insertions(+), 56 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d528437a1d..d021f1a4cd 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -328,15 +328,6 @@ def swizzled_scale(scales): scales = jnp.transpose(scales, (0, 3, 2, 1, 4)) return scales - -def get_shape_normalization_contracting_dim_size(x, dimension_numbers): - orig_order = list(range(x.ndim)) - contracting_dims, batch_dims = dimension_numbers - contracting_order = [d for d in orig_order if d in contracting_dims] - cols_shape = [x.shape[d] for d in contracting_order] - return reduce(operator.mul, cols_shape, 1) - - def grouped_gemm( lhs_list: List[Union[jnp.ndarray, ScaledTensor]], rhs_list: List[Union[jnp.ndarray, ScaledTensor]], @@ -349,8 +340,11 @@ def grouped_gemm( ), "lhs_list, rhs_list, contracting_dims_list must have the same length" num_gemms = len(lhs_list) - lhs_dn_list = [] - rhs_dn_list = [] + lhs_list_ = [] + rhs_list_ = [] + lhs_sinv_list_ = [] + rhs_sinv_list_ = [] + bias_list_ = [] for i in range(num_gemms): lhs = lhs_list[i] rhs = rhs_list[i] @@ -382,12 +376,28 @@ def grouped_gemm( (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums lhs_dn = (lhs_contract, lhs_batch) rhs_dn = (rhs_contract, rhs_batch) - lhs_dn_list.append(lhs_dn) - rhs_dn_list.append(rhs_dn) lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) + # Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy + if scaling_mode == ScalingMode.NVTE_NO_SCALING: + lhs_3d = _shape_normalization(lhs, lhs_dn) + rhs_3d = _shape_normalization(rhs, rhs_dn) + elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") + rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") + elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + lhs_3d = _shape_normalization(lhs.data, lhs_dn) + rhs_3d = _shape_normalization(rhs.data, rhs_dn) + lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) + rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn) + # swizzled_scale requires a matrix + lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze()) + rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze()) + else: + raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}") + # Note: already_transposed doesn't matter for the output shape # x.shape = [B, D1, D2] # contracting_dims = (2, ) --> output.shape = [1, B * D1, D2] @@ -397,10 +407,8 @@ def grouped_gemm( # contracting_dims = (0, ) --> output.shape = [1, D2, D1] bm = lhs_remain_shape[0] bn = rhs_remain_shape[0] - _lhs_data = lhs if scaling_mode == ScalingMode.NVTE_NO_SCALING else lhs.data - _rhs_data = rhs if scaling_mode == ScalingMode.NVTE_NO_SCALING else rhs.data - kl = get_shape_normalization_contracting_dim_size(_lhs_data, lhs_dn) - kr = get_shape_normalization_contracting_dim_size(_rhs_data, rhs_dn) + kl = lhs_3d.shape[-1] + kr = rhs_3d.shape[-1] assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}" if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0): print(f"grouped_gemm input pair {i} has invalid problem shape for lowering: ") @@ -410,45 +418,19 @@ def grouped_gemm( ) assert bm % 16 == 0 and bn % 16 == 0 and kl % 16 == 0 - # Note: do not .squeeze() for the output of _shape_normalization, it will trigger a D2D memcpy - if scaling_mode == ScalingMode.NVTE_NO_SCALING: - lhs_list_ = tuple( - _shape_normalization(lhs, lhs_dn) for lhs, lhs_dn in zip(lhs_list, lhs_dn_list) - ) - rhs_list_ = tuple( - _shape_normalization(rhs, rhs_dn) for rhs, rhs_dn in zip(rhs_list, rhs_dn_list) - ) - lhs_sinv_list_ = tuple(jnp.ones(1, dtype=jnp.float32) for _ in range(num_gemms)) - rhs_sinv_list_ = tuple(jnp.ones(1, dtype=jnp.float32) for _ in range(num_gemms)) - elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: - lhs_list_ = tuple( - _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") - for lhs, lhs_dn in zip(lhs_list, lhs_dn_list) - ) - rhs_list_ = tuple( - _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") - for rhs, rhs_dn in zip(rhs_list, rhs_dn_list) - ) - lhs_sinv_list_ = tuple(lhs.scale_inv for lhs in lhs_list) - rhs_sinv_list_ = tuple(rhs.scale_inv for rhs in rhs_list) - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: - lhs_list_ = tuple( - _shape_normalization(lhs.data, lhs_dn) for lhs, lhs_dn in zip(lhs_list, lhs_dn_list) - ) - rhs_list_ = tuple( - _shape_normalization(rhs.data, rhs_dn) for rhs, rhs_dn in zip(rhs_list, rhs_dn_list) - ) - lhs_sinv_list_ = tuple( - swizzled_scale(_shape_normalization(lhs.scale_inv, lhs_dn)) - for lhs, lhs_dn in zip(lhs_list, lhs_dn_list) - ) - rhs_sinv_list_ = tuple( - swizzled_scale(_shape_normalization(rhs.scale_inv, rhs_dn)) - for rhs, rhs_dn in zip(rhs_list, rhs_dn_list) - ) - else: - raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}") - bias_list_ = [] if bias_list is None else tuple(bias_list) + lhs_list_.append(lhs_3d) + rhs_list_.append(rhs_3d) + if scaling_mode == ScalingMode.NVTE_NO_SCALING: + lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) + rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + lhs_sinv_list_.append(lhs.scale_inv) + rhs_sinv_list_.append(rhs.scale_inv) + if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + lhs_sinv_list_.append(lhs_scale_inv) + rhs_sinv_list_.append(rhs_scale_inv) + if bias_list is not None: + bias_list_.append(bias_list[i]) # TE/common does not support NVTE_NO_SCALING yet # It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16 From fe7c3bff7bc33c94038cf13fc61f0b57d499f2be Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Apr 2025 20:48:10 +0000 Subject: [PATCH 10/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d021f1a4cd..8dc8df35c0 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -328,6 +328,7 @@ def swizzled_scale(scales): scales = jnp.transpose(scales, (0, 3, 2, 1, 4)) return scales + def grouped_gemm( lhs_list: List[Union[jnp.ndarray, ScaledTensor]], rhs_list: List[Union[jnp.ndarray, ScaledTensor]], From e5bee41ca5ff80de742ed9b05930765d0b76528b Mon Sep 17 00:00:00 2001 From: Hua Huang Date: Thu, 10 Apr 2025 09:52:22 -0700 Subject: [PATCH 11/13] Rebase and update; fix scale_inv shapes in C++ Signed-off-by: Hua Huang --- transformer_engine/jax/cpp_extensions/gemm.py | 57 ++++++++++++------- .../jax/csrc/extensions/gemm.cpp | 49 ++++++++++------ 2 files changed, 67 insertions(+), 39 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 8dc8df35c0..f713978684 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -47,7 +47,29 @@ class GroupedGemmPrimitive(BasePrimitive): @staticmethod def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias): - del scaling_mode, has_bias + """ + Args: + *args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias: + args[ 0 : num_gemms] are the lhs tensors, + args[ num_gemms : 2*num_gemms] are the rhs tensors, + args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors, + args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors, + args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True. + num_gemms: Number of GEMM operations to perform. + scaling_mode: Scaling mode for the GEMM operations. + out_dtype: Data type of the output tensors. + has_bias: Boolean indicating if bias tensors are provided. + + Returns: + A tuple of ShapedArray objects of size num_gemms+1: + ret[0 : num_gemms]: GEMM output tensors, + ret[num_gemms]:workspace tensor. + """ + del scaling_mode + expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms + assert len(args) == expected_num_args, ( + f"Expected {expected_num_args} input arguments, but got {len(args)}" + ) A_list = args[0:num_gemms] B_list = args[num_gemms : 2 * num_gemms] # A and B have shapes [1, m, k] and [1, n, k] @@ -356,8 +378,8 @@ def grouped_gemm( lhs_shape = lhs.data.shape rhs_shape = rhs.data.shape out_dtype = lhs.dq_dtype - # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout - if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + # For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout + if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: assert not ( lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 ), "FP8 GEMM does not support E5M2 * E5M2" @@ -369,7 +391,7 @@ def grouped_gemm( dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) else: # For jnp.ndarray, only consider contracting_dims, data_layout is always NN - scaling_mode = ScalingMode.NVTE_NO_SCALING + scaling_mode = ScalingMode.NO_SCALING lhs_shape = lhs.shape rhs_shape = rhs.shape out_dtype = lhs.dtype @@ -382,13 +404,13 @@ def grouped_gemm( rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) # Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy - if scaling_mode == ScalingMode.NVTE_NO_SCALING: + if scaling_mode == ScalingMode.NO_SCALING: lhs_3d = _shape_normalization(lhs, lhs_dn) rhs_3d = _shape_normalization(rhs, rhs_dn) - elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + elif scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING: lhs_3d = _shape_normalization(lhs.data, lhs_dn) rhs_3d = _shape_normalization(rhs.data, rhs_dn) lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) @@ -412,32 +434,25 @@ def grouped_gemm( kr = rhs_3d.shape[-1] assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}" if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0): - print(f"grouped_gemm input pair {i} has invalid problem shape for lowering: ") - print( - f"m = {bm}, n = {bn}, k = {kl}; cuBLAS requires the problem shapes being multiples" - " of 16" - ) - assert bm % 16 == 0 and bn % 16 == 0 and kl % 16 == 0 + print("grouped_gemm input pair {i} has invalid problem shape for lowering: ") + print(f"m = {bm}, n = {bn}, k = {kl}; ") + print("cuBLAS requires the problem shapes being multiples of 16") + assert (bm % 16 == 0) and (bn % 16 == 0) and (kl % 16 == 0) lhs_list_.append(lhs_3d) rhs_list_.append(rhs_3d) - if scaling_mode == ScalingMode.NVTE_NO_SCALING: + if scaling_mode == ScalingMode.NO_SCALING: lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: lhs_sinv_list_.append(lhs.scale_inv) rhs_sinv_list_.append(rhs.scale_inv) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: lhs_sinv_list_.append(lhs_scale_inv) rhs_sinv_list_.append(rhs_scale_inv) if bias_list is not None: bias_list_.append(bias_list[i]) - # TE/common does not support NVTE_NO_SCALING yet - # It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16 - if scaling_mode == ScalingMode.NVTE_NO_SCALING: - scaling_mode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING - out_list = GroupedGemmPrimitive.outer_primitive.bind( *lhs_list_, *rhs_list_, diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index a6159fae5d..cc5f90ff0b 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -15,11 +15,10 @@ namespace transformer_engine { namespace jax { -constexpr static size_t MXFP8_BLOCK_SIZE = 32; - Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, - Variadic_Result_Type output_list, int64_t num_gemms, int64_t scaling_mode, - int64_t has_bias, int64_t workspace_size) { + Variadic_Result_Type output_list, int64_t num_gemms, + JAXX_Scaling_Mode scaling_mode, int64_t has_bias, + int64_t workspace_size) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major with size [m, k], @@ -33,6 +32,18 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, // C: column-major with size [m, n] --> row-major with size [n, m]. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. + if (num_gemms <= 0) { + return ffi_with_cuda_error_check(); + } + size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms; + size_t expected_output_size = num_gemms + 1; + size_t actual_input_size = input_list.size(); + size_t actual_output_size = output_list.size(); + NVTE_CHECK(actual_input_size == expected_input_size, + "Expected %zu input tensors, got %zu", expected_input_size, actual_input_size); + NVTE_CHECK(actual_output_size == expected_output_size, + "Expected %zu output tensors, got %zu", expected_output_size, actual_output_size); + bool trans_lhs = true; bool trans_rhs = false; auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); @@ -97,23 +108,25 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, auto lhs_sinv_shape = std::vector{1, 1}; auto rhs_sinv_shape = std::vector{1, 1}; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, nullptr, nullptr, + if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || + scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { + float *amax_dptr = nullptr; + float *scale_dptr = nullptr; + auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr, reinterpret_cast(lhs_sinv_ptr)); - auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, nullptr, nullptr, + auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr, reinterpret_cast(rhs_sinv_ptr)); lhs_wrapper_list.push_back(std::move(lhs_i_)); rhs_wrapper_list.push_back(std::move(rhs_i_)); - } else if (scaling_mode == NVTE_MXFP8_1D_SCALING) { - NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", - MXFP8_BLOCK_SIZE, k); - size_t sinv_k = k / MXFP8_BLOCK_SIZE; - lhs_sinv_shape[0] = m; - lhs_sinv_shape[1] = sinv_k; - rhs_sinv_shape[0] = n; - rhs_sinv_shape[1] = sinv_k; - + } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { // Note: the scale_inv array should have been swizzled in Python before lowering + auto lhs_sinv_shape_ = lhs_sinv_i.dimensions(); + auto rhs_sinv_shape_ = rhs_sinv_i.dimensions(); + for (int i = 0; i < 2; i++) { + lhs_sinv_shape[i] = lhs_sinv_shape_[i]; + rhs_sinv_shape[i] = rhs_sinv_shape_[i]; + } + TensorWrapper lhs_i_(NVTE_MXFP8_1D_SCALING); TensorWrapper rhs_i_(NVTE_MXFP8_1D_SCALING); lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape); @@ -124,7 +137,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, lhs_wrapper_list.push_back(std::move(lhs_i_)); rhs_wrapper_list.push_back(std::move(rhs_i_)); } else { - NVTE_ERROR("Unsupported scaling mode: ", scaling_mode); + NVTE_ERROR("Unsupported scaling mode: ", static_cast(scaling_mode)); } auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype); @@ -178,7 +191,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .RemainingArgs() // input list .RemainingRets() // output list .Attr("num_gemms") - .Attr("scaling_mode") + .Attr("scaling_mode") .Attr("has_bias") .Attr("workspace_size"), FFI_CudaGraph_Traits); From 23ae948ede774c00b80905d618ad3af937b2de68 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Apr 2025 16:53:34 +0000 Subject: [PATCH 12/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 8 ++++---- transformer_engine/jax/csrc/extensions/gemm.cpp | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index f713978684..3fc1b08223 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -59,7 +59,7 @@ def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias): scaling_mode: Scaling mode for the GEMM operations. out_dtype: Data type of the output tensors. has_bias: Boolean indicating if bias tensors are provided. - + Returns: A tuple of ShapedArray objects of size num_gemms+1: ret[0 : num_gemms]: GEMM output tensors, @@ -67,9 +67,9 @@ def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias): """ del scaling_mode expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms - assert len(args) == expected_num_args, ( - f"Expected {expected_num_args} input arguments, but got {len(args)}" - ) + assert ( + len(args) == expected_num_args + ), f"Expected {expected_num_args} input arguments, but got {len(args)}" A_list = args[0:num_gemms] B_list = args[num_gemms : 2 * num_gemms] # A and B have shapes [1, m, k] and [1, n, k] diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index cc5f90ff0b..69b247623a 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -39,10 +39,10 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, size_t expected_output_size = num_gemms + 1; size_t actual_input_size = input_list.size(); size_t actual_output_size = output_list.size(); - NVTE_CHECK(actual_input_size == expected_input_size, - "Expected %zu input tensors, got %zu", expected_input_size, actual_input_size); - NVTE_CHECK(actual_output_size == expected_output_size, - "Expected %zu output tensors, got %zu", expected_output_size, actual_output_size); + NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu", + expected_input_size, actual_input_size); + NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu", + expected_output_size, actual_output_size); bool trans_lhs = true; bool trans_rhs = false; @@ -108,7 +108,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, auto lhs_sinv_shape = std::vector{1, 1}; auto rhs_sinv_shape = std::vector{1, 1}; - if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || + if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { float *amax_dptr = nullptr; float *scale_dptr = nullptr; @@ -126,7 +126,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, lhs_sinv_shape[i] = lhs_sinv_shape_[i]; rhs_sinv_shape[i] = rhs_sinv_shape_[i]; } - + TensorWrapper lhs_i_(NVTE_MXFP8_1D_SCALING); TensorWrapper rhs_i_(NVTE_MXFP8_1D_SCALING); lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape); From c850c000f9c03ff8bd56948083ef6765d0da4b33 Mon Sep 17 00:00:00 2001 From: Hua Huang Date: Fri, 11 Apr 2025 09:26:54 -0700 Subject: [PATCH 13/13] Bug fix and minor revision Signed-off-by: Hua Huang --- transformer_engine/jax/cpp_extensions/gemm.py | 3 +-- transformer_engine/jax/csrc/extensions/gemm.cpp | 12 ++++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 3fc1b08223..588e7a469d 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -89,14 +89,12 @@ def outer_abstract(*args, **kwargs): @staticmethod def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias): del out_dtype - workspace_size = get_cublas_workspace_size_bytes() return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( ctx, *args, num_gemms=num_gemms, scaling_mode=int(scaling_mode), has_bias=has_bias, - workspace_size=workspace_size, ) @staticmethod @@ -348,6 +346,7 @@ def swizzled_scale(scales): rows, cols = scales.shape scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4) scales = jnp.transpose(scales, (0, 3, 2, 1, 4)) + scales = scales.reshape(rows, cols) return scales diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 69b247623a..4318e19c75 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -17,8 +17,7 @@ namespace jax { Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, Variadic_Result_Type output_list, int64_t num_gemms, - JAXX_Scaling_Mode scaling_mode, int64_t has_bias, - int64_t workspace_size) { + JAXX_Scaling_Mode scaling_mode, int64_t has_bias) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major with size [m, k], @@ -127,8 +126,9 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, rhs_sinv_shape[i] = rhs_sinv_shape_[i]; } - TensorWrapper lhs_i_(NVTE_MXFP8_1D_SCALING); - TensorWrapper rhs_i_(NVTE_MXFP8_1D_SCALING); + NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode); + TensorWrapper lhs_i_(nvte_scaling_mode); + TensorWrapper rhs_i_(nvte_scaling_mode); lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape); rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape); lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape); @@ -168,6 +168,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, auto workspace_get = output_list.get(num_gemms); Result_Type workspace = workspace_get.value(); uint8_t *workspace_ptr = reinterpret_cast(workspace->untyped_data()); + size_t workspace_size = workspace->dimensions()[0] / num_streams; auto workspace_shape = std::vector{workspace_size}; for (int i = 0; i < num_streams; i++) { auto workspace_i = @@ -192,8 +193,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .RemainingRets() // output list .Attr("num_gemms") .Attr("scaling_mode") - .Attr("has_bias") - .Attr("workspace_size"), + .Attr("has_bias"), FFI_CudaGraph_Traits); } // namespace jax