diff --git a/tests/models/test_mattersim.py b/tests/models/test_mattersim.py index 91cd1ee8c..ee495aa7c 100644 --- a/tests/models/test_mattersim.py +++ b/tests/models/test_mattersim.py @@ -1,5 +1,3 @@ -# codespell-ignore: convertor - import traceback import ase.spacegroup diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 32d30f558..cdc7fb246 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -254,4 +254,4 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] stress = stress.view(-1, 3, 3) results["stress"] = stress - return results + return {k: v.detach() for k, v in results.items()} diff --git a/torch_sim/models/fairchem_legacy.py b/torch_sim/models/fairchem_legacy.py index a006cb201..05e452f1b 100644 --- a/torch_sim/models/fairchem_legacy.py +++ b/torch_sim/models/fairchem_legacy.py @@ -425,10 +425,10 @@ def forward( # noqa: C901 ): data_list.append( Data( - pos=sim_state.positions[c - n : c].clone(), - cell=sim_state.row_vector_cell[idx, None].clone(), - atomic_numbers=sim_state.atomic_numbers[c - n : c].clone(), - fixed=fixed[c - n : c].clone(), + pos=sim_state.positions[c - n : c].detach().clone(), + cell=sim_state.row_vector_cell[idx, None].detach().clone(), + atomic_numbers=sim_state.atomic_numbers[c - n : c].detach().clone(), + fixed=fixed[c - n : c].detach().clone(), natoms=n, pbc=sim_state.pbc, ) @@ -449,9 +449,9 @@ def forward( # noqa: C901 _pred = predictions[key] if key in self._reshaped_props: _pred = _pred.reshape(self._reshaped_props.get(key)).squeeze() - results[key] = _pred.detach() + results[key] = _pred results["energy"] = results["energy"].squeeze(dim=1) if results.get("stress") is not None and len(results["stress"].shape) == 2: results["stress"] = results["stress"].unsqueeze(dim=0) - return results + return {k: v.detach() for k, v in results.items()} diff --git a/torch_sim/models/graphpes_framework.py b/torch_sim/models/graphpes_framework.py index 0bca0cdd5..ab8af45d0 100644 --- a/torch_sim/models/graphpes_framework.py +++ b/torch_sim/models/graphpes_framework.py @@ -185,4 +185,5 @@ def forward(self, state: ts.SimState, **_kwargs: object) -> dict[str, torch.Tens if not isinstance(cutoff, torch.Tensor): raise TypeError("GraphPES model cutoff must be a tensor") atomic_graph = state_to_atomic_graph(state, cutoff) - return self._gp_model.predict(atomic_graph, self._properties) # type: ignore[return-value] + preds = self._gp_model.predict(atomic_graph, self._properties) # ty: ignore[call-non-callable] + return {k: v.detach() for k, v in preds.items()} diff --git a/torch_sim/models/metatomic.py b/torch_sim/models/metatomic.py index e4ed32a35..758c50d6b 100644 --- a/torch_sim/models/metatomic.py +++ b/torch_sim/models/metatomic.py @@ -237,7 +237,7 @@ def forward( # noqa: C901, PLR0915 ) results: dict[str, torch.Tensor] = {} - results["energy"] = model_outputs["energy"].block().values.detach().squeeze(-1) + results["energy"] = model_outputs["energy"].block().values.squeeze(-1) # Compute forces and/or stresses if requested tensors_for_autograd = [] @@ -286,4 +286,4 @@ def forward( # noqa: C901, PLR0915 else: results["stress"] = torch.empty_like(cell) - return results + return {k: v.detach() for k, v in results.items()} diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index a02025048..9e9d333b8 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -466,4 +466,4 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] torch.atleast_2d(results["stress"]) ) - return results + return {k: v.detach() for k, v in results.items()}