diff --git a/dpdata/system.py b/dpdata/system.py index 802b352c5..c05cb0d1e 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -1271,6 +1271,45 @@ def correction(self, hl_sys): ) return corrected_sys + def remove_outlier(self, threshold: float = 8.0) -> "LabeledSystem": + r"""Remove outlier frames from the system. + + Remove the frames whose energies satisfy the condition + + .. math:: + + \frac{\left \| E - \bar{E} \right \|}{\sigma(E)} \geq \text{threshold} + + where :math:`\bar{E}` and :math:`\sigma(E)` are the mean and standard deviation + of the energies in the system. + + Parameters + ---------- + threshold : float + The threshold of outlier detection. The default value is 8.0. + + Returns + ------- + LabeledSystem + The system without outlier frames. + + References + ---------- + .. [1] Gao, X.; Ramezanghorbani, F.; Isayev, O.; Smith, J. S.; + Roitberg, A. E. TorchANI: A Free and Open Source PyTorch-Based + Deep Learning Implementation of the ANI Neural Network + Potentials. J. Chem. Inf. Model. 2020, 60, 3408-3415. + .. [2] Zeng, J.; Tao, Y.; Giese, T. J.; York, D. M.. QDπ: A Quantum + Deep Potential Interaction Model for Drug Discovery. J. Comput. + Chem. 2023, 19, 1261-1275. + """ + energies = self.data["energies"] + std = np.std(energies) + if np.isclose(std, 0.0): + return self.copy() + idx = np.abs(energies - np.mean(energies)) / std < threshold + return self.sub_system(idx) + class MultiSystems: """A set containing several systems.""" diff --git a/tests/test_remove_outlier.py b/tests/test_remove_outlier.py new file mode 100644 index 000000000..192f4d8f0 --- /dev/null +++ b/tests/test_remove_outlier.py @@ -0,0 +1,55 @@ +import os +import unittest + +import numpy as np +from comp_sys import CompLabeledSys +from context import dpdata + + +class TestRemoveOutlier(unittest.TestCase, CompLabeledSys): + @classmethod + def setUpClass(cls): + system = dpdata.LabeledSystem( + data={ + "atom_names": ["H"], + "atom_numbs": [1], + "atom_types": np.zeros((1,), dtype=int), + "coords": np.zeros((100, 1, 3), dtype=np.float32), + "cells": np.zeros((100, 3, 3), dtype=np.float32), + "orig": np.zeros(3, dtype=np.float32), + "nopbc": True, + "energies": np.zeros((100,), dtype=np.float32), + "forces": np.zeros((100, 1, 3), dtype=np.float32), + } + ) + system.data["energies"][0] = 100.0 + cls.system_1 = system.remove_outlier() + cls.system_2 = system[1:] + cls.places = 6 + cls.e_places = 6 + cls.f_places = 6 + cls.v_places = 6 + + +class TestRemoveOutlierStdZero(unittest.TestCase, CompLabeledSys): + @classmethod + def setUpClass(cls): + system = dpdata.LabeledSystem( + data={ + "atom_names": ["H"], + "atom_numbs": [1], + "atom_types": np.zeros((1,), dtype=int), + "coords": np.zeros((100, 1, 3), dtype=np.float32), + "cells": np.zeros((100, 3, 3), dtype=np.float32), + "orig": np.zeros(3, dtype=np.float32), + "nopbc": True, + "energies": np.zeros((100,), dtype=np.float32), + "forces": np.zeros((100, 1, 3), dtype=np.float32), + } + ) + cls.system_1 = system.remove_outlier() + cls.system_2 = system + cls.places = 6 + cls.e_places = 6 + cls.f_places = 6 + cls.v_places = 6