Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tests/test_autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,19 @@ def test_calculate_scaling_metric(si_sim_state: ts.SimState) -> None:
calculate_memory_scaler(si_sim_state, "invalid_metric")


def test_calculate_scaling_metric_non_periodic(benzene_sim_state: ts.SimState) -> None:
"""Test calculation of scaling metrics for a non-periodic state."""
# Test that calculate passes
n_atoms_metric = calculate_memory_scaler(benzene_sim_state, "n_atoms")
assert n_atoms_metric == benzene_sim_state.n_atoms

# Test n_atoms_x_density metric works for non-periodic systems
n_atoms_x_density_metric = calculate_memory_scaler(
benzene_sim_state, "n_atoms_x_density"
)
assert n_atoms_x_density_metric > 0


def test_split_state(si_double_sim_state: ts.SimState) -> None:
"""Test splitting a batched state into individual states."""
split_states = si_double_sim_state.split()
Expand Down
10 changes: 9 additions & 1 deletion torch_sim/autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,15 @@ def calculate_memory_scaler(
if memory_scales_with == "n_atoms":
return state.n_atoms
if memory_scales_with == "n_atoms_x_density":
volume = torch.abs(torch.linalg.det(state.cell[0])) / 1000
if all(state.pbc):
volume = torch.abs(torch.linalg.det(state.cell[0])) / 1000
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ik it's not added in this PR, but this 1000 feels like a magic number to me. we should at minimum have a comment explaining it - or make it configurable at most.

Copy link
Copy Markdown
Collaborator Author

@orionarcher orionarcher Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ill add a comment, once I remember what the 1000 is for... I think it's an A^3 -> nm^3 conversion

else:
bbox = state.positions.max(dim=0).values - state.positions.min(dim=0).values
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would using this as the general case fail? for example if we had a 2d system or surface with a lot of vacuum the cell is not useful for determining the density?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Neither is perfect in those cases but I agree a bounding box is better than cell. I am happy to make it the general case. Does a clamp value of 2 A make sense to you? Needed for flat molecules like benzene (though benzene isn't actually flat).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clamp is harder to reason about because there are some systems say metallic lithium where the unit cell is ~1A.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the limiting case is something flat anthracene in a vacuum. If you have two molecules far enough apart then this heuristic would fail.

n_atoms*density is just to say it scales as the number of nearest neighbors? why not call a nl algorithm to estimate and go based on that?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's right but I am not quite intuiting how that would resolve the 2D vs 3D tradeoff. What would that look like in practice, call the neighbor list with a 5-6 A cutoff and then calculate number_density from that?

A couple drawbacks:

  • a switch wouldn't be backwards compatible, any users saved n_atoms_x_density metrics would need to be recomputed (though that shouldn't stop us)
  • it's more expensive and needs to be executed on every system

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should probably use a bounding box algorithm to start (it's very fast&simple logic).

If we do want to do something more complicated (using neighborlists), we'd either have to calculate it twice or refactor a bit of code to get it to work well (since I think we determine the batches before we calculate the neighbors). If we do want to support neighborlist-based memory scaling, we'd probably make a new memory_scales_with kind

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I'll just do BB and add 2 A in every non-periodic direction to account for 2D systems and slabs

# add 2 A in non-periodic directions to account for 2D systems and slabs
for i, periodic in enumerate(state.pbc):
if not periodic:
bbox[i] += 2.0
volume = bbox.prod() / 1000 # convert A^3 to nm^3
number_density = state.n_atoms / volume.item()
return state.n_atoms * number_density
raise ValueError(
Expand Down
Loading