Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions dpdata/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,74 @@ def label(self, data: dict) -> dict:
labeled_data['energies'] += lb_data ['energies']
labeled_data['forces'] += lb_data ['forces']
return labeled_data


class Minimizer(ABC):
"""The base class for a minimizer plugin. A minimizer can
minimize geometry.
"""
__MinimizerPlugin = Plugin()

@staticmethod
def register(key: str) -> Callable:
"""Register a minimizer plugin. Used as decorators.

Parameter
---------
key: str
key of the plugin.

Returns
-------
Callable
decorator of a class

Examples
--------
>>> @Minimizer.register("some_minimizer")
... class SomeMinimizer(Minimizer):
... pass
"""
return Minimizer.__MinimizerPlugin.register(key)

@staticmethod
def get_minimizer(key: str) -> "Minimizer":
"""Get a minimizer plugin.

Parameter
---------
key: str
key of the plugin.

Returns
-------
Minimizer
the specific minimizer class

Raises
------
RuntimeError
if the requested minimizer is not implemented
"""
try:
return Minimizer.__MinimizerPlugin.plugins[key]
except KeyError as e:
raise RuntimeError('Unknown minimizer: ' + key) from e

def __init__(self, *args, **kwargs) -> None:
"""Setup the minimizer."""

@abstractmethod
def minimize(self, data: dict) -> dict:
"""Minimize the geometry.

Parameters
----------
data : dict
data with coordinates and atom types

Returns
-------
dict
labeled data with minimized coordinates, energies, and forces
"""
20 changes: 19 additions & 1 deletion dpdata/plugins/amber.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import dpdata.amber.md
import dpdata.amber.sqm
from dpdata.format import Format
from dpdata.driver import Driver
from dpdata.driver import Driver, Minimizer


@Format.register("amber/md")
Expand Down Expand Up @@ -122,3 +122,21 @@ def label(self, data: dict) -> dict:
) from e
labeled_system.append(dpdata.LabeledSystem(out_fn, fmt="sqm/out"))
return labeled_system.data


@Minimizer.register("sqm")
class SQMMinimizer(Minimizer):
"""SQM minimizer.

Parameters
----------
maxcyc : int, default=1000
maximun cycle to minimize
"""
def __init__(self, maxcyc=1000, *args, **kwargs) -> None:
assert maxcyc > 0, "maxcyc should be more than 0 to minimize"
self.driver = SQMDriver(maxcyc=maxcyc, **kwargs)

def minimize(self, data: dict) -> dict:
# sqm has minimize feature
return self.driver.label(data)
63 changes: 62 additions & 1 deletion dpdata/plugins/ase.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from dpdata.driver import Driver
from typing import TYPE_CHECKING, Type
from dpdata.driver import Driver, Minimizer
from dpdata.format import Format
import numpy as np
import dpdata
try:
import ase.io
from ase.calculators.calculator import PropertyNotImplementedError
if TYPE_CHECKING:
from ase.optimize.optimize import Optimizer
except ImportError:
pass

Expand Down Expand Up @@ -204,3 +207,61 @@ def label(self, data: dict) -> dict:
ls = dpdata.LabeledSystem(atoms, fmt="ase/structure", type_map=data['atom_names'])
labeled_system.append(ls)
return labeled_system.data


@Minimizer.register("ase")
class ASEMinimizer(Minimizer):
"""ASE minimizer.

Parameters
----------
driver : Driver
dpdata driver
optimizer : type, optional
ase optimizer class
fmax : float, optional, default=5e-3
force convergence criterion
optimizer_kwargs : dict, optional
other parameters for optimizer
"""
def __init__(self,
driver: Driver,
optimizer: Type["Optimizer"] = None,
fmax: float = 5e-3,
optimizer_kwargs: dict = {}) -> None:
self.calculator = driver.ase_calculator
if optimizer is None:
from ase.optimize import LBFGS
self.optimizer = LBFGS
else:
self.optimizer = optimizer
self.optimizer_kwargs = {
"logfile": None,
**optimizer_kwargs.copy(),
}
self.fmax = fmax

def minimize(self, data: dict) -> dict:
"""Minimize the geometry.

Parameters
----------
data : dict
data with coordinates and atom types

Returns
-------
dict
labeled data with minimized coordinates, energies, and forces
"""
system = dpdata.System(data=data)
# list[Atoms]
structures = system.to_ase_structure()
labeled_system = dpdata.LabeledSystem()
for atoms in structures:
atoms.calc = self.calculator
dyn = self.optimizer(atoms, **self.optimizer_kwargs)
dyn.run(fmax=self.fmax)
ls = dpdata.LabeledSystem(atoms, fmt="ase/structure", type_map=data['atom_names'])
labeled_system.append(ls)
return labeled_system.data
61 changes: 58 additions & 3 deletions dpdata/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import dpdata.md.pbc
from copy import deepcopy
from enum import Enum, unique
from typing import Any, Tuple
from typing import Any, Tuple, Union
from monty.json import MSONable
from monty.serialization import loadfn,dumpfn
from dpdata.periodic_table import Element
Expand All @@ -17,7 +17,7 @@
import dpdata.plugins
from dpdata.plugin import Plugin
from dpdata.format import Format
from dpdata.driver import Driver
from dpdata.driver import Driver, Minimizer

