From ea49d23d192453dacabd2f25f3f850177fe8350b Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sun, 11 May 2025 16:34:00 -0700 Subject: [PATCH] Fix types so pylance's type checker doesn't complain cast list of simstates more rename Self to SimState linter fixes --- .gitignore | 2 ++ torch_sim/state.py | 21 +++++++++++++-------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 29646ebf1..b294c6580 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,5 @@ coverage.xml # env uv.lock + +.vscode/ diff --git a/torch_sim/state.py b/torch_sim/state.py index e5997cf4f..d5a223811 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -8,7 +8,7 @@ import importlib import warnings from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Literal, Self +from typing import TYPE_CHECKING, Literal, cast import torch @@ -152,7 +152,7 @@ def n_batches(self) -> int: return torch.unique(self.batch).shape[0] @property - def volume(self) -> torch.Tensor: + def volume(self) -> torch.Tensor | None: """Volume of the system.""" return torch.det(self.cell) if self.pbc else None @@ -184,7 +184,7 @@ def row_vector_cell(self, value: torch.Tensor) -> None: """ self.cell = value.transpose(-2, -1) - def clone(self) -> Self: + def clone(self) -> "SimState": """Create a deep copy of the SimState. Creates a new SimState object with identical but independent tensors, @@ -226,7 +226,7 @@ def to_phonopy(self) -> list["PhonopyAtoms"]: """ return ts.io.state_to_phonopy(self) - def split(self) -> list[Self]: + def split(self) -> list["SimState"]: """Split the SimState into a list of single-batch SimStates. Divides the current state into separate states, each containing a single batch, @@ -237,7 +237,9 @@ def split(self) -> list[Self]: """ return _split_state(self) - def pop(self, batch_indices: int | list[int] | slice | torch.Tensor) -> list[Self]: + def pop( + self, batch_indices: int | list[int] | slice | torch.Tensor + ) -> list["SimState"]: """Pop off states with the specified batch indices. This method modifies the original state object by removing the specified @@ -268,7 +270,7 @@ def pop(self, batch_indices: int | list[int] | slice | torch.Tensor) -> list[Sel def to( self, device: torch.device | None = None, dtype: torch.dtype | None = None - ) -> Self: + ) -> "SimState": """Convert the SimState to a new device and/or data type. Args: @@ -282,7 +284,9 @@ def to( """ return state_to_device(self, device, dtype) - def __getitem__(self, batch_indices: int | list[int] | slice | torch.Tensor) -> Self: + def __getitem__( + self, batch_indices: int | list[int] | slice | torch.Tensor + ) -> "SimState": """Enable standard Python indexing syntax for slicing batches. Args: @@ -387,7 +391,7 @@ def _normalize_batch_indices( def state_to_device( state: SimState, device: torch.device | None = None, dtype: torch.dtype | None = None -) -> Self: +) -> SimState: """Convert the SimState to a new device and dtype. Creates a new SimState with all tensors moved to the specified device and @@ -864,6 +868,7 @@ def initialize_state( return state_to_device(system, device, dtype) if isinstance(system, list) and all(isinstance(s, SimState) for s in system): + system = cast("list[SimState]", system) if not all(state.n_batches == 1 for state in system): raise ValueError( "When providing a list of states, to the initialize_state function, "