Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION
from deepmd.utils.sess import run_sess
from deepmd.utils.errors import GraphWithoutTensorError
from deepmd.utils.path import DPPath

if TYPE_CHECKING:
_DICT_VAL = TypeVar("_DICT_VAL")
Expand Down Expand Up @@ -429,9 +430,10 @@ def expand_sys_str(root_dir: Union[str, Path]) -> List[str]:
List[str]
list of string pointing to system directories
"""
matches = [str(d) for d in Path(root_dir).rglob("*") if (d / "type.raw").is_file()]
if (Path(root_dir) / "type.raw").is_file():
matches += [root_dir]
root_dir = DPPath(root_dir)
matches = [str(d) for d in root_dir.rglob("*") if (d / "type.raw").is_file()]
if (root_dir / "type.raw").is_file():
matches.append(str(root_dir))
return matches


Expand Down
6 changes: 4 additions & 2 deletions deepmd/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from deepmd.utils.data_system import DeepmdDataSystem
from deepmd.utils.sess import run_sess
from deepmd.utils.neighbor_stat import NeighborStat
from deepmd.utils.path import DPPath

__all__ = ["train"]

Expand Down Expand Up @@ -181,11 +182,12 @@ def get_data(jdata: Dict[str, Any], rcut, type_map, modifier):
raise IOError(msg, help_msg)
# rougly check all items in systems are valid
for ii in systems:
if (not os.path.isdir(ii)):
ii = DPPath(ii)
if (not ii.is_dir()):
msg = f'dir {ii} is not a valid dir'
log.fatal(msg)
raise IOError(msg, help_msg)
if (not os.path.isfile(os.path.join(ii, 'type.raw'))):
if (not (ii / 'type.raw').is_file()):
msg = f'dir {ii} is not a valid data system dir'
log.fatal(msg)
raise IOError(msg, help_msg)
Expand Down
57 changes: 30 additions & 27 deletions deepmd/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from deepmd.env import GLOBAL_NP_FLOAT_PRECISION
from deepmd.env import GLOBAL_ENER_FLOAT_PRECISION
from deepmd.utils import random as dp_random
from deepmd.utils.path import DPPath

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -44,17 +45,18 @@ def __init__ (self,
"""
Constructor
"""
self.dirs = glob.glob (os.path.join(sys_path, set_prefix + ".*"))
root = DPPath(sys_path)
self.dirs = root.glob(set_prefix + ".*")
self.dirs.sort()
# load atom type
self.atom_type = self._load_type(sys_path)
self.atom_type = self._load_type(root)
self.natoms = len(self.atom_type)
# load atom type map
self.type_map = self._load_type_map(sys_path)
self.type_map = self._load_type_map(root)
if self.type_map is not None:
assert(len(self.type_map) >= max(self.atom_type)+1)
# check pbc
self.pbc = self._check_pbc(sys_path)
self.pbc = self._check_pbc(root)
# enforce type_map if necessary
if type_map is not None and self.type_map is not None:
atom_type_ = [type_map.index(self.type_map[ii]) for ii in self.atom_type]
Expand Down Expand Up @@ -167,9 +169,9 @@ def check_batch_size (self, batch_size) :
"""
for ii in self.train_dirs :
if self.data_dict['coord']['high_prec'] :
tmpe = np.load(os.path.join(ii, "coord.npy")).astype(GLOBAL_ENER_FLOAT_PRECISION)
tmpe = (ii / "coord.npy").load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION)
else:
tmpe = np.load(os.path.join(ii, "coord.npy")).astype(GLOBAL_NP_FLOAT_PRECISION)
tmpe = (ii / "coord.npy").load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION)
if tmpe.ndim == 1:
tmpe = tmpe.reshape([1,-1])
if tmpe.shape[0] < batch_size :
Expand All @@ -181,9 +183,9 @@ def check_test_size (self, test_size) :
Check if the system can get a test dataset with `test_size` frames.
"""
if self.data_dict['coord']['high_prec'] :
tmpe = np.load(os.path.join(self.test_dir, "coord.npy")).astype(GLOBAL_ENER_FLOAT_PRECISION)
tmpe = (self.test_dir / "coord.npy").load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION)
else:
tmpe = np.load(os.path.join(self.test_dir, "coord.npy")).astype(GLOBAL_NP_FLOAT_PRECISION)
tmpe = (self.test_dir / "coord.npy").load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION)
if tmpe.ndim == 1:
tmpe = tmpe.reshape([1,-1])
if tmpe.shape[0] < test_size :
Expand Down Expand Up @@ -377,7 +379,7 @@ def _get_subdata(self, data, idx = None) :
return new_data

def _load_batch_set (self,
set_name) :
set_name: DPPath) :
self.batch_set = self._load_set(set_name)
self.batch_set, _ = self._shuffle_data(self.batch_set)
self.reset_get_batch()
Expand All @@ -386,7 +388,7 @@ def reset_get_batch(self):
self.iterator = 0

def _load_test_set (self,
set_name,
set_name: DPPath,
shuffle_test) :
self.test_set = self._load_set(set_name)
if shuffle_test :
Expand All @@ -409,13 +411,15 @@ def _shuffle_data (self,
ret[kk] = data[kk]
return ret, idx

def _load_set(self, set_name) :
def _load_set(self, set_name: DPPath) :
# get nframes
path = os.path.join(set_name, "coord.npy")
if not isinstance(set_name, DPPath):
set_name = DPPath(set_name)
path = set_name / "coord.npy"
if self.data_dict['coord']['high_prec'] :
coord = np.load(path).astype(GLOBAL_ENER_FLOAT_PRECISION)
coord = path.load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION)
else:
coord = np.load(path).astype(GLOBAL_NP_FLOAT_PRECISION)
coord = path.load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION)
if coord.ndim == 1:
coord = coord.reshape([1,-1])
nframes = coord.shape[0]
Expand Down Expand Up @@ -459,12 +463,12 @@ def _load_data(self, set_name, key, nframes, ndof_, atomic = False, must = True,
ndof = ndof_ * natoms
else:
ndof = ndof_
path = os.path.join(set_name, key+".npy")
if os.path.isfile (path) :
path = set_name / (key+".npy")
if path.is_file() :
if high_prec :
data = np.load(path).astype(GLOBAL_ENER_FLOAT_PRECISION)
data = path.load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION)
else:
data = np.load(path).astype(GLOBAL_NP_FLOAT_PRECISION)
data = path.load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION)
try: # YWolfeee: deal with data shape error
if atomic :
data = data.reshape([nframes, natoms, -1])
Expand All @@ -491,8 +495,8 @@ def _load_data(self, set_name, key, nframes, ndof_, atomic = False, must = True,
return np.float32(0.0), data


def _load_type (self, sys_path) :
atom_type = np.loadtxt (os.path.join(sys_path, "type.raw"), dtype=np.int32, ndmin=1)
def _load_type (self, sys_path: DPPath) :
atom_type = (sys_path / "type.raw").load_txt(dtype=np.int32, ndmin=1)
return atom_type

def _make_idx_map(self, atom_type):
Expand All @@ -501,17 +505,16 @@ def _make_idx_map(self, atom_type):
idx_map = np.lexsort ((idx, atom_type))
return idx_map

def _load_type_map(self, sys_path) :
fname = os.path.join(sys_path, 'type_map.raw')
if os.path.isfile(fname) :
with open(os.path.join(sys_path, 'type_map.raw')) as fp:
return fp.read().split()
def _load_type_map(self, sys_path: DPPath) :
fname = sys_path / 'type_map.raw'
if fname.is_file() :
return fname.load_txt(dtype=str).tolist()
else :
return None

def _check_pbc(self, sys_path):
def _check_pbc(self, sys_path: DPPath):
pbc = True
if os.path.isfile(os.path.join(sys_path, 'nopbc')) :
if (sys_path / 'nopbc').is_file() :
pbc = False
return pbc

Expand Down
Loading