From 46de84c17264ef7500f1ed313ee8a758f16166c9 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Sun, 16 Feb 2025 14:28:55 +0800 Subject: [PATCH] add `trainable` to property fitting --- deepmd/pt/model/task/property.py | 3 +++ deepmd/utils/argcheck.py | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/deepmd/pt/model/task/property.py b/deepmd/pt/model/task/property.py index c15e60fe04..5ef0cd0233 100644 --- a/deepmd/pt/model/task/property.py +++ b/deepmd/pt/model/task/property.py @@ -2,6 +2,7 @@ import logging from typing import ( Optional, + Union, ) import torch @@ -88,6 +89,7 @@ def __init__( activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, mixed_types: bool = True, + trainable: Union[bool, list[bool]] = True, seed: Optional[int] = None, **kwargs, ) -> None: @@ -107,6 +109,7 @@ def __init__( activation_function=activation_function, precision=precision, mixed_types=mixed_types, + trainable=trainable, seed=seed, **kwargs, ) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 1d897ceb57..a00cfb047a 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1580,6 +1580,9 @@ def fitting_property(): doc_task_dim = "The dimension of outputs of fitting net" doc_intensive = "Whether the fitting property is intensive" doc_property_name = "The names of fitting property, which should be consistent with the property name in the dataset." + doc_trainable = "Whether the parameters in the fitting net are trainable. This option can be\n\n\ +- bool: True if all parameters of the fitting net are trainable, False otherwise.\n\n\ +- list of bool: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of this list should be equal to len(`neuron`)+1." return [ Argument("numb_fparam", int, optional=True, default=0, doc=doc_numb_fparam), Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam), @@ -1616,6 +1619,13 @@ def fitting_property(): optional=False, doc=doc_property_name, ), + Argument( + "trainable", + [list[bool], bool], + optional=True, + default=True, + doc=doc_trainable, + ), ]