diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 514301ad5b..30e6102965 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -285,7 +285,7 @@ jobs: - name: Run quick tests (GPU) run: | nvidia-smi - export LAUNCH_DELAY=$(( RANDOM % 30 * 5 )) + export LAUNCH_DELAY=$(python -c "import numpy; print(numpy.random.randint(30) * 5)") echo "Sleep $LAUNCH_DELAY" sleep $LAUNCH_DELAY export CUDA_VISIBLE_DEVICES=$(coverage run -m tests.utils) @@ -298,7 +298,7 @@ jobs: python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))' python -c "import monai; monai.config.print_config()" BUILD_MONAI=1 ./runtests.sh --quick --unittests - if [ ${{ matrix.environment }} == "PT18+CUDA112" ]; then + if [ ${{ matrix.environment }} = "PT18+CUDA112" ]; then # test the clang-format tool downloading once coverage run -m tests.clang_format_utils fi diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index 40483e8c85..6d8f065f1e 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Dict, Optional import torch +import torch.nn as nn from monai.utils import exact_version, optional_import @@ -44,8 +45,12 @@ class CheckpointLoader: first load the module to CPU and then copy each parameter to where it was saved, which would result in all processes on the same machine using the same set of devices. - strict: whether to strictly enforce that the keys in :attr:`state_dict` match the keys - returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` + strict: whether to strictly enforce that the keys in `state_dict` match the keys + returned by `torch.nn.Module.state_dict` function. default to `True`. + strict_shape: whether to enforce the data shape of the matched layers in the checkpoint, + `if `False`, it will skip the layers that have different data shape with checkpoint content. + This can be useful advanced feature for transfer learning. users should totally + understand which layers will have different shape. default to `True`. """ @@ -56,6 +61,7 @@ def __init__( name: Optional[str] = None, map_location: Optional[Dict] = None, strict: bool = True, + strict_shape: bool = True, ) -> None: if load_path is None: raise AssertionError("must provide clear path to load checkpoint.") @@ -67,6 +73,7 @@ def __init__( self._name = name self.map_location = map_location self.strict = strict + self.strict_shape = strict_shape def attach(self, engine: Engine) -> None: """ @@ -84,6 +91,20 @@ def __call__(self, engine: Engine) -> None: """ checkpoint = torch.load(self.load_path, map_location=self.map_location) + if not self.strict_shape: + k, _ = list(self.load_dict.items())[0] + # single object and checkpoint is directly a state_dict + if len(self.load_dict) == 1 and k not in checkpoint: + checkpoint = {k: checkpoint} + + # skip items that don't match data shape + for k, obj in self.load_dict.items(): + if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): + obj = obj.module + if isinstance(obj, torch.nn.Module): + d = obj.state_dict() + checkpoint[k] = {k: v for k, v in checkpoint[k].items() if k in d and v.shape == d[k].shape} + # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint prior_max_epochs = engine.state.max_epochs Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint, strict=self.strict) diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index d58260ac8c..a69193c98c 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -146,6 +146,30 @@ def test_partial_over_load(self): engine.run([0] * 8, max_epochs=1) torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1])) + def test_strict_shape(self): + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + net1 = torch.nn.Sequential(*[torch.nn.PReLU(num_parameters=5)]) + data1 = net1.state_dict() + data1["0.weight"] = torch.tensor([1, 2, 3, 4, 5]) + data1["new"] = torch.tensor(0.1) + net1.load_state_dict(data1, strict=False) + + net2 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) + data2 = net2.state_dict() + data2["0.weight"] = torch.tensor([0.2]) + data2["1.weight"] = torch.tensor([0.3]) + net2.load_state_dict(data2) + + with tempfile.TemporaryDirectory() as tempdir: + engine = Engine(lambda e, b: None) + CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine) + engine.run([0] * 8, max_epochs=5) + path = tempdir + "/net_final_iteration=40.pt" + engine = Engine(lambda e, b: None) + CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=False, strict_shape=False).attach(engine) + engine.run([0] * 8, max_epochs=1) + torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.2])) + if __name__ == "__main__": unittest.main()