From d46d5f0b76cfd3704f7bc1ebd2fd514d002b161c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 25 Oct 2024 04:46:13 -0400 Subject: [PATCH 01/16] export savedmodel Signed-off-by: Jinzhe Zeng --- deepmd/backend/jax.py | 2 +- deepmd/dpmodel/descriptor/se_e2_a.py | 2 +- deepmd/jax/utils/serialization.py | 39 ++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/deepmd/backend/jax.py b/deepmd/backend/jax.py index bb2fba5a7c..3b2b0fcc56 100644 --- a/deepmd/backend/jax.py +++ b/deepmd/backend/jax.py @@ -38,7 +38,7 @@ class JAXBackend(Backend): # | Backend.Feature.NEIGHBOR_STAT ) """The features of the backend.""" - suffixes: ClassVar[list[str]] = [".jax"] + suffixes: ClassVar[list[str]] = [".jax", ".savedmodel"] """The suffixes of the backend.""" def is_available(self) -> bool: diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index d29ce8862e..ff8e401063 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -555,7 +555,7 @@ def call( coord_ext, atype_ext, nlist, self.davg, self.dstd ) nf, nloc, nnei, _ = rr.shape - sec = xp.asarray(self.sel_cumsum) + sec = self.sel_cumsum ng = self.neuron[-1] gr = xp.zeros([nf * nloc, ng, 4], dtype=self.dstd.dtype) diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 43070f8a07..1a086326ce 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -39,6 +39,45 @@ def deserialize_to_file(model_file: str, data: dict) -> None: model_def_script=ocp.args.JsonSave(model_def_script), ), ) + elif model_file.endswith(".savedmodel"): + import tensorflow as tf + from jax.experimental import ( + jax2tf, + ) + + model = BaseModel.deserialize(data["model"]) + model_def_script = data["model_def_script"] + call_lower = model.call_lower + + my_model = tf.Module() + + # Save a function that can take scalar inputs. + my_model.call_lower = tf.function( + jax2tf.convert( + call_lower, + polymorphic_shapes=[ + "(nf, nloc + nghost, 3)", + "(nf, nloc + nghost)", + f"(nf, nloc, {model.get_nnei()})", + "(nf, np)", + "(nf, na)", + ], + ), + autograph=False, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int64), + tf.TensorSpec([None, None, model.get_nnei()], tf.int64), + tf.TensorSpec([None, None], tf.float64), + tf.TensorSpec([None, None], tf.float64), + ], + ) + my_model.model_def_script = model_def_script + tf.saved_model.save( + my_model, + model_file, + options=tf.saved_model.SaveOptions(experimental_custom_gradients=True), + ) else: raise ValueError("JAX backend only supports converting .jax directory") From ed1288e5c4b835de2cd7300e1ea42f913076f0cb Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Nov 2024 02:44:17 -0500 Subject: [PATCH 02/16] move to a seperated module Signed-off-by: Jinzhe Zeng --- deepmd/backend/jax.py | 2 +- deepmd/backend/jax2tf.py | 115 +++++++++++++++++++++++++++++ deepmd/jax/jax2tf/__init__.py | 1 + deepmd/jax/jax2tf/serialization.py | 55 ++++++++++++++ deepmd/jax/utils/serialization.py | 40 ---------- 5 files changed, 172 insertions(+), 41 deletions(-) create mode 100644 deepmd/backend/jax2tf.py create mode 100644 deepmd/jax/jax2tf/__init__.py create mode 100644 deepmd/jax/jax2tf/serialization.py diff --git a/deepmd/backend/jax.py b/deepmd/backend/jax.py index 7a714c2090..cfb0936bda 100644 --- a/deepmd/backend/jax.py +++ b/deepmd/backend/jax.py @@ -38,7 +38,7 @@ class JAXBackend(Backend): | Backend.Feature.NEIGHBOR_STAT ) """The features of the backend.""" - suffixes: ClassVar[list[str]] = [".hlo", ".jax", ".savedmodel"] + suffixes: ClassVar[list[str]] = [".hlo", ".jax"] """The suffixes of the backend.""" def is_available(self) -> bool: diff --git a/deepmd/backend/jax2tf.py b/deepmd/backend/jax2tf.py new file mode 100644 index 0000000000..cde65bebed --- /dev/null +++ b/deepmd/backend/jax2tf.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from importlib.util import ( + find_spec, +) +from typing import ( + TYPE_CHECKING, + Callable, + ClassVar, +) + +from deepmd.backend.backend import ( + Backend, +) + +if TYPE_CHECKING: + from argparse import ( + Namespace, + ) + + from deepmd.infer.deep_eval import ( + DeepEvalBackend, + ) + from deepmd.utils.neighbor_stat import ( + NeighborStat, + ) + + +@Backend.register("jax2tf") +class JAXBackend(Backend): + """JAX to TensorFlow backend.""" + + name = "JAX2TF" + """The formal name of the backend.""" + features: ClassVar[Backend.Feature] = ( + Backend.Feature.IO + # | Backend.Feature.ENTRY_POINT + # | Backend.Feature.DEEP_EVAL + ) + """The features of the backend.""" + suffixes: ClassVar[list[str]] = [".savedmodel"] + """The suffixes of the backend.""" + + def is_available(self) -> bool: + """Check if the backend is available. + + Returns + ------- + bool + Whether the backend is available. + """ + return find_spec("jax") is not None and find_spec("tensorflow") is not None + + @property + def entry_point_hook(self) -> Callable[["Namespace"], None]: + """The entry point hook of the backend. + + Returns + ------- + Callable[[Namespace], None] + The entry point hook of the backend. + """ + raise NotImplementedError + + @property + def deep_eval(self) -> type["DeepEvalBackend"]: + """The Deep Eval backend of the backend. + + Returns + ------- + type[DeepEvalBackend] + The Deep Eval backend of the backend. + """ + raise NotImplementedError + # from deepmd.jax.infer.deep_eval import ( + # DeepEval, + # ) + + # return DeepEval + + @property + def neighbor_stat(self) -> type["NeighborStat"]: + """The neighbor statistics of the backend. + + Returns + ------- + type[NeighborStat] + The neighbor statistics of the backend. + """ + raise NotImplementedError + + @property + def serialize_hook(self) -> Callable[[str], dict]: + """The serialize hook to convert the model file to a dictionary. + + Returns + ------- + Callable[[str], dict] + The serialize hook of the backend. + """ + raise NotImplementedError + + @property + def deserialize_hook(self) -> Callable[[str, dict], None]: + """The deserialize hook to convert the dictionary to a model file. + + Returns + ------- + Callable[[str, dict], None] + The deserialize hook of the backend. + """ + from deepmd.jax.jax2tf.serialization import ( + deserialize_to_file, + ) + + return deserialize_to_file diff --git a/deepmd/jax/jax2tf/__init__.py b/deepmd/jax/jax2tf/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/jax/jax2tf/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py new file mode 100644 index 0000000000..4c248dd139 --- /dev/null +++ b/deepmd/jax/jax2tf/serialization.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tensorflow as tf +from jax.experimental import ( + jax2tf, +) + +from deepmd.jax.model.base_model import ( + BaseModel, +) + + +def deserialize_to_file(model_file: str, data: dict) -> None: + """Deserialize the dictionary to a model file. + + Parameters + ---------- + model_file : str + The model file to be saved. + data : dict + The dictionary to be deserialized. + """ + if model_file.endswith(".savedmodel"): + model = BaseModel.deserialize(data["model"]) + model_def_script = data["model_def_script"] + call_lower = model.call_lower + + my_model = tf.Module() + + # Save a function that can take scalar inputs. + my_model.call_lower = tf.function( + jax2tf.convert( + call_lower, + polymorphic_shapes=[ + "(nf, nloc + nghost, 3)", + "(nf, nloc + nghost)", + f"(nf, nloc, {model.get_nnei()})", + f"(nf, {model.get_dim_fparam()})", + f"(nf, nloc, {model.get_dim_aparam()})", + ], + ), + autograph=False, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int64), + tf.TensorSpec([None, None, model.get_nnei()], tf.int64), + tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), + tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + ], + ) + my_model.model_def_script = model_def_script + tf.saved_model.save( + my_model, + model_file, + options=tf.saved_model.SaveOptions(experimental_custom_gradients=True), + ) diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 6022c3bc64..ec2de3060e 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -46,46 +46,6 @@ def deserialize_to_file(model_file: str, data: dict) -> None: model_def_script=ocp.args.JsonSave(model_def_script), ), ) - elif model_file.endswith(".savedmodel"): - import tensorflow as tf - from jax.experimental import ( - jax2tf, - ) - - model = BaseModel.deserialize(data["model"]) - model_def_script = data["model_def_script"] - call_lower = model.call_lower - - my_model = tf.Module() - - # Save a function that can take scalar inputs. - my_model.call_lower = tf.function( - jax2tf.convert( - call_lower, - polymorphic_shapes=[ - "(nf, nloc + nghost, 3)", - "(nf, nloc + nghost)", - f"(nf, nloc, {model.get_nnei()})", - "(nf, np)", - "(nf, na)", - ], - ), - autograph=False, - input_signature=[ - tf.TensorSpec([None, None, 3], tf.float64), - tf.TensorSpec([None, None], tf.int64), - tf.TensorSpec([None, None, model.get_nnei()], tf.int64), - tf.TensorSpec([None, None], tf.float64), - tf.TensorSpec([None, None], tf.float64), - ], - ) - my_model.model_def_script = model_def_script - tf.saved_model.save( - my_model, - model_file, - options=tf.saved_model.SaveOptions(experimental_custom_gradients=True), - ) - elif model_file.endswith(".hlo"): model = BaseModel.deserialize(data["model"]) model_def_script = data["model_def_script"] From ecba7096e5da394d5a4f6befcc3fde8396cf871a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Nov 2024 03:48:07 -0500 Subject: [PATCH 03/16] refactor Signed-off-by: Jinzhe Zeng --- deepmd/backend/jax.py | 2 +- deepmd/backend/jax2tf.py | 115 ----------- deepmd/jax/infer/deep_eval.py | 27 ++- deepmd/jax/jax2tf/serialization.py | 83 +++++--- deepmd/jax/jax2tf/tfmodel.py | 308 +++++++++++++++++++++++++++++ deepmd/jax/utils/serialization.py | 8 +- 6 files changed, 395 insertions(+), 148 deletions(-) delete mode 100644 deepmd/backend/jax2tf.py create mode 100644 deepmd/jax/jax2tf/tfmodel.py diff --git a/deepmd/backend/jax.py b/deepmd/backend/jax.py index cfb0936bda..7a714c2090 100644 --- a/deepmd/backend/jax.py +++ b/deepmd/backend/jax.py @@ -38,7 +38,7 @@ class JAXBackend(Backend): | Backend.Feature.NEIGHBOR_STAT ) """The features of the backend.""" - suffixes: ClassVar[list[str]] = [".hlo", ".jax"] + suffixes: ClassVar[list[str]] = [".hlo", ".jax", ".savedmodel"] """The suffixes of the backend.""" def is_available(self) -> bool: diff --git a/deepmd/backend/jax2tf.py b/deepmd/backend/jax2tf.py deleted file mode 100644 index cde65bebed..0000000000 --- a/deepmd/backend/jax2tf.py +++ /dev/null @@ -1,115 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -from importlib.util import ( - find_spec, -) -from typing import ( - TYPE_CHECKING, - Callable, - ClassVar, -) - -from deepmd.backend.backend import ( - Backend, -) - -if TYPE_CHECKING: - from argparse import ( - Namespace, - ) - - from deepmd.infer.deep_eval import ( - DeepEvalBackend, - ) - from deepmd.utils.neighbor_stat import ( - NeighborStat, - ) - - -@Backend.register("jax2tf") -class JAXBackend(Backend): - """JAX to TensorFlow backend.""" - - name = "JAX2TF" - """The formal name of the backend.""" - features: ClassVar[Backend.Feature] = ( - Backend.Feature.IO - # | Backend.Feature.ENTRY_POINT - # | Backend.Feature.DEEP_EVAL - ) - """The features of the backend.""" - suffixes: ClassVar[list[str]] = [".savedmodel"] - """The suffixes of the backend.""" - - def is_available(self) -> bool: - """Check if the backend is available. - - Returns - ------- - bool - Whether the backend is available. - """ - return find_spec("jax") is not None and find_spec("tensorflow") is not None - - @property - def entry_point_hook(self) -> Callable[["Namespace"], None]: - """The entry point hook of the backend. - - Returns - ------- - Callable[[Namespace], None] - The entry point hook of the backend. - """ - raise NotImplementedError - - @property - def deep_eval(self) -> type["DeepEvalBackend"]: - """The Deep Eval backend of the backend. - - Returns - ------- - type[DeepEvalBackend] - The Deep Eval backend of the backend. - """ - raise NotImplementedError - # from deepmd.jax.infer.deep_eval import ( - # DeepEval, - # ) - - # return DeepEval - - @property - def neighbor_stat(self) -> type["NeighborStat"]: - """The neighbor statistics of the backend. - - Returns - ------- - type[NeighborStat] - The neighbor statistics of the backend. - """ - raise NotImplementedError - - @property - def serialize_hook(self) -> Callable[[str], dict]: - """The serialize hook to convert the model file to a dictionary. - - Returns - ------- - Callable[[str], dict] - The serialize hook of the backend. - """ - raise NotImplementedError - - @property - def deserialize_hook(self) -> Callable[[str, dict], None]: - """The deserialize hook to convert the dictionary to a model file. - - Returns - ------- - Callable[[str, dict], None] - The deserialize hook of the backend. - """ - from deepmd.jax.jax2tf.serialization import ( - deserialize_to_file, - ) - - return deserialize_to_file diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index b60076c68c..fc526a502e 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -90,15 +90,24 @@ def __init__( self.output_def = output_def self.model_path = model_file - model_data = load_dp_model(model_file) - self.dp = HLO( - stablehlo=model_data["@variables"]["stablehlo"].tobytes(), - stablehlo_atomic_virial=model_data["@variables"][ - "stablehlo_atomic_virial" - ].tobytes(), - model_def_script=model_data["model_def_script"], - **model_data["constants"], - ) + if model_file.endswith(".hlo"): + model_data = load_dp_model(model_file) + self.dp = HLO( + stablehlo=model_data["@variables"]["stablehlo"].tobytes(), + stablehlo_atomic_virial=model_data["@variables"][ + "stablehlo_atomic_virial" + ].tobytes(), + model_def_script=model_data["model_def_script"], + **model_data["constants"], + ) + elif model_file.endswith(".savedmodel"): + from deepmd.jax.jax2tf.tfmodel import ( + TFModelWrapper, + ) + + self.dp = TFModelWrapper(model_file) + else: + raise ValueError("Unsupported file extension") self.rcut = self.dp.get_rcut() self.type_map = self.dp.get_type_map() if isinstance(auto_batch_size, bool): diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index 4c248dd139..f6758a18aa 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import json + import tensorflow as tf from jax.experimental import ( jax2tf, @@ -24,32 +26,69 @@ def deserialize_to_file(model_file: str, data: dict) -> None: model_def_script = data["model_def_script"] call_lower = model.call_lower - my_model = tf.Module() + tf_model = tf.Module() - # Save a function that can take scalar inputs. - my_model.call_lower = tf.function( - jax2tf.convert( - call_lower, - polymorphic_shapes=[ - "(nf, nloc + nghost, 3)", - "(nf, nloc + nghost)", - f"(nf, nloc, {model.get_nnei()})", - f"(nf, {model.get_dim_fparam()})", - f"(nf, nloc, {model.get_dim_aparam()})", + def exported_whether_do_atomic_virial(do_atomic_virial): + def call_lower_with_fixed_do_atomic_virial( + coord, atype, nlist, nlist_start, fparam, aparam + ): + return call_lower( + coord, + atype, + nlist, + nlist_start, + fparam, + aparam, + do_atomic_virial=do_atomic_virial, + ) + + return tf.function( + jax2tf.convert( + call_lower, + polymorphic_shapes=[ + "(nf, nloc + nghost, 3)", + "(nf, nloc + nghost)", + f"(nf, nloc, {model.get_nnei()})", + f"(nf, {model.get_dim_fparam()})", + f"(nf, nloc, {model.get_dim_aparam()})", + ], + ), + autograph=False, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int64), + tf.TensorSpec([None, None, model.get_nnei()], tf.int64), + tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), + tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), ], - ), - autograph=False, - input_signature=[ - tf.TensorSpec([None, None, 3], tf.float64), - tf.TensorSpec([None, None], tf.int64), - tf.TensorSpec([None, None, model.get_nnei()], tf.int64), - tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), - tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), - ], + ) + + # Save a function that can take scalar inputs. + tf_model.call_lower = exported_whether_do_atomic_virial(do_atomic_virial=False) + tf_model.call_lower_atomic_virial = exported_whether_do_atomic_virial( + do_atomic_virial=True + ) + # set other attributes + tf_model.type_map = tf.Variable(model.get_type_map(), dtype=tf.string) + tf_model.rcut = tf.Variable(model.get_rcut(), dtype=tf.double) + tf_model.dim_fparam = tf.Variable(model.get_dim_fparam(), dtype=tf.int64) + tf_model.dim_aparam = tf.Variable(model.get_dim_aparam(), dtype=tf.int64) + tf_model.sel_type = tf.Variable(model.get_sel_type(), dtype=tf.int64) + tf_model.is_aparam_nall = tf.Variable(model.is_aparam_nall(), dtype=tf.bool) + tf_model.model_output_type = tf.Variable( + model.model_output_type(), dtype=tf.string + ) + tf_model.mixed_types = tf.Variable(model.mixed_types(), dtype=tf.bool) + if model.get_min_nbor_dist() is not None: + tf_model.min_nbor_dist = tf.Variable( + model.get_min_nbor_dist(), dtype=tf.double + ) + tf_model.sel = tf.Variable(model.get_sel(), dtype=tf.int64) + tf_model.model_def_script = tf.Variable( + json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string ) - my_model.model_def_script = model_def_script tf.saved_model.save( - my_model, + tf_model, model_file, options=tf.saved_model.SaveOptions(experimental_custom_gradients=True), ) diff --git a/deepmd/jax/jax2tf/tfmodel.py b/deepmd/jax/jax2tf/tfmodel.py new file mode 100644 index 0000000000..ae60cad936 --- /dev/null +++ b/deepmd/jax/jax2tf/tfmodel.py @@ -0,0 +1,308 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + Optional, +) + +import jax.experimental.jax2tf as jax2tf +import tensorflow as tf + +from deepmd.dpmodel.model.make_model import ( + model_call_from_call_lower, +) +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + ModelOutputDef, + OutputVariableDef, +) +from deepmd.jax.env import ( + jnp, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) + +OUTPUT_DEFS = { + "energy": OutputVariableDef( + "energy", + shape=[1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + "mask": OutputVariableDef( + "mask", + shape=[1], + reducible=False, + r_differentiable=False, + c_differentiable=False, + ), +} + + +class TFModelWrapper(tf.Module): + def __init__( + self, + model, + ) -> None: + self.model = tf.saved_model.load(model) + self._call_lower = jax2tf.call_tf(self.model.call_lower) + self._call_lower_atomic_virial = jax2tf.call_tf( + self.model.call_lower_atomic_virial + ) + self.type_map = self.model.type_map.numpy().tolist() + self.rcut = self.model.rcut.numpy().item() + self.dim_fparam = self.model.dim_fparam.numpy().item() + self.dim_aparam = self.model.dim_aparam.numpy().item() + self.sel_type = self.model.sel_type.numpy().tolist() + self._is_aparam_nall = self.model.is_aparam_nall.numpy().item() + self._model_output_type = self.model.model_output_type.numpy().tolist() + self._mixed_types = self.model.mixed_types.numpy().item() + if hasattr(self.model, "min_nbor_dist"): + self.min_nbor_dist = self.model.min_nbor_dist.numpy().item() + else: + self.min_nbor_dist = None + self.sel = self.model.sel.numpy().tolist() + self.model_def_script = self.model.model_def_script.numpy().decode() + + def __call__( + self, + coord: jnp.ndarray, + atype: jnp.ndarray, + box: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, + ) -> Any: + """Return model prediction. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + If calculate the atomic virial. + + Returns + ------- + ret_dict + The result dict of type dict[str,jnp.ndarray]. + The keys are defined by the `ModelOutputDef`. + + """ + return self.call(coord, atype, box, fparam, aparam, do_atomic_virial) + + def call( + self, + coord: jnp.ndarray, + atype: jnp.ndarray, + box: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, + ): + """Return model prediction. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + If calculate the atomic virial. + + Returns + ------- + ret_dict + The result dict of type dict[str,jnp.ndarray]. + The keys are defined by the `ModelOutputDef`. + + """ + return model_call_from_call_lower( + call_lower=self.call_lower, + rcut=self.get_rcut(), + sel=self.get_sel(), + mixed_types=self.mixed_types(), + model_output_def=self.model_output_def(), + coord=coord, + atype=atype, + box=box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + def model_output_def(self): + return ModelOutputDef( + FittingOutputDef([OUTPUT_DEFS[tt] for tt in self.model_output_type()]) + ) + + def call_lower( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, + ): + if do_atomic_virial: + call_lower = self._call_lower_atomic_virial + else: + call_lower = self._call_lower + return call_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + ) + + def get_type_map(self) -> list[str]: + """Get the type map.""" + return self.type_map + + def get_rcut(self): + """Get the cut-off radius.""" + return self.rcut + + def get_dim_fparam(self): + """Get the number (dimension) of frame parameters of this atomic model.""" + return self.dim_fparam + + def get_dim_aparam(self): + """Get the number (dimension) of atomic parameters of this atomic model.""" + return self.dim_aparam + + def get_sel_type(self) -> list[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + return self.sel_type + + def is_aparam_nall(self) -> bool: + """Check whether the shape of atomic parameters is (nframes, nall, ndim). + + If False, the shape is (nframes, nloc, ndim). + """ + return self._is_aparam_nall + + def model_output_type(self) -> list[str]: + """Get the output type for the model.""" + return self._model_output_type + + def serialize(self) -> dict: + """Serialize the model. + + Returns + ------- + dict + The serialized data + """ + raise NotImplementedError("Not implemented") + + @classmethod + def deserialize(cls, data: dict) -> "TFModelWrapper": + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + BaseModel + The deserialized model + """ + raise NotImplementedError("Not implemented") + + def get_model_def_script(self) -> str: + """Get the model definition script.""" + return self.model_def_script + + def get_min_nbor_dist(self) -> Optional[float]: + """Get the minimum distance between two atoms.""" + return self.min_nbor_dist + + def get_nnei(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return self.nsel + + def get_sel(self) -> list[int]: + return self.sel + + def get_nsel(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return sum(self.sel) + + def mixed_types(self) -> bool: + return self._mixed_types + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statictics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + raise NotImplementedError("Not implemented") + + @classmethod + def get_model(cls, model_params: dict) -> "TFModelWrapper": + """Get the model by the parameters. + + By default, all the parameters are directly passed to the constructor. + If not, override this method. + + Parameters + ---------- + model_params : dict + The model parameters + + Returns + ------- + BaseBaseModel + The model + """ + raise NotImplementedError("Not implemented") diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index ec2de3060e..f48f75f6d2 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -107,8 +107,14 @@ def call_lower_with_fixed_do_atomic_virial( "sel": model.get_sel(), } save_dp_model(filename=model_file, model_dict=data) + elif model_file.endswith(".savedmodel"): + from deepmd.jax.jax2tf.serialization import ( + deserialize_to_file as deserialize_to_savedmodel, + ) + + return deserialize_to_savedmodel(model_file, data) else: - raise ValueError("JAX backend only supports converting .jax directory") + raise ValueError("Unsupported file extension") def serialize_from_file(model_file: str) -> dict: From e1d609e351570071c5f692e5efa88a280d5541ef Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Nov 2024 04:27:29 -0500 Subject: [PATCH 04/16] fix bugs Signed-off-by: Jinzhe Zeng --- deepmd/jax/jax2tf/serialization.py | 11 +++++++---- deepmd/jax/jax2tf/tfmodel.py | 21 +++++++++++++++++++-- deepmd/jax/utils/serialization.py | 4 ++-- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index f6758a18aa..869aa8edeb 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -30,13 +30,13 @@ def deserialize_to_file(model_file: str, data: dict) -> None: def exported_whether_do_atomic_virial(do_atomic_virial): def call_lower_with_fixed_do_atomic_virial( - coord, atype, nlist, nlist_start, fparam, aparam + coord, atype, nlist, mapping, fparam, aparam ): return call_lower( coord, atype, nlist, - nlist_start, + mapping, fparam, aparam, do_atomic_virial=do_atomic_virial, @@ -44,20 +44,23 @@ def call_lower_with_fixed_do_atomic_virial( return tf.function( jax2tf.convert( - call_lower, + call_lower_with_fixed_do_atomic_virial, polymorphic_shapes=[ "(nf, nloc + nghost, 3)", "(nf, nloc + nghost)", f"(nf, nloc, {model.get_nnei()})", + "(nf, nloc + nghost)", f"(nf, {model.get_dim_fparam()})", f"(nf, nloc, {model.get_dim_aparam()})", ], + with_gradient=True, ), autograph=False, input_signature=[ tf.TensorSpec([None, None, 3], tf.float64), - tf.TensorSpec([None, None], tf.int64), + tf.TensorSpec([None, None], tf.int32), tf.TensorSpec([None, None, model.get_nnei()], tf.int64), + tf.TensorSpec([None, None], tf.int64), tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), ], diff --git a/deepmd/jax/jax2tf/tfmodel.py b/deepmd/jax/jax2tf/tfmodel.py index ae60cad936..d9bb4de293 100644 --- a/deepmd/jax/jax2tf/tfmodel.py +++ b/deepmd/jax/jax2tf/tfmodel.py @@ -40,6 +40,11 @@ } +def decode_list_of_bytes(list_of_bytes: list[bytes]) -> list[str]: + """Decode a list of bytes to a list of strings.""" + return [x.decode() for x in list_of_bytes] + + class TFModelWrapper(tf.Module): def __init__( self, @@ -50,13 +55,15 @@ def __init__( self._call_lower_atomic_virial = jax2tf.call_tf( self.model.call_lower_atomic_virial ) - self.type_map = self.model.type_map.numpy().tolist() + self.type_map = decode_list_of_bytes(self.model.type_map.numpy().tolist()) self.rcut = self.model.rcut.numpy().item() self.dim_fparam = self.model.dim_fparam.numpy().item() self.dim_aparam = self.model.dim_aparam.numpy().item() self.sel_type = self.model.sel_type.numpy().tolist() self._is_aparam_nall = self.model.is_aparam_nall.numpy().item() - self._model_output_type = self.model.model_output_type.numpy().tolist() + self._model_output_type = decode_list_of_bytes( + self.model.model_output_type.numpy().tolist() + ) self._mixed_types = self.model.mixed_types.numpy().item() if hasattr(self.model, "min_nbor_dist"): self.min_nbor_dist = self.model.min_nbor_dist.numpy().item() @@ -168,6 +175,16 @@ def call_lower( call_lower = self._call_lower_atomic_virial else: call_lower = self._call_lower + # Attempt to convert a value (None) with an unsupported type () to a Tensor. + if fparam is None: + fparam = jnp.empty( + (extended_coord.shape[0], self.get_dim_fparam()), dtype=jnp.float64 + ) + if aparam is None: + aparam = jnp.empty( + (extended_coord.shape[0], nlist.shape[1], self.get_dim_aparam()), + dtype=jnp.float64, + ) return call_lower( extended_coord, extended_atype, diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index f48f75f6d2..6ab99a81f0 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -55,13 +55,13 @@ def deserialize_to_file(model_file: str, data: dict) -> None: def exported_whether_do_atomic_virial(do_atomic_virial): def call_lower_with_fixed_do_atomic_virial( - coord, atype, nlist, nlist_start, fparam, aparam + coord, atype, nlist, mapping, fparam, aparam ): return call_lower( coord, atype, nlist, - nlist_start, + mapping, fparam, aparam, do_atomic_virial=do_atomic_virial, From 6dd361730a3240a98cfe4e78ea483cdb9552bfed Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Nov 2024 04:30:36 -0500 Subject: [PATCH 05/16] test Signed-off-by: Jinzhe Zeng --- source/tests/consistent/io/test_io.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index 91cd391322..ce834831ef 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -75,7 +75,7 @@ def tearDown(self): def test_data_equal(self): prefix = "test_consistent_io_" + self.__class__.__name__.lower() for backend_name, suffix_idx in ( - ("tensorflow", 0), + # ("tensorflow", 0), ("pytorch", 0), ("dpmodel", 0), ("jax", 0), @@ -140,13 +140,22 @@ def test_deep_eval(self): nframes = self.atype.shape[0] prefix = "test_consistent_io_" + self.__class__.__name__.lower() rets = [] - for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"): + for backend_name, suffix_idx in ( + ("tensorflow", 0), + ("pytorch", 0), + ("dpmodel", 0), + ("jax", 0), + # unfortunately, jax2tf cannot work with tf v1 behaviors + # ("jax", 2), + ): backend = Backend.get_backend(backend_name)() if not backend.is_available(): continue reference_data = copy.deepcopy(self.data) - self.save_data_to_model(prefix + backend.suffixes[0], reference_data) - deep_eval = DeepEval(prefix + backend.suffixes[0]) + self.save_data_to_model( + prefix + backend.suffixes[suffix_idx], reference_data + ) + deep_eval = DeepEval(prefix + backend.suffixes[suffix_idx]) if deep_eval.get_dim_fparam() > 0: fparam = np.ones((nframes, deep_eval.get_dim_fparam())) else: From dc1de12503b271d3463572ed6da8abcd4b728d41 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Nov 2024 04:40:57 -0500 Subject: [PATCH 06/16] add an error message Signed-off-by: Jinzhe Zeng --- deepmd/jax/jax2tf/__init__.py | 10 ++++++++++ doc/backend.md | 3 ++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/deepmd/jax/jax2tf/__init__.py b/deepmd/jax/jax2tf/__init__.py index 6ceb116d85..88a928f04d 100644 --- a/deepmd/jax/jax2tf/__init__.py +++ b/deepmd/jax/jax2tf/__init__.py @@ -1 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import tensorflow as tf + +if not tf.executing_eagerly(): + # TF disallow temporary eager execution + raise RuntimeError( + "Unfortunatly, jax2tf (requires eager execution) cannot be used with the " + "TensorFlow backend (disables eager execution). " + "If you are converting a model between different backends, " + "considering converting to the `.dp` format first." + ) diff --git a/doc/backend.md b/doc/backend.md index cf99eea9cb..3fb70bee90 100644 --- a/doc/backend.md +++ b/doc/backend.md @@ -25,11 +25,12 @@ While `.pth` and `.pt` are the same in the PyTorch package, they have different ### JAX {{ jax_icon }} -- Model filename extension: `.xlo` +- Model filename extension: `.xlo`, `.savedmodel` - Checkpoint filename extension: `.jax` [JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required. Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions. +`.savedmodel` is the TensorFlow [SavedModel format](https://www.tensorflow.org/guide/saved_model) generated by [JAX2TF](https://www.tensorflow.org/guide/jax2tf), which needs the installation of TensorFlow. Currently, this backend is developed actively, and has no support for training and the C++ interface. ### DP {{ dpmodel_icon }} From 5a7fc4a59a3932a530bc637b76d5525d6d2bd247 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Nov 2024 04:41:42 -0500 Subject: [PATCH 07/16] revert Signed-off-by: Jinzhe Zeng --- source/tests/consistent/io/test_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index ce834831ef..bd48135acf 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -75,7 +75,7 @@ def tearDown(self): def test_data_equal(self): prefix = "test_consistent_io_" + self.__class__.__name__.lower() for backend_name, suffix_idx in ( - # ("tensorflow", 0), + ("tensorflow", 0), ("pytorch", 0), ("dpmodel", 0), ("jax", 0), From 980b4a935a7d73b7e3e55ec5d9621d1a60036dcc Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Nov 2024 04:50:02 -0500 Subject: [PATCH 08/16] Update deepmd/jax/jax2tf/tfmodel.py Signed-off-by: Jinzhe Zeng --- deepmd/jax/jax2tf/tfmodel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/jax/jax2tf/tfmodel.py b/deepmd/jax/jax2tf/tfmodel.py index d9bb4de293..7339835a4b 100644 --- a/deepmd/jax/jax2tf/tfmodel.py +++ b/deepmd/jax/jax2tf/tfmodel.py @@ -266,7 +266,7 @@ def get_min_nbor_dist(self) -> Optional[float]: def get_nnei(self) -> int: """Returns the total number of selected neighboring atoms in the cut-off radius.""" - return self.nsel + return self.get_nsel() def get_sel(self) -> list[int]: return self.sel From 8e216f5e082ccff398138f2d0a45ce7075dc05fe Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Nov 2024 16:54:35 -0500 Subject: [PATCH 09/16] use functions to store constants so it can be read by C++ Signed-off-by: Jinzhe Zeng --- deepmd/jax/jax2tf/serialization.py | 48 ++++++++++++++++++++---------- deepmd/jax/jax2tf/tfmodel.py | 24 +++++++-------- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index 869aa8edeb..50e5f7c7be 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -71,24 +71,42 @@ def call_lower_with_fixed_do_atomic_virial( tf_model.call_lower_atomic_virial = exported_whether_do_atomic_virial( do_atomic_virial=True ) - # set other attributes - tf_model.type_map = tf.Variable(model.get_type_map(), dtype=tf.string) - tf_model.rcut = tf.Variable(model.get_rcut(), dtype=tf.double) - tf_model.dim_fparam = tf.Variable(model.get_dim_fparam(), dtype=tf.int64) - tf_model.dim_aparam = tf.Variable(model.get_dim_aparam(), dtype=tf.int64) - tf_model.sel_type = tf.Variable(model.get_sel_type(), dtype=tf.int64) - tf_model.is_aparam_nall = tf.Variable(model.is_aparam_nall(), dtype=tf.bool) - tf_model.model_output_type = tf.Variable( - model.model_output_type(), dtype=tf.string + # set functions to export other attributes + tf_model.get_type_map = tf.function( + lambda: tf.constant(model.get_type_map(), dtype=tf.string) + ) + tf_model.get_rcut = tf.function( + lambda: tf.constant(model.get_rcut(), dtype=tf.double) + ) + tf_model.get_dim_fparam = tf.function( + lambda: tf.constant(model.get_dim_fparam(), dtype=tf.int64) + ) + tf_model.get_dim_aparam = tf.function( + lambda: tf.constant(model.get_dim_aparam(), dtype=tf.int64) + ) + tf_model.get_sel_type = tf.function( + lambda: tf.constant(model.get_sel_type(), dtype=tf.int64) + ) + tf_model.is_aparam_nall = tf.function( + lambda: tf.constant(model.is_aparam_nall(), dtype=tf.bool) + ) + tf_model.model_output_type = tf.function( + lambda: tf.constant(model.model_output_type(), dtype=tf.string) + ) + tf_model.mixed_types = tf.function( + lambda: tf.constant(model.mixed_types(), dtype=tf.bool) ) - tf_model.mixed_types = tf.Variable(model.mixed_types(), dtype=tf.bool) if model.get_min_nbor_dist() is not None: - tf_model.min_nbor_dist = tf.Variable( - model.get_min_nbor_dist(), dtype=tf.double + tf_model.get_min_nbor_dist = tf.function( + lambda: tf.constant(model.get_min_nbor_dist(), dtype=tf.double) + ) + tf_model.get_sel = tf.function( + lambda: tf.constant(model.get_sel(), dtype=tf.int64) + ) + tf_model.get_model_def_script = tf.function( + lambda: tf.constant( + json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string ) - tf_model.sel = tf.Variable(model.get_sel(), dtype=tf.int64) - tf_model.model_def_script = tf.Variable( - json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string ) tf.saved_model.save( tf_model, diff --git a/deepmd/jax/jax2tf/tfmodel.py b/deepmd/jax/jax2tf/tfmodel.py index 7339835a4b..8f04014a97 100644 --- a/deepmd/jax/jax2tf/tfmodel.py +++ b/deepmd/jax/jax2tf/tfmodel.py @@ -55,22 +55,22 @@ def __init__( self._call_lower_atomic_virial = jax2tf.call_tf( self.model.call_lower_atomic_virial ) - self.type_map = decode_list_of_bytes(self.model.type_map.numpy().tolist()) - self.rcut = self.model.rcut.numpy().item() - self.dim_fparam = self.model.dim_fparam.numpy().item() - self.dim_aparam = self.model.dim_aparam.numpy().item() - self.sel_type = self.model.sel_type.numpy().tolist() - self._is_aparam_nall = self.model.is_aparam_nall.numpy().item() + self.type_map = decode_list_of_bytes(self.model.get_type_map().numpy().tolist()) + self.rcut = self.model.get_rcut().numpy().item() + self.dim_fparam = self.model.get_dim_fparam().numpy().item() + self.dim_aparam = self.model.get_dim_aparam().numpy().item() + self.sel_type = self.model.get_sel_type().numpy().tolist() + self._is_aparam_nall = self.model.is_aparam_nall().numpy().item() self._model_output_type = decode_list_of_bytes( - self.model.model_output_type.numpy().tolist() + self.model.model_output_type().numpy().tolist() ) - self._mixed_types = self.model.mixed_types.numpy().item() - if hasattr(self.model, "min_nbor_dist"): - self.min_nbor_dist = self.model.min_nbor_dist.numpy().item() + self._mixed_types = self.model.mixed_types().numpy().item() + if hasattr(self.model, "get_min_nbor_dist"): + self.min_nbor_dist = self.model.get_min_nbor_dist().numpy().item() else: self.min_nbor_dist = None - self.sel = self.model.sel.numpy().tolist() - self.model_def_script = self.model.model_def_script.numpy().decode() + self.sel = self.model.get_sel().numpy().tolist() + self.model_def_script = self.model.get_model_def_script().numpy().decode() def __call__( self, From 3e919a118b344d9fbc0bf34ecf4c8397a6316e83 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Nov 2024 18:39:45 -0500 Subject: [PATCH 10/16] name functions Signed-off-by: Jinzhe Zeng --- deepmd/jax/jax2tf/serialization.py | 171 +++++++++++++++++++---------- 1 file changed, 114 insertions(+), 57 deletions(-) diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index 50e5f7c7be..dff43a11fc 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -42,72 +42,129 @@ def call_lower_with_fixed_do_atomic_virial( do_atomic_virial=do_atomic_virial, ) - return tf.function( - jax2tf.convert( - call_lower_with_fixed_do_atomic_virial, - polymorphic_shapes=[ - "(nf, nloc + nghost, 3)", - "(nf, nloc + nghost)", - f"(nf, nloc, {model.get_nnei()})", - "(nf, nloc + nghost)", - f"(nf, {model.get_dim_fparam()})", - f"(nf, nloc, {model.get_dim_aparam()})", - ], - with_gradient=True, - ), - autograph=False, - input_signature=[ - tf.TensorSpec([None, None, 3], tf.float64), - tf.TensorSpec([None, None], tf.int32), - tf.TensorSpec([None, None, model.get_nnei()], tf.int64), - tf.TensorSpec([None, None], tf.int64), - tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), - tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + return jax2tf.convert( + call_lower_with_fixed_do_atomic_virial, + polymorphic_shapes=[ + "(nf, nloc + nghost, 3)", + "(nf, nloc + nghost)", + f"(nf, nloc, {model.get_nnei()})", + "(nf, nloc + nghost)", + f"(nf, {model.get_dim_fparam()})", + f"(nf, nloc, {model.get_dim_aparam()})", ], + with_gradient=True, ) # Save a function that can take scalar inputs. - tf_model.call_lower = exported_whether_do_atomic_virial(do_atomic_virial=False) - tf_model.call_lower_atomic_virial = exported_whether_do_atomic_virial( - do_atomic_virial=True + # We need to explicit set the function name, so C++ can find it. + @tf.function( + autograph=False, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int32), + tf.TensorSpec([None, None, model.get_nnei()], tf.int64), + tf.TensorSpec([None, None], tf.int64), + tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), + tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + ], ) - # set functions to export other attributes - tf_model.get_type_map = tf.function( - lambda: tf.constant(model.get_type_map(), dtype=tf.string) - ) - tf_model.get_rcut = tf.function( - lambda: tf.constant(model.get_rcut(), dtype=tf.double) - ) - tf_model.get_dim_fparam = tf.function( - lambda: tf.constant(model.get_dim_fparam(), dtype=tf.int64) - ) - tf_model.get_dim_aparam = tf.function( - lambda: tf.constant(model.get_dim_aparam(), dtype=tf.int64) - ) - tf_model.get_sel_type = tf.function( - lambda: tf.constant(model.get_sel_type(), dtype=tf.int64) - ) - tf_model.is_aparam_nall = tf.function( - lambda: tf.constant(model.is_aparam_nall(), dtype=tf.bool) - ) - tf_model.model_output_type = tf.function( - lambda: tf.constant(model.model_output_type(), dtype=tf.string) - ) - tf_model.mixed_types = tf.function( - lambda: tf.constant(model.mixed_types(), dtype=tf.bool) - ) - if model.get_min_nbor_dist() is not None: - tf_model.get_min_nbor_dist = tf.function( - lambda: tf.constant(model.get_min_nbor_dist(), dtype=tf.double) + def call_lower_without_atomic_virial( + coord, atype, nlist, mapping, fparam, aparam + ): + return exported_whether_do_atomic_virial(do_atomic_virial=False)( + coord, atype, nlist, mapping, fparam, aparam ) - tf_model.get_sel = tf.function( - lambda: tf.constant(model.get_sel(), dtype=tf.int64) + + tf_model.call_lower = call_lower_without_atomic_virial + + @tf.function( + autograph=False, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int32), + tf.TensorSpec([None, None, model.get_nnei()], tf.int64), + tf.TensorSpec([None, None], tf.int64), + tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), + tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + ], ) - tf_model.get_model_def_script = tf.function( - lambda: tf.constant( + def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam): + return exported_whether_do_atomic_virial(do_atomic_virial=True)( + coord, atype, nlist, mapping, fparam, aparam + ) + + tf_model.call_lower_atomic_virial = call_lower_with_atomic_virial + + # set functions to export other attributes + @tf.function + def get_type_map(): + return tf.constant(model.get_type_map(), dtype=tf.string) + + tf_model.get_type_map = get_type_map + + @tf.function + def get_rcut(): + return tf.constant(model.get_rcut(), dtype=tf.double) + + tf_model.get_rcut = get_rcut + + @tf.function + def get_dim_fparam(): + return tf.constant(model.get_dim_fparam(), dtype=tf.int64) + + tf_model.get_dim_fparam = get_dim_fparam + + @tf.function + def get_dim_aparam(): + return tf.constant(model.get_dim_aparam(), dtype=tf.int64) + + tf_model.get_dim_aparam = get_dim_aparam + + @tf.function + def get_sel_type(): + return tf.constant(model.get_sel_type(), dtype=tf.int64) + + tf_model.get_sel_type = get_sel_type + + @tf.function + def is_aparam_nall(): + return tf.constant(model.is_aparam_nall(), dtype=tf.bool) + + tf_model.is_aparam_nall = is_aparam_nall + + @tf.function + def model_output_type(): + return tf.constant(model.model_output_type(), dtype=tf.string) + + tf_model.model_output_type = model_output_type + + @tf.function + def mixed_types(): + return tf.constant(model.mixed_types(), dtype=tf.bool) + + tf_model.mixed_types = mixed_types + + if model.get_min_nbor_dist() is not None: + + @tf.function + def get_min_nbor_dist(): + return tf.constant(model.get_min_nbor_dist(), dtype=tf.double) + + tf_model.get_min_nbor_dist = get_min_nbor_dist + + @tf.function + def get_sel(): + return tf.constant(model.get_sel(), dtype=tf.int64) + + tf_model.get_sel = get_sel + + @tf.function + def get_model_def_script(): + return tf.constant( json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string ) - ) + + tf_model.get_model_def_script = get_model_def_script tf.saved_model.save( tf_model, model_file, From 60694a8e188fe90b6a6de95645b885177daaba14 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Nov 2024 21:20:29 -0500 Subject: [PATCH 11/16] test test_io in a seperated run Signed-off-by: Jinzhe Zeng --- .github/workflows/test_python.yml | 6 ++++++ source/tests/consistent/io/test_io.py | 7 ++++--- source/tests/utils.py | 1 + 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index e46bddd98a..649efce05c 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -53,6 +53,12 @@ jobs: - run: pytest --cov=deepmd source/tests --durations=0 --splits 6 --group ${{ matrix.group }} --store-durations --durations-path=.test_durations --splitting-algorithm least_duration env: NUM_WORKERS: 0 + - name: Test TF2 eager mode + run: pytest --cov=deepmd source/tests/consistent/io/test_io.py --durations=0 + env: + NUM_WORKERS: 0 + DP_TEST_TF2_ONLY: 1 + if: matrix.group == 1 - run: mv .test_durations .test_durations_${{ matrix.group }} - name: Upload partial durations uses: actions/upload-artifact@v4 diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index bd48135acf..9b59a7eb36 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -23,6 +23,7 @@ from ...utils import ( CI, + DP_TEST_TF2_ONLY, TEST_DEVICE, ) @@ -72,6 +73,7 @@ def tearDown(self): shutil.rmtree(ii) @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") + @unittest.skipIf(DP_TEST_TF2_ONLY, "Conflict with TF2 eager mode.") def test_data_equal(self): prefix = "test_consistent_io_" + self.__class__.__name__.lower() for backend_name, suffix_idx in ( @@ -141,12 +143,11 @@ def test_deep_eval(self): prefix = "test_consistent_io_" + self.__class__.__name__.lower() rets = [] for backend_name, suffix_idx in ( - ("tensorflow", 0), + # unfortunately, jax2tf cannot work with tf v1 behaviors + ("jax", 2) if DP_TEST_TF2_ONLY else ("tensorflow", 0), ("pytorch", 0), ("dpmodel", 0), ("jax", 0), - # unfortunately, jax2tf cannot work with tf v1 behaviors - # ("jax", 2), ): backend = Backend.get_backend(backend_name)() if not backend.is_available(): diff --git a/source/tests/utils.py b/source/tests/utils.py index bfb3d445af..a9bf0f11ea 100644 --- a/source/tests/utils.py +++ b/source/tests/utils.py @@ -8,3 +8,4 @@ # see https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/store-information-in-variables#default-environment-variables CI = os.environ.get("CI") == "true" +DP_TEST_TF2_ONLY = os.environ.get("DP_TEST_TF2_ONLY") == "1" From 15750a21ef1220eb52851f4be9801477f704fdeb Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Nov 2024 22:52:25 -0500 Subject: [PATCH 12/16] bump tensorflow Signed-off-by: Jinzhe Zeng --- .github/workflows/test_python.yml | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index 649efce05c..55fa283f9e 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -25,19 +25,22 @@ jobs: python-version: ${{ matrix.python }} - run: python -m pip install -U uv - run: | - source/install/uv_with_retry.sh pip install --system mpich + source/install/uv_with_retry.sh pip install --system mpich tensorflow-cpu source/install/uv_with_retry.sh pip install --system torch -i https://download.pytorch.org/whl/cpu + export TENSORFLOW_ROOT=$(python -c 'import tensorflow;print(tensorflow.__path__[0])') export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') - source/install/uv_with_retry.sh pip install --system --only-binary=horovod -e .[cpu,test,jax] horovod[tensorflow-cpu] mpi4py + source/install/uv_with_retry.sh pip install --system -e .[test,jax] horovod mpi4py env: # Please note that uv has some issues with finding # existing TensorFlow package. Currently, it uses # TensorFlow in the build dependency, but if it # changes, setting `TENSORFLOW_ROOT`. - TENSORFLOW_VERSION: 2.16.1 DP_ENABLE_PYTORCH: 1 DP_BUILD_TESTING: 1 - UV_EXTRA_INDEX_URL: "https://pypi.anaconda.org/njzjz/simple https://pypi.anaconda.org/mpi4py/simple" + UV_EXTRA_INDEX_URL: "https://pypi.anaconda.org/mpi4py/simple" + HOROVOD_WITH_TENSORFLOW: 1 + HOROVOD_WITHOUT_PYTORCH: 1 + HOROVOD_WITH_MPI: 1 - run: dp --version - name: Get durations from cache uses: actions/cache@v4 From e2738cbb58e67a5b4596513caf3ff97f54d536ba Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Nov 2024 22:58:06 -0500 Subject: [PATCH 13/16] no-build-isolation Signed-off-by: Jinzhe Zeng --- .github/workflows/test_python.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index 55fa283f9e..f5a44b2200 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -29,7 +29,8 @@ jobs: source/install/uv_with_retry.sh pip install --system torch -i https://download.pytorch.org/whl/cpu export TENSORFLOW_ROOT=$(python -c 'import tensorflow;print(tensorflow.__path__[0])') export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') - source/install/uv_with_retry.sh pip install --system -e .[test,jax] horovod mpi4py + source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py + source/install/uv_with_retry.sh pip install --system -e horovod --no-build-isolation env: # Please note that uv has some issues with finding # existing TensorFlow package. Currently, it uses From de0caa1eaa39c0b6463e34dabd04413358e14820 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Nov 2024 23:03:02 -0500 Subject: [PATCH 14/16] typo Signed-off-by: Jinzhe Zeng --- .github/workflows/test_python.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index f5a44b2200..eea66086c9 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -30,7 +30,7 @@ jobs: export TENSORFLOW_ROOT=$(python -c 'import tensorflow;print(tensorflow.__path__[0])') export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py - source/install/uv_with_retry.sh pip install --system -e horovod --no-build-isolation + source/install/uv_with_retry.sh pip install --system horovod --no-build-isolation env: # Please note that uv has some issues with finding # existing TensorFlow package. Currently, it uses From 306eef319b5473e36e34e5bf62ffb123a85b6eef Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Nov 2024 23:17:23 -0500 Subject: [PATCH 15/16] fix eval argument in the test Signed-off-by: Jinzhe Zeng --- source/tests/consistent/io/test_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index 9b59a7eb36..ca213da13c 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -179,7 +179,7 @@ def test_deep_eval(self): self.atype, fparam=fparam, aparam=aparam, - do_atomic_virial=True, + atomic=True, ) rets.append(ret) for ret in rets[1:]: From 9a941183fbd70e6ba0f25cb4b1dc810cd7e7c846 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 4 Nov 2024 00:53:14 -0500 Subject: [PATCH 16/16] change to openmpi Signed-off-by: Jinzhe Zeng --- .github/workflows/test_python.yml | 2 +- pyproject.toml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index eea66086c9..422dcb5f17 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -25,7 +25,7 @@ jobs: python-version: ${{ matrix.python }} - run: python -m pip install -U uv - run: | - source/install/uv_with_retry.sh pip install --system mpich tensorflow-cpu + source/install/uv_with_retry.sh pip install --system openmpi tensorflow-cpu source/install/uv_with_retry.sh pip install --system torch -i https://download.pytorch.org/whl/cpu export TENSORFLOW_ROOT=$(python -c 'import tensorflow;print(tensorflow.__path__[0])') export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') diff --git a/pyproject.toml b/pyproject.toml index 1faacb973c..802e920014 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -444,6 +444,7 @@ select = [ [tool.uv.sources] mpich = { index = "mpi4py" } +openmpi = { index = "mpi4py" } [[tool.uv.index]] name = "mpi4py"