Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/metatomic_ase/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def create_version_number(version):
# No dependency on ASE itself until this package is no longer a direct
# dependency of metatomic-torch
# "ase >=3.22.0",
"vesin >=0.5.2,<0.6",
"vesin >=0.5.5,<0.6",
]

# when packaging a sdist for release, we should never use local dependencies
Expand Down
25 changes: 9 additions & 16 deletions python/metatomic_ase/src/metatomic_ase/_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
pick_output,
)

from ._neighbors import _compute_requested_neighbors
from ._neighbors import AllNeighborsCalculator


import ase # isort: skip
Expand Down Expand Up @@ -317,6 +317,11 @@ def __init__(
"be positive"
)

self._nl_calculators = AllNeighborsCalculator(
requested_options=self._model.requested_neighbor_lists(),
check_consistency=check_consistency,
)

# We do our own check to verify if a property is implemented in `calculate()`,
# so we pretend to be able to compute all properties ASE knows about.
self.implemented_properties = ALL_ASE_PROPERTIES
Expand Down Expand Up @@ -398,11 +403,7 @@ def run_model(
systems.append(system)

# Compute the neighbors lists requested by the model
input_systems = _compute_requested_neighbors(
systems=systems,
requested_options=self._model.requested_neighbor_lists(),
check_consistency=self.parameters["check_consistency"],
)
input_systems = self._nl_calculators.compute(systems=systems)

available_outputs = self._model.capabilities().outputs
for key in outputs:
Expand Down Expand Up @@ -538,11 +539,7 @@ def calculate(
with record_function("MetatomicCalculator::compute_neighbors"):
# convert from ase.Atoms to metatomic.torch.System
system = System(types, positions, cell, pbc)
input_system = _compute_requested_neighbors(
systems=[system],
requested_options=self._model.requested_neighbor_lists(),
check_consistency=self.parameters["check_consistency"],
)[0]
input_system = self._nl_calculators.compute(systems=[system])[0]

with record_function("MetatomicCalculator::get_model_inputs"):
for name, option in self._model.requested_inputs().items():
Expand Down Expand Up @@ -721,11 +718,7 @@ def compute_energy(
systems.append(system)

# Compute the neighbors lists requested by the model
input_systems = _compute_requested_neighbors(
systems=systems,
requested_options=self._model.requested_neighbor_lists(),
check_consistency=self.parameters["check_consistency"],
)
input_systems = self._nl_calculators.compute(systems=systems)

predictions = self._model(
systems=input_systems,
Expand Down
107 changes: 60 additions & 47 deletions python/metatomic_ase/src/metatomic_ase/_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,57 +31,68 @@
HAS_NVALCHEMIOPS = False


def _compute_requested_neighbors(
systems: List[System],
requested_options: List[NeighborListOptions],
check_consistency=False,
) -> List[System]:
"""
Compute all neighbor lists requested by ``model`` and store them inside the systems.
"""
can_use_nvalchemi = HAS_NVALCHEMIOPS and all(
system.device.type == "cuda" for system in systems
)

if can_use_nvalchemi:
full_nl_options = []
half_nl_options = []
for options in requested_options:
if options.full_list:
full_nl_options.append(options)
else:
half_nl_options.append(options)

# Do the full neighbor lists with nvalchemi, and the rest with vesin
systems = _compute_requested_neighbors_nvalchemi(
systems=systems,
requested_options=full_nl_options,
)
systems = _compute_requested_neighbors_vesin(
systems=systems,
requested_options=half_nl_options,
check_consistency=check_consistency,
class AllNeighborsCalculator:
def __init__(
self,
requested_options: List[NeighborListOptions],
check_consistency=False,
):
self.check_consistency = check_consistency
self._full_nl_options = [
options for options in requested_options if options.full_list
]
self._full_vesin_calculators = [
vesin.metatomic.NeighborList(
options=options,
length_unit="angstrom",
check_consistency=check_consistency,
)
for options in requested_options
if options.full_list
]
self._half_vesin_calculators = [
vesin.metatomic.NeighborList(
options=options,
length_unit="angstrom",
check_consistency=check_consistency,
)
for options in requested_options
if not options.full_list
]

def compute(self, systems: List[System]) -> List[System]:
assert isinstance(systems, list)
assert isinstance(systems[0], torch.ScriptObject)

can_use_nvalchemi = HAS_NVALCHEMIOPS and all(
system.device.type == "cuda" for system in systems
)
else:

if can_use_nvalchemi:
# Do the full neighbor lists with nvalchemi
systems = _compute_requested_neighbors_nvalchemi(
systems=systems,
requested_options=self._full_nl_options,
)
else:
systems = _compute_requested_neighbors_vesin(
systems=systems,
calculators=self._full_vesin_calculators,
)

# always compute the half neighbor lists with vesin
systems = _compute_requested_neighbors_vesin(
systems=systems,
requested_options=requested_options,
check_consistency=check_consistency,
calculators=self._half_vesin_calculators,
)

return systems
return systems


def _compute_requested_neighbors_vesin(
systems: List[System],
requested_options: List[NeighborListOptions],
check_consistency=False,
calculators: List[vesin.metatomic.NeighborList],
) -> List[System]:
"""
Compute all neighbor lists requested by ``model`` and store them inside the systems,
using vesin.
"""

system_devices = []
moved_systems = []
for system in systems:
Expand All @@ -91,12 +102,13 @@ def _compute_requested_neighbors_vesin(
else:
moved_systems.append(system)

vesin.metatomic.compute_requested_neighbors_from_options(
systems=moved_systems,
system_length_unit="angstrom",
options=requested_options,
check_consistency=check_consistency,
)
for calculator in calculators:
calculator.add_neighbor_list(
systems=moved_systems,
# if we have more than one system, we can no keep the data as a reference
# to memory allocated in the calculator and we need to make a copy
copy=len(systems) > 1,
)

systems = []
for system, device in zip(moved_systems, system_devices, strict=True):
Expand Down Expand Up @@ -142,6 +154,7 @@ def _compute_requested_neighbors_nvalchemi(systems, requested_options):
"cell_shift_c",
],
values=torch.hstack([P, S]),
assume_unique=True,
),
components=[
Labels("xyz", torch.tensor([[0], [1], [2]], device=system.device))
Expand Down
7 changes: 7 additions & 0 deletions python/metatomic_torchsim/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ follows [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased](https://github.com/metatensor/metatomic/)

## [Version 0.1.2](https://github.com/metatensor/metatomic/releases/tag/metatomic-torchsim-v0.1.2) - 2026-04-22

### Changed

- Removed the upper-version pin on `torch-sim-atomistic` to make updating the
code in there that re-exports this package easier.

## [Version 0.1.1](https://github.com/metatensor/metatomic/releases/tag/metatomic-torchsim-v0.1.1) - 2026-04-01

### Fixed
Expand Down
14 changes: 7 additions & 7 deletions python/metatomic_torchsim/metatomic_torchsim/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
pick_output,
)

from ._neighbors import _compute_requested_neighbors
from ._neighbors import AllNeighborsCalculator


try:
Expand Down Expand Up @@ -249,7 +249,6 @@ def __init__(
"be positive"
)

self._requested_neighbor_lists = self._model.requested_neighbor_lists()
self._requested_inputs = self._model.requested_inputs()
if len(self._requested_inputs) != 0:
raise ValueError(
Expand Down Expand Up @@ -283,6 +282,11 @@ def __init__(
outputs=run_outputs,
)

self._nl_calculators = AllNeighborsCalculator(
requested_options=self._model.requested_neighbor_lists(),
check_consistency=check_consistency,
)

self.additional_outputs: Dict[str, TensorMap] = {}
"""
Additional outputs computed by :py:meth:`forward` are stored here.
Expand Down Expand Up @@ -355,11 +359,7 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]:
)

# Compute neighbor lists
systems = _compute_requested_neighbors(
systems=systems,
requested_options=self._requested_neighbor_lists,
check_consistency=self._check_consistency,
)
systems = self._nl_calculators.compute(systems=systems)

# Run the model (evaluation options precomputed in __init__)
model_outputs = self._model(
Expand Down
Loading
Loading