diff --git a/deepmd/common.py b/deepmd/common.py index 03d7d8caf3..9968cff39c 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -23,6 +23,7 @@ from deepmd.env import GLOBAL_TF_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION from deepmd.utils.sess import run_sess from deepmd.utils.errors import GraphWithoutTensorError +from deepmd.utils.path import DPPath if TYPE_CHECKING: _DICT_VAL = TypeVar("_DICT_VAL") @@ -429,9 +430,10 @@ def expand_sys_str(root_dir: Union[str, Path]) -> List[str]: List[str] list of string pointing to system directories """ - matches = [str(d) for d in Path(root_dir).rglob("*") if (d / "type.raw").is_file()] - if (Path(root_dir) / "type.raw").is_file(): - matches += [root_dir] + root_dir = DPPath(root_dir) + matches = [str(d) for d in root_dir.rglob("*") if (d / "type.raw").is_file()] + if (root_dir / "type.raw").is_file(): + matches.append(str(root_dir)) return matches diff --git a/deepmd/entrypoints/train.py b/deepmd/entrypoints/train.py index 817d603f3c..98090c18af 100755 --- a/deepmd/entrypoints/train.py +++ b/deepmd/entrypoints/train.py @@ -20,6 +20,7 @@ from deepmd.utils.data_system import DeepmdDataSystem from deepmd.utils.sess import run_sess from deepmd.utils.neighbor_stat import NeighborStat +from deepmd.utils.path import DPPath __all__ = ["train"] @@ -181,11 +182,12 @@ def get_data(jdata: Dict[str, Any], rcut, type_map, modifier): raise IOError(msg, help_msg) # rougly check all items in systems are valid for ii in systems: - if (not os.path.isdir(ii)): + ii = DPPath(ii) + if (not ii.is_dir()): msg = f'dir {ii} is not a valid dir' log.fatal(msg) raise IOError(msg, help_msg) - if (not os.path.isfile(os.path.join(ii, 'type.raw'))): + if (not (ii / 'type.raw').is_file()): msg = f'dir {ii} is not a valid data system dir' log.fatal(msg) raise IOError(msg, help_msg) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 34ef29400e..ecf4aaeaba 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -10,6 +10,7 @@ from deepmd.env import GLOBAL_NP_FLOAT_PRECISION from deepmd.env import GLOBAL_ENER_FLOAT_PRECISION from deepmd.utils import random as dp_random +from deepmd.utils.path import DPPath log = logging.getLogger(__name__) @@ -44,17 +45,18 @@ def __init__ (self, """ Constructor """ - self.dirs = glob.glob (os.path.join(sys_path, set_prefix + ".*")) + root = DPPath(sys_path) + self.dirs = root.glob(set_prefix + ".*") self.dirs.sort() # load atom type - self.atom_type = self._load_type(sys_path) + self.atom_type = self._load_type(root) self.natoms = len(self.atom_type) # load atom type map - self.type_map = self._load_type_map(sys_path) + self.type_map = self._load_type_map(root) if self.type_map is not None: assert(len(self.type_map) >= max(self.atom_type)+1) # check pbc - self.pbc = self._check_pbc(sys_path) + self.pbc = self._check_pbc(root) # enforce type_map if necessary if type_map is not None and self.type_map is not None: atom_type_ = [type_map.index(self.type_map[ii]) for ii in self.atom_type] @@ -167,9 +169,9 @@ def check_batch_size (self, batch_size) : """ for ii in self.train_dirs : if self.data_dict['coord']['high_prec'] : - tmpe = np.load(os.path.join(ii, "coord.npy")).astype(GLOBAL_ENER_FLOAT_PRECISION) + tmpe = (ii / "coord.npy").load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION) else: - tmpe = np.load(os.path.join(ii, "coord.npy")).astype(GLOBAL_NP_FLOAT_PRECISION) + tmpe = (ii / "coord.npy").load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION) if tmpe.ndim == 1: tmpe = tmpe.reshape([1,-1]) if tmpe.shape[0] < batch_size : @@ -181,9 +183,9 @@ def check_test_size (self, test_size) : Check if the system can get a test dataset with `test_size` frames. """ if self.data_dict['coord']['high_prec'] : - tmpe = np.load(os.path.join(self.test_dir, "coord.npy")).astype(GLOBAL_ENER_FLOAT_PRECISION) + tmpe = (self.test_dir / "coord.npy").load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION) else: - tmpe = np.load(os.path.join(self.test_dir, "coord.npy")).astype(GLOBAL_NP_FLOAT_PRECISION) + tmpe = (self.test_dir / "coord.npy").load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION) if tmpe.ndim == 1: tmpe = tmpe.reshape([1,-1]) if tmpe.shape[0] < test_size : @@ -377,7 +379,7 @@ def _get_subdata(self, data, idx = None) : return new_data def _load_batch_set (self, - set_name) : + set_name: DPPath) : self.batch_set = self._load_set(set_name) self.batch_set, _ = self._shuffle_data(self.batch_set) self.reset_get_batch() @@ -386,7 +388,7 @@ def reset_get_batch(self): self.iterator = 0 def _load_test_set (self, - set_name, + set_name: DPPath, shuffle_test) : self.test_set = self._load_set(set_name) if shuffle_test : @@ -409,13 +411,15 @@ def _shuffle_data (self, ret[kk] = data[kk] return ret, idx - def _load_set(self, set_name) : + def _load_set(self, set_name: DPPath) : # get nframes - path = os.path.join(set_name, "coord.npy") + if not isinstance(set_name, DPPath): + set_name = DPPath(set_name) + path = set_name / "coord.npy" if self.data_dict['coord']['high_prec'] : - coord = np.load(path).astype(GLOBAL_ENER_FLOAT_PRECISION) + coord = path.load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION) else: - coord = np.load(path).astype(GLOBAL_NP_FLOAT_PRECISION) + coord = path.load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION) if coord.ndim == 1: coord = coord.reshape([1,-1]) nframes = coord.shape[0] @@ -459,12 +463,12 @@ def _load_data(self, set_name, key, nframes, ndof_, atomic = False, must = True, ndof = ndof_ * natoms else: ndof = ndof_ - path = os.path.join(set_name, key+".npy") - if os.path.isfile (path) : + path = set_name / (key+".npy") + if path.is_file() : if high_prec : - data = np.load(path).astype(GLOBAL_ENER_FLOAT_PRECISION) + data = path.load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION) else: - data = np.load(path).astype(GLOBAL_NP_FLOAT_PRECISION) + data = path.load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION) try: # YWolfeee: deal with data shape error if atomic : data = data.reshape([nframes, natoms, -1]) @@ -491,8 +495,8 @@ def _load_data(self, set_name, key, nframes, ndof_, atomic = False, must = True, return np.float32(0.0), data - def _load_type (self, sys_path) : - atom_type = np.loadtxt (os.path.join(sys_path, "type.raw"), dtype=np.int32, ndmin=1) + def _load_type (self, sys_path: DPPath) : + atom_type = (sys_path / "type.raw").load_txt(dtype=np.int32, ndmin=1) return atom_type def _make_idx_map(self, atom_type): @@ -501,17 +505,16 @@ def _make_idx_map(self, atom_type): idx_map = np.lexsort ((idx, atom_type)) return idx_map - def _load_type_map(self, sys_path) : - fname = os.path.join(sys_path, 'type_map.raw') - if os.path.isfile(fname) : - with open(os.path.join(sys_path, 'type_map.raw')) as fp: - return fp.read().split() + def _load_type_map(self, sys_path: DPPath) : + fname = sys_path / 'type_map.raw' + if fname.is_file() : + return fname.load_txt(dtype=str).tolist() else : return None - def _check_pbc(self, sys_path): + def _check_pbc(self, sys_path: DPPath): pbc = True - if os.path.isfile(os.path.join(sys_path, 'nopbc')) : + if (sys_path / 'nopbc').is_file() : pbc = False return pbc diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py new file mode 100644 index 0000000000..b39fba7e3c --- /dev/null +++ b/deepmd/utils/path.py @@ -0,0 +1,338 @@ +import os +from abc import ABC, abstractmethod +from typing import List +from pathlib import Path +from functools import lru_cache + +import numpy as np +import h5py +from wcmatch.glob import globfilter + +class DPPath(ABC): + """The path class to data system (DeepmdData). + + Parameters + ---------- + path : str + path + """ + def __new__(cls, path: str): + if cls is DPPath: + if os.path.isdir(path): + return super().__new__(DPOSPath) + elif os.path.isfile(path.split("#")[0]): + # assume h5 if it is not dir + # TODO: check if it is a real h5? or just check suffix? + return super().__new__(DPH5Path) + raise FileNotFoundError("%s not found" % path) + return super().__new__(cls) + + @abstractmethod + def load_numpy(self) -> np.ndarray: + """Load NumPy array. + + Returns + ------- + np.ndarray + loaded NumPy array + """ + + @abstractmethod + def load_txt(self, **kwargs) -> np.ndarray: + """Load NumPy array from text. + + Returns + ------- + np.ndarray + loaded NumPy array + """ + + @abstractmethod + def glob(self, pattern: str) -> List["DPPath"]: + """Search path using the glob pattern. + + Parameters + ---------- + pattern : str + glob pattern + + Returns + ------- + List[DPPath] + list of paths + """ + + @abstractmethod + def rglob(self, pattern: str) -> List["DPPath"]: + """This is like calling :metd:`DPPath.glob()` with `**/` added in front + of the given relative pattern. + + Parameters + ---------- + pattern : str + glob pattern + + Returns + ------- + List[DPPath] + list of paths + """ + + @abstractmethod + def is_file(self) -> bool: + """Check if self is file.""" + + @abstractmethod + def is_dir(self) -> bool: + """Check if self is directory.""" + + @abstractmethod + def __truediv__(self, key: str) -> "DPPath": + """Used for / operator.""" + + @abstractmethod + def __lt__(self, other: "DPPath") -> bool: + """whether this DPPath is less than other for sorting""" + + @abstractmethod + def __str__(self) -> str: + """Represent string""" + + def __repr__(self) -> str: + return "%s (%s)" % (type(self), str(self)) + + def __eq__(self, other) -> bool: + return str(self) == str(other) + + def __hash__(self): + return hash(str(self)) + + +class DPOSPath(DPPath): + """The OS path class to data system (DeepmdData) for real directories. + + Parameters + ---------- + path : str + path + """ + def __init__(self, path: str) -> None: + super().__init__() + if isinstance(path, Path): + self.path = path + else: + self.path = Path(path) + + def load_numpy(self) -> np.ndarray: + """Load NumPy array. + + Returns + ------- + np.ndarray + loaded NumPy array + """ + return np.load(str(self.path)) + + def load_txt(self, **kwargs) -> np.ndarray: + """Load NumPy array from text. + + Returns + ------- + np.ndarray + loaded NumPy array + """ + return np.loadtxt(str(self.path), **kwargs) + + def glob(self, pattern: str) -> List["DPPath"]: + """Search path using the glob pattern. + + Parameters + ---------- + pattern : str + glob pattern + + Returns + ------- + List[DPPath] + list of paths + """ + # currently DPOSPath will only derivative DPOSPath + # TODO: discuss if we want to mix DPOSPath and DPH5Path? + return list([type(self)(p) for p in self.path.glob(pattern)]) + + def rglob(self, pattern: str) -> List["DPPath"]: + """This is like calling :metd:`DPPath.glob()` with `**/` added in front + of the given relative pattern. + + Parameters + ---------- + pattern : str + glob pattern + + Returns + ------- + List[DPPath] + list of paths + """ + return list([type(self)(p) for p in self.path.rglob(pattern)]) + + def is_file(self) -> bool: + """Check if self is file.""" + return self.path.is_file() + + def is_dir(self) -> bool: + """Check if self is directory.""" + return self.path.is_dir() + + def __truediv__(self, key: str) -> "DPPath": + """Used for / operator.""" + return type(self)(self.path / key) + + def __lt__(self, other: "DPOSPath") -> bool: + """whether this DPPath is less than other for sorting""" + return self.path < other.path + + def __str__(self) -> str: + """Represent string""" + return str(self.path) + + +class DPH5Path(DPPath): + """The path class to data system (DeepmdData) for HDF5 files. + + Notes + ----- + OS - HDF5 relationship: + directory - Group + file - Dataset + + Parameters + ---------- + path : str + path + """ + def __init__(self, path: str) -> None: + super().__init__() + # we use "#" to split path + # so we do not support file names containing #... + s = path.split("#") + self.root_path = s[0] + self.root = self._load_h5py(s[0]) + # h5 path: default is the root path + self.name = s[1] if len(s) > 1 else "/" + + @classmethod + @lru_cache(None) + def _load_h5py(cls, path: str) -> h5py.File: + """Load hdf5 file. + + Parameters + ---------- + path : str + path to hdf5 file + """ + # this method has cache to avoid duplicated + # loading from different DPH5Path + # However the file will be never closed? + return h5py.File(path, 'r') + + def load_numpy(self) -> np.ndarray: + """Load NumPy array. + + Returns + ------- + np.ndarray + loaded NumPy array + """ + return self.root[self.name][:] + + def load_txt(self, dtype: np.dtype = None, **kwargs) -> np.ndarray: + """Load NumPy array from text. + + Returns + ------- + np.ndarray + loaded NumPy array + """ + arr = self.load_numpy() + if dtype: + arr = arr.astype(dtype) + return arr + + def glob(self, pattern: str) -> List["DPPath"]: + """Search path using the glob pattern. + + Parameters + ---------- + pattern : str + glob pattern + + Returns + ------- + List[DPPath] + list of paths + """ + # got paths starts with current path first, which is faster + subpaths = [ii for ii in self._keys if ii.startswith(self.name)] + return list([type(self)("%s#%s"%(self.root_path, pp)) for pp in globfilter(subpaths, self._connect_path(pattern))]) + + def rglob(self, pattern: str) -> List["DPPath"]: + """This is like calling :metd:`DPPath.glob()` with `**/` added in front + of the given relative pattern. + + Parameters + ---------- + pattern : str + glob pattern + + Returns + ------- + List[DPPath] + list of paths + """ + return self.glob("**" + pattern) + + @property + def _keys(self) -> List[str]: + """Walk all groups and dataset""" + return self._file_keys(self.root) + + @classmethod + @lru_cache(None) + def _file_keys(cls, file: h5py.File) -> List[str]: + """Walk all groups and dataset""" + l = [] + file.visit(lambda x: l.append("/" + x)) + return l + + def is_file(self) -> bool: + """Check if self is file.""" + if self.name not in self._keys: + return False + return isinstance(self.root[self.name], h5py.Dataset) + + def is_dir(self) -> bool: + """Check if self is directory.""" + if self.name not in self._keys: + return False + return isinstance(self.root[self.name], h5py.Group) + + def __truediv__(self, key: str) -> "DPPath": + """Used for / operator.""" + return type(self)("%s#%s" % (self.root_path, self._connect_path(key))) + + def _connect_path(self, path: str) -> str: + """Connect self with path""" + if self.name.endswith("/"): + return "%s%s" % (self.name, path) + return "%s/%s" % (self.name, path) + + def __lt__(self, other: "DPH5Path") -> bool: + """whether this DPPath is less than other for sorting""" + if self.root_path == other.root_path: + return self.name < other.name + return self.root_path < other.root_path + + def __str__(self) -> str: + """returns path of self""" + return "%s#%s" % (self.root_path, self.name) diff --git a/requirements.txt b/requirements.txt index 50b597f2fe..f3ead805b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,5 @@ pyyaml dargs >= 0.2.6 python-hostlist >= 1.21 typing_extensions; python_version < "3.7" +h5py +wcmatch diff --git a/source/tests/test.hdf5 b/source/tests/test.hdf5 new file mode 100644 index 0000000000..2849b1b896 Binary files /dev/null and b/source/tests/test.hdf5 differ diff --git a/source/tests/test_deepmd_data.py b/source/tests/test_deepmd_data.py index 8532d8d9a4..b1256b67f5 100644 --- a/source/tests/test_deepmd_data.py +++ b/source/tests/test_deepmd_data.py @@ -4,6 +4,7 @@ from deepmd.utils.data import DeepmdData from deepmd.env import GLOBAL_NP_FLOAT_PRECISION +from common import tests_path if GLOBAL_NP_FLOAT_PRECISION == np.float32 : places = 6 @@ -257,3 +258,19 @@ def test_get_nbatch(self): def _comp_np_mat2(self, first, second) : np.testing.assert_almost_equal(first, second, places) + + +class TestH5Data (unittest.TestCase) : + def setUp (self) : + self.data_name = str(tests_path / 'test.hdf5') + + def test_init (self) : + dd = DeepmdData(self.data_name) + self.assertEqual(dd.idx_map[0], 0) + self.assertEqual(dd.type_map, ['X']) + self.assertEqual(dd.test_dir, self.data_name + '#/set.000') + self.assertEqual(dd.train_dirs, [self.data_name + '#/set.000']) + + def test_get_batch(self) : + dd = DeepmdData(self.data_name) + data = dd.get_batch(5)