diff --git a/tests/e2e_tests/utils/chain_interactions.py b/tests/e2e_tests/utils/chain_interactions.py index 0407f21c71..c63588d066 100644 --- a/tests/e2e_tests/utils/chain_interactions.py +++ b/tests/e2e_tests/utils/chain_interactions.py @@ -74,7 +74,7 @@ def sudo_set_hyperparameter_values( return response.is_success -async def wait_epoch(subtensor: "Subtensor", netuid: int = 1, times: int = 1): +async def wait_epoch(subtensor: "Subtensor", netuid: int = 1, **kwargs): """ Waits for the next epoch to start on a specific subnet. @@ -90,7 +90,7 @@ async def wait_epoch(subtensor: "Subtensor", netuid: int = 1, times: int = 1): raise Exception("could not determine tempo") tempo = q_tempo[0].value logging.info(f"tempo = {tempo}") - await wait_interval(tempo * times, subtensor, netuid) + await wait_interval(tempo, subtensor, netuid, **kwargs) def next_tempo(current_block: int, tempo: int, netuid: int) -> int: @@ -105,6 +105,7 @@ def next_tempo(current_block: int, tempo: int, netuid: int) -> int: Returns: int: The next tempo block number. """ + current_block += 1 interval = tempo + 1 last_epoch = current_block - 1 - (current_block + netuid + 1) % interval next_tempo_ = last_epoch + interval @@ -117,6 +118,7 @@ async def wait_interval( netuid: int = 1, reporting_interval: int = 1, sleep: float = 0.25, + times: int = 1, ): """ Waits until the next tempo interval starts for a specific subnet. @@ -126,7 +128,11 @@ async def wait_interval( the current block number until the next tempo interval starts. """ current_block = subtensor.get_current_block() - next_tempo_block_start = next_tempo(current_block, tempo, netuid) + next_tempo_block_start = current_block + + for _ in range(times): + next_tempo_block_start = next_tempo(next_tempo_block_start, tempo, netuid) + last_reported = None while current_block < next_tempo_block_start: