Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ orb = ["orb-models>=0.6.2"]
sevenn = ["sevenn[torchsim]>=0.12.1"]
nequip = ["nequip>=0.17.1"]
nequix = ["nequix[torch-sim]>=0.4.5"]
fairchem = ["fairchem-core>=2.7", "scipy<1.17.0"]
fairchem = ["fairchem-core>=2.19.0", "scipy<1.17.0"]
docs = [
"autodoc_pydantic==2.2.0",
"furo==2024.8.6",
Expand Down
13 changes: 3 additions & 10 deletions tests/models/test_orb.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,11 @@ def orbv3_direct_20_omat_calculator() -> ORBCalculator:
energy_atol=5e-5,
)

test_validate_conservative_model_outputs = pytest.mark.xfail(
reason=(
"Upstream ORB conservative model incorrectly squeezes length-1 batch "
"dimensions; see https://github.com/orbital-materials/orb-models/pull/158"
),
strict=False,
)(
make_validate_model_outputs_test(
model_fixture_name="orbv3_conservative_inf_omat_model",
)
test_validate_conservative_model_outputs = make_validate_model_outputs_test(
model_fixture_name="orbv3_conservative_inf_omat_model",
)


test_validate_direct_model_outputs = pytest.mark.xfail(
reason=(
"Upstream ORB direct model shows batch-dependent leakage; "
Expand Down
54 changes: 51 additions & 3 deletions torch_sim/models/orb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@


try:
from orb_models.forcefield.inference.d3_model import D3SumModel
from orb_models.forcefield.inference.orb_torchsim import OrbTorchSimModel

import torch_sim as ts
from torch_sim.elastic import voigt_6_to_full_3x3_stress

# Re-export with backward-compatible name
class OrbModel(OrbTorchSimModel):
Expand All @@ -37,6 +39,48 @@ def _normalize_charge_spin(state: "ts.SimState") -> "ts.SimState":
spin=spin if spin is not None else zeros,
)

def _get_results(self, out: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Parses the results into a final output dictionary."""
results = {}
model = (
self.model.xc_model if isinstance(self.model, D3SumModel) else self.model
)
heads = getattr(model, "heads", {})
no_direct_energy_head = "energy" not in heads
no_direct_force_head = "forces" not in heads
no_direct_stress_head = "stress" not in heads
for prop in self.implemented_properties:
if prop == "free_energy" and no_direct_energy_head:
continue
if prop == "forces" and no_direct_force_head:
continue
if prop == "stress" and no_direct_stress_head:
continue
_prop = "energy" if prop == "free_energy" else prop

# Do not squeeze the output tensors in the case of single atom cells
# TODO: remove after https://github.com/orbital-materials/orb-models/pull/158
results[prop] = torch.atleast_1d(out[_prop])

# Rename certain keys for the conservative model
if self.conservative:
if model.forces_name in results:
results["direct_forces"] = results[model.forces_name]
results["forces"] = results[model.grad_forces_name]

if model.has_stress:
if model.stress_name in results:
results["direct_stress"] = results[model.stress_name]
results["stress"] = results[model.grad_stress_name]

# Ensure stress has shape [-1, 3, 3]
if "stress" in results and results["stress"].shape[-1] == 6:
results["stress"] = voigt_6_to_full_3x3_stress(
torch.atleast_2d(results["stress"])
)

return results

def forward(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
"""Run forward pass, detaching outputs unless retain_graph is True."""
if args and isinstance(args[0], ts.SimState):
Expand All @@ -60,9 +104,13 @@ class OrbModel(ModelInterface):
It raises an ImportError if accessed.
"""

def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None:
"""Dummy init for type checking."""
raise err
# Capture the original ImportError in a closure-safe default so the
# fallback always re-raises the real import failure, even when callers
# pass positional/keyword args (e.g. ``OrbModel(orb_ff, adapter, ...)``)
# that would otherwise shadow an ``err`` parameter.
def __init__(self, *_args: Any, _err: ImportError = exc, **_kwargs: Any) -> None:
"""Dummy init that re-raises the original import failure."""
raise _err

def forward(self, *_args: Any, **_kwargs: Any) -> Any:
"""Unreachable — __init__ always raises."""
Expand Down
Loading