Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,15 +1420,17 @@ def test_activation(
test_device=device,
test_is_fp8=quantized_compute,
)
if quantized_compute:
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if quantized_compute:
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
dy_test = dy_test.dequantize()

# Plain PyTorch implementation
y_ref: torch.Tensor
Expand Down Expand Up @@ -1459,6 +1461,7 @@ def test_activation(
swiglu=te_ops.SwiGLU,
)[activation]
forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=quantized_compute),
make_op(),
te_ops.Quantize(forward=quantized_compute, backward=False),
)
Expand Down
32 changes: 19 additions & 13 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def _functional_forward(

# Configure input tensor for backward pass
if own_quantized_x_local:
x_local.update_usage(rowwise_usage=False)
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)

# Detach input tensor if needed
# Note: PyTorch autograd produces esoteric errors if we save
Expand Down Expand Up @@ -679,7 +679,9 @@ def _functional_backward(
quantizer=input_quantizer,
)
else:
if not isinstance(x_local, QuantizedTensor):
if isinstance(x_local, QuantizedTensor):
x_local.update_usage(columnwise_usage=True)
else:
x_local = input_quantizer(x_local)
x = x_local
else:
Expand All @@ -706,15 +708,19 @@ def _functional_backward(
raise ValueError("Weight tensor is required to compute input grad")
w = weight
w_is_quantized = isinstance(w, QuantizedTensor)
if with_quantized_compute and not w_is_quantized:
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(columnwise=True)
w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized:
w = w.dequantize()
if not with_quantized_compute and w.dtype != dtype:
w = w.to(dtype=dtype)
if with_quantized_compute:
if w_is_quantized:
w.update_usage(columnwise_usage=True)
else:
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(columnwise=True)
w = weight_quantizer(w)
else:
if w_is_quantized:
w = w.dequantize(dtype=dtype)
elif w.dtype != dtype:
w = w.to(dtype=dtype)

# Synchronize tensor-parallel communication
_wait_async(dy_async)
Expand Down Expand Up @@ -867,8 +873,8 @@ def op_forward(
# Configure quantizers
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
input_quantizer.set_usage(columnwise=weight_requires_grad)
weight_quantizer.set_usage(columnwise=False)
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=False)

# Get autocast dtype if needed
dtype = None
Expand Down