Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,8 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""

is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)

@classmethod
def setUpClass(cls):
Expand Down
4 changes: 2 additions & 2 deletions examples/jax/encoder/test_multigpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,8 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""

is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)

@classmethod
def setUpClass(cls):
Expand Down
4 changes: 2 additions & 2 deletions examples/jax/encoder/test_single_gpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""

is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)

@classmethod
def setUpClass(cls):
Expand Down
4 changes: 2 additions & 2 deletions examples/jax/mnist/test_single_gpu_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ def mnist_parser(args):
class TestMNIST(unittest.TestCase):
"""MNIST unittests"""

is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)

@classmethod
def setUpClass(cls):
Expand Down
2 changes: 1 addition & 1 deletion qa/L0_jax_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "test_multiprocessing_encoder.py"

if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
Expand Down
40 changes: 20 additions & 20 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,21 @@
LN_CASES = [(256, 128), (128, 256)]
DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, reason = helper.is_fp8_available()
is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING)

supported_scaling_modes = []
""" Find supported scaling modes"""
if is_fp8_supported:
supported_scaling_modes.append(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING)
if is_mxfp8_supported:
supported_scaling_modes.append(ScalingMode.NVTE_MXFP8_1D_SCALING)
supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING)


def is_shape_supported_by_mxfp8(input_shape):
try:
if isinstance(input_shape, type(pytest.param(0))):
input_shape = input_shape.values[0]
ScalingMode.NVTE_MXFP8_1D_SCALING.get_scale_shape_2x(input_shape)
ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape)
return True
except:
# get_scale_shapes will raise an exception if the shape is not supported
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type,
)

quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_dtype=output_type,
q_layout=QuantizeLayout.ROWWISE,
)
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_act_forward_with_delayed_scaling_fp8(

te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_dtype=output_type,
q_layout=q_layout,
)
Expand All @@ -223,7 +223,7 @@ def test_act_forward_with_block_scaling_fp8(
self.activation_type = activation_type

quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
)

output = tex.act_lu(x, activation_type, quantizer)
Expand Down Expand Up @@ -345,7 +345,7 @@ def test_norm_grad_with_delayed_scaling_fp8(
pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_dtype=out_dtype,
q_layout=q_layout,
)
Expand Down Expand Up @@ -420,7 +420,7 @@ def test_norm_forward_with_delayed_scaling_fp8(
epsilon=epsilon,
inp_dtype=inp_dtype,
out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_layout=q_layout,
)

Expand All @@ -437,7 +437,7 @@ def test_norm_forward_with_block_scaling_fp8(
epsilon=epsilon,
inp_dtype=inp_dtype,
out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING,
scaling_mode=ScalingMode.MXFP8_1D_SCALING,
q_layout=QuantizeLayout.ROWWISE_COLWISE,
)

Expand Down Expand Up @@ -493,7 +493,7 @@ def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatt
if flatten_axis == -2:
input_shape = input_shape[:-1] + (2,) + input_shape[-1:]

n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations):
x = jax.random.uniform(key, input_shape, in_dtype)

Expand Down Expand Up @@ -533,7 +533,7 @@ class TestFusedQuantize:
def test_quantize_dbias(
self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis
):
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8(
if scaling_mode == ScalingMode.MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8(
input_shape
):
pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")
Expand Down Expand Up @@ -618,7 +618,7 @@ def test_quantize_dact_dbias_no_quantization(
in_dtype=in_dtype,
input_shape=input_shape,
out_dtype=in_dtype,
scaling_mode=ScalingMode.NVTE_NO_SCALING,
scaling_mode=ScalingMode.NO_SCALING,
activation_type=activation_type,
is_dbias=is_dbias,
q_layout=QuantizeLayout.ROWWISE,
Expand All @@ -639,7 +639,7 @@ def test_quantize_dact_dbias_delayed_scaling(
in_dtype=in_dtype,
input_shape=input_shape,
out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
activation_type=activation_type,
is_dbias=is_dbias,
q_layout=q_layout,
Expand Down Expand Up @@ -670,7 +670,7 @@ def test_quantize_dact_dbias_mxfp8_scaling(
in_dtype=in_dtype,
input_shape=input_shape,
out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING,
scaling_mode=ScalingMode.MXFP8_1D_SCALING,
activation_type=activation_type,
is_dbias=is_dbias,
q_layout=q_layout,
Expand Down Expand Up @@ -785,7 +785,7 @@ def ref_func(x, w, bias, data_layout):
scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True
)

n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations):
primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
Expand Down Expand Up @@ -830,7 +830,7 @@ def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type):
Test layernorm_dense VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
pytest.skip("E5M2 is not supported in normalization with TE Backend!")

# zero_centered_gamma is already tested in TestNorm
Expand Down Expand Up @@ -886,7 +886,7 @@ def ref_func(x, w, gamma, beta):
x, w, gamma, beta
)

n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations):
prim_out, (
prim_x_grad,
Expand Down Expand Up @@ -916,7 +916,7 @@ def test_layernorm_mlp_grad(
Test layernorm_mlp VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
pytest.skip("E5M2 is not supported in normalization with TE Backend!")

# zero_centered_gamma is already tested in TestNorm
Expand Down Expand Up @@ -993,7 +993,7 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2):
value_n_grad_prim_func = value_and_grad(prim_func, range(6))
value_n_grad_ref_func = value_and_grad(ref_func, range(6))

n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations):
prim_out, (
prim_x_grad,
Expand Down
2 changes: 1 addition & 1 deletion tests/jax/test_distributed_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
}

is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)

SUPPORTED_RECIPES = []
if is_fp8_supported:
Expand Down
2 changes: 1 addition & 1 deletion tests/jax/test_distributed_layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@


is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)

SUPPORTED_RECIPES = []
if is_fp8_supported:
Expand Down
4 changes: 2 additions & 2 deletions tests/jax/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def enable_fused_attn():


is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)

QUANTIZE_RECIPES = []
""" Find supported scaling modes"""
Expand Down Expand Up @@ -313,7 +313,7 @@ def test_backward(
test_others,
test_layer,
)
if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
_, updated_quantize_meta = flax.core.pop(
updated_state[0], QuantizeConfig.COLLECTION_NAME
)
Expand Down
Loading