From b9a2ba368c20b875a4a7495d7a6ab78e12123f1a Mon Sep 17 00:00:00 2001 From: Reilly Osadchey Brown Date: Wed, 4 Mar 2026 15:50:53 -0500 Subject: [PATCH 1/3] Fix GPU memory leak from missing .detach() in model wrappers Several model wrappers were returning tensors still attached to the computation graph, causing the entire forward-pass graph to be retained in memory across simulation steps. - fairchem: detach energy, forces, stress predictions - orb: detach prediction outputs and conservative forces/stress - metatomic: detach forces and stress (energy was already detached) - fairchem_legacy: use detach().clone() on inputs to prevent graph retention via self.data_object - graphpes_framework: detach predictions from external library --- torch_sim/models/fairchem.py | 6 +++--- torch_sim/models/fairchem_legacy.py | 8 ++++---- torch_sim/models/graphpes_framework.py | 3 ++- torch_sim/models/metatomic.py | 4 ++-- torch_sim/models/orb.py | 6 +++--- 5 files changed, 14 insertions(+), 13 deletions(-) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 32d30f558..312b53d85 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -243,12 +243,12 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] # Convert predictions to torch-sim format results: dict[str, torch.Tensor] = {} - results["energy"] = predictions["energy"].to(dtype=self._dtype) - results["forces"] = predictions["forces"].to(dtype=self._dtype) + results["energy"] = predictions["energy"].detach().to(dtype=self._dtype) + results["forces"] = predictions["forces"].detach().to(dtype=self._dtype) # Handle stress if requested and available if self._compute_stress and "stress" in predictions: - stress = predictions["stress"].to(dtype=self._dtype) + stress = predictions["stress"].detach().to(dtype=self._dtype) # Ensure stress has correct shape [batch_size, 3, 3] if stress.dim() == 2 and stress.shape[0] == len(atomic_data_list): stress = stress.view(-1, 3, 3) diff --git a/torch_sim/models/fairchem_legacy.py b/torch_sim/models/fairchem_legacy.py index a006cb201..4d80059d9 100644 --- a/torch_sim/models/fairchem_legacy.py +++ b/torch_sim/models/fairchem_legacy.py @@ -425,10 +425,10 @@ def forward( # noqa: C901 ): data_list.append( Data( - pos=sim_state.positions[c - n : c].clone(), - cell=sim_state.row_vector_cell[idx, None].clone(), - atomic_numbers=sim_state.atomic_numbers[c - n : c].clone(), - fixed=fixed[c - n : c].clone(), + pos=sim_state.positions[c - n : c].detach().clone(), + cell=sim_state.row_vector_cell[idx, None].detach().clone(), + atomic_numbers=sim_state.atomic_numbers[c - n : c].detach().clone(), + fixed=fixed[c - n : c].detach().clone(), natoms=n, pbc=sim_state.pbc, ) diff --git a/torch_sim/models/graphpes_framework.py b/torch_sim/models/graphpes_framework.py index 0bca0cdd5..5553840ab 100644 --- a/torch_sim/models/graphpes_framework.py +++ b/torch_sim/models/graphpes_framework.py @@ -185,4 +185,5 @@ def forward(self, state: ts.SimState, **_kwargs: object) -> dict[str, torch.Tens if not isinstance(cutoff, torch.Tensor): raise TypeError("GraphPES model cutoff must be a tensor") atomic_graph = state_to_atomic_graph(state, cutoff) - return self._gp_model.predict(atomic_graph, self._properties) # type: ignore[return-value] + preds = self._gp_model.predict(atomic_graph, self._properties) + return {k: v.detach() for k, v in preds.items()} # type: ignore[return-value] diff --git a/torch_sim/models/metatomic.py b/torch_sim/models/metatomic.py index e4ed32a35..8abf0533a 100644 --- a/torch_sim/models/metatomic.py +++ b/torch_sim/models/metatomic.py @@ -277,12 +277,12 @@ def forward( # noqa: C901, PLR0915 # Concatenate/stack forces and stresses if self._compute_forces: if len(results_by_system["forces"]) > 0: - results["forces"] = torch.cat(results_by_system["forces"]) + results["forces"] = torch.cat(results_by_system["forces"]).detach() else: results["forces"] = torch.empty_like(positions) if self._compute_stress: if len(results_by_system["stress"]) > 0: - results["stress"] = torch.stack(results_by_system["stress"]) + results["stress"] = torch.stack(results_by_system["stress"]).detach() else: results["stress"] = torch.empty_like(cell) diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index a02025048..623da5f49 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -453,11 +453,11 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] if not model_has_direct_heads and prop == "stress": continue _property = "energy" if prop == "free_energy" else prop - results[prop] = predictions[_property] + results[prop] = predictions[_property].detach() if self.conservative: - results["forces"] = results[self.model.grad_forces_name] - results["stress"] = results[self.model.grad_stress_name] + results["forces"] = results[self.model.grad_forces_name].detach() + results["stress"] = results[self.model.grad_stress_name].detach() if "stress" in results and results["stress"].shape[-1] == 6: # NOTE: atleast_2d needed because orb internally gets rid of the batch From 02aceff66807d4e2fea17af39726d08132f7b023 Mon Sep 17 00:00:00 2001 From: Reilly Osadchey Brown Date: Wed, 4 Mar 2026 16:58:46 -0500 Subject: [PATCH 2/3] Fix lint: update type suppression and remove broken codespell comment --- tests/models/test_mattersim.py | 2 -- torch_sim/models/graphpes_framework.py | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/models/test_mattersim.py b/tests/models/test_mattersim.py index 91cd1ee8c..ee495aa7c 100644 --- a/tests/models/test_mattersim.py +++ b/tests/models/test_mattersim.py @@ -1,5 +1,3 @@ -# codespell-ignore: convertor - import traceback import ase.spacegroup diff --git a/torch_sim/models/graphpes_framework.py b/torch_sim/models/graphpes_framework.py index 5553840ab..ab8af45d0 100644 --- a/torch_sim/models/graphpes_framework.py +++ b/torch_sim/models/graphpes_framework.py @@ -185,5 +185,5 @@ def forward(self, state: ts.SimState, **_kwargs: object) -> dict[str, torch.Tens if not isinstance(cutoff, torch.Tensor): raise TypeError("GraphPES model cutoff must be a tensor") atomic_graph = state_to_atomic_graph(state, cutoff) - preds = self._gp_model.predict(atomic_graph, self._properties) - return {k: v.detach() for k, v in preds.items()} # type: ignore[return-value] + preds = self._gp_model.predict(atomic_graph, self._properties) # ty: ignore[call-non-callable] + return {k: v.detach() for k, v in preds.items()} From 6b83fadd474bfe6f2b597b9c0e8cda6dea616ac6 Mon Sep 17 00:00:00 2001 From: Reilly Osadchey Brown Date: Wed, 4 Mar 2026 17:22:37 -0500 Subject: [PATCH 3/3] Refactor: unify detach pattern across all model wrappers Move .detach() calls to a single return statement in each model's forward method instead of detaching inline at each assignment. --- torch_sim/models/fairchem.py | 8 ++++---- torch_sim/models/fairchem_legacy.py | 4 ++-- torch_sim/models/metatomic.py | 8 ++++---- torch_sim/models/orb.py | 8 ++++---- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 312b53d85..cdc7fb246 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -243,15 +243,15 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] # Convert predictions to torch-sim format results: dict[str, torch.Tensor] = {} - results["energy"] = predictions["energy"].detach().to(dtype=self._dtype) - results["forces"] = predictions["forces"].detach().to(dtype=self._dtype) + results["energy"] = predictions["energy"].to(dtype=self._dtype) + results["forces"] = predictions["forces"].to(dtype=self._dtype) # Handle stress if requested and available if self._compute_stress and "stress" in predictions: - stress = predictions["stress"].detach().to(dtype=self._dtype) + stress = predictions["stress"].to(dtype=self._dtype) # Ensure stress has correct shape [batch_size, 3, 3] if stress.dim() == 2 and stress.shape[0] == len(atomic_data_list): stress = stress.view(-1, 3, 3) results["stress"] = stress - return results + return {k: v.detach() for k, v in results.items()} diff --git a/torch_sim/models/fairchem_legacy.py b/torch_sim/models/fairchem_legacy.py index 4d80059d9..05e452f1b 100644 --- a/torch_sim/models/fairchem_legacy.py +++ b/torch_sim/models/fairchem_legacy.py @@ -449,9 +449,9 @@ def forward( # noqa: C901 _pred = predictions[key] if key in self._reshaped_props: _pred = _pred.reshape(self._reshaped_props.get(key)).squeeze() - results[key] = _pred.detach() + results[key] = _pred results["energy"] = results["energy"].squeeze(dim=1) if results.get("stress") is not None and len(results["stress"].shape) == 2: results["stress"] = results["stress"].unsqueeze(dim=0) - return results + return {k: v.detach() for k, v in results.items()} diff --git a/torch_sim/models/metatomic.py b/torch_sim/models/metatomic.py index 8abf0533a..758c50d6b 100644 --- a/torch_sim/models/metatomic.py +++ b/torch_sim/models/metatomic.py @@ -237,7 +237,7 @@ def forward( # noqa: C901, PLR0915 ) results: dict[str, torch.Tensor] = {} - results["energy"] = model_outputs["energy"].block().values.detach().squeeze(-1) + results["energy"] = model_outputs["energy"].block().values.squeeze(-1) # Compute forces and/or stresses if requested tensors_for_autograd = [] @@ -277,13 +277,13 @@ def forward( # noqa: C901, PLR0915 # Concatenate/stack forces and stresses if self._compute_forces: if len(results_by_system["forces"]) > 0: - results["forces"] = torch.cat(results_by_system["forces"]).detach() + results["forces"] = torch.cat(results_by_system["forces"]) else: results["forces"] = torch.empty_like(positions) if self._compute_stress: if len(results_by_system["stress"]) > 0: - results["stress"] = torch.stack(results_by_system["stress"]).detach() + results["stress"] = torch.stack(results_by_system["stress"]) else: results["stress"] = torch.empty_like(cell) - return results + return {k: v.detach() for k, v in results.items()} diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index 623da5f49..9e9d333b8 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -453,11 +453,11 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] if not model_has_direct_heads and prop == "stress": continue _property = "energy" if prop == "free_energy" else prop - results[prop] = predictions[_property].detach() + results[prop] = predictions[_property] if self.conservative: - results["forces"] = results[self.model.grad_forces_name].detach() - results["stress"] = results[self.model.grad_stress_name].detach() + results["forces"] = results[self.model.grad_forces_name] + results["stress"] = results[self.model.grad_stress_name] if "stress" in results and results["stress"].shape[-1] == 6: # NOTE: atleast_2d needed because orb internally gets rid of the batch @@ -466,4 +466,4 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] torch.atleast_2d(results["stress"]) ) - return results + return {k: v.detach() for k, v in results.items()}