From 877cdcb38397ac8b14096a5bad8048878a4d86bd Mon Sep 17 00:00:00 2001 From: dingye Date: Sat, 16 Dec 2023 13:55:29 +0800 Subject: [PATCH 1/6] Update the path to training and validation data dir. --- examples/zinc_protein/zinc_se_a_mask.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/zinc_protein/zinc_se_a_mask.json b/examples/zinc_protein/zinc_se_a_mask.json index b23987cf5d..04f63aa4ed 100644 --- a/examples/zinc_protein/zinc_se_a_mask.json +++ b/examples/zinc_protein/zinc_se_a_mask.json @@ -68,14 +68,14 @@ "training": { "training_data": { "systems": [ - "example/zinc_protein/train_data_dp_mask/" + "examples/zinc_protein/train_data_dp_mask/" ], "batch_size": 2, "_comment7": "that's all" }, "validation_data": { "systems": [ - "example/zinc_protein/val_data_dp_mask/" + "examples/zinc_protein/val_data_dp_mask/" ], "batch_size": 2, "_comment8": "that's all" From fe739eb4967d1dd0109b6d0c4a8f9a6de7a05fbe Mon Sep 17 00:00:00 2001 From: dingye Date: Sun, 17 Dec 2023 12:44:46 +0800 Subject: [PATCH 2/6] Skip update_sel automatically when using se_a_mask descriptor. --- deepmd/entrypoints/train.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/deepmd/entrypoints/train.py b/deepmd/entrypoints/train.py index 227aa13644..03802370b8 100755 --- a/deepmd/entrypoints/train.py +++ b/deepmd/entrypoints/train.py @@ -505,9 +505,15 @@ def update_one_sel(jdata, descriptor, one_type: bool = False): def update_sel(jdata): - log.info( - "Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)" - ) - jdata_cpy = jdata.copy() - jdata_cpy["model"] = Model.update_sel(jdata, jdata["model"]) + if jdata['model']['descriptor']['type'] != "se_a_mask": + log.info( + "Skip neighbor statistics for se_a_mask descriptor." + ) + jdata_cpy = jdata.copy() + else: + log.info( + "Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)" + ) + jdata_cpy = jdata.copy() + jdata_cpy["model"] = Model.update_sel(jdata, jdata["model"]) return jdata_cpy From 3545e76a6878252bc60d543501d0bd45b259edab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 17 Dec 2023 04:45:12 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/entrypoints/train.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/deepmd/entrypoints/train.py b/deepmd/entrypoints/train.py index 03802370b8..d952f424a5 100755 --- a/deepmd/entrypoints/train.py +++ b/deepmd/entrypoints/train.py @@ -505,15 +505,13 @@ def update_one_sel(jdata, descriptor, one_type: bool = False): def update_sel(jdata): - if jdata['model']['descriptor']['type'] != "se_a_mask": - log.info( - "Skip neighbor statistics for se_a_mask descriptor." - ) + if jdata["model"]["descriptor"]["type"] != "se_a_mask": + log.info("Skip neighbor statistics for se_a_mask descriptor.") jdata_cpy = jdata.copy() else: log.info( "Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)" ) jdata_cpy = jdata.copy() - jdata_cpy["model"] = Model.update_sel(jdata, jdata["model"]) + jdata_cpy["model"] = Model.update_sel(jdata, jdata["model"]) return jdata_cpy From 1b75a138198fc8f492feb9b691261a69a94b4912 Mon Sep 17 00:00:00 2001 From: dingye Date: Sun, 17 Dec 2023 20:05:13 +0800 Subject: [PATCH 4/6] Fix bug in neighbor stat calculation. --- deepmd/entrypoints/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/entrypoints/train.py b/deepmd/entrypoints/train.py index d952f424a5..a2bcb4a958 100755 --- a/deepmd/entrypoints/train.py +++ b/deepmd/entrypoints/train.py @@ -505,7 +505,7 @@ def update_one_sel(jdata, descriptor, one_type: bool = False): def update_sel(jdata): - if jdata["model"]["descriptor"]["type"] != "se_a_mask": + if jdata["model"]["descriptor"]["type"] == "se_a_mask": log.info("Skip neighbor statistics for se_a_mask descriptor.") jdata_cpy = jdata.copy() else: From b1435b038afd95ecf976b6aa833e297722dacd39 Mon Sep 17 00:00:00 2001 From: dingye Date: Mon, 18 Dec 2023 11:07:40 +0800 Subject: [PATCH 5/6] Overload update_sel method in descriptor se_a_mask. --- deepmd/descriptor/se_a_mask.py | 12 ++++++++++++ deepmd/entrypoints/train.py | 14 +++++--------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/deepmd/descriptor/se_a_mask.py b/deepmd/descriptor/se_a_mask.py index 780b34d294..618bbe0d3c 100644 --- a/deepmd/descriptor/se_a_mask.py +++ b/deepmd/descriptor/se_a_mask.py @@ -417,3 +417,15 @@ def prod_force_virial( atom_virial = tf.zeros([1, natoms[1], 9], dtype=force.dtype) return force, virial, atom_virial + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict): + """Update the selection and perform neighbor statistics. + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + return local_jdata \ No newline at end of file diff --git a/deepmd/entrypoints/train.py b/deepmd/entrypoints/train.py index a2bcb4a958..227aa13644 100755 --- a/deepmd/entrypoints/train.py +++ b/deepmd/entrypoints/train.py @@ -505,13 +505,9 @@ def update_one_sel(jdata, descriptor, one_type: bool = False): def update_sel(jdata): - if jdata["model"]["descriptor"]["type"] == "se_a_mask": - log.info("Skip neighbor statistics for se_a_mask descriptor.") - jdata_cpy = jdata.copy() - else: - log.info( - "Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)" - ) - jdata_cpy = jdata.copy() - jdata_cpy["model"] = Model.update_sel(jdata, jdata["model"]) + log.info( + "Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)" + ) + jdata_cpy = jdata.copy() + jdata_cpy["model"] = Model.update_sel(jdata, jdata["model"]) return jdata_cpy From 24ab62691df22c20e6c36b073af83ce5be5828e1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Dec 2023 03:08:12 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/descriptor/se_a_mask.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepmd/descriptor/se_a_mask.py b/deepmd/descriptor/se_a_mask.py index 618bbe0d3c..cc2e6b4fc8 100644 --- a/deepmd/descriptor/se_a_mask.py +++ b/deepmd/descriptor/se_a_mask.py @@ -421,6 +421,7 @@ def prod_force_virial( @classmethod def update_sel(cls, global_jdata: dict, local_jdata: dict): """Update the selection and perform neighbor statistics. + Parameters ---------- global_jdata : dict @@ -428,4 +429,4 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict): local_jdata : dict The local data refer to the current class """ - return local_jdata \ No newline at end of file + return local_jdata