Skip to content

Calling backward(retain_graph=True) multiple times with TE Layer does not work #990

@kshitij12345

Description

@kshitij12345
import torch
from transformer_engine.pytorch import Linear as TELinear, fp8_autocast

# m = torch.nn.Linear(16, 16).to("cuda")  # This works
m = TELinear(16, 16)
x = torch.randn(16, 16, device='cuda')

with fp8_autocast(True):
    o = m(x).sum()

o.backward(retain_graph=True)

# this fails with
# AssertionError: FP8 execution requires 2D input matrices with height divisible by 8 and width divisible by 16, but got tensor with dims=[0]
# looks like TELinear.backward mutates the context object such that it is not reusable.
o.backward()

This would be useful to support benchmarking just the backward pass.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions