Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions deepmd_utils/model_format/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
FittingOutputDef,
ModelOutputDef,
OutputVariableDef,
VariableDef,
fitting_check_output,
get_deriv_name,
get_reduce_name,
model_check_output,
)
from .se_e2_a import (
Expand All @@ -52,7 +53,8 @@
"ModelOutputDef",
"FittingOutputDef",
"OutputVariableDef",
"VariableDef",
"model_check_output",
"fitting_check_output",
"get_reduce_name",
"get_deriv_name",
]
115 changes: 59 additions & 56 deletions deepmd_utils/model_format/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,34 @@
Dict,
List,
Tuple,
Union,
)


def check_shape(
shape: List[int],
def_shape: List[int],
):
"""Check if the shape satisfies the defined shape."""
assert len(shape) == len(def_shape)
if def_shape[-1] == -1:
if list(shape[:-1]) != def_shape[:-1]:
raise ValueError(f"{shape[:-1]} shape not matching def {def_shape[:-1]}")
else:
if list(shape) != def_shape:
raise ValueError(f"{shape} shape not matching def {def_shape}")


def check_var(var, var_def):
if var_def.atomic:
# var.shape == [nf, nloc, *var_def.shape]
if len(var.shape) != len(var_def.shape) + 2:
raise ValueError(f"{var.shape[2:]} length not matching def {var_def.shape}")
if list(var.shape[2:]) != var_def.shape:
raise ValueError(f"{var.shape[2:]} not matching def {var_def.shape}")
check_shape(list(var.shape[2:]), var_def.shape)
else:
# var.shape == [nf, *var_def.shape]
if len(var.shape) != len(var_def.shape) + 1:
raise ValueError(f"{var.shape[1:]} length not matching def {var_def.shape}")
if list(var.shape[1:]) != var_def.shape:
raise ValueError(f"{var.shape[1:]} not matching def {var_def.shape}")
check_shape(list(var.shape[1:]), var_def.shape)


def model_check_output(cls):
Expand All @@ -38,7 +49,7 @@ def __init__(
**kwargs,
):
super().__init__(*args, **kwargs)
self.md = cls.output_def(self)
self.md = self.output_def()

