From a5c91c5394481284cd293bf63bc4920c7e37da85 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Thu, 23 Apr 2026 15:06:20 -0700 Subject: [PATCH] Network IO thread (pt1): Basic support; use in Admin Client --- kafka/admin/_acls.py | 6 +- kafka/admin/client.py | 4 ++ kafka/net/compat.py | 2 + kafka/net/manager.py | 97 +++++++++++++++++++++---- kafka/net/selector.py | 14 ++++ test/admin/test_admin_concurrent.py | 107 ++++++++++++++++++++++++++++ 6 files changed, 214 insertions(+), 16 deletions(-) create mode 100644 test/admin/test_admin_concurrent.py diff --git a/kafka/admin/_acls.py b/kafka/admin/_acls.py index 8576112b1..3b3da0c16 100644 --- a/kafka/admin/_acls.py +++ b/kafka/admin/_acls.py @@ -55,7 +55,7 @@ def describe_acls(self, acl_filter): operation=acl_filter.operation, permission_type=acl_filter.permission_type ) - response = self._manager.run(self._manager.send(request)) # pylint: disable=E0606 + response = self._manager.run(self._manager.send, request) return self._convert_describe_acls_response_to_acls(response) @staticmethod @@ -132,7 +132,7 @@ def create_acls(self, acls): creations = [self._convert_create_acls_resource_request(acl) for acl in acls] min_version = 3 if any(creation.resource_type == ResourceType.USER for creation in creations) else 0 request = CreateAclsRequest(creations=creations, min_version=min_version) - response = self._manager.run(self._manager.send(request)) # pylint: disable=E0606 + response = self._manager.run(self._manager.send, request) return self._convert_create_acls_response_to_acls(acls, response) @staticmethod @@ -191,7 +191,7 @@ def delete_acls(self, acl_filters): filters = [self._convert_delete_acls_resource_request(acl) for acl in acl_filters] min_version = 3 if any(_filter.resource_type_filter == ResourceType.USER for _filter in filters) else 0 request = DeleteAclsRequest(filters=filters, min_version=min_version) - response = self._manager.run(self._manager.send(request)) # pylint: disable=E0606 + response = self._manager.run(self._manager.send, request) return self._convert_delete_acls_response_to_matching_acls(acl_filters, response) diff --git a/kafka/admin/client.py b/kafka/admin/client.py index fc52a27d0..3d178fb36 100644 --- a/kafka/admin/client.py +++ b/kafka/admin/client.py @@ -219,6 +219,10 @@ def __init__(self, **configs): # Goal: migrate all self._client calls -> self._manager (skipping compat layer) self._manager = self._client._manager + # Run all IO on a dedicated background thread; public admin methods + # block on cross-thread Events via self._manager.run(...). + self._manager.start() + # Bootstrap on __init__ self._manager.run(self._manager.bootstrap(timeout_ms=self.config['bootstrap_timeout_ms'])) self._closed = False diff --git a/kafka/net/compat.py b/kafka/net/compat.py index aaef65c54..450182c9e 100644 --- a/kafka/net/compat.py +++ b/kafka/net/compat.py @@ -141,6 +141,8 @@ def poll(self, timeout_ms=None, future=None): return self._manager.poll(timeout_ms=timeout_ms, future=future) def close(self, node_id=None): + if node_id is None: + self._manager.stop() self._manager.close(node_id=node_id) if node_id is None: self._net.close() diff --git a/kafka/net/manager.py b/kafka/net/manager.py index 34d7df1c6..fd99288bd 100644 --- a/kafka/net/manager.py +++ b/kafka/net/manager.py @@ -4,6 +4,7 @@ import random import socket import ssl +import threading import time from .inet import create_connection @@ -69,6 +70,9 @@ def __init__(self, net, cluster, **configs): self.broker_version_data = None self._bootstrap_future = None self._metadata_future = None + self._io_thread = None + self._pending_waiters = {} # event -> state dict, for pending run() waiters + self._pending_waiters_lock = threading.Lock() if self.config['metrics']: self._sensors = KafkaManagerMetrics( self.config['metrics'], self.config['metric_group_prefix'], self._conns) @@ -354,12 +358,56 @@ def close(self, node_id=None): for conn in list(self._conns.values()): conn.close() + def start(self): + """Spawn a daemon IO thread that owns the event loop. Idempotent.""" + if self._io_thread is not None: + return + t = threading.Thread(target=self._net.run_forever, + name='kafka-io-%s' % self.config['client_id'], + daemon=True) + self._io_thread = t + t.start() + + def stop(self, timeout=None): + """Signal the IO thread to exit and join it. Fails any pending run() + waiters with KafkaConnectionError. Idempotent.""" + t = self._io_thread + if t is None: + return + self._io_thread = None + self._net.stop() + t.join(timeout) + with self._pending_waiters_lock: + waiters = list(self._pending_waiters.items()) + self._pending_waiters.clear() + for event, state in waiters: + state['exception'] = Errors.KafkaConnectionError('Manager stopped') + event.set() + def poll(self, timeout_ms=None, future=None): return self._net.poll(timeout_ms=timeout_ms, future=future) + async def _invoke(self, coro, args): + """Invoke coro/awaitable/function and fully resolve the result. + + If the result is itself a Future (e.g. send() returning an unresolved + Future), it is awaited so callers receive the resolved value. + """ + if inspect.iscoroutinefunction(coro): + result = await coro(*args) + elif hasattr(coro, '__await__'): + result = await coro + else: + result = coro(*args) + while isinstance(result, Future): + result = await result + return result + def call_soon(self, coro, *args): """Accepts a coroutine / awaitable / function and schedules it on the event loop. + Thread-safe. + Returns: Future """ if hasattr(coro, '__await__'): @@ -367,21 +415,44 @@ def call_soon(self, coro, *args): future = Future() async def wrapper(): try: - if inspect.iscoroutinefunction(coro): - future.success(await coro(*args)) - elif hasattr(coro, '__await__'): - future.success(await coro) - else: - future.success(coro(*args)) + future.success(await self._invoke(coro, args)) except Exception as exc: future.failure(exc) - self._net.call_soon(wrapper) + self._net.call_soon_threadsafe(wrapper) return future def run(self, coro, *args): - """Schedules coro on the event loop, blocks until complete, returns value or raises.""" - future = self.call_soon(coro, *args) - self.poll(future=future) - if future.exception is not None: - raise future.exception - return future.value + """Schedules coro on the event loop, blocks until complete, returns value or raises. + + If an IO thread is running (via start()), the caller thread blocks on + a cross-thread Event while the coroutine runs on the IO thread. Safe + to call concurrently from multiple caller threads. + + If no IO thread is running, falls back to driving the loop on the + caller thread (legacy behavior). + """ + if self._io_thread is None: + future = self.call_soon(coro, *args) + self.poll(future=future) + if future.exception is not None: + raise future.exception + return future.value + + event = threading.Event() + state = {'value': None, 'exception': None} + async def waiter(): + try: + state['value'] = await self._invoke(coro, args) + except BaseException as exc: + state['exception'] = exc + finally: + with self._pending_waiters_lock: + self._pending_waiters.pop(event, None) + event.set() + with self._pending_waiters_lock: + self._pending_waiters[event] = state + self._net.call_soon_threadsafe(waiter) + event.wait() + if state['exception'] is not None: + raise state['exception'] + return state['value'] diff --git a/kafka/net/selector.py b/kafka/net/selector.py index 45bfa7fc5..d6ae000a7 100644 --- a/kafka/net/selector.py +++ b/kafka/net/selector.py @@ -123,6 +123,7 @@ def __init__(self, **configs): self.config[key] = configs[key] self._closed = False + self._stop = False self._selector = self.config['selector']() self._scheduled = [] # managed by heapq self._ready = collections.deque() @@ -139,6 +140,19 @@ def run(self): while self._scheduled or self._ready: self._poll_once() + def run_forever(self): + """Run the event loop until stop() is called. Intended to be driven by + a dedicated IO thread. Wake-ups from other threads must go through + call_soon_threadsafe() so the select() loop returns promptly.""" + self._stop = False + while not self._stop: + self._poll_once() + + def stop(self): + """Signal run_forever() to exit. Safe to call from any thread.""" + self._stop = True + self.wakeup() + def run_until_done(self, task_or_future): if not isinstance(task_or_future, (Future, Task)): task_or_future = Task(task_or_future) diff --git a/test/admin/test_admin_concurrent.py b/test/admin/test_admin_concurrent.py new file mode 100644 index 000000000..deba175a2 --- /dev/null +++ b/test/admin/test_admin_concurrent.py @@ -0,0 +1,107 @@ +"""Concurrency test for KafkaAdminClient IO thread. + +Verifies that multiple caller threads can safely invoke admin methods +concurrently while a dedicated IO thread owns the event loop. Exercises +the thread-safety foundation in KafkaConnectionManager (start/stop, +cross-thread run via Event, call_soon_threadsafe). +""" +import threading + +import pytest + +from kafka.admin import KafkaAdminClient + +from test.mock_broker import MockBroker + + +@pytest.fixture +def broker(): + return MockBroker(broker_version=(4, 2)) + + +@pytest.fixture +def admin(broker): + admin = KafkaAdminClient( + kafka_client=broker.client_factory(), + bootstrap_servers='%s:%d' % (broker.host, broker.port), + request_timeout_ms=5000, + max_in_flight_requests_per_connection=32, + ) + try: + yield admin + finally: + admin.close() + + +def test_concurrent_describe_cluster(admin): + """Many threads calling describe_cluster at once all succeed with + consistent results and no deadlock.""" + N = 16 + iterations = 4 + errors = [] + results = [] + results_lock = threading.Lock() + barrier = threading.Barrier(N) + + def worker(): + try: + barrier.wait(timeout=5) + for _ in range(iterations): + r = admin.describe_cluster() + with results_lock: + results.append(r) + except BaseException as exc: + errors.append(exc) + + threads = [threading.Thread(target=worker) for _ in range(N)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=15) + assert not t.is_alive(), 'worker deadlocked' + + assert not errors, errors + assert len(results) == N * iterations + for r in results: + assert r['cluster_id'] == 'mock-cluster' + + +def test_close_unblocks_pending_callers(broker): + """If close() is called while a caller is blocked on run(), the + caller receives a KafkaConnectionError rather than hanging.""" + admin = KafkaAdminClient( + kafka_client=broker.client_factory(), + bootstrap_servers='%s:%d' % (broker.host, broker.port), + request_timeout_ms=5000, + ) + manager = admin._manager + + blocked = threading.Event() + released = threading.Event() + + async def blocker(): + blocked.set() + # Wait forever — only unblocks via manager.stop() + await manager._net.sleep(3600) + + result = {} + + def caller(): + try: + manager.run(blocker) + except BaseException as exc: + result['exception'] = exc + finally: + released.set() + + t = threading.Thread(target=caller, daemon=True) + t.start() + assert blocked.wait(timeout=5), 'blocker never scheduled' + + admin.close() + + assert released.wait(timeout=5), 'caller did not unblock on close' + assert 'exception' in result + # stop() fails pending waiters with KafkaConnectionError('Manager stopped') + from kafka.errors import KafkaConnectionError + assert isinstance(result['exception'], KafkaConnectionError)