From 6261933d18c8a22e38b5bf97b8328a969d050e79 Mon Sep 17 00:00:00 2001 From: Timo Reents Date: Fri, 4 Jul 2025 16:02:35 +0200 Subject: [PATCH 1/3] Fix memory scaling in `determine_max_batch_size` The current version results in an infinite loop when `scale_factor < 1.5` due to the rounding. This is fixed by increasing the batch size by at least `+1`. --- torch_sim/autobatching.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index f2eb32c72..d436076af 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -268,7 +268,7 @@ def determine_max_batch_size( Defaults to 500,000. start_size (int): Initial batch size to test. Defaults to 1. scale_factor (float): Factor to multiply batch size by in each iteration. - Defaults to 1.3. + Defaults to 1.6. Returns: int: Maximum number of batches that fit in GPU memory. @@ -289,7 +289,7 @@ def determine_max_batch_size( """ # Create a geometric sequence of batch sizes sizes = [start_size] - while (next_size := round(sizes[-1] * scale_factor)) < max_atoms: + while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)) < max_atoms: sizes.append(next_size) for i in range(len(sizes)): From 2da557e96e6b5fa4d214b89c9cce5a70567b506f Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Fri, 4 Jul 2025 17:04:03 +0200 Subject: [PATCH 2/3] add `test_autobatching.py` check to ensure `determine_max_batch_size` does regress to infinite loop * remove outdated pymatviz extras 'export-figs' in `6.1_Phonons_MACE.py` and `6.2_QuasiHarmonic_MACE.py` --- .../scripts/6_Phonons/6.1_Phonons_MACE.py | 2 +- .../6_Phonons/6.2_QuasiHarmonic_MACE.py | 2 +- tests/test_autobatching.py | 26 +++++++++++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py index 785c67d53..5dbe3959b 100644 --- a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py +++ b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py @@ -4,7 +4,7 @@ # dependencies = [ # "mace-torch>=0.3.12", # "phonopy>=2.35", -# "pymatviz[export-figs]>=0.15.1", +# "pymatviz>=0.15.1", # "seekpath", # "ase", # ] diff --git a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py index 3abf83231..9075e4168 100644 --- a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py +++ b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py @@ -6,7 +6,7 @@ # dependencies = [ # "mace-torch>=0.3.12", # "phonopy>=2.35", -# "pymatviz[export-figs]>=0.15.1", +# "pymatviz>=0.15.1", # ] # /// diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 2c974ef64..5d15d4a1c 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -376,6 +376,32 @@ def mock_measure(*_args: Any, **_kwargs: Any) -> float: assert max_size == 8 +@pytest.mark.parametrize("scale_factor", [1.1, 1.4]) +def test_determine_max_batch_size_small_scale_factor_no_infinite_loop( + si_sim_state: ts.SimState, + lj_model: LennardJonesModel, + monkeypatch: pytest.MonkeyPatch, + scale_factor: float, +) -> None: + """Test determine_max_batch_size doesn't infinite loop with small scale factors.""" + monkeypatch.setattr( + "torch_sim.autobatching.measure_model_memory_forward", lambda *_: 0.1 + ) + + max_size = determine_max_batch_size( + si_sim_state, lj_model, max_atoms=20, scale_factor=scale_factor + ) + assert 0 < max_size <= 20 + + # Verify sequence is strictly increasing (prevents infinite loop) + sizes = [1] + while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)) < 20: + sizes.append(next_size) + + assert all(sizes[idx] > sizes[idx - 1] for idx in range(1, len(sizes))) + assert max_size == sizes[-1] + + def test_in_flight_auto_batcher_restore_order( si_sim_state: ts.SimState, fe_supercell_sim_state: ts.SimState, From c734c00dcfd321b1e1b25717e04c33621a8320d6 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Fri, 4 Jul 2025 17:11:45 +0200 Subject: [PATCH 3/3] pin plotly!=6.2.0 --- docs/_static/draw_pkg_treemap.py | 1 + examples/scripts/6_Phonons/6.1_Phonons_MACE.py | 3 ++- examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/_static/draw_pkg_treemap.py b/docs/_static/draw_pkg_treemap.py index 3e2775a3d..f339a604c 100644 --- a/docs/_static/draw_pkg_treemap.py +++ b/docs/_static/draw_pkg_treemap.py @@ -6,6 +6,7 @@ # /// script # dependencies = [ # "pymatviz @ git+https://github.com/janosh/pymatviz", +# "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635 # ] # /// diff --git a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py index 5dbe3959b..f88fd351d 100644 --- a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py +++ b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py @@ -4,9 +4,10 @@ # dependencies = [ # "mace-torch>=0.3.12", # "phonopy>=2.35", -# "pymatviz>=0.15.1", +# "pymatviz>=0.16", # "seekpath", # "ase", +# "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635 # ] # /// diff --git a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py index 9075e4168..0fdea6b46 100644 --- a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py +++ b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py @@ -6,7 +6,8 @@ # dependencies = [ # "mace-torch>=0.3.12", # "phonopy>=2.35", -# "pymatviz>=0.15.1", +# "pymatviz>=0.16", +# "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635 # ] # ///