diff --git a/dpdata/driver.py b/dpdata/driver.py index 97583a101..670b03378 100644 --- a/dpdata/driver.py +++ b/dpdata/driver.py @@ -147,3 +147,74 @@ def label(self, data: dict) -> dict: labeled_data['energies'] += lb_data ['energies'] labeled_data['forces'] += lb_data ['forces'] return labeled_data + + +class Minimizer(ABC): + """The base class for a minimizer plugin. A minimizer can + minimize geometry. + """ + __MinimizerPlugin = Plugin() + + @staticmethod + def register(key: str) -> Callable: + """Register a minimizer plugin. Used as decorators. + + Parameter + --------- + key: str + key of the plugin. + + Returns + ------- + Callable + decorator of a class + + Examples + -------- + >>> @Minimizer.register("some_minimizer") + ... class SomeMinimizer(Minimizer): + ... pass + """ + return Minimizer.__MinimizerPlugin.register(key) + + @staticmethod + def get_minimizer(key: str) -> "Minimizer": + """Get a minimizer plugin. + + Parameter + --------- + key: str + key of the plugin. + + Returns + ------- + Minimizer + the specific minimizer class + + Raises + ------ + RuntimeError + if the requested minimizer is not implemented + """ + try: + return Minimizer.__MinimizerPlugin.plugins[key] + except KeyError as e: + raise RuntimeError('Unknown minimizer: ' + key) from e + + def __init__(self, *args, **kwargs) -> None: + """Setup the minimizer.""" + + @abstractmethod + def minimize(self, data: dict) -> dict: + """Minimize the geometry. + + Parameters + ---------- + data : dict + data with coordinates and atom types + + Returns + ------- + dict + labeled data with minimized coordinates, energies, and forces + """ diff --git a/dpdata/plugins/amber.py b/dpdata/plugins/amber.py index f9580f2fd..4fe41c1e4 100644 --- a/dpdata/plugins/amber.py +++ b/dpdata/plugins/amber.py @@ -5,7 +5,7 @@ import dpdata.amber.md import dpdata.amber.sqm from dpdata.format import Format -from dpdata.driver import Driver +from dpdata.driver import Driver, Minimizer @Format.register("amber/md") @@ -122,3 +122,21 @@ def label(self, data: dict) -> dict: ) from e labeled_system.append(dpdata.LabeledSystem(out_fn, fmt="sqm/out")) return labeled_system.data + + +@Minimizer.register("sqm") +class SQMMinimizer(Minimizer): + """SQM minimizer. + + Parameters + ---------- + maxcyc : int, default=1000 + maximun cycle to minimize + """ + def __init__(self, maxcyc=1000, *args, **kwargs) -> None: + assert maxcyc > 0, "maxcyc should be more than 0 to minimize" + self.driver = SQMDriver(maxcyc=maxcyc, **kwargs) + + def minimize(self, data: dict) -> dict: + # sqm has minimize feature + return self.driver.label(data) diff --git a/dpdata/plugins/ase.py b/dpdata/plugins/ase.py index 61bc14dc1..f61d46f6c 100644 --- a/dpdata/plugins/ase.py +++ b/dpdata/plugins/ase.py @@ -1,10 +1,13 @@ -from dpdata.driver import Driver +from typing import TYPE_CHECKING, Type +from dpdata.driver import Driver, Minimizer from dpdata.format import Format import numpy as np import dpdata try: import ase.io from ase.calculators.calculator import PropertyNotImplementedError + if TYPE_CHECKING: + from ase.optimize.optimize import Optimizer except ImportError: pass @@ -204,3 +207,61 @@ def label(self, data: dict) -> dict: ls = dpdata.LabeledSystem(atoms, fmt="ase/structure", type_map=data['atom_names']) labeled_system.append(ls) return labeled_system.data + + +@Minimizer.register("ase") +class ASEMinimizer(Minimizer): + """ASE minimizer. + + Parameters + ---------- + driver : Driver + dpdata driver + optimizer : type, optional + ase optimizer class + fmax : float, optional, default=5e-3 + force convergence criterion + optimizer_kwargs : dict, optional + other parameters for optimizer + """ + def __init__(self, + driver: Driver, + optimizer: Type["Optimizer"] = None, + fmax: float = 5e-3, + optimizer_kwargs: dict = {}) -> None: + self.calculator = driver.ase_calculator + if optimizer is None: + from ase.optimize import LBFGS + self.optimizer = LBFGS + else: + self.optimizer = optimizer + self.optimizer_kwargs = { + "logfile": None, + **optimizer_kwargs.copy(), + } + self.fmax = fmax + + def minimize(self, data: dict) -> dict: + """Minimize the geometry. + + Parameters + ---------- + data : dict + data with coordinates and atom types + + Returns + ------- + dict + labeled data with minimized coordinates, energies, and forces + """ + system = dpdata.System(data=data) + # list[Atoms] + structures = system.to_ase_structure() + labeled_system = dpdata.LabeledSystem() + for atoms in structures: + atoms.calc = self.calculator + dyn = self.optimizer(atoms, **self.optimizer_kwargs) + dyn.run(fmax=self.fmax) + ls = dpdata.LabeledSystem(atoms, fmt="ase/structure", type_map=data['atom_names']) + labeled_system.append(ls) + return labeled_system.data \ No newline at end of file diff --git a/dpdata/system.py b/dpdata/system.py index 0600f5634..a2226d030 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -6,7 +6,7 @@ import dpdata.md.pbc from copy import deepcopy from enum import Enum, unique -from typing import Any, Tuple +from typing import Any, Tuple, Union from monty.json import MSONable from monty.serialization import loadfn,dumpfn from dpdata.periodic_table import Element @@ -17,7 +17,7 @@ import dpdata.plugins from dpdata.plugin import Plugin from dpdata.format import Format -from dpdata.driver import Driver +from dpdata.driver import Driver, Minimizer from dpdata.utils import ( elements_index_map, @@ -869,6 +869,28 @@ def predict(self, *args: Any, driver: str="dp", **kwargs: Any) -> "LabeledSystem data = driver.label(self.data.copy()) return LabeledSystem(data=data) + def minimize(self, *args: Any, minimizer: Union[str, Minimizer], **kwargs: Any) -> "LabeledSystem": + """Minimize the geometry. + + Parameters + ---------- + *args : iterable + Arguments passing to the minimizer + minimizer : str or Minimizer + The assigned minimizer + **kwargs : dict + Other arguments passing to the minimizer + + Returns + ------- + labeled_sys : LabeledSystem + A new labeled system. + """ + if not isinstance(minimizer, Minimizer): + minimizer = Minimizer.get_minimizer(minimizer)(*args, **kwargs) + data = minimizer.minimize(self.data.copy()) + return LabeledSystem(data=data) + def pick_atom_idx(self, idx, nopbc=None): """Pick atom index @@ -1308,10 +1330,43 @@ def predict(self, *args: Any, driver="dp", **kwargs: Any) -> "MultiSystems": """ if not isinstance(driver, Driver): driver = Driver.get_driver(driver)(*args, **kwargs) - new_multisystems = dpdata.MultiSystems() + new_multisystems = dpdata.MultiSystems(type_map=self.atom_names) for ss in self: new_multisystems.append(ss.predict(*args, driver=driver, **kwargs)) return new_multisystems + + def minimize(self, *args: Any, minimizer: Union[str, Minimizer], **kwargs: Any) -> "MultiSystems": + """ + Minimize geometry by a minimizer. + + Parameters + ---------- + *args : iterable + Arguments passing to the minimizer + minimizer : str or Minimizer + The assigned minimizer + **kwargs : dict + Other arguments passing to the minimizer + + Returns + ------- + MultiSystems + A new labeled MultiSystems. + + Examples + -------- + Minimize a system using ASE BFGS along with a DP driver: + >>> from dpdata.driver import Driver + >>> from ase.optimize import BFGS + >>> driver = driver.get_driver("dp")("some_model.pb") + >>> some_system.minimize(minimizer="ase", driver=driver, optimizer=BFGS, fmax=1e-5) + """ + if not isinstance(minimizer, Minimizer): + minimizer = Minimizer.get_minimizer(minimizer)(*args, **kwargs) + new_multisystems = dpdata.MultiSystems(type_map=self.atom_names) + for ss in self: + new_multisystems.append(ss.minimize(*args, minimizer=minimizer, **kwargs)) + return new_multisystems def pick_atom_idx(self, idx, nopbc=None): """Pick atom index diff --git a/tests/test_predict.py b/tests/test_predict.py index 8535d02ea..3ba62ec23 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -78,7 +78,7 @@ def setUp(self) : @unittest.skipIf(skip_ase,"skip ase related test. install ase to fix") -class TestASEtraj1(unittest.TestCase, CompLabeledSys, IsPBC): +class TestASEDriver(unittest.TestCase, CompLabeledSys, IsPBC): def setUp (self) : ori_sys = dpdata.LabeledSystem('poscars/deepmd.h2o.md', fmt = 'deepmd/raw', @@ -90,3 +90,34 @@ def setUp (self) : self.e_places = 6 self.f_places = 6 self.v_places = 4 + + +@unittest.skipIf(skip_ase, "skip ase related test. install ase to fix") +class TestMinimize(unittest.TestCase, CompLabeledSys, IsPBC): + def setUp (self) : + ori_sys = dpdata.LabeledSystem('poscars/deepmd.h2o.md', + fmt = 'deepmd/raw', + type_map = ['O', 'H']) + zero_driver = ZeroDriver() + self.system_1 = ori_sys.predict(driver=zero_driver) + self.system_2 = ori_sys.minimize(driver=zero_driver, minimizer="ase") + self.places = 6 + self.e_places = 6 + self.f_places = 6 + self.v_places = 4 + + +@unittest.skipIf(skip_ase, "skip ase related test. install ase to fix") +class TestMinimizeMultiSystems(unittest.TestCase, CompLabeledSys, IsPBC): + def setUp (self) : + ori_sys = dpdata.LabeledSystem('poscars/deepmd.h2o.md', + fmt = 'deepmd/raw', + type_map = ['O', 'H']) + multi_sys = dpdata.MultiSystems(ori_sys) + zero_driver = ZeroDriver() + self.system_1 = list(multi_sys.predict(driver=zero_driver).systems.values())[0] + self.system_2 = list(multi_sys.minimize(driver=zero_driver, minimizer="ase").systems.values())[0] + self.places = 6 + self.e_places = 6 + self.f_places = 6 + self.v_places = 4