Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- name: Install rdkit
run: python -m pip install rdkit openbabel-wheel
- name: Install dependencies
run: python -m pip install .[amber,ase,pymatgen] coverage
run: python -m pip install .[amber,ase,pymatgen] coverage ./tests/plugin
- name: Test
run: cd tests && coverage run --source=../dpdata -m unittest && cd .. && coverage combine tests/.coverage && coverage report
- name: Run codecov
Expand Down
145 changes: 145 additions & 0 deletions dpdata/data_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from enum import Enum, unique
from typing import TYPE_CHECKING, Tuple

import numpy as np

from dpdata.plugin import Plugin

if TYPE_CHECKING:
from dpdata.system import System


@unique
class Axis(Enum):
"""Data axis."""

NFRAMES = "nframes"
NATOMS = "natoms"
NTYPES = "ntypes"
NBONDS = "nbonds"


class DataError(Exception):
"""Data is not correct."""


class DataType:
"""DataType represents a type of data, like coordinates, energies, etc.

Parameters
----------
name : str
name of data
dtype : type or tuple[type]
data type, e.g. np.ndarray
shape : tuple[int], optional
shape of data. Used when data is list or np.ndarray. Use Axis to
represents numbers
required : bool, default=True
whether this data is required
"""

def __init__(
self,
name: str,
dtype: type,
shape: Tuple[int, Axis] = None,
required: bool = True,
) -> None:
self.name = name
self.dtype = dtype
self.shape = shape
self.required = required

def real_shape(self, system: "System") -> Tuple[int]:
"""Returns expected real shape of a system."""
shape = []
for ii in self.shape:
if ii is Axis.NFRAMES:
shape.append(system.get_nframes())
elif ii is Axis.NTYPES:
shape.append(system.get_ntypes())
elif ii is Axis.NATOMS:
shape.append(system.get_natoms())
elif ii is Axis.NBONDS:
# BondOrderSystem
shape.append(system.get_nbonds())
elif isinstance(ii, int):
shape.append(ii)
else:
raise RuntimeError("Shape is not an int!")
return tuple(shape)

def check(self, system: "System"):
"""Check if a system has correct data of this type.

Parameters
----------
system : System
checked system

Raises
------
DataError
type or shape of data is not correct
"""
# check if exists
if self.name in system.data:
data = system.data[self.name]
# check dtype
# allow list for empty np.ndarray
if isinstance(data, list) and not len(data):
pass
elif not isinstance(data, self.dtype):
raise DataError(
f"Type of {self.name} is {type(data).__name__}, but expected {self.dtype.__name__}"
)
# check shape
if self.shape is not None:
shape = self.real_shape(system)
# skip checking empty list of np.ndarray
if isinstance(data, np.ndarray):
if data.size and shape != data.shape:
raise DataError(
f"Shape of {self.name} is {data.shape}, but expected {shape}"
)
elif isinstance(data, list):
if len(shape) and shape[0] != len(data):
raise DataError(
"Length of %s is %d, but expected %d"
% (self.name, len(data), shape[0])
)
else:
raise RuntimeError("Unsupported type to check shape")
elif self.required:
raise DataError("%s not found in data" % self.name)


__system_data_type_plugin = Plugin()
__labeled_system_data_type_plugin = Plugin()


def register_data_type(data_type: DataType, labeled: bool):
"""Register a data type.

Parameters
----------
data_type : DataType
data type to be registered
labeled : bool
whether this data type is for LabeledSystem
"""
plugin = __labeled_system_data_type_plugin if labeled else __system_data_type_plugin
plugin.register(data_type.name)(data_type)


def get_data_types(labeled: bool):
"""Get all registered data types.

Parameters
----------
labeled : bool
whether this data type is for LabeledSystem
"""
plugin = __labeled_system_data_type_plugin if labeled else __system_data_type_plugin
return tuple(plugin.plugins.values())
116 changes: 8 additions & 108 deletions dpdata/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import glob
import os
from copy import deepcopy
from enum import Enum, unique
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
Expand All @@ -15,6 +14,7 @@
# ensure all plugins are loaded!
import dpdata.plugins
from dpdata.amber.mask import load_param_file, pick_by_amber_mask
from dpdata.data_type import Axis, DataError, DataType, get_data_types
from dpdata.driver import Driver, Minimizer
from dpdata.format import Format
from dpdata.plugin import Plugin
Expand All @@ -33,112 +33,6 @@ def load_format(fmt):
)


@unique
class Axis(Enum):
"""Data axis."""

NFRAMES = "nframes"
NATOMS = "natoms"
NTYPES = "ntypes"
NBONDS = "nbonds"


class DataError(Exception):
"""Data is not correct."""


