From de197b5734f3ed59dfbbd0bd95a716648f8188a3 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Thu, 23 Apr 2026 15:49:31 -0700 Subject: [PATCH 1/2] use bootstrap coro --- kafka/admin/client.py | 2 +- kafka/consumer/group.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/kafka/admin/client.py b/kafka/admin/client.py index 3d178fb36..5dfab3ea8 100644 --- a/kafka/admin/client.py +++ b/kafka/admin/client.py @@ -224,7 +224,7 @@ def __init__(self, **configs): self._manager.start() # Bootstrap on __init__ - self._manager.run(self._manager.bootstrap(timeout_ms=self.config['bootstrap_timeout_ms'])) + self._manager.run(self._manager.bootstrap, self.config['bootstrap_timeout_ms']) self._closed = False self._controller_id = None self._coordinator_cache = {} # {group_id: node_id} diff --git a/kafka/consumer/group.py b/kafka/consumer/group.py index fe36385b1..15606aa59 100644 --- a/kafka/consumer/group.py +++ b/kafka/consumer/group.py @@ -381,9 +381,14 @@ def __init__(self, *topics, **configs): self._metrics = None self._client = self.config['kafka_client'](metrics=self._metrics, **self.config) - - # Get auto-discovered / normalized version from client - self.config['api_version'] = self._client.get_broker_version(timeout_ms=self.config['api_version_auto_timeout_ms']) + self._manager = self._client._manager + + # If api_version was not passed explicitly, bootstrap to auto-discover + # it. bootstrap is passed as a deferred coroutine so that once the IO + # thread is introduced in a later phase it runs on the IO thread. + if self._manager.broker_version_data is None: + self._manager.run(self._manager.bootstrap, self.config['api_version_auto_timeout_ms']) + self.config['api_version'] = self._manager.broker_version # Coordinator configurations are different for older brokers # max_poll_interval_ms is not supported directly -- it must the be From 5dbc48ab700e874944cb417bbc9974997048fd63 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Thu, 23 Apr 2026 15:49:51 -0700 Subject: [PATCH 2/2] Use async defs for fetcher offsets --- kafka/consumer/fetcher.py | 47 +++++----- kafka/net/manager.py | 32 +++++++ test/consumer/test_fetcher.py | 166 ++++++++++++++-------------------- 3 files changed, 122 insertions(+), 123 deletions(-) diff --git a/kafka/consumer/fetcher.py b/kafka/consumer/fetcher.py index 88e5f5599..50f07b764 100644 --- a/kafka/consumer/fetcher.py +++ b/kafka/consumer/fetcher.py @@ -221,13 +221,13 @@ def offsets_by_times(self, timestamps, timeout_ms=None): Raises: KafkaTimeoutError if timeout_ms provided """ - offsets = self._fetch_offsets_by_times(timestamps, timeout_ms) + offsets = self._client._manager.run(self._fetch_offsets_by_times_async, timestamps, timeout_ms) for tp in timestamps: if tp not in offsets: offsets[tp] = None return offsets - def _fetch_offsets_by_times(self, timestamps, timeout_ms=None): + async def _fetch_offsets_by_times_async(self, timestamps, timeout_ms=None): if not timestamps: return {} @@ -239,35 +239,30 @@ def _fetch_offsets_by_times(self, timestamps, timeout_ms=None): return {} future = self._send_list_offsets_requests(timestamps) - self._client.poll(future=future, timeout_ms=timer.timeout_ms) - - # Timeout w/o future completion - if not future.is_done: + try: + offsets, retry = await self._client._manager.wait_for(future, timer.timeout_ms) + except Errors.KafkaTimeoutError: break - - if future.succeeded(): - offsets, retry = future.value + except Exception as exc: + if not getattr(exc, 'retriable', False): + raise + if getattr(exc, 'invalid_metadata', False) or self._client._manager.cluster.need_update: + refresh_future = self._client._manager.update_metadata() + try: + await self._client._manager.wait_for(refresh_future, timer.timeout_ms) + except Errors.KafkaTimeoutError: + break + else: + delay = self.config['retry_backoff_ms'] / 1000 + if timer.timeout_ms is not None: + delay = min(delay, timer.timeout_ms / 1000) + await self._client._manager._net.sleep(delay) + else: fetched_offsets.update(offsets) if not retry: return fetched_offsets - timestamps = {tp: timestamps[tp] for tp in retry} - elif not future.retriable(): - raise future.exception # pylint: disable-msg=raising-bad-type - - elif future.exception.invalid_metadata or self._client.cluster.need_update: - refresh_future = self._client.cluster.request_update() - self._client.poll(future=refresh_future, timeout_ms=timer.timeout_ms) - - if not future.is_done: - break - else: - if timer.timeout_ms is None or timer.timeout_ms > self.config['retry_backoff_ms']: - time.sleep(self.config['retry_backoff_ms'] / 1000) - else: - time.sleep(timer.timeout_ms / 1000) - timer.maybe_raise() raise Errors.KafkaTimeoutError( @@ -283,7 +278,7 @@ def end_offsets(self, partitions, timeout_ms): def beginning_or_end_offset(self, partitions, timestamp, timeout_ms): timestamps = dict([(tp, timestamp) for tp in partitions]) - offsets = self._fetch_offsets_by_times(timestamps, timeout_ms) + offsets = self._client._manager.run(self._fetch_offsets_by_times_async, timestamps, timeout_ms) for tp in timestamps: offsets[tp] = offsets[tp].offset return offsets diff --git a/kafka/net/manager.py b/kafka/net/manager.py index fd99288bd..ecf953225 100644 --- a/kafka/net/manager.py +++ b/kafka/net/manager.py @@ -387,6 +387,38 @@ def stop(self, timeout=None): def poll(self, timeout_ms=None, future=None): return self._net.poll(timeout_ms=timeout_ms, future=future) + async def wait_for(self, future, timeout_ms): + """Await `future` with a timeout in ms. Raises KafkaTimeoutError on timeout. + + Must be awaited from a coroutine running on this loop. The underlying + future is not cancelled on timeout — it continues to run; the timeout + only unblocks the awaiter. + """ + if timeout_ms is None: + return await future + wrapper = Future() + def _on_success(value): + if not wrapper.is_done: + wrapper.success(value) + def _on_failure(exc): + if not wrapper.is_done: + wrapper.failure(exc) + future.add_callback(_on_success) + future.add_errback(_on_failure) + def _on_timeout(): + if not wrapper.is_done: + wrapper.failure(Errors.KafkaTimeoutError( + 'Timed out after %s ms' % timeout_ms)) + timer = self._net.call_later(timeout_ms / 1000, _on_timeout) + try: + return await wrapper + finally: + if not timer.is_done: + try: + self._net.unschedule(timer) + except ValueError: + pass + async def _invoke(self, coro, args): """Invoke coro/awaitable/function and fully resolve the result. diff --git a/test/consumer/test_fetcher.py b/test/consumer/test_fetcher.py index 166557f90..e3c7058a2 100644 --- a/test/consumer/test_fetcher.py +++ b/test/consumer/test_fetcher.py @@ -683,7 +683,7 @@ def test_reset_offsets_paused_with_valid(subscription_state, client, mocker): subscription_state.assignment[tp].position = OffsetAndMetadata(10, '', -1) subscription_state.pause(tp) # paused partition already has a valid position - mocker.patch.object(fetcher, '_fetch_offsets_by_times', return_value={tp: OffsetAndTimestamp(0, 1, -1)}) + mocker.patch.object(fetcher, '_reset_offsets_async') fetcher.reset_offsets_if_needed() assert not subscription_state.is_offset_reset_needed(tp) @@ -774,7 +774,21 @@ def test_seek_before_exception(client, mocker): class TestFetchOffsetsByTimes: - def _make_fetcher(self, client, mocker): + @pytest.fixture + def net_client(self): + from kafka.net.compat import KafkaNetClient + cli = KafkaNetClient( + bootstrap_servers=['localhost:1'], + socket_connection_timeout_ms=1000, + reconnect_backoff_ms=10, + reconnect_backoff_max_ms=100, + ) + try: + yield cli + finally: + cli.close() + + def _make_fetcher(self, client): subscription_state = SubscriptionState() subscription_state.subscribe(topics=['test']) tp = TopicPartition('test', 0) @@ -782,179 +796,137 @@ def _make_fetcher(self, client, mocker): subscription_state.seek(tp, 0) return Fetcher(client, subscription_state) - def test_empty_timestamps(self, client, metrics, mocker): - fetcher = self._make_fetcher(client, mocker) - assert fetcher._fetch_offsets_by_times({}) == {} + def test_empty_timestamps(self, net_client): + fetcher = self._make_fetcher(net_client) + assert fetcher.offsets_by_times({}) == {} - def test_success_no_retry(self, client, mocker): - fetcher = self._make_fetcher(client, mocker) + def test_success_no_retry(self, net_client, mocker): + fetcher = self._make_fetcher(net_client) tp = TopicPartition('test', 0) timestamps = {tp: 1000} expected_offset = OffsetAndTimestamp(10, 1000, -1) future = Future() + future.success(({tp: expected_offset}, set())) mocker.patch.object(fetcher, '_send_list_offsets_requests', return_value=future) - mocker.patch.object(fetcher._client, 'poll', side_effect=lambda **kw: future.success(({tp: expected_offset}, set()))) - result = fetcher._fetch_offsets_by_times(timestamps, timeout_ms=10000) + result = fetcher.offsets_by_times(timestamps, timeout_ms=10000) assert result == {tp: expected_offset} - def test_success_with_retry(self, client, mocker): - fetcher = self._make_fetcher(client, mocker) + def test_success_with_retry(self, net_client, mocker): + fetcher = self._make_fetcher(net_client) tp0 = TopicPartition('test', 0) tp1 = TopicPartition('test', 1) timestamps = {tp0: 1000, tp1: 2000} offset0 = OffsetAndTimestamp(10, 1000, -1) offset1 = OffsetAndTimestamp(20, 2000, -1) - # First call succeeds for tp0 but needs retry for tp1 - future1 = Future() - future2 = Future() + future1 = Future().success(({tp0: offset0}, {tp1})) + future2 = Future().success(({tp1: offset1}, set())) futures = iter([future1, future2]) mocker.patch.object(fetcher, '_send_list_offsets_requests', side_effect=lambda ts: next(futures)) - def poll_side_effect(**kw): - f = kw.get('future') - if f is future1: - f.success(({tp0: offset0}, {tp1})) - elif f is future2: - f.success(({tp1: offset1}, set())) - - mocker.patch.object(fetcher._client, 'poll', side_effect=poll_side_effect) - - result = fetcher._fetch_offsets_by_times(timestamps, timeout_ms=10000) + result = fetcher.offsets_by_times(timestamps, timeout_ms=10000) assert result == {tp0: offset0, tp1: offset1} - def test_timeout_raises(self, client, mocker): - fetcher = self._make_fetcher(client, mocker) + def test_timeout_raises(self, net_client, mocker): + fetcher = self._make_fetcher(net_client) tp = TopicPartition('test', 0) timestamps = {tp: 1000} + # Return a future that never completes future = Future() mocker.patch.object(fetcher, '_send_list_offsets_requests', return_value=future) - # poll does not complete the future - mocker.patch.object(fetcher._client, 'poll') with pytest.raises(Errors.KafkaTimeoutError): - fetcher._fetch_offsets_by_times(timestamps, timeout_ms=10000) + fetcher.offsets_by_times(timestamps, timeout_ms=50) - def test_non_retriable_error_raises(self, client, mocker): - fetcher = self._make_fetcher(client, mocker) + def test_non_retriable_error_raises(self, net_client, mocker): + fetcher = self._make_fetcher(net_client) tp = TopicPartition('test', 0) timestamps = {tp: 1000} - future = Future() - mocker.patch.object(fetcher, '_send_list_offsets_requests', return_value=future) # AuthorizationError is not retriable error = Errors.TopicAuthorizationFailedError() - mocker.patch.object(fetcher._client, 'poll', side_effect=lambda **kw: future.failure(error)) + future = Future().failure(error) + mocker.patch.object(fetcher, '_send_list_offsets_requests', return_value=future) with pytest.raises(Errors.TopicAuthorizationFailedError): - fetcher._fetch_offsets_by_times(timestamps, timeout_ms=10000) + fetcher.offsets_by_times(timestamps, timeout_ms=10000) - def test_retriable_invalid_metadata_triggers_refresh(self, client, mocker): - fetcher = self._make_fetcher(client, mocker) + def test_retriable_invalid_metadata_triggers_refresh(self, net_client, mocker): + fetcher = self._make_fetcher(net_client) tp = TopicPartition('test', 0) timestamps = {tp: 1000} expected_offset = OffsetAndTimestamp(10, 1000, -1) # First call fails with invalid_metadata error, second succeeds - future1 = Future() - future2 = Future() + future1 = Future().failure(NotLeaderForPartitionError()) + future2 = Future().success(({tp: expected_offset}, set())) futures = iter([future1, future2]) mocker.patch.object(fetcher, '_send_list_offsets_requests', side_effect=lambda ts: next(futures)) - refresh_future = Future() - mocker.patch.object(fetcher._client.cluster, 'request_update', return_value=refresh_future) + refresh_future = Future().success(None) + update_metadata_mock = mocker.patch.object( + fetcher._client._manager, 'update_metadata', return_value=refresh_future) - call_count = [0] - def poll_side_effect(**kw): - f = kw.get('future') - if f is future1: - f.failure(NotLeaderForPartitionError()) - elif f is refresh_future: - refresh_future.success(None) - elif f is future2: - f.success(({tp: expected_offset}, set())) - call_count[0] += 1 - - mocker.patch.object(fetcher._client, 'poll', side_effect=poll_side_effect) - - result = fetcher._fetch_offsets_by_times(timestamps, timeout_ms=10000) + result = fetcher.offsets_by_times(timestamps, timeout_ms=10000) assert result == {tp: expected_offset} - fetcher._client.cluster.request_update.assert_called_once() + update_metadata_mock.assert_called_once() - def test_retriable_non_metadata_error_sleeps(self, client, mocker): - fetcher = self._make_fetcher(client, mocker) + def test_retriable_non_metadata_error_sleeps(self, net_client, mocker): + fetcher = self._make_fetcher(net_client) tp = TopicPartition('test', 0) timestamps = {tp: 1000} expected_offset = OffsetAndTimestamp(10, 1000, -1) # RequestTimedOutError is retriable but not invalid_metadata - future1 = Future() - future2 = Future() + future1 = Future().failure(Errors.RequestTimedOutError()) + future2 = Future().success(({tp: expected_offset}, set())) futures = iter([future1, future2]) mocker.patch.object(fetcher, '_send_list_offsets_requests', side_effect=lambda ts: next(futures)) # Ensure cluster does not need update - mocker.patch.object(type(fetcher._client.cluster), 'need_update', new_callable=mocker.PropertyMock, return_value=False) - - def poll_side_effect(**kw): - f = kw.get('future') - if f is future1: - f.failure(Errors.RequestTimedOutError()) - elif f is future2: - f.success(({tp: expected_offset}, set())) + mocker.patch.object(type(fetcher._client._manager.cluster), 'need_update', + new_callable=mocker.PropertyMock, return_value=False) - mocker.patch.object(fetcher._client, 'poll', side_effect=poll_side_effect) - mock_sleep = mocker.patch('time.sleep') + # Spy on call_later to verify a backoff timer was scheduled + call_later_spy = mocker.spy(fetcher._client._manager._net, 'call_later') - result = fetcher._fetch_offsets_by_times(timestamps, timeout_ms=10000) + result = fetcher.offsets_by_times(timestamps, timeout_ms=10000) assert result == {tp: expected_offset} - mock_sleep.assert_called_once() + # At least one call_later was for the retry_backoff sleep + assert call_later_spy.call_count >= 1 - def test_success_does_not_check_exception(self, client, mocker): + def test_success_does_not_check_exception(self, net_client, mocker): """Regression: successful future should not fall through to check future.exception.""" - fetcher = self._make_fetcher(client, mocker) + fetcher = self._make_fetcher(net_client) tp0 = TopicPartition('test', 0) tp1 = TopicPartition('test', 1) timestamps = {tp0: 1000, tp1: 2000} offset0 = OffsetAndTimestamp(10, 1000, -1) offset1 = OffsetAndTimestamp(20, 2000, -1) - future1 = Future() - future2 = Future() + # Succeeds but has retry partitions -- the bug was that code + # would fall through to check future.exception (which is None), + # causing an AttributeError + future1 = Future().success(({tp0: offset0}, {tp1})) + future2 = Future().success(({tp1: offset1}, set())) futures = iter([future1, future2]) mocker.patch.object(fetcher, '_send_list_offsets_requests', side_effect=lambda ts: next(futures)) - def poll_side_effect(**kw): - f = kw.get('future') - if f is future1: - # Succeeds but has retry partitions -- the bug was that code - # would fall through to check future.exception (which is None), - # causing an AttributeError - f.success(({tp0: offset0}, {tp1})) - elif f is future2: - f.success(({tp1: offset1}, set())) - - mocker.patch.object(fetcher._client, 'poll', side_effect=poll_side_effect) - # Should not raise AttributeError - result = fetcher._fetch_offsets_by_times(timestamps, timeout_ms=10000) + result = fetcher.offsets_by_times(timestamps, timeout_ms=10000) assert result == {tp0: offset0, tp1: offset1} - def test_no_timeout_passes_none(self, client, mocker): - fetcher = self._make_fetcher(client, mocker) + def test_no_timeout_passes_none(self, net_client, mocker): + fetcher = self._make_fetcher(net_client) tp = TopicPartition('test', 0) timestamps = {tp: 1000} expected_offset = OffsetAndTimestamp(10, 1000, -1) - future = Future() + future = Future().success(({tp: expected_offset}, set())) mocker.patch.object(fetcher, '_send_list_offsets_requests', return_value=future) - mocker.patch.object(fetcher._client, 'poll', side_effect=lambda **kw: future.success(({tp: expected_offset}, set()))) - result = fetcher._fetch_offsets_by_times(timestamps, timeout_ms=None) + result = fetcher.offsets_by_times(timestamps, timeout_ms=None) assert result == {tp: expected_offset} - # With timeout_ms=None, poll should receive None timeout - fetcher._client.poll.assert_called_once() - assert fetcher._client.poll.call_args[1]['timeout_ms'] is None