diff --git a/setup.py b/setup.py index cf83e4955f..729d62650d 100644 --- a/setup.py +++ b/setup.py @@ -104,7 +104,7 @@ cmake_source_dir="source", cmake_minimum_required_version="3.0", extras_require={ - "test": ["dpdata>=0.1.9", "pytest", "pytest-cov", "pytest-sugar"], + "test": ["dpdata>=0.1.9", "ase", "pytest", "pytest-cov", "pytest-sugar"], "docs": ["sphinx", "recommonmark", "sphinx_rtd_theme"], **extras_require, }, diff --git a/source/tests/test_deeppot_a.py b/source/tests/test_deeppot_a.py index 8541cb8fa3..f4e73682f3 100644 --- a/source/tests/test_deeppot_a.py +++ b/source/tests/test_deeppot_a.py @@ -324,4 +324,19 @@ def test_1frame_atm(self): for ii in range(nframes, 9): self.assertAlmostEqual(vv.reshape([-1])[ii], expected_sv.reshape([-1])[ii], places = default_places) + def test_ase(self): + from ase import Atoms + from deepmd.calculator import DP + water = Atoms('OHHOHH', + positions=self.coords.reshape((-1,3)), + cell=self.box.reshape((3,3)), + calculator=DP("deeppot.pb")) + ee = water.get_potential_energy() + ff = water.get_forces() + nframes = 1 + for ii in range(ff.size): + self.assertAlmostEqual(ff.reshape([-1])[ii], self.expected_f.reshape([-1])[ii], places = default_places) + expected_se = np.sum(self.expected_e.reshape([nframes, -1]), axis = 1) + for ii in range(nframes): + self.assertAlmostEqual(ee.reshape([-1])[ii], expected_se.reshape([-1])[ii], places = default_places) diff --git a/source/train/calculator.py b/source/train/calculator.py index e9c8e63c7d..7ca7943d7a 100644 --- a/source/train/calculator.py +++ b/source/train/calculator.py @@ -1,13 +1,12 @@ """ASE calculator interface module.""" from typing import TYPE_CHECKING, Dict, List, Optional, Union +from pathlib import Path from deepmd import DeepPotential from ase.calculators.calculator import Calculator, all_changes if TYPE_CHECKING: - from pathlib import Path - from ase import Atoms __all__ = ["DP"] @@ -62,7 +61,7 @@ def __init__( **kwargs ) -> None: Calculator.__init__(self, label=label, **kwargs) - self.dp = DeepPot(str(Path(model).resolve())) + self.dp = DeepPotential(str(Path(model).resolve())) if type_dict: self.type_dict = type_dict else: