From b43ffad2b918b115fcc25cbc31bd7420c890ca39 Mon Sep 17 00:00:00 2001 From: Petru-Daniel Tudosiu Date: Sat, 3 Apr 2021 21:23:03 +0100 Subject: [PATCH 1/3] Enabled partial checkpoint loading Allowing partial loading of a model via strict=False. Signed-off-by: Petru-Daniel Tudosiu --- monai/handlers/checkpoint_loader.py | 6 ++- tests/test_handler_checkpoint_loader.py | 53 +++++++++++++++++++++++-- 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index bb67428bef..b8c2fc6764 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -44,6 +44,8 @@ class CheckpointLoader: first load the module to CPU and then copy each parameter to where it was saved, which would result in all processes on the same machine using the same set of devices. + strict: whether to strictly enforce that the keys in :attr:`state_dict` match the keys + returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` """ @@ -53,6 +55,7 @@ def __init__( load_dict: Dict, name: Optional[str] = None, map_location: Optional[Dict] = None, + strict: bool = True, ) -> None: if load_path is None: raise AssertionError("must provide clear path to load checkpoint.") @@ -63,6 +66,7 @@ def __init__( self.load_dict = load_dict self._name = name self.map_location = map_location + self.strict = strict def attach(self, engine: Engine) -> None: """ @@ -82,7 +86,7 @@ def __call__(self, engine: Engine) -> None: # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint prior_max_epochs = engine.state.max_epochs - Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint) + Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint, **{"strict": self.strict}) if engine.state.epoch > prior_max_epochs: raise ValueError( f"Epoch count ({engine.state.epoch}) in checkpoint is larger than " diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index 838cc3f4dd..f6937ad8fe 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -38,7 +38,7 @@ def test_one_save_one_load(self): engine1.run([0] * 8, max_epochs=5) path = tempdir + "/checkpoint_final_iteration=40.pt" engine2 = Engine(lambda e, b: None) - CheckpointLoader(load_path=path, load_dict={"net": net2, "eng": engine2}).attach(engine2) + CheckpointLoader(load_path=path, load_dict={"net": net2, "eng": engine2}, strict=True).attach(engine2) @engine2.on(Events.STARTED) def check_epoch(engine: Engine): @@ -49,7 +49,7 @@ def check_epoch(engine: Engine): # test bad case with max_epochs smaller than current epoch engine3 = Engine(lambda e, b: None) - CheckpointLoader(load_path=path, load_dict={"net": net2, "eng": engine3}).attach(engine3) + CheckpointLoader(load_path=path, load_dict={"net": net2, "eng": engine3}, strict=True).attach(engine3) try: engine3.run([0] * 8, max_epochs=3) @@ -75,7 +75,7 @@ def test_two_save_one_load(self): engine.run([0] * 8, max_epochs=5) path = tempdir + "/checkpoint_final_iteration=40.pt" engine = Engine(lambda e, b: None) - CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) + CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=True).attach(engine) engine.run([0] * 8, max_epochs=1) torch.testing.assert_allclose(net2.state_dict()["weight"], torch.tensor([0.1])) @@ -96,10 +96,55 @@ def test_save_single_device_load_multi_devices(self): engine.run([0] * 8, max_epochs=5) path = tempdir + "/net_final_iteration=40.pt" engine = Engine(lambda e, b: None) - CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) + CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=True).attach(engine) engine.run([0] * 8, max_epochs=1) torch.testing.assert_allclose(net2.state_dict()["module.weight"].cpu(), torch.tensor([0.1])) + def test_partial_under_load(self): + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + net1 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) + data1 = net1.state_dict() + data1["0.weight"] = torch.tensor([0.1]) + data1["1.weight"] = torch.tensor([0.2]) + net1.load_state_dict(data1) + + net2 = torch.nn.Sequential(*[torch.nn.PReLU()]) + data2 = net2.state_dict() + data2["0.weight"] = torch.tensor([0.3]) + net2.load_state_dict(data2) + + with tempfile.TemporaryDirectory() as tempdir: + engine = Engine(lambda e, b: None) + CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine) + engine.run([0] * 8, max_epochs=5) + path = tempdir + "/net_final_iteration=40.pt" + engine = Engine(lambda e, b: None) + CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=False).attach(engine) + engine.run([0] * 8, max_epochs=1) + torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1])) + + def test_partial_over_load(self): + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + net1 = torch.nn.Sequential(*[torch.nn.PReLU()]) + data1 = net1.state_dict() + data1["0.weight"] = torch.tensor([0.1]) + net1.load_state_dict(data1) + + net2 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) + data2 = net2.state_dict() + data2["0.weight"] = torch.tensor([0.2]) + data2["1.weight"] = torch.tensor([0.3]) + net2.load_state_dict(data2) + + with tempfile.TemporaryDirectory() as tempdir: + engine = Engine(lambda e, b: None) + CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine) + engine.run([0] * 8, max_epochs=5) + path = tempdir + "/net_final_iteration=40.pt" + engine = Engine(lambda e, b: None) + CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=False).attach(engine) + engine.run([0] * 8, max_epochs=1) + torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1])) if __name__ == "__main__": unittest.main() From 52ca10567fac0bf6368a410baaacc0ebd810cdaa Mon Sep 17 00:00:00 2001 From: monai-bot Date: Sat, 3 Apr 2021 21:13:19 +0000 Subject: [PATCH 2/3] [MONAI] python code formatting Signed-off-by: monai-bot --- tests/test_handler_checkpoint_loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index f6937ad8fe..d58260ac8c 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -146,5 +146,6 @@ def test_partial_over_load(self): engine.run([0] * 8, max_epochs=1) torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1])) + if __name__ == "__main__": unittest.main() From 29b3374e5b285521401df31b18dcc6b7a32818cb Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sun, 4 Apr 2021 08:41:46 +0800 Subject: [PATCH 3/3] [DLMED] simplify strict arg Signed-off-by: Nic Ma --- monai/handlers/checkpoint_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index b8c2fc6764..40483e8c85 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -86,7 +86,7 @@ def __call__(self, engine: Engine) -> None: # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint prior_max_epochs = engine.state.max_epochs - Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint, **{"strict": self.strict}) + Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint, strict=self.strict) if engine.state.epoch > prior_max_epochs: raise ValueError( f"Epoch count ({engine.state.epoch}) in checkpoint is larger than "