Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ run_transformer_benchmark: &run_transformer_benchmark
command: |
python benchmarks/transformer.py

run_oss_benchmark: &run_oss_benchmark
- run:
name: Run OSS Benchmark
command: |
python benchmarks/oss.py

# -------------------------------------------------------------------------------------
# Jobs to run
# -------------------------------------------------------------------------------------
Expand Down Expand Up @@ -244,6 +250,9 @@ jobs:

- <<: *run_transformer_benchmark

- <<: *run_oss_benchmark



workflows:
version: 2
Expand Down
150 changes: 150 additions & 0 deletions benchmarks/oss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import argparse
import math
import os
import time
from typing import Any, List

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision.models import resnet101
from torchvision.transforms import ToTensor

from fairscale.optim.oss import OSS

BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore


def dist_init(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend=BACKEND, rank=rank, world_size=world_size)


def train(
rank: int,
world_size: int,
num_epochs: int = 10,
batch_size: int = 32,
data_size: int = 200,
use_oss: bool = True,
check_regression: bool = True,
reference_speed: float = -1.0,
):
# DDP
dist_init(rank, world_size)

# Standard RN101
model = resnet101(pretrained=False, progress=True).to(rank)

# Data setup, dummy data
def collate(inputs: List[Any]):
return {
"inputs": torch.stack([i[0] for i in inputs]).to(rank),
"label": torch.stack([i[1] for i in inputs]).to(rank),
}

def _print(msg):
if dist.get_rank() == 0:
print(msg)

dataloader = DataLoader(
dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate
)
loss_fn = nn.CrossEntropyLoss()

# Shard the optimizer
optimizer = (
OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.9)
if use_oss
else torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
)

# Dummy training loop
torch.cuda.synchronize(rank)
training_start = time.monotonic()
model.train()

measurements = []

for epoch in range(num_epochs):
epoch_start = time.monotonic()

for batch in dataloader:

def closure():
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
loss /= world_size
loss.backward()
return loss

optimizer.step(closure)

epoch_end = time.monotonic()
measurements.append(data_size / (epoch_end - epoch_start))
_print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")

torch.cuda.synchronize(rank)
training_stop = time.monotonic()
img_per_sec = data_size / (training_stop - training_start) * num_epochs
max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

if use_oss and check_regression and dist.get_rank() == 0:
# Compute the mean and average img per second
mean = sum(measurements) / len(measurements)
diff = map(lambda x: pow(x - mean, 2.0), measurements)
std = math.sqrt(sum(diff) / (len(measurements) - 1))
print(f"[Regression Test] Mean: {mean:.2f} +/- {std:.2f}")
assert (mean - 3.0 * std) < reference_speed, "Regression detected"
print("[Regression Test] VALID")


if __name__ == "__main__":

parser = argparse.ArgumentParser(
description="Benchmark the optimizer state sharding, on a typical computer vision workload"
)
parser.add_argument("--world_size", action="store", default=2, type=int)
parser.add_argument("--epochs", action="store", default=10, type=int)
parser.add_argument("--batch_size", action="store", default=32, type=int)
parser.add_argument("--data_size", action="store", default=512, type=int)
parser.add_argument("--check_regression", action="store", default=True, type=bool)
parser.add_argument("--reference_speed", action="store", default=39.82, type=float)

args = parser.parse_args()

print("\nBenchmark vanilla SGD")
mp.spawn(
train,
args=(args.world_size, args.epochs, args.batch_size, args.data_size, False, False),
nprocs=args.world_size,
join=True,
)

print("\nBenchmark OSS")
mp.spawn(
train,
args=(
args.world_size,
args.epochs,
args.batch_size,
args.data_size,
True,
args.check_regression,
args.reference_speed,
),
nprocs=args.world_size,
join=True,
)
6 changes: 5 additions & 1 deletion benchmarks/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,11 @@ def make_model(device, ntokens):

criterion = nn.CrossEntropyLoss()
lr = 0.01 # learning rate
optimizer = Adam(p.parameters(), lr=lr, precision=Precision.MIXED_PRECISION)

