Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ coverage.xml

# env
uv.lock

.vscode/
21 changes: 13 additions & 8 deletions torch_sim/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import importlib
import warnings
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Literal, Self
from typing import TYPE_CHECKING, Literal, cast

import torch

Expand Down Expand Up @@ -152,7 +152,7 @@
return torch.unique(self.batch).shape[0]

@property
def volume(self) -> torch.Tensor:
def volume(self) -> torch.Tensor | None:
"""Volume of the system."""
return torch.det(self.cell) if self.pbc else None

Expand Down Expand Up @@ -184,7 +184,7 @@
"""
self.cell = value.transpose(-2, -1)

def clone(self) -> Self:
def clone(self) -> "SimState":
Copy link
Copy Markdown
Member

@CompRhys CompRhys May 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think typing as Self is correct per the pep? https://peps.python.org/pep-0673/

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree Self is better here (as in higher fidelity). it will correctly resolve to child classes of SimState like DeformState

https://github.com/Radical-AI/torch-sim/blob/8211e08d36b4ae73407564c78a1503f07f27b26b/tests/test_state.py#L489

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting. pylance gave me errors for it though. maybe ty might not complain

"""Create a deep copy of the SimState.

Creates a new SimState object with identical but independent tensors,
Expand Down Expand Up @@ -226,7 +226,7 @@
"""
return ts.io.state_to_phonopy(self)

def split(self) -> list[Self]:
def split(self) -> list["SimState"]:
"""Split the SimState into a list of single-batch SimStates.

Divides the current state into separate states, each containing a single batch,
Expand All @@ -237,7 +237,9 @@
"""
return _split_state(self)

def pop(self, batch_indices: int | list[int] | slice | torch.Tensor) -> list[Self]:
def pop(
self, batch_indices: int | list[int] | slice | torch.Tensor
) -> list["SimState"]:
"""Pop off states with the specified batch indices.

This method modifies the original state object by removing the specified
Expand Down Expand Up @@ -268,7 +270,7 @@

def to(
self, device: torch.device | None = None, dtype: torch.dtype | None = None
) -> Self:
) -> "SimState":
"""Convert the SimState to a new device and/or data type.

Args:
Expand All @@ -282,7 +284,9 @@
"""
return state_to_device(self, device, dtype)

def __getitem__(self, batch_indices: int | list[int] | slice | torch.Tensor) -> Self:
def __getitem__(
self, batch_indices: int | list[int] | slice | torch.Tensor
) -> "SimState":
"""Enable standard Python indexing syntax for slicing batches.

Args:
Expand Down Expand Up @@ -387,7 +391,7 @@

def state_to_device(
state: SimState, device: torch.device | None = None, dtype: torch.dtype | None = None
) -> Self:
) -> SimState:
"""Convert the SimState to a new device and dtype.

Creates a new SimState with all tensors moved to the specified device and
Expand Down Expand Up @@ -864,6 +868,7 @@
return state_to_device(system, device, dtype)

if isinstance(system, list) and all(isinstance(s, SimState) for s in system):
system = cast("list[SimState]", system)

Check warning on line 871 in torch_sim/state.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/state.py#L871

Added line #L871 was not covered by tests
if not all(state.n_batches == 1 for state in system):
raise ValueError(
"When providing a list of states, to the initialize_state function, "
Expand Down
Loading