import torch
from transformer_engine.pytorch import Linear as TELinear, fp8_autocast
# m = torch.nn.Linear(16, 16).to("cuda") # This works
m = TELinear(16, 16)
x = torch.randn(16, 16, device='cuda')
with fp8_autocast(True):
o = m(x).sum()
o.backward(retain_graph=True)
# this fails with
# AssertionError: FP8 execution requires 2D input matrices with height divisible by 8 and width divisible by 16, but got tensor with dims=[0]
# looks like TELinear.backward mutates the context object such that it is not reusable.
o.backward()
This would be useful to support benchmarking just the backward pass.