Skip to content
Merged
56 changes: 55 additions & 1 deletion dpdata/plugins/xyz.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import io
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -9,7 +10,7 @@

if TYPE_CHECKING:
from dpdata.utils import FileType
from dpdata.xyz.quip_gap_xyz import QuipGapxyzSystems
from dpdata.xyz.quip_gap_xyz import QuipGapxyzSystems, format_single_frame
from dpdata.xyz.xyz import coord_to_xyz, xyz_to_coord


Expand Down Expand Up @@ -56,3 +57,56 @@ def from_labeled_system(self, data, **kwargs):
def from_multi_systems(self, file_name, **kwargs):
# here directory is the file_name
return QuipGapxyzSystems(file_name)

def to_labeled_system(self, data, file_name: FileType, **kwargs):
"""Write LabeledSystem data to QUIP/GAP XYZ format file.

Parameters
----------
data : dict
system data
file_name : FileType
output file name or file handler
**kwargs : dict
additional arguments
"""
frames = []
nframes = len(data["energies"])

for frame_idx in range(nframes):
frame_lines = format_single_frame(data, frame_idx)
frames.append("\n".join(frame_lines))

content = "\n".join(frames)

if isinstance(file_name, io.IOBase):
file_name.write(content)
if not content.endswith("\n"):
file_name.write("\n")
else:
with open_file(file_name, "w") as fp:
fp.write(content)

def to_multi_systems(self, formulas, directory, **kwargs):
"""Return single filename for all systems in QUIP/GAP XYZ format.

For QUIP/GAP XYZ format, all systems are written to a single file.

Parameters
----------
formulas : list[str]
list of system names/formulas
directory : str
output filename
**kwargs : dict
additional arguments

Yields
------
file handler
file handler for all systems
"""
with open_file(directory, "w") as f:
# Just create/truncate the file, then yield file handlers
for _ in formulas:
yield f
65 changes: 65 additions & 0 deletions dpdata/xyz/quip_gap_xyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import numpy as np

from dpdata.periodic_table import Element


class QuipGapxyzSystems:
"""deal with QuipGapxyzFile."""
Expand Down Expand Up @@ -183,3 +185,66 @@ def handle_single_xyz_frame(lines):
info_dict["virials"] = virials
info_dict["orig"] = np.zeros(3)
return info_dict


def format_single_frame(data, frame_idx):
"""Format a single frame of system data into QUIP/GAP XYZ format lines.

Parameters
----------
data : dict
system data
frame_idx : int
frame index

Returns
-------
list[str]
lines for the frame
"""
# Number of atoms
natoms = len(data["atom_types"])

# Build header line with metadata
header_parts = []

# Energy
energy = data["energies"][frame_idx]
header_parts.append(f"energy={energy:.12e}")

# Virial (if present)
if "virials" in data:
virial = data["virials"][frame_idx]
virial_str = " ".join(f"{v:.12e}" for v in virial.flatten())
header_parts.append(f'virial="{virial_str}"')

# Lattice
cell = data["cells"][frame_idx]
lattice_str = " ".join(f"{c:.12e}" for c in cell.flatten())
header_parts.append(f'Lattice="{lattice_str}"')

# Properties
header_parts.append("Properties=species:S:1:pos:R:3:Z:I:1:force:R:3")

header_line = " ".join(header_parts)

# Format atom lines
atom_lines = []
coords = data["coords"][frame_idx]
forces = data["forces"][frame_idx]
atom_names = np.array(data["atom_names"])
atom_types = data["atom_types"]

for i in range(natoms):
atom_type_idx = atom_types[i]
species = atom_names[atom_type_idx]
x, y, z = coords[i]
fx, fy, fz = forces[i]
atomic_number = Element(species).Z

atom_line = f"{species} {x:.11e} {y:.11e} {z:.11e} {atomic_number} {fx:.11e} {fy:.11e} {fz:.11e}"
atom_lines.append(atom_line)