def __call__(
self,
Expand Down Expand Up @@ -77,7 +88,7 @@ def __init__(
**kwargs,
):
super().__init__(*args, **kwargs)
self.md = cls.output_def(self)
self.md = self.output_def()

def __call__(
self,
Expand All @@ -93,35 +104,7 @@ def __call__(
return wrapper


class VariableDef:
"""Defines the shape and other properties of a variable.

Parameters
----------
name
Name of the output variable. Notice that the xxxx_redu,
xxxx_derv_c, xxxx_derv_r are reserved names that should
not be used to define variables.
shape
The shape of the variable. e.g. energy should be [1],
dipole should be [3], polarizabilty should be [3,3].
atomic
If the variable is defined for each atom.

"""

def __init__(
self,
name: str,
shape: Union[List[int], Tuple[int]],
atomic: bool = True,
):
self.name = name
self.shape = list(shape)
self.atomic = atomic


class OutputVariableDef(VariableDef):
class OutputVariableDef:
"""Defines the shape and other properties of the one output variable.

It is assume that the fitting network output variables for each
Expand Down Expand Up @@ -149,12 +132,14 @@ class OutputVariableDef(VariableDef):
def __init__(
self,
name: str,
shape: Union[List[int], Tuple[int]],
shape: List[int],
reduciable: bool = False,
differentiable: bool = False,
atomic: bool = True,
):
# fitting output must be atomic
super().__init__(name, shape, atomic=True)
self.name = name
self.shape = list(shape)
self.atomic = atomic
self.reduciable = reduciable
self.differentiable = differentiable
if not self.reduciable and self.differentiable:
Expand All @@ -176,13 +161,13 @@ class FittingOutputDef:

def __init__(
self,
var_defs: List[OutputVariableDef] = [],
var_defs: List[OutputVariableDef],
):
self.var_defs = {vv.name: vv for vv in var_defs}

def __getitem__(
self,
key,
key: str,
) -> OutputVariableDef:
return self.var_defs[key]

Expand Down Expand Up @@ -215,7 +200,7 @@ def __init__(
self.def_outp = fit_defs
self.def_redu = do_reduce(self.def_outp)
self.def_derv_r, self.def_derv_c = do_derivative(self.def_outp)
self.var_defs = {}
self.var_defs: Dict[str, OutputVariableDef] = {}
for ii in [
self.def_outp.get_data(),
self.def_redu,
Expand All @@ -224,10 +209,16 @@ def __init__(
]:
self.var_defs.update(ii)

def __getitem__(self, key) -> VariableDef:
def __getitem__(
self,
key: str,
) -> OutputVariableDef:
return self.var_defs[key]

def get_data(self, key) -> Dict[str, VariableDef]:
def get_data(
self,
key: str,
) -> Dict[str, OutputVariableDef]:
return self.var_defs

def keys(self):
Expand All @@ -246,33 +237,45 @@ def keys_derv_c(self):
return self.def_derv_c.keys()


def get_reduce_name(name):
def get_reduce_name(name: str) -> str:
return name + "_redu"


def get_deriv_name(name):
def get_deriv_name(name: str) -> Tuple[str, str]:
return name + "_derv_r", name + "_derv_c"


def do_reduce(
def_outp,
):
def_redu = {}
def_outp: FittingOutputDef,
) -> Dict[str, OutputVariableDef]:
def_redu: Dict[str, OutputVariableDef] = {}
for kk, vv in def_outp.get_data().items():
if vv.reduciable:
rk = get_reduce_name(kk)
def_redu[rk] = VariableDef(rk, vv.shape, atomic=False)
def_redu[rk] = OutputVariableDef(
rk, vv.shape, reduciable=False, differentiable=False, atomic=False
)
return def_redu


def do_derivative(
def_outp,
):
def_derv_r = {}
def_derv_c = {}
def_outp: FittingOutputDef,
) -> Tuple[Dict[str, OutputVariableDef], Dict[str, OutputVariableDef]]:
def_derv_r: Dict[str, OutputVariableDef] = {}
def_derv_c: Dict[str, OutputVariableDef] = {}
for kk, vv in def_outp.get_data().items():
if vv.differentiable:
rkr, rkc = get_deriv_name(kk)
def_derv_r[rkr] = VariableDef(rkr, [*vv.shape, 3], atomic=True)
def_derv_c[rkc] = VariableDef(rkc, [*vv.shape, 3, 3], atomic=False)
def_derv_r[rkr] = OutputVariableDef(
rkr,
vv.shape + [3], # noqa: RUF005
reduciable=False,
differentiable=False,
)
def_derv_c[rkc] = OutputVariableDef(
rkc,
vv.shape + [3, 3], # noqa: RUF005
reduciable=True,
differentiable=False,
)
return def_derv_r, def_derv_c
71 changes: 65 additions & 6 deletions source/tests/test_output_def.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest
from typing import (
List,
)

import numpy as np

Expand All @@ -11,6 +14,21 @@
fitting_check_output,
model_check_output,
)
from deepmd_utils.model_format.output_def import (
check_var,
)


class VariableDef:
def __init__(
self,
name: str,
shape: List[int],
atomic: bool = True,
):
self.name = name
self.shape = list(shape)
self.atomic = atomic


class TestDef(unittest.TestCase):
Expand Down Expand Up @@ -81,7 +99,7 @@ def test_model_output_def(self):
self.assertEqual(md["foo"].atomic, True)
self.assertEqual(md["energy_redu"].atomic, False)
self.assertEqual(md["energy_derv_r"].atomic, True)
self.assertEqual(md["energy_derv_c"].atomic, False)
self.assertEqual(md["energy_derv_c"].atomic, True)

def test_raise_no_redu_deriv(self):
with self.assertRaises(ValueError) as context:
Expand All @@ -90,6 +108,7 @@ def test_raise_no_redu_deriv(self):
def test_model_decorator(self):
nf = 2
nloc = 3
nall = 4

@model_check_output
class Foo(NativeOP):
Expand All @@ -103,8 +122,8 @@ def call(self):
return {
"energy": np.zeros([nf, nloc, 1]),
"energy_redu": np.zeros([nf, 1]),
"energy_derv_r": np.zeros([nf, nloc, 1, 3]),
"energy_derv_c": np.zeros([nf, 1, 3, 3]),
"energy_derv_r": np.zeros([nf, nall, 1, 3]),
"energy_derv_c": np.zeros([nf, nall, 1, 3, 3]),
}

ff = Foo()
Expand All @@ -113,6 +132,7 @@ def call(self):
def test_model_decorator_keyerror(self):
nf = 2
nloc = 3
nall = 4

@model_check_output
class Foo(NativeOP):
Expand All @@ -129,7 +149,7 @@ def call(self):
return {
"energy": np.zeros([nf, nloc, 1]),
"energy_redu": np.zeros([nf, 1]),
"energy_derv_c": np.zeros([nf, 1, 3, 3]),
"energy_derv_c": np.zeros([nf, nall, 1, 3, 3]),
}

ff = Foo()
Expand All @@ -140,13 +160,14 @@ def call(self):
def test_model_decorator_shapeerror(self):
nf = 2
nloc = 3
nall = 4

@model_check_output
class Foo(NativeOP):
def __init__(
self,
shape_rd=[nf, 1],
shape_dr=[nf, nloc, 1, 3],
shape_dr=[nf, nall, 1, 3],
):
self.shape_rd, self.shape_dr = shape_rd, shape_dr

Expand All @@ -161,7 +182,7 @@ def call(self):
"energy": np.zeros([nf, nloc, 1]),
"energy_redu": np.zeros(self.shape_rd),
"energy_derv_r": np.zeros(self.shape_dr),
"energy_derv_c": np.zeros([nf, 1, 3, 3]),
"energy_derv_c": np.zeros([nf, nall, 1, 3, 3]),
}

ff = Foo()
Expand Down Expand Up @@ -192,6 +213,7 @@ def call(self):
def test_fitting_decorator(self):
nf = 2
nloc = 3
nall = 4

Check notice

Code scanning / CodeQL

Unused local variable

Variable nall is not used.

@fitting_check_output
class Foo(NativeOP):
Expand Down Expand Up @@ -243,3 +265,40 @@ def call(self):
ff = Foo(shape=[nf, nloc, 2])
ff()
self.assertIn("not matching", context.exception)

def test_check_var(self):
var_def = VariableDef("foo", [2, 3], atomic=True)
with self.assertRaises(ValueError) as context:
check_var(np.zeros([2, 3, 4, 5, 6]), var_def)
self.assertIn("length not matching", context.exception)
with self.assertRaises(ValueError) as context:
check_var(np.zeros([2, 3, 4, 5]), var_def)
self.assertIn("shape not matching", context.exception)
check_var(np.zeros([2, 3, 2, 3]), var_def)

var_def = VariableDef("foo", [2, 3], atomic=False)
with self.assertRaises(ValueError) as context:
check_var(np.zeros([2, 3, 4, 5]), var_def)
self.assertIn("length not matching", context.exception)
with self.assertRaises(ValueError) as context:
check_var(np.zeros([2, 3, 4]), var_def)
self.assertIn("shape not matching", context.exception)
check_var(np.zeros([2, 2, 3]), var_def)

var_def = VariableDef("foo", [2, -1], atomic=True)
with self.assertRaises(ValueError) as context:
check_var(np.zeros([2, 3, 4, 5, 6]), var_def)
self.assertIn("length not matching", context.exception)
with self.assertRaises(ValueError) as context:
check_var(np.zeros([2, 3, 4, 5]), var_def)
self.assertIn("shape not matching", context.exception)
check_var(np.zeros([2, 3, 2, 8]), var_def)

var_def = VariableDef("foo", [2, -1], atomic=False)
with self.assertRaises(ValueError) as context:
check_var(np.zeros([2, 3, 4, 5]), var_def)
self.assertIn("length not matching", context.exception)
with self.assertRaises(ValueError) as context:
check_var(np.zeros([2, 3, 4]), var_def)
self.assertIn("shape not matching", context.exception)
check_var(np.zeros([2, 2, 8]), var_def)