From f65c434975eb488f46bfa7128ddaac002fee81c9 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Tue, 20 Jan 2026 11:34:17 -0500 Subject: [PATCH 1/4] Fix memory scaling calculation for non-periodic boundary conditions --- torch_sim/autobatching.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index cd6fca019..3189182d1 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -363,7 +363,11 @@ def calculate_memory_scaler( if memory_scales_with == "n_atoms": return state.n_atoms if memory_scales_with == "n_atoms_x_density": - volume = torch.abs(torch.linalg.det(state.cell[0])) / 1000 + if all(state.pbc): + volume = torch.abs(torch.linalg.det(state.cell[0])) / 1000 + else: + bbox = state.positions.max(dim=0).values - state.positions.min(dim=0).values + volume = bbox.prod() / 1000 number_density = state.n_atoms / volume.item() return state.n_atoms * number_density raise ValueError( From 0585acfa3541c373ae1785cedadb9b1abb608a89 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Tue, 20 Jan 2026 11:40:39 -0500 Subject: [PATCH 2/4] add test that is fixed by change --- tests/test_autobatching.py | 13 +++++++++++++ torch_sim/autobatching.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 15640e599..7ec870d3c 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -107,6 +107,19 @@ def test_calculate_scaling_metric(si_sim_state: ts.SimState) -> None: calculate_memory_scaler(si_sim_state, "invalid_metric") +def test_calculate_scaling_metric_non_periodic(benzene_sim_state: ts.SimState) -> None: + """Test calculation of scaling metrics for a non-periodic state.""" + # Test that calculate passes + n_atoms_metric = calculate_memory_scaler(benzene_sim_state, "n_atoms") + assert n_atoms_metric == benzene_sim_state.n_atoms + + # Test n_atoms_x_density metric works for non-periodic systems + n_atoms_x_density_metric = calculate_memory_scaler( + benzene_sim_state, "n_atoms_x_density" + ) + assert n_atoms_x_density_metric > 0 + + def test_split_state(si_double_sim_state: ts.SimState) -> None: """Test splitting a batched state into individual states.""" split_states = si_double_sim_state.split() diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 3189182d1..c7ebe1139 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -367,7 +367,7 @@ def calculate_memory_scaler( volume = torch.abs(torch.linalg.det(state.cell[0])) / 1000 else: bbox = state.positions.max(dim=0).values - state.positions.min(dim=0).values - volume = bbox.prod() / 1000 + volume = bbox.clamp(min=1.0).prod() / 1000 # min 1 Å for planar molecules number_density = state.n_atoms / volume.item() return state.n_atoms * number_density raise ValueError( From 5cc306942b4140f4d82e76c0c6c9f7459d2e9cca Mon Sep 17 00:00:00 2001 From: orionarcher Date: Tue, 20 Jan 2026 11:42:18 -0500 Subject: [PATCH 3/4] clamp at 2 --- torch_sim/autobatching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index c7ebe1139..ed082c0ac 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -367,7 +367,7 @@ def calculate_memory_scaler( volume = torch.abs(torch.linalg.det(state.cell[0])) / 1000 else: bbox = state.positions.max(dim=0).values - state.positions.min(dim=0).values - volume = bbox.clamp(min=1.0).prod() / 1000 # min 1 Å for planar molecules + volume = bbox.clamp(min=2.0).prod() / 1000 # min 1 Å for planar molecules number_density = state.n_atoms / volume.item() return state.n_atoms * number_density raise ValueError( From b3e4d83135baec0b86421cdfdb8723cf53a92b2f Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 21 Jan 2026 13:45:21 -0500 Subject: [PATCH 4/4] respond to rhys and curtis PRs --- torch_sim/autobatching.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index ed082c0ac..a0cf6aabb 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -367,7 +367,11 @@ def calculate_memory_scaler( volume = torch.abs(torch.linalg.det(state.cell[0])) / 1000 else: bbox = state.positions.max(dim=0).values - state.positions.min(dim=0).values - volume = bbox.clamp(min=2.0).prod() / 1000 # min 1 Å for planar molecules + # add 2 A in non-periodic directions to account for 2D systems and slabs + for i, periodic in enumerate(state.pbc): + if not periodic: + bbox[i] += 2.0 + volume = bbox.prod() / 1000 # convert A^3 to nm^3 number_density = state.n_atoms / volume.item() return state.n_atoms * number_density raise ValueError(