try:
optimizer = Adam(p.parameters(), lr=lr, precision=Precision.MIXED_PRECISION)
except NameError:
optimizer = Adam(p.parameters(), lr=lr)

return p, criterion, optimizer

Expand Down
4 changes: 4 additions & 0 deletions fairscale/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def mixed_precision(self) -> bool:

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
super().load_state_dict(state_dict)

# TODO: Optimizer state gets cast to FP16 and back to FP32 for
# mixed-precision and memory-efficient mixed-precision. Eventually
# we want to fix this, as some precision may be lost
for group in self.param_groups:
for p in group["params"]:
self.state[p]["exp_avg"] = self.state[p]["exp_avg"].type(self.optim_type)
Expand Down
2 changes: 1 addition & 1 deletion fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class OSS(Optimizer):
in_super_constructor: bool

def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Any = dist.group.WORLD, **defaults: Any):
# Hold all the nmodel params in the root .param_groups
# Hold all the model params in the root .param_groups
self.in_super_constructor = True
super().__init__(params, defaults)
self.in_super_constructor = False
Expand Down
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ pytest == 5.4.1
pytest-cov == 2.10.0
torchtext == 0.6.0
torch >= 1.5.1
torchvision >= 0.6.0
# NOTE(msb) not a dependency but needed by torch
numpy == 1.17.4
34 changes: 32 additions & 2 deletions tests/optim/test_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
skip_if_no_adam = pytest.mark.skipif(not imported_adam, reason="Fairscale Adam not available")


@pytest.fixture(autouse=True)
def set_torch_seed():
torch.manual_seed(1)
yield


def make_full_precision_params():
weight = torch.randn(2, 1).cuda().requires_grad_()
bias = torch.randn(2).cuda().requires_grad_()
Expand Down Expand Up @@ -75,12 +81,26 @@ def fn_base(optimizer, weight, bias, input):
# Load state dict
state_dict = deepcopy(optimizer.state_dict())
optimizer_c.load_state_dict(state_dict)

for group, group_c in zip(optimizer.param_groups, optimizer_c.param_groups):
for p, p_c in zip(group["params"], group_c["params"]):
assert torch.equal(optimizer.state[p]["exp_avg"], optimizer_c.state[p_c]["exp_avg"])
assert torch.equal(optimizer.state[p]["exp_avg_sq"], optimizer_c.state[p_c]["exp_avg_sq"])

if optimizer.fp32_param_groups:
# When using mixed precision, fp32_param_groups are made from FP16 params rather than
# copied via state_dict, introducing differences between the original optimizer and
# the copy. Because this test requires that they be the exact same, we copy the
# fp32 params from the original optimizer to the copy
optimizer_c.fp32_param_groups = deepcopy(optimizer.fp32_param_groups)

# Run both optimizations in parallel
for _i in range(5):
optimizer.step(fn)
optimizer_c.step(fn_c)
(weight - weight_c).to("cpu").detach().apply_(assert_almost_zero)
(bias - bias_c).to("cpu").detach().apply_(assert_almost_zero)

assert torch.equal(weight, weight_c)
assert torch.equal(bias, bias_c)


def assert_almost_zero(x):
Expand Down Expand Up @@ -230,7 +250,12 @@ def test_state_dict_full_precision():

@skip_if_no_cuda
@skip_if_no_adam
@pytest.mark.xfail
def test_state_dict_mixed_precision():
# TODO: Optimizer state gets cast to FP16 and back to FP32 for
# mixed-precision and memory-efficient mixed-precision, resulting
# in a potential loss of precision. Thus, as training proceeds, we don't
# necessarily expect the parameters to remain the exact same.
weight, bias, input = make_half_precision_params()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MIXED_PRECISION)

Expand All @@ -239,7 +264,12 @@ def test_state_dict_mixed_precision():

@skip_if_no_cuda
@skip_if_no_adam
@pytest.mark.xfail
def test_state_dict_memory_efficient():
# TODO: Optimizer state gets cast to FP16 and back to FP32 for
# mixed-precision and memory-efficient mixed-precision, resulting
# in a potential loss of precision. Thus, as training proceeds, we don't
# necessarily expect the parameters to remain the exact same.
weight, bias, input = make_half_precision_params()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)

Expand Down