diff --git a/pyproject.toml b/pyproject.toml index 9f47f097..6c99bc27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ orb = ["orb-models>=0.6.2"] sevenn = ["sevenn[torchsim]>=0.12.1"] nequip = ["nequip>=0.17.1"] nequix = ["nequix[torch-sim]>=0.4.5"] -fairchem = ["fairchem-core>=2.7", "scipy<1.17.0"] +fairchem = ["fairchem-core>=2.19.0", "scipy<1.17.0"] docs = [ "autodoc_pydantic==2.2.0", "furo==2024.8.6", diff --git a/tests/models/test_orb.py b/tests/models/test_orb.py index de23cd17..1e7483cc 100644 --- a/tests/models/test_orb.py +++ b/tests/models/test_orb.py @@ -72,18 +72,11 @@ def orbv3_direct_20_omat_calculator() -> ORBCalculator: energy_atol=5e-5, ) -test_validate_conservative_model_outputs = pytest.mark.xfail( - reason=( - "Upstream ORB conservative model incorrectly squeezes length-1 batch " - "dimensions; see https://github.com/orbital-materials/orb-models/pull/158" - ), - strict=False, -)( - make_validate_model_outputs_test( - model_fixture_name="orbv3_conservative_inf_omat_model", - ) +test_validate_conservative_model_outputs = make_validate_model_outputs_test( + model_fixture_name="orbv3_conservative_inf_omat_model", ) + test_validate_direct_model_outputs = pytest.mark.xfail( reason=( "Upstream ORB direct model shows batch-dependent leakage; " diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index 5c910cf4..e8736fe0 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -15,9 +15,11 @@ try: + from orb_models.forcefield.inference.d3_model import D3SumModel from orb_models.forcefield.inference.orb_torchsim import OrbTorchSimModel import torch_sim as ts + from torch_sim.elastic import voigt_6_to_full_3x3_stress # Re-export with backward-compatible name class OrbModel(OrbTorchSimModel): @@ -37,6 +39,48 @@ def _normalize_charge_spin(state: "ts.SimState") -> "ts.SimState": spin=spin if spin is not None else zeros, ) + def _get_results(self, out: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Parses the results into a final output dictionary.""" + results = {} + model = ( + self.model.xc_model if isinstance(self.model, D3SumModel) else self.model + ) + heads = getattr(model, "heads", {}) + no_direct_energy_head = "energy" not in heads + no_direct_force_head = "forces" not in heads + no_direct_stress_head = "stress" not in heads + for prop in self.implemented_properties: + if prop == "free_energy" and no_direct_energy_head: + continue + if prop == "forces" and no_direct_force_head: + continue + if prop == "stress" and no_direct_stress_head: + continue + _prop = "energy" if prop == "free_energy" else prop + + # Do not squeeze the output tensors in the case of single atom cells + # TODO: remove after https://github.com/orbital-materials/orb-models/pull/158 + results[prop] = torch.atleast_1d(out[_prop]) + + # Rename certain keys for the conservative model + if self.conservative: + if model.forces_name in results: + results["direct_forces"] = results[model.forces_name] + results["forces"] = results[model.grad_forces_name] + + if model.has_stress: + if model.stress_name in results: + results["direct_stress"] = results[model.stress_name] + results["stress"] = results[model.grad_stress_name] + + # Ensure stress has shape [-1, 3, 3] + if "stress" in results and results["stress"].shape[-1] == 6: + results["stress"] = voigt_6_to_full_3x3_stress( + torch.atleast_2d(results["stress"]) + ) + + return results + def forward(self, *args: Any, **kwargs: Any) -> dict[str, Any]: """Run forward pass, detaching outputs unless retain_graph is True.""" if args and isinstance(args[0], ts.SimState): @@ -60,9 +104,13 @@ class OrbModel(ModelInterface): It raises an ImportError if accessed. """ - def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: - """Dummy init for type checking.""" - raise err + # Capture the original ImportError in a closure-safe default so the + # fallback always re-raises the real import failure, even when callers + # pass positional/keyword args (e.g. ``OrbModel(orb_ff, adapter, ...)``) + # that would otherwise shadow an ``err`` parameter. + def __init__(self, *_args: Any, _err: ImportError = exc, **_kwargs: Any) -> None: + """Dummy init that re-raises the original import failure.""" + raise _err def forward(self, *_args: Any, **_kwargs: Any) -> Any: """Unreachable — __init__ always raises."""