From 22ead589bc6f62bd932a2d68107bc6958f998af6 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 13 Oct 2023 03:44:27 -0400 Subject: [PATCH 1/2] limit the filename length dumped by MultiSystems Fix #553. Signed-off-by: Jinzhe Zeng --- dpdata/system.py | 39 ++++++++++++++++++++++++++++++++++++-- dpdata/utils.py | 5 +++++ tests/test_multisystems.py | 34 +++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 2 deletions(-) diff --git a/dpdata/system.py b/dpdata/system.py index 0748db284..4aea8679f 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -1,5 +1,6 @@ # %% import glob +import hashlib import os import warnings from copy import deepcopy @@ -19,7 +20,7 @@ from dpdata.driver import Driver, Minimizer from dpdata.format import Format from dpdata.plugin import Plugin -from dpdata.utils import add_atom_names, elements_index_map, remove_pbc, sort_atom_names +from dpdata.utils import add_atom_names, elements_index_map, remove_pbc, sort_atom_names, utf8len def load_format(fmt): @@ -562,6 +563,40 @@ def uniq_formula(self): ] ) + @property + def short_formula(self) -> str: + """Return the short formula of this system. Elements with zero number + will be removed.""" + return "".join( + [ + f"{symbol}{numb}" + for symbol, numb in zip( + self.data["atom_names"], self.data["atom_numbs"] + ) if numb + ] + ) + + @property + def formula_hash(self) -> str: + """Return the hash of the formula of this system.""" + return hashlib.sha256(self.formula.encode("utf-8")).hexdigest() + + @property + def short_name(self) -> str: + """Return the short name of this system (no more than 255 bytes), in + the following order: + - formula + - short_formula + - formula_hash + """ + formula = self.formula + if utf8len(formula) <= 255: + return formula + short_formula = self.short_formula + if utf8len(short_formula) <= 255: + return short_formula + return self.formula_hash + def extend(self, systems): """Extend a system list to this system. @@ -1247,7 +1282,7 @@ def from_fmt_obj(self, fmtobj, directory, labeled=True, **kwargs): def to_fmt_obj(self, fmtobj, directory, *args, **kwargs): if not isinstance(fmtobj, dpdata.plugins.deepmd.DeePMDMixedFormat): for fn, ss in zip( - fmtobj.to_multi_systems(self.systems.keys(), directory, **kwargs), + fmtobj.to_multi_systems([ss.short_name for ss in self.systems.values()], directory, **kwargs), self.systems.values(), ): ss.to_fmt_obj(fmtobj, fn, *args, **kwargs) diff --git a/dpdata/utils.py b/dpdata/utils.py index da7261790..c626917a5 100644 --- a/dpdata/utils.py +++ b/dpdata/utils.py @@ -99,3 +99,8 @@ def uniq_atom_names(data): sum(ii == data["atom_types"]) for ii in range(len(data["atom_names"])) ] return data + + +def utf8len(s: str) -> int: + """Return the byte length of a string.""" + return len(s.encode('utf-8')) \ No newline at end of file diff --git a/tests/test_multisystems.py b/tests/test_multisystems.py index 172c2ad48..f2957a3e5 100644 --- a/tests/test_multisystems.py +++ b/tests/test_multisystems.py @@ -1,7 +1,10 @@ import os +import tempfile import unittest from itertools import permutations +import numpy as np + from comp_sys import CompLabeledSys, IsNoPBC, MultiSystems from context import dpdata @@ -200,5 +203,36 @@ def setUp(self): self.atom_names = ["C", "H", "O"] +class TestLongFilename(unittest.TestCase): + def test_long_filename1(self): + system = dpdata.System( + data = { + "atom_names": [f"TYPE{ii}" for ii in range(200)], + "atom_numbs": [1] + [0 for _ in range(199)], + "atom_types": np.arange(1), + "coords": np.zeros((1, 1, 3)), + "orig": np.zeros(3), + "cells": np.zeros((1, 3, 3)), + } + ) + ms = dpdata.MultiSystems(system) + with tempfile.TemporaryDirectory() as tmpdir: + ms.to_deepmd_npy(tmpdir) + + def test_long_filename2(self): + system = dpdata.System( + data = { + "atom_names": [f"TYPE{ii}" for ii in range(200)], + "atom_numbs": [1 for _ in range(200)], + "atom_types": np.arange(200), + "coords": np.zeros((1, 200, 3)), + "orig": np.zeros(3), + "cells": np.zeros((1, 3, 3)), + } + ) + ms = dpdata.MultiSystems(system) + with tempfile.TemporaryDirectory() as tmpdir: + ms.to_deepmd_npy(tmpdir) + if __name__ == "__main__": unittest.main() From 0b6f4ee90fb9405f20189805fec921b3e48eef63 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 14 Oct 2023 05:57:26 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dpdata/system.py | 20 +++++++++++++++----- dpdata/utils.py | 2 +- tests/test_multisystems.py | 6 +++--- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/dpdata/system.py b/dpdata/system.py index 4aea8679f..f1273104e 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -20,7 +20,13 @@ from dpdata.driver import Driver, Minimizer from dpdata.format import Format from dpdata.plugin import Plugin -from dpdata.utils import add_atom_names, elements_index_map, remove_pbc, sort_atom_names, utf8len +from dpdata.utils import ( + add_atom_names, + elements_index_map, + remove_pbc, + sort_atom_names, + utf8len, +) def load_format(fmt): @@ -566,13 +572,15 @@ def uniq_formula(self): @property def short_formula(self) -> str: """Return the short formula of this system. Elements with zero number - will be removed.""" + will be removed. + """ return "".join( [ f"{symbol}{numb}" for symbol, numb in zip( self.data["atom_names"], self.data["atom_numbs"] - ) if numb + ) + if numb ] ) @@ -587,7 +595,7 @@ def short_name(self) -> str: the following order: - formula - short_formula - - formula_hash + - formula_hash. """ formula = self.formula if utf8len(formula) <= 255: @@ -1282,7 +1290,9 @@ def from_fmt_obj(self, fmtobj, directory, labeled=True, **kwargs): def to_fmt_obj(self, fmtobj, directory, *args, **kwargs): if not isinstance(fmtobj, dpdata.plugins.deepmd.DeePMDMixedFormat): for fn, ss in zip( - fmtobj.to_multi_systems([ss.short_name for ss in self.systems.values()], directory, **kwargs), + fmtobj.to_multi_systems( + [ss.short_name for ss in self.systems.values()], directory, **kwargs + ), self.systems.values(), ): ss.to_fmt_obj(fmtobj, fn, *args, **kwargs) diff --git a/dpdata/utils.py b/dpdata/utils.py index c626917a5..cf4a109ee 100644 --- a/dpdata/utils.py +++ b/dpdata/utils.py @@ -103,4 +103,4 @@ def uniq_atom_names(data): def utf8len(s: str) -> int: """Return the byte length of a string.""" - return len(s.encode('utf-8')) \ No newline at end of file + return len(s.encode("utf-8")) diff --git a/tests/test_multisystems.py b/tests/test_multisystems.py index f2957a3e5..2bda13a9b 100644 --- a/tests/test_multisystems.py +++ b/tests/test_multisystems.py @@ -4,7 +4,6 @@ from itertools import permutations import numpy as np - from comp_sys import CompLabeledSys, IsNoPBC, MultiSystems from context import dpdata @@ -206,7 +205,7 @@ def setUp(self): class TestLongFilename(unittest.TestCase): def test_long_filename1(self): system = dpdata.System( - data = { + data={ "atom_names": [f"TYPE{ii}" for ii in range(200)], "atom_numbs": [1] + [0 for _ in range(199)], "atom_types": np.arange(1), @@ -221,7 +220,7 @@ def test_long_filename1(self): def test_long_filename2(self): system = dpdata.System( - data = { + data={ "atom_names": [f"TYPE{ii}" for ii in range(200)], "atom_numbs": [1 for _ in range(200)], "atom_types": np.arange(200), @@ -234,5 +233,6 @@ def test_long_filename2(self): with tempfile.TemporaryDirectory() as tmpdir: ms.to_deepmd_npy(tmpdir) + if __name__ == "__main__": unittest.main()