diff --git a/.circleci/config.yml b/.circleci/config.yml index 00f151d9e..2925c2b5c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -96,6 +96,12 @@ run_transformer_benchmark: &run_transformer_benchmark command: | python benchmarks/transformer.py +run_oss_benchmark: &run_oss_benchmark + - run: + name: Run OSS Benchmark + command: | + python benchmarks/oss.py + # ------------------------------------------------------------------------------------- # Jobs to run # ------------------------------------------------------------------------------------- @@ -244,6 +250,9 @@ jobs: - <<: *run_transformer_benchmark + - <<: *run_oss_benchmark + + workflows: version: 2 diff --git a/benchmarks/oss.py b/benchmarks/oss.py new file mode 100755 index 000000000..bcc434b2c --- /dev/null +++ b/benchmarks/oss.py @@ -0,0 +1,150 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +import argparse +import math +import os +import time +from typing import Any, List + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from torch.utils.data import DataLoader +from torchvision.datasets import FakeData +from torchvision.models import resnet101 +from torchvision.transforms import ToTensor + +from fairscale.optim.oss import OSS + +BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore + + +def dist_init(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29501" + dist.init_process_group(backend=BACKEND, rank=rank, world_size=world_size) + + +def train( + rank: int, + world_size: int, + num_epochs: int = 10, + batch_size: int = 32, + data_size: int = 200, + use_oss: bool = True, + check_regression: bool = True, + reference_speed: float = -1.0, +): + # DDP + dist_init(rank, world_size) + + # Standard RN101 + model = resnet101(pretrained=False, progress=True).to(rank) + + # Data setup, dummy data + def collate(inputs: List[Any]): + return { + "inputs": torch.stack([i[0] for i in inputs]).to(rank), + "label": torch.stack([i[1] for i in inputs]).to(rank), + } + + def _print(msg): + if dist.get_rank() == 0: + print(msg) + + dataloader = DataLoader( + dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate + ) + loss_fn = nn.CrossEntropyLoss() + + # Shard the optimizer + optimizer = ( + OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.9) + if use_oss + else torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) + ) + + # Dummy training loop + torch.cuda.synchronize(rank) + training_start = time.monotonic() + model.train() + + measurements = [] + + for epoch in range(num_epochs): + epoch_start = time.monotonic() + + for batch in dataloader: + + def closure(): + model.zero_grad() + outputs = model(batch["inputs"]) + loss = loss_fn(outputs, batch["label"]) + dist.all_reduce(loss, op=dist.ReduceOp.SUM) + loss /= world_size + loss.backward() + return loss + + optimizer.step(closure) + + epoch_end = time.monotonic() + measurements.append(data_size / (epoch_end - epoch_start)) + _print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec") + + torch.cuda.synchronize(rank) + training_stop = time.monotonic() + img_per_sec = data_size / (training_stop - training_start) * num_epochs + max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20 + + print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall") + print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB") + + if use_oss and check_regression and dist.get_rank() == 0: + # Compute the mean and average img per second + mean = sum(measurements) / len(measurements) + diff = map(lambda x: pow(x - mean, 2.0), measurements) + std = math.sqrt(sum(diff) / (len(measurements) - 1)) + print(f"[Regression Test] Mean: {mean:.2f} +/- {std:.2f}") + assert (mean - 3.0 * std) < reference_speed, "Regression detected" + print("[Regression Test] VALID") + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="Benchmark the optimizer state sharding, on a typical computer vision workload" + ) + parser.add_argument("--world_size", action="store", default=2, type=int) + parser.add_argument("--epochs", action="store", default=10, type=int) + parser.add_argument("--batch_size", action="store", default=32, type=int) + parser.add_argument("--data_size", action="store", default=512, type=int) + parser.add_argument("--check_regression", action="store", default=True, type=bool) + parser.add_argument("--reference_speed", action="store", default=39.82, type=float) + + args = parser.parse_args() + + print("\nBenchmark vanilla SGD") + mp.spawn( + train, + args=(args.world_size, args.epochs, args.batch_size, args.data_size, False, False), + nprocs=args.world_size, + join=True, + ) + + print("\nBenchmark OSS") + mp.spawn( + train, + args=( + args.world_size, + args.epochs, + args.batch_size, + args.data_size, + True, + args.check_regression, + args.reference_speed, + ), + nprocs=args.world_size, + join=True, + ) diff --git a/benchmarks/transformer.py b/benchmarks/transformer.py index 962871668..7c2cbfce8 100644 --- a/benchmarks/transformer.py +++ b/benchmarks/transformer.py @@ -135,7 +135,11 @@ def make_model(device, ntokens): criterion = nn.CrossEntropyLoss() lr = 0.01 # learning rate - optimizer = Adam(p.parameters(), lr=lr, precision=Precision.MIXED_PRECISION) + + try: + optimizer = Adam(p.parameters(), lr=lr, precision=Precision.MIXED_PRECISION) + except NameError: + optimizer = Adam(p.parameters(), lr=lr) return p, criterion, optimizer diff --git a/fairscale/optim/adam.py b/fairscale/optim/adam.py index 905eec5c4..531ea7dce 100644 --- a/fairscale/optim/adam.py +++ b/fairscale/optim/adam.py @@ -147,6 +147,10 @@ def mixed_precision(self) -> bool: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: super().load_state_dict(state_dict) + + # TODO: Optimizer state gets cast to FP16 and back to FP32 for + # mixed-precision and memory-efficient mixed-precision. Eventually + # we want to fix this, as some precision may be lost for group in self.param_groups: for p in group["params"]: self.state[p]["exp_avg"] = self.state[p]["exp_avg"].type(self.optim_type) diff --git a/fairscale/optim/oss.py b/fairscale/optim/oss.py index d8214ba57..de18f6d3b 100644 --- a/fairscale/optim/oss.py +++ b/fairscale/optim/oss.py @@ -49,7 +49,7 @@ class OSS(Optimizer): in_super_constructor: bool def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Any = dist.group.WORLD, **defaults: Any): - # Hold all the nmodel params in the root .param_groups + # Hold all the model params in the root .param_groups self.in_super_constructor = True super().__init__(params, defaults) self.in_super_constructor = False diff --git a/requirements-test.txt b/requirements-test.txt index 66e3f1ba2..b3cacabe7 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -6,5 +6,6 @@ pytest == 5.4.1 pytest-cov == 2.10.0 torchtext == 0.6.0 torch >= 1.5.1 +torchvision >= 0.6.0 # NOTE(msb) not a dependency but needed by torch numpy == 1.17.4 diff --git a/tests/optim/test_adam.py b/tests/optim/test_adam.py index 54ff8a207..cd047bc35 100644 --- a/tests/optim/test_adam.py +++ b/tests/optim/test_adam.py @@ -20,6 +20,12 @@ skip_if_no_adam = pytest.mark.skipif(not imported_adam, reason="Fairscale Adam not available") +@pytest.fixture(autouse=True) +def set_torch_seed(): + torch.manual_seed(1) + yield + + def make_full_precision_params(): weight = torch.randn(2, 1).cuda().requires_grad_() bias = torch.randn(2).cuda().requires_grad_() @@ -75,12 +81,26 @@ def fn_base(optimizer, weight, bias, input): # Load state dict state_dict = deepcopy(optimizer.state_dict()) optimizer_c.load_state_dict(state_dict) + + for group, group_c in zip(optimizer.param_groups, optimizer_c.param_groups): + for p, p_c in zip(group["params"], group_c["params"]): + assert torch.equal(optimizer.state[p]["exp_avg"], optimizer_c.state[p_c]["exp_avg"]) + assert torch.equal(optimizer.state[p]["exp_avg_sq"], optimizer_c.state[p_c]["exp_avg_sq"]) + + if optimizer.fp32_param_groups: + # When using mixed precision, fp32_param_groups are made from FP16 params rather than + # copied via state_dict, introducing differences between the original optimizer and + # the copy. Because this test requires that they be the exact same, we copy the + # fp32 params from the original optimizer to the copy + optimizer_c.fp32_param_groups = deepcopy(optimizer.fp32_param_groups) + # Run both optimizations in parallel for _i in range(5): optimizer.step(fn) optimizer_c.step(fn_c) - (weight - weight_c).to("cpu").detach().apply_(assert_almost_zero) - (bias - bias_c).to("cpu").detach().apply_(assert_almost_zero) + + assert torch.equal(weight, weight_c) + assert torch.equal(bias, bias_c) def assert_almost_zero(x): @@ -230,7 +250,12 @@ def test_state_dict_full_precision(): @skip_if_no_cuda @skip_if_no_adam +@pytest.mark.xfail def test_state_dict_mixed_precision(): + # TODO: Optimizer state gets cast to FP16 and back to FP32 for + # mixed-precision and memory-efficient mixed-precision, resulting + # in a potential loss of precision. Thus, as training proceeds, we don't + # necessarily expect the parameters to remain the exact same. weight, bias, input = make_half_precision_params() optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MIXED_PRECISION) @@ -239,7 +264,12 @@ def test_state_dict_mixed_precision(): @skip_if_no_cuda @skip_if_no_adam +@pytest.mark.xfail def test_state_dict_memory_efficient(): + # TODO: Optimizer state gets cast to FP16 and back to FP32 for + # mixed-precision and memory-efficient mixed-precision, resulting + # in a potential loss of precision. Thus, as training proceeds, we don't + # necessarily expect the parameters to remain the exact same. weight, bias, input = make_half_precision_params() optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)