From ebcf1eb5aaf09e3c5298a629607d70d511f32698 Mon Sep 17 00:00:00 2001 From: denghuilu Date: Tue, 31 Aug 2021 14:38:15 +0800 Subject: [PATCH 1/3] change the default checkpoint path for model compression --- deepmd/entrypoints/compress.py | 1 + deepmd/entrypoints/main.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/deepmd/entrypoints/compress.py b/deepmd/entrypoints/compress.py index c689481b96..1737ae4a75 100644 --- a/deepmd/entrypoints/compress.py +++ b/deepmd/entrypoints/compress.py @@ -101,6 +101,7 @@ def compress( 10 * step, int(frequency), ] + jdata["training"]["save_ckpt"] = "model-compression/model.ckpt" jdata = normalize(jdata) # check the descriptor info of the input file diff --git a/deepmd/entrypoints/main.py b/deepmd/entrypoints/main.py index bef0d2bf58..721eed357c 100644 --- a/deepmd/entrypoints/main.py +++ b/deepmd/entrypoints/main.py @@ -311,7 +311,7 @@ def parse_args(args: Optional[List[str]] = None): "-c", "--checkpoint-folder", type=str, - default=".", + default="model-compression", help="path to checkpoint folder", ) parser_compress.add_argument( From 7348a5f8990e557a1a530fd17c495d2bb8d7a601 Mon Sep 17 00:00:00 2001 From: denghuilu Date: Tue, 31 Aug 2021 14:45:13 +0800 Subject: [PATCH 2/3] update UT for model compression --- source/tests/test_model_compression.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/source/tests/test_model_compression.py b/source/tests/test_model_compression.py index b183c6bde8..e1f5ed3b05 100644 --- a/source/tests/test_model_compression.py +++ b/source/tests/test_model_compression.py @@ -317,11 +317,10 @@ def tearDownClass(self): _file_delete(COMPRESSED_MODEL) _file_delete("out.json") _file_delete("compress.json") - _file_delete("checkpoint") - _file_delete("lcurve.out") - _file_delete("model.ckpt.meta") - _file_delete("model.ckpt.index") - _file_delete("model.ckpt.data-00000-of-00001") + _file_delete("model-compression/checkpoint") + _file_delete("model-compression/model.ckpt.meta") + _file_delete("model-compression/model.ckpt.index") + _file_delete("model-compression/model.ckpt.data-00000-of-00001") def test_attrs(self): self.assertEqual(self.dp_original.get_ntypes(), 2) From 701fbfffff5a91dd0a9f7554a0b79a7b88a3a91c Mon Sep 17 00:00:00 2001 From: denghuilu Date: Tue, 31 Aug 2021 15:35:30 +0800 Subject: [PATCH 3/3] clean up files --- source/tests/test_model_compression.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/source/tests/test_model_compression.py b/source/tests/test_model_compression.py index e1f5ed3b05..2f42134409 100644 --- a/source/tests/test_model_compression.py +++ b/source/tests/test_model_compression.py @@ -15,7 +15,9 @@ default_places = 10 def _file_delete(file) : - if os.path.exists(file): + if os.path.isdir(file): + os.rmdir(file) + elif os.path.isfile(file): os.remove(file) def _subprocess_run(command): @@ -317,10 +319,18 @@ def tearDownClass(self): _file_delete(COMPRESSED_MODEL) _file_delete("out.json") _file_delete("compress.json") + _file_delete("checkpoint") + _file_delete("model.ckpt.meta") + _file_delete("model.ckpt.index") + _file_delete("model.ckpt.data-00000-of-00001") + _file_delete("model.ckpt-100.meta") + _file_delete("model.ckpt-100.index") + _file_delete("model.ckpt-100.data-00000-of-00001") _file_delete("model-compression/checkpoint") _file_delete("model-compression/model.ckpt.meta") _file_delete("model-compression/model.ckpt.index") _file_delete("model-compression/model.ckpt.data-00000-of-00001") + _file_delete("model-compression") def test_attrs(self): self.assertEqual(self.dp_original.get_ntypes(), 2)