Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion dpdata/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from copy import deepcopy
from enum import Enum, unique
from typing import Any, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
from monty.json import MSONable
Expand Down Expand Up @@ -1584,6 +1584,63 @@ def correction(self, hl_sys: "MultiSystems"):
corrected_sys.append(ll_ss.correction(hl_ss))
return corrected_sys

def train_test_split(
self, test_size: Union[float, int], seed: Optional[int] = None
) -> Tuple["MultiSystems", "MultiSystems", Dict[str, np.ndarray]]:
"""Split systems into random train and test subsets.

Parameters
----------
test_size : float or int
If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split.
If int, represents the absolute number of test samples.
seed : int, default=None
Random seed

Returns
-------
MultiSystems
The training set
MultiSystems
The testing set
Dict[str, np.ndarray]
The bool array of training and testing sets for each system. False for training set and True for testing set.
"""
nframes = self.get_nframes()
if isinstance(test_size, float):
assert 0 <= test_size <= 1
test_size = int(np.floor(test_size * nframes))
elif isinstance(test_size, int):
assert 0 <= test_size <= nframes
else:
raise RuntimeError("test_size should be float or int")
# get random indices
rng = np.random.default_rng(seed=seed)
test_idx = rng.choice(nframes, test_size, replace=False)
select_test = np.zeros(nframes, dtype=bool)
select_test[test_idx] = True
select_train = np.logical_not(select_test)
# flatten systems dict
system_names, system_sizes = zip(
*((kk, len(vv)) for (kk, vv) in self.systems.items())
)
system_idx = np.empty(len(system_sizes) + 1, dtype=int)
system_idx[0] = 0
np.cumsum(system_sizes, out=system_idx[1:])
# make new systems
train_systems = MultiSystems(type_map=self.atom_names)
test_systems = MultiSystems(type_map=self.atom_names)
test_system_idx = {}
for ii, nn in enumerate(system_names):
sub_train = self[nn][select_train[system_idx[ii] : system_idx[ii + 1]]]
if len(sub_train):
train_systems.append(sub_train)
sub_test = self[nn][select_test[system_idx[ii] : system_idx[ii + 1]]]
if len(sub_test):
test_systems.append(sub_test)
test_system_idx[nn] = select_test[system_idx[ii] : system_idx[ii + 1]]
return train_systems, test_systems, test_system_idx


def get_cls_name(cls: object) -> str:
"""Returns the fully qualified name of a class, such as `np.ndarray`.
Expand Down
25 changes: 25 additions & 0 deletions tests/test_split_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import unittest

import numpy as np
from context import dpdata


class TestSplitDataset(unittest.TestCase):
def setUp(self):
self.systems = dpdata.MultiSystems()
sing_sys = dpdata.LabeledSystem("poscars/OUTCAR.h2o.md", fmt="vasp/outcar")
for ii in range(10):
self.systems.append(sing_sys.copy())

def test_split_dataset(self):
train, test, test_idx = self.systems.train_test_split(0.2)
self.assertEqual(
train.get_nframes(), int(np.floor(self.systems.get_nframes() * 0.8))
)
self.assertEqual(
test.get_nframes(), int(np.floor(self.systems.get_nframes() * 0.2))
)
self.assertEqual(
sum([np.count_nonzero(x) for x in test_idx.values()]),
int(np.floor(self.systems.get_nframes() * 0.2)),
)