From e5b2094db57039a8d8d24639f83fca1a7885700a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 5 Apr 2021 22:15:33 +0800 Subject: [PATCH 1/4] [DLMED] add strict_shape option Signed-off-by: Nic Ma --- monai/handlers/checkpoint_loader.py | 30 +++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index 40483e8c85..d75c3cbac4 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Dict, Optional import torch +import torch.nn as nn from monai.utils import exact_version, optional_import @@ -44,8 +45,12 @@ class CheckpointLoader: first load the module to CPU and then copy each parameter to where it was saved, which would result in all processes on the same machine using the same set of devices. - strict: whether to strictly enforce that the keys in :attr:`state_dict` match the keys - returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` + strict: whether to strictly enforce that the keys in `state_dict` match the keys + returned by `torch.nn.Module.state_dict` function. default to `True`. + strict_shape: whether to enforce the data shape of the matched layers in the checkpoint, + `if `False`, it will skip the layers that have different data shape with checkpoint content. + This can be useful advance feature for transfer learning. users should totally + understand which layers will have different shape. default to `True`. """ @@ -56,6 +61,7 @@ def __init__( name: Optional[str] = None, map_location: Optional[Dict] = None, strict: bool = True, + strict_shape: bool = True, ) -> None: if load_path is None: raise AssertionError("must provide clear path to load checkpoint.") @@ -67,6 +73,7 @@ def __init__( self._name = name self.map_location = map_location self.strict = strict + self.strict_shape = strict_shape def attach(self, engine: Engine) -> None: """ @@ -84,6 +91,25 @@ def __call__(self, engine: Engine) -> None: """ checkpoint = torch.load(self.load_path, map_location=self.map_location) + if not self.strict_shape: + def _skip_mismatch_shape_keys(obj_state_dict, ckpt_state_dict): + return { + k: v for k, v in ckpt_state_dict.items() + if k in obj_state_dict and v.shape == obj_state_dict[k].shape + } + if len(self.load_dict) == 1: + key, obj = list(self.load_dict.items())[0] + # single object and checkpoint is directly a state_dict + if key not in checkpoint: + checkpoint = {key: checkpoint} + + # multiple objects to load + for k, obj in self.load_dict.items(): + if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): + obj = obj.module + if isinstance(obj, torch.nn.Module): + checkpoint[k] = _skip_mismatch_shape_keys(obj.state_dict(), checkpoint[k]) + # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint prior_max_epochs = engine.state.max_epochs Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint, strict=self.strict) From dfc76d9fe03ffd07bfb49bf46c26ca18b3dd9df7 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 6 Apr 2021 00:04:04 +0800 Subject: [PATCH 2/4] [DLMED] add unit tests Signed-off-by: Nic Ma --- monai/handlers/checkpoint_loader.py | 23 +++++++++-------------- tests/test_handler_checkpoint_loader.py | 23 +++++++++++++++++++++++ 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index d75c3cbac4..6d8f065f1e 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -49,7 +49,7 @@ class CheckpointLoader: returned by `torch.nn.Module.state_dict` function. default to `True`. strict_shape: whether to enforce the data shape of the matched layers in the checkpoint, `if `False`, it will skip the layers that have different data shape with checkpoint content. - This can be useful advance feature for transfer learning. users should totally + This can be useful advanced feature for transfer learning. users should totally understand which layers will have different shape. default to `True`. """ @@ -92,23 +92,18 @@ def __call__(self, engine: Engine) -> None: checkpoint = torch.load(self.load_path, map_location=self.map_location) if not self.strict_shape: - def _skip_mismatch_shape_keys(obj_state_dict, ckpt_state_dict): - return { - k: v for k, v in ckpt_state_dict.items() - if k in obj_state_dict and v.shape == obj_state_dict[k].shape - } - if len(self.load_dict) == 1: - key, obj = list(self.load_dict.items())[0] - # single object and checkpoint is directly a state_dict - if key not in checkpoint: - checkpoint = {key: checkpoint} - - # multiple objects to load + k, _ = list(self.load_dict.items())[0] + # single object and checkpoint is directly a state_dict + if len(self.load_dict) == 1 and k not in checkpoint: + checkpoint = {k: checkpoint} + + # skip items that don't match data shape for k, obj in self.load_dict.items(): if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): obj = obj.module if isinstance(obj, torch.nn.Module): - checkpoint[k] = _skip_mismatch_shape_keys(obj.state_dict(), checkpoint[k]) + d = obj.state_dict() + checkpoint[k] = {k: v for k, v in checkpoint[k].items() if k in d and v.shape == d[k].shape} # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint prior_max_epochs = engine.state.max_epochs diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index d58260ac8c..1f2414d978 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -146,6 +146,29 @@ def test_partial_over_load(self): engine.run([0] * 8, max_epochs=1) torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1])) + def test_strict_shape(self): + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + net1 = torch.nn.Sequential(*[torch.nn.PReLU(num_parameters=5)]) + data1 = net1.state_dict() + data1["0.weight"] = torch.tensor([1, 2, 3, 4, 5]) + net1.load_state_dict(data1) + + net2 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) + data2 = net2.state_dict() + data2["0.weight"] = torch.tensor([0.2]) + data2["1.weight"] = torch.tensor([0.3]) + net2.load_state_dict(data2) + + with tempfile.TemporaryDirectory() as tempdir: + engine = Engine(lambda e, b: None) + CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine) + engine.run([0] * 8, max_epochs=5) + path = tempdir + "/net_final_iteration=40.pt" + engine = Engine(lambda e, b: None) + CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=False, strict_shape=False).attach(engine) + engine.run([0] * 8, max_epochs=1) + torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.2])) + if __name__ == "__main__": unittest.main() From ac086b771fad08ccd0eb5b1a4ec08d65826ad803 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 5 Apr 2021 18:54:40 +0100 Subject: [PATCH 3/4] update test case Signed-off-by: Wenqi Li --- tests/test_handler_checkpoint_loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index 1f2414d978..a69193c98c 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -151,7 +151,8 @@ def test_strict_shape(self): net1 = torch.nn.Sequential(*[torch.nn.PReLU(num_parameters=5)]) data1 = net1.state_dict() data1["0.weight"] = torch.tensor([1, 2, 3, 4, 5]) - net1.load_state_dict(data1) + data1["new"] = torch.tensor(0.1) + net1.load_state_dict(data1, strict=False) net2 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) data2 = net2.state_dict() From b406974e1f309e2062380d8293664d821fbb5f24 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 5 Apr 2021 18:55:05 +0100 Subject: [PATCH 4/4] fixes test config Signed-off-by: Wenqi Li --- .github/workflows/pythonapp.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 514301ad5b..30e6102965 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -285,7 +285,7 @@ jobs: - name: Run quick tests (GPU) run: | nvidia-smi - export LAUNCH_DELAY=$(( RANDOM % 30 * 5 )) + export LAUNCH_DELAY=$(python -c "import numpy; print(numpy.random.randint(30) * 5)") echo "Sleep $LAUNCH_DELAY" sleep $LAUNCH_DELAY export CUDA_VISIBLE_DEVICES=$(coverage run -m tests.utils) @@ -298,7 +298,7 @@ jobs: python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))' python -c "import monai; monai.config.print_config()" BUILD_MONAI=1 ./runtests.sh --quick --unittests - if [ ${{ matrix.environment }} == "PT18+CUDA112" ]; then + if [ ${{ matrix.environment }} = "PT18+CUDA112" ]; then # test the clang-format tool downloading once coverage run -m tests.clang_format_utils fi