Skip to content

v1.6: FP8GlobalStateManager seems to be preserving state in distributed setting #814

@kshitij12345

Description

@kshitij12345

Training two same Linear models (same weights and same inputs) in the same script in distributed setting leads to different amax_history. This works fine in current stable (v1.5).

Also calling, FP8GlobalStateManager.reset() between two training loops works fine.

# torchrun --nproc-per-node=2 ddp_stateful.py 
import torch
import torch.distributed
import os
from torch.nn.parallel import DistributedDataParallel as DDP
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling

import os
import torch
world_group = torch.distributed.init_process_group(backend="nccl")
all_gpus = torch.distributed.new_group(backend="nccl")
rank = int(os.environ["LOCAL_RANK"])
device = f"cuda:{rank}"
torch.cuda.set_device(rank)
# print(device)

dim = 64
n_iter = 2
fp8_recipe = DelayedScaling(amax_history_len=4)

# This function initializes the model with all 1 weights
# and trains for n_iter on fixed data.
def init_and_train_model():
    class Model(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.fc1 = te.Linear(dim, dim, bias=False)
        
        def forward(self, x):
            return torch.nn.functional.relu(self.fc1(x))

    with torch.device("cuda"):
        model = Model()
        model.fc1.weight.data = (torch.ones(dim, dim) * 4)
        x = torch.arange(dim * dim, dtype=torch.float32).view(dim, dim)
        if rank == 1:
            x = torch.ones(dim, dim) * 100

    ddp_model = DDP(model)

    optim = torch.optim.SGD(model.parameters())

    for _ in range(n_iter):
        with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
            o = ddp_model(x).sum()

        o.backward()
        optim.step()
        optim.zero_grad()

    return model

model_1 = init_and_train_model()
# In v1.6, it is necessary to reset the state
# otherwise values differ.
# te.fp8.FP8GlobalStateManager.reset()
model_2 = init_and_train_model()

if rank == 0:
    # In v1.5 both will print
    # tensor([[4.0950e+03, 4.2804e+03, 0.0000e+00],
    #         [0.0000e+00, 0.0000e+00, 0.0000e+00],
    #         [0.0000e+00, 0.0000e+00, 0.0000e+00],
    #         [4.0950e+03, 4.0000e+00, 0.0000e+00]], device='cuda:0')

    # In v1.6, without the te.fp8.FP8GlobalStateManager.reset()
    # values will be
    # tensor([[   0.0000,    0.0000,    0.0000],
    #         [4095.0000, 4280.4160,    0.0000],
    #         [   0.0000,    0.0000,    0.0000],
    #         [   0.0000,    0.0000,    0.0000]], device='cuda:0')
    # tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00],
    #         [0.0000e+00, 0.0000e+00, 0.0000e+00],
    #         [4.0950e+03, 4.0000e+00, 0.0000e+00],
    #         [4.0950e+03, 4.2804e+03, 0.0000e+00]], device='cuda:0')
    print(model_1.fc1.fp8_meta["scaling_fwd"].amax_history)
    print(model_2.fc1.fp8_meta["scaling_fwd"].amax_history)
    torch.testing.assert_close(model_1.fc1.fp8_meta["scaling_fwd"].amax_history,
                               model_2.fc1.fp8_meta["scaling_fwd"].amax_history)

Metadata

Metadata

Assignees

No one assigned

    Labels

    1.6.0bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions