-
Notifications
You must be signed in to change notification settings - Fork 641
Closed
Labels
Description
Training two same Linear models (same weights and same inputs) in the same script in distributed setting leads to different amax_history. This works fine in current stable (v1.5).
Also calling, FP8GlobalStateManager.reset() between two training loops works fine.
# torchrun --nproc-per-node=2 ddp_stateful.py
import torch
import torch.distributed
import os
from torch.nn.parallel import DistributedDataParallel as DDP
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
import os
import torch
world_group = torch.distributed.init_process_group(backend="nccl")
all_gpus = torch.distributed.new_group(backend="nccl")
rank = int(os.environ["LOCAL_RANK"])
device = f"cuda:{rank}"
torch.cuda.set_device(rank)
# print(device)
dim = 64
n_iter = 2
fp8_recipe = DelayedScaling(amax_history_len=4)
# This function initializes the model with all 1 weights
# and trains for n_iter on fixed data.
def init_and_train_model():
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = te.Linear(dim, dim, bias=False)
def forward(self, x):
return torch.nn.functional.relu(self.fc1(x))
with torch.device("cuda"):
model = Model()
model.fc1.weight.data = (torch.ones(dim, dim) * 4)
x = torch.arange(dim * dim, dtype=torch.float32).view(dim, dim)
if rank == 1:
x = torch.ones(dim, dim) * 100
ddp_model = DDP(model)
optim = torch.optim.SGD(model.parameters())
for _ in range(n_iter):
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
o = ddp_model(x).sum()
o.backward()
optim.step()
optim.zero_grad()
return model
model_1 = init_and_train_model()
# In v1.6, it is necessary to reset the state
# otherwise values differ.
# te.fp8.FP8GlobalStateManager.reset()
model_2 = init_and_train_model()
if rank == 0:
# In v1.5 both will print
# tensor([[4.0950e+03, 4.2804e+03, 0.0000e+00],
# [0.0000e+00, 0.0000e+00, 0.0000e+00],
# [0.0000e+00, 0.0000e+00, 0.0000e+00],
# [4.0950e+03, 4.0000e+00, 0.0000e+00]], device='cuda:0')
# In v1.6, without the te.fp8.FP8GlobalStateManager.reset()
# values will be
# tensor([[ 0.0000, 0.0000, 0.0000],
# [4095.0000, 4280.4160, 0.0000],
# [ 0.0000, 0.0000, 0.0000],
# [ 0.0000, 0.0000, 0.0000]], device='cuda:0')
# tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00],
# [0.0000e+00, 0.0000e+00, 0.0000e+00],
# [4.0950e+03, 4.0000e+00, 0.0000e+00],
# [4.0950e+03, 4.2804e+03, 0.0000e+00]], device='cuda:0')
print(model_1.fc1.fp8_meta["scaling_fwd"].amax_history)
print(model_2.fc1.fp8_meta["scaling_fwd"].amax_history)
torch.testing.assert_close(model_1.fc1.fp8_meta["scaling_fwd"].amax_history,
model_2.fc1.fp8_meta["scaling_fwd"].amax_history)Reactions are currently unavailable