diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 9c1a842cd8..ddc79af426 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1420,15 +1420,17 @@ def test_activation( test_device=device, test_is_fp8=quantized_compute, ) - if quantized_compute: - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, test_device=device, + test_is_fp8=quantized_compute, requires_grad=False, ) + if quantized_compute: + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref: torch.Tensor @@ -1459,6 +1461,7 @@ def test_activation( swiglu=te_ops.SwiGLU, )[activation] forward = te_ops.Sequential( + te_ops.Quantize(forward=False, backward=quantized_compute), make_op(), te_ops.Quantize(forward=quantized_compute, backward=False), ) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index cb93eb5e6b..b451acea9a 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -523,7 +523,7 @@ def _functional_forward( # Configure input tensor for backward pass if own_quantized_x_local: - x_local.update_usage(rowwise_usage=False) + x_local.update_usage(rowwise_usage=False, columnwise_usage=True) # Detach input tensor if needed # Note: PyTorch autograd produces esoteric errors if we save @@ -679,7 +679,9 @@ def _functional_backward( quantizer=input_quantizer, ) else: - if not isinstance(x_local, QuantizedTensor): + if isinstance(x_local, QuantizedTensor): + x_local.update_usage(columnwise_usage=True) + else: x_local = input_quantizer(x_local) x = x_local else: @@ -706,15 +708,19 @@ def _functional_backward( raise ValueError("Weight tensor is required to compute input grad") w = weight w_is_quantized = isinstance(w, QuantizedTensor) - if with_quantized_compute and not w_is_quantized: - if weight_quantizer is None: - raise ValueError("Missing quantizer for weight tensor") - weight_quantizer.set_usage(columnwise=True) - w = weight_quantizer(w) - elif not with_quantized_compute and w_is_quantized: - w = w.dequantize() - if not with_quantized_compute and w.dtype != dtype: - w = w.to(dtype=dtype) + if with_quantized_compute: + if w_is_quantized: + w.update_usage(columnwise_usage=True) + else: + if weight_quantizer is None: + raise ValueError("Missing quantizer for weight tensor") + weight_quantizer.set_usage(columnwise=True) + w = weight_quantizer(w) + else: + if w_is_quantized: + w = w.dequantize(dtype=dtype) + elif w.dtype != dtype: + w = w.to(dtype=dtype) # Synchronize tensor-parallel communication _wait_async(dy_async) @@ -867,8 +873,8 @@ def op_forward( # Configure quantizers # Note: We cache the quantized input for backward pass, # but discard the quantized weights. - input_quantizer.set_usage(columnwise=weight_requires_grad) - weight_quantizer.set_usage(columnwise=False) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + weight_quantizer.set_usage(rowwise=True, columnwise=False) # Get autocast dtype if needed dtype = None