diff --git a/dpdata/system.py b/dpdata/system.py index 5e0c96e7d..743f3d057 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -3,7 +3,7 @@ import os from copy import deepcopy from enum import Enum, unique -from typing import Any, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np from monty.json import MSONable @@ -1584,6 +1584,63 @@ def correction(self, hl_sys: "MultiSystems"): corrected_sys.append(ll_ss.correction(hl_ss)) return corrected_sys + def train_test_split( + self, test_size: Union[float, int], seed: Optional[int] = None + ) -> Tuple["MultiSystems", "MultiSystems", Dict[str, np.ndarray]]: + """Split systems into random train and test subsets. + + Parameters + ---------- + test_size : float or int + If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. + If int, represents the absolute number of test samples. + seed : int, default=None + Random seed + + Returns + ------- + MultiSystems + The training set + MultiSystems + The testing set + Dict[str, np.ndarray] + The bool array of training and testing sets for each system. False for training set and True for testing set. + """ + nframes = self.get_nframes() + if isinstance(test_size, float): + assert 0 <= test_size <= 1 + test_size = int(np.floor(test_size * nframes)) + elif isinstance(test_size, int): + assert 0 <= test_size <= nframes + else: + raise RuntimeError("test_size should be float or int") + # get random indices + rng = np.random.default_rng(seed=seed) + test_idx = rng.choice(nframes, test_size, replace=False) + select_test = np.zeros(nframes, dtype=bool) + select_test[test_idx] = True + select_train = np.logical_not(select_test) + # flatten systems dict + system_names, system_sizes = zip( + *((kk, len(vv)) for (kk, vv) in self.systems.items()) + ) + system_idx = np.empty(len(system_sizes) + 1, dtype=int) + system_idx[0] = 0 + np.cumsum(system_sizes, out=system_idx[1:]) + # make new systems + train_systems = MultiSystems(type_map=self.atom_names) + test_systems = MultiSystems(type_map=self.atom_names) + test_system_idx = {} + for ii, nn in enumerate(system_names): + sub_train = self[nn][select_train[system_idx[ii] : system_idx[ii + 1]]] + if len(sub_train): + train_systems.append(sub_train) + sub_test = self[nn][select_test[system_idx[ii] : system_idx[ii + 1]]] + if len(sub_test): + test_systems.append(sub_test) + test_system_idx[nn] = select_test[system_idx[ii] : system_idx[ii + 1]] + return train_systems, test_systems, test_system_idx + def get_cls_name(cls: object) -> str: """Returns the fully qualified name of a class, such as `np.ndarray`. diff --git a/tests/test_split_dataset.py b/tests/test_split_dataset.py new file mode 100644 index 000000000..a5419b7b1 --- /dev/null +++ b/tests/test_split_dataset.py @@ -0,0 +1,25 @@ +import unittest + +import numpy as np +from context import dpdata + + +class TestSplitDataset(unittest.TestCase): + def setUp(self): + self.systems = dpdata.MultiSystems() + sing_sys = dpdata.LabeledSystem("poscars/OUTCAR.h2o.md", fmt="vasp/outcar") + for ii in range(10): + self.systems.append(sing_sys.copy()) + + def test_split_dataset(self): + train, test, test_idx = self.systems.train_test_split(0.2) + self.assertEqual( + train.get_nframes(), int(np.floor(self.systems.get_nframes() * 0.8)) + ) + self.assertEqual( + test.get_nframes(), int(np.floor(self.systems.get_nframes() * 0.2)) + ) + self.assertEqual( + sum([np.count_nonzero(x) for x in test_idx.values()]), + int(np.floor(self.systems.get_nframes() * 0.2)), + )