diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index bb67428bef..40483e8c85 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -44,6 +44,8 @@ class CheckpointLoader: first load the module to CPU and then copy each parameter to where it was saved, which would result in all processes on the same machine using the same set of devices. + strict: whether to strictly enforce that the keys in :attr:`state_dict` match the keys + returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` """ @@ -53,6 +55,7 @@ def __init__( load_dict: Dict, name: Optional[str] = None, map_location: Optional[Dict] = None, + strict: bool = True, ) -> None: if load_path is None: raise AssertionError("must provide clear path to load checkpoint.") @@ -63,6 +66,7 @@ def __init__( self.load_dict = load_dict self._name = name self.map_location = map_location + self.strict = strict def attach(self, engine: Engine) -> None: """ @@ -82,7 +86,7 @@ def __call__(self, engine: Engine) -> None: # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint prior_max_epochs = engine.state.max_epochs - Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint) + Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint, strict=self.strict) if engine.state.epoch > prior_max_epochs: raise ValueError( f"Epoch count ({engine.state.epoch}) in checkpoint is larger than " diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index 838cc3f4dd..d58260ac8c 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -38,7 +38,7 @@ def test_one_save_one_load(self): engine1.run([0] * 8, max_epochs=5) path = tempdir + "/checkpoint_final_iteration=40.pt" engine2 = Engine(lambda e, b: None) - CheckpointLoader(load_path=path, load_dict={"net": net2, "eng": engine2}).attach(engine2) + CheckpointLoader(load_path=path, load_dict={"net": net2, "eng": engine2}, strict=True).attach(engine2) @engine2.on(Events.STARTED) def check_epoch(engine: Engine): @@ -49,7 +49,7 @@ def check_epoch(engine: Engine): # test bad case with max_epochs smaller than current epoch engine3 = Engine(lambda e, b: None) - CheckpointLoader(load_path=path, load_dict={"net": net2, "eng": engine3}).attach(engine3) + CheckpointLoader(load_path=path, load_dict={"net": net2, "eng": engine3}, strict=True).attach(engine3) try: engine3.run([0] * 8, max_epochs=3) @@ -75,7 +75,7 @@ def test_two_save_one_load(self): engine.run([0] * 8, max_epochs=5) path = tempdir + "/checkpoint_final_iteration=40.pt" engine = Engine(lambda e, b: None) - CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) + CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=True).attach(engine) engine.run([0] * 8, max_epochs=1) torch.testing.assert_allclose(net2.state_dict()["weight"], torch.tensor([0.1])) @@ -96,10 +96,56 @@ def test_save_single_device_load_multi_devices(self): engine.run([0] * 8, max_epochs=5) path = tempdir + "/net_final_iteration=40.pt" engine = Engine(lambda e, b: None) - CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) + CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=True).attach(engine) engine.run([0] * 8, max_epochs=1) torch.testing.assert_allclose(net2.state_dict()["module.weight"].cpu(), torch.tensor([0.1])) + def test_partial_under_load(self): + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + net1 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) + data1 = net1.state_dict() + data1["0.weight"] = torch.tensor([0.1]) + data1["1.weight"] = torch.tensor([0.2]) + net1.load_state_dict(data1) + + net2 = torch.nn.Sequential(*[torch.nn.PReLU()]) + data2 = net2.state_dict() + data2["0.weight"] = torch.tensor([0.3]) + net2.load_state_dict(data2) + + with tempfile.TemporaryDirectory() as tempdir: + engine = Engine(lambda e, b: None) + CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine) + engine.run([0] * 8, max_epochs=5) + path = tempdir + "/net_final_iteration=40.pt" + engine = Engine(lambda e, b: None) + CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=False).attach(engine) + engine.run([0] * 8, max_epochs=1) + torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1])) + + def test_partial_over_load(self): + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + net1 = torch.nn.Sequential(*[torch.nn.PReLU()]) + data1 = net1.state_dict() + data1["0.weight"] = torch.tensor([0.1]) + net1.load_state_dict(data1) + + net2 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) + data2 = net2.state_dict() + data2["0.weight"] = torch.tensor([0.2]) + data2["1.weight"] = torch.tensor([0.3]) + net2.load_state_dict(data2) + + with tempfile.TemporaryDirectory() as tempdir: + engine = Engine(lambda e, b: None) + CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine) + engine.run([0] * 8, max_epochs=5) + path = tempdir + "/net_final_iteration=40.pt" + engine = Engine(lambda e, b: None) + CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=False).attach(engine) + engine.run([0] * 8, max_epochs=1) + torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1])) + if __name__ == "__main__": unittest.main()