Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions deepmd/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from deepmd.train.run_options import BUILD, CITATION, WELCOME, RunOptions
from deepmd.train.trainer import DPTrainer
from deepmd.utils.argcheck import normalize
from deepmd.utils.compat import convert_input_v0_v1
from deepmd.utils.compat import updata_deepmd_input
from deepmd.utils.data_system import DeepmdDataSystem

if TYPE_CHECKING:
Expand Down Expand Up @@ -168,8 +168,7 @@ def train(
# load json database
jdata = j_loader(INPUT)

if "model" not in jdata.keys():
jdata = convert_input_v0_v1(jdata, warning=True, dump="input_v1_compat.json")
jdata = updata_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")

jdata = normalize(jdata)
with open(output, "w") as fp:
Expand Down
87 changes: 64 additions & 23 deletions deepmd/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def convert_input_v0_v1(
jdata: Dict[str, Any], warning: bool = True, dump: Optional[Union[str, Path]] = None
) -> Dict[str, Any]:
"""Convert input from v0 format to v1.

Parameters
----------
jdata : Dict[str, Any]
Expand All @@ -21,12 +20,12 @@ def convert_input_v0_v1(
whether to show deprecation warning, by default True
dump : Optional[Union[str, Path]], optional
whether to dump converted file, by default None

Returns
-------
Dict[str, Any]
converted output
"""

output = {}
if "with_distrib" in jdata:
output["with_distrib"] = jdata["with_distrib"]
Expand All @@ -35,33 +34,29 @@ def convert_input_v0_v1(
output["loss"] = _loss(jdata)
output["training"] = _training(jdata)
if warning:
_warnning_input_v0_v1(dump)
_warning_input_v0_v1(dump)
if dump is not None:
with open(dump, "w") as fp:
json.dump(output, fp, indent=4)
return output


def _warnning_input_v0_v1(fname: Optional[Union[str, Path]]):
msg = (
"It seems that you are using a deepmd-kit input of version 0.x.x, "
"which is deprecated. we have converted the input to >1.0.0 compatible"
)
def _warning_input_v0_v1(fname: Optional[Union[str, Path]]):
msg = "It seems that you are using a deepmd-kit input of version 0.x.x, " \
"which is deprecated. we have converted the input to >2.0.0 compatible"
if fname is not None:
msg += f", and output it to file {fname}"
warnings.warn(msg)


def _model(jdata: Dict[str, Any], smooth: bool) -> Dict[str, Dict[str, Any]]:
"""Convert data to v1 input for non-smooth model.

Parameters
----------
jdata : Dict[str, Any]
parsed input json/yaml data
smooth : bool
whether to use smooth or non-smooth descriptor version

Returns
-------
Dict[str, Dict[str, Any]]
Expand All @@ -78,12 +73,10 @@ def _model(jdata: Dict[str, Any], smooth: bool) -> Dict[str, Dict[str, Any]]:

def _nonsmth_descriptor(jdata: Dict[str, Any]) -> Dict[str, Any]:
"""Convert data to v1 input for non-smooth descriptor.

Parameters
----------
jdata : Dict[str, Any]
parsed input json/yaml data

Returns
-------
Dict[str, Any]
Expand All @@ -97,12 +90,10 @@ def _nonsmth_descriptor(jdata: Dict[str, Any]) -> Dict[str, Any]:

def _smth_descriptor(jdata: Dict[str, Any]) -> Dict[str, Any]:
"""Convert data to v1 input for smooth descriptor.

Parameters
----------
jdata : Dict[str, Any]
parsed input json/yaml data

Returns
-------
Dict[str, Any]
Expand All @@ -127,12 +118,10 @@ def _smth_descriptor(jdata: Dict[str, Any]) -> Dict[str, Any]:

def _fitting_net(jdata: Dict[str, Any]) -> Dict[str, Any]:
"""Convert data to v1 input for fitting net.

Parameters
----------
jdata : Dict[str, Any]
parsed input json/yaml data

Returns
-------
Dict[str, Any]
Expand All @@ -154,12 +143,10 @@ def _fitting_net(jdata: Dict[str, Any]) -> Dict[str, Any]:

def _learning_rate(jdata: Dict[str, Any]) -> Dict[str, Any]:
"""Convert data to v1 input for learning rate section.

Parameters
----------
jdata : Dict[str, Any]
parsed input json/yaml data

Returns
-------
Dict[str, Any]
Expand All @@ -173,12 +160,10 @@ def _learning_rate(jdata: Dict[str, Any]) -> Dict[str, Any]:

def _loss(jdata: Dict[str, Any]) -> Dict[str, Any]:
"""Convert data to v1 input for loss function.

Parameters
----------
jdata : Dict[str, Any]
parsed input json/yaml data

Returns
-------
Dict[str, Any]
Expand Down Expand Up @@ -206,12 +191,10 @@ def _loss(jdata: Dict[str, Any]) -> Dict[str, Any]:

def _training(jdata: Dict[str, Any]) -> Dict[str, Any]:
"""Convert data to v1 input for training.

Parameters
----------
jdata : Dict[str, Any]
parsed input json/yaml data

Returns
-------
Dict[str, Any]
Expand Down Expand Up @@ -241,7 +224,6 @@ def _training(jdata: Dict[str, Any]) -> Dict[str, Any]:

def _jcopy(src: Dict[str, Any], dst: Dict[str, Any], keys: Sequence[str]):
"""Copy specified keys from one dict to another.

Parameters
----------
src : Dict[str, Any]
Expand All @@ -255,3 +237,62 @@ def _jcopy(src: Dict[str, Any], dst: Dict[str, Any], keys: Sequence[str]):
"""
for k in keys:
dst[k] = src[k]


def convert_input_v1_v2(jdata: Dict[str, Any],
warning: bool = True,
dump: Optional[Union[str, Path]] = None) -> Dict[str, Any]:

tr_cfg = jdata["training"]
tr_data_keys = {
"systems",
"set_prefix",
"batch_size",
"sys_prob",
"auto_prob",
# alias included
"sys_weights",
"auto_prob_style"
}

tr_data_cfg = {k: v for k, v in tr_cfg.items() if k in tr_data_keys}
new_tr_cfg = {k: v for k, v in tr_cfg.items() if k not in tr_data_keys}
new_tr_cfg["training_data"] = tr_data_cfg

jdata["training"] = new_tr_cfg

if warning:
_warning_input_v1_v2(dump)
if dump is not None:
with open(dump, "w") as fp:
json.dump(jdata, fp, indent=4)

return jdata


def _warning_input_v1_v2(fname: Optional[Union[str, Path]]):
msg = "It seems that you are using a deepmd-kit input of version 1.x.x, " \
"which is deprecated. we have converted the input to >2.0.0 compatible"
if fname is not None:
msg += f", and output it to file {fname}"
warnings.warn(msg)


def updata_deepmd_input(jdata: Dict[str, Any],
warning: bool = True,
dump: Optional[Union[str, Path]] = None) -> Dict[str, Any]:
def is_deepmd_v0_input(jdata):
return "model" not in jdata.keys()

def is_deepmd_v1_input(jdata):
return "systems" in j_must_have(jdata, "training").keys()

if is_deepmd_v0_input(jdata):
jdata = convert_input_v0_v1(jdata, warning, None)
jdata = convert_input_v1_v2(jdata, False, dump)
elif is_deepmd_v1_input(jdata):
jdata = convert_input_v1_v2(jdata, warning, dump)
else:
pass

return jdata
51 changes: 51 additions & 0 deletions source/tests/compat_inputs/water_v2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
{
"with_distrib": false,
"model":{
"descriptor": {
"type": "loc_frame",
"sel_a": [16, 32],
"sel_r": [30, 60],
"rcut": 6.00,
"axis_rule": [0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0]
},
"fitting_net": {
"neuron": [240, 120, 60, 30, 10],
"resnet_dt": true,
"seed": 1
}
},

"learning_rate" :{
"type": "exp",
"decay_steps": 5000,
"decay_rate": 0.95,
"start_lr": 0.001
},

"loss" : {
"start_pref_e": 0.02,
"limit_pref_e": 8,
"start_pref_f": 1000,
"limit_pref_f": 1,
"start_pref_v": 0,
"limit_pref_v": 0
},

"training": {
"training_data": {
"systems": ["../data/"],
"set_prefix": "set",
"batch_size": [4]
},
"stop_batch": 1000000,
"seed": 1,
"disp_file": "lcurve.out",
"disp_freq": 100,
"numb_test": 10,
"save_freq": 1000,
"save_ckpt": "model.ckpt",
"disp_training":true,
"time_training":true
}
}

8 changes: 7 additions & 1 deletion source/tests/test_compat_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import unittest

from deepmd.utils.compat import convert_input_v0_v1
from deepmd.utils.compat import convert_input_v0_v1, convert_input_v1_v2
from common import j_loader

class TestConvertInput (unittest.TestCase) :
Expand All @@ -18,6 +18,12 @@ def test_convert_nonsmth(self):
jdata = convert_input_v0_v1(jdata0, warning = False, dump = None)
self.assertEqual(jdata, jdata1)

def test_convert_v1_v2(self):
jdata0 = j_loader(os.path.join('compat_inputs', 'water_v1.json'))
jdata1 = j_loader(os.path.join('compat_inputs', 'water_v2.json'))
jdata = convert_input_v1_v2(jdata0, warning = False, dump = None)
self.assertEqual(jdata, jdata1)

def test_json_yaml_equal(self):

inputs = ("water_v1", "water_se_a_v1")
Expand Down