From b2f4de4f3c910fd4712afb1804b976eb5a3cf11a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 24 Mar 2023 18:17:46 -0400 Subject: [PATCH] avoid decreasing precision for ASE atoms --- dpdata/plugins/ase.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dpdata/plugins/ase.py b/dpdata/plugins/ase.py index d6ed4f02a..f8dc4f8b2 100644 --- a/dpdata/plugins/ase.py +++ b/dpdata/plugins/ase.py @@ -52,8 +52,8 @@ def from_system(self, atoms: "ase.Atoms", **kwargs) -> dict: "atom_names": atom_names, "atom_numbs": atom_numbs, "atom_types": atom_types, - "cells": np.array([cells]).astype("float32"), - "coords": np.array([coords]).astype("float32"), + "cells": np.array([cells]), + "coords": np.array([coords]), "orig": np.zeros(3), "nopbc": not np.any(atoms.get_pbc()), } @@ -87,15 +87,15 @@ def from_labeled_system(self, atoms: "ase.Atoms", **kwargs) -> dict: forces = atoms.get_forces() info_dict = { **info_dict, - "energies": np.array([energies]).astype("float32"), - "forces": np.array([forces]).astype("float32"), + "energies": np.array([energies]), + "forces": np.array([forces]), } try: stress = atoms.get_stress(False) except PropertyNotImplementedError: pass else: - virials = np.array([-atoms.get_volume() * stress]).astype("float32") + virials = np.array([-atoms.get_volume() * stress]) info_dict["virials"] = virials return info_dict