From b9fb508c94cc47818facaea18216a78a6a9d0949 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 7 Apr 2025 07:19:56 -0700 Subject: [PATCH 01/10] scaling enum abstract Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/activation.py | 8 +- .../jax/cpp_extensions/normalization.py | 150 ++++++++---------- .../jax/cpp_extensions/quantization.py | 4 +- transformer_engine/jax/csrc/extensions.h | 22 ++- .../jax/csrc/extensions/activation.cpp | 65 ++++---- transformer_engine/jax/csrc/extensions/misc.h | 23 +++ .../jax/csrc/extensions/normalization.cpp | 21 ++- .../jax/csrc/extensions/pybind.cpp | 8 +- .../jax/csrc/extensions/quantization.cpp | 58 +++++-- transformer_engine/jax/quantize/quantizer.py | 9 +- .../jax/quantize/scaling_modes.py | 12 +- 11 files changed, 213 insertions(+), 167 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index d7676781c3..f0a586aadb 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -162,7 +162,7 @@ def lowering( assert scale_aval is None or scale_aval.dtype == jnp.float32 out = ffi.ffi_lowering(ActLuPrimitive.name)( - ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode, is_2x=is_2x + ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x ) return out @@ -545,7 +545,7 @@ def lowering( dz, x, scale, - scaling_mode=scaling_mode, + scaling_mode=scaling_mode.value, is_2x=is_2x, is_dbias=is_dbias, act_enum=int(act_enum), @@ -928,7 +928,7 @@ def act_lu( out_dtype=x.dtype, act_enum=act_type_id, act_len=act_len, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value, + scaling_mode=ScalingMode.NVTE_NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, scale_shapes=((), ()), @@ -1042,7 +1042,7 @@ def quantize_dact_dbias( # outputs float32 for dbias accumulation out_dtype=(jnp.float32 if is_dbias else x.dtype), # default value for no scaling, TE/common ignore this value when scale is unset - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value, + scaling_mode=ScalingMode.NVTE_NO_SCALING.value, is_2x=False, # unused scale_dtype=jnp.float32, # unused scale_shapes=((), ()), # unused diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 74882c92db..1a482016e1 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -105,6 +105,26 @@ def abstract( if norm_type == NVTE_Norm_Type.LayerNorm: assert gamma_aval.size == beta_aval.size + out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) + mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) + if norm_type == NVTE_Norm_Type.RMSNorm: + mu_aval = mu_aval.update(shape=(1,)) + + updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + + colwise_out_shape = x_aval.shape if is_2x else (1,) + colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype) + + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( + scaling_mode + ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer) + + scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) + colwise_scale_inv_shape = colwise_scale_inv_shape if is_2x else (1,) + colwise_scale_inv_aval = jax.core.ShapedArray( + shape=colwise_scale_inv_shape, dtype=scale_dtype + ) + (wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size gamma_aval.size, # hidden size @@ -112,33 +132,13 @@ def abstract( jax_dtype_to_te_dtype(gamma_aval.dtype), # wtype jax_dtype_to_te_dtype(out_dtype), norm_type, - scaling_mode.value, + scaling_mode, zero_centered_gamma, epsilon, get_forward_sm_margin(), is_2x, ) - - out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) - mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) - if norm_type == NVTE_Norm_Type.RMSNorm: - mu_aval = mu_aval.update(shape=(1,)) - - rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x( - x_aval.shape, is_padded=not is_outer - ) - - scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) - colwise_scale_inv_aval = jax.core.ShapedArray( - shape=colwise_scale_inv_shape, dtype=scale_dtype - ) - colwise_out_aval = jax.core.ShapedArray( - shape=x_aval.shape if is_2x else (1,), dtype=out_dtype - ) - - updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) - - wkspace_aval = x_aval.update( + wkspace_aval = jax.core.ShapedArray( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) @@ -274,9 +274,9 @@ def impl( scale_shapes=scale_shapes, is_outer=False, ) - rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x( - x.shape, is_padded=False - ) + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( + scaling_mode + ).get_scale_shape_2x(x.shape, is_padded=False) # slice out padding for mxfp8, noop for DelayedScaling scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape( rowwise_scale_inv_shape @@ -364,6 +364,8 @@ def infer_sharding_from_operands( del zero_centered_gamma, epsilon, out_dtype, result_infos del scale_dtype, scale_shapes, is_outer x_spec = get_padded_spec(arg_infos[0]) + scale_spec = get_padded_spec(arg_infos[1]) + out_spec = (*x_spec[:-1], None) if x_spec[-1] is not None: warnings.warn( f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! " @@ -371,34 +373,27 @@ def infer_sharding_from_operands( "and hurt performance." ) - out_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.out" + out_sharding = NamedSharding(mesh, PartitionSpec(out_spec), desc="NormFwdPrimitive.out") + colwise_out_spec = out_spec if is_2x else (None,) + colwise_out_sharding = NamedSharding( + mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" ) - if is_2x: - colwise_out_sharding = out_sharding.duplicate_with_new_description( - "NormFwdPrimitive.colwise_out" - ) - else: - colwise_out_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out" - ) - rsigma_sharding = NamedSharding( mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma" ) - mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu") - if norm_type == NVTE_Norm_Type.RMSNorm: - mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu") + mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.RMSNorm else (None,) + mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") + + scale_inv_spec = amax_spec = (None,) + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + scale_inv_spec = colwise_out_spec scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale_inv" + mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv" ) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv" - ) - - amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax") + amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax") output = ( out_sharding, colwise_out_sharding, @@ -427,8 +422,11 @@ def partition( ): del result_infos, is_outer x_spec = get_padded_spec(arg_infos[0]) + scale_spec = get_padded_spec(arg_infos[1]) g_spec = get_padded_spec(arg_infos[2]) b_spec = get_padded_spec(arg_infos[3]) + out_spec = (*x_spec[:-1], None) + if x_spec[-1] is not None: warnings.warn( f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! " @@ -445,43 +443,30 @@ def partition( f"{NormFwdPrimitive.name} does not support sharding of parameter beta " "Enforcing no sharding of parameters hidden dim! " ) - x_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.x" - ) - g_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.gamma") - b_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.beta") - out_sharding = x_sharding.duplicate_with_new_description("NormFwdPrimitive.out") - if is_2x: - colwise_out_sharding = out_sharding.duplicate_with_new_description( - "NormFwdPrimitive.colwise_out" - ) - else: - colwise_out_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out" - ) + out_sharding = NamedSharding(mesh, PartitionSpec(out_spec), desc="NormFwdPrimitive.out") + colwise_out_spec = out_spec if is_2x else (None,) + colwise_out_sharding = NamedSharding( + mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" + ) rsigma_sharding = NamedSharding( - mesh, - PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]), - desc="NormFwdPrimitive.rsigma", + mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma" ) - mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu") - if norm_type == NVTE_Norm_Type.RMSNorm: - mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu") + mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.RMSNorm else (None,) + mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") - scale_sharding = NamedSharding( - mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale" - ) - scale_inv_sharding = scale_sharding.duplicate_with_new_description( - "NormFwdPrimitive.scale_inv" + scale_inv_spec = amax_spec = (None,) + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + scale_inv_spec = colwise_out_spec + + scale_inv_sharding = NamedSharding( + mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv" ) - amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax") - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv" - ) + amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax") - arg_shardings = (x_sharding, scale_sharding, g_sharding, b_sharding) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = ( out_sharding, colwise_out_sharding, @@ -517,7 +502,7 @@ def sharded_impl(x, scale, gamma, beta): scale_shapes=scale_shapes, is_outer=True, ) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -824,7 +809,6 @@ def layernorm_fwd( if isinstance(quantizer, DelayedScaleQuantizer) else jnp.ones((1,), dtype=jnp.float32) ) - if quantizer is None: output, _, _, _, _, mu, rsigma = NormFwdPrimitive.outer_primitive.bind( x, @@ -835,7 +819,7 @@ def layernorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=x.dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.NVTE_NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, scale_shapes=((1,), (1,)), @@ -864,7 +848,7 @@ def layernorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=quantizer.q_dtype, - scaling_mode=quantizer.scaling_mode, + scaling_mode=quantizer.scaling_mode.value, is_2x=is_2x2x, scale_dtype=quantizer.get_scale_dtype(), scale_shapes=quantizer.get_scale_shapes(x.shape), @@ -1017,7 +1001,7 @@ def rmsnorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=x.dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.NVTE_NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, scale_shapes=((), ()), @@ -1046,7 +1030,7 @@ def rmsnorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=quantizer.q_dtype, - scaling_mode=quantizer.scaling_mode, + scaling_mode=quantizer.scaling_mode.value, is_2x=is_2x2x, scale_dtype=quantizer.get_scale_dtype(), scale_shapes=quantizer.get_scale_shapes(x.shape), diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 034e149c50..cd8aafedf1 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -114,6 +114,8 @@ def abstract( gi_hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype), + scaling_mode, + QuantizeLayout(q_layout) # For now until we have auto-decoding for QuantizeLayout enum ) wkspace_shape = wkspace_info[0] wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) @@ -176,7 +178,7 @@ def lowering( ctx, x, scale, - scaling_mode=scaling_mode, + scaling_mode=scaling_mode.value, q_layout=q_layout, flatten_axis=flatten_axis, is_dbias=is_dbias, diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 1950d6cbab..aaaf57fab7 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -31,6 +31,9 @@ #include "transformer_engine/activation.h" #include "utils.h" +// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); + namespace transformer_engine { namespace jax { @@ -40,6 +43,12 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler); + +pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, + DType in_dtype, DType out_dtype, + JAXX_Scaling_Mode scaling_mode, bool is_2x); + // Normalization XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler); @@ -47,7 +56,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler); pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, - NVTE_Norm_Type norm_type, int scaling_mode, + NVTE_Norm_Type norm_type, + JAXX_Scaling_Mode scaling_mode, bool zero_centered_gamma, float epsilon, int sm_margin, bool is_training); @@ -61,13 +71,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype); - -XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler); - -pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype, - int scaling_mode, bool is_2x); + DType in_dtype, DType out_dtype, + JAXX_Scaling_Mode scaling_mode, + QuantizeLayout q_layout); // Softmax XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler); diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index e71597e4b3..fc7f231f34 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -17,7 +17,7 @@ namespace jax { Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, - Result_Type amax_buf, int64_t act_enum, int64_t scaling_mode_enum, + Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, bool is_2x_int) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -34,7 +34,6 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto n = input_dims.back(); auto act_type = static_cast(act_enum); auto act_len = input_dims[input_dims.size() - 2]; - auto scaling_mode = static_cast(scaling_mode_enum); auto is_2x = static_cast(is_2x_int); auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis @@ -42,11 +41,11 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto output_shape = std::vector{m, n}; auto output_trans_shape = std::vector{n, m}; auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); - auto output_tensor = TensorWrapper(scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, static_cast(out_dtype), output_shape); if (is_fp8_dtype(out_dtype)) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); cudaMemsetAsync(amax, 0, sizeof(float), stream); @@ -66,15 +65,17 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal } if (is_2x) { - auto &tmp_shape = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape); if (is_fp8_dtype(out_dtype)) { // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &tmp_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? scale_inv_buf + : colwise_scale_inv_buf; + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{1}); @@ -138,13 +139,13 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Ret() // scale_inv colwise .Ret() // amax .Attr("act_enum") - .Attr("scaling_mode") + .Attr("scaling_mode") .Attr("is_2x"), FFI_CudaGraph_Traits); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, - int scaling_mode, bool is_2x) { + JAXX_Scaling_Mode scaling_mode, bool is_2x) { auto input_shape = std::vector{batch_size, hidden_size}; auto dact_input_shape = std::vector{batch_size, hidden_size}; auto output_shape = std::vector{batch_size, hidden_size}; @@ -163,7 +164,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid auto dact_input_tensor = TensorWrapper(reinterpret_cast(&temp), dact_input_shape, in_dtype); auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); - auto output_tensor = TensorWrapper(static_cast(scaling_mode)); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); // Only the pointers will be checked for scale_inv, thus the shapes do not matter if (is_fp8_dtype(out_dtype)) { @@ -172,9 +173,8 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid } if (is_2x) { - auto &tmp_shape = scaling_mode == static_cast(NVTE_DELAYED_TENSOR_SCALING) - ? output_trans_shape - : output_shape; + auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); // Only the pointers will be checked for scale_inv, thus the shapes do not matter @@ -184,7 +184,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid } } - if (is_fp8_dtype(out_dtype) && scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) { + if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_amax(reinterpret_cast(&temp), DType::kFloat32, std::vector{1}); output_tensor.set_scale(reinterpret_cast(&temp), DType::kFloat32, @@ -205,8 +205,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, - Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x, - bool is_dbias, int64_t act_enum) { + Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, + int64_t act_enum, bool is_2x, bool is_dbias) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -216,7 +216,6 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, float *scale = reinterpret_cast(scale_buf.untyped_data()); float *amax = reinterpret_cast(amax_buf->untyped_data()); - auto scaling_mode = static_cast(scaling_mode_enum); auto act_type = static_cast(act_enum); auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis @@ -245,10 +244,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype); - auto output_tensor = TensorWrapper(scaling_mode); + + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, out_dtype, output_shape); if (is_fp8_dtype(out_dtype)) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); cudaMemsetAsync(amax, 0, sizeof(float), stream); @@ -268,15 +268,17 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, } if (is_2x) { - auto &tmp_shape = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape); if (is_fp8_dtype(out_dtype)) { // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &tmp_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? scale_inv_buf + : colwise_scale_inv_buf; + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{1}); @@ -295,9 +297,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!"); - NVTE_CHECK( - !(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && act_len == 2), - "TE/common does not support delayed scaling for 2x with gated activations."); + NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_2x && act_len == 2), + "TE/common does not support delayed scaling for 2x with gated activations."); if (is_dbias) { switch (act_type) { @@ -384,10 +385,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Ret() // amax .Ret() // dbias .Ret() // wkspace - .Attr("scaling_mode") + .Attr("scaling_mode") + .Attr("act_enum") .Attr("is_2x") - .Attr("is_dbias") - .Attr("act_enum"), + .Attr("is_dbias"), FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index c8526e20c0..f7577c24f3 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -40,5 +40,28 @@ enum class QuantizeLayout { ROWWISE_COLWISE, }; +enum class JAXX_Scaling_Mode : int64_t { + NO_SCALING = 0, + DELAYED_TENSOR_SCALING = 1, + MXFP8_1D_SCALING = 2, +}; + +static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { + switch (mode) { + case JAXX_Scaling_Mode::NO_SCALING: + return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; + break; + case JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING: + return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; + break; + case JAXX_Scaling_Mode::MXFP8_1D_SCALING: + return NVTEScalingMode::NVTE_MXFP8_1D_SCALING; + break; + default: + NVTE_ERROR("Invalid Scaling Mode ", static_cast(mode)); + break; + } +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index 03855753cf..e23e42f528 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -14,7 +14,8 @@ namespace jax { pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, - NVTE_Norm_Type norm_type, int scaling_mode, + NVTE_Norm_Type norm_type, + JAXX_Scaling_Mode scaling_mode, bool zero_centered_gamma, float epsilon, int sm_margin, bool is_training) { auto input_shape = std::vector{batch_size, hidden_size}; @@ -26,12 +27,11 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype); auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); - auto _scaling_mode = static_cast(scaling_mode); - auto output_tensor = TensorWrapper(_scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape); // WAR: NVTE Norms query the is_training from whereas columwise_data is allocated - if (is_training && _scaling_mode == NVTE_MXFP8_1D_SCALING) { + if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { int temp = 1; output_tensor.set_columnwise_data(static_cast(&temp), out_dtype, input_shape); } @@ -47,7 +47,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, nullptr); } else { - NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma, + NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || !zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), epsilon, output_tensor.data(), rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, @@ -64,7 +64,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, double epsilon, - int64_t sm_margin, int scaling_mode, bool is_2x) { + int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, bool is_2x) { auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); @@ -80,7 +80,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc auto *amax = reinterpret_cast(amax_buf->untyped_data()); auto *workspace = wkspace_buf->untyped_data(); - auto _scaling_mode = static_cast(scaling_mode); auto _norm_type = static_cast(norm_type); auto _is_2x = static_cast(is_2x); @@ -105,7 +104,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin; auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype); - auto output_tensor = TensorWrapper(_scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, static_cast(out_dtype), input_shape); if (is_fp8_dtype(out_dtype)) { @@ -117,7 +116,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc scale_inv_buf->dimensions().back()}); } - if (_scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); cudaMemsetAsync(amax, 0, sizeof(float), stream); output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); @@ -142,7 +141,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, stream); } else { - NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma, + NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || !zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), _epsilon, output_tensor.data(), rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, @@ -170,7 +169,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, .Attr("zero_centered_gamma") .Attr("epsilon") .Attr("sm_margin") - .Attr("scaling_mode") + .Attr("scaling_mode") .Attr("is_2x"), FFI_CudaGraph_Traits); diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index ebdfe461c7..5c165cccb6 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -138,10 +138,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("RMSNorm", NVTE_Norm_Type::RMSNorm) .export_values(); - pybind11::enum_(m, "NVTE_Scaling_Mode", pybind11::module_local()) - .value("NVTE_DELAYED_TENSOR_SCALING", NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) - .value("NVTE_MXFP8_1D_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) - .value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) + pybind11::enum_(m, "JAXX_Scaling_Mode", pybind11::module_local()) + .value("NO_SCALING", JAXX_Scaling_Mode::NO_SCALING) + .value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + .value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING) .export_values(); pybind11::enum_(m, "QuantizeLayout", diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index b48ee8a9b9..481dbd7cdf 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -13,7 +13,9 @@ namespace transformer_engine { namespace jax { pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype) { + DType in_dtype, DType out_dtype, + JAXX_Scaling_Mode scaling_mode, + QuantizeLayout q_layout) { auto input_shape = std::vector{batch_size, hidden_size}; auto output_shape = std::vector{batch_size, hidden_size}; auto output_trans_shape = std::vector{hidden_size, batch_size}; @@ -27,10 +29,37 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ int temp = 0; auto input_tensor = TensorWrapper(reinterpret_cast(&temp), input_shape, in_dtype); - auto output_tensor = TensorWrapper(reinterpret_cast(&temp), output_shape, out_dtype); - output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, output_trans_shape); auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + // Only the pointers will be checked for scale_inv, thus the shapes do not matter + if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) { + output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); + if (is_fp8_dtype(out_dtype)) { + output_tensor.set_rowwise_scale_inv(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + } + } + + if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::COLWISE) { + auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape + : output_shape; + output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); + + // Only the pointers will be checked for scale_inv, thus the shapes do not matter + if (is_fp8_dtype(out_dtype)) { + output_tensor.set_columnwise_scale_inv(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + } + } + + if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { + output_tensor.set_amax(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + output_tensor.set_scale(reinterpret_cast(&temp), DType::kFloat32, + std::vector{1}); + } + TensorWrapper dummy_workspace; nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), @@ -44,8 +73,8 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T Result_Type output_buf, Result_Type output_trans_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, - int64_t scaling_mode_enum, int64_t quantize_layout_enum, bool is_dbias, - int64_t flatten_axis) { + JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, + bool is_dbias, int64_t flatten_axis) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -54,7 +83,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T auto *input = input_buf.untyped_data(); - auto scaling_mode = static_cast(scaling_mode_enum); auto const quantize_layout = static_cast(quantize_layout_enum); auto *output = output_buf->untyped_data(); @@ -77,14 +105,14 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T std::vector workspace_shape{workspace_dims.begin(), workspace_dims.end()}; auto input_tensor = TensorWrapper(input, input_shape, in_dtype); - auto output_tensor = TensorWrapper(scaling_mode); + auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); if (quantize_layout == QuantizeLayout::ROWWISE || quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { output_tensor.set_rowwise_data(output, out_dtype, output_shape); if (is_fp8_dtype(out_dtype)) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { float *scale = reinterpret_cast(scale_buf.untyped_data()); float *amax = reinterpret_cast(amax_buf->untyped_data()); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); @@ -109,14 +137,16 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T if (quantize_layout == QuantizeLayout::COLWISE || quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { - auto &tmp_shape = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? output_trans_shape + : output_shape; output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape); // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &tmp_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; + auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + ? scale_inv_buf + : colwise_scale_inv_buf; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{1}); @@ -153,7 +183,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Ret() // amax .Ret() // dbias .Ret() // wkspace - .Attr("scaling_mode") + .Attr("scaling_mode") .Attr("q_layout") .Attr("is_dbias") .Attr("flatten_axis"), diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index bd7045453b..33d996d396 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -172,7 +172,7 @@ class DelayedScaleQuantizer(Quantizer): amax_history: History of maximum absolute values """ - scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING + scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) @@ -375,7 +375,7 @@ class BlockScaleQuantizer(Quantizer): q_layout: Quantization axis (default: ROWWISE_COLWISE) """ - scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING + scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING.value q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE def get_data_layout(self) -> str: @@ -556,8 +556,9 @@ def create( A single quantizer or tuple of quantizers """ # (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted - # assert scaling_mode != ScalingMode.NVTE_INVALID_SCALING - if scaling_mode in (ScalingMode.NVTE_NO_SCALING, ScalingMode.NVTE_INVALID_SCALING): + assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type" + # import pdb; pdb.set_trace() + if scaling_mode == ScalingMode.NVTE_NO_SCALING: quantizers = [None] * n_quantizers else: quantizers = [] diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 95bbc9bb41..e31b727f50 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -16,6 +16,8 @@ from functools import reduce import operator +from transformer_engine_jax import JAXX_Scaling_Mode + from jax.tree_util import register_pytree_node_class import jax.numpy as jnp @@ -227,14 +229,12 @@ class ScalingMode(Enum): This class defines the available scaling modes for tensor quantization: - NVTE_DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales - NVTE_MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales - - NVTE_INVALID_SCALING: Invalid scaling mode - NVTE_NO_SCALING: No scaling applied """ - NVTE_DELAYED_TENSOR_SCALING = 0 - NVTE_MXFP8_1D_SCALING = 1 - NVTE_INVALID_SCALING = 100 - NVTE_NO_SCALING = 1000 + NVTE_NO_SCALING = JAXX_Scaling_Mode.NO_SCALING + NVTE_DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING + NVTE_MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING def _get_impl(self) -> ScalingModeMetadataImpl: """Get the implementation for this scaling mode. @@ -304,7 +304,7 @@ def __eq__(self, other): """ if not isinstance(other, ScalingMode): return False - return self.value == other.value + return self.value == other.value or self == other.value or self.value == other def tree_flatten(self): """Flatten this scaling mode for JAX tree operations. From 2e9a07e3b30def4269bf84ed241e9988446fc131 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Apr 2025 14:30:13 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/quantization.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index cd8aafedf1..24679d17cb 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -115,7 +115,9 @@ def abstract( jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype), scaling_mode, - QuantizeLayout(q_layout) # For now until we have auto-decoding for QuantizeLayout enum + QuantizeLayout( + q_layout + ), # For now until we have auto-decoding for QuantizeLayout enum ) wkspace_shape = wkspace_info[0] wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) From f6bb911e47d3c90eae636ea897bdcdcee0d9ea25 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 8 Apr 2025 10:27:58 -0700 Subject: [PATCH 03/10] revert unneccessary changes Signed-off-by: Phuong Nguyen --- transformer_engine/jax/quantize/quantizer.py | 4 ++-- transformer_engine/jax/quantize/scaling_modes.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 33d996d396..ed747cf44d 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -172,7 +172,7 @@ class DelayedScaleQuantizer(Quantizer): amax_history: History of maximum absolute values """ - scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value + scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) @@ -375,7 +375,7 @@ class BlockScaleQuantizer(Quantizer): q_layout: Quantization axis (default: ROWWISE_COLWISE) """ - scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING.value + scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE def get_data_layout(self) -> str: diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index e31b727f50..adf7d72482 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -304,7 +304,7 @@ def __eq__(self, other): """ if not isinstance(other, ScalingMode): return False - return self.value == other.value or self == other.value or self.value == other + return self.value == other.value def tree_flatten(self): """Flatten this scaling mode for JAX tree operations. From 7052751add4006679d825a7a17d20382cd3f77fa Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 8 Apr 2025 10:50:03 -0700 Subject: [PATCH 04/10] rm NVTE_ from ScalingMode names Signed-off-by: Phuong Nguyen --- .../encoder/test_model_parallel_encoder.py | 4 +- examples/jax/encoder/test_multigpu_encoder.py | 4 +- .../jax/encoder/test_single_gpu_encoder.py | 4 +- examples/jax/mnist/test_single_gpu_mnist.py | 4 +- tests/jax/test_custom_call_compute.py | 40 +++++++++---------- tests/jax/test_distributed_layernorm.py | 2 +- tests/jax/test_distributed_layernorm_mlp.py | 2 +- tests/jax/test_layer.py | 4 +- .../jax/cpp_extensions/activation.py | 36 ++++++++--------- transformer_engine/jax/cpp_extensions/gemm.py | 28 ++++++------- transformer_engine/jax/cpp_extensions/misc.py | 2 +- .../jax/cpp_extensions/normalization.py | 26 ++++++------ .../jax/cpp_extensions/quantization.py | 18 ++++----- transformer_engine/jax/flax/module.py | 2 +- .../jax/quantize/dequantizer.py | 4 +- transformer_engine/jax/quantize/helper.py | 16 ++++---- transformer_engine/jax/quantize/quantizer.py | 12 +++--- .../jax/quantize/scaling_modes.py | 21 +++++----- transformer_engine/jax/quantize/tensor.py | 2 +- 19 files changed, 114 insertions(+), 117 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 7e6605c9fe..eabd1b2a3f 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -448,8 +448,8 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index ba62d964fa..839bc3175e 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -416,8 +416,8 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 1300be01bb..df78157cc5 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -327,8 +327,8 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index 4022cb7493..435750a1db 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -306,8 +306,8 @@ def mnist_parser(args): class TestMNIST(unittest.TestCase): """MNIST unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 4dc07a2eea..8917e92465 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -48,21 +48,21 @@ LN_CASES = [(256, 128), (128, 256)] DTYPES = [jnp.bfloat16, jnp.float32] is_fp8_supported, reason = helper.is_fp8_available() -is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) supported_scaling_modes = [] """ Find supported scaling modes""" if is_fp8_supported: - supported_scaling_modes.append(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) + supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING) if is_mxfp8_supported: - supported_scaling_modes.append(ScalingMode.NVTE_MXFP8_1D_SCALING) + supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING) def is_shape_supported_by_mxfp8(input_shape): try: if isinstance(input_shape, type(pytest.param(0))): input_shape = input_shape.values[0] - ScalingMode.NVTE_MXFP8_1D_SCALING.get_scale_shape_2x(input_shape) + ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape) return True except: # get_scale_shapes will raise an exception if the shape is not supported @@ -170,7 +170,7 @@ def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, ) quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_dtype=output_type, q_layout=QuantizeLayout.ROWWISE, ) @@ -198,7 +198,7 @@ def test_act_forward_with_delayed_scaling_fp8( te_quantizer, jax_quantizer = QuantizerFactory.create( n_quantizers=2, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_dtype=output_type, q_layout=q_layout, ) @@ -223,7 +223,7 @@ def test_act_forward_with_block_scaling_fp8( self.activation_type = activation_type quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout + scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) output = tex.act_lu(x, activation_type, quantizer) @@ -345,7 +345,7 @@ def test_norm_grad_with_delayed_scaling_fp8( pytest.skip("RMSNorm and zero_centered_gamma is not supported!") quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_dtype=out_dtype, q_layout=q_layout, ) @@ -420,7 +420,7 @@ def test_norm_forward_with_delayed_scaling_fp8( epsilon=epsilon, inp_dtype=inp_dtype, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, q_layout=q_layout, ) @@ -437,7 +437,7 @@ def test_norm_forward_with_block_scaling_fp8( epsilon=epsilon, inp_dtype=inp_dtype, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, + scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_layout=QuantizeLayout.ROWWISE_COLWISE, ) @@ -493,7 +493,7 @@ def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatt if flatten_axis == -2: input_shape = input_shape[:-1] + (2,) + input_shape[-1:] - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): x = jax.random.uniform(key, input_shape, in_dtype) @@ -533,7 +533,7 @@ class TestFusedQuantize: def test_quantize_dbias( self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis ): - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( + if scaling_mode == ScalingMode.MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( input_shape ): pytest.skip(f"Input shape {input_shape} is not supported by MXFP8") @@ -618,7 +618,7 @@ def test_quantize_dact_dbias_no_quantization( in_dtype=in_dtype, input_shape=input_shape, out_dtype=in_dtype, - scaling_mode=ScalingMode.NVTE_NO_SCALING, + scaling_mode=ScalingMode.NO_SCALING, activation_type=activation_type, is_dbias=is_dbias, q_layout=QuantizeLayout.ROWWISE, @@ -639,7 +639,7 @@ def test_quantize_dact_dbias_delayed_scaling( in_dtype=in_dtype, input_shape=input_shape, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, activation_type=activation_type, is_dbias=is_dbias, q_layout=q_layout, @@ -670,7 +670,7 @@ def test_quantize_dact_dbias_mxfp8_scaling( in_dtype=in_dtype, input_shape=input_shape, out_dtype=out_dtype, - scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, + scaling_mode=ScalingMode.MXFP8_1D_SCALING, activation_type=activation_type, is_dbias=is_dbias, q_layout=q_layout, @@ -785,7 +785,7 @@ def ref_func(x, w, bias, data_layout): scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True ) - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) @@ -830,7 +830,7 @@ def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): Test layernorm_dense VJP Rule """ # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: pytest.skip("E5M2 is not supported in normalization with TE Backend!") # zero_centered_gamma is already tested in TestNorm @@ -886,7 +886,7 @@ def ref_func(x, w, gamma, beta): x, w, gamma, beta ) - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): prim_out, ( prim_x_grad, @@ -916,7 +916,7 @@ def test_layernorm_mlp_grad( Test layernorm_mlp VJP Rule """ # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: pytest.skip("E5M2 is not supported in normalization with TE Backend!") # zero_centered_gamma is already tested in TestNorm @@ -993,7 +993,7 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): value_n_grad_prim_func = value_and_grad(prim_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, range(6)) - n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): prim_out, ( prim_x_grad, diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index 6d4cde364f..476d455a6a 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -29,7 +29,7 @@ } is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) SUPPORTED_RECIPES = [] if is_fp8_supported: diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 4350d5e8f3..cf311ac404 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -36,7 +36,7 @@ is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) SUPPORTED_RECIPES = [] if is_fp8_supported: diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index b89530c19f..a21583a98c 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -39,7 +39,7 @@ def enable_fused_attn(): is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) +is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) QUANTIZE_RECIPES = [] """ Find supported scaling modes""" @@ -313,7 +313,7 @@ def test_backward( test_others, test_layer, ) - if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING: _, updated_quantize_meta = flax.core.pop( updated_state[0], QuantizeConfig.COLLECTION_NAME ) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index f0a586aadb..c27f6f50f7 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -282,7 +282,7 @@ def infer_sharding_from_operands( out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: colwise_out_spec = out_spec @@ -293,9 +293,9 @@ def infer_sharding_from_operands( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec if is_2x: @@ -339,7 +339,7 @@ def partition( out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: colwise_out_spec = out_spec @@ -350,9 +350,9 @@ def partition( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec if is_2x: @@ -391,7 +391,7 @@ def sharded_impl(x, scale): ) ) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -463,7 +463,7 @@ def abstract( scaling_mode ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2) if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2) else: colwise_out_shape = out_shape @@ -673,7 +673,7 @@ def infer_sharding_from_operands( mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out" ) if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: colwise_x_spec = x_spec @@ -691,9 +691,9 @@ def infer_sharding_from_operands( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if is_2x: @@ -743,7 +743,7 @@ def partition( ) if is_2x: - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: colwise_x_spec = x_spec @@ -761,9 +761,9 @@ def partition( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if is_2x: @@ -810,7 +810,7 @@ def sharded_impl(dz, x, scale): else: global_dbias = local_dbias - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -928,7 +928,7 @@ def act_lu( out_dtype=x.dtype, act_enum=act_type_id, act_len=act_len, - scaling_mode=ScalingMode.NVTE_NO_SCALING.value, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, scale_shapes=((), ()), @@ -1042,7 +1042,7 @@ def quantize_dact_dbias( # outputs float32 for dbias accumulation out_dtype=(jnp.float32 if is_dbias else x.dtype), # default value for no scaling, TE/common ignore this value when scale is unset - scaling_mode=ScalingMode.NVTE_NO_SCALING.value, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, # unused scale_dtype=jnp.float32, # unused scale_shapes=((), ()), # unused @@ -1095,7 +1095,7 @@ def quantize_dact_dbias( ) # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise - if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): + if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): colwise_scale_inv = rowwise_scale_inv quantizer.update(updated_amax) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 1df2bcc97f..187fa37317 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -198,7 +198,7 @@ def _jax_gemm_delayed_scaling_fp8( ): """FP8 GEMM for XLA pattern match""" assert ( - rhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING + rhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING ), "rhs does not have delayed tensor scaling mode" (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums @@ -230,7 +230,7 @@ def _jax_gemm_mxfp8_1d( JAX GEMM for MXFP8 via scaled_matmul """ assert ( - rhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING + rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING ), "rhs does not have MXFP8 1D scaling mode" from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper @@ -291,10 +291,10 @@ def _jax_gemm( def _jax_gemm_fp8_impl(lhs, rhs): - if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: return _jax_gemm_delayed_scaling_fp8(lhs, rhs, dim_nums) - if lhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums) raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") @@ -403,7 +403,7 @@ def grouped_gemm( rhs_shape = rhs.data.shape out_dtype = lhs.dq_dtype # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout - if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: assert not ( lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 ), "FP8 GEMM does not support E5M2 * E5M2" @@ -415,7 +415,7 @@ def grouped_gemm( dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) else: # For jnp.ndarray, only consider contracting_dims, data_layout is always NN - scaling_mode = ScalingMode.NVTE_NO_SCALING + scaling_mode = ScalingMode.NO_SCALING lhs_shape = lhs.shape rhs_shape = rhs.shape out_dtype = lhs.dtype @@ -427,13 +427,13 @@ def grouped_gemm( lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) - if scaling_mode == ScalingMode.NVTE_NO_SCALING: + if scaling_mode == ScalingMode.NO_SCALING: lhs_3d = _shape_normalization(lhs, lhs_dn) rhs_3d = _shape_normalization(rhs, rhs_dn) - elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + elif scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING: lhs_3d = _shape_normalization(lhs.data, lhs_dn) rhs_3d = _shape_normalization(rhs.data, rhs_dn) lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) @@ -470,13 +470,13 @@ def grouped_gemm( dims.append((bm, bn, k)) lhs_contig_.append(lhs_3d.reshape(-1)) rhs_contig_.append(rhs_3d.reshape(-1)) - if scaling_mode == ScalingMode.NVTE_NO_SCALING: + if scaling_mode == ScalingMode.NO_SCALING: lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1)) rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1)) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1)) rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1)) if bias_list is not None: @@ -493,8 +493,8 @@ def grouped_gemm( # TE/common does not support NVTE_NO_SCALING yet # It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16 - if scaling_mode == ScalingMode.NVTE_NO_SCALING: - scaling_mode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING + if scaling_mode == ScalingMode.NO_SCALING: + scaling_mode = ScalingMode.DELAYED_TENSOR_SCALING # Perform batched GEMM on flattened inputs out_contig = GroupedGemmPrimitive.outer_primitive.bind( diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index c79eda5568..d64104ac27 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -216,7 +216,7 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, """ should_apply_war = ( quantizer is not None - and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING + and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x() ) if not should_apply_war: diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 1a482016e1..63b1b797fd 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -385,9 +385,9 @@ def infer_sharding_from_operands( mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") scale_inv_spec = amax_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = colwise_out_spec scale_inv_sharding = NamedSharding( @@ -456,9 +456,9 @@ def partition( mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") scale_inv_spec = amax_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = colwise_out_spec scale_inv_sharding = NamedSharding( @@ -502,7 +502,7 @@ def sharded_impl(x, scale, gamma, beta): scale_shapes=scale_shapes, is_outer=True, ) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -819,7 +819,7 @@ def layernorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=x.dtype, - scaling_mode=ScalingMode.NVTE_NO_SCALING.value, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, scale_shapes=((1,), (1,)), @@ -829,7 +829,7 @@ def layernorm_fwd( is_2x2x = quantizer.is_2x2x() # TE/common normalization doesn't support 2x delayed scaling - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: is_2x2x = False ( rowwise_casted_output, @@ -857,7 +857,7 @@ def layernorm_fwd( quantizer.update(updated_amax) # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: colwise_casted_output = jnp.transpose( rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) ) @@ -866,7 +866,7 @@ def layernorm_fwd( # cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs. # So here we need to slice out the zero tail and reshape it to the unpadded scale shape. # The ScaledTensorFactory takes care of padding when creating the ScaledTensor - if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING: rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes( x.shape, is_padded=False ) @@ -1001,7 +1001,7 @@ def rmsnorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=x.dtype, - scaling_mode=ScalingMode.NVTE_NO_SCALING.value, + scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, scale_shapes=((), ()), @@ -1011,7 +1011,7 @@ def rmsnorm_fwd( is_2x2x = quantizer.is_2x2x() # TE/common normalization doesn't support 2x delayed scaling - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: is_2x2x = False ( rowwise_casted_output, @@ -1039,7 +1039,7 @@ def rmsnorm_fwd( quantizer.update(updated_amax) # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: colwise_casted_output = jnp.transpose( rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) ) @@ -1048,7 +1048,7 @@ def rmsnorm_fwd( # cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs. # So here we need to slice out the zero tail and reshape it to the unpadded scale shape. # The ScaledTensorFactory takes care of padding when creating the ScaledTensor - if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING: rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes( x.shape, is_padded=False ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 24679d17cb..2911b5a420 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -93,7 +93,7 @@ def abstract( ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis) else: colwise_out_shape = out_shape @@ -306,7 +306,7 @@ def infer_sharding_from_operands( desc="DBiasQuantizePrimitive.out_sharding", ) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec @@ -326,9 +326,9 @@ def infer_sharding_from_operands( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): @@ -378,7 +378,7 @@ def partition( desc="DBiasQuantizePrimitive.out_sharding", ) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec @@ -398,9 +398,9 @@ def partition( ) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): @@ -449,7 +449,7 @@ def sharded_impl(x, scale): is_outer=True, ) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) else: global_updated_amax = local_amax @@ -592,7 +592,7 @@ def _quantize_dbias_impl( is_outer=True, ) # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise - if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): + if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): colwise_scale_inv = rowwise_scale_inv quantizer.update(updated_amax) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index a944848881..45ff8d7ed9 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -361,7 +361,7 @@ def generate_quantize_meta(quantizer_name: str): ).value return QuantizeMeta(scale=scale, amax_history=amax_history) - if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING: x_meta = generate_quantize_meta("x") kernel_meta = generate_quantize_meta("kernel") grad_meta = generate_quantize_meta("grad") diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index b1e9ba03b4..d68eb3c6c2 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -84,8 +84,8 @@ def _dq_func_block_scaling(scaled_tensor): ) funcs = { - ScalingMode.NVTE_DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling, - ScalingMode.NVTE_MXFP8_1D_SCALING: _dq_func_block_scaling, + ScalingMode.DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling, + ScalingMode.MXFP8_1D_SCALING: _dq_func_block_scaling, } @staticmethod diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 7d144aa69d..98f280b9a9 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -94,15 +94,15 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]: A tuple of (bool, str) indicating support and any error message """ gpu_arch = get_device_compute_capability(gpu_id) - if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: return _check_delayed_scaling_fp8_support(gpu_arch) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: return _check_block_scaling_fp8_support(gpu_arch) return (False, "Unsupported scaling_mode!") def is_fp8_available( - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, gpu_id=None, ) -> Tuple[bool, str]: """Check if FP8 is available for the given scaling mode and GPU. @@ -179,9 +179,9 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode: ValueError: If the recipe type is not supported """ if isinstance(fp8_recipe, recipe.DelayedScaling): - return ScalingMode.NVTE_DELAYED_TENSOR_SCALING + return ScalingMode.DELAYED_TENSOR_SCALING if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - return ScalingMode.NVTE_MXFP8_1D_SCALING + return ScalingMode.MXFP8_1D_SCALING raise ValueError("Invalid fp8_recipe!") @@ -217,7 +217,7 @@ class QuantizeConfig: FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False IF_QUANTIZE_2X: bool = False - SCALING_MODE: ScalingMode = ScalingMode.NVTE_NO_SCALING + SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING # DelayedScaling AMAX_HISTORY_LEN: int = 1024 @@ -253,11 +253,11 @@ def finalize(cls) -> None: cls.MARGIN = 0.0 cls.FP8_FORMAT = recipe.Format.HYBRID cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) - cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING + cls.SCALING_MODE = ScalingMode.NO_SCALING cls.FP8_2X_ACC_FPROP = False cls.FP8_2X_ACC_DGRAD = False cls.FP8_2X_ACC_WGRAD = False - cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING + cls.SCALING_MODE = ScalingMode.NO_SCALING cls.IF_QUANTIZE_2X = False # DelayedScaling cls.AMAX_HISTORY_LEN = 1024 diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index ed747cf44d..b57043a034 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -172,7 +172,7 @@ class DelayedScaleQuantizer(Quantizer): amax_history: History of maximum absolute values """ - scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING + scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) @@ -375,7 +375,7 @@ class BlockScaleQuantizer(Quantizer): q_layout: Quantization axis (default: ROWWISE_COLWISE) """ - scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING + scaling_mode: ScalingMode = ScalingMode.MXFP8_1D_SCALING q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE def get_data_layout(self) -> str: @@ -530,8 +530,8 @@ class QuantizerFactory: """ quantizer_type_map = { - ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScaleQuantizer, - ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScaleQuantizer, + ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer, + ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer, } @staticmethod @@ -558,7 +558,7 @@ def create( # (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type" # import pdb; pdb.set_trace() - if scaling_mode == ScalingMode.NVTE_NO_SCALING: + if scaling_mode == ScalingMode.NO_SCALING: quantizers = [None] * n_quantizers else: quantizers = [] @@ -652,4 +652,4 @@ def create_set( return q_set[0] if len(q_set) == 1 else tuple(q_set) -noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NVTE_NO_SCALING) +noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index adf7d72482..a47e78892c 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -218,23 +218,20 @@ def get_scale_shape( return (*first_dim_scale_shape, *last_dim_scale_shape) -# (Phuong: Map the NVTEScalingMode value to the ScalingMode - - @dataclass(frozen=True) @register_pytree_node_class class ScalingMode(Enum): """Enumeration of tensor scaling modes with their corresponding metadata implementations. This class defines the available scaling modes for tensor quantization: - - NVTE_DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales - - NVTE_MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales - - NVTE_NO_SCALING: No scaling applied + - DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales + - MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales + - NO_SCALING: No scaling applied """ - NVTE_NO_SCALING = JAXX_Scaling_Mode.NO_SCALING - NVTE_DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING - NVTE_MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING + NO_SCALING = JAXX_Scaling_Mode.NO_SCALING + DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING + MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING def _get_impl(self) -> ScalingModeMetadataImpl: """Get the implementation for this scaling mode. @@ -329,8 +326,8 @@ def tree_unflatten(cls, aux_data, _children): SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = { - ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(), - ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), + ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(), + ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), # WAR - ScalingMode.NVTE_NO_SCALING: DelayedScalingModeMetadataImpl(), + ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(), } diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index c34a235d94..941f777f61 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -236,7 +236,7 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st data = with_sharding_constraint_by_logical_axes(self.data, axis_names) - if self.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + if self.scaling_mode == ScalingMode.MXFP8_1D_SCALING: # TODO(Phuong): Handle padding !? scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names) else: From e6251a87f116fb9acbb7cf048b5c8e32d191c695 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 8 Apr 2025 18:25:23 -0700 Subject: [PATCH 05/10] rework scaling mode enum in grouped gemm Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 6 +++--- transformer_engine/jax/csrc/extensions/gemm.cpp | 17 +++++++++-------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 187fa37317..0327542c2f 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -98,7 +98,7 @@ def lowering( bias_contig, dim_list, num_gemms=num_gemms, - scaling_mode=int(scaling_mode), + scaling_mode=scaling_mode.value, ) @staticmethod @@ -123,7 +123,7 @@ def impl( bias_contig, dim_list, num_gemms=num_gemms, - scaling_mode=scaling_mode.value, + scaling_mode=scaling_mode, out_dtype=out_dtype, out_flat_size=out_flat_size, ) @@ -505,7 +505,7 @@ def grouped_gemm( bias_contig, dim_list, num_gemms=num_gemms, - scaling_mode=scaling_mode, + scaling_mode=scaling_mode.value, out_dtype=out_dtype, out_flat_size=out_flat_size, ) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index e5ec160c91..03c1be551a 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -23,7 +23,7 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh uint8_t *rhs_sinv_ptr, const DType &rhs_sinv_dtype, uint8_t *bias_ptr, const DType &bias_dtype, uint8_t *out_ptr, const DType &out_dtype, uint8_t *workspace_ptr, const size_t workspace_size, size_t num_gemms, - int32_t *dim_list_ptr, const int64_t &scaling_mode, + int32_t *dim_list_ptr, const JAXX_Scaling_Mode &scaling_mode, cudaStream_t stream) { size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); @@ -90,14 +90,14 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh auto lhs_sinv_shape = std::vector{1, 1}; auto rhs_sinv_shape = std::vector{1, 1}; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { auto lhs_i = TensorWrapper(static_cast(lhs_ptr), lhs_shape, lhs_dtype, nullptr, nullptr, reinterpret_cast(lhs_sinv_ptr)); auto rhs_i = TensorWrapper(static_cast(rhs_ptr), rhs_shape, rhs_dtype, nullptr, nullptr, reinterpret_cast(rhs_sinv_ptr)); lhs_wrapper_list.push_back(std::move(lhs_i)); rhs_wrapper_list.push_back(std::move(rhs_i)); - } else if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", MXFP8_BLOCK_SIZE, k); size_t sinv_k = k / MXFP8_BLOCK_SIZE; @@ -107,8 +107,8 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh rhs_sinv_shape[1] = sinv_k; // Note: the scale_inv array should have been swizzled in Python before lowering - TensorWrapper lhs_i(NVTE_MXFP8_1D_SCALING); - TensorWrapper rhs_i(NVTE_MXFP8_1D_SCALING); + TensorWrapper lhs_i(get_nvte_scaling_mode(scaling_mode)); + TensorWrapper rhs_i(get_nvte_scaling_mode(scaling_mode)); lhs_i.set_rowwise_data(static_cast(lhs_ptr), lhs_dtype, lhs_shape); rhs_i.set_rowwise_data(static_cast(rhs_ptr), rhs_dtype, rhs_shape); lhs_i.set_rowwise_scale_inv(static_cast(lhs_sinv_ptr), DType::kFloat8E8M0, @@ -119,7 +119,7 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh lhs_wrapper_list.push_back(std::move(lhs_i)); rhs_wrapper_list.push_back(std::move(rhs_i)); } else { - NVTE_ERROR("Unsupported scaling mode: ", scaling_mode); + NVTE_ERROR("Unsupported scaling mode: ", static_cast(scaling_mode)); } auto out_i = TensorWrapper(static_cast(out_ptr), out_shape, out_dtype); @@ -169,7 +169,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten, Buffer_Type lhs_sinv_flatten, Buffer_Type rhs_flatten, Buffer_Type rhs_sinv_flatten, Buffer_Type bias_flatten, Buffer_Type dim_list, Result_Type out_flatten, - Result_Type workspace_flatten, int64_t num_gemms, int64_t scaling_mode) { + Result_Type workspace_flatten, int64_t num_gemms, + JAXX_Scaling_Mode scaling_mode) { // Inputs auto lhs_ptr = reinterpret_cast(lhs_flatten.untyped_data()); auto rhs_ptr = reinterpret_cast(rhs_flatten.untyped_data()); @@ -207,7 +208,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Ret() // out_flatten .Ret() // workspace_flatten .Attr("num_gemms") - .Attr("scaling_mode"), + .Attr("scaling_mode"), FFI_CudaGraph_Traits); } // namespace jax From 2ad77443712c05f16eb0410644730492c8057128 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 8 Apr 2025 18:32:45 -0700 Subject: [PATCH 06/10] lint fix Signed-off-by: Phuong Nguyen --- qa/L0_jax_distributed_unittest/test.sh | 2 +- transformer_engine/jax/quantize/scaling_modes.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index 3253861484..3fbfb9cf5c 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -24,7 +24,7 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" -. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" +. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "test_multiprocessing_encoder.py" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index a47e78892c..34f63a994c 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -16,11 +16,11 @@ from functools import reduce import operator -from transformer_engine_jax import JAXX_Scaling_Mode - from jax.tree_util import register_pytree_node_class import jax.numpy as jnp +from transformer_engine_jax import JAXX_Scaling_Mode + __all__ = ["ScalingMode"] From 1cd0367fab47a8fc92d4a731612ad473c9f96d47 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 9 Apr 2025 06:20:37 -0700 Subject: [PATCH 07/10] rework gemm Signed-off-by: Phuong Nguyen --- .../jax/csrc/extensions/gemm.cpp | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 03c1be551a..17eb6a6bbb 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -90,13 +90,14 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh auto lhs_sinv_shape = std::vector{1, 1}; auto rhs_sinv_shape = std::vector{1, 1}; + auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + lhs_i.set_rowwise_data(static_cast(lhs_ptr), lhs_dtype, lhs_shape); + rhs_i.set_rowwise_data(static_cast(rhs_ptr), rhs_dtype, rhs_shape); + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { - auto lhs_i = TensorWrapper(static_cast(lhs_ptr), lhs_shape, lhs_dtype, nullptr, - nullptr, reinterpret_cast(lhs_sinv_ptr)); - auto rhs_i = TensorWrapper(static_cast(rhs_ptr), rhs_shape, rhs_dtype, nullptr, - nullptr, reinterpret_cast(rhs_sinv_ptr)); - lhs_wrapper_list.push_back(std::move(lhs_i)); - rhs_wrapper_list.push_back(std::move(rhs_i)); + lhs_i.set_rowwise_scale_inv(static_cast(lhs_sinv_ptr), DType::kFloat32, std::vector{1}); + rhs_i.set_rowwise_scale_inv(static_cast(rhs_sinv_ptr), DType::kFloat32, std::vector{1}); } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", MXFP8_BLOCK_SIZE, k); @@ -107,20 +108,15 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh rhs_sinv_shape[1] = sinv_k; // Note: the scale_inv array should have been swizzled in Python before lowering - TensorWrapper lhs_i(get_nvte_scaling_mode(scaling_mode)); - TensorWrapper rhs_i(get_nvte_scaling_mode(scaling_mode)); - lhs_i.set_rowwise_data(static_cast(lhs_ptr), lhs_dtype, lhs_shape); - rhs_i.set_rowwise_data(static_cast(rhs_ptr), rhs_dtype, rhs_shape); lhs_i.set_rowwise_scale_inv(static_cast(lhs_sinv_ptr), DType::kFloat8E8M0, lhs_sinv_shape); rhs_i.set_rowwise_scale_inv(static_cast(rhs_sinv_ptr), DType::kFloat8E8M0, rhs_sinv_shape); - - lhs_wrapper_list.push_back(std::move(lhs_i)); - rhs_wrapper_list.push_back(std::move(rhs_i)); } else { NVTE_ERROR("Unsupported scaling mode: ", static_cast(scaling_mode)); } + lhs_wrapper_list.push_back(std::move(lhs_i)); + rhs_wrapper_list.push_back(std::move(rhs_i)); auto out_i = TensorWrapper(static_cast(out_ptr), out_shape, out_dtype); lhs_ptr += m * k * lhs_dtype_bytes; From 5d0a49b484b607e99ce82ea6b26d28218b1a150c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Apr 2025 13:21:12 +0000 Subject: [PATCH 08/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/csrc/extensions/gemm.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 17eb6a6bbb..c46d72aa16 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -92,12 +92,14 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - lhs_i.set_rowwise_data(static_cast(lhs_ptr), lhs_dtype, lhs_shape); - rhs_i.set_rowwise_data(static_cast(rhs_ptr), rhs_dtype, rhs_shape); + lhs_i.set_rowwise_data(static_cast(lhs_ptr), lhs_dtype, lhs_shape); + rhs_i.set_rowwise_data(static_cast(rhs_ptr), rhs_dtype, rhs_shape); if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { - lhs_i.set_rowwise_scale_inv(static_cast(lhs_sinv_ptr), DType::kFloat32, std::vector{1}); - rhs_i.set_rowwise_scale_inv(static_cast(rhs_sinv_ptr), DType::kFloat32, std::vector{1}); + lhs_i.set_rowwise_scale_inv(static_cast(lhs_sinv_ptr), DType::kFloat32, + std::vector{1}); + rhs_i.set_rowwise_scale_inv(static_cast(rhs_sinv_ptr), DType::kFloat32, + std::vector{1}); } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", MXFP8_BLOCK_SIZE, k); From 0a97d52174a5f957bdbbd84946d1bfbacc630440 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 9 Apr 2025 07:01:27 -0700 Subject: [PATCH 09/10] pass by copy Signed-off-by: Phuong Nguyen --- transformer_engine/jax/csrc/extensions/gemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index c46d72aa16..d4b9bf720e 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -23,7 +23,7 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh uint8_t *rhs_sinv_ptr, const DType &rhs_sinv_dtype, uint8_t *bias_ptr, const DType &bias_dtype, uint8_t *out_ptr, const DType &out_dtype, uint8_t *workspace_ptr, const size_t workspace_size, size_t num_gemms, - int32_t *dim_list_ptr, const JAXX_Scaling_Mode &scaling_mode, + int32_t *dim_list_ptr, const JAXX_Scaling_Mode scaling_mode, cudaStream_t stream) { size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); From d134a3b48391372d6a431751e9f5663483122d8f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 9 Apr 2025 07:18:47 -0700 Subject: [PATCH 10/10] fix norm sharding Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/normalization.py | 12 ++++++------ transformer_engine/jax/quantize/tensor.py | 1 - 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 63b1b797fd..388d4f17ee 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -373,7 +373,7 @@ def infer_sharding_from_operands( "and hurt performance." ) - out_sharding = NamedSharding(mesh, PartitionSpec(out_spec), desc="NormFwdPrimitive.out") + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") colwise_out_spec = out_spec if is_2x else (None,) colwise_out_sharding = NamedSharding( mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" @@ -381,14 +381,14 @@ def infer_sharding_from_operands( rsigma_sharding = NamedSharding( mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma" ) - mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.RMSNorm else (None,) + mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") scale_inv_spec = amax_spec = (None,) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: - scale_inv_spec = colwise_out_spec + scale_inv_spec = out_spec scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv" @@ -444,7 +444,7 @@ def partition( "Enforcing no sharding of parameters hidden dim! " ) - out_sharding = NamedSharding(mesh, PartitionSpec(out_spec), desc="NormFwdPrimitive.out") + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") colwise_out_spec = out_spec if is_2x else (None,) colwise_out_sharding = NamedSharding( mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" @@ -452,14 +452,14 @@ def partition( rsigma_sharding = NamedSharding( mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma" ) - mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.RMSNorm else (None,) + mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") scale_inv_spec = amax_spec = (None,) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: scale_inv_spec = amax_spec = scale_spec elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: - scale_inv_spec = colwise_out_spec + scale_inv_spec = out_spec scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv" diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 941f777f61..0ef30f4728 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -242,7 +242,6 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st else: scale_inv = self.scale_inv - # TODO(Phuong): constaint padded scale_inv? return ScaledTensor1x( data=data, scale_inv=scale_inv,