from dpdata.utils import (
elements_index_map,
Expand Down Expand Up @@ -869,6 +869,28 @@ def predict(self, *args: Any, driver: str="dp", **kwargs: Any) -> "LabeledSystem
data = driver.label(self.data.copy())
return LabeledSystem(data=data)

def minimize(self, *args: Any, minimizer: Union[str, Minimizer], **kwargs: Any) -> "LabeledSystem":
"""Minimize the geometry.

Parameters
----------
*args : iterable
Arguments passing to the minimizer
minimizer : str or Minimizer
The assigned minimizer
**kwargs : dict
Other arguments passing to the minimizer

Returns
-------
labeled_sys : LabeledSystem
A new labeled system.
"""
if not isinstance(minimizer, Minimizer):
minimizer = Minimizer.get_minimizer(minimizer)(*args, **kwargs)
data = minimizer.minimize(self.data.copy())
return LabeledSystem(data=data)

def pick_atom_idx(self, idx, nopbc=None):
"""Pick atom index

Expand Down Expand Up @@ -1308,10 +1330,43 @@ def predict(self, *args: Any, driver="dp", **kwargs: Any) -> "MultiSystems":
"""
if not isinstance(driver, Driver):
driver = Driver.get_driver(driver)(*args, **kwargs)
new_multisystems = dpdata.MultiSystems()
new_multisystems = dpdata.MultiSystems(type_map=self.atom_names)
for ss in self:
new_multisystems.append(ss.predict(*args, driver=driver, **kwargs))
return new_multisystems

def minimize(self, *args: Any, minimizer: Union[str, Minimizer], **kwargs: Any) -> "MultiSystems":
"""
Minimize geometry by a minimizer.

Parameters
----------
*args : iterable
Arguments passing to the minimizer
minimizer : str or Minimizer
The assigned minimizer
**kwargs : dict
Other arguments passing to the minimizer

Returns
-------
MultiSystems
A new labeled MultiSystems.

Examples
--------
Minimize a system using ASE BFGS along with a DP driver:
>>> from dpdata.driver import Driver
>>> from ase.optimize import BFGS
>>> driver = driver.get_driver("dp")("some_model.pb")
>>> some_system.minimize(minimizer="ase", driver=driver, optimizer=BFGS, fmax=1e-5)
"""
if not isinstance(minimizer, Minimizer):
minimizer = Minimizer.get_minimizer(minimizer)(*args, **kwargs)
new_multisystems = dpdata.MultiSystems(type_map=self.atom_names)
for ss in self:
new_multisystems.append(ss.minimize(*args, minimizer=minimizer, **kwargs))
return new_multisystems

def pick_atom_idx(self, idx, nopbc=None):
"""Pick atom index
Expand Down
33 changes: 32 additions & 1 deletion tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def setUp(self) :


@unittest.skipIf(skip_ase,"skip ase related test. install ase to fix")
class TestASEtraj1(unittest.TestCase, CompLabeledSys, IsPBC):
class TestASEDriver(unittest.TestCase, CompLabeledSys, IsPBC):
def setUp (self) :
ori_sys = dpdata.LabeledSystem('poscars/deepmd.h2o.md',
fmt = 'deepmd/raw',
Expand All @@ -90,3 +90,34 @@ def setUp (self) :
self.e_places = 6
self.f_places = 6
self.v_places = 4


@unittest.skipIf(skip_ase, "skip ase related test. install ase to fix")
class TestMinimize(unittest.TestCase, CompLabeledSys, IsPBC):
def setUp (self) :
ori_sys = dpdata.LabeledSystem('poscars/deepmd.h2o.md',
fmt = 'deepmd/raw',
type_map = ['O', 'H'])
zero_driver = ZeroDriver()
self.system_1 = ori_sys.predict(driver=zero_driver)
self.system_2 = ori_sys.minimize(driver=zero_driver, minimizer="ase")
self.places = 6
self.e_places = 6
self.f_places = 6
self.v_places = 4


@unittest.skipIf(skip_ase, "skip ase related test. install ase to fix")
class TestMinimizeMultiSystems(unittest.TestCase, CompLabeledSys, IsPBC):
def setUp (self) :
ori_sys = dpdata.LabeledSystem('poscars/deepmd.h2o.md',
fmt = 'deepmd/raw',
type_map = ['O', 'H'])
multi_sys = dpdata.MultiSystems(ori_sys)
zero_driver = ZeroDriver()
self.system_1 = list(multi_sys.predict(driver=zero_driver).systems.values())[0]
self.system_2 = list(multi_sys.minimize(driver=zero_driver, minimizer="ase").systems.values())[0]
self.places = 6
self.e_places = 6
self.f_places = 6
self.v_places = 4