diff --git a/docs/_static/draw_pkg_treemap.py b/docs/_static/draw_pkg_treemap.py index 3e2775a3d..f339a604c 100644 --- a/docs/_static/draw_pkg_treemap.py +++ b/docs/_static/draw_pkg_treemap.py @@ -6,6 +6,7 @@ # /// script # dependencies = [ # "pymatviz @ git+https://github.com/janosh/pymatviz", +# "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635 # ] # /// diff --git a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py index 785c67d53..f88fd351d 100644 --- a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py +++ b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py @@ -4,9 +4,10 @@ # dependencies = [ # "mace-torch>=0.3.12", # "phonopy>=2.35", -# "pymatviz[export-figs]>=0.15.1", +# "pymatviz>=0.16", # "seekpath", # "ase", +# "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635 # ] # /// diff --git a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py index 3abf83231..0fdea6b46 100644 --- a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py +++ b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py @@ -6,7 +6,8 @@ # dependencies = [ # "mace-torch>=0.3.12", # "phonopy>=2.35", -# "pymatviz[export-figs]>=0.15.1", +# "pymatviz>=0.16", +# "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635 # ] # /// diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 2c974ef64..5d15d4a1c 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -376,6 +376,32 @@ def mock_measure(*_args: Any, **_kwargs: Any) -> float: assert max_size == 8 +@pytest.mark.parametrize("scale_factor", [1.1, 1.4]) +def test_determine_max_batch_size_small_scale_factor_no_infinite_loop( + si_sim_state: ts.SimState, + lj_model: LennardJonesModel, + monkeypatch: pytest.MonkeyPatch, + scale_factor: float, +) -> None: + """Test determine_max_batch_size doesn't infinite loop with small scale factors.""" + monkeypatch.setattr( + "torch_sim.autobatching.measure_model_memory_forward", lambda *_: 0.1 + ) + + max_size = determine_max_batch_size( + si_sim_state, lj_model, max_atoms=20, scale_factor=scale_factor + ) + assert 0 < max_size <= 20 + + # Verify sequence is strictly increasing (prevents infinite loop) + sizes = [1] + while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)) < 20: + sizes.append(next_size) + + assert all(sizes[idx] > sizes[idx - 1] for idx in range(1, len(sizes))) + assert max_size == sizes[-1] + + def test_in_flight_auto_batcher_restore_order( si_sim_state: ts.SimState, fe_supercell_sim_state: ts.SimState, diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index f2eb32c72..d436076af 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -268,7 +268,7 @@ def determine_max_batch_size( Defaults to 500,000. start_size (int): Initial batch size to test. Defaults to 1. scale_factor (float): Factor to multiply batch size by in each iteration. - Defaults to 1.3. + Defaults to 1.6. Returns: int: Maximum number of batches that fit in GPU memory. @@ -289,7 +289,7 @@ def determine_max_batch_size( """ # Create a geometric sequence of batch sizes sizes = [start_size] - while (next_size := round(sizes[-1] * scale_factor)) < max_atoms: + while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)) < max_atoms: sizes.append(next_size) for i in range(len(sizes)):