diff --git a/deepmd/entrypoints/compress.py b/deepmd/entrypoints/compress.py index c689481b96..1737ae4a75 100644 --- a/deepmd/entrypoints/compress.py +++ b/deepmd/entrypoints/compress.py @@ -101,6 +101,7 @@ def compress( 10 * step, int(frequency), ] + jdata["training"]["save_ckpt"] = "model-compression/model.ckpt" jdata = normalize(jdata) # check the descriptor info of the input file diff --git a/deepmd/entrypoints/main.py b/deepmd/entrypoints/main.py index bef0d2bf58..721eed357c 100644 --- a/deepmd/entrypoints/main.py +++ b/deepmd/entrypoints/main.py @@ -311,7 +311,7 @@ def parse_args(args: Optional[List[str]] = None): "-c", "--checkpoint-folder", type=str, - default=".", + default="model-compression", help="path to checkpoint folder", ) parser_compress.add_argument( diff --git a/source/tests/test_model_compression.py b/source/tests/test_model_compression.py index b183c6bde8..2f42134409 100644 --- a/source/tests/test_model_compression.py +++ b/source/tests/test_model_compression.py @@ -15,7 +15,9 @@ default_places = 10 def _file_delete(file) : - if os.path.exists(file): + if os.path.isdir(file): + os.rmdir(file) + elif os.path.isfile(file): os.remove(file) def _subprocess_run(command): @@ -318,10 +320,17 @@ def tearDownClass(self): _file_delete("out.json") _file_delete("compress.json") _file_delete("checkpoint") - _file_delete("lcurve.out") _file_delete("model.ckpt.meta") _file_delete("model.ckpt.index") _file_delete("model.ckpt.data-00000-of-00001") + _file_delete("model.ckpt-100.meta") + _file_delete("model.ckpt-100.index") + _file_delete("model.ckpt-100.data-00000-of-00001") + _file_delete("model-compression/checkpoint") + _file_delete("model-compression/model.ckpt.meta") + _file_delete("model-compression/model.ckpt.index") + _file_delete("model-compression/model.ckpt.data-00000-of-00001") + _file_delete("model-compression") def test_attrs(self): self.assertEqual(self.dp_original.get_ntypes(), 2)