diff --git a/tests/test_runners.py b/tests/test_runners.py index 2c5e7d4fb..0b440c884 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -515,11 +515,9 @@ def test_optimize_with_default_autobatcher( """Test optimize with autobatcher.""" def mock_estimate(*args, **kwargs) -> float: # noqa: ARG001 - return 10_000.0 + return 200 - monkeypatch.setattr( - "torch_sim.autobatching.estimate_max_memory_scaler", mock_estimate - ) + monkeypatch.setattr("torch_sim.autobatching.determine_max_batch_size", mock_estimate) states = [ar_supercell_sim_state, fe_supercell_sim_state, ar_supercell_sim_state] triple_state = initialize_state( diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index e5e75c3d5..991230a49 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -853,6 +853,7 @@ def load_states( self.first_batch_returned = False self._first_batch = self._get_first_batch() + return self.max_memory_scaler def _get_next_states(self) -> list[SimState]: """Add states from the iterator until max_memory_scaler is reached. @@ -928,19 +929,17 @@ def _get_first_batch(self) -> SimState: self.current_idx += [0] self.swap_attempts.append(0) # Initialize attempt counter for first state self.iterator_idx += 1 - # self.total_metric += first_metric # if max_metric is not set, estimate it has_max_metric = bool(self.max_memory_scaler) if not has_max_metric: - self.max_memory_scaler = estimate_max_memory_scaler( + n_batches = determine_max_batch_size( + first_state, self.model, - [first_state], - [first_metric], max_atoms=self.max_atoms_to_try, scale_factor=self.memory_scaling_factor, ) - self.max_memory_scaler = self.max_memory_scaler * 0.8 + self.max_memory_scaler = n_batches * first_metric * 0.8 states = self._get_next_states() @@ -953,7 +952,6 @@ def _get_first_batch(self) -> SimState: scale_factor=self.memory_scaling_factor, ) self.max_memory_scaler = self.max_memory_scaler * self.max_memory_padding - print(f"Max metric calculated: {self.max_memory_scaler}") return concatenate_states([first_state, *states]) def next_batch( diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 148cafa14..be43147d4 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -89,7 +89,7 @@ def _configure_batches_iterator( elif autobatcher is False: batches = [(state, [])] else: - raise ValueError( + raise TypeError( f"Invalid autobatcher type: {type(autobatcher).__name__}, " "must be bool or BinningAutoBatcher." ) @@ -206,7 +206,7 @@ def _configure_in_flight_autobatcher( if isinstance(autobatcher, InFlightAutoBatcher): autobatcher.return_indices = True autobatcher.max_attempts = max_attempts - else: + elif isinstance(autobatcher, bool): if autobatcher: memory_scales_with = model.memory_scales_with max_memory_scaler = None @@ -221,6 +221,11 @@ def _configure_in_flight_autobatcher( max_iterations=max_attempts, max_memory_padding=0.9, ) + else: + raise TypeError( + f"Invalid autobatcher type: {type(autobatcher).__name__}, " + "must be bool or InFlightAutoBatcher." + ) return autobatcher