From 1777156d5c6a1dadeab2415c68d1b9ddbd382b86 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Mon, 26 Jan 2026 15:44:06 +0000 Subject: [PATCH 1/2] feat: Implement MainEventBus architecture for event queue processing Introduces centralized event processing with single background thread to guarantee event persistence before client visibility and eliminate race conditions in concurrent task updates. Key Changes: - MainEventBus: Central LinkedBlockingDeque for all events - MainEventBusProcessor: Single background thread ensuring serial processing (TaskStore.save() -> PushNotificationSender.send() -> distributeToChildren()) - Two-level queue cleanup protection via TaskStateProvider.isTaskFinalized() to prevent premature cleanup for fire-and-forget tasks - Deterministic blocking calls: waitForTaskFinalization() ensures TaskStore persistence completes before returning to client - Streaming closure: agentCompleted flag via EnhancedRunnable.DoneCallback for graceful drain when agent completes - SseFormatter utility: Framework-agnostic SSE formatting in server-common - Executor pool improvements: Bounded EventConsumerExecutor pool (size 15) prevents exhaustion during high concurrency Null TaskId Support: - QueueManager.switchKey(): Atomic key switching from temp to real task ID - Non-streaming path now handles null taskId like streaming (generates temp UUID) - Cloud deployment test reverted to use A2A.toUserMessage() (valid null taskId case) - No duplicate queue map entries, clean queue lifecycle management Additional Fixes: - Blocking calls always wait for task finalization (not just final events) - RequestContext uses message.taskId() for new tasks when task is null - PostgreSQL pod deployment: Wait loop for pod creation before readiness check - Javadoc: Remove invalid

tags around

 blocks
- ServerCallContext: EventConsumer cancellation via closeHandler for graceful disconnect
- Gemini feedback improvements

Architecture Impact:
- All events flow through MainEventBus before distribution to ChildQueues
- Clients never see unpersisted events (persistence-before-distribution)
- Fire-and-forget tasks supported: queues stay open for non-final states
- Late resubscription enabled: queues persist until task finalization
- Null taskId messages supported: temp IDs transition to real IDs via switchKey
- Test synchronization: MainEventBusProcessorCallback for deterministic testing
---
 examples/cloud-deployment/scripts/deploy.sh   |  16 +
 .../core/ReplicatedQueueManager.java          |  23 +-
 .../core/ReplicatedQueueManagerTest.java      | 234 +++++---
 .../io/a2a/server/events/EventQueueUtil.java  |  11 +
 .../server/apps/quarkus/A2AServerRoutes.java  | 120 ++--
 .../src/test/resources/application.properties |   5 +
 .../server/rest/quarkus/A2AServerRoutes.java  | 144 +++--
 .../java/io/a2a/server/ServerCallContext.java |  61 ++
 .../io/a2a/server/events/EventConsumer.java   |  42 ++
 .../java/io/a2a/server/events/EventQueue.java | 312 ++++++----
 .../server/events/InMemoryQueueManager.java   |  49 +-
 .../io/a2a/server/events/MainEventBus.java    |  42 ++
 .../server/events/MainEventBusContext.java    |  11 +
 .../server/events/MainEventBusProcessor.java  | 368 ++++++++++++
 .../events/MainEventBusProcessorCallback.java |  66 +++
 .../MainEventBusProcessorInitializer.java     |  43 ++
 .../io/a2a/server/events/QueueManager.java    |  45 +-
 .../DefaultRequestHandler.java                | 557 +++++++++++-------
 .../io/a2a/server/tasks/ResultAggregator.java | 128 ++--
 .../util/async/AsyncExecutorProducer.java     |  57 +-
 .../async/EventConsumerExecutorProducer.java  |  93 +++
 .../io/a2a/server/util/sse/SseFormatter.java  | 136 +++++
 .../io/a2a/server/util/sse/package-info.java  |  11 +
 .../META-INF/a2a-defaults.properties          |   4 +
 .../a2a/server/events/EventConsumerTest.java  |  99 +++-
 .../io/a2a/server/events/EventQueueTest.java  | 226 ++++---
 .../io/a2a/server/events/EventQueueUtil.java  |  37 +-
 .../events/InMemoryQueueManagerTest.java      |  35 +-
 .../AbstractA2ARequestHandlerTest.java        |  24 +-
 .../DefaultRequestHandlerTest.java            | 112 ++--
 .../server/tasks/ResultAggregatorTest.java    |  82 ++-
 .../io/a2a/server/tasks/TaskUpdaterTest.java  |  56 +-
 .../transport/grpc/handler/GrpcHandler.java   |  24 +-
 .../grpc/handler/GrpcHandlerTest.java         |  40 +-
 .../jsonrpc/handler/JSONRPCHandlerTest.java   | 332 +++++------
 .../transport/rest/handler/RestHandler.java   |  15 +
 36 files changed, 2713 insertions(+), 947 deletions(-)
 create mode 100644 extras/queue-manager-replicated/core/src/test/java/io/a2a/server/events/EventQueueUtil.java
 create mode 100644 server-common/src/main/java/io/a2a/server/events/MainEventBus.java
 create mode 100644 server-common/src/main/java/io/a2a/server/events/MainEventBusContext.java
 create mode 100644 server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java
 create mode 100644 server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorCallback.java
 create mode 100644 server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorInitializer.java
 create mode 100644 server-common/src/main/java/io/a2a/server/util/async/EventConsumerExecutorProducer.java
 create mode 100644 server-common/src/main/java/io/a2a/server/util/sse/SseFormatter.java
 create mode 100644 server-common/src/main/java/io/a2a/server/util/sse/package-info.java

diff --git a/examples/cloud-deployment/scripts/deploy.sh b/examples/cloud-deployment/scripts/deploy.sh
index e267f3302..fff2a6061 100755
--- a/examples/cloud-deployment/scripts/deploy.sh
+++ b/examples/cloud-deployment/scripts/deploy.sh
@@ -212,6 +212,22 @@ echo ""
 echo "Deploying PostgreSQL..."
 kubectl apply -f ../k8s/01-postgres.yaml
 echo "Waiting for PostgreSQL to be ready..."
+
+# Wait for pod to be created (StatefulSet takes time to create pod)
+for i in {1..30}; do
+    if kubectl get pod -l app=postgres -n a2a-demo 2>/dev/null | grep -q postgres; then
+        echo "PostgreSQL pod found, waiting for ready state..."
+        break
+    fi
+    if [ $i -eq 30 ]; then
+        echo -e "${RED}ERROR: PostgreSQL pod not created after 30 seconds${NC}"
+        kubectl get statefulset -n a2a-demo
+        exit 1
+    fi
+    sleep 1
+done
+
+# Now wait for pod to be ready
 kubectl wait --for=condition=Ready pod -l app=postgres -n a2a-demo --timeout=120s
 echo -e "${GREEN}✓ PostgreSQL deployed${NC}"
 
diff --git a/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java b/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java
index 586ab11a7..f320362eb 100644
--- a/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java
+++ b/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java
@@ -13,6 +13,7 @@
 import io.a2a.server.events.EventQueueFactory;
 import io.a2a.server.events.EventQueueItem;
 import io.a2a.server.events.InMemoryQueueManager;
+import io.a2a.server.events.MainEventBus;
 import io.a2a.server.events.QueueManager;
 import io.a2a.server.tasks.TaskStateProvider;
 import org.slf4j.Logger;
@@ -45,10 +46,12 @@ protected ReplicatedQueueManager() {
     }
 
     @Inject
-    public ReplicatedQueueManager(ReplicationStrategy replicationStrategy, TaskStateProvider taskStateProvider) {
+    public ReplicatedQueueManager(ReplicationStrategy replicationStrategy,
+                                    TaskStateProvider taskStateProvider,
+                                    MainEventBus mainEventBus) {
         this.replicationStrategy = replicationStrategy;
         this.taskStateProvider = taskStateProvider;
-        this.delegate = new InMemoryQueueManager(new ReplicatingEventQueueFactory(), taskStateProvider);
+        this.delegate = new InMemoryQueueManager(new ReplicatingEventQueueFactory(), taskStateProvider, mainEventBus);
     }
 
 
@@ -57,6 +60,11 @@ public void add(String taskId, EventQueue queue) {
         delegate.add(taskId, queue);
     }
 
+    @Override
+    public void switchKey(String oldId, String newId) {
+        delegate.switchKey(oldId, newId);
+    }
+
     @Override
     public EventQueue get(String taskId) {
         return delegate.get(taskId);
@@ -152,12 +160,11 @@ public EventQueue.EventQueueBuilder builder(String taskId) {
             // which sends the QueueClosedEvent after the database transaction commits.
             // This ensures proper ordering and transactional guarantees.
 
-            // Return the builder with callbacks
-            return delegate.getEventQueueBuilder(taskId)
-                    .taskId(taskId)
-                    .hook(new ReplicationHook(taskId))
-                    .addOnCloseCallback(delegate.getCleanupCallback(taskId))
-                    .taskStateProvider(taskStateProvider);
+            // Call createBaseEventQueueBuilder() directly to avoid infinite recursion
+            // (getEventQueueBuilder() would delegate back to this factory, creating a loop)
+            // The base builder already includes: taskId, cleanup callback, taskStateProvider, mainEventBus
+            return delegate.createBaseEventQueueBuilder(taskId)
+                    .hook(new ReplicationHook(taskId));
         }
     }
 
diff --git a/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java b/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java
index 43571cd30..a339be543 100644
--- a/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java
+++ b/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java
@@ -22,12 +22,18 @@
 import io.a2a.server.events.EventQueueClosedException;
 import io.a2a.server.events.EventQueueItem;
 import io.a2a.server.events.EventQueueTestHelper;
+import io.a2a.server.events.EventQueueUtil;
+import io.a2a.server.events.MainEventBus;
+import io.a2a.server.events.MainEventBusProcessor;
 import io.a2a.server.events.QueueClosedEvent;
+import io.a2a.server.tasks.InMemoryTaskStore;
+import io.a2a.server.tasks.PushNotificationSender;
 import io.a2a.spec.Event;
 import io.a2a.spec.StreamingEventKind;
 import io.a2a.spec.TaskState;
 import io.a2a.spec.TaskStatus;
 import io.a2a.spec.TaskStatusUpdateEvent;
+import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
@@ -35,10 +41,24 @@ class ReplicatedQueueManagerTest {
 
     private ReplicatedQueueManager queueManager;
     private StreamingEventKind testEvent;
+    private MainEventBus mainEventBus;
+    private MainEventBusProcessor mainEventBusProcessor;
+    private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {};
 
     @BeforeEach
     void setUp() {
-        queueManager = new ReplicatedQueueManager(new NoOpReplicationStrategy(), new MockTaskStateProvider(true));
+        // Create MainEventBus and MainEventBusProcessor for tests
+        InMemoryTaskStore taskStore = new InMemoryTaskStore();
+        mainEventBus = new MainEventBus();
+        mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER);
+        EventQueueUtil.start(mainEventBusProcessor);
+
+        queueManager = new ReplicatedQueueManager(
+            new NoOpReplicationStrategy(),
+            new MockTaskStateProvider(true),
+            mainEventBus
+        );
+
         testEvent = TaskStatusUpdateEvent.builder()
                 .taskId("test-task")
                 .contextId("test-context")
@@ -47,10 +67,65 @@ void setUp() {
                 .build();
     }
 
+    /**
+     * Helper to create a test event with the specified taskId.
+     * This ensures taskId consistency between queue creation and event creation.
+     */
+    private TaskStatusUpdateEvent createEventForTask(String taskId) {
+        return TaskStatusUpdateEvent.builder()
+                .taskId(taskId)
+                .contextId("test-context")
+                .status(new TaskStatus(TaskState.SUBMITTED))
+                .isFinal(false)
+                .build();
+    }
+
+    @AfterEach
+    void tearDown() {
+        if (mainEventBusProcessor != null) {
+            mainEventBusProcessor.setCallback(null);  // Clear any test callbacks
+            EventQueueUtil.stop(mainEventBusProcessor);
+        }
+        mainEventBusProcessor = null;
+        mainEventBus = null;
+        queueManager = null;
+    }
+
+    /**
+     * Helper to wait for MainEventBusProcessor to process an event.
+     * Replaces polling patterns with deterministic callback-based waiting.
+     *
+     * @param action the action that triggers event processing
+     * @throws InterruptedException if waiting is interrupted
+     * @throws AssertionError if processing doesn't complete within timeout
+     */
+    private void waitForEventProcessing(Runnable action) throws InterruptedException {
+        CountDownLatch processingLatch = new CountDownLatch(1);
+        mainEventBusProcessor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() {
+            @Override
+            public void onEventProcessed(String taskId, io.a2a.spec.Event event) {
+                processingLatch.countDown();
+            }
+
+            @Override
+            public void onTaskFinalized(String taskId) {
+                // Not needed for basic event processing wait
+            }
+        });
+
+        try {
+            action.run();
+            assertTrue(processingLatch.await(5, TimeUnit.SECONDS),
+                    "MainEventBusProcessor should have processed the event within timeout");
+        } finally {
+            mainEventBusProcessor.setCallback(null);
+        }
+    }
+
     @Test
     void testReplicationStrategyTriggeredOnNormalEnqueue() throws InterruptedException {
         CountingReplicationStrategy strategy = new CountingReplicationStrategy();
-        queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true));
+        queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true), mainEventBus);
 
         String taskId = "test-task-1";
         EventQueue queue = queueManager.createOrTap(taskId);
@@ -65,7 +140,7 @@ void testReplicationStrategyTriggeredOnNormalEnqueue() throws InterruptedExcepti
     @Test
     void testReplicationStrategyNotTriggeredOnReplicatedEvent() throws InterruptedException {
         CountingReplicationStrategy strategy = new CountingReplicationStrategy();
-        queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true));
+        queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true), mainEventBus);
 
         String taskId = "test-task-2";
         EventQueue queue = queueManager.createOrTap(taskId);
@@ -79,7 +154,7 @@ void testReplicationStrategyNotTriggeredOnReplicatedEvent() throws InterruptedEx
     @Test
     void testReplicationStrategyWithCountingImplementation() throws InterruptedException {
         CountingReplicationStrategy countingStrategy = new CountingReplicationStrategy();
-        queueManager = new ReplicatedQueueManager(countingStrategy, new MockTaskStateProvider(true));
+        queueManager = new ReplicatedQueueManager(countingStrategy, new MockTaskStateProvider(true), mainEventBus);
 
         String taskId = "test-task-3";
         EventQueue queue = queueManager.createOrTap(taskId);
@@ -100,46 +175,45 @@ void testReplicationStrategyWithCountingImplementation() throws InterruptedExcep
     @Test
     void testReplicatedEventDeliveredToCorrectQueue() throws InterruptedException {
         String taskId = "test-task-4";
+        TaskStatusUpdateEvent eventForTask = createEventForTask(taskId);  // Use matching taskId
         EventQueue queue = queueManager.createOrTap(taskId);
 
-        ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, testEvent);
-        queueManager.onReplicatedEvent(replicatedEvent);
+        ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, eventForTask);
 
-        Event dequeuedEvent;
-        try {
-            dequeuedEvent = queue.dequeueEventItem(100).getEvent();
-        } catch (EventQueueClosedException e) {
-            fail("Queue should not be closed");
-            return;
-        }
-        assertEquals(testEvent, dequeuedEvent);
+        // Use callback to wait for event processing
+        EventQueueItem item = dequeueEventWithRetry(queue, () -> queueManager.onReplicatedEvent(replicatedEvent));
+        assertNotNull(item, "Event should be available in queue");
+        Event dequeuedEvent = item.getEvent();
+        assertEquals(eventForTask, dequeuedEvent);
     }
 
     @Test
     void testReplicatedEventCreatesQueueIfNeeded() throws InterruptedException {
         String taskId = "non-existent-task";
+        TaskStatusUpdateEvent eventForTask = createEventForTask(taskId);  // Use matching taskId
 
         // Verify no queue exists initially
         assertNull(queueManager.get(taskId));
 
-        ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, testEvent);
-
-        // Process the replicated event
-        assertDoesNotThrow(() -> queueManager.onReplicatedEvent(replicatedEvent));
-
-        // Verify that a queue was created and the event was enqueued
-        EventQueue queue = queueManager.get(taskId);
-        assertNotNull(queue, "Queue should be created when processing replicated event for non-existent task");
-
-        // Verify the event was enqueued by dequeuing it
-        Event dequeuedEvent;
-        try {
-            dequeuedEvent = queue.dequeueEventItem(100).getEvent();
-        } catch (EventQueueClosedException e) {
-            fail("Queue should not be closed");
-            return;
-        }
-        assertEquals(testEvent, dequeuedEvent, "The replicated event should be enqueued in the newly created queue");
+        // Create a ChildQueue BEFORE processing the replicated event
+        // This ensures the ChildQueue exists when MainEventBusProcessor distributes the event
+        EventQueue childQueue = queueManager.createOrTap(taskId);
+        assertNotNull(childQueue, "ChildQueue should be created");
+
+        // Verify MainQueue was created
+        EventQueue mainQueue = queueManager.get(taskId);
+        assertNotNull(mainQueue, "MainQueue should exist after createOrTap");
+
+        ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, eventForTask);
+
+        // Process the replicated event and wait for distribution
+        // Use callback to wait for event processing
+        EventQueueItem item = dequeueEventWithRetry(childQueue, () -> {
+            assertDoesNotThrow(() -> queueManager.onReplicatedEvent(replicatedEvent));
+        });
+        assertNotNull(item, "Event should be available in queue");
+        Event dequeuedEvent = item.getEvent();
+        assertEquals(eventForTask, dequeuedEvent, "The replicated event should be enqueued in the newly created queue");
     }
 
     @Test
@@ -170,7 +244,7 @@ void testBasicQueueManagerFunctionality() throws InterruptedException {
     void testQueueToTaskIdMappingMaintained() throws InterruptedException {
         String taskId = "test-task-6";
         CountingReplicationStrategy countingStrategy = new CountingReplicationStrategy();
-        queueManager = new ReplicatedQueueManager(countingStrategy, new MockTaskStateProvider(true));
+        queueManager = new ReplicatedQueueManager(countingStrategy, new MockTaskStateProvider(true), mainEventBus);
 
         EventQueue queue = queueManager.createOrTap(taskId);
         queue.enqueueEvent(testEvent);
@@ -217,7 +291,7 @@ void testReplicatedEventJsonSerialization() throws Exception {
     @Test
     void testParallelReplicationBehavior() throws InterruptedException {
         CountingReplicationStrategy strategy = new CountingReplicationStrategy();
-        queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true));
+        queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true), mainEventBus);
 
         String taskId = "parallel-test-task";
         EventQueue queue = queueManager.createOrTap(taskId);
@@ -297,7 +371,7 @@ void testParallelReplicationBehavior() throws InterruptedException {
     void testReplicatedEventSkippedWhenTaskInactive() throws InterruptedException {
         // Create a task state provider that returns false (task is inactive)
         MockTaskStateProvider stateProvider = new MockTaskStateProvider(false);
-        queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider);
+        queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider, mainEventBus);
 
         String taskId = "inactive-task";
 
@@ -316,30 +390,32 @@ void testReplicatedEventSkippedWhenTaskInactive() throws InterruptedException {
     void testReplicatedEventProcessedWhenTaskActive() throws InterruptedException {
         // Create a task state provider that returns true (task is active)
         MockTaskStateProvider stateProvider = new MockTaskStateProvider(true);
-        queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider);
+        queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider, mainEventBus);
 
         String taskId = "active-task";
+        TaskStatusUpdateEvent eventForTask = createEventForTask(taskId);  // Use matching taskId
 
         // Verify no queue exists initially
         assertNull(queueManager.get(taskId));
 
-        // Process a replicated event for an active task
-        ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, testEvent);
-        queueManager.onReplicatedEvent(replicatedEvent);
+        // Create a ChildQueue BEFORE processing the replicated event
+        // This ensures the ChildQueue exists when MainEventBusProcessor distributes the event
+        EventQueue childQueue = queueManager.createOrTap(taskId);
+        assertNotNull(childQueue, "ChildQueue should be created");
 
-        // Queue should be created and event should be enqueued
-        EventQueue queue = queueManager.get(taskId);
-        assertNotNull(queue, "Queue should be created for active task");
+        // Verify MainQueue was created
+        EventQueue mainQueue = queueManager.get(taskId);
+        assertNotNull(mainQueue, "MainQueue should exist after createOrTap");
 
-        // Verify the event was enqueued
-        Event dequeuedEvent;
-        try {
-            dequeuedEvent = queue.dequeueEventItem(100).getEvent();
-        } catch (EventQueueClosedException e) {
-            fail("Queue should not be closed");
-            return;
-        }
-        assertEquals(testEvent, dequeuedEvent, "Event should be enqueued for active task");
+        // Process a replicated event for an active task
+        ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, eventForTask);
+
+        // Verify the event was enqueued and distributed to our ChildQueue
+        // Use callback to wait for event processing
+        EventQueueItem item = dequeueEventWithRetry(childQueue, () -> queueManager.onReplicatedEvent(replicatedEvent));
+        assertNotNull(item, "Event should be available in queue");
+        Event dequeuedEvent = item.getEvent();
+        assertEquals(eventForTask, dequeuedEvent, "Event should be enqueued for active task");
     }
 
 
@@ -347,7 +423,7 @@ void testReplicatedEventProcessedWhenTaskActive() throws InterruptedException {
     void testReplicatedEventToExistingQueueWhenTaskBecomesInactive() throws InterruptedException {
         // Create a task state provider that returns true initially
         MockTaskStateProvider stateProvider = new MockTaskStateProvider(true);
-        queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider);
+        queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider, mainEventBus);
 
         String taskId = "task-becomes-inactive";
 
@@ -387,7 +463,7 @@ void testReplicatedEventToExistingQueueWhenTaskBecomesInactive() throws Interrup
     @Test
     void testPoisonPillSentViaTransactionAwareEvent() throws InterruptedException {
         CountingReplicationStrategy strategy = new CountingReplicationStrategy();
-        queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true));
+        queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true), mainEventBus);
 
         String taskId = "poison-pill-test";
         EventQueue queue = queueManager.createOrTap(taskId);
@@ -451,36 +527,21 @@ void testQueueClosedEventJsonSerialization() throws Exception {
     @Test
     void testReplicatedQueueClosedEventTerminatesConsumer() throws InterruptedException {
         String taskId = "remote-close-test";
+        TaskStatusUpdateEvent eventForTask = createEventForTask(taskId);  // Use matching taskId
         EventQueue queue = queueManager.createOrTap(taskId);
 
-        // Enqueue a normal event
-        queue.enqueueEvent(testEvent);
-
         // Simulate receiving QueueClosedEvent from remote node
         QueueClosedEvent closedEvent = new QueueClosedEvent(taskId);
         ReplicatedEventQueueItem replicatedClosedEvent = new ReplicatedEventQueueItem(taskId, closedEvent);
-        queueManager.onReplicatedEvent(replicatedClosedEvent);
 
-        // Dequeue the normal event first
-        EventQueueItem item1;
-        try {
-            item1 = queue.dequeueEventItem(100);
-        } catch (EventQueueClosedException e) {
-            fail("Should not throw on first dequeue");
-            return;
-        }
-        assertNotNull(item1);
-        assertEquals(testEvent, item1.getEvent());
+        // Dequeue the normal event first (use callback to wait for async processing)
+        EventQueueItem item1 = dequeueEventWithRetry(queue, () -> queue.enqueueEvent(eventForTask));
+        assertNotNull(item1, "First event should be available");
+        assertEquals(eventForTask, item1.getEvent());
 
-        // Next dequeue should get the QueueClosedEvent
-        EventQueueItem item2;
-        try {
-            item2 = queue.dequeueEventItem(100);
-        } catch (EventQueueClosedException e) {
-            fail("Should not throw on second dequeue, should return the event");
-            return;
-        }
-        assertNotNull(item2);
+        // Next dequeue should get the QueueClosedEvent (use callback to wait for async processing)
+        EventQueueItem item2 = dequeueEventWithRetry(queue, () -> queueManager.onReplicatedEvent(replicatedClosedEvent));
+        assertNotNull(item2, "QueueClosedEvent should be available");
         assertTrue(item2.getEvent() instanceof QueueClosedEvent,
                 "Second event should be QueueClosedEvent");
     }
@@ -539,4 +600,25 @@ public void setActive(boolean active) {
             this.active = active;
         }
     }
+
+    /**
+     * Helper method to dequeue an event after waiting for MainEventBusProcessor distribution.
+     * Uses callback-based waiting instead of polling for deterministic synchronization.
+     *
+     * @param queue the queue to dequeue from
+     * @param enqueueAction the action that enqueues the event (triggers event processing)
+     * @return the dequeued EventQueueItem, or null if queue is closed
+     */
+    private EventQueueItem dequeueEventWithRetry(EventQueue queue, Runnable enqueueAction) throws InterruptedException {
+        // Wait for event to be processed and distributed
+        waitForEventProcessing(enqueueAction);
+
+        // Event is now available, dequeue directly
+        try {
+            return queue.dequeueEventItem(100);
+        } catch (EventQueueClosedException e) {
+            // Queue closed, return null
+            return null;
+        }
+    }
 }
\ No newline at end of file
diff --git a/extras/queue-manager-replicated/core/src/test/java/io/a2a/server/events/EventQueueUtil.java b/extras/queue-manager-replicated/core/src/test/java/io/a2a/server/events/EventQueueUtil.java
new file mode 100644
index 000000000..a91575aaa
--- /dev/null
+++ b/extras/queue-manager-replicated/core/src/test/java/io/a2a/server/events/EventQueueUtil.java
@@ -0,0 +1,11 @@
+package io.a2a.server.events;
+
+public class EventQueueUtil {
+    public static void start(MainEventBusProcessor processor) {
+        processor.start();
+    }
+
+    public static void stop(MainEventBusProcessor processor) {
+        processor.stop();
+    }
+}
diff --git a/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java b/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java
index 18e18a2f1..cb5bdb25b 100644
--- a/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java
+++ b/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java
@@ -13,7 +13,6 @@
 import java.util.concurrent.Executor;
 import java.util.concurrent.Flow;
 import java.util.concurrent.atomic.AtomicLong;
-import java.util.function.Function;
 
 import jakarta.enterprise.inject.Instance;
 import jakarta.inject.Inject;
@@ -21,6 +20,7 @@
 
 import com.google.gson.JsonSyntaxException;
 import io.a2a.common.A2AHeaders;
+import io.a2a.server.util.sse.SseFormatter;
 import io.a2a.grpc.utils.JSONRPCUtils;
 import io.a2a.jsonrpc.common.json.IdJsonMappingException;
 import io.a2a.jsonrpc.common.json.InvalidParamsJsonMappingException;
@@ -65,7 +65,6 @@
 import io.a2a.transport.jsonrpc.handler.JSONRPCHandler;
 import io.quarkus.security.Authenticated;
 import io.quarkus.vertx.web.Body;
-import io.quarkus.vertx.web.ReactiveRoutes;
 import io.quarkus.vertx.web.Route;
 import io.smallrye.mutiny.Multi;
 import io.vertx.core.AsyncResult;
@@ -74,6 +73,8 @@
 import io.vertx.core.buffer.Buffer;
 import io.vertx.core.http.HttpServerResponse;
 import io.vertx.ext.web.RoutingContext;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 @Singleton
 public class A2AServerRoutes {
@@ -135,8 +136,12 @@ public void invokeJSONRPCHandler(@Body String body, RoutingContext rc) {
             } else if (streaming) {
                 final Multi> finalStreamingResponse = streamingResponse;
                 executor.execute(() -> {
-                    MultiSseSupport.subscribeObject(
-                            finalStreamingResponse.map(i -> (Object) i), rc);
+                    // Convert Multi to Multi with SSE formatting
+                    AtomicLong eventIdCounter = new AtomicLong(0);
+                    Multi sseEvents = finalStreamingResponse
+                            .map(response -> SseFormatter.formatResponseAsSSE(response, eventIdCounter.getAndIncrement()));
+                    // Write SSE-formatted strings to HTTP response
+                    MultiSseSupport.writeSseStrings(sseEvents, rc, context);
                 });
 
             } else {
@@ -295,34 +300,30 @@ private static com.google.protobuf.MessageOrBuilder convertToProto(A2AResponse
+     * This class only handles HTTP-specific concerns (writing to response, backpressure, disconnect).
+     * SSE formatting and JSON serialization are handled by {@link SseFormatter}.
+     */
     private static class MultiSseSupport {
+        private static final Logger logger = LoggerFactory.getLogger(MultiSseSupport.class);
 
         private MultiSseSupport() {
             // Avoid direct instantiation.
         }
 
-        private static void initialize(HttpServerResponse response) {
-            if (response.bytesWritten() == 0) {
-                MultiMap headers = response.headers();
-                if (headers.get(CONTENT_TYPE) == null) {
-                    headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS);
-                }
-                response.setChunked(true);
-            }
-        }
-
-        private static void onWriteDone(Flow.Subscription subscription, AsyncResult ar, RoutingContext rc) {
-            if (ar.failed()) {
-                rc.fail(ar.cause());
-            } else {
-                subscription.request(1);
-            }
-        }
-
-        public static void write(Multi multi, RoutingContext rc) {
+        /**
+         * Write SSE-formatted strings to HTTP response.
+         *
+         * @param sseStrings Multi stream of SSE-formatted strings (from SseFormatter)
+         * @param rc         Vert.x routing context
+         * @param context    A2A server call context (for EventConsumer cancellation)
+         */
+        public static void writeSseStrings(Multi sseStrings, RoutingContext rc, ServerCallContext context) {
             HttpServerResponse response = rc.response();
-            multi.subscribe().withSubscriber(new Flow.Subscriber() {
+
+            sseStrings.subscribe().withSubscriber(new Flow.Subscriber() {
                 Flow.Subscription upstream;
 
                 @Override
@@ -330,6 +331,13 @@ public void onSubscribe(Flow.Subscription subscription) {
                     this.upstream = subscription;
                     this.upstream.request(1);
 
+                    // Detect client disconnect and call EventConsumer.cancel() directly
+                    response.closeHandler(v -> {
+                        logger.info("SSE connection closed by client, calling EventConsumer.cancel() to stop polling loop");
+                        context.invokeEventConsumerCancelCallback();
+                        subscription.cancel();
+                    });
+
                     // Notify tests that we are subscribed
                     Runnable runnable = streamingMultiSseSupportSubscribedRunnable;
                     if (runnable != null) {
@@ -338,54 +346,50 @@ public void onSubscribe(Flow.Subscription subscription) {
                 }
 
                 @Override
-                public void onNext(Buffer item) {
-                    initialize(response);
-                    response.write(item, new Handler>() {
+                public void onNext(String sseEvent) {
+                    // Set SSE headers on first event
+                    if (response.bytesWritten() == 0) {
+                        MultiMap headers = response.headers();
+                        if (headers.get(CONTENT_TYPE) == null) {
+                            headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS);
+                        }
+                        response.setChunked(true);
+                    }
+
+                    // Write SSE-formatted string to response
+                    response.write(Buffer.buffer(sseEvent), new Handler>() {
                         @Override
                         public void handle(AsyncResult ar) {
-                            onWriteDone(upstream, ar, rc);
+                            if (ar.failed()) {
+                                // Client disconnected or write failed - cancel upstream to stop EventConsumer
+                                upstream.cancel();
+                                rc.fail(ar.cause());
+                            } else {
+                                upstream.request(1);
+                            }
                         }
                     });
                 }
 
                 @Override
                 public void onError(Throwable throwable) {
+                    // Cancel upstream to stop EventConsumer when error occurs
+                    upstream.cancel();
                     rc.fail(throwable);
                 }
 
                 @Override
                 public void onComplete() {
-                    endOfStream(response);
-                }
-            });
-        }
-
-        public static void subscribeObject(Multi multi, RoutingContext rc) {
-            AtomicLong count = new AtomicLong();
-            write(multi.map(new Function() {
-                @Override
-                public Buffer apply(Object o) {
-                    if (o instanceof ReactiveRoutes.ServerSentEvent) {
-                        ReactiveRoutes.ServerSentEvent ev = (ReactiveRoutes.ServerSentEvent) o;
-                        long id = ev.id() != -1 ? ev.id() : count.getAndIncrement();
-                        String e = ev.event() == null ? "" : "event: " + ev.event() + "\n";
-                        String data = serializeResponse((A2AResponse) ev.data());
-                        return Buffer.buffer(e + "data: " + data + "\nid: " + id + "\n\n");
+                    if (response.bytesWritten() == 0) {
+                        // No events written - still set SSE content type
+                        MultiMap headers = response.headers();
+                        if (headers.get(CONTENT_TYPE) == null) {
+                            headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS);
+                        }
                     }
-                    String data = serializeResponse((A2AResponse) o);
-                    return Buffer.buffer("data: " + data + "\nid: " + count.getAndIncrement() + "\n\n");
-                }
-            }), rc);
-        }
-
-        private static void endOfStream(HttpServerResponse response) {
-            if (response.bytesWritten() == 0) { // No item
-                MultiMap headers = response.headers();
-                if (headers.get(CONTENT_TYPE) == null) {
-                    headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS);
+                    response.end();
                 }
-            }
-            response.end();
+            });
         }
     }
 }
diff --git a/reference/jsonrpc/src/test/resources/application.properties b/reference/jsonrpc/src/test/resources/application.properties
index 7b9cea9cc..e612925d4 100644
--- a/reference/jsonrpc/src/test/resources/application.properties
+++ b/reference/jsonrpc/src/test/resources/application.properties
@@ -1 +1,6 @@
 quarkus.arc.selected-alternatives=io.a2a.server.apps.common.TestHttpClient
+
+# Debug logging for event processing and request handling
+quarkus.log.category."io.a2a.server.events".level=DEBUG
+quarkus.log.category."io.a2a.server.requesthandlers".level=DEBUG
+quarkus.log.category."io.a2a.server.tasks".level=DEBUG
diff --git a/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java b/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java
index 46d0d38e6..7a50f0afb 100644
--- a/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java
+++ b/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java
@@ -15,7 +15,8 @@
 import java.util.concurrent.Executor;
 import java.util.concurrent.Flow;
 import java.util.concurrent.atomic.AtomicLong;
-import java.util.function.Function;
+
+import io.a2a.server.util.sse.SseFormatter;
 
 import jakarta.annotation.security.PermitAll;
 import jakarta.enterprise.inject.Instance;
@@ -38,7 +39,6 @@
 import io.a2a.util.Utils;
 import io.quarkus.security.Authenticated;
 import io.quarkus.vertx.web.Body;
-import io.quarkus.vertx.web.ReactiveRoutes;
 import io.quarkus.vertx.web.Route;
 import io.smallrye.mutiny.Multi;
 import io.vertx.core.AsyncResult;
@@ -110,10 +110,14 @@ public void sendMessageStreaming(@Body String body, RoutingContext rc) {
             if (error != null) {
                 sendResponse(rc, error);
             } else if (streamingResponse != null) {
-                Multi events = Multi.createFrom().publisher(streamingResponse.getPublisher());
+                final HTTPRestStreamingResponse finalStreamingResponse = streamingResponse;
                 executor.execute(() -> {
-                    MultiSseSupport.subscribeObject(
-                            events.map(i -> (Object) i), rc);
+                    // Convert Flow.Publisher (JSON) to Multi (SSE-formatted)
+                    AtomicLong eventIdCounter = new AtomicLong(0);
+                    Multi sseEvents = Multi.createFrom().publisher(finalStreamingResponse.getPublisher())
+                            .map(json -> SseFormatter.formatJsonAsSSE(json, eventIdCounter.getAndIncrement()));
+                    // Write SSE-formatted strings to HTTP response
+                    MultiSseSupport.writeSseStrings(sseEvents, rc, context);
                 });
             }
         }
@@ -243,10 +247,14 @@ public void subscribeToTask(RoutingContext rc) {
             if (error != null) {
                 sendResponse(rc, error);
             } else if (streamingResponse != null) {
-                Multi events = Multi.createFrom().publisher(streamingResponse.getPublisher());
+                final HTTPRestStreamingResponse finalStreamingResponse = streamingResponse;
                 executor.execute(() -> {
-                    MultiSseSupport.subscribeObject(
-                            events.map(i -> (Object) i), rc);
+                    // Convert Flow.Publisher (JSON) to Multi (SSE-formatted)
+                    AtomicLong eventIdCounter = new AtomicLong(0);
+                    Multi sseEvents = Multi.createFrom().publisher(finalStreamingResponse.getPublisher())
+                            .map(json -> SseFormatter.formatJsonAsSSE(json, eventIdCounter.getAndIncrement()));
+                    // Write SSE-formatted strings to HTTP response
+                    MultiSseSupport.writeSseStrings(sseEvents, rc, context);
                 });
             }
         }
@@ -450,34 +458,30 @@ public String getUsername() {
         }
     }
 
-    // Port of import io.quarkus.vertx.web.runtime.MultiSseSupport, which is considered internal API
+    /**
+     * Simplified SSE support for Vert.x/Quarkus.
+     * 

+ * This class only handles HTTP-specific concerns (writing to response, backpressure, disconnect). + * SSE formatting and JSON serialization are handled by {@link SseFormatter}. + */ private static class MultiSseSupport { + private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(MultiSseSupport.class); private MultiSseSupport() { // Avoid direct instantiation. } - private static void initialize(HttpServerResponse response) { - if (response.bytesWritten() == 0) { - MultiMap headers = response.headers(); - if (headers.get(CONTENT_TYPE) == null) { - headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); - } - response.setChunked(true); - } - } - - private static void onWriteDone(Flow.@Nullable Subscription subscription, AsyncResult ar, RoutingContext rc) { - if (ar.failed()) { - rc.fail(ar.cause()); - } else if (subscription != null) { - subscription.request(1); - } - } - - private static void write(Multi multi, RoutingContext rc) { + /** + * Write SSE-formatted strings to HTTP response. + * + * @param sseStrings Multi stream of SSE-formatted strings (from SseFormatter) + * @param rc Vert.x routing context + * @param context A2A server call context (for EventConsumer cancellation) + */ + public static void writeSseStrings(Multi sseStrings, RoutingContext rc, ServerCallContext context) { HttpServerResponse response = rc.response(); - multi.subscribe().withSubscriber(new Flow.Subscriber() { + + sseStrings.subscribe().withSubscriber(new Flow.Subscriber() { Flow.@Nullable Subscription upstream; @Override @@ -485,6 +489,13 @@ public void onSubscribe(Flow.Subscription subscription) { this.upstream = subscription; this.upstream.request(1); + // Detect client disconnect and call EventConsumer.cancel() directly + response.closeHandler(v -> { + logger.debug("REST SSE connection closed by client, calling EventConsumer.cancel() to stop polling loop"); + context.invokeEventConsumerCancelCallback(); + subscription.cancel(); + }); + // Notify tests that we are subscribed Runnable runnable = streamingMultiSseSupportSubscribedRunnable; if (runnable != null) { @@ -493,53 +504,64 @@ public void onSubscribe(Flow.Subscription subscription) { } @Override - public void onNext(Buffer item) { - initialize(response); - response.write(item, new Handler>() { + public void onNext(String sseEvent) { + // Set SSE headers on first event + if (response.bytesWritten() == 0) { + MultiMap headers = response.headers(); + if (headers.get(CONTENT_TYPE) == null) { + headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); + } + // Additional SSE headers to prevent buffering + headers.set("Cache-Control", "no-cache"); + headers.set("X-Accel-Buffering", "no"); // Disable nginx buffering + response.setChunked(true); + + // CRITICAL: Disable write queue max size to prevent buffering + // Vert.x buffers writes by default - we need immediate flushing for SSE + response.setWriteQueueMaxSize(1); // Force immediate flush + + // Send initial SSE comment to kickstart the stream + // This forces Vert.x to send headers and start the stream immediately + response.write(": SSE stream started\n\n"); + } + + // Write SSE-formatted string to response + response.write(Buffer.buffer(sseEvent), new Handler>() { @Override public void handle(AsyncResult ar) { - onWriteDone(upstream, ar, rc); + if (ar.failed()) { + // Client disconnected or write failed - cancel upstream to stop EventConsumer + // NullAway: upstream is guaranteed non-null after onSubscribe + java.util.Objects.requireNonNull(upstream).cancel(); + rc.fail(ar.cause()); + } else { + // NullAway: upstream is guaranteed non-null after onSubscribe + java.util.Objects.requireNonNull(upstream).request(1); + } } }); } @Override public void onError(Throwable throwable) { + // Cancel upstream to stop EventConsumer when error occurs + // NullAway: upstream is guaranteed non-null after onSubscribe + java.util.Objects.requireNonNull(upstream).cancel(); rc.fail(throwable); } @Override public void onComplete() { - endOfStream(response); - } - }); - } - - private static void subscribeObject(Multi multi, RoutingContext rc) { - AtomicLong count = new AtomicLong(); - write(multi.map(new Function() { - @Override - public Buffer apply(Object o) { - if (o instanceof ReactiveRoutes.ServerSentEvent) { - ReactiveRoutes.ServerSentEvent ev = (ReactiveRoutes.ServerSentEvent) o; - long id = ev.id() != -1 ? ev.id() : count.getAndIncrement(); - String e = ev.event() == null ? "" : "event: " + ev.event() + "\n"; - return Buffer.buffer(e + "data: " + ev.data() + "\nid: " + id + "\n\n"); - } else { - return Buffer.buffer("data: " + o + "\nid: " + count.getAndIncrement() + "\n\n"); + if (response.bytesWritten() == 0) { + // No events written - still set SSE content type + MultiMap headers = response.headers(); + if (headers.get(CONTENT_TYPE) == null) { + headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); + } } + response.end(); } - }), rc); - } - - private static void endOfStream(HttpServerResponse response) { - if (response.bytesWritten() == 0) { // No item - MultiMap headers = response.headers(); - if (headers.get(CONTENT_TYPE) == null) { - headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); - } - } - response.end(); + }); } } diff --git a/server-common/src/main/java/io/a2a/server/ServerCallContext.java b/server-common/src/main/java/io/a2a/server/ServerCallContext.java index ba5c20b95..c12c60c21 100644 --- a/server-common/src/main/java/io/a2a/server/ServerCallContext.java +++ b/server-common/src/main/java/io/a2a/server/ServerCallContext.java @@ -16,6 +16,7 @@ public class ServerCallContext { private final Set requestedExtensions; private final Set activatedExtensions; private final @Nullable String requestedProtocolVersion; + private volatile @Nullable Runnable eventConsumerCancelCallback; public ServerCallContext(User user, Map state, Set requestedExtensions) { this(user, state, requestedExtensions, null); @@ -64,4 +65,64 @@ public boolean isExtensionRequested(String extensionUri) { public @Nullable String getRequestedProtocolVersion() { return requestedProtocolVersion; } + + /** + * Sets the callback to be invoked when the client disconnects or the call is cancelled. + *

+ * This callback is typically used to stop the EventConsumer polling loop when a client + * disconnects from a streaming endpoint. The callback is invoked by transport layers + * (JSON-RPC over HTTP/SSE, REST over HTTP/SSE, gRPC streaming) when they detect that + * the client has closed the connection. + *

+ *

+ * Thread Safety: The callback may be invoked from any thread, depending + * on the transport implementation. Implementations should be thread-safe. + *

+ * Example Usage: + *
{@code
+     * EventConsumer consumer = new EventConsumer(queue);
+     * context.setEventConsumerCancelCallback(consumer::cancel);
+     * }
+ * + * @param callback the callback to invoke on client disconnect, or null to clear any existing callback + * @see #invokeEventConsumerCancelCallback() + */ + public void setEventConsumerCancelCallback(@Nullable Runnable callback) { + this.eventConsumerCancelCallback = callback; + } + + /** + * Invokes the EventConsumer cancel callback if one has been set. + *

+ * This method is called by transport layers when a client disconnects or cancels a + * streaming request. It triggers the callback registered via + * {@link #setEventConsumerCancelCallback(Runnable)}, which typically stops the + * EventConsumer polling loop. + *

+ *

+ * Transport-Specific Behavior: + *

+ *
    + *
  • JSON-RPC/REST over HTTP/SSE: Called from Vert.x + * {@code HttpServerResponse.closeHandler()} when the SSE connection is closed
  • + *
  • gRPC streaming: Called from gRPC + * {@code Context.CancellationListener.cancelled()} when the call is cancelled
  • + *
+ *

+ * Thread Safety: This method is thread-safe. The callback is stored + * in a volatile field and null-checked before invocation to prevent race conditions. + *

+ *

+ * If no callback has been set, this method does nothing (no-op). + *

+ * + * @see #setEventConsumerCancelCallback(Runnable) + * @see io.a2a.server.events.EventConsumer#cancel() + */ + public void invokeEventConsumerCancelCallback() { + Runnable callback = this.eventConsumerCancelCallback; + if (callback != null) { + callback.run(); + } + } } diff --git a/server-common/src/main/java/io/a2a/server/events/EventConsumer.java b/server-common/src/main/java/io/a2a/server/events/EventConsumer.java index d4fe5b395..83b4575ca 100644 --- a/server-common/src/main/java/io/a2a/server/events/EventConsumer.java +++ b/server-common/src/main/java/io/a2a/server/events/EventConsumer.java @@ -19,6 +19,8 @@ public class EventConsumer { private static final Logger LOGGER = LoggerFactory.getLogger(EventConsumer.class); private final EventQueue queue; private volatile @Nullable Throwable error; + private volatile boolean cancelled = false; + private volatile boolean agentCompleted = false; private static final String ERROR_MSG = "Agent did not return any response"; private static final int NO_WAIT = -1; @@ -45,6 +47,14 @@ public Flow.Publisher consumeAll() { boolean completed = false; try { while (true) { + // Check if cancelled by client disconnect + if (cancelled) { + LOGGER.debug("EventConsumer detected cancellation, exiting polling loop for queue {}", System.identityHashCode(queue)); + completed = true; + tube.complete(); + return; + } + if (error != null) { completed = true; tube.fail(error); @@ -60,13 +70,32 @@ public Flow.Publisher consumeAll() { EventQueueItem item; Event event; try { + LOGGER.debug("EventConsumer polling queue {} (error={}, agentCompleted={})", + System.identityHashCode(queue), error, agentCompleted); item = queue.dequeueEventItem(QUEUE_WAIT_MILLISECONDS); if (item == null) { + LOGGER.debug("EventConsumer poll timeout (null item), agentCompleted={}", agentCompleted); + // If agent completed, a poll timeout means no more events are coming + // MainEventBusProcessor has 500ms to distribute events from MainEventBus + // If we timeout with agentCompleted=true, all events have been distributed + if (agentCompleted) { + LOGGER.debug("Agent completed and poll timeout, closing queue for graceful completion (queue={})", + System.identityHashCode(queue)); + queue.close(); + completed = true; + tube.complete(); + return; + } continue; } event = item.getEvent(); + LOGGER.debug("EventConsumer received event: {} (queue={})", + event.getClass().getSimpleName(), System.identityHashCode(queue)); + // Defensive logging for error handling if (event instanceof Throwable thr) { + LOGGER.debug("EventConsumer detected Throwable event: {} - triggering tube.fail()", + thr.getClass().getSimpleName()); tube.fail(thr); return; } @@ -138,12 +167,25 @@ private boolean isStreamTerminatingTask(Task task) { public EnhancedRunnable.DoneCallback createAgentRunnableDoneCallback() { return agentRunnable -> { + LOGGER.debug("EventConsumer: Agent done callback invoked (hasError={}, queue={})", + agentRunnable.getError() != null, System.identityHashCode(queue)); if (agentRunnable.getError() != null) { error = agentRunnable.getError(); + LOGGER.debug("EventConsumer: Set error field from agent callback"); + } else { + agentCompleted = true; + LOGGER.debug("EventConsumer: Agent completed successfully, set agentCompleted=true, will close queue after draining"); } }; } + public void cancel() { + // Set cancellation flag to stop polling loop + // Called when client disconnects without completing stream + LOGGER.debug("EventConsumer cancelled (client disconnect), stopping polling for queue {}", System.identityHashCode(queue)); + cancelled = true; + } + public void close() { // Close the queue to stop the polling loop in consumeAll() // This will cause EventQueueClosedException and exit the while(true) loop diff --git a/server-common/src/main/java/io/a2a/server/events/EventQueue.java b/server-common/src/main/java/io/a2a/server/events/EventQueue.java index a08f63084..0dff01a31 100644 --- a/server-common/src/main/java/io/a2a/server/events/EventQueue.java +++ b/server-common/src/main/java/io/a2a/server/events/EventQueue.java @@ -1,6 +1,7 @@ package io.a2a.server.events; import java.util.List; +import java.util.Objects; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; @@ -23,7 +24,7 @@ * and hierarchical queue structures via MainQueue and ChildQueue implementations. *

*

- * Use {@link #builder()} to create configured instances or extend MainQueue/ChildQueue directly. + * Use {@link #builder(MainEventBus)} to create configured instances or extend MainQueue/ChildQueue directly. *

*/ public abstract class EventQueue implements AutoCloseable { @@ -36,14 +37,6 @@ public abstract class EventQueue implements AutoCloseable { public static final int DEFAULT_QUEUE_SIZE = 1000; private final int queueSize; - /** - * Internal blocking queue for storing event queue items. - */ - protected final BlockingQueue queue = new LinkedBlockingDeque<>(); - /** - * Semaphore for backpressure control, limiting the number of pending events. - */ - protected final Semaphore semaphore; private volatile boolean closed = false; /** @@ -64,7 +57,6 @@ protected EventQueue(int queueSize) { throw new IllegalArgumentException("Queue size must be greater than 0"); } this.queueSize = queueSize; - this.semaphore = new Semaphore(queueSize, true); LOGGER.trace("Creating {} with queue size: {}", this, queueSize); } @@ -78,8 +70,8 @@ protected EventQueue(EventQueue parent) { LOGGER.trace("Creating {}, parent: {}", this, parent); } - static EventQueueBuilder builder() { - return new EventQueueBuilder(); + static EventQueueBuilder builder(MainEventBus mainEventBus) { + return new EventQueueBuilder().mainEventBus(mainEventBus); } /** @@ -95,6 +87,7 @@ public static class EventQueueBuilder { private @Nullable String taskId; private List onCloseCallbacks = new java.util.ArrayList<>(); private @Nullable TaskStateProvider taskStateProvider; + private @Nullable MainEventBus mainEventBus; /** * Sets the maximum queue size. @@ -153,17 +146,31 @@ public EventQueueBuilder taskStateProvider(TaskStateProvider taskStateProvider) return this; } + /** + * Sets the main event bus + * + * @param mainEventBus the main event bus + * @return this builder + */ + public EventQueueBuilder mainEventBus(MainEventBus mainEventBus) { + this.mainEventBus = mainEventBus; + return this; + } + /** * Builds and returns the configured EventQueue. * * @return a new MainQueue instance */ public EventQueue build() { - if (hook != null || !onCloseCallbacks.isEmpty() || taskStateProvider != null) { - return new MainQueue(queueSize, hook, taskId, onCloseCallbacks, taskStateProvider); - } else { - return new MainQueue(queueSize); + // MainEventBus is REQUIRED - enforce single architectural path + if (mainEventBus == null) { + throw new IllegalStateException("MainEventBus is required for EventQueue creation"); + } + if (taskId == null) { + throw new IllegalStateException("taskId is required for EventQueue creation"); } + return new MainQueue(queueSize, hook, taskId, onCloseCallbacks, taskStateProvider, mainEventBus); } } @@ -209,22 +216,7 @@ public void enqueueEvent(Event event) { * @param item the event queue item to enqueue * @throws RuntimeException if interrupted while waiting to acquire the semaphore */ - public void enqueueItem(EventQueueItem item) { - Event event = item.getEvent(); - if (closed) { - LOGGER.warn("Queue is closed. Event will not be enqueued. {} {}", this, event); - return; - } - // Call toString() since for errors we don't really want the full stacktrace - try { - semaphore.acquire(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Unable to acquire the semaphore to enqueue the event", e); - } - queue.add(item); - LOGGER.debug("Enqueued event {} {}", event instanceof Throwable ? event.toString() : event, this); - } + public abstract void enqueueItem(EventQueueItem item); /** * Creates a child queue that shares events with this queue. @@ -244,48 +236,17 @@ public void enqueueItem(EventQueueItem item) { * This method returns the full EventQueueItem wrapper, allowing callers to check * metadata like whether the event is replicated via {@link EventQueueItem#isReplicated()}. *

+ *

+ * Note: MainQueue does not support dequeue operations - only ChildQueues can be consumed. + *

* * @param waitMilliSeconds the maximum time to wait in milliseconds * @return the EventQueueItem, or null if timeout occurs * @throws EventQueueClosedException if the queue is closed and empty + * @throws UnsupportedOperationException if called on MainQueue */ - public @Nullable EventQueueItem dequeueEventItem(int waitMilliSeconds) throws EventQueueClosedException { - if (closed && queue.isEmpty()) { - LOGGER.debug("Queue is closed, and empty. Sending termination message. {}", this); - throw new EventQueueClosedException(); - } - try { - if (waitMilliSeconds <= 0) { - EventQueueItem item = queue.poll(); - if (item != null) { - Event event = item.getEvent(); - // Call toString() since for errors we don't really want the full stacktrace - LOGGER.debug("Dequeued event item (no wait) {} {}", this, event instanceof Throwable ? event.toString() : event); - semaphore.release(); - } - return item; - } - try { - LOGGER.trace("Polling queue {} (wait={}ms)", System.identityHashCode(this), waitMilliSeconds); - EventQueueItem item = queue.poll(waitMilliSeconds, TimeUnit.MILLISECONDS); - if (item != null) { - Event event = item.getEvent(); - // Call toString() since for errors we don't really want the full stacktrace - LOGGER.debug("Dequeued event item (waiting) {} {}", this, event instanceof Throwable ? event.toString() : event); - semaphore.release(); - } else { - LOGGER.trace("Dequeue timeout (null) from queue {}", System.identityHashCode(this)); - } - return item; - } catch (InterruptedException e) { - LOGGER.debug("Interrupted dequeue (waiting) {}", this); - Thread.currentThread().interrupt(); - return null; - } - } finally { - signalQueuePollerStarted(); - } - } + @Nullable + public abstract EventQueueItem dequeueEventItem(int waitMilliSeconds) throws EventQueueClosedException; /** * Placeholder method for task completion notification. @@ -295,6 +256,18 @@ public void taskDone() { // TODO Not sure if needed yet. BlockingQueue.poll()/.take() remove the events. } + /** + * Returns the current size of the queue. + *

+ * For MainQueue: returns the number of events in-flight (in MainEventBus queue + currently being processed). + * This reflects actual capacity usage tracked by the semaphore. + * For ChildQueue: returns the size of the local consumption queue. + *

+ * + * @return the number of events currently in the queue + */ + public abstract int size(); + /** * Closes this event queue gracefully, allowing pending events to be consumed. */ @@ -348,72 +321,64 @@ protected void doClose(boolean immediate) { LOGGER.debug("Closing {} (immediate={})", this, immediate); closed = true; } - - if (immediate) { - // Immediate close: clear pending events - queue.clear(); - LOGGER.debug("Cleared queue for immediate close: {}", this); - } - // For graceful close, let the queue drain naturally through normal consumption + // Subclasses handle immediate close logic (e.g., ChildQueue clears its local queue) } static class MainQueue extends EventQueue { private final List children = new CopyOnWriteArrayList<>(); + protected final Semaphore semaphore; private final CountDownLatch pollingStartedLatch = new CountDownLatch(1); private final AtomicBoolean pollingStarted = new AtomicBoolean(false); private final @Nullable EventEnqueueHook enqueueHook; - private final @Nullable String taskId; + private final String taskId; private final List onCloseCallbacks; private final @Nullable TaskStateProvider taskStateProvider; - - MainQueue() { - super(); - this.enqueueHook = null; - this.taskId = null; - this.onCloseCallbacks = List.of(); - this.taskStateProvider = null; - } - - MainQueue(int queueSize) { - super(queueSize); - this.enqueueHook = null; - this.taskId = null; - this.onCloseCallbacks = List.of(); - this.taskStateProvider = null; - } - - MainQueue(EventEnqueueHook hook) { - super(); - this.enqueueHook = hook; - this.taskId = null; - this.onCloseCallbacks = List.of(); - this.taskStateProvider = null; - } - - MainQueue(int queueSize, EventEnqueueHook hook) { - super(queueSize); - this.enqueueHook = hook; - this.taskId = null; - this.onCloseCallbacks = List.of(); - this.taskStateProvider = null; - } - - MainQueue(int queueSize, @Nullable EventEnqueueHook hook, @Nullable String taskId, List onCloseCallbacks, @Nullable TaskStateProvider taskStateProvider) { + private final MainEventBus mainEventBus; + + MainQueue(int queueSize, + @Nullable EventEnqueueHook hook, + String taskId, + List onCloseCallbacks, + @Nullable TaskStateProvider taskStateProvider, + @Nullable MainEventBus mainEventBus) { super(queueSize); + this.semaphore = new Semaphore(queueSize, true); this.enqueueHook = hook; this.taskId = taskId; this.onCloseCallbacks = List.copyOf(onCloseCallbacks); // Defensive copy this.taskStateProvider = taskStateProvider; - LOGGER.debug("Created MainQueue for task {} with {} onClose callbacks and TaskStateProvider: {}", + this.mainEventBus = Objects.requireNonNull(mainEventBus, "MainEventBus is required"); + LOGGER.debug("Created MainQueue for task {} with {} onClose callbacks, TaskStateProvider: {}, MainEventBus configured", taskId, onCloseCallbacks.size(), taskStateProvider != null); } + public EventQueue tap() { ChildQueue child = new ChildQueue(this); children.add(child); return child; } + /** + * Returns the current number of child queues. + * Useful for debugging and logging event distribution. + */ + public int getChildCount() { + return children.size(); + } + + @Override + public EventQueueItem dequeueEventItem(int waitMilliSeconds) throws EventQueueClosedException { + throw new UnsupportedOperationException("MainQueue cannot be consumed directly - use tap() to create a ChildQueue for consumption"); + } + + @Override + public int size() { + // Return total in-flight events (in MainEventBus + being processed) + // This aligns with semaphore's capacity tracking + return getQueueSize() - semaphore.availablePermits(); + } + @Override public void enqueueItem(EventQueueItem item) { // MainQueue must accept events even when closed to support: @@ -432,14 +397,13 @@ public void enqueueItem(EventQueueItem item) { throw new RuntimeException("Unable to acquire the semaphore to enqueue the event", e); } - // Add to this MainQueue's internal queue - queue.add(item); LOGGER.debug("Enqueued event {} {}", event instanceof Throwable ? event.toString() : event, this); - // Distribute to all ChildQueues (they will receive the event even if MainQueue is closed) - children.forEach(eq -> eq.internalEnqueueItem(item)); + // Submit to MainEventBus for centralized persistence + distribution + // MainEventBus is guaranteed non-null by constructor requirement + mainEventBus.submit(taskId, this, item); - // Trigger replication hook if configured + // Trigger replication hook if configured (for inter-process replication) if (enqueueHook != null) { enqueueHook.onEnqueue(item); } @@ -493,6 +457,36 @@ void childClosing(ChildQueue child, boolean immediate) { this.doClose(immediate); } + /** + * Distribute event to all ChildQueues. + * Called by MainEventBusProcessor after TaskStore persistence. + */ + void distributeToChildren(EventQueueItem item) { + int childCount = children.size(); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("MainQueue[{}]: Distributing event {} to {} children", + taskId, item.getEvent().getClass().getSimpleName(), childCount); + } + children.forEach(child -> { + LOGGER.debug("MainQueue[{}]: Enqueueing event {} to child queue", + taskId, item.getEvent().getClass().getSimpleName()); + child.internalEnqueueItem(item); + }); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("MainQueue[{}]: Completed distribution of {} to {} children", + taskId, item.getEvent().getClass().getSimpleName(), childCount); + } + } + + /** + * Release the semaphore after event processing is complete. + * Called by MainEventBusProcessor in finally block to ensure release even on exceptions. + * Balances the acquire() in enqueueEvent() - protects MainEventBus throughput. + */ + void releaseSemaphore() { + semaphore.release(); + } + /** * Get the count of active child queues. * Used for testing to verify reference counting mechanism. @@ -543,6 +537,8 @@ public void close(boolean immediate, boolean notifyParent) { static class ChildQueue extends EventQueue { private final MainQueue parent; + private final BlockingQueue queue = new LinkedBlockingDeque<>(); + private volatile boolean immediateClose = false; public ChildQueue(MainQueue parent) { this.parent = parent; @@ -553,8 +549,69 @@ public void enqueueEvent(Event event) { parent.enqueueEvent(event); } + @Override + public void enqueueItem(EventQueueItem item) { + // ChildQueue delegates writes to parent MainQueue + parent.enqueueItem(item); + } + private void internalEnqueueItem(EventQueueItem item) { - super.enqueueItem(item); + // Internal method called by MainEventBusProcessor to add to local queue + // Note: Semaphore is managed by parent MainQueue (acquire/release), not ChildQueue + Event event = item.getEvent(); + // For graceful close: still accept events so they can be drained by EventConsumer + // For immediate close: reject events to stop distribution quickly + if (isClosed() && immediateClose) { + LOGGER.warn("ChildQueue is immediately closed. Event will not be enqueued. {} {}", this, event); + return; + } + if (!queue.offer(item)) { + LOGGER.warn("ChildQueue {} is full. Closing immediately.", this); + close(true); // immediate close + } else { + LOGGER.debug("Enqueued event {} {}", event instanceof Throwable ? event.toString() : event, this); + } + } + + @Override + @Nullable + public EventQueueItem dequeueEventItem(int waitMilliSeconds) throws EventQueueClosedException { + // For immediate close: exit immediately even if queue is not empty (race with MainEventBusProcessor) + // For graceful close: only exit when queue is empty (wait for all events to be consumed) + if (isClosed() && (queue.isEmpty() || immediateClose)) { + LOGGER.debug("ChildQueue is closed{}, sending termination message. {} (queueSize={})", + immediateClose ? " (immediate)" : " and empty", + this, + queue.size()); + throw new EventQueueClosedException(); + } + try { + if (waitMilliSeconds <= 0) { + EventQueueItem item = queue.poll(); + if (item != null) { + Event event = item.getEvent(); + LOGGER.debug("Dequeued event item (no wait) {} {}", this, event instanceof Throwable ? event.toString() : event); + } + return item; + } + try { + LOGGER.trace("Polling ChildQueue {} (wait={}ms)", System.identityHashCode(this), waitMilliSeconds); + EventQueueItem item = queue.poll(waitMilliSeconds, TimeUnit.MILLISECONDS); + if (item != null) { + Event event = item.getEvent(); + LOGGER.debug("Dequeued event item (waiting) {} {}", this, event instanceof Throwable ? event.toString() : event); + } else { + LOGGER.trace("Dequeue timeout (null) from ChildQueue {}", System.identityHashCode(this)); + } + return item; + } catch (InterruptedException e) { + LOGGER.debug("Interrupted dequeue (waiting) {}", this); + Thread.currentThread().interrupt(); + return null; + } + } finally { + signalQueuePollerStarted(); + } } @Override @@ -562,6 +619,12 @@ public EventQueue tap() { throw new IllegalStateException("Can only tap the main queue"); } + @Override + public int size() { + // Return size of local consumption queue + return queue.size(); + } + @Override public void awaitQueuePollerStart() throws InterruptedException { parent.awaitQueuePollerStart(); @@ -572,6 +635,19 @@ public void signalQueuePollerStarted() { parent.signalQueuePollerStarted(); } + @Override + protected void doClose(boolean immediate) { + super.doClose(immediate); // Sets closed flag + if (immediate) { + // Immediate close: clear pending events from local queue + this.immediateClose = true; + int clearedCount = queue.size(); + queue.clear(); + LOGGER.debug("Cleared {} events from ChildQueue for immediate close: {}", clearedCount, this); + } + // For graceful close, let the queue drain naturally through normal consumption + } + @Override public void close() { close(false); diff --git a/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java b/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java index e5a17e0e7..abd043614 100644 --- a/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java +++ b/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java @@ -34,16 +34,20 @@ protected InMemoryQueueManager() { this.taskStateProvider = null; } + MainEventBus mainEventBus; + @Inject - public InMemoryQueueManager(TaskStateProvider taskStateProvider) { + public InMemoryQueueManager(TaskStateProvider taskStateProvider, MainEventBus mainEventBus) { + this.mainEventBus = mainEventBus; this.factory = new DefaultEventQueueFactory(); this.taskStateProvider = taskStateProvider; } - // For testing with custom factory - public InMemoryQueueManager(EventQueueFactory factory, TaskStateProvider taskStateProvider) { + // For testing/extensions with custom factory and MainEventBus + public InMemoryQueueManager(EventQueueFactory factory, TaskStateProvider taskStateProvider, MainEventBus mainEventBus) { this.factory = factory; this.taskStateProvider = taskStateProvider; + this.mainEventBus = mainEventBus; } @Override @@ -54,6 +58,24 @@ public void add(String taskId, EventQueue queue) { } } + @Override + public void switchKey(String oldId, String newId) { + EventQueue queue = queues.remove(oldId); + if (queue == null) { + throw new IllegalStateException("No queue found for old ID: " + oldId); + } + + EventQueue existing = queues.putIfAbsent(newId, queue); + if (existing != null) { + // Rollback - put old one back + queues.putIfAbsent(oldId, queue); + throw new TaskQueueExistsException(); + } + + LOGGER.debug("Switched queue {} from temp ID {} to real task ID {}", + System.identityHashCode(queue), oldId, newId); + } + @Override public @Nullable EventQueue get(String taskId) { return queues.get(taskId); @@ -128,6 +150,12 @@ public void awaitQueuePollerStart(EventQueue eventQueue) throws InterruptedExcep eventQueue.awaitQueuePollerStart(); } + @Override + public EventQueue.EventQueueBuilder getEventQueueBuilder(String taskId) { + // Use the factory to ensure proper configuration (MainEventBus, callbacks, etc.) + return factory.builder(taskId); + } + @Override public int getActiveChildQueueCount(String taskId) { EventQueue queue = queues.get(taskId); @@ -142,6 +170,14 @@ public int getActiveChildQueueCount(String taskId) { return -1; } + @Override + public EventQueue.EventQueueBuilder createBaseEventQueueBuilder(String taskId) { + return EventQueue.builder(mainEventBus) + .taskId(taskId) + .addOnCloseCallback(getCleanupCallback(taskId)) + .taskStateProvider(taskStateProvider); + } + /** * Get the cleanup callback that removes a queue from the map when it closes. * This is exposed so that subclasses (like ReplicatedQueueManager) can reuse @@ -181,11 +217,8 @@ public Runnable getCleanupCallback(String taskId) { private class DefaultEventQueueFactory implements EventQueueFactory { @Override public EventQueue.EventQueueBuilder builder(String taskId) { - // Return builder with callback that removes queue from map when closed - return EventQueue.builder() - .taskId(taskId) - .addOnCloseCallback(getCleanupCallback(taskId)) - .taskStateProvider(taskStateProvider); + // Delegate to the base builder creation method + return createBaseEventQueueBuilder(taskId); } } } diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBus.java b/server-common/src/main/java/io/a2a/server/events/MainEventBus.java new file mode 100644 index 000000000..73500254e --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBus.java @@ -0,0 +1,42 @@ +package io.a2a.server.events; + +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingDeque; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import jakarta.enterprise.context.ApplicationScoped; + +@ApplicationScoped +public class MainEventBus { + private static final Logger LOGGER = LoggerFactory.getLogger(MainEventBus.class); + private final BlockingQueue queue; + + public MainEventBus() { + this.queue = new LinkedBlockingDeque<>(); + } + + public void submit(String taskId, EventQueue eventQueue, EventQueueItem item) { + try { + queue.put(new MainEventBusContext(taskId, eventQueue, item)); + LOGGER.debug("Submitted event for task {} to MainEventBus (queue size: {})", + taskId, queue.size()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted submitting to MainEventBus", e); + } + } + + public MainEventBusContext take() throws InterruptedException { + LOGGER.debug("MainEventBus: Waiting to take event (current queue size: {})...", queue.size()); + MainEventBusContext context = queue.take(); + LOGGER.debug("MainEventBus: Took event for task {} (remaining queue size: {})", + context.taskId(), queue.size()); + return context; + } + + public int size() { + return queue.size(); + } +} diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusContext.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusContext.java new file mode 100644 index 000000000..f8e5e03ec --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusContext.java @@ -0,0 +1,11 @@ +package io.a2a.server.events; + +import java.util.Objects; + +record MainEventBusContext(String taskId, EventQueue eventQueue, EventQueueItem eventQueueItem) { + MainEventBusContext { + Objects.requireNonNull(taskId, "taskId cannot be null"); + Objects.requireNonNull(eventQueue, "eventQueue cannot be null"); + Objects.requireNonNull(eventQueueItem, "eventQueueItem cannot be null"); + } +} diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java new file mode 100644 index 000000000..91aaac3ef --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java @@ -0,0 +1,368 @@ +package io.a2a.server.events; + +import java.util.concurrent.CompletableFuture; + +import jakarta.annotation.Nullable; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import io.a2a.server.tasks.PushNotificationSender; +import io.a2a.server.tasks.TaskManager; +import io.a2a.server.tasks.TaskStore; +import io.a2a.spec.A2AServerException; +import io.a2a.spec.Event; +import io.a2a.spec.InternalError; +import io.a2a.spec.Task; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskStatusUpdateEvent; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Background processor for the MainEventBus. + *

+ * This processor runs in a dedicated background thread, consuming events from the MainEventBus + * and performing two critical operations in order: + *

+ *
    + *
  1. Update TaskStore with event data (persistence FIRST)
  2. + *
  3. Distribute event to ChildQueues (clients see it AFTER persistence)
  4. + *
+ *

+ * This architecture ensures clients never receive events before they're persisted, + * eliminating race conditions and enabling reliable event replay. + *

+ *

+ * Note: This bean is eagerly initialized by {@link MainEventBusProcessorInitializer} + * to ensure the background thread starts automatically when the application starts. + *

+ */ +@ApplicationScoped +public class MainEventBusProcessor implements Runnable { + private static final Logger LOGGER = LoggerFactory.getLogger(MainEventBusProcessor.class); + + /** + * Callback for testing synchronization with async event processing. + * Default is NOOP to avoid null checks in production code. + * Tests can inject their own callback via setCallback(). + */ + private volatile MainEventBusProcessorCallback callback = MainEventBusProcessorCallback.NOOP; + + /** + * Optional executor for push notifications. + * If null, uses default ForkJoinPool (async). + * Tests can inject a synchronous executor to ensure deterministic ordering. + */ + private volatile @Nullable java.util.concurrent.Executor pushNotificationExecutor = null; + + private final MainEventBus eventBus; + + private final TaskStore taskStore; + + private final PushNotificationSender pushSender; + + private volatile boolean running = true; + private @Nullable Thread processorThread; + + @Inject + public MainEventBusProcessor(MainEventBus eventBus, TaskStore taskStore, PushNotificationSender pushSender) { + this.eventBus = eventBus; + this.taskStore = taskStore; + this.pushSender = pushSender; + } + + /** + * Set a callback for testing synchronization with async event processing. + *

+ * This is primarily intended for tests that need to wait for event processing to complete. + * Pass null to reset to the default NOOP callback. + *

+ * + * @param callback the callback to invoke during event processing, or null for NOOP + */ + public void setCallback(MainEventBusProcessorCallback callback) { + this.callback = callback != null ? callback : MainEventBusProcessorCallback.NOOP; + } + + /** + * Set a custom executor for push notifications (primarily for testing). + *

+ * By default, push notifications are sent asynchronously using CompletableFuture.runAsync() + * with the default ForkJoinPool. For tests that need deterministic ordering of push + * notifications, inject a synchronous executor that runs tasks immediately on the calling thread. + *

+ * Example synchronous executor for tests: + *
{@code
+     * Executor syncExecutor = Runnable::run;
+     * mainEventBusProcessor.setPushNotificationExecutor(syncExecutor);
+     * }
+ * + * @param executor the executor to use for push notifications, or null to use default ForkJoinPool + */ + public void setPushNotificationExecutor(java.util.concurrent.Executor executor) { + this.pushNotificationExecutor = executor; + } + + @PostConstruct + void start() { + processorThread = new Thread(this, "MainEventBusProcessor"); + processorThread.setDaemon(true); // Allow JVM to exit even if this thread is running + processorThread.start(); + LOGGER.info("MainEventBusProcessor started"); + } + + /** + * No-op method to force CDI proxy resolution and ensure @PostConstruct has been called. + * Called by MainEventBusProcessorInitializer during application startup. + */ + public void ensureStarted() { + // Method intentionally empty - just forces proxy resolution + } + + @PreDestroy + void stop() { + LOGGER.info("MainEventBusProcessor stopping..."); + running = false; + if (processorThread != null) { + processorThread.interrupt(); + try { + long start = System.currentTimeMillis(); + processorThread.join(5000); // Wait up to 5 seconds + long elapsed = System.currentTimeMillis() - start; + LOGGER.info("MainEventBusProcessor thread stopped in {}ms", elapsed); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOGGER.warn("Interrupted while waiting for MainEventBusProcessor thread to stop"); + } + } + LOGGER.info("MainEventBusProcessor stopped"); + } + + @Override + public void run() { + LOGGER.info("MainEventBusProcessor processing loop started"); + while (running) { + try { + LOGGER.debug("MainEventBusProcessor: Waiting for event from MainEventBus..."); + MainEventBusContext context = eventBus.take(); + LOGGER.debug("MainEventBusProcessor: Retrieved event for task {} from MainEventBus", + context.taskId()); + processEvent(context); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOGGER.info("MainEventBusProcessor interrupted, shutting down"); + break; + } catch (Exception e) { + LOGGER.error("Error processing event from MainEventBus", e); + // Continue processing despite errors + } + } + LOGGER.info("MainEventBusProcessor processing loop ended"); + } + + private void processEvent(MainEventBusContext context) { + String taskId = context.taskId(); + Event event = context.eventQueueItem().getEvent(); + EventQueue eventQueue = context.eventQueue(); + + LOGGER.debug("MainEventBusProcessor: Processing event for task {}: {} (queue type: {})", + taskId, event.getClass().getSimpleName(), eventQueue.getClass().getSimpleName()); + + Event eventToDistribute = null; + try { + // Step 1: Update TaskStore FIRST (persistence before clients see it) + // If this throws, we distribute an error to ensure "persist before client visibility" + + try { + updateTaskStore(taskId, event); + eventToDistribute = event; // Success - distribute original event + } catch (InternalError e) { + // Persistence failed - create error event to distribute instead + LOGGER.error("Failed to persist event for task {}, distributing error to clients", taskId, e); + String errorMessage = "Failed to persist event: " + e.getMessage(); + eventToDistribute = e; + } catch (Exception e) { + LOGGER.error("Failed to persist event for task {}, distributing error to clients", taskId, e); + String errorMessage = "Failed to persist event: " + e.getMessage(); + eventToDistribute = new InternalError(errorMessage); + } + + // Step 2: Send push notification AFTER successful persistence + if (eventToDistribute == event) { + // Capture task state immediately after persistence, before going async + // This ensures we send the task as it existed when THIS event was processed, + // not whatever state might exist later when the async callback executes + Task taskSnapshot = taskStore.get(taskId); + if (taskSnapshot != null) { + sendPushNotification(taskId, taskSnapshot); + } else { + LOGGER.warn("Task {} not found in TaskStore after successful persistence, skipping push notification", taskId); + } + } + + // Step 3: Then distribute to ChildQueues (clients see either event or error AFTER persistence attempt) + if (eventToDistribute == null) { + LOGGER.error("MainEventBusProcessor: eventToDistribute is NULL for task {} - this should never happen!", taskId); + eventToDistribute = new InternalError("Internal error: event processing failed"); + } + + if (eventQueue instanceof EventQueue.MainQueue mainQueue) { + int childCount = mainQueue.getChildCount(); + LOGGER.debug("MainEventBusProcessor: Distributing {} to {} children for task {}", + eventToDistribute.getClass().getSimpleName(), childCount, taskId); + // Create new EventQueueItem with the event to distribute (original or error) + EventQueueItem itemToDistribute = new LocalEventQueueItem(eventToDistribute); + mainQueue.distributeToChildren(itemToDistribute); + LOGGER.debug("MainEventBusProcessor: Distributed {} to {} children for task {}", + eventToDistribute.getClass().getSimpleName(), childCount, taskId); + } else { + LOGGER.warn("MainEventBusProcessor: Expected MainQueue but got {} for task {}", + eventQueue.getClass().getSimpleName(), taskId); + } + + LOGGER.debug("MainEventBusProcessor: Completed processing event for task {}", taskId); + + } finally { + try { + // Step 4: Notify callback after all processing is complete + // Call callback with the distributed event (original or error) + if (eventToDistribute != null) { + callback.onEventProcessed(taskId, eventToDistribute); + + // Step 5: If this is a final event, notify task finalization + // Only for successful persistence (not for errors) + if (eventToDistribute == event && isFinalEvent(event)) { + callback.onTaskFinalized(taskId); + } + } + } finally { + // ALWAYS release semaphore, even if processing fails + // Balances the acquire() in MainQueue.enqueueEvent() + if (eventQueue instanceof EventQueue.MainQueue mainQueue) { + mainQueue.releaseSemaphore(); + } + } + } + } + + /** + * Updates TaskStore using TaskManager.process(). + *

+ * Creates a temporary TaskManager instance for this event and delegates to its process() method, + * which handles all event types (Task, TaskStatusUpdateEvent, TaskArtifactUpdateEvent). + * This leverages existing TaskManager logic for status updates, artifact appending, message history, etc. + *

+ *

+ * If persistence fails, the exception is propagated to processEvent() which distributes an + * InternalError to clients instead of the original event, ensuring "persist before visibility". + * See Gemini's comment: https://github.com/a2aproject/a2a-java/pull/515#discussion_r2604621833 + *

+ * + * @throws InternalError if persistence fails + */ + private void updateTaskStore(String taskId, Event event) throws InternalError { + try { + // Extract contextId from event (all relevant events have it) + String contextId = extractContextId(event); + + // Create temporary TaskManager instance for this event + TaskManager taskManager = new TaskManager(taskId, contextId, taskStore, null); + + // Use TaskManager.process() - handles all event types with existing logic + taskManager.process(event); + LOGGER.debug("TaskStore updated via TaskManager.process() for task {}: {}", + taskId, event.getClass().getSimpleName()); + } catch (InternalError e) { + LOGGER.error("Error updating TaskStore via TaskManager for task {}", taskId, e); + // Rethrow to prevent distributing unpersisted event to clients + throw e; + } catch (Exception e) { + LOGGER.error("Unexpected error updating TaskStore for task {}", taskId, e); + // Rethrow to prevent distributing unpersisted event to clients + throw new InternalError("TaskStore persistence failed: " + e.getMessage()); + } + } + + /** + * Sends push notification for the task AFTER persistence. + *

+ * This is called after updateTaskStore() to ensure the notification contains + * the latest persisted state, avoiding race conditions. + *

+ *

+ * CRITICAL: Push notifications are sent asynchronously in the background + * to avoid blocking event distribution to ChildQueues. The 83ms overhead from + * PushNotificationSender.sendNotification() was causing streaming delays. + *

+ *

+ * IMPORTANT: The task parameter is a snapshot captured immediately after + * persistence. This ensures we send the task state as it existed when THIS event + * was processed, not whatever state might exist in TaskStore when the async + * callback executes (subsequent events may have already updated the store). + *

+ *

+ * NOTE: Tests can inject a synchronous executor via setPushNotificationExecutor() + * to ensure deterministic ordering of push notifications in the test environment. + *

+ * + * @param taskId the task ID + * @param task the task snapshot to send (captured immediately after persistence) + */ + private void sendPushNotification(String taskId, Task task) { + Runnable pushTask = () -> { + try { + if (task != null) { + LOGGER.debug("Sending push notification for task {}", taskId); + pushSender.sendNotification(task); + } else { + LOGGER.debug("Skipping push notification - task snapshot is null for task {}", taskId); + } + } catch (Exception e) { + LOGGER.error("Error sending push notification for task {}", taskId, e); + // Don't rethrow - push notifications are best-effort + } + }; + + // Use custom executor if set (for tests), otherwise use default ForkJoinPool (async) + if (pushNotificationExecutor != null) { + pushNotificationExecutor.execute(pushTask); + } else { + CompletableFuture.runAsync(pushTask); + } + } + + /** + * Extracts contextId from an event. + * Returns null if the event type doesn't have a contextId (e.g., Message). + */ + @Nullable + private String extractContextId(Event event) { + if (event instanceof Task task) { + return task.contextId(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.contextId(); + } else if (event instanceof TaskArtifactUpdateEvent artifactUpdate) { + return artifactUpdate.contextId(); + } + // Message and other events don't have contextId + return null; + } + + /** + * Checks if an event represents a final task state. + * + * @param event the event to check + * @return true if the event represents a final state (COMPLETED, FAILED, CANCELED, REJECTED, UNKNOWN) + */ + private boolean isFinalEvent(Event event) { + if (event instanceof Task task) { + return task.status() != null && task.status().state() != null + && task.status().state().isFinal(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.isFinal(); + } + return false; + } +} diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorCallback.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorCallback.java new file mode 100644 index 000000000..b0a9adbce --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorCallback.java @@ -0,0 +1,66 @@ +package io.a2a.server.events; + +import io.a2a.spec.Event; + +/** + * Callback interface for MainEventBusProcessor events. + *

+ * This interface is primarily intended for testing, allowing tests to synchronize + * with the asynchronous MainEventBusProcessor. Production code should not rely on this. + *

+ * Usage in tests: + *
+ * {@code
+ * @Inject
+ * MainEventBusProcessor processor;
+ *
+ * @BeforeEach
+ * void setUp() {
+ *     CountDownLatch latch = new CountDownLatch(3);
+ *     processor.setCallback(new MainEventBusProcessorCallback() {
+ *         public void onEventProcessed(String taskId, Event event) {
+ *             latch.countDown();
+ *         }
+ *     });
+ * }
+ *
+ * @AfterEach
+ * void tearDown() {
+ *     processor.setCallback(null); // Reset to NOOP
+ * }
+ * }
+ * 
+ */ +public interface MainEventBusProcessorCallback { + + /** + * Called after an event has been fully processed (persisted, notification sent, distributed to children). + * + * @param taskId the task ID + * @param event the event that was processed + */ + void onEventProcessed(String taskId, Event event); + + /** + * Called when a task reaches a final state (COMPLETED, FAILED, CANCELED, REJECTED). + * + * @param taskId the task ID that was finalized + */ + void onTaskFinalized(String taskId); + + /** + * No-op implementation that does nothing. + * Used as the default callback to avoid null checks. + */ + MainEventBusProcessorCallback NOOP = new MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, Event event) { + // No-op + } + + @Override + public void onTaskFinalized(String taskId) { + // No-op + } + }; +} diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorInitializer.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorInitializer.java new file mode 100644 index 000000000..ba4b300be --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorInitializer.java @@ -0,0 +1,43 @@ +package io.a2a.server.events; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.Initialized; +import jakarta.enterprise.event.Observes; +import jakarta.inject.Inject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Portable CDI initializer for MainEventBusProcessor. + *

+ * This bean observes the ApplicationScoped initialization event and injects + * MainEventBusProcessor, which triggers its eager creation and starts the background thread. + *

+ *

+ * This approach is portable across all Jakarta CDI implementations (Weld, OpenWebBeans, Quarkus, etc.) + * and ensures MainEventBusProcessor starts automatically when the application starts. + *

+ */ +@ApplicationScoped +public class MainEventBusProcessorInitializer { + private static final Logger LOGGER = LoggerFactory.getLogger(MainEventBusProcessorInitializer.class); + + @Inject + MainEventBusProcessor processor; + + /** + * Observes ApplicationScoped initialization to force eager creation of MainEventBusProcessor. + * The injection of MainEventBusProcessor in this bean triggers its creation, and calling + * ensureStarted() forces the CDI proxy to be resolved, which ensures @PostConstruct has been + * called and the background thread is running. + */ + void onStart(@Observes @Initialized(ApplicationScoped.class) Object event) { + if (processor != null) { + // Force proxy resolution to ensure @PostConstruct has been called + processor.ensureStarted(); + LOGGER.info("MainEventBusProcessor initialized and started"); + } else { + LOGGER.error("MainEventBusProcessor is null - initialization failed!"); + } + } +} diff --git a/server-common/src/main/java/io/a2a/server/events/QueueManager.java b/server-common/src/main/java/io/a2a/server/events/QueueManager.java index 01e754fcb..b4ab24317 100644 --- a/server-common/src/main/java/io/a2a/server/events/QueueManager.java +++ b/server-common/src/main/java/io/a2a/server/events/QueueManager.java @@ -96,6 +96,25 @@ public interface QueueManager { */ void add(String taskId, EventQueue queue); + /** + * Switches a queue from an old key to a new key atomically. + *

+ * Used when transitioning from a temporary task ID (e.g., "temp-UUID") to the real task ID + * when the Task event arrives with the actual task.id. This prevents duplicate map entries + * and ensures clean queue lifecycle management. + *

+ *

+ * The operation is atomic: removes the old key and adds the new key. If the new key already + * exists, the operation is rolled back by restoring the old key. + *

+ * + * @param oldId the temporary/old task identifier to remove + * @param newId the real/new task identifier to add + * @throws IllegalStateException if no queue exists for oldId + * @throws TaskQueueExistsException if a queue already exists for newId + */ + void switchKey(String oldId, String newId); + /** * Retrieves the MainQueue for a task, if it exists. *

@@ -177,7 +196,31 @@ public interface QueueManager { * @return a builder for creating event queues */ default EventQueue.EventQueueBuilder getEventQueueBuilder(String taskId) { - return EventQueue.builder(); + throw new UnsupportedOperationException( + "QueueManager implementations must override getEventQueueBuilder() to provide MainEventBus" + ); + } + + /** + * Creates a base EventQueueBuilder with standard configuration for this QueueManager. + * This method provides the foundation for creating event queues with proper configuration + * (MainEventBus, TaskStateProvider, cleanup callbacks, etc.). + *

+ * QueueManager implementations that use custom factories can call this method directly + * to get the base builder without going through the factory (which could cause infinite + * recursion if the factory delegates back to getEventQueueBuilder()). + *

+ *

+ * Callers can then add additional configuration (hooks, callbacks) before building the queue. + *

+ * + * @param taskId the task ID for the queue + * @return a builder with base configuration specific to this QueueManager implementation + */ + default EventQueue.EventQueueBuilder createBaseEventQueueBuilder(String taskId) { + throw new UnsupportedOperationException( + "QueueManager implementations must override createBaseEventQueueBuilder() to provide MainEventBus" + ); } /** diff --git a/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java b/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java index 002acbafd..07bba3a9b 100644 --- a/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java +++ b/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java @@ -11,13 +11,13 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; import java.util.concurrent.Flow; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; @@ -35,6 +35,8 @@ import io.a2a.server.events.EventConsumer; import io.a2a.server.events.EventQueue; import io.a2a.server.events.EventQueueItem; +import io.a2a.server.events.MainEventBusProcessor; +import io.a2a.server.events.MainEventBusProcessorCallback; import io.a2a.server.events.QueueManager; import io.a2a.server.events.TaskQueueExistsException; import io.a2a.server.tasks.PushNotificationConfigStore; @@ -42,6 +44,7 @@ import io.a2a.server.tasks.ResultAggregator; import io.a2a.server.tasks.TaskManager; import io.a2a.server.tasks.TaskStore; +import io.a2a.server.util.async.EventConsumerExecutorProducer.EventConsumerExecutor; import io.a2a.server.util.async.Internal; import io.a2a.spec.A2AError; import io.a2a.spec.DeleteTaskPushNotificationConfigParams; @@ -64,6 +67,7 @@ import io.a2a.spec.TaskPushNotificationConfig; import io.a2a.spec.TaskQueryParams; import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.UnsupportedOperationError; import org.jspecify.annotations.NonNull; import org.jspecify.annotations.Nullable; @@ -122,7 +126,6 @@ *
  • {@link EventConsumer} polls and processes events on Vert.x worker thread
  • *
  • Queue closes automatically on final event (COMPLETED/FAILED/CANCELED)
  • *
  • Cleanup waits for both agent execution AND event consumption to complete
  • - *
  • Background tasks tracked via {@link #trackBackgroundTask(CompletableFuture)}
  • * * *

    Threading Model

    @@ -179,6 +182,13 @@ public class DefaultRequestHandler implements RequestHandler { private static final Logger LOGGER = LoggerFactory.getLogger(DefaultRequestHandler.class); + /** + * Separate logger for thread statistics diagnostic logging. + * This allows independent control of verbose thread pool monitoring without affecting + * general request handler logging. Enable with: logging.level.io.a2a.server.diagnostics.ThreadStats=DEBUG + */ + private static final Logger THREAD_STATS_LOGGER = LoggerFactory.getLogger("io.a2a.server.diagnostics.ThreadStats"); + private static final String A2A_BLOCKING_AGENT_TIMEOUT_SECONDS = "a2a.blocking.agent.timeout.seconds"; private static final String A2A_BLOCKING_CONSUMPTION_TIMEOUT_SECONDS = "a2a.blocking.consumption.timeout.seconds"; @@ -214,13 +224,19 @@ public class DefaultRequestHandler implements RequestHandler { private TaskStore taskStore; private QueueManager queueManager; private PushNotificationConfigStore pushConfigStore; - private PushNotificationSender pushSender; + private MainEventBusProcessor mainEventBusProcessor; private Supplier requestContextBuilder; private final ConcurrentMap> runningAgents = new ConcurrentHashMap<>(); - private final Set> backgroundTasks = ConcurrentHashMap.newKeySet(); + + /** + * Map of taskId → CountDownLatch for tasks waiting for finalization. + * Used by blocking calls to wait for MainEventBusProcessor to persist final state to TaskStore. + */ + private final ConcurrentMap pendingFinalizations = new ConcurrentHashMap<>(); private Executor executor; + private Executor eventConsumerExecutor; /** * No-args constructor for CDI proxy creation. @@ -234,21 +250,25 @@ protected DefaultRequestHandler() { this.taskStore = null; this.queueManager = null; this.pushConfigStore = null; - this.pushSender = null; + this.mainEventBusProcessor = null; this.requestContextBuilder = null; this.executor = null; + this.eventConsumerExecutor = null; } @Inject public DefaultRequestHandler(AgentExecutor agentExecutor, TaskStore taskStore, QueueManager queueManager, PushNotificationConfigStore pushConfigStore, - PushNotificationSender pushSender, @Internal Executor executor) { + MainEventBusProcessor mainEventBusProcessor, + @Internal Executor executor, + @EventConsumerExecutor Executor eventConsumerExecutor) { this.agentExecutor = agentExecutor; this.taskStore = taskStore; this.queueManager = queueManager; this.pushConfigStore = pushConfigStore; - this.pushSender = pushSender; + this.mainEventBusProcessor = mainEventBusProcessor; this.executor = executor; + this.eventConsumerExecutor = eventConsumerExecutor; // TODO In Python this is also a constructor parameter defaulting to this SimpleRequestContextBuilder // implementation if the parameter is null. Skip that for now, since otherwise I get CDI errors, and // I am unsure about the correct scope. @@ -262,6 +282,97 @@ void initConfig() { configProvider.getValue(A2A_BLOCKING_AGENT_TIMEOUT_SECONDS)); consumptionCompletionTimeoutSeconds = Integer.parseInt( configProvider.getValue(A2A_BLOCKING_CONSUMPTION_TIMEOUT_SECONDS)); + + // Register permanent callback for task finalization events + registerFinalizationCallback(); + } + + /** + * Register the permanent callback for task finalization events. + *

    + * This single callback multiplexes for all concurrent requests via the + * {@link #pendingFinalizations} map. Called by both {@link #initConfig()} + * (@PostConstruct for CDI) and {@link #create(AgentExecutor, TaskStore, QueueManager, PushNotificationConfigStore, MainEventBusProcessor, Executor, Executor)} + * (static factory for tests). + *

    + */ + private void registerFinalizationCallback() { + mainEventBusProcessor.setCallback(new MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, Event event) { + // Not used for task finalization wait + } + + @Override + public void onTaskFinalized(String taskId) { + // Signal any blocking call waiting for this task to finalize + CountDownLatch latch = pendingFinalizations.get(taskId); + if (latch != null) { + latch.countDown(); + pendingFinalizations.remove(taskId); + LOGGER.debug("Task {} finalization signaled to waiting thread", taskId); + } + } + }); + } + + /** + * Wait for MainEventBusProcessor to finalize a task (reach final state and persist to TaskStore). + *

    + * This method is used by blocking calls to ensure TaskStore is fully updated before returning + * to the client. It registers this task in the {@link #pendingFinalizations} map, which is + * monitored by the permanent callback registered in {@link #initConfig()}. + *

    + *

    + * Why this is needed: Events flow through MainEventBus → MainEventBusProcessor → + * TaskStore persistence. The consumption future completing only means ChildQueue is empty, + * NOT that MainEventBusProcessor has finished persisting to TaskStore. This creates a race + * condition where blocking calls might read stale state from TaskStore. + *

    + *

    + * Concurrency: Uses a single permanent callback that multiplexes for all concurrent + * requests via {@link #pendingFinalizations} map. This avoids callback overwrite issues + * when multiple requests execute simultaneously. + *

    + *

    + * Race Condition Prevention: Checks TaskStore FIRST before waiting. If the task is + * already finalized, returns immediately. This prevents waiting forever when the callback + * fires before the latch is registered in the map. + *

    + * + * @param taskId the task ID to wait for + * @param timeoutSeconds maximum time to wait for finalization + * @throws InterruptedException if interrupted while waiting + * @throws TimeoutException if task doesn't finalize within timeout + */ + private void waitForTaskFinalization(String taskId, int timeoutSeconds) + throws InterruptedException, TimeoutException { + // CRITICAL: Check TaskStore FIRST to avoid race condition where callback fires + // before latch is registered. If task is finalized, return immediately. + Task task = taskStore.get(taskId); + if (task != null && task.status() != null && task.status().state() != null + && task.status().state().isFinal()) { + LOGGER.debug("Task {} is finalized in TaskStore, skipping wait", taskId); + return; + } + + CountDownLatch finalizationLatch = new CountDownLatch(1); + + // Register this task's latch in the map + // The permanent callback (registered in initConfig) will signal it + pendingFinalizations.put(taskId, finalizationLatch); + + try { + // Wait for the callback to fire + if (!finalizationLatch.await(timeoutSeconds, SECONDS)) { + throw new TimeoutException( + String.format("Task %s finalization timeout after %d seconds", taskId, timeoutSeconds)); + } + LOGGER.debug("Task {} finalized and persisted to TaskStore", taskId); + } finally { + // Always remove from map to avoid memory leaks + pendingFinalizations.remove(taskId); + } } /** @@ -269,11 +380,17 @@ void initConfig() { */ public static DefaultRequestHandler create(AgentExecutor agentExecutor, TaskStore taskStore, QueueManager queueManager, PushNotificationConfigStore pushConfigStore, - PushNotificationSender pushSender, Executor executor) { + MainEventBusProcessor mainEventBusProcessor, + Executor executor, Executor eventConsumerExecutor) { DefaultRequestHandler handler = - new DefaultRequestHandler(agentExecutor, taskStore, queueManager, pushConfigStore, pushSender, executor); + new DefaultRequestHandler(agentExecutor, taskStore, queueManager, pushConfigStore, + mainEventBusProcessor, executor, eventConsumerExecutor); handler.agentCompletionTimeoutSeconds = 5; handler.consumptionCompletionTimeoutSeconds = 2; + + // Register permanent callback for task finalization (normally done in @PostConstruct) + handler.registerFinalizationCallback(); + return handler; } @@ -359,12 +476,9 @@ public Task onCancelTask(TaskIdParams params, ServerCallContext context) throws taskStore, null); - ResultAggregator resultAggregator = new ResultAggregator(taskManager, null, executor); + ResultAggregator resultAggregator = new ResultAggregator(taskManager, null, executor, eventConsumerExecutor); - EventQueue queue = queueManager.tap(task.id()); - if (queue == null) { - queue = queueManager.getEventQueueBuilder(task.id()).build(); - } + EventQueue queue = queueManager.createOrTap(task.id()); agentExecutor.cancel( requestContextBuilder.get() .setTaskId(task.id()) @@ -397,26 +511,36 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte LOGGER.debug("onMessageSend - task: {}; context {}", params.message().taskId(), params.message().contextId()); MessageSendSetup mss = initMessageSend(params, context); - String taskId = mss.requestContext.getTaskId(); - LOGGER.debug("Request context taskId: {}", taskId); + @Nullable String initialTaskId = mss.requestContext.getTaskId(); + // For non-streaming, taskId can be null initially (will be set when Task event arrives) + // Use a temporary ID for queue creation if needed (same pattern as streaming) + String queueTaskId = initialTaskId != null ? initialTaskId : "temp-" + java.util.UUID.randomUUID(); + LOGGER.debug("Request context taskId: {} (queue key: {})", initialTaskId, queueTaskId); - if (taskId == null) { - throw new io.a2a.spec.InternalError("Task ID is null in onMessageSend"); - } - EventQueue queue = queueManager.createOrTap(taskId); - ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null, executor); + EventQueue queue = queueManager.createOrTap(queueTaskId); + final java.util.concurrent.atomic.AtomicReference<@NonNull String> taskId = new java.util.concurrent.atomic.AtomicReference<>(queueTaskId); + ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null, executor, eventConsumerExecutor); + // Default to blocking=false per A2A spec (return after task creation) boolean blocking = params.configuration() != null && Boolean.TRUE.equals(params.configuration().blocking()); + // Log blocking behavior from client request + if (params.configuration() != null && params.configuration().blocking() != null) { + LOGGER.debug("DefaultRequestHandler: Client requested blocking={} for task {}", + params.configuration().blocking(), taskId.get()); + } else if (params.configuration() != null) { + LOGGER.debug("DefaultRequestHandler: Client sent configuration but blocking=null, using default blocking={} for task {}", blocking, taskId.get()); + } else { + LOGGER.debug("DefaultRequestHandler: Client sent no configuration, using default blocking={} for task {}", blocking, taskId.get()); + } + LOGGER.debug("DefaultRequestHandler: Final blocking decision: {} for task {}", blocking, taskId.get()); + boolean interruptedOrNonBlocking = false; - EnhancedRunnable producerRunnable = registerAndExecuteAgentAsync(taskId, mss.requestContext, queue); + EnhancedRunnable producerRunnable = registerAndExecuteAgentAsync(queueTaskId, mss.requestContext, queue); ResultAggregator.EventTypeAndInterrupt etai = null; EventKind kind = null; // Declare outside try block so it's in scope for return try { - // Create callback for push notifications during background event processing - Runnable pushNotificationCallback = () -> sendPushNotification(taskId, resultAggregator); - EventConsumer consumer = new EventConsumer(queue); // This callback must be added before we start consuming. Otherwise, @@ -424,7 +548,7 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte producerRunnable.addDoneCallback(consumer.createAgentRunnableDoneCallback()); // Get agent future before consuming (for blocking calls to wait for agent completion) - CompletableFuture agentFuture = runningAgents.get(taskId); + CompletableFuture agentFuture = runningAgents.get(queueTaskId); etai = resultAggregator.consumeAndBreakOnInterrupt(consumer, blocking); if (etai == null) { @@ -432,7 +556,8 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte throw new InternalError("No result"); } interruptedOrNonBlocking = etai.interrupted(); - LOGGER.debug("Was interrupted or non-blocking: {}", interruptedOrNonBlocking); + LOGGER.debug("DefaultRequestHandler: interruptedOrNonBlocking={} (blocking={}, eventType={})", + interruptedOrNonBlocking, blocking, kind != null ? kind.getClass().getSimpleName() : null); // For blocking calls that were interrupted (returned on first event), // wait for agent execution and event processing BEFORE returning to client. @@ -441,30 +566,51 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte // during the consumption loop itself. kind = etai.eventType(); + // Switch from temporary ID to real task ID if they differ + if (kind instanceof Task createdTask) { + String currentId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); + if (!Objects.equals(currentId, createdTask.id())) { + try { + queueManager.switchKey(currentId, createdTask.id()); + taskId.set(createdTask.id()); + LOGGER.debug("Switched non-streaming queue from {} to real task ID {}", + currentId, createdTask.id()); + } catch (TaskQueueExistsException | IllegalStateException e) { + String msg = "Failed to switch queue key from " + currentId + " to " + createdTask.id() + ": " + e.getMessage(); + LOGGER.error(msg, e); + throw new InternalError(msg); + } + } + } + // Store push notification config for newly created tasks (mirrors streaming logic) // Only for NEW tasks - existing tasks are handled by initMessageSend() if (mss.task() == null && kind instanceof Task createdTask && shouldAddPushInfo(params)) { - LOGGER.debug("Storing push notification config for new task {}", createdTask.id()); + LOGGER.debug("Storing push notification config for new task {} (original taskId from params: {})", + createdTask.id(), params.message().taskId()); pushConfigStore.setInfo(createdTask.id(), params.configuration().pushNotificationConfig()); } if (blocking && interruptedOrNonBlocking) { - // For blocking calls: ensure all events are processed before returning - // Order of operations is critical to avoid circular dependency: + // For blocking calls: ensure all events are persisted to TaskStore before returning + // Order of operations is critical to avoid circular dependency and race conditions: // 1. Wait for agent to finish enqueueing events // 2. Close the queue to signal consumption can complete // 3. Wait for consumption to finish processing events - // 4. Fetch final task state from TaskStore + // 4. Wait for MainEventBusProcessor to persist final state to TaskStore + // 5. Fetch final task state from TaskStore (now guaranteed persisted) + LOGGER.debug("DefaultRequestHandler: Entering blocking fire-and-forget handling for task {}", taskId.get()); try { // Step 1: Wait for agent to finish (with configurable timeout) if (agentFuture != null) { try { agentFuture.get(agentCompletionTimeoutSeconds, SECONDS); - LOGGER.debug("Agent completed for task {}", taskId); + LOGGER.debug("DefaultRequestHandler: Step 1 - Agent completed for task {}", taskId.get()); } catch (java.util.concurrent.TimeoutException e) { // Agent still running after timeout - that's fine, events already being processed - LOGGER.debug("Agent still running for task {} after {}s", taskId, agentCompletionTimeoutSeconds); + LOGGER.debug("DefaultRequestHandler: Step 1 - Agent still running for task {} after {}s timeout", + taskId.get(), agentCompletionTimeoutSeconds); } } @@ -472,55 +618,87 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte // For fire-and-forget tasks, there's no final event, so we need to close the queue // This allows EventConsumer.consumeAll() to exit queue.close(false, false); // graceful close, don't notify parent yet - LOGGER.debug("Closed queue for task {} to allow consumption completion", taskId); + LOGGER.debug("DefaultRequestHandler: Step 2 - Closed queue for task {} to allow consumption completion", taskId.get()); // Step 3: Wait for consumption to complete (now that queue is closed) if (etai.consumptionFuture() != null) { etai.consumptionFuture().get(consumptionCompletionTimeoutSeconds, SECONDS); - LOGGER.debug("Consumption completed for task {}", taskId); + LOGGER.debug("DefaultRequestHandler: Step 3 - Consumption completed for task {}", taskId.get()); } + + // Step 4: Wait for MainEventBusProcessor to finalize task (persist to TaskStore) + // For blocking calls, ALWAYS try to wait for finalization. + // waitForTaskFinalization() checks TaskStore first, so if task is already finalized + // it returns immediately. + // This is CRITICAL: consumption completing only means ChildQueue is empty, NOT that + // MainEventBusProcessor has finished persisting to TaskStore. The callback ensures + // we wait for the final state to be persisted before reading from TaskStore. + try { + String taskIdForFinalization = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); + waitForTaskFinalization(taskIdForFinalization, consumptionCompletionTimeoutSeconds); + LOGGER.debug("DefaultRequestHandler: Step 4 - Task {} finalized and persisted to TaskStore", taskId.get()); + } catch (TimeoutException e) { + // Timeout is OK for fire-and-forget tasks that never reach final state + // Just log and continue - we'll return the current non-final state + LOGGER.debug("DefaultRequestHandler: Step 4 - Task {} finalization timeout (fire-and-forget task)", taskId.get()); + } + } catch (InterruptedException e) { Thread.currentThread().interrupt(); - String msg = String.format("Error waiting for task %s completion", taskId); + String msg = String.format("Error waiting for task %s completion", taskId.get()); LOGGER.warn(msg, e); throw new InternalError(msg); } catch (java.util.concurrent.ExecutionException e) { - String msg = String.format("Error during task %s execution", taskId); + String msg = String.format("Error during task %s execution", taskId.get()); LOGGER.warn(msg, e.getCause()); throw new InternalError(msg); - } catch (java.util.concurrent.TimeoutException e) { - String msg = String.format("Timeout waiting for consumption to complete for task %s", taskId); - LOGGER.warn(msg, taskId); + } catch (TimeoutException e) { + // Timeout from consumption future.get() - different from finalization timeout + String msg = String.format("Timeout waiting for task %s consumption", taskId.get()); + LOGGER.warn(msg, e); throw new InternalError(msg); } - // Step 4: Fetch the final task state from TaskStore (all events have been processed) - // taskId is guaranteed non-null here (checked earlier) - String nonNullTaskId = taskId; + // Step 5: Fetch the final task state from TaskStore (now guaranteed persisted) + String nonNullTaskId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); Task updatedTask = taskStore.get(nonNullTaskId); if (updatedTask != null) { kind = updatedTask; - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Fetched final task for {} with state {} and {} artifacts", - nonNullTaskId, updatedTask.status().state(), - updatedTask.artifacts().size()); - } + LOGGER.debug("DefaultRequestHandler: Step 5 - Fetched final task for {} with state {} and {} artifacts", + taskId.get(), updatedTask.status().state(), + updatedTask.artifacts().size()); + } else { + LOGGER.warn("DefaultRequestHandler: Step 5 - Task {} not found in TaskStore!", taskId.get()); } } - if (kind instanceof Task taskResult && !taskId.equals(taskResult.id())) { + String finalTaskId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); + if (kind instanceof Task taskResult && !finalTaskId.equals(taskResult.id())) { throw new InternalError("Task ID mismatch in agent response"); } - - // Send push notification after initial return (for both blocking and non-blocking) - pushNotificationCallback.run(); } finally { + // For non-blocking calls: close ChildQueue IMMEDIATELY to free EventConsumer thread + // CRITICAL: Must use immediate=true to clear the local queue, otherwise EventConsumer + // continues polling until queue drains naturally, holding executor thread. + // Immediate close clears pending events and triggers EventQueueClosedException on next poll. + // Events continue flowing through MainQueue → MainEventBus → TaskStore. + if (!blocking && etai != null && etai.interrupted()) { + LOGGER.debug("DefaultRequestHandler: Non-blocking call in finally - closing ChildQueue IMMEDIATELY for task {} to free EventConsumer", taskId.get()); + queue.close(true); // immediate=true: clear queue and free EventConsumer + } + // Remove agent from map immediately to prevent accumulation - CompletableFuture agentFuture = runningAgents.remove(taskId); - LOGGER.debug("Removed agent for task {} from runningAgents in finally block, size after: {}", taskId, runningAgents.size()); + CompletableFuture agentFuture = runningAgents.remove(queueTaskId); + String cleanupTaskId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); + LOGGER.debug("Removed agent for task {} from runningAgents in finally block, size after: {}", cleanupTaskId, runningAgents.size()); - // Track cleanup as background task to avoid blocking Vert.x threads + // Cleanup as background task to avoid blocking Vert.x threads // Pass the consumption future to ensure cleanup waits for background consumption to complete - trackBackgroundTask(cleanupProducer(agentFuture, etai != null ? etai.consumptionFuture() : null, taskId, queue, false)); + cleanupProducer(agentFuture, etai != null ? etai.consumptionFuture() : null, cleanupTaskId, queue, false) + .whenComplete((res, err) -> { + if (err != null) { + LOGGER.error("Error during async cleanup for task {}", taskId.get(), err); + } + }); } LOGGER.debug("Returning: {}", kind); @@ -530,8 +708,8 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte @Override public Flow.Publisher onMessageSendStream( MessageSendParams params, ServerCallContext context) throws A2AError { - LOGGER.debug("onMessageSendStream START - task: {}; context: {}; runningAgents: {}; backgroundTasks: {}", - params.message().taskId(), params.message().contextId(), runningAgents.size(), backgroundTasks.size()); + LOGGER.debug("onMessageSendStream START - task: {}; context: {}; runningAgents: {}", + params.message().taskId(), params.message().contextId(), runningAgents.size()); MessageSendSetup mss = initMessageSend(params, context); @Nullable String initialTaskId = mss.requestContext.getTaskId(); @@ -539,20 +717,33 @@ public Flow.Publisher onMessageSendStream( // Use a temporary ID for queue creation if needed String queueTaskId = initialTaskId != null ? initialTaskId : "temp-" + java.util.UUID.randomUUID(); - AtomicReference<@NonNull String> taskId = new AtomicReference<>(queueTaskId); + final AtomicReference<@NonNull String> taskId = new AtomicReference<>(queueTaskId); @SuppressWarnings("NullAway") EventQueue queue = queueManager.createOrTap(taskId.get()); LOGGER.debug("Created/tapped queue for task {}: {}", taskId.get(), queue); - ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null, executor); + + // Store push notification config SYNCHRONOUSLY for new tasks before agent starts + // This ensures config is available when MainEventBusProcessor sends push notifications + // For existing tasks, config is stored in initMessageSend() + if (mss.task() == null && shouldAddPushInfo(params)) { + // Satisfy Nullaway + Objects.requireNonNull(taskId.get(), "taskId was null"); + LOGGER.debug("Storing push notification config for new streaming task {} EARLY (original taskId from params: {})", + taskId.get(), params.message().taskId()); + pushConfigStore.setInfo(taskId.get(), params.configuration().pushNotificationConfig()); + } + + ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null, executor, eventConsumerExecutor); EnhancedRunnable producerRunnable = registerAndExecuteAgentAsync(queueTaskId, mss.requestContext, queue); // Move consumer creation and callback registration outside try block - // so consumer is available for background consumption on client disconnect EventConsumer consumer = new EventConsumer(queue); producerRunnable.addDoneCallback(consumer.createAgentRunnableDoneCallback()); - AtomicBoolean backgroundConsumeStarted = new AtomicBoolean(false); + // Store cancel callback in context for closeHandler to access + // When client disconnects, closeHandler can call this to stop EventConsumer polling loop + context.setEventConsumerCancelCallback(consumer::cancel); try { Flow.Publisher results = resultAggregator.consumeAndEmit(consumer); @@ -566,32 +757,21 @@ public Flow.Publisher onMessageSendStream( errorConsumer.accept(new InternalError("Task ID mismatch in agent response")); } - // TODO the Python implementation no longer has the following block but removing it causes - // failures here - try { - queueManager.add(createdTask.id(), queue); - taskId.set(createdTask.id()); - } catch (TaskQueueExistsException e) { - // TODO Log - } - if (pushConfigStore != null && - params.configuration() != null && - params.configuration().pushNotificationConfig() != null) { - - pushConfigStore.setInfo( - createdTask.id(), - params.configuration().pushNotificationConfig()); - } - - } - String currentTaskId = taskId.get(); - if (pushSender != null && currentTaskId != null) { - EventKind latest = resultAggregator.getCurrentResult(); - if (latest instanceof Task latestTask) { - pushSender.sendNotification(latestTask); + // Switch from temporary ID to real task ID if they differ + String currentId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); + if (!Objects.equals(currentId, createdTask.id())) { + try { + queueManager.switchKey(currentId, createdTask.id()); + taskId.set(createdTask.id()); + LOGGER.debug("Switched streaming queue from {} to real task ID {}", + currentId, createdTask.id()); + } catch (TaskQueueExistsException e) { + errorConsumer.accept(new InternalError("Queue already exists for task " + createdTask.id())); + } catch (IllegalStateException e) { + errorConsumer.accept(new InternalError("Failed to switch queue key: " + e.getMessage())); + } } } - return true; })); @@ -600,7 +780,8 @@ public Flow.Publisher onMessageSendStream( Flow.Publisher finalPublisher = convertingProcessor(eventPublisher, event -> (StreamingEventKind) event); - // Wrap publisher to detect client disconnect and continue background consumption + // Wrap publisher to detect client disconnect and immediately close ChildQueue + // This prevents ChildQueue backpressure from blocking MainEventBusProcessor return subscriber -> { String currentTaskId = taskId.get(); LOGGER.debug("Creating subscription wrapper for task {}", currentTaskId); @@ -621,8 +802,10 @@ public void request(long n) { @Override public void cancel() { - LOGGER.debug("Client cancelled subscription for task {}, starting background consumption", taskId.get()); - startBackgroundConsumption(); + LOGGER.debug("Client cancelled subscription for task {}, closing ChildQueue immediately", taskId.get()); + // Close ChildQueue immediately to prevent backpressure + // (clears queue and releases semaphore permits) + queue.close(true); // immediate=true subscription.cancel(); } }); @@ -647,8 +830,8 @@ public void onComplete() { subscriber.onComplete(); } catch (IllegalStateException e) { // Client already disconnected and response closed - this is expected - // for streaming responses where client disconnect triggers background - // consumption. Log and ignore. + // for streaming responses where client disconnect closes ChildQueue. + // Log and ignore. if (e.getMessage() != null && e.getMessage().contains("Response has already been written")) { LOGGER.debug("Client disconnected before onComplete, response already closed for task {}", taskId.get()); } else { @@ -656,36 +839,26 @@ public void onComplete() { } } } - - private void startBackgroundConsumption() { - if (backgroundConsumeStarted.compareAndSet(false, true)) { - LOGGER.debug("Starting background consumption for task {}", taskId.get()); - // Client disconnected: continue consuming and persisting events in background - CompletableFuture bgTask = CompletableFuture.runAsync(() -> { - try { - LOGGER.debug("Background consumption thread started for task {}", taskId.get()); - resultAggregator.consumeAll(consumer); - LOGGER.debug("Background consumption completed for task {}", taskId.get()); - } catch (Exception e) { - LOGGER.error("Error during background consumption for task {}", taskId.get(), e); - } - }, executor); - trackBackgroundTask(bgTask); - } else { - LOGGER.debug("Background consumption already started for task {}", taskId.get()); - } - } }); }; } finally { - LOGGER.debug("onMessageSendStream FINALLY - task: {}; runningAgents: {}; backgroundTasks: {}", - taskId.get(), runningAgents.size(), backgroundTasks.size()); - - // Remove agent from map immediately to prevent accumulation - CompletableFuture agentFuture = runningAgents.remove(taskId.get()); - LOGGER.debug("Removed agent for task {} from runningAgents in finally block, size after: {}", taskId.get(), runningAgents.size()); - - trackBackgroundTask(cleanupProducer(agentFuture, null, Objects.requireNonNull(taskId.get()), queue, true)); + // Needed to satisfy Nullaway + String idOfTask = taskId.get(); + if (idOfTask != null) { + LOGGER.debug("onMessageSendStream FINALLY - task: {}; runningAgents: {}", + idOfTask, runningAgents.size()); + + // Remove agent from map immediately to prevent accumulation + CompletableFuture agentFuture = runningAgents.remove(idOfTask); + LOGGER.debug("Removed agent for task {} from runningAgents in finally block, size after: {}", taskId.get(), runningAgents.size()); + + cleanupProducer(agentFuture, null, idOfTask, queue, true) + .whenComplete((res, err) -> { + if (err != null) { + LOGGER.error("Error during async cleanup for streaming task {}", taskId.get(), err); + } + }); + } } } @@ -746,7 +919,7 @@ public Flow.Publisher onResubscribeToTask( } TaskManager taskManager = new TaskManager(task.id(), task.contextId(), taskStore, null); - ResultAggregator resultAggregator = new ResultAggregator(taskManager, null, executor); + ResultAggregator resultAggregator = new ResultAggregator(taskManager, null, executor, eventConsumerExecutor); EventQueue queue = queueManager.tap(task.id()); LOGGER.debug("onResubscribeToTask - tapped queue: {}", queue != null ? System.identityHashCode(queue) : "null"); @@ -819,8 +992,7 @@ public void run() { LOGGER.debug("Agent execution starting for task {}", taskId); agentExecutor.execute(requestContext, queue); LOGGER.debug("Agent execution completed for task {}", taskId); - // No longer wait for queue poller to start - the consumer (which is guaranteed - // to be running on the Vert.x worker thread) will handle queue lifecycle. + // The consumer (running on the Vert.x worker thread) handles queue lifecycle. // This avoids blocking agent-executor threads waiting for worker threads. } }; @@ -833,8 +1005,8 @@ public void run() { // Don't close queue here - let the consumer handle it via error callback // This ensures the consumer (which may not have started polling yet) gets the error } - // Queue lifecycle is now managed entirely by EventConsumer.consumeAll() - // which closes the queue on final events. No need to close here. + // Queue lifecycle is managed by EventConsumer.consumeAll() + // which closes the queue on final events. logThreadStats("AGENT COMPLETE END"); runnable.invokeDoneCallbacks(); }); @@ -843,47 +1015,6 @@ public void run() { return runnable; } - private void trackBackgroundTask(CompletableFuture task) { - backgroundTasks.add(task); - LOGGER.debug("Tracking background task (total: {}): {}", backgroundTasks.size(), task); - - task.whenComplete((result, throwable) -> { - try { - if (throwable != null) { - // Unwrap CompletionException to check for CancellationException - Throwable cause = throwable; - if (throwable instanceof java.util.concurrent.CompletionException && throwable.getCause() != null) { - cause = throwable.getCause(); - } - - if (cause instanceof java.util.concurrent.CancellationException) { - LOGGER.debug("Background task cancelled: {}", task); - } else { - LOGGER.error("Background task failed", throwable); - } - } - } finally { - backgroundTasks.remove(task); - LOGGER.debug("Removed background task (remaining: {}): {}", backgroundTasks.size(), task); - } - }); - } - - /** - * Wait for all background tasks to complete. - * Useful for testing to ensure cleanup completes before assertions. - * - * @return CompletableFuture that completes when all background tasks finish - */ - public CompletableFuture waitForBackgroundTasks() { - CompletableFuture[] tasks = backgroundTasks.toArray(new CompletableFuture[0]); - if (tasks.length == 0) { - return CompletableFuture.completedFuture(null); - } - LOGGER.debug("Waiting for {} background tasks to complete", tasks.length); - return CompletableFuture.allOf(tasks); - } - private CompletableFuture cleanupProducer(@Nullable CompletableFuture agentFuture, @Nullable CompletableFuture consumptionFuture, String taskId, EventQueue queue, boolean isStreaming) { LOGGER.debug("Starting cleanup for task {} (streaming={})", taskId, isStreaming); logThreadStats("CLEANUP START"); @@ -908,14 +1039,20 @@ private CompletableFuture cleanupProducer(@Nullable CompletableFuture + * Enable independently with: {@code logging.level.io.a2a.server.diagnostics.ThreadStats=DEBUG} + *

    */ @SuppressWarnings("unused") // Used for debugging private void logThreadStats(String label) { // Early return if debug logging is not enabled to avoid overhead - if (!LOGGER.isDebugEnabled()) { + if (!THREAD_STATS_LOGGER.isDebugEnabled()) { return; } @@ -982,28 +1114,57 @@ private void logThreadStats(String label) { } int activeThreads = rootGroup.activeCount(); - LOGGER.debug("=== THREAD STATS: {} ===", label); - LOGGER.debug("Active threads: {}", activeThreads); - LOGGER.debug("Running agents: {}", runningAgents.size()); - LOGGER.debug("Background tasks: {}", backgroundTasks.size()); - LOGGER.debug("Queue manager active queues: {}", queueManager.getClass().getSimpleName()); + // Count specific thread types + Thread[] threads = new Thread[activeThreads * 2]; + int count = rootGroup.enumerate(threads); + int eventConsumerThreads = 0; + int agentExecutorThreads = 0; + for (int i = 0; i < count; i++) { + if (threads[i] != null) { + String name = threads[i].getName(); + if (name.startsWith("a2a-event-consumer-")) { + eventConsumerThreads++; + } else if (name.startsWith("a2a-agent-executor-")) { + agentExecutorThreads++; + } + } + } + + THREAD_STATS_LOGGER.debug("=== THREAD STATS: {} ===", label); + THREAD_STATS_LOGGER.debug("Total active threads: {}", activeThreads); + THREAD_STATS_LOGGER.debug("EventConsumer threads: {}", eventConsumerThreads); + THREAD_STATS_LOGGER.debug("AgentExecutor threads: {}", agentExecutorThreads); + THREAD_STATS_LOGGER.debug("Running agents: {}", runningAgents.size()); + THREAD_STATS_LOGGER.debug("Queue manager active queues: {}", queueManager.getClass().getSimpleName()); // List running agents if (!runningAgents.isEmpty()) { - LOGGER.debug("Running agent tasks:"); + THREAD_STATS_LOGGER.debug("Running agent tasks:"); runningAgents.forEach((taskId, future) -> - LOGGER.debug(" - Task {}: {}", taskId, future.isDone() ? "DONE" : "RUNNING") + THREAD_STATS_LOGGER.debug(" - Task {}: {}", taskId, future.isDone() ? "DONE" : "RUNNING") ); } - // List background tasks - if (!backgroundTasks.isEmpty()) { - LOGGER.debug("Background tasks:"); - backgroundTasks.forEach(task -> - LOGGER.debug(" - {}: {}", task, task.isDone() ? "DONE" : "RUNNING") - ); + THREAD_STATS_LOGGER.debug("=== END THREAD STATS ==="); + } + + /** + * Check if an event represents a final task state. + * + * @param eventKind the event to check + * @return true if the event represents a final state (COMPLETED, FAILED, CANCELED, REJECTED, UNKNOWN) + */ + private boolean isFinalEvent(EventKind eventKind) { + if (!(eventKind instanceof Event event)) { + return false; + } + if (event instanceof Task task) { + return task.status() != null && task.status().state() != null + && task.status().state().isFinal(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.isFinal(); } - LOGGER.debug("=== END THREAD STATS ==="); + return false; } private record MessageSendSetup(TaskManager taskManager, @Nullable Task task, RequestContext requestContext) {} diff --git a/server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java b/server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java index 95684e199..506b3f3b6 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java +++ b/server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java @@ -14,9 +14,9 @@ import io.a2a.server.events.EventConsumer; import io.a2a.server.events.EventQueueItem; import io.a2a.spec.A2AError; -import io.a2a.spec.A2AServerException; import io.a2a.spec.Event; import io.a2a.spec.EventKind; +import io.a2a.spec.InternalError; import io.a2a.spec.Message; import io.a2a.spec.Task; import io.a2a.spec.TaskState; @@ -31,12 +31,14 @@ public class ResultAggregator { private final TaskManager taskManager; private final Executor executor; + private final Executor eventConsumerExecutor; private volatile @Nullable Message message; - public ResultAggregator(TaskManager taskManager, @Nullable Message message, Executor executor) { + public ResultAggregator(TaskManager taskManager, @Nullable Message message, Executor executor, Executor eventConsumerExecutor) { this.taskManager = taskManager; this.message = message; this.executor = executor; + this.eventConsumerExecutor = eventConsumerExecutor; } public @Nullable EventKind getCurrentResult() { @@ -49,20 +51,23 @@ public ResultAggregator(TaskManager taskManager, @Nullable Message message, Exec public Flow.Publisher consumeAndEmit(EventConsumer consumer) { Flow.Publisher allItems = consumer.consumeAll(); - // Process items conditionally - only save non-replicated events to database - return processor(createTubeConfig(), allItems, (errorConsumer, item) -> { - // Only process non-replicated events to avoid duplicate database writes - if (!item.isReplicated()) { - try { - callTaskManagerProcess(item.getEvent()); - } catch (A2AServerException e) { - errorConsumer.accept(e); - return false; - } - } - // Continue processing and emit (both replicated and non-replicated) + // Just stream events - no persistence needed + // TaskStore update moved to MainEventBusProcessor + Flow.Publisher processed = processor(createTubeConfig(), allItems, (errorConsumer, item) -> { + // Continue processing and emit all events return true; }); + + // Wrap the publisher to ensure subscription happens on eventConsumerExecutor + // This prevents EventConsumer polling loop from running on AgentExecutor threads + // which caused thread accumulation when those threads didn't timeout + return new Flow.Publisher() { + @Override + public void subscribe(Flow.Subscriber subscriber) { + // Submit subscription to eventConsumerExecutor to isolate polling work + eventConsumerExecutor.execute(() -> processed.subscribe(subscriber)); + } + }; } public EventKind consumeAll(EventConsumer consumer) throws A2AError { @@ -81,15 +86,7 @@ public EventKind consumeAll(EventConsumer consumer) throws A2AError { return false; } } - // Only process non-replicated events to avoid duplicate database writes - if (!item.isReplicated()) { - try { - callTaskManagerProcess(event); - } catch (A2AServerException e) { - error.set(e); - return false; - } - } + // TaskStore update moved to MainEventBusProcessor return true; }, error::set); @@ -113,18 +110,24 @@ public EventKind consumeAll(EventConsumer consumer) throws A2AError { public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer, boolean blocking) throws A2AError { Flow.Publisher allItems = consumer.consumeAll(); AtomicReference message = new AtomicReference<>(); + AtomicReference capturedTask = new AtomicReference<>(); // Capture Task events AtomicBoolean interrupted = new AtomicBoolean(false); AtomicReference errorRef = new AtomicReference<>(); CompletableFuture completionFuture = new CompletableFuture<>(); // Separate future for tracking background consumption completion CompletableFuture consumptionCompletionFuture = new CompletableFuture<>(); + // Latch to ensure EventConsumer starts polling before we wait on completionFuture + java.util.concurrent.CountDownLatch pollingStarted = new java.util.concurrent.CountDownLatch(1); // CRITICAL: The subscription itself must run on a background thread to avoid blocking // the Vert.x worker thread. EventConsumer.consumeAll() starts a polling loop that // blocks in dequeueEventItem(), so we must subscribe from a background thread. - // Use the @Internal executor (not ForkJoinPool.commonPool) to avoid saturation - // during concurrent request bursts. + // Use the dedicated @EventConsumerExecutor (cached thread pool) which creates threads + // on demand for I/O-bound polling. Using the @Internal executor caused deadlock when + // pool exhausted (100+ concurrent queues but maxPoolSize=50). CompletableFuture.runAsync(() -> { + // Signal that polling is about to start + pollingStarted.countDown(); consumer( createTubeConfig(), allItems, @@ -146,25 +149,30 @@ public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer, return false; } - // Process event through TaskManager - only for non-replicated events - if (!item.isReplicated()) { - try { - callTaskManagerProcess(event); - } catch (A2AServerException e) { - errorRef.set(e); - completionFuture.completeExceptionally(e); - return false; + // Capture Task events (especially for new tasks where taskManager.getTask() would return null) + // We capture the LATEST task to ensure we get the most up-to-date state + if (event instanceof Task t) { + Task previousTask = capturedTask.get(); + capturedTask.set(t); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Captured Task event: id={}, state={} (previous: {})", + t.id(), t.status().state(), + previousTask != null ? previousTask.id() + "/" + previousTask.status().state() : "none"); } } + // TaskStore update moved to MainEventBusProcessor + // Determine interrupt behavior boolean shouldInterrupt = false; - boolean continueInBackground = false; boolean isFinalEvent = (event instanceof Task task && task.status().state().isFinal()) || (event instanceof TaskStatusUpdateEvent tsue && tsue.isFinal()); boolean isAuthRequired = (event instanceof Task task && task.status().state() == TaskState.AUTH_REQUIRED) || (event instanceof TaskStatusUpdateEvent tsue && tsue.status().state() == TaskState.AUTH_REQUIRED); + LOGGER.debug("ResultAggregator: Evaluating interrupt (blocking={}, isFinal={}, isAuth={}, eventType={})", + blocking, isFinalEvent, isAuthRequired, event.getClass().getSimpleName()); + // Always interrupt on auth_required, as it needs external action. if (isAuthRequired) { // auth-required is a special state: the message should be @@ -174,20 +182,19 @@ public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer, // new request is expected in order for the agent to make progress, // so the agent should exit. shouldInterrupt = true; - continueInBackground = true; + LOGGER.debug("ResultAggregator: Setting shouldInterrupt=true (AUTH_REQUIRED)"); } else if (!blocking) { // For non-blocking calls, interrupt as soon as a task is available. shouldInterrupt = true; - continueInBackground = true; + LOGGER.debug("ResultAggregator: Setting shouldInterrupt=true (non-blocking)"); } else if (blocking) { // For blocking calls: Interrupt to free Vert.x thread, but continue in background // Python's async consumption doesn't block threads, but Java's does // So we interrupt to return quickly, then rely on background consumption - // DefaultRequestHandler will fetch the final state from TaskStore shouldInterrupt = true; - continueInBackground = true; + LOGGER.debug("ResultAggregator: Setting shouldInterrupt=true (blocking, isFinal={})", isFinalEvent); if (LOGGER.isDebugEnabled()) { LOGGER.debug("Blocking call for task {}: {} event, returning with background consumption", taskIdForLogging(), isFinalEvent ? "final" : "non-final"); @@ -195,14 +202,14 @@ else if (blocking) { } if (shouldInterrupt) { + LOGGER.debug("ResultAggregator: Interrupting consumption (setting interrupted=true)"); // Complete the future to unblock the main thread interrupted.set(true); completionFuture.complete(null); // For blocking calls, DON'T complete consumptionCompletionFuture here. // Let it complete naturally when subscription finishes (onComplete callback below). - // This ensures all events are processed and persisted to TaskStore before - // DefaultRequestHandler.cleanupProducer() proceeds with cleanup. + // This ensures all events are fully processed before cleanup. // // For non-blocking and auth-required calls, complete immediately to allow // cleanup to proceed while consumption continues in background. @@ -237,7 +244,16 @@ else if (blocking) { } } ); - }, executor); + }, eventConsumerExecutor); + + // Wait for EventConsumer to start polling before we wait for events + // This prevents race where agent enqueues events before EventConsumer starts + try { + pollingStarted.await(5, java.util.concurrent.TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new io.a2a.spec.InternalError("Interrupted waiting for EventConsumer to start"); + } // Wait for completion or interruption try { @@ -261,28 +277,30 @@ else if (blocking) { Utils.rethrow(error); } - EventKind eventType; - Message msg = message.get(); - if (msg != null) { - eventType = msg; - } else { - Task task = taskManager.getTask(); - if (task == null) { - throw new io.a2a.spec.InternalError("No task or message available after consuming events"); + // Return Message if captured, otherwise Task if captured, otherwise fetch from TaskStore + EventKind eventKind = message.get(); + if (eventKind == null) { + eventKind = capturedTask.get(); + if (LOGGER.isDebugEnabled() && eventKind instanceof Task t) { + LOGGER.debug("Returning capturedTask: id={}, state={}", t.id(), t.status().state()); } - eventType = task; + } + if (eventKind == null) { + eventKind = taskManager.getTask(); + if (LOGGER.isDebugEnabled() && eventKind instanceof Task t) { + LOGGER.debug("Returning task from TaskStore: id={}, state={}", t.id(), t.status().state()); + } + } + if (eventKind == null) { + throw new InternalError("Could not find a Task/Message for " + taskManager.getTaskId()); } return new EventTypeAndInterrupt( - eventType, + eventKind, interrupted.get(), consumptionCompletionFuture); } - private void callTaskManagerProcess(Event event) throws A2AServerException { - taskManager.process(event); - } - private String taskIdForLogging() { Task task = taskManager.getTask(); return task != null ? task.id() : "unknown"; diff --git a/server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java b/server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java index e26dd55fb..eee254ba3 100644 --- a/server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java +++ b/server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java @@ -1,8 +1,8 @@ package io.a2a.server.util.async; +import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; @@ -26,6 +26,7 @@ public class AsyncExecutorProducer { private static final String A2A_EXECUTOR_CORE_POOL_SIZE = "a2a.executor.core-pool-size"; private static final String A2A_EXECUTOR_MAX_POOL_SIZE = "a2a.executor.max-pool-size"; private static final String A2A_EXECUTOR_KEEP_ALIVE_SECONDS = "a2a.executor.keep-alive-seconds"; + private static final String A2A_EXECUTOR_QUEUE_CAPACITY = "a2a.executor.queue-capacity"; @Inject A2AConfigProvider configProvider; @@ -57,6 +58,16 @@ public class AsyncExecutorProducer { */ long keepAliveSeconds; + /** + * Queue capacity for pending tasks. + *

    + * Property: {@code a2a.executor.queue-capacity}
    + * Default: 100
    + * Note: Must be bounded to allow pool growth to maxPoolSize. + * When queue is full, new threads are created up to maxPoolSize. + */ + int queueCapacity; + private @Nullable ExecutorService executor; @PostConstruct @@ -64,18 +75,34 @@ public void init() { corePoolSize = Integer.parseInt(configProvider.getValue(A2A_EXECUTOR_CORE_POOL_SIZE)); maxPoolSize = Integer.parseInt(configProvider.getValue(A2A_EXECUTOR_MAX_POOL_SIZE)); keepAliveSeconds = Long.parseLong(configProvider.getValue(A2A_EXECUTOR_KEEP_ALIVE_SECONDS)); - - LOGGER.info("Initializing async executor: corePoolSize={}, maxPoolSize={}, keepAliveSeconds={}", - corePoolSize, maxPoolSize, keepAliveSeconds); - - executor = new ThreadPoolExecutor( + queueCapacity = Integer.parseInt(configProvider.getValue(A2A_EXECUTOR_QUEUE_CAPACITY)); + + LOGGER.info("Initializing async executor: corePoolSize={}, maxPoolSize={}, keepAliveSeconds={}, queueCapacity={}", + corePoolSize, maxPoolSize, keepAliveSeconds, queueCapacity); + + // CRITICAL: Use ArrayBlockingQueue (bounded) instead of LinkedBlockingQueue (unbounded). + // With unbounded queue, ThreadPoolExecutor NEVER grows beyond corePoolSize because the + // queue never fills. This causes executor pool exhaustion during concurrent requests when + // EventConsumer polling threads hold all core threads and agent tasks queue indefinitely. + // Bounded queue enables pool growth: when queue is full, new threads are created up to + // maxPoolSize, preventing agent execution starvation. + ThreadPoolExecutor tpe = new ThreadPoolExecutor( corePoolSize, maxPoolSize, keepAliveSeconds, TimeUnit.SECONDS, - new LinkedBlockingQueue<>(), + new ArrayBlockingQueue<>(queueCapacity), new A2AThreadFactory() ); + + // CRITICAL: Allow core threads to timeout after keepAliveSeconds when idle. + // By default, ThreadPoolExecutor only times out threads above corePoolSize. + // Without this, core threads accumulate during testing and never clean up. + // This is essential for streaming scenarios where many short-lived tasks create threads + // for agent execution and cleanup callbacks, but those threads remain idle afterward. + tpe.allowCoreThreadTimeOut(true); + + executor = tpe; } @PreDestroy @@ -106,6 +133,22 @@ public Executor produce() { return executor; } + /** + * Log current executor pool statistics for diagnostics. + * Useful for debugging pool exhaustion or sizing issues. + */ + public void logPoolStats() { + if (executor instanceof ThreadPoolExecutor tpe) { + LOGGER.info("Executor pool stats: active={}/{}, queued={}/{}, completed={}, total={}", + tpe.getActiveCount(), + tpe.getPoolSize(), + tpe.getQueue().size(), + queueCapacity, + tpe.getCompletedTaskCount(), + tpe.getTaskCount()); + } + } + private static class A2AThreadFactory implements ThreadFactory { private final AtomicInteger threadNumber = new AtomicInteger(1); private final String namePrefix = "a2a-agent-executor-"; diff --git a/server-common/src/main/java/io/a2a/server/util/async/EventConsumerExecutorProducer.java b/server-common/src/main/java/io/a2a/server/util/async/EventConsumerExecutorProducer.java new file mode 100644 index 000000000..24ff7f5d1 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/util/async/EventConsumerExecutorProducer.java @@ -0,0 +1,93 @@ +package io.a2a.server.util.async; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Produces; +import jakarta.inject.Qualifier; + +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static java.lang.annotation.ElementType.*; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +/** + * Produces a dedicated executor for EventConsumer polling threads. + *

    + * CRITICAL: EventConsumer polling must use a separate executor from AgentExecutor because: + *

      + *
    • EventConsumer threads are I/O-bound (blocking on queue.poll()), not CPU-bound
    • + *
    • One EventConsumer thread needed per active queue (can be 100+ concurrent)
    • + *
    • Threads are mostly idle, waiting for events
    • + *
    • Using the same bounded pool as AgentExecutor causes deadlock when pool exhausted
    • + *
    + *

    + * Uses a cached thread pool (unbounded) with automatic thread reclamation: + *

      + *
    • Creates threads on demand as EventConsumers start
    • + *
    • Idle threads automatically terminated after 10 seconds
    • + *
    • No queue saturation since threads are created as needed
    • + *
    + */ +@ApplicationScoped +public class EventConsumerExecutorProducer { + private static final Logger LOGGER = LoggerFactory.getLogger(EventConsumerExecutorProducer.class); + + /** + * Qualifier annotation for EventConsumer executor injection. + */ + @Retention(RUNTIME) + @Target({METHOD, FIELD, PARAMETER, TYPE}) + @Qualifier + public @interface EventConsumerExecutor { + } + + /** + * Thread factory for EventConsumer threads. + */ + private static class EventConsumerThreadFactory implements ThreadFactory { + private final AtomicInteger threadNumber = new AtomicInteger(1); + + @Override + public Thread newThread(Runnable r) { + Thread thread = new Thread(r, "a2a-event-consumer-" + threadNumber.getAndIncrement()); + thread.setDaemon(true); + return thread; + } + } + + private @Nullable ExecutorService executor; + + @Produces + @EventConsumerExecutor + @ApplicationScoped + public Executor eventConsumerExecutor() { + // Cached thread pool with 10s idle timeout (reduced from default 60s): + // - Creates threads on demand as EventConsumers start + // - Reclaims idle threads after 10s to prevent accumulation during fast test execution + // - Perfect for I/O-bound EventConsumer polling which blocks on queue.poll() + // - 10s timeout balances thread reuse (production) vs cleanup (testing) + executor = new ThreadPoolExecutor( + 0, // corePoolSize - no core threads + Integer.MAX_VALUE, // maxPoolSize - unbounded + 10, TimeUnit.SECONDS, // keepAliveTime - 10s idle timeout + new SynchronousQueue<>(), // queue - same as cached pool + new EventConsumerThreadFactory() + ); + + LOGGER.info("Initialized EventConsumer executor: cached thread pool (unbounded, 10s idle timeout)"); + + return executor; + } +} diff --git a/server-common/src/main/java/io/a2a/server/util/sse/SseFormatter.java b/server-common/src/main/java/io/a2a/server/util/sse/SseFormatter.java new file mode 100644 index 000000000..737fbac23 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/util/sse/SseFormatter.java @@ -0,0 +1,136 @@ +package io.a2a.server.util.sse; + +import io.a2a.grpc.utils.JSONRPCUtils; +import io.a2a.jsonrpc.common.wrappers.A2AErrorResponse; +import io.a2a.jsonrpc.common.wrappers.A2AResponse; +import io.a2a.jsonrpc.common.wrappers.CancelTaskResponse; +import io.a2a.jsonrpc.common.wrappers.DeleteTaskPushNotificationConfigResponse; +import io.a2a.jsonrpc.common.wrappers.GetExtendedAgentCardResponse; +import io.a2a.jsonrpc.common.wrappers.GetTaskPushNotificationConfigResponse; +import io.a2a.jsonrpc.common.wrappers.GetTaskResponse; +import io.a2a.jsonrpc.common.wrappers.ListTaskPushNotificationConfigResponse; +import io.a2a.jsonrpc.common.wrappers.ListTasksResponse; +import io.a2a.jsonrpc.common.wrappers.SendMessageResponse; +import io.a2a.jsonrpc.common.wrappers.SendStreamingMessageResponse; +import io.a2a.jsonrpc.common.wrappers.SetTaskPushNotificationConfigResponse; + +/** + * Framework-agnostic utility for formatting A2A responses as Server-Sent Events (SSE). + *

    + * Provides static methods to serialize A2A responses to JSON and format them as SSE events. + * This allows HTTP server frameworks (Vert.x, Jakarta/WildFly, etc.) to use their own + * reactive libraries for publisher mapping while sharing the serialization logic. + *

    + * Example usage (Quarkus/Vert.x with Mutiny): + *

    {@code
    + * Flow.Publisher> responses = handler.onMessageSendStream(request, context);
    + * AtomicLong eventId = new AtomicLong(0);
    + *
    + * Multi sseEvents = Multi.createFrom().publisher(responses)
    + *     .map(response -> SseFormatter.formatResponseAsSSE(response, eventId.getAndIncrement()));
    + *
    + * sseEvents.subscribe().with(sseEvent -> httpResponse.write(Buffer.buffer(sseEvent)));
    + * }
    + *

    + * Example usage (Jakarta/WildFly with custom reactive library): + *

    {@code
    + * Flow.Publisher jsonStrings = restHandler.getJsonPublisher();
    + * AtomicLong eventId = new AtomicLong(0);
    + *
    + * Flow.Publisher sseEvents = mapPublisher(jsonStrings,
    + *     json -> SseFormatter.formatJsonAsSSE(json, eventId.getAndIncrement()));
    + * }
    + */ +public class SseFormatter { + + private SseFormatter() { + // Utility class - prevent instantiation + } + + /** + * Format an A2A response as an SSE event. + *

    + * Serializes the response to JSON and formats as: + *

    +     * data: {"jsonrpc":"2.0","result":{...},"id":123}
    +     * id: 0
    +     *
    +     * 
    + * + * @param response the A2A response to format + * @param eventId the SSE event ID + * @return SSE-formatted string (ready to write to HTTP response) + */ + public static String formatResponseAsSSE(A2AResponse response, long eventId) { + String jsonData = serializeResponse(response); + return "data: " + jsonData + "\nid: " + eventId + "\n\n"; + } + + /** + * Format a pre-serialized JSON string as an SSE event. + *

    + * Wraps the JSON in SSE format as: + *

    +     * data: {"jsonrpc":"2.0","result":{...},"id":123}
    +     * id: 0
    +     *
    +     * 
    + *

    + * Use this when you already have JSON strings (e.g., from REST transport) + * and just need to add SSE formatting. + * + * @param jsonString the JSON string to wrap + * @param eventId the SSE event ID + * @return SSE-formatted string (ready to write to HTTP response) + */ + public static String formatJsonAsSSE(String jsonString, long eventId) { + return "data: " + jsonString + "\nid: " + eventId + "\n\n"; + } + + /** + * Serialize an A2AResponse to JSON string. + */ + private static String serializeResponse(A2AResponse response) { + // For error responses, use standard JSON-RPC error format + if (response instanceof A2AErrorResponse error) { + return JSONRPCUtils.toJsonRPCErrorResponse(error.getId(), error.getError()); + } + if (response.getError() != null) { + return JSONRPCUtils.toJsonRPCErrorResponse(response.getId(), response.getError()); + } + + // Convert domain response to protobuf message and serialize + com.google.protobuf.MessageOrBuilder protoMessage = convertToProto(response); + return JSONRPCUtils.toJsonRPCResultResponse(response.getId(), protoMessage); + } + + /** + * Convert A2AResponse to protobuf message for serialization. + */ + private static com.google.protobuf.MessageOrBuilder convertToProto(A2AResponse response) { + if (response instanceof GetTaskResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.task(r.getResult()); + } else if (response instanceof CancelTaskResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.task(r.getResult()); + } else if (response instanceof SendMessageResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.taskOrMessage(r.getResult()); + } else if (response instanceof ListTasksResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.listTasksResult(r.getResult()); + } else if (response instanceof SetTaskPushNotificationConfigResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.setTaskPushNotificationConfigResponse(r.getResult()); + } else if (response instanceof GetTaskPushNotificationConfigResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.getTaskPushNotificationConfigResponse(r.getResult()); + } else if (response instanceof ListTaskPushNotificationConfigResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.listTaskPushNotificationConfigResponse(r.getResult()); + } else if (response instanceof DeleteTaskPushNotificationConfigResponse) { + // DeleteTaskPushNotificationConfig has no result body, just return empty message + return com.google.protobuf.Empty.getDefaultInstance(); + } else if (response instanceof GetExtendedAgentCardResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.getExtendedCardResponse(r.getResult()); + } else if (response instanceof SendStreamingMessageResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.taskOrMessageStream(r.getResult()); + } else { + throw new IllegalArgumentException("Unknown response type: " + response.getClass().getName()); + } + } +} diff --git a/server-common/src/main/java/io/a2a/server/util/sse/package-info.java b/server-common/src/main/java/io/a2a/server/util/sse/package-info.java new file mode 100644 index 000000000..7e668b632 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/util/sse/package-info.java @@ -0,0 +1,11 @@ +/** + * Server-Sent Events (SSE) formatting utilities for A2A streaming responses. + *

    + * Provides framework-agnostic conversion of {@code Flow.Publisher>} to + * {@code Flow.Publisher} with SSE formatting, enabling easy integration with + * any HTTP server framework (Vert.x, Jakarta Servlet, etc.). + */ +@NullMarked +package io.a2a.server.util.sse; + +import org.jspecify.annotations.NullMarked; diff --git a/server-common/src/main/resources/META-INF/a2a-defaults.properties b/server-common/src/main/resources/META-INF/a2a-defaults.properties index 280fd943b..719be9e7a 100644 --- a/server-common/src/main/resources/META-INF/a2a-defaults.properties +++ b/server-common/src/main/resources/META-INF/a2a-defaults.properties @@ -19,3 +19,7 @@ a2a.executor.max-pool-size=50 # Keep-alive time for idle threads (seconds) a2a.executor.keep-alive-seconds=60 + +# Queue capacity for pending tasks (must be bounded to enable pool growth) +# When queue is full, new threads are created up to max-pool-size +a2a.executor.queue-capacity=100 diff --git a/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java b/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java index 4354f1639..3c84bb2ae 100644 --- a/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java +++ b/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java @@ -16,6 +16,8 @@ import java.util.concurrent.atomic.AtomicReference; import io.a2a.jsonrpc.common.json.JsonProcessingException; +import io.a2a.server.tasks.InMemoryTaskStore; +import io.a2a.server.tasks.PushNotificationSender; import io.a2a.spec.A2AError; import io.a2a.spec.A2AServerException; import io.a2a.spec.Artifact; @@ -27,14 +29,19 @@ import io.a2a.spec.TaskStatus; import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; public class EventConsumerTest { + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; + private static final String TASK_ID = "123"; // Must match MINIMAL_TASK id + private EventQueue eventQueue; private EventConsumer eventConsumer; - + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; private static final String MINIMAL_TASK = """ { @@ -54,10 +61,58 @@ public class EventConsumerTest { @BeforeEach public void init() { - eventQueue = EventQueue.builder().build(); + // Set up MainEventBus and processor for production-like test environment + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER); + EventQueueUtil.start(mainEventBusProcessor); + + eventQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TASK_ID) + .mainEventBus(mainEventBus) + .build().tap(); eventConsumer = new EventConsumer(eventQueue); } + @AfterEach + public void cleanup() { + if (mainEventBusProcessor != null) { + mainEventBusProcessor.setCallback(null); // Clear any test callbacks + EventQueueUtil.stop(mainEventBusProcessor); + } + } + + /** + * Helper to wait for MainEventBusProcessor to process an event. + * Replaces polling patterns with deterministic callback-based waiting. + * + * @param action the action that triggers event processing + * @throws InterruptedException if waiting is interrupted + * @throws AssertionError if processing doesn't complete within timeout + */ + private void waitForEventProcessing(Runnable action) throws InterruptedException { + CountDownLatch processingLatch = new CountDownLatch(1); + mainEventBusProcessor.setCallback(new MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, Event event) { + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String taskId) { + // Not needed for basic event processing wait + } + }); + + try { + action.run(); + assertTrue(processingLatch.await(5, TimeUnit.SECONDS), + "MainEventBusProcessor should have processed the event within timeout"); + } finally { + mainEventBusProcessor.setCallback(null); + } + } + @Test public void testConsumeOneTaskEvent() throws Exception { Task event = fromJson(MINIMAL_TASK, Task.class); @@ -92,7 +147,7 @@ public void testConsumeAllMultipleEvents() throws JsonProcessingException { List events = List.of( fromJson(MINIMAL_TASK, Task.class), TaskArtifactUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .artifact(Artifact.builder() .artifactId("11") @@ -100,7 +155,7 @@ public void testConsumeAllMultipleEvents() throws JsonProcessingException { .build()) .build(), TaskStatusUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .status(new TaskStatus(TaskState.WORKING)) .isFinal(true) @@ -128,7 +183,7 @@ public void testConsumeUntilMessage() throws Exception { List events = List.of( fromJson(MINIMAL_TASK, Task.class), TaskArtifactUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .artifact(Artifact.builder() .artifactId("11") @@ -136,7 +191,7 @@ public void testConsumeUntilMessage() throws Exception { .build()) .build(), TaskStatusUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .status(new TaskStatus(TaskState.WORKING)) .isFinal(true) @@ -185,14 +240,14 @@ public void testConsumeMessageEvents() throws Exception { @Test public void testConsumeTaskInputRequired() { Task task = Task.builder() - .id("task-id") - .contextId("task-context") + .id(TASK_ID) + .contextId("session-xyz") .status(new TaskStatus(TaskState.INPUT_REQUIRED)) .build(); List events = List.of( task, TaskArtifactUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .artifact(Artifact.builder() .artifactId("11") @@ -200,7 +255,7 @@ public void testConsumeTaskInputRequired() { .build()) .build(), TaskStatusUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .status(new TaskStatus(TaskState.WORKING)) .isFinal(true) @@ -332,7 +387,9 @@ public void onComplete() { @Test public void testConsumeAllStopsOnQueueClosed() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .mainEventBus(mainEventBus) + .build().tap(); EventConsumer consumer = new EventConsumer(queue); // Close the queue immediately @@ -378,12 +435,16 @@ public void onComplete() { @Test public void testConsumeAllHandlesQueueClosedException() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .mainEventBus(mainEventBus) + .build().tap(); EventConsumer consumer = new EventConsumer(queue); // Add a message event (which will complete the stream) Event message = fromJson(MESSAGE_PAYLOAD, Message.class); - queue.enqueueEvent(message); + + // Use callback to wait for event processing + waitForEventProcessing(() -> queue.enqueueEvent(message)); // Close the queue before consuming queue.close(); @@ -428,11 +489,13 @@ public void onComplete() { @Test public void testConsumeAllTerminatesOnQueueClosedEvent() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .mainEventBus(mainEventBus) + .build().tap(); EventConsumer consumer = new EventConsumer(queue); // Enqueue a QueueClosedEvent (poison pill) - QueueClosedEvent queueClosedEvent = new QueueClosedEvent("task-123"); + QueueClosedEvent queueClosedEvent = new QueueClosedEvent(TASK_ID); queue.enqueueEvent(queueClosedEvent); Flow.Publisher publisher = consumer.consumeAll(); @@ -477,8 +540,12 @@ public void onComplete() { } private void enqueueAndConsumeOneEvent(Event event) throws Exception { - eventQueue.enqueueEvent(event); + // Use callback to wait for event processing + waitForEventProcessing(() -> eventQueue.enqueueEvent(event)); + + // Event is now available, consume it directly Event result = eventConsumer.consumeOne(); + assertNotNull(result, "Event should be available"); assertSame(event, result); } diff --git a/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java b/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java index a3dc7d916..daf0c1dc9 100644 --- a/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java +++ b/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java @@ -11,7 +11,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import io.a2a.server.tasks.InMemoryTaskStore; +import io.a2a.server.tasks.PushNotificationSender; import io.a2a.spec.A2AError; import io.a2a.spec.Artifact; import io.a2a.spec.Event; @@ -23,12 +27,17 @@ import io.a2a.spec.TaskStatus; import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; public class EventQueueTest { private EventQueue eventQueue; + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; + + private static final String TASK_ID = "123"; // Must match MINIMAL_TASK id private static final String MINIMAL_TASK = """ { @@ -46,38 +55,95 @@ public class EventQueueTest { } """; + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; @BeforeEach public void init() { - eventQueue = EventQueue.builder().build(); + // Set up MainEventBus and processor for production-like test environment + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER); + EventQueueUtil.start(mainEventBusProcessor); + + eventQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TASK_ID) + .mainEventBus(mainEventBus) + .build().tap(); + } + + @AfterEach + public void cleanup() { + if (mainEventBusProcessor != null) { + mainEventBusProcessor.setCallback(null); // Clear any test callbacks + EventQueueUtil.stop(mainEventBusProcessor); + } + } + /** + * Helper to create a queue with MainEventBus configured (for tests that need event distribution). + */ + private EventQueue createQueueWithEventBus(String taskId) { + return EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(taskId) + .build(); + } + + /** + * Helper to wait for MainEventBusProcessor to process an event. + * Replaces polling patterns with deterministic callback-based waiting. + * + * @param action the action that triggers event processing + * @throws InterruptedException if waiting is interrupted + * @throws AssertionError if processing doesn't complete within timeout + */ + private void waitForEventProcessing(Runnable action) throws InterruptedException { + CountDownLatch processingLatch = new CountDownLatch(1); + mainEventBusProcessor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, io.a2a.spec.Event event) { + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String taskId) { + // Not needed for basic event processing wait + } + }); + + try { + action.run(); + assertTrue(processingLatch.await(5, TimeUnit.SECONDS), + "MainEventBusProcessor should have processed the event within timeout"); + } finally { + mainEventBusProcessor.setCallback(null); + } } @Test public void testConstructorDefaultQueueSize() { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); assertEquals(EventQueue.DEFAULT_QUEUE_SIZE, queue.getQueueSize()); } @Test public void testConstructorCustomQueueSize() { int customSize = 500; - EventQueue queue = EventQueue.builder().queueSize(customSize).build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).queueSize(customSize).build(); assertEquals(customSize, queue.getQueueSize()); } @Test public void testConstructorInvalidQueueSize() { // Test zero queue size - assertThrows(IllegalArgumentException.class, () -> EventQueue.builder().queueSize(0).build()); + assertThrows(IllegalArgumentException.class, () -> EventQueueUtil.getEventQueueBuilder(mainEventBus).queueSize(0).build()); // Test negative queue size - assertThrows(IllegalArgumentException.class, () -> EventQueue.builder().queueSize(-10).build()); + assertThrows(IllegalArgumentException.class, () -> EventQueueUtil.getEventQueueBuilder(mainEventBus).queueSize(-10).build()); } @Test public void testTapCreatesChildQueue() { - EventQueue parentQueue = EventQueue.builder().build(); + EventQueue parentQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); EventQueue childQueue = parentQueue.tap(); assertNotNull(childQueue); @@ -87,7 +153,7 @@ public void testTapCreatesChildQueue() { @Test public void testTapOnChildQueueThrowsException() { - EventQueue parentQueue = EventQueue.builder().build(); + EventQueue parentQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); EventQueue childQueue = parentQueue.tap(); assertThrows(IllegalStateException.class, () -> childQueue.tap()); @@ -95,69 +161,74 @@ public void testTapOnChildQueueThrowsException() { @Test public void testEnqueueEventPropagagesToChildren() throws Exception { - EventQueue parentQueue = EventQueue.builder().build(); - EventQueue childQueue = parentQueue.tap(); + EventQueue mainQueue = createQueueWithEventBus(TASK_ID); + EventQueue childQueue1 = mainQueue.tap(); + EventQueue childQueue2 = mainQueue.tap(); Event event = fromJson(MINIMAL_TASK, Task.class); - parentQueue.enqueueEvent(event); + mainQueue.enqueueEvent(event); - // Event should be available in both parent and child queues - Event parentEvent = parentQueue.dequeueEventItem(-1).getEvent(); - Event childEvent = childQueue.dequeueEventItem(-1).getEvent(); + // Event should be available in all child queues + // Note: MainEventBusProcessor runs async, so we use dequeueEventItem with timeout + Event child1Event = childQueue1.dequeueEventItem(5000).getEvent(); + Event child2Event = childQueue2.dequeueEventItem(5000).getEvent(); - assertSame(event, parentEvent); - assertSame(event, childEvent); + assertSame(event, child1Event); + assertSame(event, child2Event); } @Test public void testMultipleChildQueuesReceiveEvents() throws Exception { - EventQueue parentQueue = EventQueue.builder().build(); - EventQueue childQueue1 = parentQueue.tap(); - EventQueue childQueue2 = parentQueue.tap(); + EventQueue mainQueue = createQueueWithEventBus(TASK_ID); + EventQueue childQueue1 = mainQueue.tap(); + EventQueue childQueue2 = mainQueue.tap(); + EventQueue childQueue3 = mainQueue.tap(); Event event1 = fromJson(MINIMAL_TASK, Task.class); Event event2 = fromJson(MESSAGE_PAYLOAD, Message.class); - parentQueue.enqueueEvent(event1); - parentQueue.enqueueEvent(event2); + mainQueue.enqueueEvent(event1); + mainQueue.enqueueEvent(event2); - // All queues should receive both events - assertSame(event1, parentQueue.dequeueEventItem(-1).getEvent()); - assertSame(event2, parentQueue.dequeueEventItem(-1).getEvent()); + // All child queues should receive both events + // Note: Use timeout for async processing + assertSame(event1, childQueue1.dequeueEventItem(5000).getEvent()); + assertSame(event2, childQueue1.dequeueEventItem(5000).getEvent()); - assertSame(event1, childQueue1.dequeueEventItem(-1).getEvent()); - assertSame(event2, childQueue1.dequeueEventItem(-1).getEvent()); + assertSame(event1, childQueue2.dequeueEventItem(5000).getEvent()); + assertSame(event2, childQueue2.dequeueEventItem(5000).getEvent()); - assertSame(event1, childQueue2.dequeueEventItem(-1).getEvent()); - assertSame(event2, childQueue2.dequeueEventItem(-1).getEvent()); + assertSame(event1, childQueue3.dequeueEventItem(5000).getEvent()); + assertSame(event2, childQueue3.dequeueEventItem(5000).getEvent()); } @Test public void testChildQueueDequeueIndependently() throws Exception { - EventQueue parentQueue = EventQueue.builder().build(); - EventQueue childQueue1 = parentQueue.tap(); - EventQueue childQueue2 = parentQueue.tap(); + EventQueue mainQueue = createQueueWithEventBus(TASK_ID); + EventQueue childQueue1 = mainQueue.tap(); + EventQueue childQueue2 = mainQueue.tap(); + EventQueue childQueue3 = mainQueue.tap(); Event event = fromJson(MINIMAL_TASK, Task.class); - parentQueue.enqueueEvent(event); + mainQueue.enqueueEvent(event); - // Dequeue from child1 first - Event child1Event = childQueue1.dequeueEventItem(-1).getEvent(); + // Dequeue from child1 first (use timeout for async processing) + Event child1Event = childQueue1.dequeueEventItem(5000).getEvent(); assertSame(event, child1Event); // child2 should still have the event available - Event child2Event = childQueue2.dequeueEventItem(-1).getEvent(); + Event child2Event = childQueue2.dequeueEventItem(5000).getEvent(); assertSame(event, child2Event); - // Parent should still have the event available - Event parentEvent = parentQueue.dequeueEventItem(-1).getEvent(); - assertSame(event, parentEvent); + // child3 should still have the event available + Event child3Event = childQueue3.dequeueEventItem(5000).getEvent(); + assertSame(event, child3Event); } @Test public void testCloseImmediatePropagationToChildren() throws Exception { - EventQueue parentQueue = EventQueue.builder().build(); + EventQueue parentQueue = createQueueWithEventBus(TASK_ID); EventQueue childQueue = parentQueue.tap(); // Add events to both parent and child @@ -166,7 +237,7 @@ public void testCloseImmediatePropagationToChildren() throws Exception { assertFalse(childQueue.isClosed()); try { - assertNotNull(childQueue.dequeueEventItem(-1)); // Child has the event + assertNotNull(childQueue.dequeueEventItem(5000)); // Child has the event (use timeout) } catch (EventQueueClosedException e) { // This is fine if queue closed before dequeue } @@ -187,27 +258,37 @@ public void testCloseImmediatePropagationToChildren() throws Exception { @Test public void testEnqueueEventWhenClosed() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue mainQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TASK_ID) + .build(); + EventQueue childQueue = mainQueue.tap(); Event event = fromJson(MINIMAL_TASK, Task.class); - queue.close(); // Close the queue first - assertTrue(queue.isClosed()); + childQueue.close(); // Close the child queue first (removes from children list) + assertTrue(childQueue.isClosed()); + + // Create a new child queue BEFORE enqueuing (ensures it's in children list for distribution) + EventQueue newChildQueue = mainQueue.tap(); // MainQueue accepts events even when closed (for replication support) // This ensures late-arriving replicated events can be enqueued to closed queues - queue.enqueueEvent(event); + // Note: MainEventBusProcessor runs asynchronously, so we use dequeueEventItem with timeout + mainQueue.enqueueEvent(event); - // Event should be available for dequeuing - Event dequeuedEvent = queue.dequeueEventItem(-1).getEvent(); + // New child queue should receive the event (old closed child was removed from children list) + EventQueueItem item = newChildQueue.dequeueEventItem(5000); + assertNotNull(item); + Event dequeuedEvent = item.getEvent(); assertSame(event, dequeuedEvent); - // Now queue is closed and empty, should throw exception - assertThrows(EventQueueClosedException.class, () -> queue.dequeueEventItem(-1)); + // Now new child queue is closed and empty, should throw exception + newChildQueue.close(); + assertThrows(EventQueueClosedException.class, () -> newChildQueue.dequeueEventItem(-1)); } @Test public void testDequeueEventWhenClosedAndEmpty() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build().tap(); queue.close(); assertTrue(queue.isClosed()); @@ -217,19 +298,27 @@ public void testDequeueEventWhenClosedAndEmpty() throws Exception { @Test public void testDequeueEventWhenClosedButHasEvents() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue mainQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TASK_ID) + .build(); + EventQueue childQueue = mainQueue.tap(); Event event = fromJson(MINIMAL_TASK, Task.class); - queue.enqueueEvent(event); - queue.close(); // Graceful close - events should remain - assertTrue(queue.isClosed()); + // Use callback to wait for event processing instead of polling + waitForEventProcessing(() -> mainQueue.enqueueEvent(event)); - // Should still be able to dequeue existing events - Event dequeuedEvent = queue.dequeueEventItem(-1).getEvent(); + // At this point, event has been processed and distributed to childQueue + childQueue.close(); // Graceful close - events should remain + assertTrue(childQueue.isClosed()); + + // Should still be able to dequeue existing events from closed queue + EventQueueItem item = childQueue.dequeueEventItem(5000); + assertNotNull(item); + Event dequeuedEvent = item.getEvent(); assertSame(event, dequeuedEvent); // Now queue is closed and empty, should throw exception - assertThrows(EventQueueClosedException.class, () -> queue.dequeueEventItem(-1)); + assertThrows(EventQueueClosedException.class, () -> childQueue.dequeueEventItem(-1)); } @Test @@ -244,7 +333,9 @@ public void testEnqueueAndDequeueEvent() throws Exception { public void testDequeueEventNoWait() throws Exception { Event event = fromJson(MINIMAL_TASK, Task.class); eventQueue.enqueueEvent(event); - Event dequeuedEvent = eventQueue.dequeueEventItem(-1).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event dequeuedEvent = item.getEvent(); assertSame(event, dequeuedEvent); } @@ -257,7 +348,7 @@ public void testDequeueEventEmptyQueueNoWait() throws Exception { @Test public void testDequeueEventWait() throws Exception { Event event = TaskStatusUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .status(new TaskStatus(TaskState.WORKING)) .isFinal(true) @@ -271,7 +362,7 @@ public void testDequeueEventWait() throws Exception { @Test public void testTaskDone() throws Exception { Event event = TaskArtifactUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .artifact(Artifact.builder() .artifactId("11") @@ -347,7 +438,7 @@ public void testCloseIdempotent() throws Exception { assertTrue(eventQueue.isClosed()); // Test with immediate close as well - EventQueue eventQueue2 = EventQueue.builder().build(); + EventQueue eventQueue2 = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); eventQueue2.close(true); assertTrue(eventQueue2.isClosed()); @@ -361,19 +452,20 @@ public void testCloseIdempotent() throws Exception { */ @Test public void testCloseChildQueues() throws Exception { - EventQueue childQueue = eventQueue.tap(); + EventQueue mainQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); + EventQueue childQueue = mainQueue.tap(); assertTrue(childQueue != null); // Graceful close - parent closes but children remain open - eventQueue.close(); - assertTrue(eventQueue.isClosed()); + mainQueue.close(); + assertTrue(mainQueue.isClosed()); assertFalse(childQueue.isClosed()); // Child NOT closed on graceful parent close // Immediate close - parent force-closes all children - EventQueue parentQueue2 = EventQueue.builder().build(); - EventQueue childQueue2 = parentQueue2.tap(); - parentQueue2.close(true); // immediate=true - assertTrue(parentQueue2.isClosed()); + EventQueue mainQueue2 = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); + EventQueue childQueue2 = mainQueue2.tap(); + mainQueue2.close(true); // immediate=true + assertTrue(mainQueue2.isClosed()); assertTrue(childQueue2.isClosed()); // Child IS closed on immediate parent close } @@ -383,7 +475,7 @@ public void testCloseChildQueues() throws Exception { */ @Test public void testMainQueueReferenceCountingStaysOpenWithActiveChildren() throws Exception { - EventQueue mainQueue = EventQueue.builder().build(); + EventQueue mainQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); EventQueue child1 = mainQueue.tap(); EventQueue child2 = mainQueue.tap(); diff --git a/server-common/src/test/java/io/a2a/server/events/EventQueueUtil.java b/server-common/src/test/java/io/a2a/server/events/EventQueueUtil.java index 39201c1f6..6c9ed4a17 100644 --- a/server-common/src/test/java/io/a2a/server/events/EventQueueUtil.java +++ b/server-common/src/test/java/io/a2a/server/events/EventQueueUtil.java @@ -1,8 +1,39 @@ package io.a2a.server.events; +import java.util.concurrent.atomic.AtomicInteger; + public class EventQueueUtil { - // Since EventQueue.builder() is package protected, add a method to expose it - public static EventQueue.EventQueueBuilder getEventQueueBuilder() { - return EventQueue.builder(); + // Counter for generating unique test taskIds + private static final AtomicInteger TASK_ID_COUNTER = new AtomicInteger(0); + + /** + * Get an EventQueue builder pre-configured with the shared test MainEventBus and a unique taskId. + *

    + * Note: Returns MainQueue - tests should call .tap() if they need to consume events. + *

    + * + * @return builder with TEST_EVENT_BUS and unique taskId already set + */ + public static EventQueue.EventQueueBuilder getEventQueueBuilder(MainEventBus eventBus) { + return EventQueue.builder(eventBus) + .taskId("test-task-" + TASK_ID_COUNTER.incrementAndGet()); + } + + /** + * Start a MainEventBusProcessor instance. + * + * @param processor the processor to start + */ + public static void start(MainEventBusProcessor processor) { + processor.start(); + } + + /** + * Stop a MainEventBusProcessor instance. + * + * @param processor the processor to stop + */ + public static void stop(MainEventBusProcessor processor) { + processor.stop(); } } diff --git a/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java b/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java index 1eca1b739..808a1107a 100644 --- a/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java +++ b/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java @@ -14,7 +14,10 @@ import java.util.concurrent.ExecutionException; import java.util.stream.IntStream; +import io.a2a.server.tasks.InMemoryTaskStore; import io.a2a.server.tasks.MockTaskStateProvider; +import io.a2a.server.tasks.PushNotificationSender; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -22,17 +25,31 @@ public class InMemoryQueueManagerTest { private InMemoryQueueManager queueManager; private MockTaskStateProvider taskStateProvider; + private InMemoryTaskStore taskStore; + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; @BeforeEach public void setUp() { taskStateProvider = new MockTaskStateProvider(); - queueManager = new InMemoryQueueManager(taskStateProvider); + taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER); + EventQueueUtil.start(mainEventBusProcessor); + + queueManager = new InMemoryQueueManager(taskStateProvider, mainEventBus); + } + + @AfterEach + public void tearDown() { + EventQueueUtil.stop(mainEventBusProcessor); } @Test public void testAddNewQueue() { String taskId = "test_task_id"; - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); @@ -43,8 +60,8 @@ public void testAddNewQueue() { @Test public void testAddExistingQueueThrowsException() { String taskId = "test_task_id"; - EventQueue queue1 = EventQueue.builder().build(); - EventQueue queue2 = EventQueue.builder().build(); + EventQueue queue1 = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); + EventQueue queue2 = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue1); @@ -56,7 +73,7 @@ public void testAddExistingQueueThrowsException() { @Test public void testGetExistingQueue() { String taskId = "test_task_id"; - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); EventQueue result = queueManager.get(taskId); @@ -73,7 +90,7 @@ public void testGetNonexistentQueue() { @Test public void testTapExistingQueue() { String taskId = "test_task_id"; - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); EventQueue tappedQueue = queueManager.tap(taskId); @@ -94,7 +111,7 @@ public void testTapNonexistentQueue() { @Test public void testCloseExistingQueue() { String taskId = "test_task_id"; - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); queueManager.close(taskId); @@ -129,7 +146,7 @@ public void testCreateOrTapNewQueue() { @Test public void testCreateOrTapExistingQueue() { String taskId = "test_task_id"; - EventQueue originalQueue = EventQueue.builder().build(); + EventQueue originalQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, originalQueue); EventQueue result = queueManager.createOrTap(taskId); @@ -151,7 +168,7 @@ public void testConcurrentOperations() throws InterruptedException, ExecutionExc // Add tasks concurrently List> addFutures = taskIds.stream() .map(taskId -> CompletableFuture.supplyAsync(() -> { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); return taskId; })) diff --git a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java index ea5bbe797..4535bbeb3 100644 --- a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java +++ b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java @@ -26,7 +26,10 @@ import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.events.EventQueue; import io.a2a.server.events.EventQueueItem; +import io.a2a.server.events.EventQueueUtil; import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.events.MainEventBus; +import io.a2a.server.events.MainEventBusProcessor; import io.a2a.server.tasks.BasePushNotificationSender; import io.a2a.server.tasks.InMemoryPushNotificationConfigStore; import io.a2a.server.tasks.InMemoryTaskStore; @@ -66,6 +69,8 @@ public class AbstractA2ARequestHandlerTest { private static final String PREFERRED_TRANSPORT = "preferred-transport"; private static final String A2A_REQUESTHANDLER_TEST_PROPERTIES = "/a2a-requesthandler-test.properties"; + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; + protected AgentExecutor executor; protected TaskStore taskStore; protected RequestHandler requestHandler; @@ -73,6 +78,8 @@ public class AbstractA2ARequestHandlerTest { protected AgentExecutorMethod agentExecutorCancel; protected InMemoryQueueManager queueManager; protected TestHttpClient httpClient; + protected MainEventBus mainEventBus; + protected MainEventBusProcessor mainEventBusProcessor; protected final Executor internalExecutor = Executors.newCachedThreadPool(); @@ -96,19 +103,32 @@ public void cancel(RequestContext context, EventQueue eventQueue) throws A2AErro InMemoryTaskStore inMemoryTaskStore = new InMemoryTaskStore(); taskStore = inMemoryTaskStore; - queueManager = new InMemoryQueueManager(inMemoryTaskStore); + + // Create push notification components BEFORE MainEventBusProcessor httpClient = new TestHttpClient(); PushNotificationConfigStore pushConfigStore = new InMemoryPushNotificationConfigStore(); PushNotificationSender pushSender = new BasePushNotificationSender(pushConfigStore, httpClient); + // Create MainEventBus and MainEventBusProcessor (production code path) + mainEventBus = new MainEventBus(); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, pushSender); + EventQueueUtil.start(mainEventBusProcessor); + + queueManager = new InMemoryQueueManager(inMemoryTaskStore, mainEventBus); + requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, pushConfigStore, pushSender, internalExecutor); + executor, taskStore, queueManager, pushConfigStore, mainEventBusProcessor, internalExecutor, internalExecutor); } @AfterEach public void cleanup() { agentExecutorExecute = null; agentExecutorCancel = null; + + // Stop MainEventBusProcessor background thread + if (mainEventBusProcessor != null) { + EventQueueUtil.stop(mainEventBusProcessor); + } } protected static AgentCard createAgentCard(boolean streaming, boolean pushNotifications, boolean stateTransitionHistory) { diff --git a/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java b/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java index 293babe4e..42a940fae 100644 --- a/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java +++ b/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java @@ -3,11 +3,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -17,9 +19,13 @@ import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.auth.UnauthenticatedUser; import io.a2a.server.events.EventQueue; +import io.a2a.server.events.EventQueueUtil; import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.events.MainEventBus; +import io.a2a.server.events.MainEventBusProcessor; import io.a2a.server.tasks.InMemoryPushNotificationConfigStore; import io.a2a.server.tasks.InMemoryTaskStore; +import io.a2a.server.tasks.PushNotificationSender; import io.a2a.server.tasks.TaskUpdater; import io.a2a.spec.A2AError; import io.a2a.spec.ListTaskPushNotificationConfigParams; @@ -32,6 +38,7 @@ import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatus; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; @@ -50,26 +57,75 @@ public class DefaultRequestHandlerTest { private InMemoryQueueManager queueManager; private TestAgentExecutor agentExecutor; private ServerCallContext serverCallContext; + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; + + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; @BeforeEach void setUp() { taskStore = new InMemoryTaskStore(); + + // Create MainEventBus and MainEventBusProcessor (production code path) + mainEventBus = new MainEventBus(); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER); + EventQueueUtil.start(mainEventBusProcessor); + // Pass taskStore as TaskStateProvider to queueManager for task-aware queue management - queueManager = new InMemoryQueueManager(taskStore); + queueManager = new InMemoryQueueManager(taskStore, mainEventBus); + agentExecutor = new TestAgentExecutor(); + ExecutorService executor = Executors.newCachedThreadPool(); requestHandler = DefaultRequestHandler.create( agentExecutor, taskStore, queueManager, null, // pushConfigStore - null, // pushSender - Executors.newCachedThreadPool() + mainEventBusProcessor, + executor, + executor ); serverCallContext = new ServerCallContext(UnauthenticatedUser.INSTANCE, Map.of(), Set.of()); } + @AfterEach + void tearDown() { + // Stop MainEventBusProcessor background thread + // Note: Don't clear callback here - DefaultRequestHandler has a permanent callback + if (mainEventBusProcessor != null) { + EventQueueUtil.stop(mainEventBusProcessor); + } + } + + /** + * Helper to wait for task finalization in background (for non-blocking tests). + *

    + * Note: Does NOT set callbacks - DefaultRequestHandler has a permanent callback. + * Simply polls TaskStore until task reaches final state. + *

    + * + * @param action the action that triggers task finalization (e.g., allowing agent to complete) + * @param taskId the task ID to wait for + * @throws InterruptedException if waiting is interrupted + * @throws AssertionError if finalization doesn't complete within timeout + */ + private void waitForTaskFinalization(Runnable action, String taskId) throws InterruptedException { + action.run(); + + // Poll TaskStore for final state (non-blocking tests complete in background) + for (int i = 0; i < 50; i++) { + Task task = taskStore.get(taskId); + if (task != null && task.status() != null && task.status().state() != null + && task.status().state().isFinal()) { + return; // Success! + } + Thread.sleep(100); + } + fail("Task " + taskId + " should have been finalized within timeout"); + } + /** * Test that multiple blocking messages to the same task work correctly * when agent doesn't emit final events (fire-and-forget pattern). @@ -576,32 +632,15 @@ void testNonBlockingMessagePersistsAllEventsInBackground() throws Exception { // At this point, the non-blocking call has returned, but the agent is still running - // Allow the agent to emit the final COMPLETED event - allowCompletion.countDown(); - - // Assertion 2: Poll for the final task state to be persisted in background - // Use polling loop instead of fixed sleep for faster and more reliable test - long timeoutMs = 5000; - long startTime = System.currentTimeMillis(); - Task persistedTask = null; - boolean completedStateFound = false; - - while (System.currentTimeMillis() - startTime < timeoutMs) { - persistedTask = taskStore.get(taskId); - if (persistedTask != null && persistedTask.status().state() == TaskState.COMPLETED) { - completedStateFound = true; - break; - } - Thread.sleep(100); // Poll every 100ms - } + // Assertion 2: Wait for the final task to be processed and finalized in background + // Poll TaskStore for finalization (background consumption) + waitForTaskFinalization(() -> allowCompletion.countDown(), taskId); - assertTrue(persistedTask != null, "Task should be persisted to store"); - assertTrue( - completedStateFound, - "Final task state should be COMPLETED (background consumption should have processed it), got: " + - (persistedTask != null ? persistedTask.status().state() : "null") + - " after " + (System.currentTimeMillis() - startTime) + "ms" - ); + // Verify the task was persisted with COMPLETED state + Task persistedTask = taskStore.get(taskId); + assertNotNull(persistedTask, "Task should be persisted to store"); + assertEquals(TaskState.COMPLETED, persistedTask.status().state(), + "Final task state should be COMPLETED (background consumption should have processed it)"); } /** @@ -779,13 +818,16 @@ void testBlockingCallReturnsCompleteTaskWithArtifacts() throws Exception { }); // Call blocking onMessageSend - should wait for ALL events + // DefaultRequestHandler now waits internally for task finalization before returning Object result = requestHandler.onMessageSend(params, serverCallContext); // The returned result should be a Task with ALL artifacts assertTrue(result instanceof Task, "Result should be a Task"); Task returnedTask = (Task) result; - // Verify task is completed + // Fetch final state from TaskStore (guaranteed to be persisted after blocking call) + returnedTask = taskStore.get(taskId); + assertEquals(TaskState.COMPLETED, returnedTask.status().state(), "Returned task should be COMPLETED"); @@ -817,13 +859,15 @@ void testBlockingMessageStoresPushNotificationConfigForNewTask() throws Exceptio InMemoryPushNotificationConfigStore pushConfigStore = new InMemoryPushNotificationConfigStore(); // Re-create request handler with pushConfigStore + ExecutorService pushTestExecutor = Executors.newCachedThreadPool(); requestHandler = DefaultRequestHandler.create( agentExecutor, taskStore, queueManager, pushConfigStore, // Add push config store - null, // pushSender - Executors.newCachedThreadPool() + mainEventBusProcessor, + pushTestExecutor, + pushTestExecutor ); // Create push notification config @@ -888,13 +932,15 @@ void testBlockingMessageStoresPushNotificationConfigForExistingTask() throws Exc InMemoryPushNotificationConfigStore pushConfigStore = new InMemoryPushNotificationConfigStore(); // Re-create request handler with pushConfigStore + ExecutorService pushTestExecutor = Executors.newCachedThreadPool(); requestHandler = DefaultRequestHandler.create( agentExecutor, taskStore, queueManager, pushConfigStore, // Add push config store - null, // pushSender - Executors.newCachedThreadPool() + mainEventBusProcessor, + pushTestExecutor, + pushTestExecutor ); // Create EXISTING task in store diff --git a/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java b/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java index d64729077..b33fa4132 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java @@ -11,18 +11,25 @@ import static org.mockito.Mockito.when; import java.util.Collections; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import io.a2a.server.events.EventConsumer; import io.a2a.server.events.EventQueue; +import io.a2a.server.events.EventQueueUtil; import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.events.MainEventBus; +import io.a2a.server.events.MainEventBusProcessor; +import io.a2a.spec.Event; import io.a2a.spec.EventKind; import io.a2a.spec.Message; import io.a2a.spec.Task; import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatus; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; @@ -49,7 +56,7 @@ public class ResultAggregatorTest { @BeforeEach void setUp() { MockitoAnnotations.openMocks(this); - aggregator = new ResultAggregator(mockTaskManager, null, testExecutor); + aggregator = new ResultAggregator(mockTaskManager, null, testExecutor, testExecutor); } // Helper methods for creating sample data @@ -69,13 +76,45 @@ private Task createSampleTask(String taskId, TaskState statusState, String conte .build(); } + /** + * Helper to wait for MainEventBusProcessor to process an event. + * Replaces polling patterns with deterministic callback-based waiting. + * + * @param processor the processor to set callback on + * @param action the action that triggers event processing + * @throws InterruptedException if waiting is interrupted + * @throws AssertionError if processing doesn't complete within timeout + */ + private void waitForEventProcessing(MainEventBusProcessor processor, Runnable action) throws InterruptedException { + CountDownLatch processingLatch = new CountDownLatch(1); + processor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, Event event) { + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String taskId) { + // Not needed for basic event processing wait + } + }); + + try { + action.run(); + assertTrue(processingLatch.await(5, TimeUnit.SECONDS), + "MainEventBusProcessor should have processed the event within timeout"); + } finally { + processor.setCallback(null); + } + } + // Basic functionality tests @Test void testConstructorWithMessage() { Message initialMessage = createSampleMessage("initial", "msg1", Message.Role.USER); - ResultAggregator aggregatorWithMessage = new ResultAggregator(mockTaskManager, initialMessage, testExecutor); + ResultAggregator aggregatorWithMessage = new ResultAggregator(mockTaskManager, initialMessage, testExecutor, testExecutor); // Test that the message is properly stored by checking getCurrentResult assertEquals(initialMessage, aggregatorWithMessage.getCurrentResult()); @@ -86,7 +125,7 @@ void testConstructorWithMessage() { @Test void testGetCurrentResultWithMessageSet() { Message sampleMessage = createSampleMessage("hola", "msg1", Message.Role.USER); - ResultAggregator aggregatorWithMessage = new ResultAggregator(mockTaskManager, sampleMessage, testExecutor); + ResultAggregator aggregatorWithMessage = new ResultAggregator(mockTaskManager, sampleMessage, testExecutor, testExecutor); EventKind result = aggregatorWithMessage.getCurrentResult(); @@ -121,7 +160,7 @@ void testConstructorStoresTaskManagerCorrectly() { @Test void testConstructorWithNullMessage() { - ResultAggregator aggregatorWithNullMessage = new ResultAggregator(mockTaskManager, null, testExecutor); + ResultAggregator aggregatorWithNullMessage = new ResultAggregator(mockTaskManager, null, testExecutor, testExecutor); Task expectedTask = createSampleTask("null_msg_task", TaskState.WORKING, "ctx1"); when(mockTaskManager.getTask()).thenReturn(expectedTask); @@ -181,7 +220,7 @@ void testMultipleGetCurrentResultCalls() { void testGetCurrentResultWithMessageTakesPrecedence() { // Test that when both message and task are available, message takes precedence Message message = createSampleMessage("priority message", "pri1", Message.Role.USER); - ResultAggregator messageAggregator = new ResultAggregator(mockTaskManager, message, testExecutor); + ResultAggregator messageAggregator = new ResultAggregator(mockTaskManager, message, testExecutor, testExecutor); // Even if we set up the task manager to return something, message should take precedence Task task = createSampleTask("should_not_be_returned", TaskState.WORKING, "ctx1"); @@ -197,17 +236,25 @@ void testGetCurrentResultWithMessageTakesPrecedence() { @Test void testConsumeAndBreakNonBlocking() throws Exception { // Test that with blocking=false, the method returns after the first event - Task firstEvent = createSampleTask("non_blocking_task", TaskState.WORKING, "ctx1"); + String taskId = "test-task"; + Task firstEvent = createSampleTask(taskId, TaskState.WORKING, "ctx1"); // After processing firstEvent, the current result will be that task when(mockTaskManager.getTask()).thenReturn(firstEvent); // Create an event queue using QueueManager (which has access to builder) + MainEventBus mainEventBus = new MainEventBus(); + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + MainEventBusProcessor processor = new MainEventBusProcessor(mainEventBus, taskStore, task -> {}); + EventQueueUtil.start(processor); + InMemoryQueueManager queueManager = - new InMemoryQueueManager(new MockTaskStateProvider()); + new InMemoryQueueManager(new MockTaskStateProvider(), mainEventBus); + + EventQueue queue = queueManager.getEventQueueBuilder(taskId).build().tap(); - EventQueue queue = queueManager.getEventQueueBuilder("test-task").build(); - queue.enqueueEvent(firstEvent); + // Use callback to wait for event processing (replaces polling) + waitForEventProcessing(processor, () -> queue.enqueueEvent(firstEvent)); // Create real EventConsumer with the queue EventConsumer eventConsumer = @@ -221,11 +268,16 @@ void testConsumeAndBreakNonBlocking() throws Exception { assertEquals(firstEvent, result.eventType()); assertTrue(result.interrupted()); - verify(mockTaskManager).process(firstEvent); - // getTask() is called at least once for the return value (line 255) - // May be called once more if debug logging executes in time (line 209) - // The async consumer may or may not execute before verification, so we accept 1-2 calls - verify(mockTaskManager, atLeast(1)).getTask(); - verify(mockTaskManager, atMost(2)).getTask(); + // NOTE: ResultAggregator no longer calls taskManager.process() + // That responsibility has moved to MainEventBusProcessor for centralized persistence + // + // NOTE: Since firstEvent is a Task, ResultAggregator captures it directly from the queue + // (capturedTask.get() at line 283 in ResultAggregator). Therefore, taskManager.getTask() + // is only called for debug logging in taskIdForLogging() (line 305), which may or may not + // execute depending on timing and log level. We expect 0-1 calls, not 1-2. + verify(mockTaskManager, atMost(1)).getTask(); + + // Cleanup: stop the processor + EventQueueUtil.stop(processor); } } diff --git a/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java b/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java index 40f763569..fd195e0a5 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java @@ -14,7 +14,10 @@ import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.events.EventQueue; +import io.a2a.server.events.EventQueueItem; import io.a2a.server.events.EventQueueUtil; +import io.a2a.server.events.MainEventBus; +import io.a2a.server.events.MainEventBusProcessor; import io.a2a.spec.Event; import io.a2a.spec.Message; import io.a2a.spec.Part; @@ -22,6 +25,7 @@ import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -38,14 +42,27 @@ public class TaskUpdaterTest { private static final List> SAMPLE_PARTS = List.of(new TextPart("Test message")); + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; + EventQueue eventQueue; + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; private TaskUpdater taskUpdater; @BeforeEach public void init() { - eventQueue = EventQueueUtil.getEventQueueBuilder().build(); + // Set up MainEventBus and processor for production-like test environment + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER); + EventQueueUtil.start(mainEventBusProcessor); + + eventQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TEST_TASK_ID) + .mainEventBus(mainEventBus) + .build().tap(); RequestContext context = new RequestContext.Builder() .setTaskId(TEST_TASK_ID) .setContextId(TEST_TASK_CONTEXT_ID) @@ -53,10 +70,19 @@ public void init() { taskUpdater = new TaskUpdater(context, eventQueue); } + @AfterEach + public void cleanup() { + if (mainEventBusProcessor != null) { + EventQueueUtil.stop(mainEventBusProcessor); + } + } + @Test public void testAddArtifactWithCustomIdAndName() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, "custom-artifact-id", "Custom Artifact", null); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -239,7 +265,9 @@ public void testNewAgentMessageWithMetadata() throws Exception { @Test public void testAddArtifactWithAppendTrue() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, "artifact-id", "Test Artifact", null, true, null); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -258,7 +286,9 @@ public void testAddArtifactWithAppendTrue() throws Exception { @Test public void testAddArtifactWithLastChunkTrue() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, "artifact-id", "Test Artifact", null, null, true); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -273,7 +303,9 @@ public void testAddArtifactWithLastChunkTrue() throws Exception { @Test public void testAddArtifactWithAppendAndLastChunk() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, "artifact-id", "Test Artifact", null, true, false); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -287,7 +319,9 @@ public void testAddArtifactWithAppendAndLastChunk() throws Exception { @Test public void testAddArtifactGeneratesIdWhenNull() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, null, "Test Artifact", null); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -383,7 +417,9 @@ public void testConcurrentCompletionAttempts() throws Exception { thread2.join(); // Exactly one event should have been queued - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskStatusUpdateEvent.class, event); @@ -396,7 +432,10 @@ public void testConcurrentCompletionAttempts() throws Exception { } private TaskStatusUpdateEvent checkTaskStatusUpdateEventOnQueue(boolean isFinal, TaskState state, Message statusMessage) throws Exception { - Event event = eventQueue.dequeueEventItem(0).getEvent(); + // Wait up to 5 seconds for event (async MainEventBusProcessor needs time to distribute) + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskStatusUpdateEvent.class, event); @@ -408,6 +447,7 @@ private TaskStatusUpdateEvent checkTaskStatusUpdateEventOnQueue(boolean isFinal, assertEquals(state, tsue.status().state()); assertEquals(statusMessage, tsue.status().message()); + // Check no additional events (still use 0 timeout for this check) assertNull(eventQueue.dequeueEventItem(0)); return tsue; diff --git a/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java b/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java index 408205aa2..439d97497 100644 --- a/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java +++ b/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java @@ -242,7 +242,7 @@ public void sendStreamingMessage(io.a2a.grpc.SendMessageRequest request, A2AExtensions.validateRequiredExtensions(getAgentCardInternal(), context); MessageSendParams params = FromProto.messageSendParams(request); Flow.Publisher publisher = getRequestHandler().onMessageSendStream(params, context); - convertToStreamResponse(publisher, responseObserver); + convertToStreamResponse(publisher, responseObserver, context); } catch (A2AError e) { handleError(responseObserver, e); } catch (SecurityException e) { @@ -264,7 +264,7 @@ public void subscribeToTask(io.a2a.grpc.SubscribeToTaskRequest request, ServerCallContext context = createCallContext(responseObserver); TaskIdParams params = FromProto.taskIdParams(request); Flow.Publisher publisher = getRequestHandler().onResubscribeToTask(params, context); - convertToStreamResponse(publisher, responseObserver); + convertToStreamResponse(publisher, responseObserver, context); } catch (A2AError e) { handleError(responseObserver, e); } catch (SecurityException e) { @@ -275,7 +275,8 @@ public void subscribeToTask(io.a2a.grpc.SubscribeToTaskRequest request, } private void convertToStreamResponse(Flow.Publisher publisher, - StreamObserver responseObserver) { + StreamObserver responseObserver, + ServerCallContext context) { CompletableFuture.runAsync(() -> { publisher.subscribe(new Flow.Subscriber() { private Flow.Subscription subscription; @@ -285,6 +286,18 @@ public void onSubscribe(Flow.Subscription subscription) { this.subscription = subscription; subscription.request(1); + // Detect gRPC client disconnect and call EventConsumer.cancel() directly + // This stops the polling loop without relying on subscription cancellation propagation + Context grpcContext = Context.current(); + grpcContext.addListener(new Context.CancellationListener() { + @Override + public void cancelled(Context ctx) { + LOGGER.fine(() -> "gRPC call cancelled by client, calling EventConsumer.cancel() to stop polling loop"); + context.invokeEventConsumerCancelCallback(); + subscription.cancel(); + } + }, getExecutor()); + // Notify tests that we are subscribed Runnable runnable = streamingSubscribedRunnable; if (runnable != null) { @@ -305,6 +318,8 @@ public void onNext(StreamingEventKind event) { @Override public void onError(Throwable throwable) { + // Cancel upstream to stop EventConsumer when error occurs + subscription.cancel(); if (throwable instanceof A2AError jsonrpcError) { handleError(responseObserver, jsonrpcError); } else { @@ -329,6 +344,9 @@ public void getExtendedAgentCard(io.a2a.grpc.GetExtendedAgentCardRequest request if (extendedAgentCard != null) { responseObserver.onNext(ToProto.agentCard(extendedAgentCard)); responseObserver.onCompleted(); + } else { + // Extended agent card not configured - return error instead of hanging + handleError(responseObserver, new ExtendedAgentCardNotConfiguredError(null, "Extended agent card not configured", null)); } } catch (Throwable t) { handleInternalError(responseObserver, t); diff --git a/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java b/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java index 690d69a87..f7d711382 100644 --- a/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java +++ b/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java @@ -281,8 +281,7 @@ public void testPushNotificationsNotSupportedError() throws Exception { @Test public void testOnGetPushNotificationNoPushNotifierConfig() throws Exception { // Create request handler without a push notifier - DefaultRequestHandler requestHandler = - new DefaultRequestHandler(executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = AbstractA2ARequestHandlerTest.createAgentCard(false, true, false); GrpcHandler handler = new TestGrpcHandler(card, requestHandler, internalExecutor); String NAME = "tasks/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id() + "/pushNotificationConfigs/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(); @@ -293,8 +292,7 @@ public void testOnGetPushNotificationNoPushNotifierConfig() throws Exception { @Test public void testOnSetPushNotificationNoPushNotifierConfig() throws Exception { // Create request handler without a push notifier - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = AbstractA2ARequestHandlerTest.createAgentCard(false, true, false); GrpcHandler handler = new TestGrpcHandler(card, requestHandler, internalExecutor); String NAME = "tasks/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id() + "/pushNotificationConfigs/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(); @@ -424,9 +422,14 @@ public void testOnMessageStreamNewMessageExistingTaskSuccessMocks() throws Excep @Test public void testOnMessageStreamNewMessageSendPushNotificationSuccess() throws Exception { - GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - List events = List.of( - AbstractA2ARequestHandlerTest.MINIMAL_TASK, + // Use synchronous executor for push notifications to ensure deterministic ordering + // Without this, async push notifications can execute out of order, causing test flakiness + mainEventBusProcessor.setPushNotificationExecutor(Runnable::run); + + try { + GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); + List events = List.of( + AbstractA2ARequestHandlerTest.MINIMAL_TASK, TaskArtifactUpdateEvent.builder() .taskId(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id()) .contextId(AbstractA2ARequestHandlerTest.MINIMAL_TASK.contextId()) @@ -493,13 +496,16 @@ public void onCompleted() { Assertions.assertEquals(1, curr.artifacts().get(0).parts().size()); Assertions.assertEquals("text", ((TextPart)curr.artifacts().get(0).parts().get(0)).text()); - curr = httpClient.tasks.get(2); - Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(), curr.id()); - Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.contextId(), curr.contextId()); - Assertions.assertEquals(io.a2a.spec.TaskState.COMPLETED, curr.status().state()); - Assertions.assertEquals(1, curr.artifacts().size()); - Assertions.assertEquals(1, curr.artifacts().get(0).parts().size()); - Assertions.assertEquals("text", ((TextPart)curr.artifacts().get(0).parts().get(0)).text()); + curr = httpClient.tasks.get(2); + Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(), curr.id()); + Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.contextId(), curr.contextId()); + Assertions.assertEquals(io.a2a.spec.TaskState.COMPLETED, curr.status().state()); + Assertions.assertEquals(1, curr.artifacts().size()); + Assertions.assertEquals(1, curr.artifacts().get(0).parts().size()); + Assertions.assertEquals("text", ((TextPart)curr.artifacts().get(0).parts().get(0)).text()); + } finally { + mainEventBusProcessor.setPushNotificationExecutor(null); + } } @Test @@ -668,8 +674,7 @@ public void testListPushNotificationConfigNotSupported() throws Exception { @Test public void testListPushNotificationConfigNoPushConfigStore() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); agentExecutorExecute = (context, eventQueue) -> { @@ -741,8 +746,7 @@ public void testDeletePushNotificationConfigNotSupported() throws Exception { @Test public void testDeletePushNotificationConfigNoPushConfigStore() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); String NAME = "tasks/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id() + "/pushNotificationConfigs/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(); DeleteTaskPushNotificationConfigRequest request = DeleteTaskPushNotificationConfigRequest.newBuilder() diff --git a/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java b/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java index b43c28029..f09e9a40f 100644 --- a/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java +++ b/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java @@ -3,6 +3,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.fail; import java.util.ArrayList; import java.util.Collections; @@ -30,6 +31,8 @@ import io.a2a.jsonrpc.common.wrappers.GetTaskResponse; import io.a2a.jsonrpc.common.wrappers.ListTaskPushNotificationConfigRequest; import io.a2a.jsonrpc.common.wrappers.ListTaskPushNotificationConfigResponse; +import io.a2a.jsonrpc.common.wrappers.ListTasksRequest; +import io.a2a.jsonrpc.common.wrappers.ListTasksResponse; import io.a2a.jsonrpc.common.wrappers.ListTasksResult; import io.a2a.jsonrpc.common.wrappers.SendMessageRequest; import io.a2a.jsonrpc.common.wrappers.SendMessageResponse; @@ -37,12 +40,11 @@ import io.a2a.jsonrpc.common.wrappers.SendStreamingMessageResponse; import io.a2a.jsonrpc.common.wrappers.SetTaskPushNotificationConfigRequest; import io.a2a.jsonrpc.common.wrappers.SetTaskPushNotificationConfigResponse; -import io.a2a.jsonrpc.common.wrappers.ListTasksRequest; -import io.a2a.jsonrpc.common.wrappers.ListTasksResponse; import io.a2a.jsonrpc.common.wrappers.SubscribeToTaskRequest; import io.a2a.server.ServerCallContext; import io.a2a.server.auth.UnauthenticatedUser; import io.a2a.server.events.EventConsumer; +import io.a2a.server.events.MainEventBusProcessorCallback; import io.a2a.server.requesthandlers.AbstractA2ARequestHandlerTest; import io.a2a.server.requesthandlers.DefaultRequestHandler; import io.a2a.server.tasks.ResultAggregator; @@ -52,16 +54,15 @@ import io.a2a.spec.AgentExtension; import io.a2a.spec.AgentInterface; import io.a2a.spec.Artifact; -import io.a2a.spec.ExtendedAgentCardNotConfiguredError; -import io.a2a.spec.ExtensionSupportRequiredError; -import io.a2a.spec.VersionNotSupportedError; import io.a2a.spec.DeleteTaskPushNotificationConfigParams; import io.a2a.spec.Event; +import io.a2a.spec.ExtendedAgentCardNotConfiguredError; +import io.a2a.spec.ExtensionSupportRequiredError; import io.a2a.spec.GetTaskPushNotificationConfigParams; import io.a2a.spec.InternalError; import io.a2a.spec.InvalidRequestError; -import io.a2a.spec.ListTasksParams; import io.a2a.spec.ListTaskPushNotificationConfigParams; +import io.a2a.spec.ListTasksParams; import io.a2a.spec.Message; import io.a2a.spec.MessageSendParams; import io.a2a.spec.PushNotificationConfig; @@ -78,6 +79,7 @@ import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; import io.a2a.spec.UnsupportedOperationError; +import io.a2a.spec.VersionNotSupportedError; import mutiny.zero.ZeroPublisher; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Disabled; @@ -174,38 +176,9 @@ public void testOnMessageNewMessageSuccess() { SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); SendMessageResponse response = handler.onMessageSend(request, callContext); assertNull(response.getError()); - // The Python implementation returns a Task here, but then again they are using hardcoded mocks and - // bypassing the whole EventQueue. - // If we were to send a Task in agentExecutorExecute EventConsumer.consumeAll() would not exit due to - // the Task not having a 'final' state - // - // See testOnMessageNewMessageSuccessMocks() for a test more similar to the Python implementation Assertions.assertSame(message, response.getResult()); } - @Test - public void testOnMessageNewMessageSuccessMocks() { - JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - - Message message = Message.builder(MESSAGE) - .taskId(MINIMAL_TASK.id()) - .contextId(MINIMAL_TASK.contextId()) - .build(); - - SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); - SendMessageResponse response; - try (MockedConstruction mocked = Mockito.mockConstruction( - EventConsumer.class, - (mock, context) -> { - Mockito.doReturn(ZeroPublisher.fromItems(wrapEvent(MINIMAL_TASK))).when(mock).consumeAll(); - Mockito.doCallRealMethod().when(mock).createAgentRunnableDoneCallback(); - })) { - response = handler.onMessageSend(request, callContext); - } - assertNull(response.getError()); - Assertions.assertSame(MINIMAL_TASK, response.getResult()); - } - @Test public void testOnMessageNewMessageWithExistingTaskSuccess() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); @@ -220,38 +193,9 @@ public void testOnMessageNewMessageWithExistingTaskSuccess() { SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); SendMessageResponse response = handler.onMessageSend(request, callContext); assertNull(response.getError()); - // The Python implementation returns a Task here, but then again they are using hardcoded mocks and - // bypassing the whole EventQueue. - // If we were to send a Task in agentExecutorExecute EventConsumer.consumeAll() would not exit due to - // the Task not having a 'final' state - // - // See testOnMessageNewMessageWithExistingTaskSuccessMocks() for a test more similar to the Python implementation Assertions.assertSame(message, response.getResult()); } - @Test - public void testOnMessageNewMessageWithExistingTaskSuccessMocks() { - JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); - - Message message = Message.builder(MESSAGE) - .taskId(MINIMAL_TASK.id()) - .contextId(MINIMAL_TASK.contextId()) - .build(); - SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); - SendMessageResponse response; - try (MockedConstruction mocked = Mockito.mockConstruction( - EventConsumer.class, - (mock, context) -> { - Mockito.doReturn(ZeroPublisher.fromItems(wrapEvent(MINIMAL_TASK))).when(mock).consumeAll(); - })) { - response = handler.onMessageSend(request, callContext); - } - assertNull(response.getError()); - Assertions.assertSame(MINIMAL_TASK, response.getResult()); - - } - @Test public void testOnMessageError() { // See testMessageOnErrorMocks() for a test more similar to the Python implementation, using mocks for @@ -352,9 +296,11 @@ public void onComplete() { @Test public void testOnMessageStreamNewMessageMultipleEventsSuccess() throws InterruptedException { + // Note: Do NOT set callback - DefaultRequestHandler has a permanent callback + // We'll verify persistence by checking TaskStore after streaming completes JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - // Create multiple events to be sent during streaming + // Create multiple events to be sent during streaming Task taskEvent = Task.builder(MINIMAL_TASK) .status(new TaskStatus(TaskState.WORKING)) .build(); @@ -429,8 +375,8 @@ public void onComplete() { } }); - // Wait for all events to be received - Assertions.assertTrue(latch.await(2, TimeUnit.SECONDS), + // Wait for all events to be received (increased timeout for async processing) + Assertions.assertTrue(latch.await(10, TimeUnit.SECONDS), "Expected to receive 3 events within timeout"); // Assert no error occurred during streaming @@ -456,6 +402,17 @@ public void onComplete() { "Third event should be a TaskStatusUpdateEvent"); assertEquals(MINIMAL_TASK.id(), receivedStatus.taskId()); assertEquals(TaskState.COMPLETED, receivedStatus.status().state()); + + // Verify events were persisted to TaskStore (poll for final state) + for (int i = 0; i < 50; i++) { + Task storedTask = taskStore.get(MINIMAL_TASK.id()); + if (storedTask != null && storedTask.status() != null + && TaskState.COMPLETED.equals(storedTask.status().state())) { + return; // Success - task finalized in TaskStore + } + Thread.sleep(100); + } + fail("Task should have been finalized in TaskStore within timeout"); } @Test @@ -729,106 +686,118 @@ public void testGetPushNotificationConfigSuccess() { @Test public void testOnMessageStreamNewMessageSendPushNotificationSuccess() throws Exception { - JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); - - List events = List.of( - MINIMAL_TASK, - TaskArtifactUpdateEvent.builder() - .taskId(MINIMAL_TASK.id()) - .contextId(MINIMAL_TASK.contextId()) - .artifact(Artifact.builder() - .artifactId("11") - .parts(new TextPart("text")) - .build()) - .build(), - TaskStatusUpdateEvent.builder() - .taskId(MINIMAL_TASK.id()) - .contextId(MINIMAL_TASK.contextId()) - .status(new TaskStatus(TaskState.COMPLETED)) - .build()); - - agentExecutorExecute = (context, eventQueue) -> { - // Hardcode the events to send here - for (Event event : events) { - eventQueue.enqueueEvent(event); - } - }; - - TaskPushNotificationConfig config = new TaskPushNotificationConfig( - MINIMAL_TASK.id(), - PushNotificationConfig.builder().id("c295ea44-7543-4f78-b524-7a38915ad6e4").url("http://example.com").build(), "tenant"); - - SetTaskPushNotificationConfigRequest stpnRequest = new SetTaskPushNotificationConfigRequest("1", config); - SetTaskPushNotificationConfigResponse stpnResponse = handler.setPushNotificationConfig(stpnRequest, callContext); - assertNull(stpnResponse.getError()); - - Message msg = Message.builder(MESSAGE) - .taskId(MINIMAL_TASK.id()) - .build(); - SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(msg, null, null)); - Flow.Publisher response = handler.onMessageSendStream(request, callContext); - - final List results = Collections.synchronizedList(new ArrayList<>()); - final AtomicReference subscriptionRef = new AtomicReference<>(); - final CountDownLatch latch = new CountDownLatch(6); - httpClient.latch = latch; - - Executors.newSingleThreadExecutor().execute(() -> { - response.subscribe(new Flow.Subscriber<>() { - @Override - public void onSubscribe(Flow.Subscription subscription) { - subscriptionRef.set(subscription); - subscription.request(1); - } - - @Override - public void onNext(SendStreamingMessageResponse item) { - System.out.println("-> " + item.getResult()); - results.add(item.getResult()); - System.out.println(results); - subscriptionRef.get().request(1); - latch.countDown(); - } - - @Override - public void onError(Throwable throwable) { - subscriptionRef.get().cancel(); - } - - @Override - public void onComplete() { - subscriptionRef.get().cancel(); + // Note: Do NOT set callback - DefaultRequestHandler has a permanent callback + + // Use synchronous executor for push notifications to ensure deterministic ordering + // Without this, async push notifications can execute out of order, causing test flakiness + mainEventBusProcessor.setPushNotificationExecutor(Runnable::run); + + try { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); + taskStore.save(MINIMAL_TASK); + + List events = List.of( + MINIMAL_TASK, + TaskArtifactUpdateEvent.builder() + .taskId(MINIMAL_TASK.id()) + .contextId(MINIMAL_TASK.contextId()) + .artifact(Artifact.builder() + .artifactId("11") + .parts(new TextPart("text")) + .build()) + .build(), + TaskStatusUpdateEvent.builder() + .taskId(MINIMAL_TASK.id()) + .contextId(MINIMAL_TASK.contextId()) + .status(new TaskStatus(TaskState.COMPLETED)) + .build()); + + + agentExecutorExecute = (context, eventQueue) -> { + // Hardcode the events to send here + for (Event event : events) { + eventQueue.enqueueEvent(event); } + }; + + TaskPushNotificationConfig config = new TaskPushNotificationConfig( + MINIMAL_TASK.id(), + PushNotificationConfig.builder().id("c295ea44-7543-4f78-b524-7a38915ad6e4").url("http://example.com").build(), "tenant"); + + SetTaskPushNotificationConfigRequest stpnRequest = new SetTaskPushNotificationConfigRequest("1", config); + SetTaskPushNotificationConfigResponse stpnResponse = handler.setPushNotificationConfig(stpnRequest, callContext); + assertNull(stpnResponse.getError()); + + Message msg = Message.builder(MESSAGE) + .taskId(MINIMAL_TASK.id()) + .build(); + SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(msg, null, null)); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); + + final List results = Collections.synchronizedList(new ArrayList<>()); + final AtomicReference subscriptionRef = new AtomicReference<>(); + final CountDownLatch latch = new CountDownLatch(6); + httpClient.latch = latch; + + Executors.newSingleThreadExecutor().execute(() -> { + response.subscribe(new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + subscriptionRef.set(subscription); + subscription.request(1); + } + + @Override + public void onNext(SendStreamingMessageResponse item) { + System.out.println("-> " + item.getResult()); + results.add(item.getResult()); + System.out.println(results); + subscriptionRef.get().request(1); + latch.countDown(); + } + + @Override + public void onError(Throwable throwable) { + subscriptionRef.get().cancel(); + } + + @Override + public void onComplete() { + subscriptionRef.get().cancel(); + } + }); }); - }); - Assertions.assertTrue(latch.await(5, TimeUnit.SECONDS)); - subscriptionRef.get().cancel(); - assertEquals(3, results.size()); - assertEquals(3, httpClient.tasks.size()); - - Task curr = httpClient.tasks.get(0); - assertEquals(MINIMAL_TASK.id(), curr.id()); - assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); - assertEquals(MINIMAL_TASK.status().state(), curr.status().state()); - assertEquals(0, curr.artifacts() == null ? 0 : curr.artifacts().size()); - - curr = httpClient.tasks.get(1); - assertEquals(MINIMAL_TASK.id(), curr.id()); - assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); - assertEquals(MINIMAL_TASK.status().state(), curr.status().state()); - assertEquals(1, curr.artifacts().size()); - assertEquals(1, curr.artifacts().get(0).parts().size()); - assertEquals("text", ((TextPart) curr.artifacts().get(0).parts().get(0)).text()); - - curr = httpClient.tasks.get(2); - assertEquals(MINIMAL_TASK.id(), curr.id()); - assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); - assertEquals(TaskState.COMPLETED, curr.status().state()); - assertEquals(1, curr.artifacts().size()); - assertEquals(1, curr.artifacts().get(0).parts().size()); - assertEquals("text", ((TextPart) curr.artifacts().get(0).parts().get(0)).text()); + Assertions.assertTrue(latch.await(5, TimeUnit.SECONDS)); + + subscriptionRef.get().cancel(); + assertEquals(3, results.size()); + assertEquals(3, httpClient.tasks.size()); + + Task curr = httpClient.tasks.get(0); + assertEquals(MINIMAL_TASK.id(), curr.id()); + assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); + assertEquals(MINIMAL_TASK.status().state(), curr.status().state()); + assertEquals(0, curr.artifacts() == null ? 0 : curr.artifacts().size()); + + curr = httpClient.tasks.get(1); + assertEquals(MINIMAL_TASK.id(), curr.id()); + assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); + assertEquals(MINIMAL_TASK.status().state(), curr.status().state()); + assertEquals(1, curr.artifacts().size()); + assertEquals(1, curr.artifacts().get(0).parts().size()); + assertEquals("text", ((TextPart) curr.artifacts().get(0).parts().get(0)).text()); + + curr = httpClient.tasks.get(2); + assertEquals(MINIMAL_TASK.id(), curr.id()); + assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); + assertEquals(TaskState.COMPLETED, curr.status().state()); + assertEquals(1, curr.artifacts().size()); + assertEquals(1, curr.artifacts().get(0).parts().size()); + assertEquals("text", ((TextPart) curr.artifacts().get(0).parts().get(0)).text()); + } finally { + mainEventBusProcessor.setPushNotificationExecutor(null); + } } @Test @@ -1060,7 +1029,7 @@ public void onComplete() { if (results.get(0).getError() != null && results.get(0).getError() instanceof InvalidRequestError ire) { assertEquals("Streaming is not supported by the agent", ire.getMessage()); } else { - Assertions.fail("Expected a response containing an error"); + fail("Expected a response containing an error"); } } @@ -1107,7 +1076,7 @@ public void onComplete() { if (results.get(0).getError() != null && results.get(0).getError() instanceof InvalidRequestError ire) { assertEquals("Streaming is not supported by the agent", ire.getMessage()); } else { - Assertions.fail("Expected a response containing an error"); + fail("Expected a response containing an error"); } } @@ -1135,8 +1104,7 @@ public void testPushNotificationsNotSupportedError() { @Test public void testOnGetPushNotificationNoPushNotifierConfig() { // Create request handler without a push notifier - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = createAgentCard(false, true, false); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler, internalExecutor); @@ -1154,8 +1122,7 @@ public void testOnGetPushNotificationNoPushNotifierConfig() { @Test public void testOnSetPushNotificationNoPushNotifierConfig() { // Create request handler without a push notifier - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = createAgentCard(false, true, false); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler, internalExecutor); @@ -1246,8 +1213,7 @@ public void testDefaultRequestHandlerWithCustomComponents() { @Test public void testOnMessageSendErrorHandling() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = createAgentCard(false, true, false); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler, internalExecutor); @@ -1293,16 +1259,17 @@ public void testOnMessageSendTaskIdMismatch() { } @Test - public void testOnMessageStreamTaskIdMismatch() { + public void testOnMessageStreamTaskIdMismatch() throws InterruptedException { + // Note: Do NOT set callback - DefaultRequestHandler has a permanent callback JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK); - agentExecutorExecute = ((context, eventQueue) -> { - eventQueue.enqueueEvent(MINIMAL_TASK); - }); + agentExecutorExecute = ((context, eventQueue) -> { + eventQueue.enqueueEvent(MINIMAL_TASK); + }); - SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); - Flow.Publisher response = handler.onMessageSendStream(request, callContext); + SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); CompletableFuture future = new CompletableFuture<>(); List results = new ArrayList<>(); @@ -1404,8 +1371,7 @@ public void testListPushNotificationConfigNotSupported() { @Test public void testListPushNotificationConfigNoPushConfigStore() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); taskStore.save(MINIMAL_TASK); agentExecutorExecute = (context, eventQueue) -> { @@ -1496,8 +1462,8 @@ public void testDeletePushNotificationConfigNotSupported() { @Test public void testDeletePushNotificationConfigNoPushConfigStore() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = + DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); taskStore.save(MINIMAL_TASK); agentExecutorExecute = (context, eventQueue) -> { diff --git a/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java b/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java index 3ffb56c5f..3273d6119 100644 --- a/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java +++ b/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java @@ -399,32 +399,46 @@ private Flow.Publisher convertToSendStreamingMessageResponse( Flow.Publisher publisher) { // We can't use the normal convertingProcessor since that propagates any errors as an error handled // via Subscriber.onError() rather than as part of the SendStreamingResponse payload + log.log(Level.FINE, "REST: convertToSendStreamingMessageResponse called, creating ZeroPublisher"); return ZeroPublisher.create(createTubeConfig(), tube -> { + log.log(Level.FINE, "REST: ZeroPublisher tube created, starting CompletableFuture.runAsync"); CompletableFuture.runAsync(() -> { + log.log(Level.FINE, "REST: Inside CompletableFuture, subscribing to EventKind publisher"); publisher.subscribe(new Flow.Subscriber() { Flow.@Nullable Subscription subscription; @Override public void onSubscribe(Flow.Subscription subscription) { + log.log(Level.FINE, "REST: onSubscribe called, storing subscription and requesting first event"); this.subscription = subscription; subscription.request(1); } @Override public void onNext(StreamingEventKind item) { + log.log(Level.FINE, "REST: onNext called with event: {0}", item.getClass().getSimpleName()); try { String payload = JsonFormat.printer().omittingInsignificantWhitespace().print(ProtoUtils.ToProto.taskOrMessageStream(item)); + log.log(Level.FINE, "REST: Converted to JSON, sending via tube: {0}", payload.substring(0, Math.min(100, payload.length()))); tube.send(payload); + log.log(Level.FINE, "REST: tube.send() completed, requesting next event from EventConsumer"); + // Request next event from EventConsumer (Chain 1: EventConsumer → RestHandler) + // This is safe because ZeroPublisher buffers items + // Chain 2 (ZeroPublisher → MultiSseSupport) controls actual delivery via request(1) in onWriteDone() if (subscription != null) { subscription.request(1); + } else { + log.log(Level.WARNING, "REST: subscription is null in onNext!"); } } catch (InvalidProtocolBufferException ex) { + log.log(Level.SEVERE, "REST: JSON conversion failed", ex); onError(ex); } } @Override public void onError(Throwable throwable) { + log.log(Level.SEVERE, "REST: onError called", throwable); if (throwable instanceof A2AError jsonrpcError) { tube.send(new HTTPRestErrorResponse(jsonrpcError).toJson()); } else { @@ -435,6 +449,7 @@ public void onError(Throwable throwable) { @Override public void onComplete() { + log.log(Level.FINE, "REST: onComplete called, calling tube.complete()"); tube.complete(); } }); From 0b728aa8c257b7ba29173fa6b760c06662245ad3 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Wed, 28 Jan 2026 16:22:03 +0000 Subject: [PATCH 2/2] fix: Allow Message with null taskId --- ...paDatabasePushNotificationConfigStore.java | 41 ++++ .../core/ReplicatedQueueManager.java | 12 +- .../core/ReplicatedQueueManagerTest.java | 60 ++++-- .../java/io/a2a/server/events/EventQueue.java | 102 ++++++++-- .../server/events/InMemoryQueueManager.java | 33 ++- .../io/a2a/server/events/MainEventBus.java | 6 +- .../server/events/MainEventBusContext.java | 2 +- .../server/events/MainEventBusProcessor.java | 106 +++++++--- .../io/a2a/server/events/QueueManager.java | 19 +- .../DefaultRequestHandler.java | 122 ++++++++---- .../InMemoryPushNotificationConfigStore.java | 8 + .../tasks/PushNotificationConfigStore.java | 16 ++ .../java/io/a2a/server/tasks/TaskManager.java | 42 +++- .../a2a/server/events/EventConsumerTest.java | 3 +- .../io/a2a/server/events/EventQueueTest.java | 3 +- .../events/InMemoryQueueManagerTest.java | 5 +- .../AbstractA2ARequestHandlerTest.java | 5 +- .../DefaultRequestHandlerTest.java | 188 +++++++++++++++++- .../server/tasks/ResultAggregatorTest.java | 5 +- .../io/a2a/server/tasks/TaskManagerTest.java | 30 +-- .../io/a2a/server/tasks/TaskUpdaterTest.java | 4 +- 21 files changed, 676 insertions(+), 136 deletions(-) diff --git a/extras/push-notification-config-store-database-jpa/src/main/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStore.java b/extras/push-notification-config-store-database-jpa/src/main/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStore.java index 36245e277..799d71733 100644 --- a/extras/push-notification-config-store-database-jpa/src/main/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStore.java +++ b/extras/push-notification-config-store-database-jpa/src/main/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStore.java @@ -164,4 +164,45 @@ public void deleteInfo(String taskId, String configId) { taskId, configId); } } + + @Transactional + @Override + public void switchKey(String oldTaskId, String newTaskId) { + LOGGER.debug("Switching PushNotificationConfigs from Task '{}' to Task '{}'", oldTaskId, newTaskId); + + // Find all configs for the old task ID + TypedQuery query = em.createQuery( + "SELECT c FROM JpaPushNotificationConfig c WHERE c.id.taskId = :taskId", + JpaPushNotificationConfig.class); + query.setParameter("taskId", oldTaskId); + List configs = query.getResultList(); + + if (configs.isEmpty()) { + LOGGER.debug("No PushNotificationConfigs found for Task '{}', nothing to switch", oldTaskId); + return; + } + + // For each config, create a new entity with the new task ID and remove the old one + for (JpaPushNotificationConfig oldConfig : configs) { + try { + // Create new config with new task ID + JpaPushNotificationConfig newConfig = JpaPushNotificationConfig.createFromConfig( + newTaskId, oldConfig.getConfig()); + + // Remove old config and persist new one + em.remove(oldConfig); + em.persist(newConfig); + + LOGGER.debug("Switched PushNotificationConfig ID '{}' from Task '{}' to Task '{}'", + oldConfig.getId().getConfigId(), oldTaskId, newTaskId); + } catch (JsonProcessingException e) { + LOGGER.error("Failed to switch PushNotificationConfig ID '{}' from Task '{}' to Task '{}'", + oldConfig.getId().getConfigId(), oldTaskId, newTaskId, e); + throw new RuntimeException("Failed to switch PushNotificationConfig", e); + } + } + + LOGGER.debug("Successfully switched {} PushNotificationConfigs from Task '{}' to Task '{}'", + configs.size(), oldTaskId, newTaskId); + } } diff --git a/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java b/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java index f320362eb..0ec264c5a 100644 --- a/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java +++ b/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java @@ -16,6 +16,7 @@ import io.a2a.server.events.MainEventBus; import io.a2a.server.events.QueueManager; import io.a2a.server.tasks.TaskStateProvider; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -85,7 +86,12 @@ public void close(String taskId) { @Override public EventQueue createOrTap(String taskId) { - EventQueue queue = delegate.createOrTap(taskId); + return createOrTap(taskId, null); + } + + @Override + public EventQueue createOrTap(String taskId, @Nullable String tempId) { + EventQueue queue = delegate.createOrTap(taskId, tempId); return queue; } @@ -106,7 +112,9 @@ public void onReplicatedEvent(@Observes ReplicatedEventQueueItem replicatedEvent } // Get or create a ChildQueue for this task (creates MainQueue if it doesn't exist) - EventQueue childQueue = delegate.createOrTap(replicatedEvent.getTaskId()); + // Replicated events should always have real task IDs (not temp IDs) because + // replication now happens AFTER TaskStore persistence in MainEventBusProcessor + EventQueue childQueue = delegate.createOrTap(replicatedEvent.getTaskId(), null); try { // Get the MainQueue to enqueue the replicated event item diff --git a/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java b/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java index a339be543..5fcc4126b 100644 --- a/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java +++ b/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java @@ -47,18 +47,21 @@ class ReplicatedQueueManagerTest { @BeforeEach void setUp() { - // Create MainEventBus and MainEventBusProcessor for tests + // Create MainEventBus first InMemoryTaskStore taskStore = new InMemoryTaskStore(); mainEventBus = new MainEventBus(); - mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER); - EventQueueUtil.start(mainEventBusProcessor); + // Create QueueManager before MainEventBusProcessor (processor needs it as parameter) queueManager = new ReplicatedQueueManager( new NoOpReplicationStrategy(), new MockTaskStateProvider(true), mainEventBus ); + // Create MainEventBusProcessor with QueueManager + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + testEvent = TaskStatusUpdateEvent.builder() .taskId("test-task") .contextId("test-context") @@ -129,12 +132,14 @@ void testReplicationStrategyTriggeredOnNormalEnqueue() throws InterruptedExcepti String taskId = "test-task-1"; EventQueue queue = queueManager.createOrTap(taskId); + TaskStatusUpdateEvent event = createEventForTask(taskId); - queue.enqueueEvent(testEvent); + // Wait for MainEventBusProcessor to process the event and trigger replication + waitForEventProcessing(() -> queue.enqueueEvent(event)); assertEquals(1, strategy.getCallCount()); assertEquals(taskId, strategy.getLastTaskId()); - assertEquals(testEvent, strategy.getLastEvent()); + assertEquals(event, strategy.getLastEvent()); } @Test @@ -158,13 +163,15 @@ void testReplicationStrategyWithCountingImplementation() throws InterruptedExcep String taskId = "test-task-3"; EventQueue queue = queueManager.createOrTap(taskId); + TaskStatusUpdateEvent event = createEventForTask(taskId); - queue.enqueueEvent(testEvent); - queue.enqueueEvent(testEvent); + // Wait for MainEventBusProcessor to process each event + waitForEventProcessing(() -> queue.enqueueEvent(event)); + waitForEventProcessing(() -> queue.enqueueEvent(event)); assertEquals(2, countingStrategy.getCallCount()); assertEquals(taskId, countingStrategy.getLastTaskId()); - assertEquals(testEvent, countingStrategy.getLastEvent()); + assertEquals(event, countingStrategy.getLastEvent()); ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, testEvent); queueManager.onReplicatedEvent(replicatedEvent); @@ -245,16 +252,17 @@ void testQueueToTaskIdMappingMaintained() throws InterruptedException { String taskId = "test-task-6"; CountingReplicationStrategy countingStrategy = new CountingReplicationStrategy(); queueManager = new ReplicatedQueueManager(countingStrategy, new MockTaskStateProvider(true), mainEventBus); + TaskStatusUpdateEvent event = createEventForTask(taskId); EventQueue queue = queueManager.createOrTap(taskId); - queue.enqueueEvent(testEvent); + waitForEventProcessing(() -> queue.enqueueEvent(event)); assertEquals(taskId, countingStrategy.getLastTaskId()); queueManager.close(taskId); // Task is active, so NO poison pill is sent EventQueue newQueue = queueManager.createOrTap(taskId); - newQueue.enqueueEvent(testEvent); + waitForEventProcessing(() -> newQueue.enqueueEvent(event)); assertEquals(taskId, countingStrategy.getLastTaskId()); // 2 replication calls: 1 testEvent, 1 testEvent (no QueueClosedEvent because task is active) @@ -298,10 +306,25 @@ void testParallelReplicationBehavior() throws InterruptedException { int numThreads = 10; int eventsPerThread = 5; + int expectedEventCount = (numThreads / 2) * eventsPerThread; // Only normal enqueues ExecutorService executor = Executors.newFixedThreadPool(numThreads); CountDownLatch startLatch = new CountDownLatch(1); CountDownLatch doneLatch = new CountDownLatch(numThreads); + // Set up callback to wait for all events to be processed by MainEventBusProcessor + CountDownLatch processingLatch = new CountDownLatch(expectedEventCount); + mainEventBusProcessor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String tid, io.a2a.spec.Event event) { + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String tid) { + // Not needed for this test + } + }); + // Launch threads that will enqueue events normally (should trigger replication) for (int i = 0; i < numThreads / 2; i++) { final int threadId = i; @@ -310,7 +333,7 @@ void testParallelReplicationBehavior() throws InterruptedException { startLatch.await(); for (int j = 0; j < eventsPerThread; j++) { TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() - .taskId("normal-" + threadId + "-" + j) + .taskId(taskId) // Use same taskId as queue .contextId("test-context") .status(new TaskStatus(TaskState.WORKING)) .isFinal(false) @@ -334,7 +357,7 @@ void testParallelReplicationBehavior() throws InterruptedException { startLatch.await(); for (int j = 0; j < eventsPerThread; j++) { TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() - .taskId("replicated-" + threadId + "-" + j) + .taskId(taskId) // Use same taskId as queue .contextId("test-context") .status(new TaskStatus(TaskState.COMPLETED)) .isFinal(true) @@ -360,6 +383,14 @@ void testParallelReplicationBehavior() throws InterruptedException { executor.shutdown(); assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS), "Executor should shutdown within 5 seconds"); + // Wait for MainEventBusProcessor to process all events + try { + assertTrue(processingLatch.await(10, TimeUnit.SECONDS), + "MainEventBusProcessor should have processed all events within timeout"); + } finally { + mainEventBusProcessor.setCallback(null); + } + // Only the normal enqueue operations should have triggered replication // numThreads/2 threads * eventsPerThread events each = total expected replication calls int expectedReplicationCalls = (numThreads / 2) * eventsPerThread; @@ -467,9 +498,10 @@ void testPoisonPillSentViaTransactionAwareEvent() throws InterruptedException { String taskId = "poison-pill-test"; EventQueue queue = queueManager.createOrTap(taskId); + TaskStatusUpdateEvent event = createEventForTask(taskId); - // Enqueue a normal event first - queue.enqueueEvent(testEvent); + // Enqueue a normal event first and wait for processing + waitForEventProcessing(() -> queue.enqueueEvent(event)); // In the new architecture, QueueClosedEvent (poison pill) is sent via CDI events // when JpaDatabaseTaskStore.save() persists a final task and the transaction commits diff --git a/server-common/src/main/java/io/a2a/server/events/EventQueue.java b/server-common/src/main/java/io/a2a/server/events/EventQueue.java index 0dff01a31..3f1e21816 100644 --- a/server-common/src/main/java/io/a2a/server/events/EventQueue.java +++ b/server-common/src/main/java/io/a2a/server/events/EventQueue.java @@ -85,6 +85,7 @@ public static class EventQueueBuilder { private int queueSize = DEFAULT_QUEUE_SIZE; private @Nullable EventEnqueueHook hook; private @Nullable String taskId; + private @Nullable String tempId = null; private List onCloseCallbacks = new java.util.ArrayList<>(); private @Nullable TaskStateProvider taskStateProvider; private @Nullable MainEventBus mainEventBus; @@ -122,6 +123,17 @@ public EventQueueBuilder taskId(String taskId) { return this; } + /** + * Sets the temporary task ID if this queue is for a temporary task. + * + * @param tempId the temporary task ID, or null if not temporary + * @return this builder + */ + public EventQueueBuilder tempId(@Nullable String tempId) { + this.tempId = tempId; + return this; + } + /** * Adds a callback to be executed when the queue is closed. * @@ -170,7 +182,7 @@ public EventQueue build() { if (taskId == null) { throw new IllegalStateException("taskId is required for EventQueue creation"); } - return new MainQueue(queueSize, hook, taskId, onCloseCallbacks, taskStateProvider, mainEventBus); + return new MainQueue(queueSize, hook, taskId, tempId, onCloseCallbacks, taskStateProvider, mainEventBus); } } @@ -330,14 +342,16 @@ static class MainQueue extends EventQueue { private final CountDownLatch pollingStartedLatch = new CountDownLatch(1); private final AtomicBoolean pollingStarted = new AtomicBoolean(false); private final @Nullable EventEnqueueHook enqueueHook; - private final String taskId; + private volatile String taskId; // Volatile to allow switching from temp to real ID across threads private final List onCloseCallbacks; private final @Nullable TaskStateProvider taskStateProvider; private final MainEventBus mainEventBus; + private final @Nullable String tempId; MainQueue(int queueSize, @Nullable EventEnqueueHook hook, String taskId, + @Nullable String tempId, List onCloseCallbacks, @Nullable TaskStateProvider taskStateProvider, @Nullable MainEventBus mainEventBus) { @@ -345,11 +359,12 @@ static class MainQueue extends EventQueue { this.semaphore = new Semaphore(queueSize, true); this.enqueueHook = hook; this.taskId = taskId; + this.tempId = tempId; this.onCloseCallbacks = List.copyOf(onCloseCallbacks); // Defensive copy this.taskStateProvider = taskStateProvider; this.mainEventBus = Objects.requireNonNull(mainEventBus, "MainEventBus is required"); - LOGGER.debug("Created MainQueue for task {} with {} onClose callbacks, TaskStateProvider: {}, MainEventBus configured", - taskId, onCloseCallbacks.size(), taskStateProvider != null); + LOGGER.debug("Created MainQueue for task {} (tempId={}) with {} onClose callbacks, TaskStateProvider: {}, MainEventBus configured", + taskId, tempId, onCloseCallbacks.size(), taskStateProvider != null); } @@ -367,6 +382,30 @@ public int getChildCount() { return children.size(); } + /** + * Returns the enqueue hook for replication (package-protected for MainEventBusProcessor). + */ + @Nullable EventEnqueueHook getEnqueueHook() { + return enqueueHook; + } + + /** + * Returns the temporary task ID if this queue was created with one, null otherwise. + * Package-protected for MainEventBusProcessor access. + */ + @Nullable String getTempId() { + return tempId; + } + + /** + * Updates the task ID when switching from temporary to real ID. + * Package-protected for MainEventBusProcessor access. + * @param newTaskId the real task ID to use + */ + void setTaskId(String newTaskId) { + this.taskId = newTaskId; + } + @Override public EventQueueItem dequeueEventItem(int waitMilliSeconds) throws EventQueueClosedException { throw new UnsupportedOperationException("MainQueue cannot be consumed directly - use tap() to create a ChildQueue for consumption"); @@ -401,12 +440,44 @@ public void enqueueItem(EventQueueItem item) { // Submit to MainEventBus for centralized persistence + distribution // MainEventBus is guaranteed non-null by constructor requirement - mainEventBus.submit(taskId, this, item); + // Note: Replication now happens in MainEventBusProcessor AFTER persistence + + // For temp ID scenarios: if event contains a real task ID different from our temp ID, + // use the real ID when submitting to MainEventBus. This ensures subsequent events + // are submitted with the correct ID even before MainEventBusProcessor updates our taskId field. + // IMPORTANT: We don't update our taskId here - that happens in MainEventBusProcessor + // AFTER TaskManager validates the ID (e.g., checking for duplicate task IDs). + String submissionId = taskId; + if (tempId != null && taskId.equals(tempId)) { + // Not yet switched - try to extract real ID from first event + String eventTaskId = extractTaskId(event); + if (eventTaskId != null && !eventTaskId.equals(taskId)) { + // Event has a different (real) ID - use it for submission + // but keep our taskId as temp-UUID for now (until MainEventBusProcessor switches it) + submissionId = eventTaskId; + LOGGER.debug("MainQueue submitting event with real ID {} (current taskId: {})", eventTaskId, taskId); + } + } + // If already switched (tempId != null but taskId != tempId), submissionId stays as taskId (the real ID) + // This ensures subsequent events with stale temp IDs still get submitted with the correct real ID + mainEventBus.submit(submissionId, this, item); + } - // Trigger replication hook if configured (for inter-process replication) - if (enqueueHook != null) { - enqueueHook.onEnqueue(item); + /** + * Extracts taskId from an event (package-private helper). + */ + @Nullable + private String extractTaskId(Event event) { + if (event instanceof io.a2a.spec.Task task) { + return task.id(); + } else if (event instanceof io.a2a.spec.TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.taskId(); + } else if (event instanceof io.a2a.spec.TaskArtifactUpdateEvent artifactUpdate) { + return artifactUpdate.taskId(); + } else if (event instanceof io.a2a.spec.Message message) { + return message.taskId(); } + return null; } @Override @@ -429,20 +500,15 @@ public void signalQueuePollerStarted() { void childClosing(ChildQueue child, boolean immediate) { children.remove(child); // Remove the closing child - // Close immediately if requested - if (immediate) { - LOGGER.debug("MainQueue closing immediately (immediate=true)"); - this.doClose(immediate); - return; - } - // If there are still children, keep queue open if (!children.isEmpty()) { LOGGER.debug("MainQueue staying open: {} children remaining", children.size()); return; } - // No children left - check if task is finalized before auto-closing + // No children left - check if task is finalized before closing + // IMPORTANT: This check must happen BEFORE the immediate flag check + // to prevent closing queues for non-final tasks (fire-and-forget, resubscription support) if (taskStateProvider != null && taskId != null) { boolean isFinalized = taskStateProvider.isTaskFinalized(taskId); if (!isFinalized) { @@ -533,6 +599,10 @@ public void close(boolean immediate) { public void close(boolean immediate, boolean notifyParent) { throw new UnsupportedOperationException("MainQueue does not support notifyParent parameter - use close(boolean) instead"); } + + String getTaskId() { + return taskId; + } } static class ChildQueue extends EventQueue { diff --git a/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java b/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java index abd043614..edcb55307 100644 --- a/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java +++ b/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java @@ -59,7 +59,23 @@ public void add(String taskId, EventQueue queue) { } @Override - public void switchKey(String oldId, String newId) { + public synchronized void switchKey(String oldId, String newId) { + // Check if already switched (idempotent operation) + // This check is now safe because the method is synchronized + EventQueue existingNew = queues.get(newId); + if (existingNew != null) { + EventQueue oldQueue = queues.get(oldId); + if (oldQueue == null) { + // Already switched - idempotent success + LOGGER.debug("Queue already switched from {} to {}, skipping", oldId, newId); + return; + } else { + // Different queue already at newId - error + throw new TaskQueueExistsException(); + } + } + + // Normal path: move queue from oldId to newId EventQueue queue = queues.remove(oldId); if (queue == null) { throw new IllegalStateException("No queue found for old ID: " + oldId); @@ -100,7 +116,12 @@ public void close(String taskId) { @Override public EventQueue createOrTap(String taskId) { - LOGGER.debug("createOrTap called for task {}, current map size: {}", taskId, queues.size()); + return createOrTap(taskId, null); + } + + @Override + public EventQueue createOrTap(String taskId, @Nullable String tempId) { + LOGGER.debug("createOrTap called for task {} (tempId={}), current map size: {}", taskId, tempId, queues.size()); EventQueue existing = queues.get(taskId); // Lazy cleanup: only remove closed queues if task is finalized @@ -123,8 +144,8 @@ public EventQueue createOrTap(String taskId) { EventQueue newQueue = null; if (existing == null) { // Use builder pattern for cleaner queue creation - // Use the new taskId-aware builder method if available - newQueue = factory.builder(taskId).build(); + // Pass tempId to the builder + newQueue = factory.builder(taskId).tempId(tempId).build(); // Make sure an existing queue has not been added in the meantime existing = queues.putIfAbsent(taskId, newQueue); } @@ -136,8 +157,8 @@ public EventQueue createOrTap(String taskId) { EventQueue result = main.tap(); // Always return ChildQueue if (existing == null) { - LOGGER.debug("Created new MainQueue {} for task {}, returning ChildQueue {} (map size: {})", - System.identityHashCode(main), taskId, System.identityHashCode(result), queues.size()); + LOGGER.debug("Created new MainQueue {} for task {} (tempId={}), returning ChildQueue {} (map size: {})", + System.identityHashCode(main), taskId, tempId, System.identityHashCode(result), queues.size()); } else { LOGGER.debug("Tapped existing MainQueue {} -> ChildQueue {} for task {}", System.identityHashCode(main), System.identityHashCode(result), taskId); diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBus.java b/server-common/src/main/java/io/a2a/server/events/MainEventBus.java index 73500254e..90080b1e2 100644 --- a/server-common/src/main/java/io/a2a/server/events/MainEventBus.java +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBus.java @@ -17,9 +17,9 @@ public MainEventBus() { this.queue = new LinkedBlockingDeque<>(); } - public void submit(String taskId, EventQueue eventQueue, EventQueueItem item) { + void submit(String taskId, EventQueue.MainQueue mainQueue, EventQueueItem item) { try { - queue.put(new MainEventBusContext(taskId, eventQueue, item)); + queue.put(new MainEventBusContext(taskId, mainQueue, item)); LOGGER.debug("Submitted event for task {} to MainEventBus (queue size: {})", taskId, queue.size()); } catch (InterruptedException e) { @@ -28,7 +28,7 @@ public void submit(String taskId, EventQueue eventQueue, EventQueueItem item) { } } - public MainEventBusContext take() throws InterruptedException { + MainEventBusContext take() throws InterruptedException { LOGGER.debug("MainEventBus: Waiting to take event (current queue size: {})...", queue.size()); MainEventBusContext context = queue.take(); LOGGER.debug("MainEventBus: Took event for task {} (remaining queue size: {})", diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusContext.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusContext.java index f8e5e03ec..292a60f21 100644 --- a/server-common/src/main/java/io/a2a/server/events/MainEventBusContext.java +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusContext.java @@ -2,7 +2,7 @@ import java.util.Objects; -record MainEventBusContext(String taskId, EventQueue eventQueue, EventQueueItem eventQueueItem) { +record MainEventBusContext(String taskId, EventQueue.MainQueue eventQueue, EventQueueItem eventQueueItem) { MainEventBusContext { Objects.requireNonNull(taskId, "taskId cannot be null"); Objects.requireNonNull(eventQueue, "eventQueue cannot be null"); diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java index 91aaac3ef..a511fedf1 100644 --- a/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java @@ -14,6 +14,7 @@ import io.a2a.spec.A2AServerException; import io.a2a.spec.Event; import io.a2a.spec.InternalError; +import io.a2a.spec.Message; import io.a2a.spec.Task; import io.a2a.spec.TaskArtifactUpdateEvent; import io.a2a.spec.TaskStatusUpdateEvent; @@ -63,14 +64,17 @@ public class MainEventBusProcessor implements Runnable { private final PushNotificationSender pushSender; + private final QueueManager queueManager; + private volatile boolean running = true; private @Nullable Thread processorThread; @Inject - public MainEventBusProcessor(MainEventBus eventBus, TaskStore taskStore, PushNotificationSender pushSender) { + public MainEventBusProcessor(MainEventBus eventBus, TaskStore taskStore, PushNotificationSender pushSender, QueueManager queueManager) { this.eventBus = eventBus; this.taskStore = taskStore; this.pushSender = pushSender; + this.queueManager = queueManager; } /** @@ -165,10 +169,17 @@ public void run() { private void processEvent(MainEventBusContext context) { String taskId = context.taskId(); Event event = context.eventQueueItem().getEvent(); - EventQueue eventQueue = context.eventQueue(); - - LOGGER.debug("MainEventBusProcessor: Processing event for task {}: {} (queue type: {})", - taskId, event.getClass().getSimpleName(), eventQueue.getClass().getSimpleName()); + // MainEventBus.submit() guarantees this is always a MainQueue + EventQueue.MainQueue mainQueue = (EventQueue.MainQueue) context.eventQueue(); + + // Determine if this is a temp ID scenario + // If MainQueue was created with a tempId, then isTempId is true ONLY for events + // BEFORE the queue key is switched. After switching, mainQueue.taskId != tempId, + // so isTempId should be false for subsequent events. + String tempId = mainQueue.getTempId(); + boolean isTempId = (tempId != null && mainQueue.getTaskId().equals(tempId)); + LOGGER.debug("MainEventBusProcessor: Processing event for task {} (tempId={}, isTempId={}): {}", + taskId, tempId, isTempId, event.getClass().getSimpleName()); Event eventToDistribute = null; try { @@ -176,8 +187,17 @@ private void processEvent(MainEventBusContext context) { // If this throws, we distribute an error to ensure "persist before client visibility" try { - updateTaskStore(taskId, event); + updateTaskStore(taskId, event, isTempId, mainQueue); + eventToDistribute = event; // Success - distribute original event + + // Trigger replication AFTER successful persistence (moved from MainQueue.enqueueEvent) + // This ensures replicated events have real task IDs, not temp-UUIDs + EventEnqueueHook hook = mainQueue.getEnqueueHook(); + if (hook != null) { + LOGGER.debug("Triggering replication hook for task {} after successful persistence", taskId); + hook.onEnqueue(context.eventQueueItem()); + } } catch (InternalError e) { // Persistence failed - create error event to distribute instead LOGGER.error("Failed to persist event for task {}, distributing error to clients", taskId, e); @@ -208,19 +228,14 @@ private void processEvent(MainEventBusContext context) { eventToDistribute = new InternalError("Internal error: event processing failed"); } - if (eventQueue instanceof EventQueue.MainQueue mainQueue) { - int childCount = mainQueue.getChildCount(); - LOGGER.debug("MainEventBusProcessor: Distributing {} to {} children for task {}", - eventToDistribute.getClass().getSimpleName(), childCount, taskId); - // Create new EventQueueItem with the event to distribute (original or error) - EventQueueItem itemToDistribute = new LocalEventQueueItem(eventToDistribute); - mainQueue.distributeToChildren(itemToDistribute); - LOGGER.debug("MainEventBusProcessor: Distributed {} to {} children for task {}", - eventToDistribute.getClass().getSimpleName(), childCount, taskId); - } else { - LOGGER.warn("MainEventBusProcessor: Expected MainQueue but got {} for task {}", - eventQueue.getClass().getSimpleName(), taskId); - } + int childCount = mainQueue.getChildCount(); + LOGGER.debug("MainEventBusProcessor: Distributing {} to {} children for task {}", + eventToDistribute.getClass().getSimpleName(), childCount, taskId); + // Create new EventQueueItem with the event to distribute (original or error) + EventQueueItem itemToDistribute = new LocalEventQueueItem(eventToDistribute); + mainQueue.distributeToChildren(itemToDistribute); + LOGGER.debug("MainEventBusProcessor: Distributed {} to {} children for task {}", + eventToDistribute.getClass().getSimpleName(), childCount, taskId); LOGGER.debug("MainEventBusProcessor: Completed processing event for task {}", taskId); @@ -240,9 +255,7 @@ private void processEvent(MainEventBusContext context) { } finally { // ALWAYS release semaphore, even if processing fails // Balances the acquire() in MainQueue.enqueueEvent() - if (eventQueue instanceof EventQueue.MainQueue mainQueue) { - mainQueue.releaseSemaphore(); - } + mainQueue.releaseSemaphore(); } } } @@ -260,20 +273,46 @@ private void processEvent(MainEventBusContext context) { * See Gemini's comment: https://github.com/a2aproject/a2a-java/pull/515#discussion_r2604621833 *

    * + * @param taskId the task ID (may be temporary) + * @param event the event to persist + * @param isTempId whether the task ID is temporary (from MainQueue.tempId) + * @param mainQueue the main queue (for updating taskId after ID switch) * @throws InternalError if persistence fails */ - private void updateTaskStore(String taskId, Event event) throws InternalError { + private void updateTaskStore(String taskId, Event event, boolean isTempId, EventQueue.MainQueue mainQueue) throws InternalError { try { // Extract contextId from event (all relevant events have it) String contextId = extractContextId(event); + String eventTaskId = null; + if (isTempId) { + eventTaskId = extractTaskId(event); + LOGGER.debug("Temp ID scenario: taskId={}, event={}", taskId, eventTaskId); + } + + // For temp ID scenarios (before switch), use the MainQueue's current taskId (temp ID) for TaskManager. + // This ensures duplicate detection works when switching from temp to real ID. + // After switch (isTempId=false), use the submission taskId (real ID). + String taskIdForManager = isTempId ? mainQueue.getTaskId() : taskId; + // Create temporary TaskManager instance for this event - TaskManager taskManager = new TaskManager(taskId, contextId, taskStore, null); + // Use taskIdForManager (temp ID if switching, real ID otherwise) + // isTempId allows TaskManager.checkIdsAndUpdateIfNecessary() to switch to real ID + TaskManager taskManager = new TaskManager(taskIdForManager, contextId, taskStore, null, isTempId); // Use TaskManager.process() - handles all event types with existing logic taskManager.process(event); LOGGER.debug("TaskStore updated via TaskManager.process() for task {}: {}", taskId, event.getClass().getSimpleName()); + + // If this was a temp ID scenario and the event had a different task ID, + // then TaskManager switched from temp to real ID. Update the queue key and MainQueue taskId. + if (isTempId && eventTaskId != null && !eventTaskId.equals(taskIdForManager)) { + LOGGER.debug("Switching queue key from temp {} to real {}", taskIdForManager, eventTaskId); + queueManager.switchKey(taskIdForManager, eventTaskId); + // Also update the MainQueue's taskId so subsequent events use the real ID + mainQueue.setTaskId(eventTaskId); + } } catch (InternalError e) { LOGGER.error("Error updating TaskStore via TaskManager for task {}", taskId, e); // Rethrow to prevent distributing unpersisted event to clients @@ -350,6 +389,25 @@ private String extractContextId(Event event) { return null; } + /** + * Extracts taskId from an event. + * Returns null if the event type doesn't have a taskId (e.g., Throwable). + */ + @Nullable + private String extractTaskId(Event event) { + if (event instanceof Task task) { + return task.id(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.taskId(); + } else if (event instanceof TaskArtifactUpdateEvent artifactUpdate) { + return artifactUpdate.taskId(); + } else if (event instanceof Message message) { + return message.taskId(); + } + // Other events (Throwable, etc.) don't have taskId + return null; + } + /** * Checks if an event represents a final task state. * diff --git a/server-common/src/main/java/io/a2a/server/events/QueueManager.java b/server-common/src/main/java/io/a2a/server/events/QueueManager.java index b4ab24317..2de0feb73 100644 --- a/server-common/src/main/java/io/a2a/server/events/QueueManager.java +++ b/server-common/src/main/java/io/a2a/server/events/QueueManager.java @@ -167,7 +167,24 @@ public interface QueueManager { * @param taskId the task identifier * @return a MainQueue (if new task) or ChildQueue (if tapping existing) */ - EventQueue createOrTap(String taskId); + default EventQueue createOrTap(String taskId) { + return createOrTap(taskId, null); + } + + /** + * Creates a MainQueue if none exists, or taps the existing queue to create a ChildQueue. + *

    + * This is the primary method used by {@link io.a2a.server.requesthandlers.DefaultRequestHandler}: + *

      + *
    • New task: Creates and returns a MainQueue with tempId
    • + *
    • Resubscription: Taps existing MainQueue and returns a ChildQueue
    • + *
    + * + * @param taskId the task identifier (may be temporary) + * @param tempId the temporary task ID if taskId is temporary, null otherwise + * @return a MainQueue (if new task) or ChildQueue (if tapping existing) + */ + EventQueue createOrTap(String taskId, @Nullable String tempId); /** * Waits for the queue's consumer polling to start. diff --git a/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java b/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java index 07bba3a9b..c96a461f5 100644 --- a/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java +++ b/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java @@ -474,7 +474,8 @@ public Task onCancelTask(TaskIdParams params, ServerCallContext context) throws task.id(), task.contextId(), taskStore, - null); + null, + false); // Not a temp ID - task already exists ResultAggregator resultAggregator = new ResultAggregator(taskManager, null, executor, eventConsumerExecutor); @@ -509,15 +510,18 @@ public Task onCancelTask(TaskIdParams params, ServerCallContext context) throws @Override public EventKind onMessageSend(MessageSendParams params, ServerCallContext context) throws A2AError { LOGGER.debug("onMessageSend - task: {}; context {}", params.message().taskId(), params.message().contextId()); - MessageSendSetup mss = initMessageSend(params, context); - @Nullable String initialTaskId = mss.requestContext.getTaskId(); - // For non-streaming, taskId can be null initially (will be set when Task event arrives) - // Use a temporary ID for queue creation if needed (same pattern as streaming) - String queueTaskId = initialTaskId != null ? initialTaskId : "temp-" + java.util.UUID.randomUUID(); - LOGGER.debug("Request context taskId: {} (queue key: {})", initialTaskId, queueTaskId); + // Generate temp taskId BEFORE initMessageSend if client didn't provide one + // This ensures TaskManager is created with a valid taskId for ResultAggregator + @Nullable String messageTaskId = params.message().taskId(); + boolean isTempId = messageTaskId == null; + String queueTaskId = isTempId ? "temp-" + java.util.UUID.randomUUID() : messageTaskId; + LOGGER.debug("Message taskId: {} (queue key: {}, isTempId: {})", messageTaskId, queueTaskId, isTempId); - EventQueue queue = queueManager.createOrTap(queueTaskId); + MessageSendSetup mss = initMessageSend(params, context, queueTaskId, isTempId); + + // Pass the actual tempId string (queueTaskId) if this is a temp ID, null otherwise + EventQueue queue = queueManager.createOrTap(queueTaskId, isTempId ? queueTaskId : null); final java.util.concurrent.atomic.AtomicReference<@NonNull String> taskId = new java.util.concurrent.atomic.AtomicReference<>(queueTaskId); ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null, executor, eventConsumerExecutor); @@ -571,12 +575,10 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte String currentId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); if (!Objects.equals(currentId, createdTask.id())) { try { - queueManager.switchKey(currentId, createdTask.id()); + switchFromTempToRealTaskId(currentId, createdTask.id(), mss.taskManager); taskId.set(createdTask.id()); - LOGGER.debug("Switched non-streaming queue from {} to real task ID {}", - currentId, createdTask.id()); } catch (TaskQueueExistsException | IllegalStateException e) { - String msg = "Failed to switch queue key from " + currentId + " to " + createdTask.id() + ": " + e.getMessage(); + String msg = "Failed to switch from temp ID " + currentId + " to real task ID " + createdTask.id() + ": " + e.getMessage(); LOGGER.error(msg, e); throw new InternalError(msg); } @@ -710,16 +712,19 @@ public Flow.Publisher onMessageSendStream( MessageSendParams params, ServerCallContext context) throws A2AError { LOGGER.debug("onMessageSendStream START - task: {}; context: {}; runningAgents: {}", params.message().taskId(), params.message().contextId(), runningAgents.size()); - MessageSendSetup mss = initMessageSend(params, context); - @Nullable String initialTaskId = mss.requestContext.getTaskId(); - // For streaming, taskId can be null initially (will be set when Task event arrives) - // Use a temporary ID for queue creation if needed - String queueTaskId = initialTaskId != null ? initialTaskId : "temp-" + java.util.UUID.randomUUID(); + // Generate temp taskId BEFORE initMessageSend if client didn't provide one + // This ensures TaskManager is created with a valid taskId for ResultAggregator + @Nullable String messageTaskId = params.message().taskId(); + boolean isTempId = messageTaskId == null; + String queueTaskId = isTempId ? "temp-" + java.util.UUID.randomUUID() : messageTaskId; + + MessageSendSetup mss = initMessageSend(params, context, queueTaskId, isTempId); final AtomicReference<@NonNull String> taskId = new AtomicReference<>(queueTaskId); + // Pass the actual tempId string (queueTaskId) if this is a temp ID, null otherwise @SuppressWarnings("NullAway") - EventQueue queue = queueManager.createOrTap(taskId.get()); + EventQueue queue = queueManager.createOrTap(taskId.get(), isTempId ? queueTaskId : null); LOGGER.debug("Created/tapped queue for task {}: {}", taskId.get(), queue); // Store push notification config SYNCHRONOUSLY for new tasks before agent starts @@ -753,24 +758,24 @@ public Flow.Publisher onMessageSendStream( processor(createTubeConfig(), results, ((errorConsumer, item) -> { Event event = item.getEvent(); if (event instanceof Task createdTask) { - if (!Objects.equals(taskId.get(), createdTask.id())) { - errorConsumer.accept(new InternalError("Task ID mismatch in agent response")); - } - // Switch from temporary ID to real task ID if they differ String currentId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); if (!Objects.equals(currentId, createdTask.id())) { try { - queueManager.switchKey(currentId, createdTask.id()); + switchFromTempToRealTaskId(currentId, createdTask.id(), mss.taskManager); taskId.set(createdTask.id()); - LOGGER.debug("Switched streaming queue from {} to real task ID {}", - currentId, createdTask.id()); - } catch (TaskQueueExistsException e) { - errorConsumer.accept(new InternalError("Queue already exists for task " + createdTask.id())); - } catch (IllegalStateException e) { - errorConsumer.accept(new InternalError("Failed to switch queue key: " + e.getMessage())); + } catch (TaskQueueExistsException | IllegalStateException e) { + errorConsumer.accept(new InternalError("Failed to switch from temp ID " + currentId + + " to real task ID " + createdTask.id() + ": " + e.getMessage())); + return true; // Don't proceed to final check if switch failed } } + + // Final verification AFTER switch attempt + String finalTaskId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); + if (!finalTaskId.equals(createdTask.id())) { + errorConsumer.accept(new InternalError("Task ID mismatch in agent response")); + } } return true; })); @@ -918,7 +923,7 @@ public Flow.Publisher onResubscribeToTask( throw new TaskNotFoundError(); } - TaskManager taskManager = new TaskManager(task.id(), task.contextId(), taskStore, null); + TaskManager taskManager = new TaskManager(task.id(), task.contextId(), taskStore, null, false); // Not a temp ID - task already exists ResultAggregator resultAggregator = new ResultAggregator(taskManager, null, executor, eventConsumerExecutor); EventQueue queue = queueManager.tap(task.id()); LOGGER.debug("onResubscribeToTask - tapped queue: {}", queue != null ? System.identityHashCode(queue) : "null"); @@ -1064,12 +1069,59 @@ private CompletableFuture cleanupProducer(@Nullable CompletableFuture agentFuture = runningAgents.remove(tempId); + if (agentFuture != null) { + runningAgents.put(realId, agentFuture); + LOGGER.debug("Moved runningAgents from {} to {}", tempId, realId); + } + + // 4. Switch push notification configs (if configured) + if (pushConfigStore != null) { + pushConfigStore.switchKey(tempId, realId); + } + + // 5. Update TaskManager's taskId + taskManager.setTaskId(realId); + + LOGGER.debug("Completed switch from temp ID {} to real task ID {}", tempId, realId); + } + + private MessageSendSetup initMessageSend(MessageSendParams params, ServerCallContext context, String taskId, boolean isTempId) { TaskManager taskManager = new TaskManager( - params.message().taskId(), + taskId, params.message().contextId(), taskStore, - params.message()); + params.message(), + isTempId); Task task = taskManager.getTask(); if (task != null) { @@ -1082,9 +1134,11 @@ private MessageSendSetup initMessageSend(MessageSendParams params, ServerCallCon } } + // For RequestContext, pass null as taskId when using temp ID to avoid validation error + // The temp UUID is only for queue management, not for the RequestContext RequestContext requestContext = requestContextBuilder.get() .setParams(params) - .setTaskId(task == null ? params.message().taskId() : task.id()) + .setTaskId(isTempId ? null : taskId) .setContextId(params.message().contextId()) .setTask(task) .setServerCallContext(context) diff --git a/server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStore.java b/server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStore.java index 093ff910d..d18610875 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStore.java +++ b/server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStore.java @@ -118,4 +118,12 @@ public void deleteInfo(String taskId, String configId) { pushNotificationInfos.remove(taskId); } } + + @Override + public void switchKey(String oldTaskId, String newTaskId) { + List configs = pushNotificationInfos.remove(oldTaskId); + if (configs != null && !configs.isEmpty()) { + pushNotificationInfos.put(newTaskId, configs); + } + } } diff --git a/server-common/src/main/java/io/a2a/server/tasks/PushNotificationConfigStore.java b/server-common/src/main/java/io/a2a/server/tasks/PushNotificationConfigStore.java index 828b066a6..e6686b999 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/PushNotificationConfigStore.java +++ b/server-common/src/main/java/io/a2a/server/tasks/PushNotificationConfigStore.java @@ -120,4 +120,20 @@ public interface PushNotificationConfigStore { */ void deleteInfo(String taskId, String configId); + /** + * Switches push notification configuration from an old task ID to a new task ID. + *

    + * Used when transitioning from a temporary task ID (e.g., "temp-UUID") to the real task ID + * when the Task event arrives with the actual task.id. Moves all push notification configs + * associated with the old task ID to the new task ID. + *

    + *

    + * If no configs exist for the old task ID, this method returns normally (no-op). + *

    + * + * @param oldTaskId the temporary/old task identifier + * @param newTaskId the real/new task identifier + */ + void switchKey(String oldTaskId, String newTaskId); + } diff --git a/server-common/src/main/java/io/a2a/server/tasks/TaskManager.java b/server-common/src/main/java/io/a2a/server/tasks/TaskManager.java index fd3696a60..dbaf76ac2 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/TaskManager.java +++ b/server-common/src/main/java/io/a2a/server/tasks/TaskManager.java @@ -12,7 +12,7 @@ import io.a2a.spec.A2AServerException; import io.a2a.spec.Event; -import io.a2a.spec.InvalidParamsError; +import io.a2a.spec.InternalError; import io.a2a.spec.Message; import io.a2a.spec.Task; import io.a2a.spec.TaskArtifactUpdateEvent; @@ -31,13 +31,29 @@ public class TaskManager { private final TaskStore taskStore; private final @Nullable Message initialMessage; private volatile @Nullable Task currentTask; + private volatile boolean isTempId; - public TaskManager(@Nullable String taskId, @Nullable String contextId, TaskStore taskStore, @Nullable Message initialMessage) { + public TaskManager(@Nullable String taskId, @Nullable String contextId, TaskStore taskStore, @Nullable Message initialMessage, boolean isTempId) { checkNotNullParam("taskStore", taskStore); this.taskId = taskId; this.contextId = contextId; this.taskStore = taskStore; this.initialMessage = initialMessage; + this.isTempId = isTempId; + } + + /** + * Updates the taskId from a temporary ID to the real task ID. + * Only allowed when this TaskManager was created with isTempId=true. + * Called by DefaultRequestHandler when switching from temp-UUID to real task.id. + */ + public void setTaskId(String newTaskId) { + if (!isTempId) { + throw new IllegalStateException("Cannot change taskId - not created with temporary ID"); + } + LOGGER.debug("TaskManager switching taskId from {} to {}", this.taskId, newTaskId); + this.taskId = newTaskId; + this.isTempId = false; // No longer temporary after switch } @Nullable String getTaskId() { @@ -131,9 +147,25 @@ public Task updateWithMessage(Message message, Task task) { private void checkIdsAndUpdateIfNecessary(String eventTaskId, String eventContextId) throws A2AServerException { if (taskId != null && !eventTaskId.equals(taskId)) { - throw new A2AServerException( - "Invalid task id", - new InvalidParamsError(String.format("Task in event doesn't match TaskManager "))); + // Allow switching from temporary ID to real task ID + // This happens when client sends message without taskId and agent creates Task with real ID + if (isTempId) { + // Verify the new task ID doesn't already exist in the store + // If it does, the agent is trying to return an existing task when it should create a new one + Task existingTask = taskStore.get(eventTaskId); + if (existingTask != null) { + throw new A2AServerException( + "Invalid task id", + new InternalError(String.format("Agent returned existing task ID %s when expecting new task", eventTaskId))); + } + LOGGER.debug("TaskManager allowing taskId switch from temp {} to real {}", taskId, eventTaskId); + taskId = eventTaskId; + isTempId = false; // No longer temporary after switch + } else { + throw new A2AServerException( + "Invalid task id", + new InternalError(String.format("Task in event doesn't match TaskManager "))); + } } if (taskId == null) { taskId = eventTaskId; diff --git a/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java b/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java index 3c84bb2ae..146bfb10a 100644 --- a/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java +++ b/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java @@ -64,7 +64,8 @@ public void init() { // Set up MainEventBus and processor for production-like test environment InMemoryTaskStore taskStore = new InMemoryTaskStore(); mainEventBus = new MainEventBus(); - mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER); + InMemoryQueueManager queueManager = new InMemoryQueueManager(taskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); EventQueueUtil.start(mainEventBusProcessor); eventQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) diff --git a/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java b/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java index daf0c1dc9..2499a8173 100644 --- a/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java +++ b/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java @@ -62,7 +62,8 @@ public void init() { // Set up MainEventBus and processor for production-like test environment InMemoryTaskStore taskStore = new InMemoryTaskStore(); mainEventBus = new MainEventBus(); - mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER); + InMemoryQueueManager queueManager = new InMemoryQueueManager(taskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); EventQueueUtil.start(mainEventBusProcessor); eventQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) diff --git a/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java b/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java index 808a1107a..3e09ff2af 100644 --- a/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java +++ b/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java @@ -35,10 +35,9 @@ public void setUp() { taskStateProvider = new MockTaskStateProvider(); taskStore = new InMemoryTaskStore(); mainEventBus = new MainEventBus(); - mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER); - EventQueueUtil.start(mainEventBusProcessor); - queueManager = new InMemoryQueueManager(taskStateProvider, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); } @AfterEach diff --git a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java index 4535bbeb3..9c64f03f9 100644 --- a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java +++ b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java @@ -111,10 +111,9 @@ public void cancel(RequestContext context, EventQueue eventQueue) throws A2AErro // Create MainEventBus and MainEventBusProcessor (production code path) mainEventBus = new MainEventBus(); - mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, pushSender); - EventQueueUtil.start(mainEventBusProcessor); - queueManager = new InMemoryQueueManager(inMemoryTaskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, pushSender, queueManager); + EventQueueUtil.start(mainEventBusProcessor); requestHandler = DefaultRequestHandler.create( executor, taskStore, queueManager, pushConfigStore, mainEventBusProcessor, internalExecutor, internalExecutor); diff --git a/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java b/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java index 42a940fae..19df17aac 100644 --- a/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java +++ b/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java @@ -28,6 +28,7 @@ import io.a2a.server.tasks.PushNotificationSender; import io.a2a.server.tasks.TaskUpdater; import io.a2a.spec.A2AError; +import io.a2a.spec.Event; import io.a2a.spec.ListTaskPushNotificationConfigParams; import io.a2a.spec.ListTaskPushNotificationConfigResult; import io.a2a.spec.Message; @@ -37,6 +38,7 @@ import io.a2a.spec.Task; import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -68,11 +70,10 @@ void setUp() { // Create MainEventBus and MainEventBusProcessor (production code path) mainEventBus = new MainEventBus(); - mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER); - EventQueueUtil.start(mainEventBusProcessor); - // Pass taskStore as TaskStateProvider to queueManager for task-aware queue management queueManager = new InMemoryQueueManager(taskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); agentExecutor = new TestAgentExecutor(); @@ -997,6 +998,187 @@ void testBlockingMessageStoresPushNotificationConfigForExistingTask() throws Exc assertEquals("https://example.com/existing-webhook", storedConfig.url()); } + /** + * Test that sending a message WITHOUT taskId works correctly (like TCK). + * Agent emits Task with the same ID it receives from context.getTaskId(). + */ + @Test + @Timeout(10) + void testMessageSendWithoutTaskIdCreatesTask() throws Exception { + String contextId = "temp-id-ctx"; + + // Agent does same as TCK: emit SUBMITTED task with received taskId, then WORKING (fire-and-forget) + agentExecutor.setExecuteCallback((context, queue) -> { + Task task = context.getTask(); + if (task == null) { + // First message: create SUBMITTED task using context's taskId + // (RequestContext generates a real UUID when message.taskId is null) + task = Task.builder() + .id(context.getTaskId()) + .contextId(context.getContextId()) + .status(new TaskStatus(TaskState.SUBMITTED)) + .history(List.of(context.getMessage())) + .build(); + queue.enqueueEvent(task); + } + // Set to WORKING (fire-and-forget like TCK) + TaskUpdater updater = new TaskUpdater(context, queue); + updater.startWork(); + // Don't complete - just return (fire-and-forget) + }); + + // Send message WITHOUT taskId (null) in blocking mode + Message message = Message.builder() + .messageId("msg-no-taskid") + .role(Message.Role.USER) + .parts(new TextPart("message without taskId")) + .taskId(null) // No taskId! + .contextId(contextId) + .build(); + + MessageSendConfiguration config = MessageSendConfiguration.builder() + .blocking(true) + .build(); + + MessageSendParams params = new MessageSendParams(message, config, null, ""); + + // Call blocking onMessageSend + Object result = requestHandler.onMessageSend(params, serverCallContext); + + // Verify result is a Task + assertTrue(result instanceof Task, "Result should be a Task"); + Task resultTask = (Task) result; + + // Task should have an ID (auto-generated) + assertNotNull(resultTask.id(), "Task should have an ID"); + + // ID should NOT start with "temp-" (that's just the queue key, not the task.id) + assertTrue(!resultTask.id().startsWith("temp-"), + "Task ID should not start with 'temp-', got: " + resultTask.id()); + + assertEquals(contextId, resultTask.contextId()); + assertEquals(TaskState.WORKING, resultTask.status().state()); + + // Verify task is persisted in TaskStore + Task storedTask = taskStore.get(resultTask.id()); + assertNotNull(storedTask, "Task should be stored"); + assertEquals(resultTask.id(), storedTask.id()); + } + + /** + * Test message send without taskId using streaming. + * This tests the same scenario as testMessageSendWithoutTaskIdCreatesTask but with streaming. + * TCK streaming tests were failing with similar temp-to-real ID switching issues. + */ + @Test + @Timeout(10) + void testMessageSendStreamWithoutTaskIdCreatesTask() throws Exception { + String contextId = "temp-id-ctx-stream"; + + // Agent does same as TCK: emit SUBMITTED task with received taskId, then WORKING (fire-and-forget) + agentExecutor.setExecuteCallback((context, queue) -> { + Task task = context.getTask(); + if (task == null) { + // First message: create SUBMITTED task using context's taskId + // (RequestContext generates a real UUID when message.taskId is null) + task = Task.builder() + .id(context.getTaskId()) + .contextId(context.getContextId()) + .status(new TaskStatus(TaskState.SUBMITTED)) + .history(List.of(context.getMessage())) + .build(); + queue.enqueueEvent(task); + } + // Set to WORKING (fire-and-forget like TCK) + TaskUpdater updater = new TaskUpdater(context, queue); + updater.startWork(); + // Don't complete - just return (fire-and-forget) + }); + + // Send message WITHOUT taskId (null) in streaming mode + Message message = Message.builder() + .messageId("msg-no-taskid-stream") + .role(Message.Role.USER) + .parts(new TextPart("message without taskId streaming")) + .taskId(null) // No taskId! + .contextId(contextId) + .build(); + + MessageSendParams params = new MessageSendParams(message, null, null, ""); + + // Call streaming onMessageSendStream + var publisher = requestHandler.onMessageSendStream(params, serverCallContext); + + // Collect events from stream + List events = new java.util.ArrayList<>(); + CountDownLatch completionLatch = new CountDownLatch(1); + AtomicBoolean hasError = new AtomicBoolean(false); + final Throwable[] error = new Throwable[1]; + + publisher.subscribe(new java.util.concurrent.Flow.Subscriber() { + private java.util.concurrent.Flow.Subscription subscription; + + @Override + public void onSubscribe(java.util.concurrent.Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(Long.MAX_VALUE); // Request all events + } + + @Override + public void onNext(Event event) { + events.add(event); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + hasError.set(true); + error[0] = throwable; + throwable.printStackTrace(); + completionLatch.countDown(); + } + + @Override + public void onComplete() { + completionLatch.countDown(); + } + }); + + // Wait for stream to complete + assertTrue(completionLatch.await(5, TimeUnit.SECONDS), "Stream should complete"); + if (hasError.get()) { + fail("Stream had error: " + error[0], error[0]); + } + + // Should have received at least 2 events: Task (SUBMITTED) and TaskStatusUpdateEvent (WORKING) + assertTrue(events.size() >= 2, "Should have at least 2 events, got: " + events.size()); + + // First event should be Task with SUBMITTED state + Event firstEvent = events.get(0); + assertTrue(firstEvent instanceof Task, "First event should be Task"); + Task firstTask = (Task) firstEvent; + + assertNotNull(firstTask.id(), "Task should have an ID"); + assertTrue(!firstTask.id().startsWith("temp-"), + "Task ID should not start with 'temp-', got: " + firstTask.id()); + assertEquals(contextId, firstTask.contextId()); + assertEquals(TaskState.SUBMITTED, firstTask.status().state()); + + // Second event should be TaskStatusUpdateEvent with WORKING state + Event secondEvent = events.get(1); + assertTrue(secondEvent instanceof TaskStatusUpdateEvent, + "Second event should be TaskStatusUpdateEvent, got: " + secondEvent.getClass().getSimpleName()); + TaskStatusUpdateEvent statusUpdate = (TaskStatusUpdateEvent) secondEvent; + assertEquals(firstTask.id(), statusUpdate.taskId(), "Status update should have same task ID"); + assertEquals(TaskState.WORKING, statusUpdate.status().state()); + + // Verify task is persisted in TaskStore with WORKING state + Task storedTask = taskStore.get(firstTask.id()); + assertNotNull(storedTask, "Task should be stored"); + assertEquals(firstTask.id(), storedTask.id()); + assertEquals(TaskState.WORKING, storedTask.status().state(), "Stored task should have WORKING state"); + } + /** * Simple test agent executor that allows controlling execution timing */ diff --git a/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java b/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java index b33fa4132..0e25e9aad 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java @@ -245,11 +245,10 @@ void testConsumeAndBreakNonBlocking() throws Exception { // Create an event queue using QueueManager (which has access to builder) MainEventBus mainEventBus = new MainEventBus(); InMemoryTaskStore taskStore = new InMemoryTaskStore(); - MainEventBusProcessor processor = new MainEventBusProcessor(mainEventBus, taskStore, task -> {}); - EventQueueUtil.start(processor); - InMemoryQueueManager queueManager = new InMemoryQueueManager(new MockTaskStateProvider(), mainEventBus); + MainEventBusProcessor processor = new MainEventBusProcessor(mainEventBus, taskStore, task -> {}, queueManager); + EventQueueUtil.start(processor); EventQueue queue = queueManager.getEventQueueBuilder(taskId).build().tap(); diff --git a/server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java b/server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java index f14ebc0fe..abeeed859 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java @@ -42,7 +42,7 @@ public class TaskManagerTest { public void init() throws Exception { minimalTask = fromJson(TASK_JSON, Task.class); taskStore = new InMemoryTaskStore(); - taskManager = new TaskManager(minimalTask.id(), minimalTask.contextId(), taskStore, null); + taskManager = new TaskManager(minimalTask.id(), minimalTask.contextId(), taskStore, null, false); } @Test @@ -136,7 +136,7 @@ public void testEnsureTaskExisting() { @Test public void testEnsureTaskNonExistentForStatusUpdate() throws A2AServerException { // Tests that an update event instantiates a new task and that - TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null); + TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null, false); TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() .taskId("new-task") .contextId("some-context") @@ -157,7 +157,7 @@ public void testEnsureTaskNonExistentForStatusUpdate() throws A2AServerException @Test public void testSaveTaskEventNewTaskNoTaskId() throws A2AServerException { - TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null); + TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null, false); Task task = Task.builder() .id("new-task-id") .contextId("some-context") @@ -175,7 +175,7 @@ public void testSaveTaskEventNewTaskNoTaskId() throws A2AServerException { @Test public void testGetTaskNoTaskId() { - TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null); + TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null, false); Task retrieved = taskManagerWithoutId.getTask(); assertNull(retrieved); } @@ -321,7 +321,7 @@ public void testTaskArtifactUpdateEventAppendNullWithExistingArtifact() throws A @Test public void testAddingTaskWithDifferentIdFails() { // Test that adding a task with a different id from the taskmanager's taskId fails - TaskManager taskManagerWithId = new TaskManager("task-abc", "session-xyz", taskStore, null); + TaskManager taskManagerWithId = new TaskManager("task-abc", "session-xyz", taskStore, null, false); Task differentTask = Task.builder() .id("different-task-id") @@ -337,7 +337,7 @@ public void testAddingTaskWithDifferentIdFails() { @Test public void testAddingTaskWithDifferentIdViaStatusUpdateFails() { // Test that adding a status update with different taskId fails - TaskManager taskManagerWithId = new TaskManager("task-abc", "session-xyz", taskStore, null); + TaskManager taskManagerWithId = new TaskManager("task-abc", "session-xyz", taskStore, null, false); TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() .taskId("different-task-id") @@ -354,7 +354,7 @@ public void testAddingTaskWithDifferentIdViaStatusUpdateFails() { @Test public void testAddingTaskWithDifferentIdViaArtifactUpdateFails() { // Test that adding an artifact update with different taskId fails - TaskManager taskManagerWithId = new TaskManager("task-abc", "session-xyz", taskStore, null); + TaskManager taskManagerWithId = new TaskManager("task-abc", "session-xyz", taskStore, null, false); Artifact artifact = Artifact.builder() .artifactId("artifact-id") @@ -382,7 +382,7 @@ public void testTaskWithNoMessageUsesInitialMessage() throws A2AServerException .messageId("initial-msg-id") .build(); - TaskManager taskManagerWithInitialMessage = new TaskManager(null, null, taskStore, initialMessage); + TaskManager taskManagerWithInitialMessage = new TaskManager(null, null, taskStore, initialMessage, false); // Use a status update event instead of a Task to trigger createTask TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() @@ -413,7 +413,7 @@ public void testTaskWithMessageDoesNotUseInitialMessage() throws A2AServerExcept .messageId("initial-msg-id") .build(); - TaskManager taskManagerWithInitialMessage = new TaskManager(null, null, taskStore, initialMessage); + TaskManager taskManagerWithInitialMessage = new TaskManager(null, null, taskStore, initialMessage, false); Message taskMessage = Message.builder() .role(Message.Role.AGENT) @@ -533,11 +533,11 @@ public void testMultipleArtifactsWithDifferentArtifactIds() throws A2AServerExce @Test public void testInvalidTaskIdValidation() { // Test that creating TaskManager with null taskId is allowed (Python allows None) - TaskManager taskManagerWithNullId = new TaskManager(null, "context", taskStore, null); + TaskManager taskManagerWithNullId = new TaskManager(null, "context", taskStore, null, false); assertNull(taskManagerWithNullId.getTaskId()); // Test that empty string task ID is handled (Java doesn't have explicit validation like Python) - TaskManager taskManagerWithEmptyId = new TaskManager("", "context", taskStore, null); + TaskManager taskManagerWithEmptyId = new TaskManager("", "context", taskStore, null, false); assertEquals("", taskManagerWithEmptyId.getTaskId()); } @@ -625,7 +625,7 @@ public void testCreateTaskWithInitialMessage() throws A2AServerException { .messageId("initial-msg-id") .build(); - TaskManager taskManagerWithMessage = new TaskManager(null, null, taskStore, initialMessage); + TaskManager taskManagerWithMessage = new TaskManager(null, null, taskStore, initialMessage, false); TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() .taskId("new-task-id") @@ -653,7 +653,7 @@ public void testCreateTaskWithInitialMessage() throws A2AServerException { @Test public void testCreateTaskWithoutInitialMessage() throws A2AServerException { // Test task creation without initial message - TaskManager taskManagerWithoutMessage = new TaskManager(null, null, taskStore, null); + TaskManager taskManagerWithoutMessage = new TaskManager(null, null, taskStore, null, false); TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() .taskId("new-task-id") @@ -677,7 +677,7 @@ public void testCreateTaskWithoutInitialMessage() throws A2AServerException { @Test public void testSaveTaskInternal() throws A2AServerException { // Test equivalent of _save_task functionality through saveTaskEvent - TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null); + TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null, false); Task newTask = Task.builder() .id("test-task-id") @@ -701,7 +701,7 @@ public void testUpdateWithMessage() throws A2AServerException { .messageId("initial-msg-id") .build(); - TaskManager taskManagerWithInitialMessage = new TaskManager(null, null, taskStore, initialMessage); + TaskManager taskManagerWithInitialMessage = new TaskManager(null, null, taskStore, initialMessage, false); Message taskMessage = Message.builder() .role(Message.Role.AGENT) diff --git a/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java b/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java index fd195e0a5..73da17824 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java @@ -16,6 +16,7 @@ import io.a2a.server.events.EventQueue; import io.a2a.server.events.EventQueueItem; import io.a2a.server.events.EventQueueUtil; +import io.a2a.server.events.InMemoryQueueManager; import io.a2a.server.events.MainEventBus; import io.a2a.server.events.MainEventBusProcessor; import io.a2a.spec.Event; @@ -56,7 +57,8 @@ public void init() { // Set up MainEventBus and processor for production-like test environment InMemoryTaskStore taskStore = new InMemoryTaskStore(); mainEventBus = new MainEventBus(); - mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER); + InMemoryQueueManager queueManager = new InMemoryQueueManager(taskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); EventQueueUtil.start(mainEventBusProcessor); eventQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus)