Skip to content

FusedLayerNorm corrupts data when switching devices #1022

@stas00

Description

@stas00

I am porting transformers's Bart to model parallel, so there is a lot of device switching, and I run into:

RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemm(...)`

when using apex's FusedLayerNorm.

This problem doesn't exist with torch.nn.LayerNorm in pt-nightly (which was ported from FusedLayerNorm pytorch/pytorch#27634).

I was able to reproduce it quite simply:

# fused.py
from apex.normalization import FusedLayerNorm
import torch

class Norm(torch.nn.Module):
    def __init__(self):
        super().__init__()    
        self.ln1 = torch.nn.LayerNorm(4)
        self.ln2 = FusedLayerNorm(4)

    def forward(self, x):
        y2 = self.ln2(x)
        print(f"apex  : {y2}")
        y1 = self.ln1(x)
        print(f"native: {y1}")

model = Norm()
        
x = torch.tensor([[1.5,.0,.0,.0]])
for id in [0, 1]:
    print(f"ID {id}")
    x = x.to(id)
    model.to(id)
    model(x)

2 GPUs are needed to see the problem:

python fused.py
ID 0
apex  : tensor([[ 1.7320, -0.5773, -0.5773, -0.5773]], device='cuda:0',
       grad_fn=<FusedLayerNormAffineFunctionBackward>)
native: tensor([[ 1.7320, -0.5773, -0.5773, -0.5773]], device='cuda:0',
       grad_fn=<NativeLayerNormBackward>)
ID 1
apex  : tensor([[0., 0., 0., 0.]], device='cuda:1',
       grad_fn=<FusedLayerNormAffineFunctionBackward>)
native: tensor([[ 1.7320, -0.5773, -0.5773, -0.5773]], device='cuda:1',
       grad_fn=<NativeLayerNormBackward>)

As you can see apex's norm broke the tensor and if I pass it to some other pytorch op it blows up with CUBLAS_STATUS_EXECUTION_FAILED

If hover, I flip the two norm calls:

    def forward(self, x):
        y1 = self.ln1(x)
        print(f"native: {y1}")
        y2 = self.ln2(x)
        print(f"apex  : {y2}")

Everything works:

ID 0
native: tensor([[ 1.7320, -0.5773, -0.5773, -0.5773]], device='cuda:0',
       grad_fn=<NativeLayerNormBackward>)
apex  : tensor([[ 1.7320, -0.5773, -0.5773, -0.5773]], device='cuda:0',
       grad_fn=<FusedLayerNormAffineFunctionBackward>)
ID 1
native: tensor([[ 1.7320, -0.5773, -0.5773, -0.5773]], device='cuda:1',
       grad_fn=<NativeLayerNormBackward>)
apex  : tensor([[ 1.7320, -0.5773, -0.5773, -0.5773]], device='cuda:1',
       grad_fn=<FusedLayerNormAffineFunctionBackward>)

So apex's version is missing something wrt to switching to a new device.

The workaround I found in this simple script is to torch.cuda.set_device(id) before the call to FusedLayerNorm:

x = torch.tensor([[1.5,.0,.0,.0]])
for id in [0, 1]:
    print(f"ID {id}")
    x = x.to(id)
    model.to(id)
    torch.cuda.set_device(id)
    model(x)

This is with pt-nightly, py-38, and apex master.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions