diff --git a/deepmd/entrypoints/freeze.py b/deepmd/entrypoints/freeze.py index 9f92c2f682..e2ca820218 100755 --- a/deepmd/entrypoints/freeze.py +++ b/deepmd/entrypoints/freeze.py @@ -158,6 +158,11 @@ def _modify_model_suffix(output_graph_def, out_suffix, freeze_type): loss_dict = jdata.pop("loss_dict") if out_suffix in loss_dict: jdata["loss"] = loss_dict[out_suffix] + # learning_rate + if "learning_rate_dict" in jdata: + learning_rate_dict = jdata.pop("learning_rate_dict") + if out_suffix in learning_rate_dict: + jdata["learning_rate"] = learning_rate_dict[out_suffix] # fitting weight if "fitting_weight" in jdata["training"]: jdata["training"].pop("fitting_weight") diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 9c2565a722..7b1aae2529 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -268,22 +268,37 @@ def fitting_net_init(fitting_type_, descrpt_type_, params): model_param.get("sw_rmax"), ) + def get_lr_and_coef(lr_param): + scale_by_worker = lr_param.get("scale_by_worker", "linear") + if scale_by_worker == "linear": + scale_lr_coef = float(self.run_opt.world_size) + elif scale_by_worker == "sqrt": + scale_lr_coef = np.sqrt(self.run_opt.world_size).real + else: + scale_lr_coef = 1.0 + lr_type = lr_param.get("type", "exp") + if lr_type == "exp": + lr = LearningRateExp( + lr_param["start_lr"], lr_param["stop_lr"], lr_param["decay_steps"] + ) + else: + raise RuntimeError("unknown learning_rate type " + lr_type) + return lr, scale_lr_coef + # learning rate - lr_param = j_must_have(jdata, "learning_rate") - scale_by_worker = lr_param.get("scale_by_worker", "linear") - if scale_by_worker == "linear": - self.scale_lr_coef = float(self.run_opt.world_size) - elif scale_by_worker == "sqrt": - self.scale_lr_coef = np.sqrt(self.run_opt.world_size).real - else: - self.scale_lr_coef = 1.0 - lr_type = lr_param.get("type", "exp") - if lr_type == "exp": - self.lr = LearningRateExp( - lr_param["start_lr"], lr_param["stop_lr"], lr_param["decay_steps"] - ) + if not self.multi_task_mode: + lr_param = j_must_have(jdata, "learning_rate") + self.lr, self.scale_lr_coef = get_lr_and_coef(lr_param) else: - raise RuntimeError("unknown learning_rate type " + lr_type) + self.lr_dict = {} + self.scale_lr_coef_dict = {} + lr_param_dict = jdata.get("learning_rate_dict", {}) + for fitting_key in self.fitting_type_dict: + lr_param = lr_param_dict.get(fitting_key, {}) + ( + self.lr_dict[fitting_key], + self.scale_lr_coef_dict[fitting_key], + ) = get_lr_and_coef(lr_param) # loss # infer loss type by fitting_type @@ -349,7 +364,7 @@ def loss_init(_loss_param, _fitting_type, _fitting, _lr): loss_param, self.fitting_type_dict[fitting_key], self.fitting_dict[fitting_key], - self.lr, + self.lr_dict[fitting_key], ) # training @@ -559,9 +574,51 @@ def build(self, data=None, stop_batch=0, origin_type_map=None, suffix=""): def _build_lr(self): self._extra_train_ops = [] self.global_step = tf.train.get_or_create_global_step() - self.learning_rate = self.lr.build(self.global_step, self.stop_batch) + if not self.multi_task_mode: + self.learning_rate = self.lr.build(self.global_step, self.stop_batch) + else: + self.learning_rate_dict = {} + for fitting_key in self.fitting_type_dict: + self.learning_rate_dict[fitting_key] = self.lr_dict[fitting_key].build( + self.global_step, self.stop_batch + ) + log.info("built lr") + def _build_loss(self): + if not self.multi_task_mode: + l2_l, l2_more = self.loss.build( + self.learning_rate, + self.place_holders["natoms_vec"], + self.model_pred, + self.place_holders, + suffix="test", + ) + + if self.mixed_prec is not None: + l2_l = tf.cast(l2_l, get_precision(self.mixed_prec["output_prec"])) + else: + l2_l, l2_more = {}, {} + for fitting_key in self.fitting_type_dict: + lr = self.learning_rate_dict[fitting_key] + model = self.model_pred[fitting_key] + loss_dict = self.loss_dict[fitting_key] + + l2_l[fitting_key], l2_more[fitting_key] = loss_dict.build( + lr, + self.place_holders["natoms_vec"], + model, + self.place_holders, + suffix=fitting_key, + ) + + if self.mixed_prec is not None: + l2_l[fitting_key] = tf.cast( + l2_l[fitting_key], get_precision(self.mixed_prec["output_prec"]) + ) + + return l2_l, l2_more + def _build_network(self, data, suffix=""): self.place_holders = {} if self.is_compress: @@ -597,55 +654,46 @@ def _build_network(self, data, suffix=""): reuse=False, ) - if not self.multi_task_mode: - self.l2_l, self.l2_more = self.loss.build( - self.learning_rate, - self.place_holders["natoms_vec"], - self.model_pred, - self.place_holders, - suffix="test", - ) - - if self.mixed_prec is not None: - self.l2_l = tf.cast( - self.l2_l, get_precision(self.mixed_prec["output_prec"]) - ) - else: - self.l2_l, self.l2_more = {}, {} - for fitting_key in self.fitting_type_dict: - self.l2_l[fitting_key], self.l2_more[fitting_key] = self.loss_dict[ - fitting_key - ].build( - self.learning_rate, - self.place_holders["natoms_vec"], - self.model_pred[fitting_key], - self.place_holders, - suffix=fitting_key, - ) - if self.mixed_prec is not None: - self.l2_l[fitting_key] = tf.cast( - self.l2_l[fitting_key], - get_precision(self.mixed_prec["output_prec"]), - ) + self.l2_l, self.l2_more = self._build_loss() log.info("built network") - def _build_training(self): - trainable_variables = tf.trainable_variables() + def _build_optimizer(self, fitting_key=None): if self.run_opt.is_distrib: - if self.scale_lr_coef > 1.0: - log.info("Scale learning rate by coef: %f", self.scale_lr_coef) - optimizer = tf.train.AdamOptimizer( - self.learning_rate * self.scale_lr_coef - ) + if fitting_key is None: + if self.scale_lr_coef > 1.0: + log.info("Scale learning rate by coef: %f", self.scale_lr_coef) + optimizer = tf.train.AdamOptimizer( + self.learning_rate * self.scale_lr_coef + ) + else: + optimizer = tf.train.AdamOptimizer(self.learning_rate) + optimizer = self.run_opt._HVD.DistributedOptimizer(optimizer) else: - optimizer = tf.train.AdamOptimizer(self.learning_rate) - optimizer = self.run_opt._HVD.DistributedOptimizer(optimizer) + if self.scale_lr_coef_dict[fitting_key] > 1.0: + log.info( + "Scale learning rate by coef: %f", + self.scale_lr_coef_dict[fitting_key], + ) + optimizer = tf.train.AdamOptimizer( + self.learning_rate_dict[fitting_key] + * self.scale_lr_coef_dict[fitting_key] + ) + else: + optimizer = tf.train.AdamOptimizer( + learning_rate=self.learning_rate_dict[fitting_key] + ) + optimizer = self.run_opt._HVD.DistributedOptimizer(optimizer) else: - optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) + if fitting_key is None: + optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) + else: + optimizer = tf.train.AdamOptimizer( + learning_rate=self.learning_rate_dict[fitting_key] + ) + if self.mixed_prec is not None: _TF_VERSION = Version(TF_VERSION) - # check the TF_VERSION, when TF < 1.12, mixed precision is not allowed if _TF_VERSION < Version("1.14.0"): raise RuntimeError( "TensorFlow version %s is not compatible with the mixed precision setting. Please consider upgrading your TF version!" @@ -659,7 +707,13 @@ def _build_training(self): optimizer = tf.mixed_precision.enable_mixed_precision_graph_rewrite( optimizer ) + return optimizer + + def _build_training(self): + trainable_variables = tf.trainable_variables() + if not self.multi_task_mode: + optimizer = self._build_optimizer() apply_op = optimizer.minimize( loss=self.l2_l, global_step=self.global_step, @@ -671,6 +725,7 @@ def _build_training(self): else: self.train_op = {} for fitting_key in self.fitting_type_dict: + optimizer = self._build_optimizer(fitting_key=fitting_key) apply_op = optimizer.minimize( loss=self.l2_l[fitting_key], global_step=self.global_step, @@ -751,16 +806,30 @@ def train(self, train_data=None, valid_data=None): cur_batch = run_sess(self.sess, self.global_step) is_first_step = True self.cur_batch = cur_batch - log.info( - "start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e" - % ( - run_sess(self.sess, self.learning_rate), - self.lr.value(cur_batch), - self.lr.decay_steps_, - self.lr.decay_rate_, - self.lr.value(stop_batch), + if not self.multi_task_mode: + log.info( + "start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e" + % ( + run_sess(self.sess, self.learning_rate), + self.lr.value(cur_batch), + self.lr.decay_steps_, + self.lr.decay_rate_, + self.lr.value(stop_batch), + ) ) - ) + else: + for fitting_key in self.fitting_type_dict: + log.info( + "%s: start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e" + % ( + fitting_key, + run_sess(self.sess, self.learning_rate_dict[fitting_key]), + self.lr_dict[fitting_key].value(cur_batch), + self.lr_dict[fitting_key].decay_steps_, + self.lr_dict[fitting_key].decay_rate_, + self.lr_dict[fitting_key].value(stop_batch), + ) + ) prf_options = None prf_run_metadata = None @@ -829,22 +898,27 @@ def train(self, train_data=None, valid_data=None): train_batches = {} valid_batches = {} # valid_numb_batch_dict - for fitting_key in train_data: - train_batches[fitting_key] = [ - train_data[fitting_key].get_batch() + for fitting_key_ii in train_data: + # enumerate fitting key as fitting_key_ii + train_batches[fitting_key_ii] = [ + train_data[fitting_key_ii].get_batch() ] - valid_batches[fitting_key] = ( + valid_batches[fitting_key_ii] = ( [ - valid_data[fitting_key].get_batch() + valid_data[fitting_key_ii].get_batch() for ii in range( - self.valid_numb_batch_dict[fitting_key] + self.valid_numb_batch_dict[fitting_key_ii] ) ] - if fitting_key in valid_data + if fitting_key_ii in valid_data else None ) self.valid_on_the_fly( - fp, train_batches, valid_batches, print_header=True + fp, + train_batches, + valid_batches, + print_header=True, + fitting_key=fitting_key, ) is_first_step = False @@ -895,21 +969,23 @@ def train(self, train_data=None, valid_data=None): else: train_batches = {} valid_batches = {} - for fitting_key in train_data: - train_batches[fitting_key] = [ - train_data[fitting_key].get_batch() + for fitting_key_ii in train_data: + train_batches[fitting_key_ii] = [ + train_data[fitting_key_ii].get_batch() ] - valid_batches[fitting_key] = ( + valid_batches[fitting_key_ii] = ( [ - valid_data[fitting_key].get_batch() + valid_data[fitting_key_ii].get_batch() for ii in range( - self.valid_numb_batch_dict[fitting_key] + self.valid_numb_batch_dict[fitting_key_ii] ) ] - if fitting_key in valid_data + if fitting_key_ii in valid_data else None ) - self.valid_on_the_fly(fp, train_batches, valid_batches) + self.valid_on_the_fly( + fp, train_batches, valid_batches, fitting_key=fitting_key + ) if self.timing_in_training: toc = time.time() test_time = toc - tic @@ -1013,22 +1089,50 @@ def get_global_step(self): # fp.write(print_str) # fp.close () - def valid_on_the_fly(self, fp, train_batches, valid_batches, print_header=False): + def valid_on_the_fly( + self, fp, train_batches, valid_batches, print_header=False, fitting_key=None + ): train_results = self.get_evaluation_results(train_batches) valid_results = self.get_evaluation_results(valid_batches) cur_batch = self.cur_batch - current_lr = run_sess(self.sess, self.learning_rate) + if not self.multi_task_mode: + current_lr = run_sess(self.sess, self.learning_rate) + else: + assert ( + fitting_key is not None + ), "Fitting key must be assigned in validation!" + current_lr = None + # current_lr can be used as the learning rate of descriptor in the future + current_lr_dict = {} + for fitting_key_ii in train_batches: + current_lr_dict[fitting_key_ii] = run_sess( + self.sess, self.learning_rate_dict[fitting_key_ii] + ) if print_header: self.print_header(fp, train_results, valid_results, self.multi_task_mode) - self.print_on_training( - fp, - train_results, - valid_results, - cur_batch, - current_lr, - self.multi_task_mode, - ) + if not self.multi_task_mode: + self.print_on_training( + fp, + train_results, + valid_results, + cur_batch, + current_lr, + self.multi_task_mode, + ) + else: + assert ( + fitting_key is not None + ), "Fitting key must be assigned when printing learning rate!" + self.print_on_training( + fp, + train_results, + valid_results, + cur_batch, + current_lr, + self.multi_task_mode, + current_lr_dict, + ) @staticmethod def print_header(fp, train_results, valid_results, multi_task_mode=False): @@ -1043,6 +1147,7 @@ def print_header(fp, train_results, valid_results, multi_task_mode=False): prop_fmt = " %11s" for k in train_results.keys(): print_str += prop_fmt % (k + "_trn") + print_str += " %8s\n" % (k + "_lr") else: for fitting_key in train_results: if valid_results[fitting_key] is not None: @@ -1053,13 +1158,19 @@ def print_header(fp, train_results, valid_results, multi_task_mode=False): prop_fmt = " %11s" for k in train_results[fitting_key].keys(): print_str += prop_fmt % (k + "_trn") - print_str += " %8s\n" % "lr" + print_str += " %8s\n" % (fitting_key + "_lr") fp.write(print_str) fp.flush() @staticmethod def print_on_training( - fp, train_results, valid_results, cur_batch, cur_lr, multi_task_mode=False + fp, + train_results, + valid_results, + cur_batch, + cur_lr, + multi_task_mode=False, + cur_lr_dict=None, ): print_str = "" print_str += "%7d" % cur_batch @@ -1073,6 +1184,7 @@ def print_on_training( prop_fmt = " %11.2e" for k in train_results.keys(): print_str += prop_fmt % (train_results[k]) + print_str += " %8.1e\n" % cur_lr else: for fitting_key in train_results: if valid_results[fitting_key] is not None: @@ -1087,7 +1199,7 @@ def print_on_training( prop_fmt = " %11.2e" for k in train_results[fitting_key].keys(): print_str += prop_fmt % (train_results[fitting_key][k]) - print_str += " %8.1e\n" % cur_lr + print_str += " %8.1e\n" % cur_lr_dict[fitting_key] fp.write(print_str) fp.flush() diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 87ded2727c..76cb903835 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -803,10 +803,22 @@ def learning_rate_args(): ) ], [learning_rate_variant_type_args()], + optional=True, doc=doc_lr, ) +def learning_rate_dict_args(): + doc_learning_rate_dict = ( + "The dictionary of definitions of learning rates in multi-task mode. " + "Each learning_rate_dict[fitting_key], with user-defined name `fitting_key` in `model/fitting_net_dict`, is the single definition of learning rate.\n" + ) + ca = Argument( + "learning_rate_dict", dict, [], [], optional=True, doc=doc_learning_rate_dict + ) + return ca + + # --- Loss configurations: --- # def start_pref(item): return f"The prefactor of {item} loss at the start of the training. Should be larger than or equal to 0. If set to none-zero value, the {item} label should be provided by file {item}.npy in each data system. If both start_pref_{item} and limit_pref_{item} are set to 0, then the {item} will be ignored." @@ -1239,6 +1251,7 @@ def gen_doc(*, make_anchor=True, make_link=True, **kwargs): make_anchor = True ma = model_args() lra = learning_rate_args() + lrda = learning_rate_dict_args() la = loss_args() lda = loss_dict_args() ta = training_args() @@ -1248,6 +1261,7 @@ def gen_doc(*, make_anchor=True, make_link=True, **kwargs): ptr.append(la.gen_doc(make_anchor=make_anchor, make_link=make_link, **kwargs)) ptr.append(lda.gen_doc(make_anchor=make_anchor, make_link=make_link, **kwargs)) ptr.append(lra.gen_doc(make_anchor=make_anchor, make_link=make_link, **kwargs)) + ptr.append(lrda.gen_doc(make_anchor=make_anchor, make_link=make_link, **kwargs)) ptr.append(ta.gen_doc(make_anchor=make_anchor, make_link=make_link, **kwargs)) ptr.append(nvnmda.gen_doc(make_anchor=make_anchor, make_link=make_link, **kwargs)) @@ -1265,6 +1279,7 @@ def gen_json(**kwargs): ( model_args(), learning_rate_args(), + learning_rate_dict_args(), loss_args(), loss_dict_args(), training_args(), @@ -1278,6 +1293,7 @@ def gen_args(**kwargs): return [ model_args(), learning_rate_args(), + learning_rate_dict_args(), loss_args(), loss_dict_args(), training_args(), @@ -1301,10 +1317,12 @@ def normalize_multi_task(data): single_training_data = "training_data" in data["training"].keys() single_valid_data = "validation_data" in data["training"].keys() single_loss = "loss" in data.keys() + single_learning_rate = "learning_rate" in data.keys() multi_fitting_net = "fitting_net_dict" in data["model"].keys() multi_training_data = "data_dict" in data["training"].keys() multi_loss = "loss_dict" in data.keys() multi_fitting_weight = "fitting_weight" in data["training"].keys() + multi_learning_rate = "learning_rate_dict" in data.keys() assert (single_fitting_net == single_training_data) and ( multi_fitting_net == multi_training_data ), ( @@ -1343,6 +1361,16 @@ def normalize_multi_task(data): if multi_loss else {} ) + if multi_learning_rate: + data["learning_rate_dict"] = normalize_learning_rate_dict( + data["model"]["fitting_net_dict"].keys(), data["learning_rate_dict"] + ) + elif single_learning_rate: + data[ + "learning_rate_dict" + ] = normalize_learning_rate_dict_with_single_learning_rate( + data["model"]["fitting_net_dict"].keys(), data["learning_rate"] + ) fitting_weight = ( data["training"]["fitting_weight"] if multi_fitting_weight else None ) @@ -1355,6 +1383,9 @@ def normalize_multi_task(data): assert ( not multi_loss ), "In single-task mode, please use 'model/loss' in stead of 'model/loss_dict'! " + assert ( + not multi_learning_rate + ), "In single-task mode, please use 'model/learning_rate' in stead of 'model/learning_rate_dict'! " return data @@ -1399,6 +1430,35 @@ def normalize_loss_dict(fitting_keys, loss_dict): return new_dict +def normalize_learning_rate_dict(fitting_keys, learning_rate_dict): + # check the learning_rate dict + failed_learning_rate_keys = [ + item for item in learning_rate_dict if item not in fitting_keys + ] + assert ( + not failed_learning_rate_keys + ), "Learning rate dict key(s) {} not have corresponding fitting keys in {}! ".format( + str(failed_learning_rate_keys), str(list(fitting_keys)) + ) + new_dict = {} + base = Argument("base", dict, [], [learning_rate_variant_type_args()], doc="") + for item in learning_rate_dict: + data = base.normalize_value(learning_rate_dict[item], trim_pattern="_*") + base.check_value(data, strict=True) + new_dict[item] = data + return new_dict + + +def normalize_learning_rate_dict_with_single_learning_rate(fitting_keys, learning_rate): + new_dict = {} + base = Argument("base", dict, [], [learning_rate_variant_type_args()], doc="") + data = base.normalize_value(learning_rate, trim_pattern="_*") + base.check_value(data, strict=True) + for fitting_key in fitting_keys: + new_dict[fitting_key] = data + return new_dict + + def normalize_fitting_weight(fitting_keys, data_keys, fitting_weight=None): # check the mapping failed_data_keys = [item for item in data_keys if item not in fitting_keys] @@ -1473,12 +1533,13 @@ def normalize(data): data = normalize_multi_task(data) ma = model_args() lra = learning_rate_args() + lrda = learning_rate_dict_args() la = loss_args() lda = loss_dict_args() ta = training_args() nvnmda = nvnmd_args() - base = Argument("base", dict, [ma, lra, la, lda, ta, nvnmda]) + base = Argument("base", dict, [ma, lra, lrda, la, lda, ta, nvnmda]) data = base.normalize_value(data, trim_pattern="_*") base.check_value(data, strict=True) diff --git a/deepmd/utils/multi_init.py b/deepmd/utils/multi_init.py index f2984ce145..108dab103c 100644 --- a/deepmd/utils/multi_init.py +++ b/deepmd/utils/multi_init.py @@ -157,6 +157,30 @@ def replace_model_params_with_frz_multi_model( f"Add '{config_key}/{fitting_key}' configurations from the pretrained frozen model." ) + # learning rate dict keep backward compatibility + config_key = "learning_rate_dict" + single_config_key = "learning_rate" + cur_jdata = jdata + target_jdata = pretrained_jdata + if (single_config_key not in cur_jdata) and (config_key in cur_jdata): + cur_jdata = cur_jdata[config_key] + if config_key in target_jdata: + target_jdata = target_jdata[config_key] + for fitting_key in reused_fittings: + if fitting_key not in cur_jdata: + target_para = target_jdata[fitting_key] + cur_jdata[fitting_key] = target_para + log.info( + f"Add '{config_key}/{fitting_key}' configurations from the pretrained frozen model." + ) + else: + for fitting_key in reused_fittings: + if fitting_key not in cur_jdata: + cur_jdata[fitting_key] = {} + log.info( + f"Add '{config_key}/{fitting_key}' configurations as default." + ) + return jdata diff --git a/source/tests/test_init_frz_model_multi.py b/source/tests/test_init_frz_model_multi.py index 3405ab1544..99e2a7b3a6 100644 --- a/source/tests/test_init_frz_model_multi.py +++ b/source/tests/test_init_frz_model_multi.py @@ -61,6 +61,7 @@ def _init_models(): jdata = j_loader(str(tests_path / os.path.join("init_frz_model", "input.json"))) fitting_config = jdata["model"].pop("fitting_net") loss_config = jdata.pop("loss") + learning_rate_config = jdata.pop("learning_rate") training_data_config = jdata["training"].pop("training_data") validation_data_config = jdata["training"].pop("validation_data") jdata["training"]["data_dict"] = {} @@ -78,6 +79,8 @@ def _init_models(): jdata["model"]["fitting_net_dict"]["water_ener"] = fitting_config jdata["loss_dict"] = {} jdata["loss_dict"]["water_ener"] = loss_config + jdata["learning_rate_dict"] = {} + jdata["learning_rate_dict"]["water_ener"] = learning_rate_config with open(INPUT, "w") as fp: json.dump(jdata, fp, indent=4) ret = run_dp("dp train " + INPUT) @@ -95,6 +98,8 @@ def _init_models(): jdata["model"]["fitting_net_dict"]["water_ener_new"] = fitting_config jdata["loss_dict"] = {} jdata["loss_dict"]["water_ener_new"] = loss_config + jdata["learning_rate_dict"] = {} + jdata["learning_rate_dict"]["water_ener_new"] = learning_rate_config jdata["training"]["data_dict"] = {} jdata["training"]["data_dict"]["water_ener_new"] = {} jdata["training"]["data_dict"]["water_ener_new"][ diff --git a/source/tests/water_multi.json b/source/tests/water_multi.json index f16e0eadac..e74b724157 100644 --- a/source/tests/water_multi.json +++ b/source/tests/water_multi.json @@ -45,13 +45,23 @@ } } }, - "learning_rate": { - "type": "exp", - "start_lr": 0.001, - "decay_steps": 5000, - "decay_rate": 0.95, - "_comment": "that's all" - }, + "learning_rate_dict": + { + "water_ener": { + "type": "exp", + "start_lr": 0.001, + "decay_steps": 5000, + "decay_rate": 0.95, + "_comment": "that's all" + }, + "water_dipole": { + "type": "exp", + "start_lr": 0.001, + "decay_steps": 5000, + "decay_rate": 0.95, + "_comment": "that's all" + } + }, "loss_dict": { "water_ener": {