From c39bf14303920436f32386aea43881c6361be370 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 20 Sep 2021 16:49:23 -0400 Subject: [PATCH 01/10] make a draft --- deepmd/utils/path.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 deepmd/utils/path.py diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py new file mode 100644 index 0000000000..46683d6ba2 --- /dev/null +++ b/deepmd/utils/path.py @@ -0,0 +1,37 @@ +import os +from abc import ABC, abstractmethod + +import numpy as np + +class DPPath(ABC): + """The path class to data system (DeepmdData).""" + def __new__(cls, path: str): + if os.path.isdir(path): + return DPOSPath + elif os.path.isfile(path): + # assume h5 if it is not dir + # TODO: check if it is a real h5? + return DPH5Path + raise OSError("%s not exists" % path) + + @abstractmethod + def load_numpy_array(path: str) -> np.ndarray: + """Load NumPy array. + + Parameters + ---------- + path : str + Path to numpy array relative to the path system + + Returns + ------- + np.ndarray + The loaded NumPy array + """ + +class DPOSPath(DPPath): + pass + + +class DPH5Path(DPPath): + pass \ No newline at end of file From d0e737e68cfc6c8f2c22be6f39124a101337513b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 21 Sep 2021 02:51:31 -0400 Subject: [PATCH 02/10] add support for hdf5 --- deepmd/common.py | 8 +- deepmd/entrypoints/train.py | 6 +- deepmd/utils/data.py | 57 +++--- deepmd/utils/path.py | 292 +++++++++++++++++++++++++++++-- requirements.txt | 2 + source/tests/test.hdf5 | Bin 0 -> 6056 bytes source/tests/test_deepmd_data.py | 16 ++ 7 files changed, 334 insertions(+), 47 deletions(-) create mode 100644 source/tests/test.hdf5 diff --git a/deepmd/common.py b/deepmd/common.py index 03d7d8caf3..9968cff39c 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -23,6 +23,7 @@ from deepmd.env import GLOBAL_TF_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION from deepmd.utils.sess import run_sess from deepmd.utils.errors import GraphWithoutTensorError +from deepmd.utils.path import DPPath if TYPE_CHECKING: _DICT_VAL = TypeVar("_DICT_VAL") @@ -429,9 +430,10 @@ def expand_sys_str(root_dir: Union[str, Path]) -> List[str]: List[str] list of string pointing to system directories """ - matches = [str(d) for d in Path(root_dir).rglob("*") if (d / "type.raw").is_file()] - if (Path(root_dir) / "type.raw").is_file(): - matches += [root_dir] + root_dir = DPPath(root_dir) + matches = [str(d) for d in root_dir.rglob("*") if (d / "type.raw").is_file()] + if (root_dir / "type.raw").is_file(): + matches.append(str(root_dir)) return matches diff --git a/deepmd/entrypoints/train.py b/deepmd/entrypoints/train.py index 817d603f3c..98090c18af 100755 --- a/deepmd/entrypoints/train.py +++ b/deepmd/entrypoints/train.py @@ -20,6 +20,7 @@ from deepmd.utils.data_system import DeepmdDataSystem from deepmd.utils.sess import run_sess from deepmd.utils.neighbor_stat import NeighborStat +from deepmd.utils.path import DPPath __all__ = ["train"] @@ -181,11 +182,12 @@ def get_data(jdata: Dict[str, Any], rcut, type_map, modifier): raise IOError(msg, help_msg) # rougly check all items in systems are valid for ii in systems: - if (not os.path.isdir(ii)): + ii = DPPath(ii) + if (not ii.is_dir()): msg = f'dir {ii} is not a valid dir' log.fatal(msg) raise IOError(msg, help_msg) - if (not os.path.isfile(os.path.join(ii, 'type.raw'))): + if (not (ii / 'type.raw').is_file()): msg = f'dir {ii} is not a valid data system dir' log.fatal(msg) raise IOError(msg, help_msg) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 34ef29400e..ecf4aaeaba 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -10,6 +10,7 @@ from deepmd.env import GLOBAL_NP_FLOAT_PRECISION from deepmd.env import GLOBAL_ENER_FLOAT_PRECISION from deepmd.utils import random as dp_random +from deepmd.utils.path import DPPath log = logging.getLogger(__name__) @@ -44,17 +45,18 @@ def __init__ (self, """ Constructor """ - self.dirs = glob.glob (os.path.join(sys_path, set_prefix + ".*")) + root = DPPath(sys_path) + self.dirs = root.glob(set_prefix + ".*") self.dirs.sort() # load atom type - self.atom_type = self._load_type(sys_path) + self.atom_type = self._load_type(root) self.natoms = len(self.atom_type) # load atom type map - self.type_map = self._load_type_map(sys_path) + self.type_map = self._load_type_map(root) if self.type_map is not None: assert(len(self.type_map) >= max(self.atom_type)+1) # check pbc - self.pbc = self._check_pbc(sys_path) + self.pbc = self._check_pbc(root) # enforce type_map if necessary if type_map is not None and self.type_map is not None: atom_type_ = [type_map.index(self.type_map[ii]) for ii in self.atom_type] @@ -167,9 +169,9 @@ def check_batch_size (self, batch_size) : """ for ii in self.train_dirs : if self.data_dict['coord']['high_prec'] : - tmpe = np.load(os.path.join(ii, "coord.npy")).astype(GLOBAL_ENER_FLOAT_PRECISION) + tmpe = (ii / "coord.npy").load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION) else: - tmpe = np.load(os.path.join(ii, "coord.npy")).astype(GLOBAL_NP_FLOAT_PRECISION) + tmpe = (ii / "coord.npy").load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION) if tmpe.ndim == 1: tmpe = tmpe.reshape([1,-1]) if tmpe.shape[0] < batch_size : @@ -181,9 +183,9 @@ def check_test_size (self, test_size) : Check if the system can get a test dataset with `test_size` frames. """ if self.data_dict['coord']['high_prec'] : - tmpe = np.load(os.path.join(self.test_dir, "coord.npy")).astype(GLOBAL_ENER_FLOAT_PRECISION) + tmpe = (self.test_dir / "coord.npy").load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION) else: - tmpe = np.load(os.path.join(self.test_dir, "coord.npy")).astype(GLOBAL_NP_FLOAT_PRECISION) + tmpe = (self.test_dir / "coord.npy").load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION) if tmpe.ndim == 1: tmpe = tmpe.reshape([1,-1]) if tmpe.shape[0] < test_size : @@ -377,7 +379,7 @@ def _get_subdata(self, data, idx = None) : return new_data def _load_batch_set (self, - set_name) : + set_name: DPPath) : self.batch_set = self._load_set(set_name) self.batch_set, _ = self._shuffle_data(self.batch_set) self.reset_get_batch() @@ -386,7 +388,7 @@ def reset_get_batch(self): self.iterator = 0 def _load_test_set (self, - set_name, + set_name: DPPath, shuffle_test) : self.test_set = self._load_set(set_name) if shuffle_test : @@ -409,13 +411,15 @@ def _shuffle_data (self, ret[kk] = data[kk] return ret, idx - def _load_set(self, set_name) : + def _load_set(self, set_name: DPPath) : # get nframes - path = os.path.join(set_name, "coord.npy") + if not isinstance(set_name, DPPath): + set_name = DPPath(set_name) + path = set_name / "coord.npy" if self.data_dict['coord']['high_prec'] : - coord = np.load(path).astype(GLOBAL_ENER_FLOAT_PRECISION) + coord = path.load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION) else: - coord = np.load(path).astype(GLOBAL_NP_FLOAT_PRECISION) + coord = path.load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION) if coord.ndim == 1: coord = coord.reshape([1,-1]) nframes = coord.shape[0] @@ -459,12 +463,12 @@ def _load_data(self, set_name, key, nframes, ndof_, atomic = False, must = True, ndof = ndof_ * natoms else: ndof = ndof_ - path = os.path.join(set_name, key+".npy") - if os.path.isfile (path) : + path = set_name / (key+".npy") + if path.is_file() : if high_prec : - data = np.load(path).astype(GLOBAL_ENER_FLOAT_PRECISION) + data = path.load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION) else: - data = np.load(path).astype(GLOBAL_NP_FLOAT_PRECISION) + data = path.load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION) try: # YWolfeee: deal with data shape error if atomic : data = data.reshape([nframes, natoms, -1]) @@ -491,8 +495,8 @@ def _load_data(self, set_name, key, nframes, ndof_, atomic = False, must = True, return np.float32(0.0), data - def _load_type (self, sys_path) : - atom_type = np.loadtxt (os.path.join(sys_path, "type.raw"), dtype=np.int32, ndmin=1) + def _load_type (self, sys_path: DPPath) : + atom_type = (sys_path / "type.raw").load_txt(dtype=np.int32, ndmin=1) return atom_type def _make_idx_map(self, atom_type): @@ -501,17 +505,16 @@ def _make_idx_map(self, atom_type): idx_map = np.lexsort ((idx, atom_type)) return idx_map - def _load_type_map(self, sys_path) : - fname = os.path.join(sys_path, 'type_map.raw') - if os.path.isfile(fname) : - with open(os.path.join(sys_path, 'type_map.raw')) as fp: - return fp.read().split() + def _load_type_map(self, sys_path: DPPath) : + fname = sys_path / 'type_map.raw' + if fname.is_file() : + return fname.load_txt(dtype=str).tolist() else : return None - def _check_pbc(self, sys_path): + def _check_pbc(self, sys_path: DPPath): pbc = True - if os.path.isfile(os.path.join(sys_path, 'nopbc')) : + if (sys_path / 'nopbc').is_file() : pbc = False return pbc diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index 46683d6ba2..4b9619d968 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -1,37 +1,299 @@ import os from abc import ABC, abstractmethod +from typing import List +from pathlib import Path +from functools import lru_cache import numpy as np +import h5py +from wcmatch.glob import globfilter class DPPath(ABC): - """The path class to data system (DeepmdData).""" + """The path class to data system (DeepmdData). + + Parameters + ---------- + path : str + path + """ def __new__(cls, path: str): - if os.path.isdir(path): - return DPOSPath - elif os.path.isfile(path): - # assume h5 if it is not dir - # TODO: check if it is a real h5? - return DPH5Path - raise OSError("%s not exists" % path) + if cls is DPPath: + if os.path.isdir(path): + return super().__new__(DPOSPath) + elif os.path.isfile(path.split("#")[0]): + # assume h5 if it is not dir + # TODO: check if it is a real h5? or just check suffix? + return super().__new__(DPH5Path) + raise FileNotFoundError("%s not found" % path) + return super().__new__(cls) @abstractmethod - def load_numpy_array(path: str) -> np.ndarray: + def load_numpy(self) -> np.ndarray: """Load NumPy array. + Returns + ------- + np.ndarray + loaded NumPy array + """ + + @abstractmethod + def load_txt(self, **kwargs) -> np.ndarray: + """Load NumPy array from text. + + Returns + ------- + np.ndarray + loaded NumPy array + """ + + @abstractmethod + def glob(self, pattern: str) -> List["DPPath"]: + """Search path using the glob pattern. + Parameters ---------- - path : str - Path to numpy array relative to the path system + pattern : str + glob pattern Returns ------- - np.ndarray - The loaded NumPy array + List[DPPath] + list of paths + """ + + def rglob(self, pattern: str) -> List["DPPath"]: + """This is like calling :metd:`DPPath.glob()` with `**/` added in front + of the given relative pattern. + + Parameters + ---------- + pattern : str + glob pattern + + Returns + ------- + List[DPPath] + list of paths """ + return self.glob("**" + pattern) + + @abstractmethod + def is_file(self) -> bool: + """Check if self is file.""" + + @abstractmethod + def is_dir(self) -> bool: + """Check if self is directory.""" + + @abstractmethod + def __truediv__(self, key: str) -> "DPPath": + """Used for / operator.""" + + @abstractmethod + def __lt__(self, other: "DPPath") -> bool: + """whether this DPPath is less than other for sorting""" + + @abstractmethod + def __str__(self) -> str: + """Represent string""" + + def __repr__(self) -> str: + return "%s (%s)" % (type(self), str(self)) + + def __eq__(self, other) -> bool: + return str(self) == str(other) + + def __hash__(self): + return hash(str(self)) + class DPOSPath(DPPath): - pass + """The OS path class to data system (DeepmdData) for real directories. + + Parameters + ---------- + path : str + path + """ + def __init__(self, path: str) -> None: + super().__init__() + if isinstance(path, Path): + self.path = path + else: + self.path = Path(path) + + def load_numpy(self) -> np.ndarray: + """Load NumPy array. + + Returns + ------- + np.ndarray + loaded NumPy array + """ + return np.load(str(self.path)) + + def load_txt(self, **kwargs) -> np.ndarray: + """Load NumPy array from text. + + Returns + ------- + np.ndarray + loaded NumPy array + """ + return np.loadtxt(str(self.path), **kwargs) + + def glob(self, pattern: str) -> List["DPPath"]: + """Search path using the glob pattern. + + Parameters + ---------- + pattern : str + glob pattern + + Returns + ------- + List[DPPath] + list of paths + """ + # currently DPOSPath will only derivative DPOSPath + # TODO: discuss if we want to mix DPOSPath and DPH5Path? + return list([type(self)(p) for p in self.path.glob(pattern)]) + + def is_file(self) -> bool: + """Check if self is file.""" + return self.path.is_file() + + def is_dir(self) -> bool: + """Check if self is directory.""" + return self.path.is_dir() + + def __truediv__(self, key: str) -> "DPPath": + """Used for / operator.""" + return type(self)(self.path / key) + + def __lt__(self, other: "DPOSPath") -> bool: + """whether this DPPath is less than other for sorting""" + return self.path < other.path + + def __str__(self) -> str: + """Represent string""" + return str(self.path) class DPH5Path(DPPath): - pass \ No newline at end of file + """The path class to data system (DeepmdData) for HDF5 files. + + Notes + ----- + OS - HDF5 relationship: + directory - Group + file - Dataset + + Parameters + ---------- + path : str + path + """ + def __init__(self, path: str) -> None: + super().__init__() + # we use "#" to split path + # so we do not support file names containing #... + s = path.split("#") + self.root_path = s[0] + self.root = self._load_h5py(s[0]) + # h5 path: default is the root path + self.name = s[1] if len(s) > 1 else "/" + + @classmethod + @lru_cache(None) + def _load_h5py(cls, path: str) -> h5py.File: + """Load hdf5 file. + + Parameters + ---------- + path : str + path to hdf5 file + """ + # this method has cache to avoid duplicated + # loading from different DPH5Path + # However the file will be never closed? + return h5py.File(path, 'r') + + def load_numpy(self) -> np.ndarray: + """Load NumPy array. + + Returns + ------- + np.ndarray + loaded NumPy array + """ + return self.root[self.name][:] + + def load_txt(self, dtype: np.dtype = None, **kwargs) -> np.ndarray: + """Load NumPy array from text. + + Returns + ------- + np.ndarray + loaded NumPy array + """ + arr = self.load_numpy() + if dtype: + arr = arr.astype(dtype) + return arr + + def glob(self, pattern: str) -> List["DPPath"]: + """Search path using the glob pattern. + + Parameters + ---------- + pattern : str + glob pattern + + Returns + ------- + List[DPPath] + list of paths + """ + return list([type(self)("%s#%s"%(self.root_path, pp)) for pp in globfilter(self._keys, self._connect_path(pattern))]) + + @property + @lru_cache + def _keys(self): + """Walk all groups and dataset""" + l = [] + self.root.visititems(lambda x, _: l.append("/" + x)) + return l + + def is_file(self) -> bool: + """Check if self is file.""" + if self.name not in self._keys: + return False + return isinstance(self.root[self.name], h5py.Dataset) + + def is_dir(self) -> bool: + """Check if self is directory.""" + if self.name not in self._keys: + return False + return isinstance(self.root[self.name], h5py.Group) + + def __truediv__(self, key: str) -> "DPPath": + """Used for / operator.""" + return type(self)("%s#%s" % (self.root_path, self._connect_path(key))) + + def _connect_path(self, path: str) -> str: + """Connect self with path""" + if self.name.endswith("/"): + return "%s%s" % (self.name, path) + return "%s/%s" % (self.name, path) + + def __lt__(self, other: "DPH5Path") -> bool: + """whether this DPPath is less than other for sorting""" + if self.root_path == other.root_path: + return self.name < other.name + return self.root_path < other.root_path + + def __str__(self) -> str: + """returns path of self""" + return "%s#%s" % (self.root_path, self.name) diff --git a/requirements.txt b/requirements.txt index 50b597f2fe..f3ead805b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,5 @@ pyyaml dargs >= 0.2.6 python-hostlist >= 1.21 typing_extensions; python_version < "3.7" +h5py +wcmatch diff --git a/source/tests/test.hdf5 b/source/tests/test.hdf5 new file mode 100644 index 0000000000000000000000000000000000000000..2849b1b896af9d3a4a0284b10ae325c01add067d GIT binary patch literal 6056 zcmeHLeN5D49DjcIVh4l@gpe7YMbp5sa||b&<@0<36(^RNlWK8Za)vK>bfO$JioiAy zb#tr6v>{?7>ksAD7+^2R700zUBp1j^)V!?rgY%>Ym$v-{<@C zJkRg@+2?ug_kC(sW-X77oE=G&W-<{2nW=Wg;}&nt37SRd!oQJsBrdDDY}RQVi6}op z61hEw&p$ryxo5Mpi5w~B7m^Pn7K7Fk^9Xxb{yz{{na!S2GR@=ZTDbIao^h3wloi|4 zoVLeeD#v9Rs5m!w%N7L}7q%mNV@L`r0u7Rw1PaVATWW_W;)6H+hpTk*^^$=gZw zQT1t*HN~#e>C}NHK964FVWdHg6OUMKR)RCjb8LiE2-l>_b!(o^l8F8`Zq?JY*HM#F zy8QTP;dTuJLd^QmuPyeq%NRQDzv8_cuLPb#>N$VnWP@T0&ZhsJ(@ zqSnARJazNQ%Z6Pr?yiTY=my*T{0aQ<#VO3%^*TPw_fTuqTJ&AI!cxzl#O#N>u*Owj zMxGS~y>r-AeF~};wK3^eE9{;qmU8V3yB<}`V!wTizLr|W@>0KNM~=-wQ{g#wV_pl+ zjXBuyzDE$V_5_-Jd*J(OABOMWh>o`gY2Q2Vvkf)vcyYf6r%oqhY~VfYPKd>i=PQvu z|53L1;4PM#ILw-=-o=u(y|}Psn0-BWJH0%VjKS7uc6{C2C>_egjNwj}KRp^fi{+iVqhBjXqL<;{bv+;o_h=N^Vf#^LAd zZq}Xj8*TBe!Ut)sC|vvkt}gqIZFSY6&D)5`EA7}meGkjYaiKTS%SOj~P|~yl@s~!} z*-RaCNA%OJoo4EK=~d+Q46@}T_3Y@O44l2_LC1kn8o%a2oNQ^pnQvAx+Es`9hSHIH z+6K4%7#g>Hh!l4nBD-=~Qj(V*7^#Fir;%MSb>rcM5m>l-FZJg@SPLT%Mj(tp7=e34 z;I4IdoVI|dx9)Md@bG^b0)chP0yP6kD`RheFGujaMAad5z5FGu*ka6c<+~-NnMs0) zB)qBiL)9@yc^#9?BQ+6N#}xYc|87OAe+{B&C+67DEAn!fxWolPkZ?1%SH}kl;AXK8 z>t=&JqpyHlMi d`Qh_VD}i|Y`MD$bAaD`-#PR9i_)IN+{0=mjRr&w` literal 0 HcmV?d00001 diff --git a/source/tests/test_deepmd_data.py b/source/tests/test_deepmd_data.py index 8532d8d9a4..bd57f0a090 100644 --- a/source/tests/test_deepmd_data.py +++ b/source/tests/test_deepmd_data.py @@ -257,3 +257,19 @@ def test_get_nbatch(self): def _comp_np_mat2(self, first, second) : np.testing.assert_almost_equal(first, second, places) + + +class TestH5Data (unittest.TestCase) : + def setUp (self) : + self.data_name = 'test.hdf5' + + def test_init (self) : + dd = DeepmdData(self.data_name) + self.assertEqual(dd.idx_map[0], 0) + self.assertEqual(dd.type_map, ['X']) + self.assertEqual(dd.test_dir, 'test.hdf5#/set.000') + self.assertEqual(dd.train_dirs, ['test.hdf5#/set.000']) + + def test_get_batch(self) : + dd = DeepmdData(self.data_name) + data = dd.get_batch(5) From f0ec2b839fea9a80f3a60ab37b0ee02e41de1aa4 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 21 Sep 2021 03:04:29 -0400 Subject: [PATCH 03/10] fix error with old python --- deepmd/utils/path.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index 4b9619d968..883f73d7c8 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -259,7 +259,7 @@ def glob(self, pattern: str) -> List["DPPath"]: return list([type(self)("%s#%s"%(self.root_path, pp)) for pp in globfilter(self._keys, self._connect_path(pattern))]) @property - @lru_cache + @lru_cache(None) def _keys(self): """Walk all groups and dataset""" l = [] From ed71cd834ae35bf9b18ee1853ee2148b5018e804 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 21 Sep 2021 03:17:47 -0400 Subject: [PATCH 04/10] fix rglobal --- deepmd/utils/path.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index 4b9619d968..34f81c7b49 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -62,6 +62,7 @@ def glob(self, pattern: str) -> List["DPPath"]: list of paths """ + @abstractmethod def rglob(self, pattern: str) -> List["DPPath"]: """This is like calling :metd:`DPPath.glob()` with `**/` added in front of the given relative pattern. @@ -76,7 +77,6 @@ def rglob(self, pattern: str) -> List["DPPath"]: List[DPPath] list of paths """ - return self.glob("**" + pattern) @abstractmethod def is_file(self) -> bool: @@ -160,6 +160,22 @@ def glob(self, pattern: str) -> List["DPPath"]: # TODO: discuss if we want to mix DPOSPath and DPH5Path? return list([type(self)(p) for p in self.path.glob(pattern)]) + def rglob(self, pattern: str) -> List["DPPath"]: + """This is like calling :metd:`DPPath.glob()` with `**/` added in front + of the given relative pattern. + + Parameters + ---------- + pattern : str + glob pattern + + Returns + ------- + List[DPPath] + list of paths + """ + return list([type(self)(p) for p in self.path.rglob(pattern)]) + def is_file(self) -> bool: """Check if self is file.""" return self.path.is_file() @@ -258,6 +274,22 @@ def glob(self, pattern: str) -> List["DPPath"]: """ return list([type(self)("%s#%s"%(self.root_path, pp)) for pp in globfilter(self._keys, self._connect_path(pattern))]) + def rglob(self, pattern: str) -> List["DPPath"]: + """This is like calling :metd:`DPPath.glob()` with `**/` added in front + of the given relative pattern. + + Parameters + ---------- + pattern : str + glob pattern + + Returns + ------- + List[DPPath] + list of paths + """ + return self.glob("**" + pattern) + @property @lru_cache def _keys(self): From 035d942927a0f3ac11f5badabb556d8699c90dff Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 21 Sep 2021 03:38:39 -0400 Subject: [PATCH 05/10] fix tests_path --- source/tests/test_deepmd_data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/source/tests/test_deepmd_data.py b/source/tests/test_deepmd_data.py index bd57f0a090..81d6890747 100644 --- a/source/tests/test_deepmd_data.py +++ b/source/tests/test_deepmd_data.py @@ -4,6 +4,7 @@ from deepmd.utils.data import DeepmdData from deepmd.env import GLOBAL_NP_FLOAT_PRECISION +from common import tests_path if GLOBAL_NP_FLOAT_PRECISION == np.float32 : places = 6 @@ -261,7 +262,7 @@ def _comp_np_mat2(self, first, second) : class TestH5Data (unittest.TestCase) : def setUp (self) : - self.data_name = 'test.hdf5' + self.data_name = tests_path / 'test.hdf5' def test_init (self) : dd = DeepmdData(self.data_name) From 0c2097d2c1b8910e94a949cc8dedd7c1a5082db0 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 21 Sep 2021 03:52:54 -0400 Subject: [PATCH 06/10] fix tests --- source/tests/test_deepmd_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/test_deepmd_data.py b/source/tests/test_deepmd_data.py index 81d6890747..8cf4fbc29a 100644 --- a/source/tests/test_deepmd_data.py +++ b/source/tests/test_deepmd_data.py @@ -262,7 +262,7 @@ def _comp_np_mat2(self, first, second) : class TestH5Data (unittest.TestCase) : def setUp (self) : - self.data_name = tests_path / 'test.hdf5' + self.data_name = str(tests_path / 'test.hdf5') def test_init (self) : dd = DeepmdData(self.data_name) From 4941023a2b02ff1a607115197fc910d849cf4cdb Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 21 Sep 2021 04:08:33 -0400 Subject: [PATCH 07/10] Update test_deepmd_data.py --- source/tests/test_deepmd_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/tests/test_deepmd_data.py b/source/tests/test_deepmd_data.py index 8cf4fbc29a..b1256b67f5 100644 --- a/source/tests/test_deepmd_data.py +++ b/source/tests/test_deepmd_data.py @@ -268,8 +268,8 @@ def test_init (self) : dd = DeepmdData(self.data_name) self.assertEqual(dd.idx_map[0], 0) self.assertEqual(dd.type_map, ['X']) - self.assertEqual(dd.test_dir, 'test.hdf5#/set.000') - self.assertEqual(dd.train_dirs, ['test.hdf5#/set.000']) + self.assertEqual(dd.test_dir, self.data_name + '#/set.000') + self.assertEqual(dd.train_dirs, [self.data_name + '#/set.000']) def test_get_batch(self) : dd = DeepmdData(self.data_name) From 58e2bd5e560a888775bb151436c1fe743fd5bc21 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 21 Sep 2021 16:27:35 -0400 Subject: [PATCH 08/10] use `visit` instead of `visititems` --- deepmd/utils/path.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index bca8cddece..d22887cedb 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -295,7 +295,7 @@ def rglob(self, pattern: str) -> List["DPPath"]: def _keys(self): """Walk all groups and dataset""" l = [] - self.root.visititems(lambda x, _: l.append("/" + x)) + self.root.visit(lambda x: l.append("/" + x)) return l def is_file(self) -> bool: From e7db22ed137be0bba44dcfbb5a9cbd2475b7d649 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 21 Sep 2021 16:40:22 -0400 Subject: [PATCH 09/10] cache file keys to prevent performance issues --- deepmd/utils/path.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index d22887cedb..08704193cb 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -291,11 +291,16 @@ def rglob(self, pattern: str) -> List["DPPath"]: return self.glob("**" + pattern) @property + def _keys(self) -> List[str]: + """Walk all groups and dataset""" + return self._file_keys(self.root) + + @classmethod @lru_cache(None) - def _keys(self): + def _file_keys(cls, file: h5py.File) -> List[str]: """Walk all groups and dataset""" l = [] - self.root.visit(lambda x: l.append("/" + x)) + file.visit(lambda x: l.append("/" + x)) return l def is_file(self) -> bool: From dc44906b9801608cd231e402c779acaf1769d282 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 21 Sep 2021 16:47:07 -0400 Subject: [PATCH 10/10] improve performance --- deepmd/utils/path.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index 08704193cb..b39fba7e3c 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -272,7 +272,9 @@ def glob(self, pattern: str) -> List["DPPath"]: List[DPPath] list of paths """ - return list([type(self)("%s#%s"%(self.root_path, pp)) for pp in globfilter(self._keys, self._connect_path(pattern))]) + # got paths starts with current path first, which is faster + subpaths = [ii for ii in self._keys if ii.startswith(self.name)] + return list([type(self)("%s#%s"%(self.root_path, pp)) for pp in globfilter(subpaths, self._connect_path(pattern))]) def rglob(self, pattern: str) -> List["DPPath"]: """This is like calling :metd:`DPPath.glob()` with `**/` added in front