# Combine all lines for this frame
frame_lines = [str(natoms), header_line] + atom_lines
return frame_lines
120 changes: 120 additions & 0 deletions tests/test_quip_gap_xyz_to_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#!/usr/bin/env python3
from __future__ import annotations

import os
import tempfile
import unittest

from context import dpdata


class TestQuipGapXYZToMethods(unittest.TestCase):
"""Test the to_labeled_system and to_multi_systems methods for QuipGapXYZFormat."""

def setUp(self):
"""Set up test data."""
# Load test multi-systems
self.multi_systems = dpdata.MultiSystems.from_file(
"xyz/xyz_unittest.xyz", "quip/gap/xyz"
)
self.system_b1c9 = self.multi_systems.systems["B1C9"]
self.system_b5c7 = self.multi_systems.systems["B5C7"]

def test_to_labeled_system(self):
"""Test writing a single labeled system to QUIP/GAP XYZ format."""
with tempfile.NamedTemporaryFile(
mode="w", suffix=".xyz", delete=False
) as tmp_file:
output_file = tmp_file.name

try:
# Write the system to file
self.system_b1c9.to("quip/gap/xyz", output_file)

# Verify file was created and has content
self.assertTrue(os.path.exists(output_file))
with open(output_file) as f:
content = f.read()
self.assertTrue(len(content) > 0)

# Read back and verify we can parse it (use MultiSystems.from_file for QUIP/GAP XYZ)
reloaded_multi = dpdata.MultiSystems.from_file(output_file, "quip/gap/xyz")
self.assertEqual(len(reloaded_multi.systems), 1)

# Verify the data matches (we should have the same system)
reloaded_system = list(reloaded_multi.systems.values())[0]
self.assertEqual(len(reloaded_system), len(self.system_b1c9))

finally:
if os.path.exists(output_file):
os.unlink(output_file)

def test_to_multi_systems(self):
"""Test writing multiple systems to a single QUIP/GAP XYZ format file."""
with tempfile.NamedTemporaryFile(
mode="w", suffix=".xyz", delete=False
) as tmp_file:
output_file = tmp_file.name

try:
# Write all systems to file
self.multi_systems.to("quip/gap/xyz", output_file)

# Verify file was created and has content
self.assertTrue(os.path.exists(output_file))
with open(output_file) as f:
content = f.read()
self.assertTrue(len(content) > 0)

# Read back and verify we get the same number of systems
reloaded_multi = dpdata.MultiSystems.from_file(output_file, "quip/gap/xyz")
self.assertEqual(
len(reloaded_multi.systems), len(self.multi_systems.systems)
)

# Verify total number of frames is preserved
original_frames = sum(
len(sys) for sys in self.multi_systems.systems.values()
)
reloaded_frames = sum(len(sys) for sys in reloaded_multi.systems.values())
self.assertEqual(reloaded_frames, original_frames)

finally:
if os.path.exists(output_file):
os.unlink(output_file)

def test_roundtrip_consistency(self):
"""Test that writing and reading back preserves data consistency."""
with tempfile.NamedTemporaryFile(
mode="w", suffix=".xyz", delete=False
) as tmp_file:
output_file = tmp_file.name

try:
# Write and read back
self.multi_systems.to("quip/gap/xyz", output_file)
reloaded_multi = dpdata.MultiSystems.from_file(output_file, "quip/gap/xyz")

# Compare original and reloaded data for each system
for system_name in self.multi_systems.systems:
if system_name in reloaded_multi.systems:
original = self.multi_systems.systems[system_name]
reloaded = reloaded_multi.systems[system_name]

# Check basic properties
self.assertEqual(len(original), len(reloaded))
self.assertEqual(
len(original.data["atom_names"]),
len(reloaded.data["atom_names"]),
)

# Note: We don't check exact numerical equality because of floating point precision
# and potential differences in formatting, but the data should be structurally the same

finally:
if os.path.exists(output_file):
os.unlink(output_file)


if __name__ == "__main__":
unittest.main()