From 9afed6c00260da5e3085771594914bd4671e85d2 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 20 Apr 2026 09:59:33 -0400 Subject: [PATCH 1/5] fea: single atom interface test --- torch_sim/models/interface.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 98d28433..0dcee40b 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -384,6 +384,7 @@ def validate_model_outputs( # noqa: C901, PLR0915 and primitive BCC iron) for validation. It tests both single and multi-batch processing capabilities. """ + from ase import Atoms from ase.build import bulk, molecule def _modify(state: SimState) -> SimState: @@ -566,3 +567,27 @@ def _modify(state: SimState) -> SimState: f"forces shape incorrect for benzene: " f"{benzene_output['forces'].shape=} != (12, 3)" ) + + # Test an isolated, non-periodic single atom. This catches models whose + # downstream post-processing squeezes length-1 batch/atom dimensions + # (e.g. ``torch.atleast_1d(out.squeeze())`` which collapses ``(1, 3)`` to + # ``(3,)``). An isolated atom has no edges, which often routes models + # through a different code path than a 1-atom periodic cell. + isolated_atoms = Atoms("H", positions=[[0.0, 0.0, 0.0]], pbc=False) + isolated_state = _modify(ts.io.atoms_to_state([isolated_atoms], device, dtype)) + isolated_output = model.forward(isolated_state) + if isolated_output["energy"].shape != (1,): + raise ValueError( + f"energy shape incorrect for isolated atom: " + f"{isolated_output['energy'].shape=} != (1,)" + ) + if force_computed and isolated_output["forces"].shape != (1, 3): + raise ValueError( + f"forces shape incorrect for isolated atom: " + f"{isolated_output['forces'].shape=} != (1, 3)" + ) + if stress_computed and isolated_output["stress"].shape != (1, 3, 3): + raise ValueError( + f"stress shape incorrect for isolated atom: " + f"{isolated_output['stress'].shape=} != (1, 3, 3)" + ) From 62df33d6883892193f08c83bbd9d84b180d8b187 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 20 Apr 2026 14:21:00 -0400 Subject: [PATCH 2/5] fix orb squeeze in wrapper --- tests/models/test_orb.py | 13 +++------- torch_sim/models/orb.py | 52 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 52 insertions(+), 13 deletions(-) diff --git a/tests/models/test_orb.py b/tests/models/test_orb.py index de23cd17..1e7483cc 100644 --- a/tests/models/test_orb.py +++ b/tests/models/test_orb.py @@ -72,18 +72,11 @@ def orbv3_direct_20_omat_calculator() -> ORBCalculator: energy_atol=5e-5, ) -test_validate_conservative_model_outputs = pytest.mark.xfail( - reason=( - "Upstream ORB conservative model incorrectly squeezes length-1 batch " - "dimensions; see https://github.com/orbital-materials/orb-models/pull/158" - ), - strict=False, -)( - make_validate_model_outputs_test( - model_fixture_name="orbv3_conservative_inf_omat_model", - ) +test_validate_conservative_model_outputs = make_validate_model_outputs_test( + model_fixture_name="orbv3_conservative_inf_omat_model", ) + test_validate_direct_model_outputs = pytest.mark.xfail( reason=( "Upstream ORB direct model shows batch-dependent leakage; " diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index 5c910cf4..a115ced4 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -15,9 +15,11 @@ try: + from orb_models.forcefield.inference.d3_model import D3SumModel from orb_models.forcefield.inference.orb_torchsim import OrbTorchSimModel import torch_sim as ts + from torch_sim.elastic import voigt_6_to_full_3x3_stress # Re-export with backward-compatible name class OrbModel(OrbTorchSimModel): @@ -37,6 +39,46 @@ def _normalize_charge_spin(state: "ts.SimState") -> "ts.SimState": spin=spin if spin is not None else zeros, ) + def _get_results(self, out: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Parses the results into a final output dictionary.""" + results = {} + model = ( + self.model.xc_model if isinstance(self.model, D3SumModel) else self.model + ) + heads = getattr(model, "heads", {}) + no_direct_energy_head = "energy" not in heads + no_direct_force_head = "forces" not in heads + no_direct_stress_head = "stress" not in heads + for prop in self.implemented_properties: + if prop == "free_energy" and no_direct_energy_head: + continue + if prop == "forces" and no_direct_force_head: + continue + if prop == "stress" and no_direct_stress_head: + continue + _prop = "energy" if prop == "free_energy" else prop + + results[prop] = torch.atleast_1d(out[_prop]) + + # Rename certain keys for the conservative model + if self.conservative: + if model.forces_name in results: + results["direct_forces"] = results[model.forces_name] + results["forces"] = results[model.grad_forces_name] + + if model.has_stress: + if model.stress_name in results: + results["direct_stress"] = results[model.stress_name] + results["stress"] = results[model.grad_stress_name] + + # Ensure stress has shape [-1, 3, 3] + if "stress" in results and results["stress"].shape[-1] == 6: + results["stress"] = voigt_6_to_full_3x3_stress( + torch.atleast_2d(results["stress"]) + ) + + return results + def forward(self, *args: Any, **kwargs: Any) -> dict[str, Any]: """Run forward pass, detaching outputs unless retain_graph is True.""" if args and isinstance(args[0], ts.SimState): @@ -60,9 +102,13 @@ class OrbModel(ModelInterface): It raises an ImportError if accessed. """ - def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: - """Dummy init for type checking.""" - raise err + # Capture the original ImportError in a closure-safe default so the + # fallback always re-raises the real import failure, even when callers + # pass positional/keyword args (e.g. ``OrbModel(orb_ff, adapter, ...)``) + # that would otherwise shadow an ``err`` parameter. + def __init__(self, *_args: Any, _err: ImportError = exc, **_kwargs: Any) -> None: + """Dummy init that re-raises the original import failure.""" + raise _err def forward(self, *_args: Any, **_kwargs: Any) -> Any: """Unreachable — __init__ always raises.""" From 71e373b6e56fe402ea6b6af742301f6305f4bb96 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 20 Apr 2026 15:53:32 -0400 Subject: [PATCH 3/5] fix: bump min version to allow single atoms in fairchem uma models --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9f47f097..6c99bc27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ orb = ["orb-models>=0.6.2"] sevenn = ["sevenn[torchsim]>=0.12.1"] nequip = ["nequip>=0.17.1"] nequix = ["nequix[torch-sim]>=0.4.5"] -fairchem = ["fairchem-core>=2.7", "scipy<1.17.0"] +fairchem = ["fairchem-core>=2.19.0", "scipy<1.17.0"] docs = [ "autodoc_pydantic==2.2.0", "furo==2024.8.6", From 761cb3ade3e526942e9b57a10678891903861c10 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 20 Apr 2026 18:41:37 -0400 Subject: [PATCH 4/5] remove isolated H test as single atom Fe already tests 1 atom unit cell --- torch_sim/models/interface.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 0dcee40b..98d28433 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -384,7 +384,6 @@ def validate_model_outputs( # noqa: C901, PLR0915 and primitive BCC iron) for validation. It tests both single and multi-batch processing capabilities. """ - from ase import Atoms from ase.build import bulk, molecule def _modify(state: SimState) -> SimState: @@ -567,27 +566,3 @@ def _modify(state: SimState) -> SimState: f"forces shape incorrect for benzene: " f"{benzene_output['forces'].shape=} != (12, 3)" ) - - # Test an isolated, non-periodic single atom. This catches models whose - # downstream post-processing squeezes length-1 batch/atom dimensions - # (e.g. ``torch.atleast_1d(out.squeeze())`` which collapses ``(1, 3)`` to - # ``(3,)``). An isolated atom has no edges, which often routes models - # through a different code path than a 1-atom periodic cell. - isolated_atoms = Atoms("H", positions=[[0.0, 0.0, 0.0]], pbc=False) - isolated_state = _modify(ts.io.atoms_to_state([isolated_atoms], device, dtype)) - isolated_output = model.forward(isolated_state) - if isolated_output["energy"].shape != (1,): - raise ValueError( - f"energy shape incorrect for isolated atom: " - f"{isolated_output['energy'].shape=} != (1,)" - ) - if force_computed and isolated_output["forces"].shape != (1, 3): - raise ValueError( - f"forces shape incorrect for isolated atom: " - f"{isolated_output['forces'].shape=} != (1, 3)" - ) - if stress_computed and isolated_output["stress"].shape != (1, 3, 3): - raise ValueError( - f"stress shape incorrect for isolated atom: " - f"{isolated_output['stress'].shape=} != (1, 3, 3)" - ) From 5305e171d57d672c26e8a6e7dee2f4b85438f953 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 21 Apr 2026 09:04:06 -0400 Subject: [PATCH 5/5] doc: add comment to highlight --- torch_sim/models/orb.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index a115ced4..e8736fe0 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -58,6 +58,8 @@ def _get_results(self, out: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: continue _prop = "energy" if prop == "free_energy" else prop + # Do not squeeze the output tensors in the case of single atom cells + # TODO: remove after https://github.com/orbital-materials/orb-models/pull/158 results[prop] = torch.atleast_1d(out[_prop]) # Rename certain keys for the conservative model