class DataType:
"""DataType represents a type of data, like coordinates, energies, etc.

Parameters
----------
name : str
name of data
dtype : type or tuple[type]
data type, e.g. np.ndarray
shape : tuple[int], optional
shape of data. Used when data is list or np.ndarray. Use Axis to
represents numbers
required : bool, default=True
whether this data is required
"""

def __init__(
self,
name: str,
dtype: type,
shape: Tuple[int, Axis] = None,
required: bool = True,
) -> None:
self.name = name
self.dtype = dtype
self.shape = shape
self.required = required

def real_shape(self, system: "System") -> Tuple[int]:
"""Returns expected real shape of a system."""
shape = []
for ii in self.shape:
if ii is Axis.NFRAMES:
shape.append(system.get_nframes())
elif ii is Axis.NTYPES:
shape.append(system.get_ntypes())
elif ii is Axis.NATOMS:
shape.append(system.get_natoms())
elif ii is Axis.NBONDS:
# BondOrderSystem
shape.append(system.get_nbonds())
elif isinstance(ii, int):
shape.append(ii)
else:
raise RuntimeError("Shape is not an int!")
return tuple(shape)

def check(self, system: "System"):
"""Check if a system has correct data of this type.

Parameters
----------
system : System
checked system

Raises
------
DataError
type or shape of data is not correct
"""
# check if exists
if self.name in system.data:
data = system.data[self.name]
# check dtype
# allow list for empty np.ndarray
if isinstance(data, list) and not len(data):
pass
elif not isinstance(data, self.dtype):
raise DataError(
f"Type of {self.name} is {type(data).__name__}, but expected {self.dtype.__name__}"
)
# check shape
if self.shape is not None:
shape = self.real_shape(system)
# skip checking empty list of np.ndarray
if isinstance(data, np.ndarray):
if data.size and shape != data.shape:
raise DataError(
f"Shape of {self.name} is {data.shape}, but expected {shape}"
)
elif isinstance(data, list):
if len(shape) and shape[0] != len(data):
raise DataError(
"Length of %s is %d, but expected %d"
% (self.name, len(data), shape[0])
)
else:
raise RuntimeError("Unsupported type to check shape")
elif self.required:
raise DataError("%s not found in data" % self.name)


class System(MSONable):
"""The data System.

Expand Down Expand Up @@ -1657,7 +1551,8 @@ def get_cls_name(cls: object) -> str:


def add_format_methods():
"""Add format methods to System, LabeledSystem, and MultiSystems.
"""Add format methods to System, LabeledSystem, and MultiSystems; add data types
to System and LabeledSystem.

Notes
-----
Expand Down Expand Up @@ -1701,5 +1596,10 @@ def to_format(self, *args, **kwargs):
setattr(LabeledSystem, method, get_func(formatcls))
setattr(MultiSystems, method, get_func(formatcls))

# at this point, System.DTYPES and LabeledSystem.DTYPES has been initialized
System.DTYPES = System.DTYPES + get_data_types(labeled=False)
LabeledSystem.DTYPES = LabeledSystem.DTYPES + get_data_types(labeled=False)
LabeledSystem.DTYPES = LabeledSystem.DTYPES + get_data_types(labeled=True)


add_format_methods()
11 changes: 11 additions & 0 deletions tests/plugin/dpdata_plugin_test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import numpy as np

from dpdata.data_type import Axis, DataType, register_data_type

# test data type

register_data_type(
DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False), labeled=True
)

ep = None
17 changes: 17 additions & 0 deletions tests/plugin/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[build-system]
requires = ["setuptools>=61"]
build-backend = "setuptools.build_meta"

[project]
name = "dpdata_plugin_test"
version = "0.0.0"
description = "A test for dpdata plugin"
dependencies = [
'numpy',
'dpdata',
]
readme = "README.md"
requires-python = ">=3.7"

[project.entry-points.'dpdata.plugins']
random = "dpdata_plugin_test:ep"
8 changes: 0 additions & 8 deletions tests/test_custom_data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,15 @@
import numpy as np

import dpdata
from dpdata.system import Axis, DataType


class TestDeepmdLoadDumpComp(unittest.TestCase):
def setUp(self):
self.backup = dpdata.system.LabeledSystem.DTYPES
dpdata.system.LabeledSystem.DTYPES = dpdata.system.LabeledSystem.DTYPES + (
DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False),
)
self.system = dpdata.LabeledSystem("poscars/OUTCAR.h2o.md", fmt="vasp/outcar")
self.foo = np.ones((len(self.system), 2, 4))
self.system.data["foo"] = self.foo
self.system.check_data()

def tearDown(self) -> None:
dpdata.system.LabeledSystem.DTYPES = self.backup

def test_to_deepmd_raw(self):
self.system.to_deepmd_raw("data_foo")
foo = np.loadtxt("data_foo/foo.raw")
Expand Down