diff --git a/benchmarks/transformer.py b/benchmarks/transformer.py index 962871668..7c2cbfce8 100644 --- a/benchmarks/transformer.py +++ b/benchmarks/transformer.py @@ -135,7 +135,11 @@ def make_model(device, ntokens): criterion = nn.CrossEntropyLoss() lr = 0.01 # learning rate - optimizer = Adam(p.parameters(), lr=lr, precision=Precision.MIXED_PRECISION) + + try: + optimizer = Adam(p.parameters(), lr=lr, precision=Precision.MIXED_PRECISION) + except NameError: + optimizer = Adam(p.parameters(), lr=lr) return p, criterion, optimizer diff --git a/fairscale/optim/adam.py b/fairscale/optim/adam.py index 905eec5c4..531ea7dce 100644 --- a/fairscale/optim/adam.py +++ b/fairscale/optim/adam.py @@ -147,6 +147,10 @@ def mixed_precision(self) -> bool: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: super().load_state_dict(state_dict) + + # TODO: Optimizer state gets cast to FP16 and back to FP32 for + # mixed-precision and memory-efficient mixed-precision. Eventually + # we want to fix this, as some precision may be lost for group in self.param_groups: for p in group["params"]: self.state[p]["exp_avg"] = self.state[p]["exp_avg"].type(self.optim_type) diff --git a/tests/optim/test_adam.py b/tests/optim/test_adam.py index 54ff8a207..cd047bc35 100644 --- a/tests/optim/test_adam.py +++ b/tests/optim/test_adam.py @@ -20,6 +20,12 @@ skip_if_no_adam = pytest.mark.skipif(not imported_adam, reason="Fairscale Adam not available") +@pytest.fixture(autouse=True) +def set_torch_seed(): + torch.manual_seed(1) + yield + + def make_full_precision_params(): weight = torch.randn(2, 1).cuda().requires_grad_() bias = torch.randn(2).cuda().requires_grad_() @@ -75,12 +81,26 @@ def fn_base(optimizer, weight, bias, input): # Load state dict state_dict = deepcopy(optimizer.state_dict()) optimizer_c.load_state_dict(state_dict) + + for group, group_c in zip(optimizer.param_groups, optimizer_c.param_groups): + for p, p_c in zip(group["params"], group_c["params"]): + assert torch.equal(optimizer.state[p]["exp_avg"], optimizer_c.state[p_c]["exp_avg"]) + assert torch.equal(optimizer.state[p]["exp_avg_sq"], optimizer_c.state[p_c]["exp_avg_sq"]) + + if optimizer.fp32_param_groups: + # When using mixed precision, fp32_param_groups are made from FP16 params rather than + # copied via state_dict, introducing differences between the original optimizer and + # the copy. Because this test requires that they be the exact same, we copy the + # fp32 params from the original optimizer to the copy + optimizer_c.fp32_param_groups = deepcopy(optimizer.fp32_param_groups) + # Run both optimizations in parallel for _i in range(5): optimizer.step(fn) optimizer_c.step(fn_c) - (weight - weight_c).to("cpu").detach().apply_(assert_almost_zero) - (bias - bias_c).to("cpu").detach().apply_(assert_almost_zero) + + assert torch.equal(weight, weight_c) + assert torch.equal(bias, bias_c) def assert_almost_zero(x): @@ -230,7 +250,12 @@ def test_state_dict_full_precision(): @skip_if_no_cuda @skip_if_no_adam +@pytest.mark.xfail def test_state_dict_mixed_precision(): + # TODO: Optimizer state gets cast to FP16 and back to FP32 for + # mixed-precision and memory-efficient mixed-precision, resulting + # in a potential loss of precision. Thus, as training proceeds, we don't + # necessarily expect the parameters to remain the exact same. weight, bias, input = make_half_precision_params() optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MIXED_PRECISION) @@ -239,7 +264,12 @@ def test_state_dict_mixed_precision(): @skip_if_no_cuda @skip_if_no_adam +@pytest.mark.xfail def test_state_dict_memory_efficient(): + # TODO: Optimizer state gets cast to FP16 and back to FP32 for + # mixed-precision and memory-efficient mixed-precision, resulting + # in a potential loss of precision. Thus, as training proceeds, we don't + # necessarily expect the parameters to remain the exact same. weight, bias, input = make_half_precision_params() optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)