From 4423fb9f765e0fbab276fe43869b1b945810887f Mon Sep 17 00:00:00 2001 From: kianpu34593 Date: Sun, 20 Jul 2025 14:17:54 -0400 Subject: [PATCH] add new states when the max_memory_scaler is updated --- torch_sim/autobatching.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 9cd556739..a691de366 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -952,6 +952,8 @@ def _get_first_batch(self) -> SimState: scale_factor=self.memory_scaling_factor, ) self.max_memory_scaler = self.max_memory_scaler * self.max_memory_padding + newer_states = self._get_next_states() + states = [*states, *newer_states] return concatenate_states([first_state, *states]) def next_batch( # noqa: C901