Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions tests/models/test_mattersim.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# codespell-ignore: convertor

import traceback

import ase.spacegroup
Expand Down
2 changes: 1 addition & 1 deletion torch_sim/models/fairchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,4 +254,4 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
stress = stress.view(-1, 3, 3)
results["stress"] = stress

return results
return {k: v.detach() for k, v in results.items()}
12 changes: 6 additions & 6 deletions torch_sim/models/fairchem_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,10 @@ def forward( # noqa: C901
):
data_list.append(
Data(
pos=sim_state.positions[c - n : c].clone(),
cell=sim_state.row_vector_cell[idx, None].clone(),
atomic_numbers=sim_state.atomic_numbers[c - n : c].clone(),
fixed=fixed[c - n : c].clone(),
pos=sim_state.positions[c - n : c].detach().clone(),
cell=sim_state.row_vector_cell[idx, None].detach().clone(),
atomic_numbers=sim_state.atomic_numbers[c - n : c].detach().clone(),
fixed=fixed[c - n : c].detach().clone(),
natoms=n,
pbc=sim_state.pbc,
)
Expand All @@ -449,9 +449,9 @@ def forward( # noqa: C901
_pred = predictions[key]
if key in self._reshaped_props:
_pred = _pred.reshape(self._reshaped_props.get(key)).squeeze()
results[key] = _pred.detach()
results[key] = _pred

results["energy"] = results["energy"].squeeze(dim=1)
if results.get("stress") is not None and len(results["stress"].shape) == 2:
results["stress"] = results["stress"].unsqueeze(dim=0)
return results
return {k: v.detach() for k, v in results.items()}
3 changes: 2 additions & 1 deletion torch_sim/models/graphpes_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,5 @@ def forward(self, state: ts.SimState, **_kwargs: object) -> dict[str, torch.Tens
if not isinstance(cutoff, torch.Tensor):
raise TypeError("GraphPES model cutoff must be a tensor")
atomic_graph = state_to_atomic_graph(state, cutoff)
return self._gp_model.predict(atomic_graph, self._properties) # type: ignore[return-value]
preds = self._gp_model.predict(atomic_graph, self._properties) # ty: ignore[call-non-callable]
return {k: v.detach() for k, v in preds.items()}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels cleaner to be to detatch everything like this all in one line at the end? Could you update all the models to follow this pattern?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's already detached then the op is idempotent and so no harm.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call

4 changes: 2 additions & 2 deletions torch_sim/models/metatomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def forward( # noqa: C901, PLR0915
)

results: dict[str, torch.Tensor] = {}
results["energy"] = model_outputs["energy"].block().values.detach().squeeze(-1)
results["energy"] = model_outputs["energy"].block().values.squeeze(-1)

# Compute forces and/or stresses if requested
tensors_for_autograd = []
Expand Down Expand Up @@ -286,4 +286,4 @@ def forward( # noqa: C901, PLR0915
else:
results["stress"] = torch.empty_like(cell)

return results
return {k: v.detach() for k, v in results.items()}
2 changes: 1 addition & 1 deletion torch_sim/models/orb.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,4 +466,4 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
torch.atleast_2d(results["stress"])
)

return results
return {k: v.detach() for k, v in results.items()}
Loading