diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 0327542c2f..588e7a469d 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -41,32 +41,45 @@ class GroupedGemmPrimitive(BasePrimitive): name = "te_grouped_gemm_ffi" multiple_results = True - impl_static_args = (6, 7, 8, 9) + impl_static_args = () inner_primitive = None outer_primitive = None @staticmethod - def abstract( - lhs_contig_aval, - lhs_scale_contig_aval, - rhs_contig_aval, - rhs_scale_contig_aval, - bias_contig_aval, - dim_list_aval, - *, - num_gemms, - scaling_mode, - out_dtype, - out_flat_size, - ): - del lhs_contig_aval, lhs_scale_contig_aval - del rhs_contig_aval, rhs_scale_contig_aval - del bias_contig_aval, dim_list_aval - del num_gemms, scaling_mode - out_flat_aval = jax.core.ShapedArray(shape=(out_flat_size,), dtype=out_dtype) - wkspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams - wkspace_aval = jax.core.ShapedArray(shape=(wkspace_size,), dtype=jnp.uint8) - return (out_flat_aval, wkspace_aval) + def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias): + """ + Args: + *args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias: + args[ 0 : num_gemms] are the lhs tensors, + args[ num_gemms : 2*num_gemms] are the rhs tensors, + args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors, + args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors, + args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True. + num_gemms: Number of GEMM operations to perform. + scaling_mode: Scaling mode for the GEMM operations. + out_dtype: Data type of the output tensors. + has_bias: Boolean indicating if bias tensors are provided. + + Returns: + A tuple of ShapedArray objects of size num_gemms+1: + ret[0 : num_gemms]: GEMM output tensors, + ret[num_gemms]:workspace tensor. + """ + del scaling_mode + expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms + assert ( + len(args) == expected_num_args + ), f"Expected {expected_num_args} input arguments, but got {len(args)}" + A_list = args[0:num_gemms] + B_list = args[num_gemms : 2 * num_gemms] + # A and B have shapes [1, m, k] and [1, n, k] + out_list_aval = tuple( + jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype) + for A, B in zip(A_list, B_list) + ) + workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams + workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) + return (*out_list_aval, workspace_aval) @staticmethod def outer_abstract(*args, **kwargs): @@ -74,60 +87,27 @@ def outer_abstract(*args, **kwargs): return out_aval @staticmethod - def lowering( - ctx, - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, - *, - num_gemms, - scaling_mode, - out_dtype, - out_flat_size, - ) -> jnp.ndarray: - del out_dtype, out_flat_size + def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias): + del out_dtype return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( ctx, - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, + *args, num_gemms=num_gemms, - scaling_mode=scaling_mode.value, + scaling_mode=int(scaling_mode), + has_bias=has_bias, ) @staticmethod - def impl( - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, - num_gemms, - scaling_mode, - out_dtype, - out_flat_size, - ) -> jnp.ndarray: + def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias): assert GroupedGemmPrimitive.inner_primitive is not None out = GroupedGemmPrimitive.inner_primitive.bind( - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, + *args, num_gemms=num_gemms, - scaling_mode=scaling_mode, + scaling_mode=scaling_mode.value, out_dtype=out_dtype, - out_flat_size=out_flat_size, + has_bias=has_bias, ) - return out[0] # out is [out_flat, wkspace], only return out_flat + return out[:-1] # out is [out_list, wkspace], only return out_list register_primitive(GroupedGemmPrimitive) @@ -366,6 +346,7 @@ def swizzled_scale(scales): rows, cols = scales.shape scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4) scales = jnp.transpose(scales, (0, 3, 2, 1, 4)) + scales = scales.reshape(rows, cols) return scales @@ -380,18 +361,12 @@ def grouped_gemm( len(lhs_list) == len(rhs_list) == len(contracting_dims_list) ), "lhs_list, rhs_list, contracting_dims_list must have the same length" - # Flatten inputs and save their shapes - num_gemms = len(lhs_list) - out_flat_size = 0 - dims = [] - lhs_contig_ = [] - rhs_contig_ = [] - lhs_scale_inv_contig_ = [] - rhs_scale_inv_contig_ = [] - bias_contig_ = [] - out_offsets = [] - remain_shape_list = [] num_gemms = len(lhs_list) + lhs_list_ = [] + rhs_list_ = [] + lhs_sinv_list_ = [] + rhs_sinv_list_ = [] + bias_list_ = [] for i in range(num_gemms): lhs = lhs_list[i] rhs = rhs_list[i] @@ -402,7 +377,7 @@ def grouped_gemm( lhs_shape = lhs.data.shape rhs_shape = rhs.data.shape out_dtype = lhs.dq_dtype - # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout + # For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: assert not ( lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 @@ -427,6 +402,7 @@ def grouped_gemm( lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) + # Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy if scaling_mode == ScalingMode.NO_SCALING: lhs_3d = _shape_normalization(lhs, lhs_dn) rhs_3d = _shape_normalization(rhs, rhs_dn) @@ -438,13 +414,13 @@ def grouped_gemm( rhs_3d = _shape_normalization(rhs.data, rhs_dn) lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn) + # swizzled_scale requires a matrix lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze()) rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze()) else: raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}") - # Note: if _shape_normalization() is updated to support non-TN, need to update here - # already_transposed doesn't matter for the output shape + # Note: already_transposed doesn't matter for the output shape # x.shape = [B, D1, D2] # contracting_dims = (2, ) --> output.shape = [1, B * D1, D2] # contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1] @@ -455,66 +431,37 @@ def grouped_gemm( bn = rhs_remain_shape[0] kl = lhs_3d.shape[-1] kr = rhs_3d.shape[-1] - remain_shape_list.append(((bm,), (bn,))) - assert kl == kr, f"lhs_3d.shape[-1] ({kl}) != rhs_3d.shape[-1] ({kr})" - k = kl - - if (bm % 16 != 0) or (bn % 16 != 0) or (k % 16 != 0): - print(f"grouped_gemm input pair {i} has invalid problem shape for lowering: ") - print( - f"m = {bm}, n = {bn}, k = {k}; cuBLAS requires the problem shapes being multiples" - " of 16" - ) - assert bm % 16 == 0 and bn % 16 == 0 and k % 16 == 0 - - dims.append((bm, bn, k)) - lhs_contig_.append(lhs_3d.reshape(-1)) - rhs_contig_.append(rhs_3d.reshape(-1)) + assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}" + if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0): + print("grouped_gemm input pair {i} has invalid problem shape for lowering: ") + print(f"m = {bm}, n = {bn}, k = {kl}; ") + print("cuBLAS requires the problem shapes being multiples of 16") + assert (bm % 16 == 0) and (bn % 16 == 0) and (kl % 16 == 0) + + lhs_list_.append(lhs_3d) + rhs_list_.append(rhs_3d) if scaling_mode == ScalingMode.NO_SCALING: - lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) - rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) + lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) + rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: - lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1)) - rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1)) + lhs_sinv_list_.append(lhs.scale_inv) + rhs_sinv_list_.append(rhs.scale_inv) if scaling_mode == ScalingMode.MXFP8_1D_SCALING: - lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1)) - rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1)) + lhs_sinv_list_.append(lhs_scale_inv) + rhs_sinv_list_.append(rhs_scale_inv) if bias_list is not None: - bias_contig_.append(bias_list[i].reshape(-1)) - out_flat_size += bm * bn - out_offsets.append(out_flat_size) - - lhs_contig = jnp.concatenate(lhs_contig_) - rhs_contig = jnp.concatenate(rhs_contig_) - lhs_scale_inv_contig = jnp.concatenate(lhs_scale_inv_contig_) - rhs_scale_inv_contig = jnp.concatenate(rhs_scale_inv_contig_) - bias_contig = jnp.empty(0) if bias_list is None else jnp.concatenate(bias_contig_) - dim_list = jnp.array(dims, dtype=jnp.int32) - - # TE/common does not support NVTE_NO_SCALING yet - # It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16 - if scaling_mode == ScalingMode.NO_SCALING: - scaling_mode = ScalingMode.DELAYED_TENSOR_SCALING - - # Perform batched GEMM on flattened inputs - out_contig = GroupedGemmPrimitive.outer_primitive.bind( - lhs_contig, - lhs_scale_inv_contig, - rhs_contig, - rhs_scale_inv_contig, - bias_contig, - dim_list, + bias_list_.append(bias_list[i]) + + out_list = GroupedGemmPrimitive.outer_primitive.bind( + *lhs_list_, + *rhs_list_, + *lhs_sinv_list_, + *rhs_sinv_list_, + *bias_list_, num_gemms=num_gemms, - scaling_mode=scaling_mode.value, + scaling_mode=scaling_mode, out_dtype=out_dtype, - out_flat_size=out_flat_size, + has_bias=1 if bias_list is not None else 0, ) - # Split the output back into tensors - out_offsets = jnp.array(out_offsets) - out_flat_list = jnp.split(out_contig, out_offsets[:-1]) - out_tensors = [] - for out_flat, (lhs_remain_shape, rhs_remain_shape) in zip(out_flat_list, remain_shape_list): - out_tensors.append(out_flat.reshape(*lhs_remain_shape, *rhs_remain_shape)) - - return out_tensors + return out_list diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index d4b9bf720e..4318e19c75 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -15,34 +15,9 @@ namespace transformer_engine { namespace jax { -constexpr static size_t MXFP8_BLOCK_SIZE = 32; - -// Note: we only support TN-GEMM for now (TN in cuBLASLt == NT in JAX) -Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lhs_sinv_ptr, - const DType &lhs_sinv_dtype, uint8_t *rhs_ptr, const DType &rhs_dtype, - uint8_t *rhs_sinv_ptr, const DType &rhs_sinv_dtype, uint8_t *bias_ptr, - const DType &bias_dtype, uint8_t *out_ptr, const DType &out_dtype, - uint8_t *workspace_ptr, const size_t workspace_size, size_t num_gemms, - int32_t *dim_list_ptr, const JAXX_Scaling_Mode scaling_mode, - cudaStream_t stream) { - size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); - size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); - size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); - size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); - size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); - size_t out_dtype_bytes = te_dtype_bytes(out_dtype); - NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); - NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, - "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); - - size_t dim_list_bytes = sizeof(int32_t) * 3 * num_gemms; - std::unique_ptr dim_list_host = std::make_unique(3 * num_gemms); - - cudaMemcpyAsync(dim_list_host.get(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); - +Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, + Variadic_Result_Type output_list, int64_t num_gemms, + JAXX_Scaling_Mode scaling_mode, int64_t has_bias) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major with size [m, k], @@ -56,6 +31,18 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh // C: column-major with size [m, n] --> row-major with size [n, m]. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. + if (num_gemms <= 0) { + return ffi_with_cuda_error_check(); + } + size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms; + size_t expected_output_size = num_gemms + 1; + size_t actual_input_size = input_list.size(); + size_t actual_output_size = output_list.size(); + NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu", + expected_input_size, actual_input_size); + NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu", + expected_output_size, actual_output_size); + bool trans_lhs = true; bool trans_rhs = false; auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); @@ -79,10 +66,40 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh std::vector out_list; std::vector workspace_list; + int lhs_list_offset = 0; + int rhs_list_offset = num_gemms; + int lhs_sinv_list_offset = 2 * num_gemms; + int rhs_sinv_list_offset = 3 * num_gemms; + int bias_list_offset = 4 * num_gemms; + int out_list_offset = 0; for (int i = 0; i < num_gemms; i++) { - size_t m = dim_list_host[i * 3]; - size_t n = dim_list_host[i * 3 + 1]; - size_t k = dim_list_host[i * 3 + 2]; + Buffer_Type lhs_i = input_list.get(lhs_list_offset + i).value(); + Buffer_Type rhs_i = input_list.get(rhs_list_offset + i).value(); + Buffer_Type lhs_sinv_i = input_list.get(lhs_sinv_list_offset + i).value(); + Buffer_Type rhs_sinv_i = input_list.get(rhs_sinv_list_offset + i).value(); + Result_Type out_i = output_list.get(out_list_offset + i).value(); + + DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type()); + DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type()); + DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type()); + + void *lhs_ptr = lhs_i.untyped_data(); + void *rhs_ptr = rhs_i.untyped_data(); + void *lhs_sinv_ptr = lhs_sinv_i.untyped_data(); + void *rhs_sinv_ptr = rhs_sinv_i.untyped_data(); + void *out_ptr = out_i->untyped_data(); + + // Placeholder for bias since it can be empty + DType bias_dtype = DType::kFloat32; + void *bias_ptr = nullptr; + + auto lhs_shape_ = lhs_i.dimensions(); + auto rhs_shape_ = rhs_i.dimensions(); + + // lhs and rhs has shape [1, m, k] and [1, n, k] + size_t m = lhs_shape_[1]; + size_t n = rhs_shape_[1]; + size_t k = lhs_shape_[2]; auto lhs_shape = std::vector{m, k}; auto rhs_shape = std::vector{n, k}; @@ -90,52 +107,54 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh auto lhs_sinv_shape = std::vector{1, 1}; auto rhs_sinv_shape = std::vector{1, 1}; - auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - lhs_i.set_rowwise_data(static_cast(lhs_ptr), lhs_dtype, lhs_shape); - rhs_i.set_rowwise_data(static_cast(rhs_ptr), rhs_dtype, rhs_shape); - - if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { - lhs_i.set_rowwise_scale_inv(static_cast(lhs_sinv_ptr), DType::kFloat32, - std::vector{1}); - rhs_i.set_rowwise_scale_inv(static_cast(rhs_sinv_ptr), DType::kFloat32, - std::vector{1}); + if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || + scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { + float *amax_dptr = nullptr; + float *scale_dptr = nullptr; + auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr, + reinterpret_cast(lhs_sinv_ptr)); + auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr, + reinterpret_cast(rhs_sinv_ptr)); + lhs_wrapper_list.push_back(std::move(lhs_i_)); + rhs_wrapper_list.push_back(std::move(rhs_i_)); } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { - NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", - MXFP8_BLOCK_SIZE, k); - size_t sinv_k = k / MXFP8_BLOCK_SIZE; - lhs_sinv_shape[0] = m; - lhs_sinv_shape[1] = sinv_k; - rhs_sinv_shape[0] = n; - rhs_sinv_shape[1] = sinv_k; - // Note: the scale_inv array should have been swizzled in Python before lowering - lhs_i.set_rowwise_scale_inv(static_cast(lhs_sinv_ptr), DType::kFloat8E8M0, - lhs_sinv_shape); - rhs_i.set_rowwise_scale_inv(static_cast(rhs_sinv_ptr), DType::kFloat8E8M0, - rhs_sinv_shape); + auto lhs_sinv_shape_ = lhs_sinv_i.dimensions(); + auto rhs_sinv_shape_ = rhs_sinv_i.dimensions(); + for (int i = 0; i < 2; i++) { + lhs_sinv_shape[i] = lhs_sinv_shape_[i]; + rhs_sinv_shape[i] = rhs_sinv_shape_[i]; + } + + NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode); + TensorWrapper lhs_i_(nvte_scaling_mode); + TensorWrapper rhs_i_(nvte_scaling_mode); + lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape); + rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape); + lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape); + rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape); + + lhs_wrapper_list.push_back(std::move(lhs_i_)); + rhs_wrapper_list.push_back(std::move(rhs_i_)); } else { NVTE_ERROR("Unsupported scaling mode: ", static_cast(scaling_mode)); } - lhs_wrapper_list.push_back(std::move(lhs_i)); - rhs_wrapper_list.push_back(std::move(rhs_i)); - - auto out_i = TensorWrapper(static_cast(out_ptr), out_shape, out_dtype); - lhs_ptr += m * k * lhs_dtype_bytes; - rhs_ptr += n * k * rhs_dtype_bytes; - out_ptr += m * n * out_dtype_bytes; - lhs_sinv_ptr += lhs_sinv_shape[0] * lhs_sinv_shape[1] * lhs_sinv_dtype_bytes; - rhs_sinv_ptr += rhs_sinv_shape[0] * rhs_sinv_shape[1] * rhs_sinv_dtype_bytes; + auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype); void *pre_gelu_ptr = nullptr; auto bias_shape = std::vector{0}; auto pre_gelu_shape = std::vector{0}; - if (bias_ptr != nullptr) bias_shape[0] = n; + if (has_bias) { + auto bias_i_get = input_list.get(bias_list_offset + i); + Buffer_Type bias_i = bias_i_get.value(); + bias_ptr = bias_i.untyped_data(); + bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type()); + bias_shape[0] = n; + } auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); - if (bias_ptr != nullptr) bias_ptr += n * bias_dtype_bytes; auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype); - out_wrapper_list.push_back(std::move(out_i)); + out_wrapper_list.push_back(std::move(out_i_)); bias_wrapper_list.push_back(std::move(bias_i)); pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); @@ -146,6 +165,10 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh out_list.push_back(out_wrapper_list.back().data()); } + auto workspace_get = output_list.get(num_gemms); + Result_Type workspace = workspace_get.value(); + uint8_t *workspace_ptr = reinterpret_cast(workspace->untyped_data()); + size_t workspace_size = workspace->dimensions()[0] / num_streams; auto workspace_shape = std::vector{workspace_size}; for (int i = 0; i < num_streams; i++) { auto workspace_i = @@ -163,50 +186,14 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh return ffi_with_cuda_error_check(); } -Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten, - Buffer_Type lhs_sinv_flatten, Buffer_Type rhs_flatten, - Buffer_Type rhs_sinv_flatten, Buffer_Type bias_flatten, - Buffer_Type dim_list, Result_Type out_flatten, - Result_Type workspace_flatten, int64_t num_gemms, - JAXX_Scaling_Mode scaling_mode) { - // Inputs - auto lhs_ptr = reinterpret_cast(lhs_flatten.untyped_data()); - auto rhs_ptr = reinterpret_cast(rhs_flatten.untyped_data()); - auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv_flatten.untyped_data()); - auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv_flatten.untyped_data()); - auto bias_ptr = reinterpret_cast(bias_flatten.untyped_data()); - auto dim_list_ptr = reinterpret_cast(dim_list.untyped_data()); - auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_flatten.element_type()); - auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_flatten.element_type()); - auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv_flatten.element_type()); - auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv_flatten.element_type()); - auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias_flatten.element_type()); - - // Outputs - auto out_ptr = reinterpret_cast(out_flatten->untyped_data()); - auto out_dtype = convert_ffi_datatype_to_te_dtype(out_flatten->element_type()); - auto workspace_ptr = reinterpret_cast(workspace_flatten->untyped_data()); - auto workspace_size = workspace_flatten->dimensions().back() / num_streams; - - return GroupedGemmImpl(lhs_ptr, lhs_dtype, lhs_sinv_ptr, lhs_sinv_dtype, rhs_ptr, rhs_dtype, - rhs_sinv_ptr, rhs_sinv_dtype, bias_ptr, bias_dtype, out_ptr, out_dtype, - workspace_ptr, workspace_size, num_gemms, dim_list_ptr, scaling_mode, - stream); -} - XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, FFI::Bind() .Ctx() // stream - .Arg() // lhs_flatten - .Arg() // lhs_sinv_flatten - .Arg() // rhs_flatten - .Arg() // rhs_sinv_flatten - .Arg() // bias_flatten - .Arg() // dim_list - .Ret() // out_flatten - .Ret() // workspace_flatten + .RemainingArgs() // input list + .RemainingRets() // output list .Attr("num_gemms") - .Attr("scaling_mode"), + .Attr("scaling_mode") + .Attr("has_bias"), FFI_CudaGraph_Traits); } // namespace jax