diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 15640e599..7ec870d3c 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -107,6 +107,19 @@ def test_calculate_scaling_metric(si_sim_state: ts.SimState) -> None: calculate_memory_scaler(si_sim_state, "invalid_metric") +def test_calculate_scaling_metric_non_periodic(benzene_sim_state: ts.SimState) -> None: + """Test calculation of scaling metrics for a non-periodic state.""" + # Test that calculate passes + n_atoms_metric = calculate_memory_scaler(benzene_sim_state, "n_atoms") + assert n_atoms_metric == benzene_sim_state.n_atoms + + # Test n_atoms_x_density metric works for non-periodic systems + n_atoms_x_density_metric = calculate_memory_scaler( + benzene_sim_state, "n_atoms_x_density" + ) + assert n_atoms_x_density_metric > 0 + + def test_split_state(si_double_sim_state: ts.SimState) -> None: """Test splitting a batched state into individual states.""" split_states = si_double_sim_state.split() diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index cd6fca019..a0cf6aabb 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -363,7 +363,15 @@ def calculate_memory_scaler( if memory_scales_with == "n_atoms": return state.n_atoms if memory_scales_with == "n_atoms_x_density": - volume = torch.abs(torch.linalg.det(state.cell[0])) / 1000 + if all(state.pbc): + volume = torch.abs(torch.linalg.det(state.cell[0])) / 1000 + else: + bbox = state.positions.max(dim=0).values - state.positions.min(dim=0).values + # add 2 A in non-periodic directions to account for 2D systems and slabs + for i, periodic in enumerate(state.pbc): + if not periodic: + bbox[i] += 2.0 + volume = bbox.prod() / 1000 # convert A^3 to nm^3 number_density = state.n_atoms / volume.item() return state.n_atoms * number_density raise ValueError(