Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion monai/handlers/checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class CheckpointLoader:
first load the module to CPU and then copy each parameter to where it was
saved, which would result in all processes on the same machine using the
same set of devices.
strict: whether to strictly enforce that the keys in :attr:`state_dict` match the keys
returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True``

"""

Expand All @@ -53,6 +55,7 @@ def __init__(
load_dict: Dict,
name: Optional[str] = None,
map_location: Optional[Dict] = None,
strict: bool = True,
) -> None:
if load_path is None:
raise AssertionError("must provide clear path to load checkpoint.")
Expand All @@ -63,6 +66,7 @@ def __init__(
self.load_dict = load_dict
self._name = name
self.map_location = map_location
self.strict = strict

def attach(self, engine: Engine) -> None:
"""
Expand All @@ -82,7 +86,7 @@ def __call__(self, engine: Engine) -> None:

# save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint
prior_max_epochs = engine.state.max_epochs
Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint)
Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint, strict=self.strict)
if engine.state.epoch > prior_max_epochs:
raise ValueError(
f"Epoch count ({engine.state.epoch}) in checkpoint is larger than "
Expand Down
54 changes: 50 additions & 4 deletions tests/test_handler_checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_one_save_one_load(self):
engine1.run([0] * 8, max_epochs=5)
path = tempdir + "/checkpoint_final_iteration=40.pt"
engine2 = Engine(lambda e, b: None)
CheckpointLoader(load_path=path, load_dict={"net": net2, "eng": engine2}).attach(engine2)
CheckpointLoader(load_path=path, load_dict={"net": net2, "eng": engine2}, strict=True).attach(engine2)

@engine2.on(Events.STARTED)
def check_epoch(engine: Engine):
Expand All @@ -49,7 +49,7 @@ def check_epoch(engine: Engine):

# test bad case with max_epochs smaller than current epoch
engine3 = Engine(lambda e, b: None)
CheckpointLoader(load_path=path, load_dict={"net": net2, "eng": engine3}).attach(engine3)
CheckpointLoader(load_path=path, load_dict={"net": net2, "eng": engine3}, strict=True).attach(engine3)

try:
engine3.run([0] * 8, max_epochs=3)
Expand All @@ -75,7 +75,7 @@ def test_two_save_one_load(self):
engine.run([0] * 8, max_epochs=5)
path = tempdir + "/checkpoint_final_iteration=40.pt"
engine = Engine(lambda e, b: None)
CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine)
CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=True).attach(engine)
engine.run([0] * 8, max_epochs=1)
torch.testing.assert_allclose(net2.state_dict()["weight"], torch.tensor([0.1]))

Expand All @@ -96,10 +96,56 @@ def test_save_single_device_load_multi_devices(self):
engine.run([0] * 8, max_epochs=5)
path = tempdir + "/net_final_iteration=40.pt"
engine = Engine(lambda e, b: None)
CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine)
CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=True).attach(engine)
engine.run([0] * 8, max_epochs=1)
torch.testing.assert_allclose(net2.state_dict()["module.weight"].cpu(), torch.tensor([0.1]))

def test_partial_under_load(self):
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
net1 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()])
data1 = net1.state_dict()
data1["0.weight"] = torch.tensor([0.1])
data1["1.weight"] = torch.tensor([0.2])
net1.load_state_dict(data1)

net2 = torch.nn.Sequential(*[torch.nn.PReLU()])
data2 = net2.state_dict()
data2["0.weight"] = torch.tensor([0.3])
net2.load_state_dict(data2)

with tempfile.TemporaryDirectory() as tempdir:
engine = Engine(lambda e, b: None)
CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine)
engine.run([0] * 8, max_epochs=5)
path = tempdir + "/net_final_iteration=40.pt"
engine = Engine(lambda e, b: None)
CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=False).attach(engine)
engine.run([0] * 8, max_epochs=1)
torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1]))

def test_partial_over_load(self):
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
net1 = torch.nn.Sequential(*[torch.nn.PReLU()])
data1 = net1.state_dict()
data1["0.weight"] = torch.tensor([0.1])
net1.load_state_dict(data1)

net2 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()])
data2 = net2.state_dict()
data2["0.weight"] = torch.tensor([0.2])
data2["1.weight"] = torch.tensor([0.3])
net2.load_state_dict(data2)

with tempfile.TemporaryDirectory() as tempdir:
engine = Engine(lambda e, b: None)
CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine)
engine.run([0] * 8, max_epochs=5)
path = tempdir + "/net_final_iteration=40.pt"
engine = Engine(lambda e, b: None)
CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=False).attach(engine)
engine.run([0] * 8, max_epochs=1)
torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1]))


if __name__ == "__main__":
unittest.main()