diff --git a/dpdata/plugins/xyz.py b/dpdata/plugins/xyz.py index d56a8618c..8ec59486a 100644 --- a/dpdata/plugins/xyz.py +++ b/dpdata/plugins/xyz.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io from typing import TYPE_CHECKING import numpy as np @@ -9,7 +10,7 @@ if TYPE_CHECKING: from dpdata.utils import FileType -from dpdata.xyz.quip_gap_xyz import QuipGapxyzSystems +from dpdata.xyz.quip_gap_xyz import QuipGapxyzSystems, format_single_frame from dpdata.xyz.xyz import coord_to_xyz, xyz_to_coord @@ -56,3 +57,56 @@ def from_labeled_system(self, data, **kwargs): def from_multi_systems(self, file_name, **kwargs): # here directory is the file_name return QuipGapxyzSystems(file_name) + + def to_labeled_system(self, data, file_name: FileType, **kwargs): + """Write LabeledSystem data to QUIP/GAP XYZ format file. + + Parameters + ---------- + data : dict + system data + file_name : FileType + output file name or file handler + **kwargs : dict + additional arguments + """ + frames = [] + nframes = len(data["energies"]) + + for frame_idx in range(nframes): + frame_lines = format_single_frame(data, frame_idx) + frames.append("\n".join(frame_lines)) + + content = "\n".join(frames) + + if isinstance(file_name, io.IOBase): + file_name.write(content) + if not content.endswith("\n"): + file_name.write("\n") + else: + with open_file(file_name, "w") as fp: + fp.write(content) + + def to_multi_systems(self, formulas, directory, **kwargs): + """Return single filename for all systems in QUIP/GAP XYZ format. + + For QUIP/GAP XYZ format, all systems are written to a single file. + + Parameters + ---------- + formulas : list[str] + list of system names/formulas + directory : str + output filename + **kwargs : dict + additional arguments + + Yields + ------ + file handler + file handler for all systems + """ + with open_file(directory, "w") as f: + # Just create/truncate the file, then yield file handlers + for _ in formulas: + yield f diff --git a/dpdata/xyz/quip_gap_xyz.py b/dpdata/xyz/quip_gap_xyz.py index cba971e47..71e976de6 100644 --- a/dpdata/xyz/quip_gap_xyz.py +++ b/dpdata/xyz/quip_gap_xyz.py @@ -7,6 +7,8 @@ import numpy as np +from dpdata.periodic_table import Element + class QuipGapxyzSystems: """deal with QuipGapxyzFile.""" @@ -183,3 +185,66 @@ def handle_single_xyz_frame(lines): info_dict["virials"] = virials info_dict["orig"] = np.zeros(3) return info_dict + + +def format_single_frame(data, frame_idx): + """Format a single frame of system data into QUIP/GAP XYZ format lines. + + Parameters + ---------- + data : dict + system data + frame_idx : int + frame index + + Returns + ------- + list[str] + lines for the frame + """ + # Number of atoms + natoms = len(data["atom_types"]) + + # Build header line with metadata + header_parts = [] + + # Energy + energy = data["energies"][frame_idx] + header_parts.append(f"energy={energy:.12e}") + + # Virial (if present) + if "virials" in data: + virial = data["virials"][frame_idx] + virial_str = " ".join(f"{v:.12e}" for v in virial.flatten()) + header_parts.append(f'virial="{virial_str}"') + + # Lattice + cell = data["cells"][frame_idx] + lattice_str = " ".join(f"{c:.12e}" for c in cell.flatten()) + header_parts.append(f'Lattice="{lattice_str}"') + + # Properties + header_parts.append("Properties=species:S:1:pos:R:3:Z:I:1:force:R:3") + + header_line = " ".join(header_parts) + + # Format atom lines + atom_lines = [] + coords = data["coords"][frame_idx] + forces = data["forces"][frame_idx] + atom_names = np.array(data["atom_names"]) + atom_types = data["atom_types"] + + for i in range(natoms): + atom_type_idx = atom_types[i] + species = atom_names[atom_type_idx] + x, y, z = coords[i] + fx, fy, fz = forces[i] + atomic_number = Element(species).Z + + atom_line = f"{species} {x:.11e} {y:.11e} {z:.11e} {atomic_number} {fx:.11e} {fy:.11e} {fz:.11e}" + atom_lines.append(atom_line) + + # Combine all lines for this frame + frame_lines = [str(natoms), header_line] + atom_lines + return frame_lines diff --git a/tests/test_quip_gap_xyz_to_methods.py b/tests/test_quip_gap_xyz_to_methods.py new file mode 100644 index 000000000..7b38cc334 --- /dev/null +++ b/tests/test_quip_gap_xyz_to_methods.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import os +import tempfile +import unittest + +from context import dpdata + + +class TestQuipGapXYZToMethods(unittest.TestCase): + """Test the to_labeled_system and to_multi_systems methods for QuipGapXYZFormat.""" + + def setUp(self): + """Set up test data.""" + # Load test multi-systems + self.multi_systems = dpdata.MultiSystems.from_file( + "xyz/xyz_unittest.xyz", "quip/gap/xyz" + ) + self.system_b1c9 = self.multi_systems.systems["B1C9"] + self.system_b5c7 = self.multi_systems.systems["B5C7"] + + def test_to_labeled_system(self): + """Test writing a single labeled system to QUIP/GAP XYZ format.""" + with tempfile.NamedTemporaryFile( + mode="w", suffix=".xyz", delete=False + ) as tmp_file: + output_file = tmp_file.name + + try: + # Write the system to file + self.system_b1c9.to("quip/gap/xyz", output_file) + + # Verify file was created and has content + self.assertTrue(os.path.exists(output_file)) + with open(output_file) as f: + content = f.read() + self.assertTrue(len(content) > 0) + + # Read back and verify we can parse it (use MultiSystems.from_file for QUIP/GAP XYZ) + reloaded_multi = dpdata.MultiSystems.from_file(output_file, "quip/gap/xyz") + self.assertEqual(len(reloaded_multi.systems), 1) + + # Verify the data matches (we should have the same system) + reloaded_system = list(reloaded_multi.systems.values())[0] + self.assertEqual(len(reloaded_system), len(self.system_b1c9)) + + finally: + if os.path.exists(output_file): + os.unlink(output_file) + + def test_to_multi_systems(self): + """Test writing multiple systems to a single QUIP/GAP XYZ format file.""" + with tempfile.NamedTemporaryFile( + mode="w", suffix=".xyz", delete=False + ) as tmp_file: + output_file = tmp_file.name + + try: + # Write all systems to file + self.multi_systems.to("quip/gap/xyz", output_file) + + # Verify file was created and has content + self.assertTrue(os.path.exists(output_file)) + with open(output_file) as f: + content = f.read() + self.assertTrue(len(content) > 0) + + # Read back and verify we get the same number of systems + reloaded_multi = dpdata.MultiSystems.from_file(output_file, "quip/gap/xyz") + self.assertEqual( + len(reloaded_multi.systems), len(self.multi_systems.systems) + ) + + # Verify total number of frames is preserved + original_frames = sum( + len(sys) for sys in self.multi_systems.systems.values() + ) + reloaded_frames = sum(len(sys) for sys in reloaded_multi.systems.values()) + self.assertEqual(reloaded_frames, original_frames) + + finally: + if os.path.exists(output_file): + os.unlink(output_file) + + def test_roundtrip_consistency(self): + """Test that writing and reading back preserves data consistency.""" + with tempfile.NamedTemporaryFile( + mode="w", suffix=".xyz", delete=False + ) as tmp_file: + output_file = tmp_file.name + + try: + # Write and read back + self.multi_systems.to("quip/gap/xyz", output_file) + reloaded_multi = dpdata.MultiSystems.from_file(output_file, "quip/gap/xyz") + + # Compare original and reloaded data for each system + for system_name in self.multi_systems.systems: + if system_name in reloaded_multi.systems: + original = self.multi_systems.systems[system_name] + reloaded = reloaded_multi.systems[system_name] + + # Check basic properties + self.assertEqual(len(original), len(reloaded)) + self.assertEqual( + len(original.data["atom_names"]), + len(reloaded.data["atom_names"]), + ) + + # Note: We don't check exact numerical equality because of floating point precision + # and potential differences in formatting, but the data should be structurally the same + + finally: + if os.path.exists(output_file): + os.unlink(output_file) + + +if __name__ == "__main__": + unittest.main()