diff --git a/examples/cloud-deployment/scripts/deploy.sh b/examples/cloud-deployment/scripts/deploy.sh index e267f3302..fff2a6061 100755 --- a/examples/cloud-deployment/scripts/deploy.sh +++ b/examples/cloud-deployment/scripts/deploy.sh @@ -212,6 +212,22 @@ echo "" echo "Deploying PostgreSQL..." kubectl apply -f ../k8s/01-postgres.yaml echo "Waiting for PostgreSQL to be ready..." + +# Wait for pod to be created (StatefulSet takes time to create pod) +for i in {1..30}; do + if kubectl get pod -l app=postgres -n a2a-demo 2>/dev/null | grep -q postgres; then + echo "PostgreSQL pod found, waiting for ready state..." + break + fi + if [ $i -eq 30 ]; then + echo -e "${RED}ERROR: PostgreSQL pod not created after 30 seconds${NC}" + kubectl get statefulset -n a2a-demo + exit 1 + fi + sleep 1 +done + +# Now wait for pod to be ready kubectl wait --for=condition=Ready pod -l app=postgres -n a2a-demo --timeout=120s echo -e "${GREEN}✓ PostgreSQL deployed${NC}" diff --git a/extras/push-notification-config-store-database-jpa/src/main/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStore.java b/extras/push-notification-config-store-database-jpa/src/main/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStore.java index 36245e277..799d71733 100644 --- a/extras/push-notification-config-store-database-jpa/src/main/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStore.java +++ b/extras/push-notification-config-store-database-jpa/src/main/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStore.java @@ -164,4 +164,45 @@ public void deleteInfo(String taskId, String configId) { taskId, configId); } } + + @Transactional + @Override + public void switchKey(String oldTaskId, String newTaskId) { + LOGGER.debug("Switching PushNotificationConfigs from Task '{}' to Task '{}'", oldTaskId, newTaskId); + + // Find all configs for the old task ID + TypedQuery query = em.createQuery( + "SELECT c FROM JpaPushNotificationConfig c WHERE c.id.taskId = :taskId", + JpaPushNotificationConfig.class); + query.setParameter("taskId", oldTaskId); + List configs = query.getResultList(); + + if (configs.isEmpty()) { + LOGGER.debug("No PushNotificationConfigs found for Task '{}', nothing to switch", oldTaskId); + return; + } + + // For each config, create a new entity with the new task ID and remove the old one + for (JpaPushNotificationConfig oldConfig : configs) { + try { + // Create new config with new task ID + JpaPushNotificationConfig newConfig = JpaPushNotificationConfig.createFromConfig( + newTaskId, oldConfig.getConfig()); + + // Remove old config and persist new one + em.remove(oldConfig); + em.persist(newConfig); + + LOGGER.debug("Switched PushNotificationConfig ID '{}' from Task '{}' to Task '{}'", + oldConfig.getId().getConfigId(), oldTaskId, newTaskId); + } catch (JsonProcessingException e) { + LOGGER.error("Failed to switch PushNotificationConfig ID '{}' from Task '{}' to Task '{}'", + oldConfig.getId().getConfigId(), oldTaskId, newTaskId, e); + throw new RuntimeException("Failed to switch PushNotificationConfig", e); + } + } + + LOGGER.debug("Successfully switched {} PushNotificationConfigs from Task '{}' to Task '{}'", + configs.size(), oldTaskId, newTaskId); + } } diff --git a/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java b/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java index 586ab11a7..0ec264c5a 100644 --- a/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java +++ b/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java @@ -13,8 +13,10 @@ import io.a2a.server.events.EventQueueFactory; import io.a2a.server.events.EventQueueItem; import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.events.MainEventBus; import io.a2a.server.events.QueueManager; import io.a2a.server.tasks.TaskStateProvider; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,10 +47,12 @@ protected ReplicatedQueueManager() { } @Inject - public ReplicatedQueueManager(ReplicationStrategy replicationStrategy, TaskStateProvider taskStateProvider) { + public ReplicatedQueueManager(ReplicationStrategy replicationStrategy, + TaskStateProvider taskStateProvider, + MainEventBus mainEventBus) { this.replicationStrategy = replicationStrategy; this.taskStateProvider = taskStateProvider; - this.delegate = new InMemoryQueueManager(new ReplicatingEventQueueFactory(), taskStateProvider); + this.delegate = new InMemoryQueueManager(new ReplicatingEventQueueFactory(), taskStateProvider, mainEventBus); } @@ -57,6 +61,11 @@ public void add(String taskId, EventQueue queue) { delegate.add(taskId, queue); } + @Override + public void switchKey(String oldId, String newId) { + delegate.switchKey(oldId, newId); + } + @Override public EventQueue get(String taskId) { return delegate.get(taskId); @@ -77,7 +86,12 @@ public void close(String taskId) { @Override public EventQueue createOrTap(String taskId) { - EventQueue queue = delegate.createOrTap(taskId); + return createOrTap(taskId, null); + } + + @Override + public EventQueue createOrTap(String taskId, @Nullable String tempId) { + EventQueue queue = delegate.createOrTap(taskId, tempId); return queue; } @@ -98,7 +112,9 @@ public void onReplicatedEvent(@Observes ReplicatedEventQueueItem replicatedEvent } // Get or create a ChildQueue for this task (creates MainQueue if it doesn't exist) - EventQueue childQueue = delegate.createOrTap(replicatedEvent.getTaskId()); + // Replicated events should always have real task IDs (not temp IDs) because + // replication now happens AFTER TaskStore persistence in MainEventBusProcessor + EventQueue childQueue = delegate.createOrTap(replicatedEvent.getTaskId(), null); try { // Get the MainQueue to enqueue the replicated event item @@ -152,12 +168,11 @@ public EventQueue.EventQueueBuilder builder(String taskId) { // which sends the QueueClosedEvent after the database transaction commits. // This ensures proper ordering and transactional guarantees. - // Return the builder with callbacks - return delegate.getEventQueueBuilder(taskId) - .taskId(taskId) - .hook(new ReplicationHook(taskId)) - .addOnCloseCallback(delegate.getCleanupCallback(taskId)) - .taskStateProvider(taskStateProvider); + // Call createBaseEventQueueBuilder() directly to avoid infinite recursion + // (getEventQueueBuilder() would delegate back to this factory, creating a loop) + // The base builder already includes: taskId, cleanup callback, taskStateProvider, mainEventBus + return delegate.createBaseEventQueueBuilder(taskId) + .hook(new ReplicationHook(taskId)); } } diff --git a/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java b/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java index 43571cd30..5fcc4126b 100644 --- a/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java +++ b/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java @@ -22,12 +22,18 @@ import io.a2a.server.events.EventQueueClosedException; import io.a2a.server.events.EventQueueItem; import io.a2a.server.events.EventQueueTestHelper; +import io.a2a.server.events.EventQueueUtil; +import io.a2a.server.events.MainEventBus; +import io.a2a.server.events.MainEventBusProcessor; import io.a2a.server.events.QueueClosedEvent; +import io.a2a.server.tasks.InMemoryTaskStore; +import io.a2a.server.tasks.PushNotificationSender; import io.a2a.spec.Event; import io.a2a.spec.StreamingEventKind; import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatus; import io.a2a.spec.TaskStatusUpdateEvent; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -35,10 +41,27 @@ class ReplicatedQueueManagerTest { private ReplicatedQueueManager queueManager; private StreamingEventKind testEvent; + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; @BeforeEach void setUp() { - queueManager = new ReplicatedQueueManager(new NoOpReplicationStrategy(), new MockTaskStateProvider(true)); + // Create MainEventBus first + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + + // Create QueueManager before MainEventBusProcessor (processor needs it as parameter) + queueManager = new ReplicatedQueueManager( + new NoOpReplicationStrategy(), + new MockTaskStateProvider(true), + mainEventBus + ); + + // Create MainEventBusProcessor with QueueManager + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + testEvent = TaskStatusUpdateEvent.builder() .taskId("test-task") .contextId("test-context") @@ -47,25 +70,82 @@ void setUp() { .build(); } + /** + * Helper to create a test event with the specified taskId. + * This ensures taskId consistency between queue creation and event creation. + */ + private TaskStatusUpdateEvent createEventForTask(String taskId) { + return TaskStatusUpdateEvent.builder() + .taskId(taskId) + .contextId("test-context") + .status(new TaskStatus(TaskState.SUBMITTED)) + .isFinal(false) + .build(); + } + + @AfterEach + void tearDown() { + if (mainEventBusProcessor != null) { + mainEventBusProcessor.setCallback(null); // Clear any test callbacks + EventQueueUtil.stop(mainEventBusProcessor); + } + mainEventBusProcessor = null; + mainEventBus = null; + queueManager = null; + } + + /** + * Helper to wait for MainEventBusProcessor to process an event. + * Replaces polling patterns with deterministic callback-based waiting. + * + * @param action the action that triggers event processing + * @throws InterruptedException if waiting is interrupted + * @throws AssertionError if processing doesn't complete within timeout + */ + private void waitForEventProcessing(Runnable action) throws InterruptedException { + CountDownLatch processingLatch = new CountDownLatch(1); + mainEventBusProcessor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, io.a2a.spec.Event event) { + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String taskId) { + // Not needed for basic event processing wait + } + }); + + try { + action.run(); + assertTrue(processingLatch.await(5, TimeUnit.SECONDS), + "MainEventBusProcessor should have processed the event within timeout"); + } finally { + mainEventBusProcessor.setCallback(null); + } + } + @Test void testReplicationStrategyTriggeredOnNormalEnqueue() throws InterruptedException { CountingReplicationStrategy strategy = new CountingReplicationStrategy(); - queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true)); + queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true), mainEventBus); String taskId = "test-task-1"; EventQueue queue = queueManager.createOrTap(taskId); + TaskStatusUpdateEvent event = createEventForTask(taskId); - queue.enqueueEvent(testEvent); + // Wait for MainEventBusProcessor to process the event and trigger replication + waitForEventProcessing(() -> queue.enqueueEvent(event)); assertEquals(1, strategy.getCallCount()); assertEquals(taskId, strategy.getLastTaskId()); - assertEquals(testEvent, strategy.getLastEvent()); + assertEquals(event, strategy.getLastEvent()); } @Test void testReplicationStrategyNotTriggeredOnReplicatedEvent() throws InterruptedException { CountingReplicationStrategy strategy = new CountingReplicationStrategy(); - queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true)); + queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true), mainEventBus); String taskId = "test-task-2"; EventQueue queue = queueManager.createOrTap(taskId); @@ -79,17 +159,19 @@ void testReplicationStrategyNotTriggeredOnReplicatedEvent() throws InterruptedEx @Test void testReplicationStrategyWithCountingImplementation() throws InterruptedException { CountingReplicationStrategy countingStrategy = new CountingReplicationStrategy(); - queueManager = new ReplicatedQueueManager(countingStrategy, new MockTaskStateProvider(true)); + queueManager = new ReplicatedQueueManager(countingStrategy, new MockTaskStateProvider(true), mainEventBus); String taskId = "test-task-3"; EventQueue queue = queueManager.createOrTap(taskId); + TaskStatusUpdateEvent event = createEventForTask(taskId); - queue.enqueueEvent(testEvent); - queue.enqueueEvent(testEvent); + // Wait for MainEventBusProcessor to process each event + waitForEventProcessing(() -> queue.enqueueEvent(event)); + waitForEventProcessing(() -> queue.enqueueEvent(event)); assertEquals(2, countingStrategy.getCallCount()); assertEquals(taskId, countingStrategy.getLastTaskId()); - assertEquals(testEvent, countingStrategy.getLastEvent()); + assertEquals(event, countingStrategy.getLastEvent()); ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, testEvent); queueManager.onReplicatedEvent(replicatedEvent); @@ -100,46 +182,45 @@ void testReplicationStrategyWithCountingImplementation() throws InterruptedExcep @Test void testReplicatedEventDeliveredToCorrectQueue() throws InterruptedException { String taskId = "test-task-4"; + TaskStatusUpdateEvent eventForTask = createEventForTask(taskId); // Use matching taskId EventQueue queue = queueManager.createOrTap(taskId); - ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, testEvent); - queueManager.onReplicatedEvent(replicatedEvent); + ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, eventForTask); - Event dequeuedEvent; - try { - dequeuedEvent = queue.dequeueEventItem(100).getEvent(); - } catch (EventQueueClosedException e) { - fail("Queue should not be closed"); - return; - } - assertEquals(testEvent, dequeuedEvent); + // Use callback to wait for event processing + EventQueueItem item = dequeueEventWithRetry(queue, () -> queueManager.onReplicatedEvent(replicatedEvent)); + assertNotNull(item, "Event should be available in queue"); + Event dequeuedEvent = item.getEvent(); + assertEquals(eventForTask, dequeuedEvent); } @Test void testReplicatedEventCreatesQueueIfNeeded() throws InterruptedException { String taskId = "non-existent-task"; + TaskStatusUpdateEvent eventForTask = createEventForTask(taskId); // Use matching taskId // Verify no queue exists initially assertNull(queueManager.get(taskId)); - ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, testEvent); - - // Process the replicated event - assertDoesNotThrow(() -> queueManager.onReplicatedEvent(replicatedEvent)); - - // Verify that a queue was created and the event was enqueued - EventQueue queue = queueManager.get(taskId); - assertNotNull(queue, "Queue should be created when processing replicated event for non-existent task"); - - // Verify the event was enqueued by dequeuing it - Event dequeuedEvent; - try { - dequeuedEvent = queue.dequeueEventItem(100).getEvent(); - } catch (EventQueueClosedException e) { - fail("Queue should not be closed"); - return; - } - assertEquals(testEvent, dequeuedEvent, "The replicated event should be enqueued in the newly created queue"); + // Create a ChildQueue BEFORE processing the replicated event + // This ensures the ChildQueue exists when MainEventBusProcessor distributes the event + EventQueue childQueue = queueManager.createOrTap(taskId); + assertNotNull(childQueue, "ChildQueue should be created"); + + // Verify MainQueue was created + EventQueue mainQueue = queueManager.get(taskId); + assertNotNull(mainQueue, "MainQueue should exist after createOrTap"); + + ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, eventForTask); + + // Process the replicated event and wait for distribution + // Use callback to wait for event processing + EventQueueItem item = dequeueEventWithRetry(childQueue, () -> { + assertDoesNotThrow(() -> queueManager.onReplicatedEvent(replicatedEvent)); + }); + assertNotNull(item, "Event should be available in queue"); + Event dequeuedEvent = item.getEvent(); + assertEquals(eventForTask, dequeuedEvent, "The replicated event should be enqueued in the newly created queue"); } @Test @@ -170,17 +251,18 @@ void testBasicQueueManagerFunctionality() throws InterruptedException { void testQueueToTaskIdMappingMaintained() throws InterruptedException { String taskId = "test-task-6"; CountingReplicationStrategy countingStrategy = new CountingReplicationStrategy(); - queueManager = new ReplicatedQueueManager(countingStrategy, new MockTaskStateProvider(true)); + queueManager = new ReplicatedQueueManager(countingStrategy, new MockTaskStateProvider(true), mainEventBus); + TaskStatusUpdateEvent event = createEventForTask(taskId); EventQueue queue = queueManager.createOrTap(taskId); - queue.enqueueEvent(testEvent); + waitForEventProcessing(() -> queue.enqueueEvent(event)); assertEquals(taskId, countingStrategy.getLastTaskId()); queueManager.close(taskId); // Task is active, so NO poison pill is sent EventQueue newQueue = queueManager.createOrTap(taskId); - newQueue.enqueueEvent(testEvent); + waitForEventProcessing(() -> newQueue.enqueueEvent(event)); assertEquals(taskId, countingStrategy.getLastTaskId()); // 2 replication calls: 1 testEvent, 1 testEvent (no QueueClosedEvent because task is active) @@ -217,17 +299,32 @@ void testReplicatedEventJsonSerialization() throws Exception { @Test void testParallelReplicationBehavior() throws InterruptedException { CountingReplicationStrategy strategy = new CountingReplicationStrategy(); - queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true)); + queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true), mainEventBus); String taskId = "parallel-test-task"; EventQueue queue = queueManager.createOrTap(taskId); int numThreads = 10; int eventsPerThread = 5; + int expectedEventCount = (numThreads / 2) * eventsPerThread; // Only normal enqueues ExecutorService executor = Executors.newFixedThreadPool(numThreads); CountDownLatch startLatch = new CountDownLatch(1); CountDownLatch doneLatch = new CountDownLatch(numThreads); + // Set up callback to wait for all events to be processed by MainEventBusProcessor + CountDownLatch processingLatch = new CountDownLatch(expectedEventCount); + mainEventBusProcessor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String tid, io.a2a.spec.Event event) { + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String tid) { + // Not needed for this test + } + }); + // Launch threads that will enqueue events normally (should trigger replication) for (int i = 0; i < numThreads / 2; i++) { final int threadId = i; @@ -236,7 +333,7 @@ void testParallelReplicationBehavior() throws InterruptedException { startLatch.await(); for (int j = 0; j < eventsPerThread; j++) { TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() - .taskId("normal-" + threadId + "-" + j) + .taskId(taskId) // Use same taskId as queue .contextId("test-context") .status(new TaskStatus(TaskState.WORKING)) .isFinal(false) @@ -260,7 +357,7 @@ void testParallelReplicationBehavior() throws InterruptedException { startLatch.await(); for (int j = 0; j < eventsPerThread; j++) { TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() - .taskId("replicated-" + threadId + "-" + j) + .taskId(taskId) // Use same taskId as queue .contextId("test-context") .status(new TaskStatus(TaskState.COMPLETED)) .isFinal(true) @@ -286,6 +383,14 @@ void testParallelReplicationBehavior() throws InterruptedException { executor.shutdown(); assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS), "Executor should shutdown within 5 seconds"); + // Wait for MainEventBusProcessor to process all events + try { + assertTrue(processingLatch.await(10, TimeUnit.SECONDS), + "MainEventBusProcessor should have processed all events within timeout"); + } finally { + mainEventBusProcessor.setCallback(null); + } + // Only the normal enqueue operations should have triggered replication // numThreads/2 threads * eventsPerThread events each = total expected replication calls int expectedReplicationCalls = (numThreads / 2) * eventsPerThread; @@ -297,7 +402,7 @@ void testParallelReplicationBehavior() throws InterruptedException { void testReplicatedEventSkippedWhenTaskInactive() throws InterruptedException { // Create a task state provider that returns false (task is inactive) MockTaskStateProvider stateProvider = new MockTaskStateProvider(false); - queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider); + queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider, mainEventBus); String taskId = "inactive-task"; @@ -316,30 +421,32 @@ void testReplicatedEventSkippedWhenTaskInactive() throws InterruptedException { void testReplicatedEventProcessedWhenTaskActive() throws InterruptedException { // Create a task state provider that returns true (task is active) MockTaskStateProvider stateProvider = new MockTaskStateProvider(true); - queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider); + queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider, mainEventBus); String taskId = "active-task"; + TaskStatusUpdateEvent eventForTask = createEventForTask(taskId); // Use matching taskId // Verify no queue exists initially assertNull(queueManager.get(taskId)); - // Process a replicated event for an active task - ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, testEvent); - queueManager.onReplicatedEvent(replicatedEvent); + // Create a ChildQueue BEFORE processing the replicated event + // This ensures the ChildQueue exists when MainEventBusProcessor distributes the event + EventQueue childQueue = queueManager.createOrTap(taskId); + assertNotNull(childQueue, "ChildQueue should be created"); - // Queue should be created and event should be enqueued - EventQueue queue = queueManager.get(taskId); - assertNotNull(queue, "Queue should be created for active task"); + // Verify MainQueue was created + EventQueue mainQueue = queueManager.get(taskId); + assertNotNull(mainQueue, "MainQueue should exist after createOrTap"); - // Verify the event was enqueued - Event dequeuedEvent; - try { - dequeuedEvent = queue.dequeueEventItem(100).getEvent(); - } catch (EventQueueClosedException e) { - fail("Queue should not be closed"); - return; - } - assertEquals(testEvent, dequeuedEvent, "Event should be enqueued for active task"); + // Process a replicated event for an active task + ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, eventForTask); + + // Verify the event was enqueued and distributed to our ChildQueue + // Use callback to wait for event processing + EventQueueItem item = dequeueEventWithRetry(childQueue, () -> queueManager.onReplicatedEvent(replicatedEvent)); + assertNotNull(item, "Event should be available in queue"); + Event dequeuedEvent = item.getEvent(); + assertEquals(eventForTask, dequeuedEvent, "Event should be enqueued for active task"); } @@ -347,7 +454,7 @@ void testReplicatedEventProcessedWhenTaskActive() throws InterruptedException { void testReplicatedEventToExistingQueueWhenTaskBecomesInactive() throws InterruptedException { // Create a task state provider that returns true initially MockTaskStateProvider stateProvider = new MockTaskStateProvider(true); - queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider); + queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider, mainEventBus); String taskId = "task-becomes-inactive"; @@ -387,13 +494,14 @@ void testReplicatedEventToExistingQueueWhenTaskBecomesInactive() throws Interrup @Test void testPoisonPillSentViaTransactionAwareEvent() throws InterruptedException { CountingReplicationStrategy strategy = new CountingReplicationStrategy(); - queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true)); + queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true), mainEventBus); String taskId = "poison-pill-test"; EventQueue queue = queueManager.createOrTap(taskId); + TaskStatusUpdateEvent event = createEventForTask(taskId); - // Enqueue a normal event first - queue.enqueueEvent(testEvent); + // Enqueue a normal event first and wait for processing + waitForEventProcessing(() -> queue.enqueueEvent(event)); // In the new architecture, QueueClosedEvent (poison pill) is sent via CDI events // when JpaDatabaseTaskStore.save() persists a final task and the transaction commits @@ -451,36 +559,21 @@ void testQueueClosedEventJsonSerialization() throws Exception { @Test void testReplicatedQueueClosedEventTerminatesConsumer() throws InterruptedException { String taskId = "remote-close-test"; + TaskStatusUpdateEvent eventForTask = createEventForTask(taskId); // Use matching taskId EventQueue queue = queueManager.createOrTap(taskId); - // Enqueue a normal event - queue.enqueueEvent(testEvent); - // Simulate receiving QueueClosedEvent from remote node QueueClosedEvent closedEvent = new QueueClosedEvent(taskId); ReplicatedEventQueueItem replicatedClosedEvent = new ReplicatedEventQueueItem(taskId, closedEvent); - queueManager.onReplicatedEvent(replicatedClosedEvent); - // Dequeue the normal event first - EventQueueItem item1; - try { - item1 = queue.dequeueEventItem(100); - } catch (EventQueueClosedException e) { - fail("Should not throw on first dequeue"); - return; - } - assertNotNull(item1); - assertEquals(testEvent, item1.getEvent()); + // Dequeue the normal event first (use callback to wait for async processing) + EventQueueItem item1 = dequeueEventWithRetry(queue, () -> queue.enqueueEvent(eventForTask)); + assertNotNull(item1, "First event should be available"); + assertEquals(eventForTask, item1.getEvent()); - // Next dequeue should get the QueueClosedEvent - EventQueueItem item2; - try { - item2 = queue.dequeueEventItem(100); - } catch (EventQueueClosedException e) { - fail("Should not throw on second dequeue, should return the event"); - return; - } - assertNotNull(item2); + // Next dequeue should get the QueueClosedEvent (use callback to wait for async processing) + EventQueueItem item2 = dequeueEventWithRetry(queue, () -> queueManager.onReplicatedEvent(replicatedClosedEvent)); + assertNotNull(item2, "QueueClosedEvent should be available"); assertTrue(item2.getEvent() instanceof QueueClosedEvent, "Second event should be QueueClosedEvent"); } @@ -539,4 +632,25 @@ public void setActive(boolean active) { this.active = active; } } + + /** + * Helper method to dequeue an event after waiting for MainEventBusProcessor distribution. + * Uses callback-based waiting instead of polling for deterministic synchronization. + * + * @param queue the queue to dequeue from + * @param enqueueAction the action that enqueues the event (triggers event processing) + * @return the dequeued EventQueueItem, or null if queue is closed + */ + private EventQueueItem dequeueEventWithRetry(EventQueue queue, Runnable enqueueAction) throws InterruptedException { + // Wait for event to be processed and distributed + waitForEventProcessing(enqueueAction); + + // Event is now available, dequeue directly + try { + return queue.dequeueEventItem(100); + } catch (EventQueueClosedException e) { + // Queue closed, return null + return null; + } + } } \ No newline at end of file diff --git a/extras/queue-manager-replicated/core/src/test/java/io/a2a/server/events/EventQueueUtil.java b/extras/queue-manager-replicated/core/src/test/java/io/a2a/server/events/EventQueueUtil.java new file mode 100644 index 000000000..a91575aaa --- /dev/null +++ b/extras/queue-manager-replicated/core/src/test/java/io/a2a/server/events/EventQueueUtil.java @@ -0,0 +1,11 @@ +package io.a2a.server.events; + +public class EventQueueUtil { + public static void start(MainEventBusProcessor processor) { + processor.start(); + } + + public static void stop(MainEventBusProcessor processor) { + processor.stop(); + } +} diff --git a/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java b/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java index 18e18a2f1..cb5bdb25b 100644 --- a/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java +++ b/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java @@ -13,7 +13,6 @@ import java.util.concurrent.Executor; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Function; import jakarta.enterprise.inject.Instance; import jakarta.inject.Inject; @@ -21,6 +20,7 @@ import com.google.gson.JsonSyntaxException; import io.a2a.common.A2AHeaders; +import io.a2a.server.util.sse.SseFormatter; import io.a2a.grpc.utils.JSONRPCUtils; import io.a2a.jsonrpc.common.json.IdJsonMappingException; import io.a2a.jsonrpc.common.json.InvalidParamsJsonMappingException; @@ -65,7 +65,6 @@ import io.a2a.transport.jsonrpc.handler.JSONRPCHandler; import io.quarkus.security.Authenticated; import io.quarkus.vertx.web.Body; -import io.quarkus.vertx.web.ReactiveRoutes; import io.quarkus.vertx.web.Route; import io.smallrye.mutiny.Multi; import io.vertx.core.AsyncResult; @@ -74,6 +73,8 @@ import io.vertx.core.buffer.Buffer; import io.vertx.core.http.HttpServerResponse; import io.vertx.ext.web.RoutingContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; @Singleton public class A2AServerRoutes { @@ -135,8 +136,12 @@ public void invokeJSONRPCHandler(@Body String body, RoutingContext rc) { } else if (streaming) { final Multi> finalStreamingResponse = streamingResponse; executor.execute(() -> { - MultiSseSupport.subscribeObject( - finalStreamingResponse.map(i -> (Object) i), rc); + // Convert Multi to Multi with SSE formatting + AtomicLong eventIdCounter = new AtomicLong(0); + Multi sseEvents = finalStreamingResponse + .map(response -> SseFormatter.formatResponseAsSSE(response, eventIdCounter.getAndIncrement())); + // Write SSE-formatted strings to HTTP response + MultiSseSupport.writeSseStrings(sseEvents, rc, context); }); } else { @@ -295,34 +300,30 @@ private static com.google.protobuf.MessageOrBuilder convertToProto(A2AResponse + * This class only handles HTTP-specific concerns (writing to response, backpressure, disconnect). + * SSE formatting and JSON serialization are handled by {@link SseFormatter}. + */ private static class MultiSseSupport { + private static final Logger logger = LoggerFactory.getLogger(MultiSseSupport.class); private MultiSseSupport() { // Avoid direct instantiation. } - private static void initialize(HttpServerResponse response) { - if (response.bytesWritten() == 0) { - MultiMap headers = response.headers(); - if (headers.get(CONTENT_TYPE) == null) { - headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); - } - response.setChunked(true); - } - } - - private static void onWriteDone(Flow.Subscription subscription, AsyncResult ar, RoutingContext rc) { - if (ar.failed()) { - rc.fail(ar.cause()); - } else { - subscription.request(1); - } - } - - public static void write(Multi multi, RoutingContext rc) { + /** + * Write SSE-formatted strings to HTTP response. + * + * @param sseStrings Multi stream of SSE-formatted strings (from SseFormatter) + * @param rc Vert.x routing context + * @param context A2A server call context (for EventConsumer cancellation) + */ + public static void writeSseStrings(Multi sseStrings, RoutingContext rc, ServerCallContext context) { HttpServerResponse response = rc.response(); - multi.subscribe().withSubscriber(new Flow.Subscriber() { + + sseStrings.subscribe().withSubscriber(new Flow.Subscriber() { Flow.Subscription upstream; @Override @@ -330,6 +331,13 @@ public void onSubscribe(Flow.Subscription subscription) { this.upstream = subscription; this.upstream.request(1); + // Detect client disconnect and call EventConsumer.cancel() directly + response.closeHandler(v -> { + logger.info("SSE connection closed by client, calling EventConsumer.cancel() to stop polling loop"); + context.invokeEventConsumerCancelCallback(); + subscription.cancel(); + }); + // Notify tests that we are subscribed Runnable runnable = streamingMultiSseSupportSubscribedRunnable; if (runnable != null) { @@ -338,54 +346,50 @@ public void onSubscribe(Flow.Subscription subscription) { } @Override - public void onNext(Buffer item) { - initialize(response); - response.write(item, new Handler>() { + public void onNext(String sseEvent) { + // Set SSE headers on first event + if (response.bytesWritten() == 0) { + MultiMap headers = response.headers(); + if (headers.get(CONTENT_TYPE) == null) { + headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); + } + response.setChunked(true); + } + + // Write SSE-formatted string to response + response.write(Buffer.buffer(sseEvent), new Handler>() { @Override public void handle(AsyncResult ar) { - onWriteDone(upstream, ar, rc); + if (ar.failed()) { + // Client disconnected or write failed - cancel upstream to stop EventConsumer + upstream.cancel(); + rc.fail(ar.cause()); + } else { + upstream.request(1); + } } }); } @Override public void onError(Throwable throwable) { + // Cancel upstream to stop EventConsumer when error occurs + upstream.cancel(); rc.fail(throwable); } @Override public void onComplete() { - endOfStream(response); - } - }); - } - - public static void subscribeObject(Multi multi, RoutingContext rc) { - AtomicLong count = new AtomicLong(); - write(multi.map(new Function() { - @Override - public Buffer apply(Object o) { - if (o instanceof ReactiveRoutes.ServerSentEvent) { - ReactiveRoutes.ServerSentEvent ev = (ReactiveRoutes.ServerSentEvent) o; - long id = ev.id() != -1 ? ev.id() : count.getAndIncrement(); - String e = ev.event() == null ? "" : "event: " + ev.event() + "\n"; - String data = serializeResponse((A2AResponse) ev.data()); - return Buffer.buffer(e + "data: " + data + "\nid: " + id + "\n\n"); + if (response.bytesWritten() == 0) { + // No events written - still set SSE content type + MultiMap headers = response.headers(); + if (headers.get(CONTENT_TYPE) == null) { + headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); + } } - String data = serializeResponse((A2AResponse) o); - return Buffer.buffer("data: " + data + "\nid: " + count.getAndIncrement() + "\n\n"); - } - }), rc); - } - - private static void endOfStream(HttpServerResponse response) { - if (response.bytesWritten() == 0) { // No item - MultiMap headers = response.headers(); - if (headers.get(CONTENT_TYPE) == null) { - headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); + response.end(); } - } - response.end(); + }); } } } diff --git a/reference/jsonrpc/src/test/resources/application.properties b/reference/jsonrpc/src/test/resources/application.properties index 7b9cea9cc..e612925d4 100644 --- a/reference/jsonrpc/src/test/resources/application.properties +++ b/reference/jsonrpc/src/test/resources/application.properties @@ -1 +1,6 @@ quarkus.arc.selected-alternatives=io.a2a.server.apps.common.TestHttpClient + +# Debug logging for event processing and request handling +quarkus.log.category."io.a2a.server.events".level=DEBUG +quarkus.log.category."io.a2a.server.requesthandlers".level=DEBUG +quarkus.log.category."io.a2a.server.tasks".level=DEBUG diff --git a/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java b/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java index 46d0d38e6..7a50f0afb 100644 --- a/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java +++ b/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java @@ -15,7 +15,8 @@ import java.util.concurrent.Executor; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Function; + +import io.a2a.server.util.sse.SseFormatter; import jakarta.annotation.security.PermitAll; import jakarta.enterprise.inject.Instance; @@ -38,7 +39,6 @@ import io.a2a.util.Utils; import io.quarkus.security.Authenticated; import io.quarkus.vertx.web.Body; -import io.quarkus.vertx.web.ReactiveRoutes; import io.quarkus.vertx.web.Route; import io.smallrye.mutiny.Multi; import io.vertx.core.AsyncResult; @@ -110,10 +110,14 @@ public void sendMessageStreaming(@Body String body, RoutingContext rc) { if (error != null) { sendResponse(rc, error); } else if (streamingResponse != null) { - Multi events = Multi.createFrom().publisher(streamingResponse.getPublisher()); + final HTTPRestStreamingResponse finalStreamingResponse = streamingResponse; executor.execute(() -> { - MultiSseSupport.subscribeObject( - events.map(i -> (Object) i), rc); + // Convert Flow.Publisher (JSON) to Multi (SSE-formatted) + AtomicLong eventIdCounter = new AtomicLong(0); + Multi sseEvents = Multi.createFrom().publisher(finalStreamingResponse.getPublisher()) + .map(json -> SseFormatter.formatJsonAsSSE(json, eventIdCounter.getAndIncrement())); + // Write SSE-formatted strings to HTTP response + MultiSseSupport.writeSseStrings(sseEvents, rc, context); }); } } @@ -243,10 +247,14 @@ public void subscribeToTask(RoutingContext rc) { if (error != null) { sendResponse(rc, error); } else if (streamingResponse != null) { - Multi events = Multi.createFrom().publisher(streamingResponse.getPublisher()); + final HTTPRestStreamingResponse finalStreamingResponse = streamingResponse; executor.execute(() -> { - MultiSseSupport.subscribeObject( - events.map(i -> (Object) i), rc); + // Convert Flow.Publisher (JSON) to Multi (SSE-formatted) + AtomicLong eventIdCounter = new AtomicLong(0); + Multi sseEvents = Multi.createFrom().publisher(finalStreamingResponse.getPublisher()) + .map(json -> SseFormatter.formatJsonAsSSE(json, eventIdCounter.getAndIncrement())); + // Write SSE-formatted strings to HTTP response + MultiSseSupport.writeSseStrings(sseEvents, rc, context); }); } } @@ -450,34 +458,30 @@ public String getUsername() { } } - // Port of import io.quarkus.vertx.web.runtime.MultiSseSupport, which is considered internal API + /** + * Simplified SSE support for Vert.x/Quarkus. + *

+ * This class only handles HTTP-specific concerns (writing to response, backpressure, disconnect). + * SSE formatting and JSON serialization are handled by {@link SseFormatter}. + */ private static class MultiSseSupport { + private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(MultiSseSupport.class); private MultiSseSupport() { // Avoid direct instantiation. } - private static void initialize(HttpServerResponse response) { - if (response.bytesWritten() == 0) { - MultiMap headers = response.headers(); - if (headers.get(CONTENT_TYPE) == null) { - headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); - } - response.setChunked(true); - } - } - - private static void onWriteDone(Flow.@Nullable Subscription subscription, AsyncResult ar, RoutingContext rc) { - if (ar.failed()) { - rc.fail(ar.cause()); - } else if (subscription != null) { - subscription.request(1); - } - } - - private static void write(Multi multi, RoutingContext rc) { + /** + * Write SSE-formatted strings to HTTP response. + * + * @param sseStrings Multi stream of SSE-formatted strings (from SseFormatter) + * @param rc Vert.x routing context + * @param context A2A server call context (for EventConsumer cancellation) + */ + public static void writeSseStrings(Multi sseStrings, RoutingContext rc, ServerCallContext context) { HttpServerResponse response = rc.response(); - multi.subscribe().withSubscriber(new Flow.Subscriber() { + + sseStrings.subscribe().withSubscriber(new Flow.Subscriber() { Flow.@Nullable Subscription upstream; @Override @@ -485,6 +489,13 @@ public void onSubscribe(Flow.Subscription subscription) { this.upstream = subscription; this.upstream.request(1); + // Detect client disconnect and call EventConsumer.cancel() directly + response.closeHandler(v -> { + logger.debug("REST SSE connection closed by client, calling EventConsumer.cancel() to stop polling loop"); + context.invokeEventConsumerCancelCallback(); + subscription.cancel(); + }); + // Notify tests that we are subscribed Runnable runnable = streamingMultiSseSupportSubscribedRunnable; if (runnable != null) { @@ -493,53 +504,64 @@ public void onSubscribe(Flow.Subscription subscription) { } @Override - public void onNext(Buffer item) { - initialize(response); - response.write(item, new Handler>() { + public void onNext(String sseEvent) { + // Set SSE headers on first event + if (response.bytesWritten() == 0) { + MultiMap headers = response.headers(); + if (headers.get(CONTENT_TYPE) == null) { + headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); + } + // Additional SSE headers to prevent buffering + headers.set("Cache-Control", "no-cache"); + headers.set("X-Accel-Buffering", "no"); // Disable nginx buffering + response.setChunked(true); + + // CRITICAL: Disable write queue max size to prevent buffering + // Vert.x buffers writes by default - we need immediate flushing for SSE + response.setWriteQueueMaxSize(1); // Force immediate flush + + // Send initial SSE comment to kickstart the stream + // This forces Vert.x to send headers and start the stream immediately + response.write(": SSE stream started\n\n"); + } + + // Write SSE-formatted string to response + response.write(Buffer.buffer(sseEvent), new Handler>() { @Override public void handle(AsyncResult ar) { - onWriteDone(upstream, ar, rc); + if (ar.failed()) { + // Client disconnected or write failed - cancel upstream to stop EventConsumer + // NullAway: upstream is guaranteed non-null after onSubscribe + java.util.Objects.requireNonNull(upstream).cancel(); + rc.fail(ar.cause()); + } else { + // NullAway: upstream is guaranteed non-null after onSubscribe + java.util.Objects.requireNonNull(upstream).request(1); + } } }); } @Override public void onError(Throwable throwable) { + // Cancel upstream to stop EventConsumer when error occurs + // NullAway: upstream is guaranteed non-null after onSubscribe + java.util.Objects.requireNonNull(upstream).cancel(); rc.fail(throwable); } @Override public void onComplete() { - endOfStream(response); - } - }); - } - - private static void subscribeObject(Multi multi, RoutingContext rc) { - AtomicLong count = new AtomicLong(); - write(multi.map(new Function() { - @Override - public Buffer apply(Object o) { - if (o instanceof ReactiveRoutes.ServerSentEvent) { - ReactiveRoutes.ServerSentEvent ev = (ReactiveRoutes.ServerSentEvent) o; - long id = ev.id() != -1 ? ev.id() : count.getAndIncrement(); - String e = ev.event() == null ? "" : "event: " + ev.event() + "\n"; - return Buffer.buffer(e + "data: " + ev.data() + "\nid: " + id + "\n\n"); - } else { - return Buffer.buffer("data: " + o + "\nid: " + count.getAndIncrement() + "\n\n"); + if (response.bytesWritten() == 0) { + // No events written - still set SSE content type + MultiMap headers = response.headers(); + if (headers.get(CONTENT_TYPE) == null) { + headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); + } } + response.end(); } - }), rc); - } - - private static void endOfStream(HttpServerResponse response) { - if (response.bytesWritten() == 0) { // No item - MultiMap headers = response.headers(); - if (headers.get(CONTENT_TYPE) == null) { - headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); - } - } - response.end(); + }); } } diff --git a/server-common/src/main/java/io/a2a/server/ServerCallContext.java b/server-common/src/main/java/io/a2a/server/ServerCallContext.java index ba5c20b95..c12c60c21 100644 --- a/server-common/src/main/java/io/a2a/server/ServerCallContext.java +++ b/server-common/src/main/java/io/a2a/server/ServerCallContext.java @@ -16,6 +16,7 @@ public class ServerCallContext { private final Set requestedExtensions; private final Set activatedExtensions; private final @Nullable String requestedProtocolVersion; + private volatile @Nullable Runnable eventConsumerCancelCallback; public ServerCallContext(User user, Map state, Set requestedExtensions) { this(user, state, requestedExtensions, null); @@ -64,4 +65,64 @@ public boolean isExtensionRequested(String extensionUri) { public @Nullable String getRequestedProtocolVersion() { return requestedProtocolVersion; } + + /** + * Sets the callback to be invoked when the client disconnects or the call is cancelled. + *

+ * This callback is typically used to stop the EventConsumer polling loop when a client + * disconnects from a streaming endpoint. The callback is invoked by transport layers + * (JSON-RPC over HTTP/SSE, REST over HTTP/SSE, gRPC streaming) when they detect that + * the client has closed the connection. + *

+ *

+ * Thread Safety: The callback may be invoked from any thread, depending + * on the transport implementation. Implementations should be thread-safe. + *

+ * Example Usage: + *
{@code
+     * EventConsumer consumer = new EventConsumer(queue);
+     * context.setEventConsumerCancelCallback(consumer::cancel);
+     * }
+ * + * @param callback the callback to invoke on client disconnect, or null to clear any existing callback + * @see #invokeEventConsumerCancelCallback() + */ + public void setEventConsumerCancelCallback(@Nullable Runnable callback) { + this.eventConsumerCancelCallback = callback; + } + + /** + * Invokes the EventConsumer cancel callback if one has been set. + *

+ * This method is called by transport layers when a client disconnects or cancels a + * streaming request. It triggers the callback registered via + * {@link #setEventConsumerCancelCallback(Runnable)}, which typically stops the + * EventConsumer polling loop. + *

+ *

+ * Transport-Specific Behavior: + *

+ *
    + *
  • JSON-RPC/REST over HTTP/SSE: Called from Vert.x + * {@code HttpServerResponse.closeHandler()} when the SSE connection is closed
  • + *
  • gRPC streaming: Called from gRPC + * {@code Context.CancellationListener.cancelled()} when the call is cancelled
  • + *
+ *

+ * Thread Safety: This method is thread-safe. The callback is stored + * in a volatile field and null-checked before invocation to prevent race conditions. + *

+ *

+ * If no callback has been set, this method does nothing (no-op). + *

+ * + * @see #setEventConsumerCancelCallback(Runnable) + * @see io.a2a.server.events.EventConsumer#cancel() + */ + public void invokeEventConsumerCancelCallback() { + Runnable callback = this.eventConsumerCancelCallback; + if (callback != null) { + callback.run(); + } + } } diff --git a/server-common/src/main/java/io/a2a/server/events/EventConsumer.java b/server-common/src/main/java/io/a2a/server/events/EventConsumer.java index d4fe5b395..83b4575ca 100644 --- a/server-common/src/main/java/io/a2a/server/events/EventConsumer.java +++ b/server-common/src/main/java/io/a2a/server/events/EventConsumer.java @@ -19,6 +19,8 @@ public class EventConsumer { private static final Logger LOGGER = LoggerFactory.getLogger(EventConsumer.class); private final EventQueue queue; private volatile @Nullable Throwable error; + private volatile boolean cancelled = false; + private volatile boolean agentCompleted = false; private static final String ERROR_MSG = "Agent did not return any response"; private static final int NO_WAIT = -1; @@ -45,6 +47,14 @@ public Flow.Publisher consumeAll() { boolean completed = false; try { while (true) { + // Check if cancelled by client disconnect + if (cancelled) { + LOGGER.debug("EventConsumer detected cancellation, exiting polling loop for queue {}", System.identityHashCode(queue)); + completed = true; + tube.complete(); + return; + } + if (error != null) { completed = true; tube.fail(error); @@ -60,13 +70,32 @@ public Flow.Publisher consumeAll() { EventQueueItem item; Event event; try { + LOGGER.debug("EventConsumer polling queue {} (error={}, agentCompleted={})", + System.identityHashCode(queue), error, agentCompleted); item = queue.dequeueEventItem(QUEUE_WAIT_MILLISECONDS); if (item == null) { + LOGGER.debug("EventConsumer poll timeout (null item), agentCompleted={}", agentCompleted); + // If agent completed, a poll timeout means no more events are coming + // MainEventBusProcessor has 500ms to distribute events from MainEventBus + // If we timeout with agentCompleted=true, all events have been distributed + if (agentCompleted) { + LOGGER.debug("Agent completed and poll timeout, closing queue for graceful completion (queue={})", + System.identityHashCode(queue)); + queue.close(); + completed = true; + tube.complete(); + return; + } continue; } event = item.getEvent(); + LOGGER.debug("EventConsumer received event: {} (queue={})", + event.getClass().getSimpleName(), System.identityHashCode(queue)); + // Defensive logging for error handling if (event instanceof Throwable thr) { + LOGGER.debug("EventConsumer detected Throwable event: {} - triggering tube.fail()", + thr.getClass().getSimpleName()); tube.fail(thr); return; } @@ -138,12 +167,25 @@ private boolean isStreamTerminatingTask(Task task) { public EnhancedRunnable.DoneCallback createAgentRunnableDoneCallback() { return agentRunnable -> { + LOGGER.debug("EventConsumer: Agent done callback invoked (hasError={}, queue={})", + agentRunnable.getError() != null, System.identityHashCode(queue)); if (agentRunnable.getError() != null) { error = agentRunnable.getError(); + LOGGER.debug("EventConsumer: Set error field from agent callback"); + } else { + agentCompleted = true; + LOGGER.debug("EventConsumer: Agent completed successfully, set agentCompleted=true, will close queue after draining"); } }; } + public void cancel() { + // Set cancellation flag to stop polling loop + // Called when client disconnects without completing stream + LOGGER.debug("EventConsumer cancelled (client disconnect), stopping polling for queue {}", System.identityHashCode(queue)); + cancelled = true; + } + public void close() { // Close the queue to stop the polling loop in consumeAll() // This will cause EventQueueClosedException and exit the while(true) loop diff --git a/server-common/src/main/java/io/a2a/server/events/EventQueue.java b/server-common/src/main/java/io/a2a/server/events/EventQueue.java index a08f63084..3f1e21816 100644 --- a/server-common/src/main/java/io/a2a/server/events/EventQueue.java +++ b/server-common/src/main/java/io/a2a/server/events/EventQueue.java @@ -1,6 +1,7 @@ package io.a2a.server.events; import java.util.List; +import java.util.Objects; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; @@ -23,7 +24,7 @@ * and hierarchical queue structures via MainQueue and ChildQueue implementations. *

*

- * Use {@link #builder()} to create configured instances or extend MainQueue/ChildQueue directly. + * Use {@link #builder(MainEventBus)} to create configured instances or extend MainQueue/ChildQueue directly. *

*/ public abstract class EventQueue implements AutoCloseable { @@ -36,14 +37,6 @@ public abstract class EventQueue implements AutoCloseable { public static final int DEFAULT_QUEUE_SIZE = 1000; private final int queueSize; - /** - * Internal blocking queue for storing event queue items. - */ - protected final BlockingQueue queue = new LinkedBlockingDeque<>(); - /** - * Semaphore for backpressure control, limiting the number of pending events. - */ - protected final Semaphore semaphore; private volatile boolean closed = false; /** @@ -64,7 +57,6 @@ protected EventQueue(int queueSize) { throw new IllegalArgumentException("Queue size must be greater than 0"); } this.queueSize = queueSize; - this.semaphore = new Semaphore(queueSize, true); LOGGER.trace("Creating {} with queue size: {}", this, queueSize); } @@ -78,8 +70,8 @@ protected EventQueue(EventQueue parent) { LOGGER.trace("Creating {}, parent: {}", this, parent); } - static EventQueueBuilder builder() { - return new EventQueueBuilder(); + static EventQueueBuilder builder(MainEventBus mainEventBus) { + return new EventQueueBuilder().mainEventBus(mainEventBus); } /** @@ -93,8 +85,10 @@ public static class EventQueueBuilder { private int queueSize = DEFAULT_QUEUE_SIZE; private @Nullable EventEnqueueHook hook; private @Nullable String taskId; + private @Nullable String tempId = null; private List onCloseCallbacks = new java.util.ArrayList<>(); private @Nullable TaskStateProvider taskStateProvider; + private @Nullable MainEventBus mainEventBus; /** * Sets the maximum queue size. @@ -129,6 +123,17 @@ public EventQueueBuilder taskId(String taskId) { return this; } + /** + * Sets the temporary task ID if this queue is for a temporary task. + * + * @param tempId the temporary task ID, or null if not temporary + * @return this builder + */ + public EventQueueBuilder tempId(@Nullable String tempId) { + this.tempId = tempId; + return this; + } + /** * Adds a callback to be executed when the queue is closed. * @@ -153,17 +158,31 @@ public EventQueueBuilder taskStateProvider(TaskStateProvider taskStateProvider) return this; } + /** + * Sets the main event bus + * + * @param mainEventBus the main event bus + * @return this builder + */ + public EventQueueBuilder mainEventBus(MainEventBus mainEventBus) { + this.mainEventBus = mainEventBus; + return this; + } + /** * Builds and returns the configured EventQueue. * * @return a new MainQueue instance */ public EventQueue build() { - if (hook != null || !onCloseCallbacks.isEmpty() || taskStateProvider != null) { - return new MainQueue(queueSize, hook, taskId, onCloseCallbacks, taskStateProvider); - } else { - return new MainQueue(queueSize); + // MainEventBus is REQUIRED - enforce single architectural path + if (mainEventBus == null) { + throw new IllegalStateException("MainEventBus is required for EventQueue creation"); + } + if (taskId == null) { + throw new IllegalStateException("taskId is required for EventQueue creation"); } + return new MainQueue(queueSize, hook, taskId, tempId, onCloseCallbacks, taskStateProvider, mainEventBus); } } @@ -209,22 +228,7 @@ public void enqueueEvent(Event event) { * @param item the event queue item to enqueue * @throws RuntimeException if interrupted while waiting to acquire the semaphore */ - public void enqueueItem(EventQueueItem item) { - Event event = item.getEvent(); - if (closed) { - LOGGER.warn("Queue is closed. Event will not be enqueued. {} {}", this, event); - return; - } - // Call toString() since for errors we don't really want the full stacktrace - try { - semaphore.acquire(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Unable to acquire the semaphore to enqueue the event", e); - } - queue.add(item); - LOGGER.debug("Enqueued event {} {}", event instanceof Throwable ? event.toString() : event, this); - } + public abstract void enqueueItem(EventQueueItem item); /** * Creates a child queue that shares events with this queue. @@ -244,48 +248,17 @@ public void enqueueItem(EventQueueItem item) { * This method returns the full EventQueueItem wrapper, allowing callers to check * metadata like whether the event is replicated via {@link EventQueueItem#isReplicated()}. *

+ *

+ * Note: MainQueue does not support dequeue operations - only ChildQueues can be consumed. + *

* * @param waitMilliSeconds the maximum time to wait in milliseconds * @return the EventQueueItem, or null if timeout occurs * @throws EventQueueClosedException if the queue is closed and empty + * @throws UnsupportedOperationException if called on MainQueue */ - public @Nullable EventQueueItem dequeueEventItem(int waitMilliSeconds) throws EventQueueClosedException { - if (closed && queue.isEmpty()) { - LOGGER.debug("Queue is closed, and empty. Sending termination message. {}", this); - throw new EventQueueClosedException(); - } - try { - if (waitMilliSeconds <= 0) { - EventQueueItem item = queue.poll(); - if (item != null) { - Event event = item.getEvent(); - // Call toString() since for errors we don't really want the full stacktrace - LOGGER.debug("Dequeued event item (no wait) {} {}", this, event instanceof Throwable ? event.toString() : event); - semaphore.release(); - } - return item; - } - try { - LOGGER.trace("Polling queue {} (wait={}ms)", System.identityHashCode(this), waitMilliSeconds); - EventQueueItem item = queue.poll(waitMilliSeconds, TimeUnit.MILLISECONDS); - if (item != null) { - Event event = item.getEvent(); - // Call toString() since for errors we don't really want the full stacktrace - LOGGER.debug("Dequeued event item (waiting) {} {}", this, event instanceof Throwable ? event.toString() : event); - semaphore.release(); - } else { - LOGGER.trace("Dequeue timeout (null) from queue {}", System.identityHashCode(this)); - } - return item; - } catch (InterruptedException e) { - LOGGER.debug("Interrupted dequeue (waiting) {}", this); - Thread.currentThread().interrupt(); - return null; - } - } finally { - signalQueuePollerStarted(); - } - } + @Nullable + public abstract EventQueueItem dequeueEventItem(int waitMilliSeconds) throws EventQueueClosedException; /** * Placeholder method for task completion notification. @@ -295,6 +268,18 @@ public void taskDone() { // TODO Not sure if needed yet. BlockingQueue.poll()/.take() remove the events. } + /** + * Returns the current size of the queue. + *

+ * For MainQueue: returns the number of events in-flight (in MainEventBus queue + currently being processed). + * This reflects actual capacity usage tracked by the semaphore. + * For ChildQueue: returns the size of the local consumption queue. + *

+ * + * @return the number of events currently in the queue + */ + public abstract int size(); + /** * Closes this event queue gracefully, allowing pending events to be consumed. */ @@ -348,72 +333,91 @@ protected void doClose(boolean immediate) { LOGGER.debug("Closing {} (immediate={})", this, immediate); closed = true; } - - if (immediate) { - // Immediate close: clear pending events - queue.clear(); - LOGGER.debug("Cleared queue for immediate close: {}", this); - } - // For graceful close, let the queue drain naturally through normal consumption + // Subclasses handle immediate close logic (e.g., ChildQueue clears its local queue) } static class MainQueue extends EventQueue { private final List children = new CopyOnWriteArrayList<>(); + protected final Semaphore semaphore; private final CountDownLatch pollingStartedLatch = new CountDownLatch(1); private final AtomicBoolean pollingStarted = new AtomicBoolean(false); private final @Nullable EventEnqueueHook enqueueHook; - private final @Nullable String taskId; + private volatile String taskId; // Volatile to allow switching from temp to real ID across threads private final List onCloseCallbacks; private final @Nullable TaskStateProvider taskStateProvider; - - MainQueue() { - super(); - this.enqueueHook = null; - this.taskId = null; - this.onCloseCallbacks = List.of(); - this.taskStateProvider = null; - } - - MainQueue(int queueSize) { - super(queueSize); - this.enqueueHook = null; - this.taskId = null; - this.onCloseCallbacks = List.of(); - this.taskStateProvider = null; - } - - MainQueue(EventEnqueueHook hook) { - super(); - this.enqueueHook = hook; - this.taskId = null; - this.onCloseCallbacks = List.of(); - this.taskStateProvider = null; - } - - MainQueue(int queueSize, EventEnqueueHook hook) { - super(queueSize); - this.enqueueHook = hook; - this.taskId = null; - this.onCloseCallbacks = List.of(); - this.taskStateProvider = null; - } - - MainQueue(int queueSize, @Nullable EventEnqueueHook hook, @Nullable String taskId, List onCloseCallbacks, @Nullable TaskStateProvider taskStateProvider) { + private final MainEventBus mainEventBus; + private final @Nullable String tempId; + + MainQueue(int queueSize, + @Nullable EventEnqueueHook hook, + String taskId, + @Nullable String tempId, + List onCloseCallbacks, + @Nullable TaskStateProvider taskStateProvider, + @Nullable MainEventBus mainEventBus) { super(queueSize); + this.semaphore = new Semaphore(queueSize, true); this.enqueueHook = hook; this.taskId = taskId; + this.tempId = tempId; this.onCloseCallbacks = List.copyOf(onCloseCallbacks); // Defensive copy this.taskStateProvider = taskStateProvider; - LOGGER.debug("Created MainQueue for task {} with {} onClose callbacks and TaskStateProvider: {}", - taskId, onCloseCallbacks.size(), taskStateProvider != null); + this.mainEventBus = Objects.requireNonNull(mainEventBus, "MainEventBus is required"); + LOGGER.debug("Created MainQueue for task {} (tempId={}) with {} onClose callbacks, TaskStateProvider: {}, MainEventBus configured", + taskId, tempId, onCloseCallbacks.size(), taskStateProvider != null); } + public EventQueue tap() { ChildQueue child = new ChildQueue(this); children.add(child); return child; } + /** + * Returns the current number of child queues. + * Useful for debugging and logging event distribution. + */ + public int getChildCount() { + return children.size(); + } + + /** + * Returns the enqueue hook for replication (package-protected for MainEventBusProcessor). + */ + @Nullable EventEnqueueHook getEnqueueHook() { + return enqueueHook; + } + + /** + * Returns the temporary task ID if this queue was created with one, null otherwise. + * Package-protected for MainEventBusProcessor access. + */ + @Nullable String getTempId() { + return tempId; + } + + /** + * Updates the task ID when switching from temporary to real ID. + * Package-protected for MainEventBusProcessor access. + * @param newTaskId the real task ID to use + */ + void setTaskId(String newTaskId) { + this.taskId = newTaskId; + } + + @Override + public EventQueueItem dequeueEventItem(int waitMilliSeconds) throws EventQueueClosedException { + throw new UnsupportedOperationException("MainQueue cannot be consumed directly - use tap() to create a ChildQueue for consumption"); + } + + @Override + public int size() { + // Return total in-flight events (in MainEventBus + being processed) + // This aligns with semaphore's capacity tracking + return getQueueSize() - semaphore.availablePermits(); + } + @Override public void enqueueItem(EventQueueItem item) { // MainQueue must accept events even when closed to support: @@ -432,17 +436,48 @@ public void enqueueItem(EventQueueItem item) { throw new RuntimeException("Unable to acquire the semaphore to enqueue the event", e); } - // Add to this MainQueue's internal queue - queue.add(item); LOGGER.debug("Enqueued event {} {}", event instanceof Throwable ? event.toString() : event, this); - // Distribute to all ChildQueues (they will receive the event even if MainQueue is closed) - children.forEach(eq -> eq.internalEnqueueItem(item)); + // Submit to MainEventBus for centralized persistence + distribution + // MainEventBus is guaranteed non-null by constructor requirement + // Note: Replication now happens in MainEventBusProcessor AFTER persistence + + // For temp ID scenarios: if event contains a real task ID different from our temp ID, + // use the real ID when submitting to MainEventBus. This ensures subsequent events + // are submitted with the correct ID even before MainEventBusProcessor updates our taskId field. + // IMPORTANT: We don't update our taskId here - that happens in MainEventBusProcessor + // AFTER TaskManager validates the ID (e.g., checking for duplicate task IDs). + String submissionId = taskId; + if (tempId != null && taskId.equals(tempId)) { + // Not yet switched - try to extract real ID from first event + String eventTaskId = extractTaskId(event); + if (eventTaskId != null && !eventTaskId.equals(taskId)) { + // Event has a different (real) ID - use it for submission + // but keep our taskId as temp-UUID for now (until MainEventBusProcessor switches it) + submissionId = eventTaskId; + LOGGER.debug("MainQueue submitting event with real ID {} (current taskId: {})", eventTaskId, taskId); + } + } + // If already switched (tempId != null but taskId != tempId), submissionId stays as taskId (the real ID) + // This ensures subsequent events with stale temp IDs still get submitted with the correct real ID + mainEventBus.submit(submissionId, this, item); + } - // Trigger replication hook if configured - if (enqueueHook != null) { - enqueueHook.onEnqueue(item); + /** + * Extracts taskId from an event (package-private helper). + */ + @Nullable + private String extractTaskId(Event event) { + if (event instanceof io.a2a.spec.Task task) { + return task.id(); + } else if (event instanceof io.a2a.spec.TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.taskId(); + } else if (event instanceof io.a2a.spec.TaskArtifactUpdateEvent artifactUpdate) { + return artifactUpdate.taskId(); + } else if (event instanceof io.a2a.spec.Message message) { + return message.taskId(); } + return null; } @Override @@ -465,20 +500,15 @@ public void signalQueuePollerStarted() { void childClosing(ChildQueue child, boolean immediate) { children.remove(child); // Remove the closing child - // Close immediately if requested - if (immediate) { - LOGGER.debug("MainQueue closing immediately (immediate=true)"); - this.doClose(immediate); - return; - } - // If there are still children, keep queue open if (!children.isEmpty()) { LOGGER.debug("MainQueue staying open: {} children remaining", children.size()); return; } - // No children left - check if task is finalized before auto-closing + // No children left - check if task is finalized before closing + // IMPORTANT: This check must happen BEFORE the immediate flag check + // to prevent closing queues for non-final tasks (fire-and-forget, resubscription support) if (taskStateProvider != null && taskId != null) { boolean isFinalized = taskStateProvider.isTaskFinalized(taskId); if (!isFinalized) { @@ -493,6 +523,36 @@ void childClosing(ChildQueue child, boolean immediate) { this.doClose(immediate); } + /** + * Distribute event to all ChildQueues. + * Called by MainEventBusProcessor after TaskStore persistence. + */ + void distributeToChildren(EventQueueItem item) { + int childCount = children.size(); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("MainQueue[{}]: Distributing event {} to {} children", + taskId, item.getEvent().getClass().getSimpleName(), childCount); + } + children.forEach(child -> { + LOGGER.debug("MainQueue[{}]: Enqueueing event {} to child queue", + taskId, item.getEvent().getClass().getSimpleName()); + child.internalEnqueueItem(item); + }); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("MainQueue[{}]: Completed distribution of {} to {} children", + taskId, item.getEvent().getClass().getSimpleName(), childCount); + } + } + + /** + * Release the semaphore after event processing is complete. + * Called by MainEventBusProcessor in finally block to ensure release even on exceptions. + * Balances the acquire() in enqueueEvent() - protects MainEventBus throughput. + */ + void releaseSemaphore() { + semaphore.release(); + } + /** * Get the count of active child queues. * Used for testing to verify reference counting mechanism. @@ -539,10 +599,16 @@ public void close(boolean immediate) { public void close(boolean immediate, boolean notifyParent) { throw new UnsupportedOperationException("MainQueue does not support notifyParent parameter - use close(boolean) instead"); } + + String getTaskId() { + return taskId; + } } static class ChildQueue extends EventQueue { private final MainQueue parent; + private final BlockingQueue queue = new LinkedBlockingDeque<>(); + private volatile boolean immediateClose = false; public ChildQueue(MainQueue parent) { this.parent = parent; @@ -553,8 +619,69 @@ public void enqueueEvent(Event event) { parent.enqueueEvent(event); } + @Override + public void enqueueItem(EventQueueItem item) { + // ChildQueue delegates writes to parent MainQueue + parent.enqueueItem(item); + } + private void internalEnqueueItem(EventQueueItem item) { - super.enqueueItem(item); + // Internal method called by MainEventBusProcessor to add to local queue + // Note: Semaphore is managed by parent MainQueue (acquire/release), not ChildQueue + Event event = item.getEvent(); + // For graceful close: still accept events so they can be drained by EventConsumer + // For immediate close: reject events to stop distribution quickly + if (isClosed() && immediateClose) { + LOGGER.warn("ChildQueue is immediately closed. Event will not be enqueued. {} {}", this, event); + return; + } + if (!queue.offer(item)) { + LOGGER.warn("ChildQueue {} is full. Closing immediately.", this); + close(true); // immediate close + } else { + LOGGER.debug("Enqueued event {} {}", event instanceof Throwable ? event.toString() : event, this); + } + } + + @Override + @Nullable + public EventQueueItem dequeueEventItem(int waitMilliSeconds) throws EventQueueClosedException { + // For immediate close: exit immediately even if queue is not empty (race with MainEventBusProcessor) + // For graceful close: only exit when queue is empty (wait for all events to be consumed) + if (isClosed() && (queue.isEmpty() || immediateClose)) { + LOGGER.debug("ChildQueue is closed{}, sending termination message. {} (queueSize={})", + immediateClose ? " (immediate)" : " and empty", + this, + queue.size()); + throw new EventQueueClosedException(); + } + try { + if (waitMilliSeconds <= 0) { + EventQueueItem item = queue.poll(); + if (item != null) { + Event event = item.getEvent(); + LOGGER.debug("Dequeued event item (no wait) {} {}", this, event instanceof Throwable ? event.toString() : event); + } + return item; + } + try { + LOGGER.trace("Polling ChildQueue {} (wait={}ms)", System.identityHashCode(this), waitMilliSeconds); + EventQueueItem item = queue.poll(waitMilliSeconds, TimeUnit.MILLISECONDS); + if (item != null) { + Event event = item.getEvent(); + LOGGER.debug("Dequeued event item (waiting) {} {}", this, event instanceof Throwable ? event.toString() : event); + } else { + LOGGER.trace("Dequeue timeout (null) from ChildQueue {}", System.identityHashCode(this)); + } + return item; + } catch (InterruptedException e) { + LOGGER.debug("Interrupted dequeue (waiting) {}", this); + Thread.currentThread().interrupt(); + return null; + } + } finally { + signalQueuePollerStarted(); + } } @Override @@ -562,6 +689,12 @@ public EventQueue tap() { throw new IllegalStateException("Can only tap the main queue"); } + @Override + public int size() { + // Return size of local consumption queue + return queue.size(); + } + @Override public void awaitQueuePollerStart() throws InterruptedException { parent.awaitQueuePollerStart(); @@ -572,6 +705,19 @@ public void signalQueuePollerStarted() { parent.signalQueuePollerStarted(); } + @Override + protected void doClose(boolean immediate) { + super.doClose(immediate); // Sets closed flag + if (immediate) { + // Immediate close: clear pending events from local queue + this.immediateClose = true; + int clearedCount = queue.size(); + queue.clear(); + LOGGER.debug("Cleared {} events from ChildQueue for immediate close: {}", clearedCount, this); + } + // For graceful close, let the queue drain naturally through normal consumption + } + @Override public void close() { close(false); diff --git a/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java b/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java index e5a17e0e7..edcb55307 100644 --- a/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java +++ b/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java @@ -34,16 +34,20 @@ protected InMemoryQueueManager() { this.taskStateProvider = null; } + MainEventBus mainEventBus; + @Inject - public InMemoryQueueManager(TaskStateProvider taskStateProvider) { + public InMemoryQueueManager(TaskStateProvider taskStateProvider, MainEventBus mainEventBus) { + this.mainEventBus = mainEventBus; this.factory = new DefaultEventQueueFactory(); this.taskStateProvider = taskStateProvider; } - // For testing with custom factory - public InMemoryQueueManager(EventQueueFactory factory, TaskStateProvider taskStateProvider) { + // For testing/extensions with custom factory and MainEventBus + public InMemoryQueueManager(EventQueueFactory factory, TaskStateProvider taskStateProvider, MainEventBus mainEventBus) { this.factory = factory; this.taskStateProvider = taskStateProvider; + this.mainEventBus = mainEventBus; } @Override @@ -54,6 +58,40 @@ public void add(String taskId, EventQueue queue) { } } + @Override + public synchronized void switchKey(String oldId, String newId) { + // Check if already switched (idempotent operation) + // This check is now safe because the method is synchronized + EventQueue existingNew = queues.get(newId); + if (existingNew != null) { + EventQueue oldQueue = queues.get(oldId); + if (oldQueue == null) { + // Already switched - idempotent success + LOGGER.debug("Queue already switched from {} to {}, skipping", oldId, newId); + return; + } else { + // Different queue already at newId - error + throw new TaskQueueExistsException(); + } + } + + // Normal path: move queue from oldId to newId + EventQueue queue = queues.remove(oldId); + if (queue == null) { + throw new IllegalStateException("No queue found for old ID: " + oldId); + } + + EventQueue existing = queues.putIfAbsent(newId, queue); + if (existing != null) { + // Rollback - put old one back + queues.putIfAbsent(oldId, queue); + throw new TaskQueueExistsException(); + } + + LOGGER.debug("Switched queue {} from temp ID {} to real task ID {}", + System.identityHashCode(queue), oldId, newId); + } + @Override public @Nullable EventQueue get(String taskId) { return queues.get(taskId); @@ -78,7 +116,12 @@ public void close(String taskId) { @Override public EventQueue createOrTap(String taskId) { - LOGGER.debug("createOrTap called for task {}, current map size: {}", taskId, queues.size()); + return createOrTap(taskId, null); + } + + @Override + public EventQueue createOrTap(String taskId, @Nullable String tempId) { + LOGGER.debug("createOrTap called for task {} (tempId={}), current map size: {}", taskId, tempId, queues.size()); EventQueue existing = queues.get(taskId); // Lazy cleanup: only remove closed queues if task is finalized @@ -101,8 +144,8 @@ public EventQueue createOrTap(String taskId) { EventQueue newQueue = null; if (existing == null) { // Use builder pattern for cleaner queue creation - // Use the new taskId-aware builder method if available - newQueue = factory.builder(taskId).build(); + // Pass tempId to the builder + newQueue = factory.builder(taskId).tempId(tempId).build(); // Make sure an existing queue has not been added in the meantime existing = queues.putIfAbsent(taskId, newQueue); } @@ -114,8 +157,8 @@ public EventQueue createOrTap(String taskId) { EventQueue result = main.tap(); // Always return ChildQueue if (existing == null) { - LOGGER.debug("Created new MainQueue {} for task {}, returning ChildQueue {} (map size: {})", - System.identityHashCode(main), taskId, System.identityHashCode(result), queues.size()); + LOGGER.debug("Created new MainQueue {} for task {} (tempId={}), returning ChildQueue {} (map size: {})", + System.identityHashCode(main), taskId, tempId, System.identityHashCode(result), queues.size()); } else { LOGGER.debug("Tapped existing MainQueue {} -> ChildQueue {} for task {}", System.identityHashCode(main), System.identityHashCode(result), taskId); @@ -128,6 +171,12 @@ public void awaitQueuePollerStart(EventQueue eventQueue) throws InterruptedExcep eventQueue.awaitQueuePollerStart(); } + @Override + public EventQueue.EventQueueBuilder getEventQueueBuilder(String taskId) { + // Use the factory to ensure proper configuration (MainEventBus, callbacks, etc.) + return factory.builder(taskId); + } + @Override public int getActiveChildQueueCount(String taskId) { EventQueue queue = queues.get(taskId); @@ -142,6 +191,14 @@ public int getActiveChildQueueCount(String taskId) { return -1; } + @Override + public EventQueue.EventQueueBuilder createBaseEventQueueBuilder(String taskId) { + return EventQueue.builder(mainEventBus) + .taskId(taskId) + .addOnCloseCallback(getCleanupCallback(taskId)) + .taskStateProvider(taskStateProvider); + } + /** * Get the cleanup callback that removes a queue from the map when it closes. * This is exposed so that subclasses (like ReplicatedQueueManager) can reuse @@ -181,11 +238,8 @@ public Runnable getCleanupCallback(String taskId) { private class DefaultEventQueueFactory implements EventQueueFactory { @Override public EventQueue.EventQueueBuilder builder(String taskId) { - // Return builder with callback that removes queue from map when closed - return EventQueue.builder() - .taskId(taskId) - .addOnCloseCallback(getCleanupCallback(taskId)) - .taskStateProvider(taskStateProvider); + // Delegate to the base builder creation method + return createBaseEventQueueBuilder(taskId); } } } diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBus.java b/server-common/src/main/java/io/a2a/server/events/MainEventBus.java new file mode 100644 index 000000000..90080b1e2 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBus.java @@ -0,0 +1,42 @@ +package io.a2a.server.events; + +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingDeque; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import jakarta.enterprise.context.ApplicationScoped; + +@ApplicationScoped +public class MainEventBus { + private static final Logger LOGGER = LoggerFactory.getLogger(MainEventBus.class); + private final BlockingQueue queue; + + public MainEventBus() { + this.queue = new LinkedBlockingDeque<>(); + } + + void submit(String taskId, EventQueue.MainQueue mainQueue, EventQueueItem item) { + try { + queue.put(new MainEventBusContext(taskId, mainQueue, item)); + LOGGER.debug("Submitted event for task {} to MainEventBus (queue size: {})", + taskId, queue.size()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted submitting to MainEventBus", e); + } + } + + MainEventBusContext take() throws InterruptedException { + LOGGER.debug("MainEventBus: Waiting to take event (current queue size: {})...", queue.size()); + MainEventBusContext context = queue.take(); + LOGGER.debug("MainEventBus: Took event for task {} (remaining queue size: {})", + context.taskId(), queue.size()); + return context; + } + + public int size() { + return queue.size(); + } +} diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusContext.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusContext.java new file mode 100644 index 000000000..292a60f21 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusContext.java @@ -0,0 +1,11 @@ +package io.a2a.server.events; + +import java.util.Objects; + +record MainEventBusContext(String taskId, EventQueue.MainQueue eventQueue, EventQueueItem eventQueueItem) { + MainEventBusContext { + Objects.requireNonNull(taskId, "taskId cannot be null"); + Objects.requireNonNull(eventQueue, "eventQueue cannot be null"); + Objects.requireNonNull(eventQueueItem, "eventQueueItem cannot be null"); + } +} diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java new file mode 100644 index 000000000..a511fedf1 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java @@ -0,0 +1,426 @@ +package io.a2a.server.events; + +import java.util.concurrent.CompletableFuture; + +import jakarta.annotation.Nullable; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import io.a2a.server.tasks.PushNotificationSender; +import io.a2a.server.tasks.TaskManager; +import io.a2a.server.tasks.TaskStore; +import io.a2a.spec.A2AServerException; +import io.a2a.spec.Event; +import io.a2a.spec.InternalError; +import io.a2a.spec.Message; +import io.a2a.spec.Task; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskStatusUpdateEvent; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Background processor for the MainEventBus. + *

+ * This processor runs in a dedicated background thread, consuming events from the MainEventBus + * and performing two critical operations in order: + *

+ *
    + *
  1. Update TaskStore with event data (persistence FIRST)
  2. + *
  3. Distribute event to ChildQueues (clients see it AFTER persistence)
  4. + *
+ *

+ * This architecture ensures clients never receive events before they're persisted, + * eliminating race conditions and enabling reliable event replay. + *

+ *

+ * Note: This bean is eagerly initialized by {@link MainEventBusProcessorInitializer} + * to ensure the background thread starts automatically when the application starts. + *

+ */ +@ApplicationScoped +public class MainEventBusProcessor implements Runnable { + private static final Logger LOGGER = LoggerFactory.getLogger(MainEventBusProcessor.class); + + /** + * Callback for testing synchronization with async event processing. + * Default is NOOP to avoid null checks in production code. + * Tests can inject their own callback via setCallback(). + */ + private volatile MainEventBusProcessorCallback callback = MainEventBusProcessorCallback.NOOP; + + /** + * Optional executor for push notifications. + * If null, uses default ForkJoinPool (async). + * Tests can inject a synchronous executor to ensure deterministic ordering. + */ + private volatile @Nullable java.util.concurrent.Executor pushNotificationExecutor = null; + + private final MainEventBus eventBus; + + private final TaskStore taskStore; + + private final PushNotificationSender pushSender; + + private final QueueManager queueManager; + + private volatile boolean running = true; + private @Nullable Thread processorThread; + + @Inject + public MainEventBusProcessor(MainEventBus eventBus, TaskStore taskStore, PushNotificationSender pushSender, QueueManager queueManager) { + this.eventBus = eventBus; + this.taskStore = taskStore; + this.pushSender = pushSender; + this.queueManager = queueManager; + } + + /** + * Set a callback for testing synchronization with async event processing. + *

+ * This is primarily intended for tests that need to wait for event processing to complete. + * Pass null to reset to the default NOOP callback. + *

+ * + * @param callback the callback to invoke during event processing, or null for NOOP + */ + public void setCallback(MainEventBusProcessorCallback callback) { + this.callback = callback != null ? callback : MainEventBusProcessorCallback.NOOP; + } + + /** + * Set a custom executor for push notifications (primarily for testing). + *

+ * By default, push notifications are sent asynchronously using CompletableFuture.runAsync() + * with the default ForkJoinPool. For tests that need deterministic ordering of push + * notifications, inject a synchronous executor that runs tasks immediately on the calling thread. + *

+ * Example synchronous executor for tests: + *
{@code
+     * Executor syncExecutor = Runnable::run;
+     * mainEventBusProcessor.setPushNotificationExecutor(syncExecutor);
+     * }
+ * + * @param executor the executor to use for push notifications, or null to use default ForkJoinPool + */ + public void setPushNotificationExecutor(java.util.concurrent.Executor executor) { + this.pushNotificationExecutor = executor; + } + + @PostConstruct + void start() { + processorThread = new Thread(this, "MainEventBusProcessor"); + processorThread.setDaemon(true); // Allow JVM to exit even if this thread is running + processorThread.start(); + LOGGER.info("MainEventBusProcessor started"); + } + + /** + * No-op method to force CDI proxy resolution and ensure @PostConstruct has been called. + * Called by MainEventBusProcessorInitializer during application startup. + */ + public void ensureStarted() { + // Method intentionally empty - just forces proxy resolution + } + + @PreDestroy + void stop() { + LOGGER.info("MainEventBusProcessor stopping..."); + running = false; + if (processorThread != null) { + processorThread.interrupt(); + try { + long start = System.currentTimeMillis(); + processorThread.join(5000); // Wait up to 5 seconds + long elapsed = System.currentTimeMillis() - start; + LOGGER.info("MainEventBusProcessor thread stopped in {}ms", elapsed); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOGGER.warn("Interrupted while waiting for MainEventBusProcessor thread to stop"); + } + } + LOGGER.info("MainEventBusProcessor stopped"); + } + + @Override + public void run() { + LOGGER.info("MainEventBusProcessor processing loop started"); + while (running) { + try { + LOGGER.debug("MainEventBusProcessor: Waiting for event from MainEventBus..."); + MainEventBusContext context = eventBus.take(); + LOGGER.debug("MainEventBusProcessor: Retrieved event for task {} from MainEventBus", + context.taskId()); + processEvent(context); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOGGER.info("MainEventBusProcessor interrupted, shutting down"); + break; + } catch (Exception e) { + LOGGER.error("Error processing event from MainEventBus", e); + // Continue processing despite errors + } + } + LOGGER.info("MainEventBusProcessor processing loop ended"); + } + + private void processEvent(MainEventBusContext context) { + String taskId = context.taskId(); + Event event = context.eventQueueItem().getEvent(); + // MainEventBus.submit() guarantees this is always a MainQueue + EventQueue.MainQueue mainQueue = (EventQueue.MainQueue) context.eventQueue(); + + // Determine if this is a temp ID scenario + // If MainQueue was created with a tempId, then isTempId is true ONLY for events + // BEFORE the queue key is switched. After switching, mainQueue.taskId != tempId, + // so isTempId should be false for subsequent events. + String tempId = mainQueue.getTempId(); + boolean isTempId = (tempId != null && mainQueue.getTaskId().equals(tempId)); + LOGGER.debug("MainEventBusProcessor: Processing event for task {} (tempId={}, isTempId={}): {}", + taskId, tempId, isTempId, event.getClass().getSimpleName()); + + Event eventToDistribute = null; + try { + // Step 1: Update TaskStore FIRST (persistence before clients see it) + // If this throws, we distribute an error to ensure "persist before client visibility" + + try { + updateTaskStore(taskId, event, isTempId, mainQueue); + + eventToDistribute = event; // Success - distribute original event + + // Trigger replication AFTER successful persistence (moved from MainQueue.enqueueEvent) + // This ensures replicated events have real task IDs, not temp-UUIDs + EventEnqueueHook hook = mainQueue.getEnqueueHook(); + if (hook != null) { + LOGGER.debug("Triggering replication hook for task {} after successful persistence", taskId); + hook.onEnqueue(context.eventQueueItem()); + } + } catch (InternalError e) { + // Persistence failed - create error event to distribute instead + LOGGER.error("Failed to persist event for task {}, distributing error to clients", taskId, e); + String errorMessage = "Failed to persist event: " + e.getMessage(); + eventToDistribute = e; + } catch (Exception e) { + LOGGER.error("Failed to persist event for task {}, distributing error to clients", taskId, e); + String errorMessage = "Failed to persist event: " + e.getMessage(); + eventToDistribute = new InternalError(errorMessage); + } + + // Step 2: Send push notification AFTER successful persistence + if (eventToDistribute == event) { + // Capture task state immediately after persistence, before going async + // This ensures we send the task as it existed when THIS event was processed, + // not whatever state might exist later when the async callback executes + Task taskSnapshot = taskStore.get(taskId); + if (taskSnapshot != null) { + sendPushNotification(taskId, taskSnapshot); + } else { + LOGGER.warn("Task {} not found in TaskStore after successful persistence, skipping push notification", taskId); + } + } + + // Step 3: Then distribute to ChildQueues (clients see either event or error AFTER persistence attempt) + if (eventToDistribute == null) { + LOGGER.error("MainEventBusProcessor: eventToDistribute is NULL for task {} - this should never happen!", taskId); + eventToDistribute = new InternalError("Internal error: event processing failed"); + } + + int childCount = mainQueue.getChildCount(); + LOGGER.debug("MainEventBusProcessor: Distributing {} to {} children for task {}", + eventToDistribute.getClass().getSimpleName(), childCount, taskId); + // Create new EventQueueItem with the event to distribute (original or error) + EventQueueItem itemToDistribute = new LocalEventQueueItem(eventToDistribute); + mainQueue.distributeToChildren(itemToDistribute); + LOGGER.debug("MainEventBusProcessor: Distributed {} to {} children for task {}", + eventToDistribute.getClass().getSimpleName(), childCount, taskId); + + LOGGER.debug("MainEventBusProcessor: Completed processing event for task {}", taskId); + + } finally { + try { + // Step 4: Notify callback after all processing is complete + // Call callback with the distributed event (original or error) + if (eventToDistribute != null) { + callback.onEventProcessed(taskId, eventToDistribute); + + // Step 5: If this is a final event, notify task finalization + // Only for successful persistence (not for errors) + if (eventToDistribute == event && isFinalEvent(event)) { + callback.onTaskFinalized(taskId); + } + } + } finally { + // ALWAYS release semaphore, even if processing fails + // Balances the acquire() in MainQueue.enqueueEvent() + mainQueue.releaseSemaphore(); + } + } + } + + /** + * Updates TaskStore using TaskManager.process(). + *

+ * Creates a temporary TaskManager instance for this event and delegates to its process() method, + * which handles all event types (Task, TaskStatusUpdateEvent, TaskArtifactUpdateEvent). + * This leverages existing TaskManager logic for status updates, artifact appending, message history, etc. + *

+ *

+ * If persistence fails, the exception is propagated to processEvent() which distributes an + * InternalError to clients instead of the original event, ensuring "persist before visibility". + * See Gemini's comment: https://github.com/a2aproject/a2a-java/pull/515#discussion_r2604621833 + *

+ * + * @param taskId the task ID (may be temporary) + * @param event the event to persist + * @param isTempId whether the task ID is temporary (from MainQueue.tempId) + * @param mainQueue the main queue (for updating taskId after ID switch) + * @throws InternalError if persistence fails + */ + private void updateTaskStore(String taskId, Event event, boolean isTempId, EventQueue.MainQueue mainQueue) throws InternalError { + try { + // Extract contextId from event (all relevant events have it) + String contextId = extractContextId(event); + + String eventTaskId = null; + if (isTempId) { + eventTaskId = extractTaskId(event); + LOGGER.debug("Temp ID scenario: taskId={}, event={}", taskId, eventTaskId); + } + + // For temp ID scenarios (before switch), use the MainQueue's current taskId (temp ID) for TaskManager. + // This ensures duplicate detection works when switching from temp to real ID. + // After switch (isTempId=false), use the submission taskId (real ID). + String taskIdForManager = isTempId ? mainQueue.getTaskId() : taskId; + + // Create temporary TaskManager instance for this event + // Use taskIdForManager (temp ID if switching, real ID otherwise) + // isTempId allows TaskManager.checkIdsAndUpdateIfNecessary() to switch to real ID + TaskManager taskManager = new TaskManager(taskIdForManager, contextId, taskStore, null, isTempId); + + // Use TaskManager.process() - handles all event types with existing logic + taskManager.process(event); + LOGGER.debug("TaskStore updated via TaskManager.process() for task {}: {}", + taskId, event.getClass().getSimpleName()); + + // If this was a temp ID scenario and the event had a different task ID, + // then TaskManager switched from temp to real ID. Update the queue key and MainQueue taskId. + if (isTempId && eventTaskId != null && !eventTaskId.equals(taskIdForManager)) { + LOGGER.debug("Switching queue key from temp {} to real {}", taskIdForManager, eventTaskId); + queueManager.switchKey(taskIdForManager, eventTaskId); + // Also update the MainQueue's taskId so subsequent events use the real ID + mainQueue.setTaskId(eventTaskId); + } + } catch (InternalError e) { + LOGGER.error("Error updating TaskStore via TaskManager for task {}", taskId, e); + // Rethrow to prevent distributing unpersisted event to clients + throw e; + } catch (Exception e) { + LOGGER.error("Unexpected error updating TaskStore for task {}", taskId, e); + // Rethrow to prevent distributing unpersisted event to clients + throw new InternalError("TaskStore persistence failed: " + e.getMessage()); + } + } + + /** + * Sends push notification for the task AFTER persistence. + *

+ * This is called after updateTaskStore() to ensure the notification contains + * the latest persisted state, avoiding race conditions. + *

+ *

+ * CRITICAL: Push notifications are sent asynchronously in the background + * to avoid blocking event distribution to ChildQueues. The 83ms overhead from + * PushNotificationSender.sendNotification() was causing streaming delays. + *

+ *

+ * IMPORTANT: The task parameter is a snapshot captured immediately after + * persistence. This ensures we send the task state as it existed when THIS event + * was processed, not whatever state might exist in TaskStore when the async + * callback executes (subsequent events may have already updated the store). + *

+ *

+ * NOTE: Tests can inject a synchronous executor via setPushNotificationExecutor() + * to ensure deterministic ordering of push notifications in the test environment. + *

+ * + * @param taskId the task ID + * @param task the task snapshot to send (captured immediately after persistence) + */ + private void sendPushNotification(String taskId, Task task) { + Runnable pushTask = () -> { + try { + if (task != null) { + LOGGER.debug("Sending push notification for task {}", taskId); + pushSender.sendNotification(task); + } else { + LOGGER.debug("Skipping push notification - task snapshot is null for task {}", taskId); + } + } catch (Exception e) { + LOGGER.error("Error sending push notification for task {}", taskId, e); + // Don't rethrow - push notifications are best-effort + } + }; + + // Use custom executor if set (for tests), otherwise use default ForkJoinPool (async) + if (pushNotificationExecutor != null) { + pushNotificationExecutor.execute(pushTask); + } else { + CompletableFuture.runAsync(pushTask); + } + } + + /** + * Extracts contextId from an event. + * Returns null if the event type doesn't have a contextId (e.g., Message). + */ + @Nullable + private String extractContextId(Event event) { + if (event instanceof Task task) { + return task.contextId(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.contextId(); + } else if (event instanceof TaskArtifactUpdateEvent artifactUpdate) { + return artifactUpdate.contextId(); + } + // Message and other events don't have contextId + return null; + } + + /** + * Extracts taskId from an event. + * Returns null if the event type doesn't have a taskId (e.g., Throwable). + */ + @Nullable + private String extractTaskId(Event event) { + if (event instanceof Task task) { + return task.id(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.taskId(); + } else if (event instanceof TaskArtifactUpdateEvent artifactUpdate) { + return artifactUpdate.taskId(); + } else if (event instanceof Message message) { + return message.taskId(); + } + // Other events (Throwable, etc.) don't have taskId + return null; + } + + /** + * Checks if an event represents a final task state. + * + * @param event the event to check + * @return true if the event represents a final state (COMPLETED, FAILED, CANCELED, REJECTED, UNKNOWN) + */ + private boolean isFinalEvent(Event event) { + if (event instanceof Task task) { + return task.status() != null && task.status().state() != null + && task.status().state().isFinal(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.isFinal(); + } + return false; + } +} diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorCallback.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorCallback.java new file mode 100644 index 000000000..b0a9adbce --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorCallback.java @@ -0,0 +1,66 @@ +package io.a2a.server.events; + +import io.a2a.spec.Event; + +/** + * Callback interface for MainEventBusProcessor events. + *

+ * This interface is primarily intended for testing, allowing tests to synchronize + * with the asynchronous MainEventBusProcessor. Production code should not rely on this. + *

+ * Usage in tests: + *
+ * {@code
+ * @Inject
+ * MainEventBusProcessor processor;
+ *
+ * @BeforeEach
+ * void setUp() {
+ *     CountDownLatch latch = new CountDownLatch(3);
+ *     processor.setCallback(new MainEventBusProcessorCallback() {
+ *         public void onEventProcessed(String taskId, Event event) {
+ *             latch.countDown();
+ *         }
+ *     });
+ * }
+ *
+ * @AfterEach
+ * void tearDown() {
+ *     processor.setCallback(null); // Reset to NOOP
+ * }
+ * }
+ * 
+ */ +public interface MainEventBusProcessorCallback { + + /** + * Called after an event has been fully processed (persisted, notification sent, distributed to children). + * + * @param taskId the task ID + * @param event the event that was processed + */ + void onEventProcessed(String taskId, Event event); + + /** + * Called when a task reaches a final state (COMPLETED, FAILED, CANCELED, REJECTED). + * + * @param taskId the task ID that was finalized + */ + void onTaskFinalized(String taskId); + + /** + * No-op implementation that does nothing. + * Used as the default callback to avoid null checks. + */ + MainEventBusProcessorCallback NOOP = new MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, Event event) { + // No-op + } + + @Override + public void onTaskFinalized(String taskId) { + // No-op + } + }; +} diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorInitializer.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorInitializer.java new file mode 100644 index 000000000..ba4b300be --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorInitializer.java @@ -0,0 +1,43 @@ +package io.a2a.server.events; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.Initialized; +import jakarta.enterprise.event.Observes; +import jakarta.inject.Inject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Portable CDI initializer for MainEventBusProcessor. + *

+ * This bean observes the ApplicationScoped initialization event and injects + * MainEventBusProcessor, which triggers its eager creation and starts the background thread. + *

+ *

+ * This approach is portable across all Jakarta CDI implementations (Weld, OpenWebBeans, Quarkus, etc.) + * and ensures MainEventBusProcessor starts automatically when the application starts. + *

+ */ +@ApplicationScoped +public class MainEventBusProcessorInitializer { + private static final Logger LOGGER = LoggerFactory.getLogger(MainEventBusProcessorInitializer.class); + + @Inject + MainEventBusProcessor processor; + + /** + * Observes ApplicationScoped initialization to force eager creation of MainEventBusProcessor. + * The injection of MainEventBusProcessor in this bean triggers its creation, and calling + * ensureStarted() forces the CDI proxy to be resolved, which ensures @PostConstruct has been + * called and the background thread is running. + */ + void onStart(@Observes @Initialized(ApplicationScoped.class) Object event) { + if (processor != null) { + // Force proxy resolution to ensure @PostConstruct has been called + processor.ensureStarted(); + LOGGER.info("MainEventBusProcessor initialized and started"); + } else { + LOGGER.error("MainEventBusProcessor is null - initialization failed!"); + } + } +} diff --git a/server-common/src/main/java/io/a2a/server/events/QueueManager.java b/server-common/src/main/java/io/a2a/server/events/QueueManager.java index 01e754fcb..2de0feb73 100644 --- a/server-common/src/main/java/io/a2a/server/events/QueueManager.java +++ b/server-common/src/main/java/io/a2a/server/events/QueueManager.java @@ -96,6 +96,25 @@ public interface QueueManager { */ void add(String taskId, EventQueue queue); + /** + * Switches a queue from an old key to a new key atomically. + *

+ * Used when transitioning from a temporary task ID (e.g., "temp-UUID") to the real task ID + * when the Task event arrives with the actual task.id. This prevents duplicate map entries + * and ensures clean queue lifecycle management. + *

+ *

+ * The operation is atomic: removes the old key and adds the new key. If the new key already + * exists, the operation is rolled back by restoring the old key. + *

+ * + * @param oldId the temporary/old task identifier to remove + * @param newId the real/new task identifier to add + * @throws IllegalStateException if no queue exists for oldId + * @throws TaskQueueExistsException if a queue already exists for newId + */ + void switchKey(String oldId, String newId); + /** * Retrieves the MainQueue for a task, if it exists. *

@@ -148,7 +167,24 @@ public interface QueueManager { * @param taskId the task identifier * @return a MainQueue (if new task) or ChildQueue (if tapping existing) */ - EventQueue createOrTap(String taskId); + default EventQueue createOrTap(String taskId) { + return createOrTap(taskId, null); + } + + /** + * Creates a MainQueue if none exists, or taps the existing queue to create a ChildQueue. + *

+ * This is the primary method used by {@link io.a2a.server.requesthandlers.DefaultRequestHandler}: + *

    + *
  • New task: Creates and returns a MainQueue with tempId
  • + *
  • Resubscription: Taps existing MainQueue and returns a ChildQueue
  • + *
+ * + * @param taskId the task identifier (may be temporary) + * @param tempId the temporary task ID if taskId is temporary, null otherwise + * @return a MainQueue (if new task) or ChildQueue (if tapping existing) + */ + EventQueue createOrTap(String taskId, @Nullable String tempId); /** * Waits for the queue's consumer polling to start. @@ -177,7 +213,31 @@ public interface QueueManager { * @return a builder for creating event queues */ default EventQueue.EventQueueBuilder getEventQueueBuilder(String taskId) { - return EventQueue.builder(); + throw new UnsupportedOperationException( + "QueueManager implementations must override getEventQueueBuilder() to provide MainEventBus" + ); + } + + /** + * Creates a base EventQueueBuilder with standard configuration for this QueueManager. + * This method provides the foundation for creating event queues with proper configuration + * (MainEventBus, TaskStateProvider, cleanup callbacks, etc.). + *

+ * QueueManager implementations that use custom factories can call this method directly + * to get the base builder without going through the factory (which could cause infinite + * recursion if the factory delegates back to getEventQueueBuilder()). + *

+ *

+ * Callers can then add additional configuration (hooks, callbacks) before building the queue. + *

+ * + * @param taskId the task ID for the queue + * @return a builder with base configuration specific to this QueueManager implementation + */ + default EventQueue.EventQueueBuilder createBaseEventQueueBuilder(String taskId) { + throw new UnsupportedOperationException( + "QueueManager implementations must override createBaseEventQueueBuilder() to provide MainEventBus" + ); } /** diff --git a/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java b/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java index 002acbafd..c96a461f5 100644 --- a/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java +++ b/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java @@ -11,13 +11,13 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; import java.util.concurrent.Flow; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; @@ -35,6 +35,8 @@ import io.a2a.server.events.EventConsumer; import io.a2a.server.events.EventQueue; import io.a2a.server.events.EventQueueItem; +import io.a2a.server.events.MainEventBusProcessor; +import io.a2a.server.events.MainEventBusProcessorCallback; import io.a2a.server.events.QueueManager; import io.a2a.server.events.TaskQueueExistsException; import io.a2a.server.tasks.PushNotificationConfigStore; @@ -42,6 +44,7 @@ import io.a2a.server.tasks.ResultAggregator; import io.a2a.server.tasks.TaskManager; import io.a2a.server.tasks.TaskStore; +import io.a2a.server.util.async.EventConsumerExecutorProducer.EventConsumerExecutor; import io.a2a.server.util.async.Internal; import io.a2a.spec.A2AError; import io.a2a.spec.DeleteTaskPushNotificationConfigParams; @@ -64,6 +67,7 @@ import io.a2a.spec.TaskPushNotificationConfig; import io.a2a.spec.TaskQueryParams; import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.UnsupportedOperationError; import org.jspecify.annotations.NonNull; import org.jspecify.annotations.Nullable; @@ -122,7 +126,6 @@ *
  • {@link EventConsumer} polls and processes events on Vert.x worker thread
  • *
  • Queue closes automatically on final event (COMPLETED/FAILED/CANCELED)
  • *
  • Cleanup waits for both agent execution AND event consumption to complete
  • - *
  • Background tasks tracked via {@link #trackBackgroundTask(CompletableFuture)}
  • * * *

    Threading Model

    @@ -179,6 +182,13 @@ public class DefaultRequestHandler implements RequestHandler { private static final Logger LOGGER = LoggerFactory.getLogger(DefaultRequestHandler.class); + /** + * Separate logger for thread statistics diagnostic logging. + * This allows independent control of verbose thread pool monitoring without affecting + * general request handler logging. Enable with: logging.level.io.a2a.server.diagnostics.ThreadStats=DEBUG + */ + private static final Logger THREAD_STATS_LOGGER = LoggerFactory.getLogger("io.a2a.server.diagnostics.ThreadStats"); + private static final String A2A_BLOCKING_AGENT_TIMEOUT_SECONDS = "a2a.blocking.agent.timeout.seconds"; private static final String A2A_BLOCKING_CONSUMPTION_TIMEOUT_SECONDS = "a2a.blocking.consumption.timeout.seconds"; @@ -214,13 +224,19 @@ public class DefaultRequestHandler implements RequestHandler { private TaskStore taskStore; private QueueManager queueManager; private PushNotificationConfigStore pushConfigStore; - private PushNotificationSender pushSender; + private MainEventBusProcessor mainEventBusProcessor; private Supplier requestContextBuilder; private final ConcurrentMap> runningAgents = new ConcurrentHashMap<>(); - private final Set> backgroundTasks = ConcurrentHashMap.newKeySet(); + + /** + * Map of taskId → CountDownLatch for tasks waiting for finalization. + * Used by blocking calls to wait for MainEventBusProcessor to persist final state to TaskStore. + */ + private final ConcurrentMap pendingFinalizations = new ConcurrentHashMap<>(); private Executor executor; + private Executor eventConsumerExecutor; /** * No-args constructor for CDI proxy creation. @@ -234,21 +250,25 @@ protected DefaultRequestHandler() { this.taskStore = null; this.queueManager = null; this.pushConfigStore = null; - this.pushSender = null; + this.mainEventBusProcessor = null; this.requestContextBuilder = null; this.executor = null; + this.eventConsumerExecutor = null; } @Inject public DefaultRequestHandler(AgentExecutor agentExecutor, TaskStore taskStore, QueueManager queueManager, PushNotificationConfigStore pushConfigStore, - PushNotificationSender pushSender, @Internal Executor executor) { + MainEventBusProcessor mainEventBusProcessor, + @Internal Executor executor, + @EventConsumerExecutor Executor eventConsumerExecutor) { this.agentExecutor = agentExecutor; this.taskStore = taskStore; this.queueManager = queueManager; this.pushConfigStore = pushConfigStore; - this.pushSender = pushSender; + this.mainEventBusProcessor = mainEventBusProcessor; this.executor = executor; + this.eventConsumerExecutor = eventConsumerExecutor; // TODO In Python this is also a constructor parameter defaulting to this SimpleRequestContextBuilder // implementation if the parameter is null. Skip that for now, since otherwise I get CDI errors, and // I am unsure about the correct scope. @@ -262,6 +282,97 @@ void initConfig() { configProvider.getValue(A2A_BLOCKING_AGENT_TIMEOUT_SECONDS)); consumptionCompletionTimeoutSeconds = Integer.parseInt( configProvider.getValue(A2A_BLOCKING_CONSUMPTION_TIMEOUT_SECONDS)); + + // Register permanent callback for task finalization events + registerFinalizationCallback(); + } + + /** + * Register the permanent callback for task finalization events. + *

    + * This single callback multiplexes for all concurrent requests via the + * {@link #pendingFinalizations} map. Called by both {@link #initConfig()} + * (@PostConstruct for CDI) and {@link #create(AgentExecutor, TaskStore, QueueManager, PushNotificationConfigStore, MainEventBusProcessor, Executor, Executor)} + * (static factory for tests). + *

    + */ + private void registerFinalizationCallback() { + mainEventBusProcessor.setCallback(new MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, Event event) { + // Not used for task finalization wait + } + + @Override + public void onTaskFinalized(String taskId) { + // Signal any blocking call waiting for this task to finalize + CountDownLatch latch = pendingFinalizations.get(taskId); + if (latch != null) { + latch.countDown(); + pendingFinalizations.remove(taskId); + LOGGER.debug("Task {} finalization signaled to waiting thread", taskId); + } + } + }); + } + + /** + * Wait for MainEventBusProcessor to finalize a task (reach final state and persist to TaskStore). + *

    + * This method is used by blocking calls to ensure TaskStore is fully updated before returning + * to the client. It registers this task in the {@link #pendingFinalizations} map, which is + * monitored by the permanent callback registered in {@link #initConfig()}. + *

    + *

    + * Why this is needed: Events flow through MainEventBus → MainEventBusProcessor → + * TaskStore persistence. The consumption future completing only means ChildQueue is empty, + * NOT that MainEventBusProcessor has finished persisting to TaskStore. This creates a race + * condition where blocking calls might read stale state from TaskStore. + *

    + *

    + * Concurrency: Uses a single permanent callback that multiplexes for all concurrent + * requests via {@link #pendingFinalizations} map. This avoids callback overwrite issues + * when multiple requests execute simultaneously. + *

    + *

    + * Race Condition Prevention: Checks TaskStore FIRST before waiting. If the task is + * already finalized, returns immediately. This prevents waiting forever when the callback + * fires before the latch is registered in the map. + *

    + * + * @param taskId the task ID to wait for + * @param timeoutSeconds maximum time to wait for finalization + * @throws InterruptedException if interrupted while waiting + * @throws TimeoutException if task doesn't finalize within timeout + */ + private void waitForTaskFinalization(String taskId, int timeoutSeconds) + throws InterruptedException, TimeoutException { + // CRITICAL: Check TaskStore FIRST to avoid race condition where callback fires + // before latch is registered. If task is finalized, return immediately. + Task task = taskStore.get(taskId); + if (task != null && task.status() != null && task.status().state() != null + && task.status().state().isFinal()) { + LOGGER.debug("Task {} is finalized in TaskStore, skipping wait", taskId); + return; + } + + CountDownLatch finalizationLatch = new CountDownLatch(1); + + // Register this task's latch in the map + // The permanent callback (registered in initConfig) will signal it + pendingFinalizations.put(taskId, finalizationLatch); + + try { + // Wait for the callback to fire + if (!finalizationLatch.await(timeoutSeconds, SECONDS)) { + throw new TimeoutException( + String.format("Task %s finalization timeout after %d seconds", taskId, timeoutSeconds)); + } + LOGGER.debug("Task {} finalized and persisted to TaskStore", taskId); + } finally { + // Always remove from map to avoid memory leaks + pendingFinalizations.remove(taskId); + } } /** @@ -269,11 +380,17 @@ void initConfig() { */ public static DefaultRequestHandler create(AgentExecutor agentExecutor, TaskStore taskStore, QueueManager queueManager, PushNotificationConfigStore pushConfigStore, - PushNotificationSender pushSender, Executor executor) { + MainEventBusProcessor mainEventBusProcessor, + Executor executor, Executor eventConsumerExecutor) { DefaultRequestHandler handler = - new DefaultRequestHandler(agentExecutor, taskStore, queueManager, pushConfigStore, pushSender, executor); + new DefaultRequestHandler(agentExecutor, taskStore, queueManager, pushConfigStore, + mainEventBusProcessor, executor, eventConsumerExecutor); handler.agentCompletionTimeoutSeconds = 5; handler.consumptionCompletionTimeoutSeconds = 2; + + // Register permanent callback for task finalization (normally done in @PostConstruct) + handler.registerFinalizationCallback(); + return handler; } @@ -357,14 +474,12 @@ public Task onCancelTask(TaskIdParams params, ServerCallContext context) throws task.id(), task.contextId(), taskStore, - null); + null, + false); // Not a temp ID - task already exists - ResultAggregator resultAggregator = new ResultAggregator(taskManager, null, executor); + ResultAggregator resultAggregator = new ResultAggregator(taskManager, null, executor, eventConsumerExecutor); - EventQueue queue = queueManager.tap(task.id()); - if (queue == null) { - queue = queueManager.getEventQueueBuilder(task.id()).build(); - } + EventQueue queue = queueManager.createOrTap(task.id()); agentExecutor.cancel( requestContextBuilder.get() .setTaskId(task.id()) @@ -395,28 +510,41 @@ public Task onCancelTask(TaskIdParams params, ServerCallContext context) throws @Override public EventKind onMessageSend(MessageSendParams params, ServerCallContext context) throws A2AError { LOGGER.debug("onMessageSend - task: {}; context {}", params.message().taskId(), params.message().contextId()); - MessageSendSetup mss = initMessageSend(params, context); - String taskId = mss.requestContext.getTaskId(); - LOGGER.debug("Request context taskId: {}", taskId); + // Generate temp taskId BEFORE initMessageSend if client didn't provide one + // This ensures TaskManager is created with a valid taskId for ResultAggregator + @Nullable String messageTaskId = params.message().taskId(); + boolean isTempId = messageTaskId == null; + String queueTaskId = isTempId ? "temp-" + java.util.UUID.randomUUID() : messageTaskId; + LOGGER.debug("Message taskId: {} (queue key: {}, isTempId: {})", messageTaskId, queueTaskId, isTempId); - if (taskId == null) { - throw new io.a2a.spec.InternalError("Task ID is null in onMessageSend"); - } - EventQueue queue = queueManager.createOrTap(taskId); - ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null, executor); + MessageSendSetup mss = initMessageSend(params, context, queueTaskId, isTempId); + + // Pass the actual tempId string (queueTaskId) if this is a temp ID, null otherwise + EventQueue queue = queueManager.createOrTap(queueTaskId, isTempId ? queueTaskId : null); + final java.util.concurrent.atomic.AtomicReference<@NonNull String> taskId = new java.util.concurrent.atomic.AtomicReference<>(queueTaskId); + ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null, executor, eventConsumerExecutor); + // Default to blocking=false per A2A spec (return after task creation) boolean blocking = params.configuration() != null && Boolean.TRUE.equals(params.configuration().blocking()); + // Log blocking behavior from client request + if (params.configuration() != null && params.configuration().blocking() != null) { + LOGGER.debug("DefaultRequestHandler: Client requested blocking={} for task {}", + params.configuration().blocking(), taskId.get()); + } else if (params.configuration() != null) { + LOGGER.debug("DefaultRequestHandler: Client sent configuration but blocking=null, using default blocking={} for task {}", blocking, taskId.get()); + } else { + LOGGER.debug("DefaultRequestHandler: Client sent no configuration, using default blocking={} for task {}", blocking, taskId.get()); + } + LOGGER.debug("DefaultRequestHandler: Final blocking decision: {} for task {}", blocking, taskId.get()); + boolean interruptedOrNonBlocking = false; - EnhancedRunnable producerRunnable = registerAndExecuteAgentAsync(taskId, mss.requestContext, queue); + EnhancedRunnable producerRunnable = registerAndExecuteAgentAsync(queueTaskId, mss.requestContext, queue); ResultAggregator.EventTypeAndInterrupt etai = null; EventKind kind = null; // Declare outside try block so it's in scope for return try { - // Create callback for push notifications during background event processing - Runnable pushNotificationCallback = () -> sendPushNotification(taskId, resultAggregator); - EventConsumer consumer = new EventConsumer(queue); // This callback must be added before we start consuming. Otherwise, @@ -424,7 +552,7 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte producerRunnable.addDoneCallback(consumer.createAgentRunnableDoneCallback()); // Get agent future before consuming (for blocking calls to wait for agent completion) - CompletableFuture agentFuture = runningAgents.get(taskId); + CompletableFuture agentFuture = runningAgents.get(queueTaskId); etai = resultAggregator.consumeAndBreakOnInterrupt(consumer, blocking); if (etai == null) { @@ -432,7 +560,8 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte throw new InternalError("No result"); } interruptedOrNonBlocking = etai.interrupted(); - LOGGER.debug("Was interrupted or non-blocking: {}", interruptedOrNonBlocking); + LOGGER.debug("DefaultRequestHandler: interruptedOrNonBlocking={} (blocking={}, eventType={})", + interruptedOrNonBlocking, blocking, kind != null ? kind.getClass().getSimpleName() : null); // For blocking calls that were interrupted (returned on first event), // wait for agent execution and event processing BEFORE returning to client. @@ -441,30 +570,49 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte // during the consumption loop itself. kind = etai.eventType(); + // Switch from temporary ID to real task ID if they differ + if (kind instanceof Task createdTask) { + String currentId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); + if (!Objects.equals(currentId, createdTask.id())) { + try { + switchFromTempToRealTaskId(currentId, createdTask.id(), mss.taskManager); + taskId.set(createdTask.id()); + } catch (TaskQueueExistsException | IllegalStateException e) { + String msg = "Failed to switch from temp ID " + currentId + " to real task ID " + createdTask.id() + ": " + e.getMessage(); + LOGGER.error(msg, e); + throw new InternalError(msg); + } + } + } + // Store push notification config for newly created tasks (mirrors streaming logic) // Only for NEW tasks - existing tasks are handled by initMessageSend() if (mss.task() == null && kind instanceof Task createdTask && shouldAddPushInfo(params)) { - LOGGER.debug("Storing push notification config for new task {}", createdTask.id()); + LOGGER.debug("Storing push notification config for new task {} (original taskId from params: {})", + createdTask.id(), params.message().taskId()); pushConfigStore.setInfo(createdTask.id(), params.configuration().pushNotificationConfig()); } if (blocking && interruptedOrNonBlocking) { - // For blocking calls: ensure all events are processed before returning - // Order of operations is critical to avoid circular dependency: + // For blocking calls: ensure all events are persisted to TaskStore before returning + // Order of operations is critical to avoid circular dependency and race conditions: // 1. Wait for agent to finish enqueueing events // 2. Close the queue to signal consumption can complete // 3. Wait for consumption to finish processing events - // 4. Fetch final task state from TaskStore + // 4. Wait for MainEventBusProcessor to persist final state to TaskStore + // 5. Fetch final task state from TaskStore (now guaranteed persisted) + LOGGER.debug("DefaultRequestHandler: Entering blocking fire-and-forget handling for task {}", taskId.get()); try { // Step 1: Wait for agent to finish (with configurable timeout) if (agentFuture != null) { try { agentFuture.get(agentCompletionTimeoutSeconds, SECONDS); - LOGGER.debug("Agent completed for task {}", taskId); + LOGGER.debug("DefaultRequestHandler: Step 1 - Agent completed for task {}", taskId.get()); } catch (java.util.concurrent.TimeoutException e) { // Agent still running after timeout - that's fine, events already being processed - LOGGER.debug("Agent still running for task {} after {}s", taskId, agentCompletionTimeoutSeconds); + LOGGER.debug("DefaultRequestHandler: Step 1 - Agent still running for task {} after {}s timeout", + taskId.get(), agentCompletionTimeoutSeconds); } } @@ -472,55 +620,87 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte // For fire-and-forget tasks, there's no final event, so we need to close the queue // This allows EventConsumer.consumeAll() to exit queue.close(false, false); // graceful close, don't notify parent yet - LOGGER.debug("Closed queue for task {} to allow consumption completion", taskId); + LOGGER.debug("DefaultRequestHandler: Step 2 - Closed queue for task {} to allow consumption completion", taskId.get()); // Step 3: Wait for consumption to complete (now that queue is closed) if (etai.consumptionFuture() != null) { etai.consumptionFuture().get(consumptionCompletionTimeoutSeconds, SECONDS); - LOGGER.debug("Consumption completed for task {}", taskId); + LOGGER.debug("DefaultRequestHandler: Step 3 - Consumption completed for task {}", taskId.get()); + } + + // Step 4: Wait for MainEventBusProcessor to finalize task (persist to TaskStore) + // For blocking calls, ALWAYS try to wait for finalization. + // waitForTaskFinalization() checks TaskStore first, so if task is already finalized + // it returns immediately. + // This is CRITICAL: consumption completing only means ChildQueue is empty, NOT that + // MainEventBusProcessor has finished persisting to TaskStore. The callback ensures + // we wait for the final state to be persisted before reading from TaskStore. + try { + String taskIdForFinalization = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); + waitForTaskFinalization(taskIdForFinalization, consumptionCompletionTimeoutSeconds); + LOGGER.debug("DefaultRequestHandler: Step 4 - Task {} finalized and persisted to TaskStore", taskId.get()); + } catch (TimeoutException e) { + // Timeout is OK for fire-and-forget tasks that never reach final state + // Just log and continue - we'll return the current non-final state + LOGGER.debug("DefaultRequestHandler: Step 4 - Task {} finalization timeout (fire-and-forget task)", taskId.get()); } + } catch (InterruptedException e) { Thread.currentThread().interrupt(); - String msg = String.format("Error waiting for task %s completion", taskId); + String msg = String.format("Error waiting for task %s completion", taskId.get()); LOGGER.warn(msg, e); throw new InternalError(msg); } catch (java.util.concurrent.ExecutionException e) { - String msg = String.format("Error during task %s execution", taskId); + String msg = String.format("Error during task %s execution", taskId.get()); LOGGER.warn(msg, e.getCause()); throw new InternalError(msg); - } catch (java.util.concurrent.TimeoutException e) { - String msg = String.format("Timeout waiting for consumption to complete for task %s", taskId); - LOGGER.warn(msg, taskId); + } catch (TimeoutException e) { + // Timeout from consumption future.get() - different from finalization timeout + String msg = String.format("Timeout waiting for task %s consumption", taskId.get()); + LOGGER.warn(msg, e); throw new InternalError(msg); } - // Step 4: Fetch the final task state from TaskStore (all events have been processed) - // taskId is guaranteed non-null here (checked earlier) - String nonNullTaskId = taskId; + // Step 5: Fetch the final task state from TaskStore (now guaranteed persisted) + String nonNullTaskId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); Task updatedTask = taskStore.get(nonNullTaskId); if (updatedTask != null) { kind = updatedTask; - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Fetched final task for {} with state {} and {} artifacts", - nonNullTaskId, updatedTask.status().state(), - updatedTask.artifacts().size()); - } + LOGGER.debug("DefaultRequestHandler: Step 5 - Fetched final task for {} with state {} and {} artifacts", + taskId.get(), updatedTask.status().state(), + updatedTask.artifacts().size()); + } else { + LOGGER.warn("DefaultRequestHandler: Step 5 - Task {} not found in TaskStore!", taskId.get()); } } - if (kind instanceof Task taskResult && !taskId.equals(taskResult.id())) { + String finalTaskId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); + if (kind instanceof Task taskResult && !finalTaskId.equals(taskResult.id())) { throw new InternalError("Task ID mismatch in agent response"); } - - // Send push notification after initial return (for both blocking and non-blocking) - pushNotificationCallback.run(); } finally { + // For non-blocking calls: close ChildQueue IMMEDIATELY to free EventConsumer thread + // CRITICAL: Must use immediate=true to clear the local queue, otherwise EventConsumer + // continues polling until queue drains naturally, holding executor thread. + // Immediate close clears pending events and triggers EventQueueClosedException on next poll. + // Events continue flowing through MainQueue → MainEventBus → TaskStore. + if (!blocking && etai != null && etai.interrupted()) { + LOGGER.debug("DefaultRequestHandler: Non-blocking call in finally - closing ChildQueue IMMEDIATELY for task {} to free EventConsumer", taskId.get()); + queue.close(true); // immediate=true: clear queue and free EventConsumer + } + // Remove agent from map immediately to prevent accumulation - CompletableFuture agentFuture = runningAgents.remove(taskId); - LOGGER.debug("Removed agent for task {} from runningAgents in finally block, size after: {}", taskId, runningAgents.size()); + CompletableFuture agentFuture = runningAgents.remove(queueTaskId); + String cleanupTaskId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); + LOGGER.debug("Removed agent for task {} from runningAgents in finally block, size after: {}", cleanupTaskId, runningAgents.size()); - // Track cleanup as background task to avoid blocking Vert.x threads + // Cleanup as background task to avoid blocking Vert.x threads // Pass the consumption future to ensure cleanup waits for background consumption to complete - trackBackgroundTask(cleanupProducer(agentFuture, etai != null ? etai.consumptionFuture() : null, taskId, queue, false)); + cleanupProducer(agentFuture, etai != null ? etai.consumptionFuture() : null, cleanupTaskId, queue, false) + .whenComplete((res, err) -> { + if (err != null) { + LOGGER.error("Error during async cleanup for task {}", taskId.get(), err); + } + }); } LOGGER.debug("Returning: {}", kind); @@ -530,29 +710,45 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte @Override public Flow.Publisher onMessageSendStream( MessageSendParams params, ServerCallContext context) throws A2AError { - LOGGER.debug("onMessageSendStream START - task: {}; context: {}; runningAgents: {}; backgroundTasks: {}", - params.message().taskId(), params.message().contextId(), runningAgents.size(), backgroundTasks.size()); - MessageSendSetup mss = initMessageSend(params, context); + LOGGER.debug("onMessageSendStream START - task: {}; context: {}; runningAgents: {}", + params.message().taskId(), params.message().contextId(), runningAgents.size()); + + // Generate temp taskId BEFORE initMessageSend if client didn't provide one + // This ensures TaskManager is created with a valid taskId for ResultAggregator + @Nullable String messageTaskId = params.message().taskId(); + boolean isTempId = messageTaskId == null; + String queueTaskId = isTempId ? "temp-" + java.util.UUID.randomUUID() : messageTaskId; - @Nullable String initialTaskId = mss.requestContext.getTaskId(); - // For streaming, taskId can be null initially (will be set when Task event arrives) - // Use a temporary ID for queue creation if needed - String queueTaskId = initialTaskId != null ? initialTaskId : "temp-" + java.util.UUID.randomUUID(); + MessageSendSetup mss = initMessageSend(params, context, queueTaskId, isTempId); - AtomicReference<@NonNull String> taskId = new AtomicReference<>(queueTaskId); + final AtomicReference<@NonNull String> taskId = new AtomicReference<>(queueTaskId); + // Pass the actual tempId string (queueTaskId) if this is a temp ID, null otherwise @SuppressWarnings("NullAway") - EventQueue queue = queueManager.createOrTap(taskId.get()); + EventQueue queue = queueManager.createOrTap(taskId.get(), isTempId ? queueTaskId : null); LOGGER.debug("Created/tapped queue for task {}: {}", taskId.get(), queue); - ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null, executor); + + // Store push notification config SYNCHRONOUSLY for new tasks before agent starts + // This ensures config is available when MainEventBusProcessor sends push notifications + // For existing tasks, config is stored in initMessageSend() + if (mss.task() == null && shouldAddPushInfo(params)) { + // Satisfy Nullaway + Objects.requireNonNull(taskId.get(), "taskId was null"); + LOGGER.debug("Storing push notification config for new streaming task {} EARLY (original taskId from params: {})", + taskId.get(), params.message().taskId()); + pushConfigStore.setInfo(taskId.get(), params.configuration().pushNotificationConfig()); + } + + ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null, executor, eventConsumerExecutor); EnhancedRunnable producerRunnable = registerAndExecuteAgentAsync(queueTaskId, mss.requestContext, queue); // Move consumer creation and callback registration outside try block - // so consumer is available for background consumption on client disconnect EventConsumer consumer = new EventConsumer(queue); producerRunnable.addDoneCallback(consumer.createAgentRunnableDoneCallback()); - AtomicBoolean backgroundConsumeStarted = new AtomicBoolean(false); + // Store cancel callback in context for closeHandler to access + // When client disconnects, closeHandler can call this to stop EventConsumer polling loop + context.setEventConsumerCancelCallback(consumer::cancel); try { Flow.Publisher results = resultAggregator.consumeAndEmit(consumer); @@ -562,36 +758,25 @@ public Flow.Publisher onMessageSendStream( processor(createTubeConfig(), results, ((errorConsumer, item) -> { Event event = item.getEvent(); if (event instanceof Task createdTask) { - if (!Objects.equals(taskId.get(), createdTask.id())) { - errorConsumer.accept(new InternalError("Task ID mismatch in agent response")); - } - - // TODO the Python implementation no longer has the following block but removing it causes - // failures here - try { - queueManager.add(createdTask.id(), queue); - taskId.set(createdTask.id()); - } catch (TaskQueueExistsException e) { - // TODO Log - } - if (pushConfigStore != null && - params.configuration() != null && - params.configuration().pushNotificationConfig() != null) { - - pushConfigStore.setInfo( - createdTask.id(), - params.configuration().pushNotificationConfig()); + // Switch from temporary ID to real task ID if they differ + String currentId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); + if (!Objects.equals(currentId, createdTask.id())) { + try { + switchFromTempToRealTaskId(currentId, createdTask.id(), mss.taskManager); + taskId.set(createdTask.id()); + } catch (TaskQueueExistsException | IllegalStateException e) { + errorConsumer.accept(new InternalError("Failed to switch from temp ID " + currentId + + " to real task ID " + createdTask.id() + ": " + e.getMessage())); + return true; // Don't proceed to final check if switch failed + } } - } - String currentTaskId = taskId.get(); - if (pushSender != null && currentTaskId != null) { - EventKind latest = resultAggregator.getCurrentResult(); - if (latest instanceof Task latestTask) { - pushSender.sendNotification(latestTask); + // Final verification AFTER switch attempt + String finalTaskId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); + if (!finalTaskId.equals(createdTask.id())) { + errorConsumer.accept(new InternalError("Task ID mismatch in agent response")); } } - return true; })); @@ -600,7 +785,8 @@ public Flow.Publisher onMessageSendStream( Flow.Publisher finalPublisher = convertingProcessor(eventPublisher, event -> (StreamingEventKind) event); - // Wrap publisher to detect client disconnect and continue background consumption + // Wrap publisher to detect client disconnect and immediately close ChildQueue + // This prevents ChildQueue backpressure from blocking MainEventBusProcessor return subscriber -> { String currentTaskId = taskId.get(); LOGGER.debug("Creating subscription wrapper for task {}", currentTaskId); @@ -621,8 +807,10 @@ public void request(long n) { @Override public void cancel() { - LOGGER.debug("Client cancelled subscription for task {}, starting background consumption", taskId.get()); - startBackgroundConsumption(); + LOGGER.debug("Client cancelled subscription for task {}, closing ChildQueue immediately", taskId.get()); + // Close ChildQueue immediately to prevent backpressure + // (clears queue and releases semaphore permits) + queue.close(true); // immediate=true subscription.cancel(); } }); @@ -647,8 +835,8 @@ public void onComplete() { subscriber.onComplete(); } catch (IllegalStateException e) { // Client already disconnected and response closed - this is expected - // for streaming responses where client disconnect triggers background - // consumption. Log and ignore. + // for streaming responses where client disconnect closes ChildQueue. + // Log and ignore. if (e.getMessage() != null && e.getMessage().contains("Response has already been written")) { LOGGER.debug("Client disconnected before onComplete, response already closed for task {}", taskId.get()); } else { @@ -656,36 +844,26 @@ public void onComplete() { } } } - - private void startBackgroundConsumption() { - if (backgroundConsumeStarted.compareAndSet(false, true)) { - LOGGER.debug("Starting background consumption for task {}", taskId.get()); - // Client disconnected: continue consuming and persisting events in background - CompletableFuture bgTask = CompletableFuture.runAsync(() -> { - try { - LOGGER.debug("Background consumption thread started for task {}", taskId.get()); - resultAggregator.consumeAll(consumer); - LOGGER.debug("Background consumption completed for task {}", taskId.get()); - } catch (Exception e) { - LOGGER.error("Error during background consumption for task {}", taskId.get(), e); - } - }, executor); - trackBackgroundTask(bgTask); - } else { - LOGGER.debug("Background consumption already started for task {}", taskId.get()); - } - } }); }; } finally { - LOGGER.debug("onMessageSendStream FINALLY - task: {}; runningAgents: {}; backgroundTasks: {}", - taskId.get(), runningAgents.size(), backgroundTasks.size()); - - // Remove agent from map immediately to prevent accumulation - CompletableFuture agentFuture = runningAgents.remove(taskId.get()); - LOGGER.debug("Removed agent for task {} from runningAgents in finally block, size after: {}", taskId.get(), runningAgents.size()); - - trackBackgroundTask(cleanupProducer(agentFuture, null, Objects.requireNonNull(taskId.get()), queue, true)); + // Needed to satisfy Nullaway + String idOfTask = taskId.get(); + if (idOfTask != null) { + LOGGER.debug("onMessageSendStream FINALLY - task: {}; runningAgents: {}", + idOfTask, runningAgents.size()); + + // Remove agent from map immediately to prevent accumulation + CompletableFuture agentFuture = runningAgents.remove(idOfTask); + LOGGER.debug("Removed agent for task {} from runningAgents in finally block, size after: {}", taskId.get(), runningAgents.size()); + + cleanupProducer(agentFuture, null, idOfTask, queue, true) + .whenComplete((res, err) -> { + if (err != null) { + LOGGER.error("Error during async cleanup for streaming task {}", taskId.get(), err); + } + }); + } } } @@ -745,8 +923,8 @@ public Flow.Publisher onResubscribeToTask( throw new TaskNotFoundError(); } - TaskManager taskManager = new TaskManager(task.id(), task.contextId(), taskStore, null); - ResultAggregator resultAggregator = new ResultAggregator(taskManager, null, executor); + TaskManager taskManager = new TaskManager(task.id(), task.contextId(), taskStore, null, false); // Not a temp ID - task already exists + ResultAggregator resultAggregator = new ResultAggregator(taskManager, null, executor, eventConsumerExecutor); EventQueue queue = queueManager.tap(task.id()); LOGGER.debug("onResubscribeToTask - tapped queue: {}", queue != null ? System.identityHashCode(queue) : "null"); @@ -819,8 +997,7 @@ public void run() { LOGGER.debug("Agent execution starting for task {}", taskId); agentExecutor.execute(requestContext, queue); LOGGER.debug("Agent execution completed for task {}", taskId); - // No longer wait for queue poller to start - the consumer (which is guaranteed - // to be running on the Vert.x worker thread) will handle queue lifecycle. + // The consumer (running on the Vert.x worker thread) handles queue lifecycle. // This avoids blocking agent-executor threads waiting for worker threads. } }; @@ -833,8 +1010,8 @@ public void run() { // Don't close queue here - let the consumer handle it via error callback // This ensures the consumer (which may not have started polling yet) gets the error } - // Queue lifecycle is now managed entirely by EventConsumer.consumeAll() - // which closes the queue on final events. No need to close here. + // Queue lifecycle is managed by EventConsumer.consumeAll() + // which closes the queue on final events. logThreadStats("AGENT COMPLETE END"); runnable.invokeDoneCallbacks(); }); @@ -843,47 +1020,6 @@ public void run() { return runnable; } - private void trackBackgroundTask(CompletableFuture task) { - backgroundTasks.add(task); - LOGGER.debug("Tracking background task (total: {}): {}", backgroundTasks.size(), task); - - task.whenComplete((result, throwable) -> { - try { - if (throwable != null) { - // Unwrap CompletionException to check for CancellationException - Throwable cause = throwable; - if (throwable instanceof java.util.concurrent.CompletionException && throwable.getCause() != null) { - cause = throwable.getCause(); - } - - if (cause instanceof java.util.concurrent.CancellationException) { - LOGGER.debug("Background task cancelled: {}", task); - } else { - LOGGER.error("Background task failed", throwable); - } - } - } finally { - backgroundTasks.remove(task); - LOGGER.debug("Removed background task (remaining: {}): {}", backgroundTasks.size(), task); - } - }); - } - - /** - * Wait for all background tasks to complete. - * Useful for testing to ensure cleanup completes before assertions. - * - * @return CompletableFuture that completes when all background tasks finish - */ - public CompletableFuture waitForBackgroundTasks() { - CompletableFuture[] tasks = backgroundTasks.toArray(new CompletableFuture[0]); - if (tasks.length == 0) { - return CompletableFuture.completedFuture(null); - } - LOGGER.debug("Waiting for {} background tasks to complete", tasks.length); - return CompletableFuture.allOf(tasks); - } - private CompletableFuture cleanupProducer(@Nullable CompletableFuture agentFuture, @Nullable CompletableFuture consumptionFuture, String taskId, EventQueue queue, boolean isStreaming) { LOGGER.debug("Starting cleanup for task {} (streaming={})", taskId, isStreaming); logThreadStats("CLEANUP START"); @@ -908,14 +1044,20 @@ private CompletableFuture cleanupProducer(@Nullable CompletableFuture cleanupProducer(@Nullable CompletableFuture agentFuture = runningAgents.remove(tempId); + if (agentFuture != null) { + runningAgents.put(realId, agentFuture); + LOGGER.debug("Moved runningAgents from {} to {}", tempId, realId); + } + + // 4. Switch push notification configs (if configured) + if (pushConfigStore != null) { + pushConfigStore.switchKey(tempId, realId); + } + + // 5. Update TaskManager's taskId + taskManager.setTaskId(realId); + + LOGGER.debug("Completed switch from temp ID {} to real task ID {}", tempId, realId); + } + + private MessageSendSetup initMessageSend(MessageSendParams params, ServerCallContext context, String taskId, boolean isTempId) { TaskManager taskManager = new TaskManager( - params.message().taskId(), + taskId, params.message().contextId(), taskStore, - params.message()); + params.message(), + isTempId); Task task = taskManager.getTask(); if (task != null) { @@ -945,9 +1134,11 @@ private MessageSendSetup initMessageSend(MessageSendParams params, ServerCallCon } } + // For RequestContext, pass null as taskId when using temp ID to avoid validation error + // The temp UUID is only for queue management, not for the RequestContext RequestContext requestContext = requestContextBuilder.get() .setParams(params) - .setTaskId(task == null ? null : task.id()) + .setTaskId(isTempId ? null : taskId) .setContextId(params.message().contextId()) .setTask(task) .setServerCallContext(context) @@ -955,24 +1146,19 @@ private MessageSendSetup initMessageSend(MessageSendParams params, ServerCallCon return new MessageSendSetup(taskManager, task, requestContext); } - private void sendPushNotification(String taskId, ResultAggregator resultAggregator) { - if (pushSender != null) { - EventKind latest = resultAggregator.getCurrentResult(); - if (latest instanceof Task latestTask) { - pushSender.sendNotification(latestTask); - } - } - } - /** * Log current thread and resource statistics for debugging. + * Uses dedicated {@link #THREAD_STATS_LOGGER} for independent logging control. * Only logs when DEBUG level is enabled. Call this from debugger or add strategic * calls during investigation. In production with INFO logging, this is a no-op. + *

    + * Enable independently with: {@code logging.level.io.a2a.server.diagnostics.ThreadStats=DEBUG} + *

    */ @SuppressWarnings("unused") // Used for debugging private void logThreadStats(String label) { // Early return if debug logging is not enabled to avoid overhead - if (!LOGGER.isDebugEnabled()) { + if (!THREAD_STATS_LOGGER.isDebugEnabled()) { return; } @@ -982,28 +1168,57 @@ private void logThreadStats(String label) { } int activeThreads = rootGroup.activeCount(); - LOGGER.debug("=== THREAD STATS: {} ===", label); - LOGGER.debug("Active threads: {}", activeThreads); - LOGGER.debug("Running agents: {}", runningAgents.size()); - LOGGER.debug("Background tasks: {}", backgroundTasks.size()); - LOGGER.debug("Queue manager active queues: {}", queueManager.getClass().getSimpleName()); + // Count specific thread types + Thread[] threads = new Thread[activeThreads * 2]; + int count = rootGroup.enumerate(threads); + int eventConsumerThreads = 0; + int agentExecutorThreads = 0; + for (int i = 0; i < count; i++) { + if (threads[i] != null) { + String name = threads[i].getName(); + if (name.startsWith("a2a-event-consumer-")) { + eventConsumerThreads++; + } else if (name.startsWith("a2a-agent-executor-")) { + agentExecutorThreads++; + } + } + } + + THREAD_STATS_LOGGER.debug("=== THREAD STATS: {} ===", label); + THREAD_STATS_LOGGER.debug("Total active threads: {}", activeThreads); + THREAD_STATS_LOGGER.debug("EventConsumer threads: {}", eventConsumerThreads); + THREAD_STATS_LOGGER.debug("AgentExecutor threads: {}", agentExecutorThreads); + THREAD_STATS_LOGGER.debug("Running agents: {}", runningAgents.size()); + THREAD_STATS_LOGGER.debug("Queue manager active queues: {}", queueManager.getClass().getSimpleName()); // List running agents if (!runningAgents.isEmpty()) { - LOGGER.debug("Running agent tasks:"); + THREAD_STATS_LOGGER.debug("Running agent tasks:"); runningAgents.forEach((taskId, future) -> - LOGGER.debug(" - Task {}: {}", taskId, future.isDone() ? "DONE" : "RUNNING") + THREAD_STATS_LOGGER.debug(" - Task {}: {}", taskId, future.isDone() ? "DONE" : "RUNNING") ); } - // List background tasks - if (!backgroundTasks.isEmpty()) { - LOGGER.debug("Background tasks:"); - backgroundTasks.forEach(task -> - LOGGER.debug(" - {}: {}", task, task.isDone() ? "DONE" : "RUNNING") - ); + THREAD_STATS_LOGGER.debug("=== END THREAD STATS ==="); + } + + /** + * Check if an event represents a final task state. + * + * @param eventKind the event to check + * @return true if the event represents a final state (COMPLETED, FAILED, CANCELED, REJECTED, UNKNOWN) + */ + private boolean isFinalEvent(EventKind eventKind) { + if (!(eventKind instanceof Event event)) { + return false; + } + if (event instanceof Task task) { + return task.status() != null && task.status().state() != null + && task.status().state().isFinal(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.isFinal(); } - LOGGER.debug("=== END THREAD STATS ==="); + return false; } private record MessageSendSetup(TaskManager taskManager, @Nullable Task task, RequestContext requestContext) {} diff --git a/server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStore.java b/server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStore.java index 093ff910d..d18610875 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStore.java +++ b/server-common/src/main/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStore.java @@ -118,4 +118,12 @@ public void deleteInfo(String taskId, String configId) { pushNotificationInfos.remove(taskId); } } + + @Override + public void switchKey(String oldTaskId, String newTaskId) { + List configs = pushNotificationInfos.remove(oldTaskId); + if (configs != null && !configs.isEmpty()) { + pushNotificationInfos.put(newTaskId, configs); + } + } } diff --git a/server-common/src/main/java/io/a2a/server/tasks/PushNotificationConfigStore.java b/server-common/src/main/java/io/a2a/server/tasks/PushNotificationConfigStore.java index 828b066a6..e6686b999 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/PushNotificationConfigStore.java +++ b/server-common/src/main/java/io/a2a/server/tasks/PushNotificationConfigStore.java @@ -120,4 +120,20 @@ public interface PushNotificationConfigStore { */ void deleteInfo(String taskId, String configId); + /** + * Switches push notification configuration from an old task ID to a new task ID. + *

    + * Used when transitioning from a temporary task ID (e.g., "temp-UUID") to the real task ID + * when the Task event arrives with the actual task.id. Moves all push notification configs + * associated with the old task ID to the new task ID. + *

    + *

    + * If no configs exist for the old task ID, this method returns normally (no-op). + *

    + * + * @param oldTaskId the temporary/old task identifier + * @param newTaskId the real/new task identifier + */ + void switchKey(String oldTaskId, String newTaskId); + } diff --git a/server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java b/server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java index 95684e199..506b3f3b6 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java +++ b/server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java @@ -14,9 +14,9 @@ import io.a2a.server.events.EventConsumer; import io.a2a.server.events.EventQueueItem; import io.a2a.spec.A2AError; -import io.a2a.spec.A2AServerException; import io.a2a.spec.Event; import io.a2a.spec.EventKind; +import io.a2a.spec.InternalError; import io.a2a.spec.Message; import io.a2a.spec.Task; import io.a2a.spec.TaskState; @@ -31,12 +31,14 @@ public class ResultAggregator { private final TaskManager taskManager; private final Executor executor; + private final Executor eventConsumerExecutor; private volatile @Nullable Message message; - public ResultAggregator(TaskManager taskManager, @Nullable Message message, Executor executor) { + public ResultAggregator(TaskManager taskManager, @Nullable Message message, Executor executor, Executor eventConsumerExecutor) { this.taskManager = taskManager; this.message = message; this.executor = executor; + this.eventConsumerExecutor = eventConsumerExecutor; } public @Nullable EventKind getCurrentResult() { @@ -49,20 +51,23 @@ public ResultAggregator(TaskManager taskManager, @Nullable Message message, Exec public Flow.Publisher consumeAndEmit(EventConsumer consumer) { Flow.Publisher allItems = consumer.consumeAll(); - // Process items conditionally - only save non-replicated events to database - return processor(createTubeConfig(), allItems, (errorConsumer, item) -> { - // Only process non-replicated events to avoid duplicate database writes - if (!item.isReplicated()) { - try { - callTaskManagerProcess(item.getEvent()); - } catch (A2AServerException e) { - errorConsumer.accept(e); - return false; - } - } - // Continue processing and emit (both replicated and non-replicated) + // Just stream events - no persistence needed + // TaskStore update moved to MainEventBusProcessor + Flow.Publisher processed = processor(createTubeConfig(), allItems, (errorConsumer, item) -> { + // Continue processing and emit all events return true; }); + + // Wrap the publisher to ensure subscription happens on eventConsumerExecutor + // This prevents EventConsumer polling loop from running on AgentExecutor threads + // which caused thread accumulation when those threads didn't timeout + return new Flow.Publisher() { + @Override + public void subscribe(Flow.Subscriber subscriber) { + // Submit subscription to eventConsumerExecutor to isolate polling work + eventConsumerExecutor.execute(() -> processed.subscribe(subscriber)); + } + }; } public EventKind consumeAll(EventConsumer consumer) throws A2AError { @@ -81,15 +86,7 @@ public EventKind consumeAll(EventConsumer consumer) throws A2AError { return false; } } - // Only process non-replicated events to avoid duplicate database writes - if (!item.isReplicated()) { - try { - callTaskManagerProcess(event); - } catch (A2AServerException e) { - error.set(e); - return false; - } - } + // TaskStore update moved to MainEventBusProcessor return true; }, error::set); @@ -113,18 +110,24 @@ public EventKind consumeAll(EventConsumer consumer) throws A2AError { public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer, boolean blocking) throws A2AError { Flow.Publisher allItems = consumer.consumeAll(); AtomicReference message = new AtomicReference<>(); + AtomicReference capturedTask = new AtomicReference<>(); // Capture Task events AtomicBoolean interrupted = new AtomicBoolean(false); AtomicReference errorRef = new AtomicReference<>(); CompletableFuture completionFuture = new CompletableFuture<>(); // Separate future for tracking background consumption completion CompletableFuture consumptionCompletionFuture = new CompletableFuture<>(); + // Latch to ensure EventConsumer starts polling before we wait on completionFuture + java.util.concurrent.CountDownLatch pollingStarted = new java.util.concurrent.CountDownLatch(1); // CRITICAL: The subscription itself must run on a background thread to avoid blocking // the Vert.x worker thread. EventConsumer.consumeAll() starts a polling loop that // blocks in dequeueEventItem(), so we must subscribe from a background thread. - // Use the @Internal executor (not ForkJoinPool.commonPool) to avoid saturation - // during concurrent request bursts. + // Use the dedicated @EventConsumerExecutor (cached thread pool) which creates threads + // on demand for I/O-bound polling. Using the @Internal executor caused deadlock when + // pool exhausted (100+ concurrent queues but maxPoolSize=50). CompletableFuture.runAsync(() -> { + // Signal that polling is about to start + pollingStarted.countDown(); consumer( createTubeConfig(), allItems, @@ -146,25 +149,30 @@ public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer, return false; } - // Process event through TaskManager - only for non-replicated events - if (!item.isReplicated()) { - try { - callTaskManagerProcess(event); - } catch (A2AServerException e) { - errorRef.set(e); - completionFuture.completeExceptionally(e); - return false; + // Capture Task events (especially for new tasks where taskManager.getTask() would return null) + // We capture the LATEST task to ensure we get the most up-to-date state + if (event instanceof Task t) { + Task previousTask = capturedTask.get(); + capturedTask.set(t); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Captured Task event: id={}, state={} (previous: {})", + t.id(), t.status().state(), + previousTask != null ? previousTask.id() + "/" + previousTask.status().state() : "none"); } } + // TaskStore update moved to MainEventBusProcessor + // Determine interrupt behavior boolean shouldInterrupt = false; - boolean continueInBackground = false; boolean isFinalEvent = (event instanceof Task task && task.status().state().isFinal()) || (event instanceof TaskStatusUpdateEvent tsue && tsue.isFinal()); boolean isAuthRequired = (event instanceof Task task && task.status().state() == TaskState.AUTH_REQUIRED) || (event instanceof TaskStatusUpdateEvent tsue && tsue.status().state() == TaskState.AUTH_REQUIRED); + LOGGER.debug("ResultAggregator: Evaluating interrupt (blocking={}, isFinal={}, isAuth={}, eventType={})", + blocking, isFinalEvent, isAuthRequired, event.getClass().getSimpleName()); + // Always interrupt on auth_required, as it needs external action. if (isAuthRequired) { // auth-required is a special state: the message should be @@ -174,20 +182,19 @@ public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer, // new request is expected in order for the agent to make progress, // so the agent should exit. shouldInterrupt = true; - continueInBackground = true; + LOGGER.debug("ResultAggregator: Setting shouldInterrupt=true (AUTH_REQUIRED)"); } else if (!blocking) { // For non-blocking calls, interrupt as soon as a task is available. shouldInterrupt = true; - continueInBackground = true; + LOGGER.debug("ResultAggregator: Setting shouldInterrupt=true (non-blocking)"); } else if (blocking) { // For blocking calls: Interrupt to free Vert.x thread, but continue in background // Python's async consumption doesn't block threads, but Java's does // So we interrupt to return quickly, then rely on background consumption - // DefaultRequestHandler will fetch the final state from TaskStore shouldInterrupt = true; - continueInBackground = true; + LOGGER.debug("ResultAggregator: Setting shouldInterrupt=true (blocking, isFinal={})", isFinalEvent); if (LOGGER.isDebugEnabled()) { LOGGER.debug("Blocking call for task {}: {} event, returning with background consumption", taskIdForLogging(), isFinalEvent ? "final" : "non-final"); @@ -195,14 +202,14 @@ else if (blocking) { } if (shouldInterrupt) { + LOGGER.debug("ResultAggregator: Interrupting consumption (setting interrupted=true)"); // Complete the future to unblock the main thread interrupted.set(true); completionFuture.complete(null); // For blocking calls, DON'T complete consumptionCompletionFuture here. // Let it complete naturally when subscription finishes (onComplete callback below). - // This ensures all events are processed and persisted to TaskStore before - // DefaultRequestHandler.cleanupProducer() proceeds with cleanup. + // This ensures all events are fully processed before cleanup. // // For non-blocking and auth-required calls, complete immediately to allow // cleanup to proceed while consumption continues in background. @@ -237,7 +244,16 @@ else if (blocking) { } } ); - }, executor); + }, eventConsumerExecutor); + + // Wait for EventConsumer to start polling before we wait for events + // This prevents race where agent enqueues events before EventConsumer starts + try { + pollingStarted.await(5, java.util.concurrent.TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new io.a2a.spec.InternalError("Interrupted waiting for EventConsumer to start"); + } // Wait for completion or interruption try { @@ -261,28 +277,30 @@ else if (blocking) { Utils.rethrow(error); } - EventKind eventType; - Message msg = message.get(); - if (msg != null) { - eventType = msg; - } else { - Task task = taskManager.getTask(); - if (task == null) { - throw new io.a2a.spec.InternalError("No task or message available after consuming events"); + // Return Message if captured, otherwise Task if captured, otherwise fetch from TaskStore + EventKind eventKind = message.get(); + if (eventKind == null) { + eventKind = capturedTask.get(); + if (LOGGER.isDebugEnabled() && eventKind instanceof Task t) { + LOGGER.debug("Returning capturedTask: id={}, state={}", t.id(), t.status().state()); } - eventType = task; + } + if (eventKind == null) { + eventKind = taskManager.getTask(); + if (LOGGER.isDebugEnabled() && eventKind instanceof Task t) { + LOGGER.debug("Returning task from TaskStore: id={}, state={}", t.id(), t.status().state()); + } + } + if (eventKind == null) { + throw new InternalError("Could not find a Task/Message for " + taskManager.getTaskId()); } return new EventTypeAndInterrupt( - eventType, + eventKind, interrupted.get(), consumptionCompletionFuture); } - private void callTaskManagerProcess(Event event) throws A2AServerException { - taskManager.process(event); - } - private String taskIdForLogging() { Task task = taskManager.getTask(); return task != null ? task.id() : "unknown"; diff --git a/server-common/src/main/java/io/a2a/server/tasks/TaskManager.java b/server-common/src/main/java/io/a2a/server/tasks/TaskManager.java index fd3696a60..dbaf76ac2 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/TaskManager.java +++ b/server-common/src/main/java/io/a2a/server/tasks/TaskManager.java @@ -12,7 +12,7 @@ import io.a2a.spec.A2AServerException; import io.a2a.spec.Event; -import io.a2a.spec.InvalidParamsError; +import io.a2a.spec.InternalError; import io.a2a.spec.Message; import io.a2a.spec.Task; import io.a2a.spec.TaskArtifactUpdateEvent; @@ -31,13 +31,29 @@ public class TaskManager { private final TaskStore taskStore; private final @Nullable Message initialMessage; private volatile @Nullable Task currentTask; + private volatile boolean isTempId; - public TaskManager(@Nullable String taskId, @Nullable String contextId, TaskStore taskStore, @Nullable Message initialMessage) { + public TaskManager(@Nullable String taskId, @Nullable String contextId, TaskStore taskStore, @Nullable Message initialMessage, boolean isTempId) { checkNotNullParam("taskStore", taskStore); this.taskId = taskId; this.contextId = contextId; this.taskStore = taskStore; this.initialMessage = initialMessage; + this.isTempId = isTempId; + } + + /** + * Updates the taskId from a temporary ID to the real task ID. + * Only allowed when this TaskManager was created with isTempId=true. + * Called by DefaultRequestHandler when switching from temp-UUID to real task.id. + */ + public void setTaskId(String newTaskId) { + if (!isTempId) { + throw new IllegalStateException("Cannot change taskId - not created with temporary ID"); + } + LOGGER.debug("TaskManager switching taskId from {} to {}", this.taskId, newTaskId); + this.taskId = newTaskId; + this.isTempId = false; // No longer temporary after switch } @Nullable String getTaskId() { @@ -131,9 +147,25 @@ public Task updateWithMessage(Message message, Task task) { private void checkIdsAndUpdateIfNecessary(String eventTaskId, String eventContextId) throws A2AServerException { if (taskId != null && !eventTaskId.equals(taskId)) { - throw new A2AServerException( - "Invalid task id", - new InvalidParamsError(String.format("Task in event doesn't match TaskManager "))); + // Allow switching from temporary ID to real task ID + // This happens when client sends message without taskId and agent creates Task with real ID + if (isTempId) { + // Verify the new task ID doesn't already exist in the store + // If it does, the agent is trying to return an existing task when it should create a new one + Task existingTask = taskStore.get(eventTaskId); + if (existingTask != null) { + throw new A2AServerException( + "Invalid task id", + new InternalError(String.format("Agent returned existing task ID %s when expecting new task", eventTaskId))); + } + LOGGER.debug("TaskManager allowing taskId switch from temp {} to real {}", taskId, eventTaskId); + taskId = eventTaskId; + isTempId = false; // No longer temporary after switch + } else { + throw new A2AServerException( + "Invalid task id", + new InternalError(String.format("Task in event doesn't match TaskManager "))); + } } if (taskId == null) { taskId = eventTaskId; diff --git a/server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java b/server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java index e26dd55fb..eee254ba3 100644 --- a/server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java +++ b/server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java @@ -1,8 +1,8 @@ package io.a2a.server.util.async; +import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; @@ -26,6 +26,7 @@ public class AsyncExecutorProducer { private static final String A2A_EXECUTOR_CORE_POOL_SIZE = "a2a.executor.core-pool-size"; private static final String A2A_EXECUTOR_MAX_POOL_SIZE = "a2a.executor.max-pool-size"; private static final String A2A_EXECUTOR_KEEP_ALIVE_SECONDS = "a2a.executor.keep-alive-seconds"; + private static final String A2A_EXECUTOR_QUEUE_CAPACITY = "a2a.executor.queue-capacity"; @Inject A2AConfigProvider configProvider; @@ -57,6 +58,16 @@ public class AsyncExecutorProducer { */ long keepAliveSeconds; + /** + * Queue capacity for pending tasks. + *

    + * Property: {@code a2a.executor.queue-capacity}
    + * Default: 100
    + * Note: Must be bounded to allow pool growth to maxPoolSize. + * When queue is full, new threads are created up to maxPoolSize. + */ + int queueCapacity; + private @Nullable ExecutorService executor; @PostConstruct @@ -64,18 +75,34 @@ public void init() { corePoolSize = Integer.parseInt(configProvider.getValue(A2A_EXECUTOR_CORE_POOL_SIZE)); maxPoolSize = Integer.parseInt(configProvider.getValue(A2A_EXECUTOR_MAX_POOL_SIZE)); keepAliveSeconds = Long.parseLong(configProvider.getValue(A2A_EXECUTOR_KEEP_ALIVE_SECONDS)); - - LOGGER.info("Initializing async executor: corePoolSize={}, maxPoolSize={}, keepAliveSeconds={}", - corePoolSize, maxPoolSize, keepAliveSeconds); - - executor = new ThreadPoolExecutor( + queueCapacity = Integer.parseInt(configProvider.getValue(A2A_EXECUTOR_QUEUE_CAPACITY)); + + LOGGER.info("Initializing async executor: corePoolSize={}, maxPoolSize={}, keepAliveSeconds={}, queueCapacity={}", + corePoolSize, maxPoolSize, keepAliveSeconds, queueCapacity); + + // CRITICAL: Use ArrayBlockingQueue (bounded) instead of LinkedBlockingQueue (unbounded). + // With unbounded queue, ThreadPoolExecutor NEVER grows beyond corePoolSize because the + // queue never fills. This causes executor pool exhaustion during concurrent requests when + // EventConsumer polling threads hold all core threads and agent tasks queue indefinitely. + // Bounded queue enables pool growth: when queue is full, new threads are created up to + // maxPoolSize, preventing agent execution starvation. + ThreadPoolExecutor tpe = new ThreadPoolExecutor( corePoolSize, maxPoolSize, keepAliveSeconds, TimeUnit.SECONDS, - new LinkedBlockingQueue<>(), + new ArrayBlockingQueue<>(queueCapacity), new A2AThreadFactory() ); + + // CRITICAL: Allow core threads to timeout after keepAliveSeconds when idle. + // By default, ThreadPoolExecutor only times out threads above corePoolSize. + // Without this, core threads accumulate during testing and never clean up. + // This is essential for streaming scenarios where many short-lived tasks create threads + // for agent execution and cleanup callbacks, but those threads remain idle afterward. + tpe.allowCoreThreadTimeOut(true); + + executor = tpe; } @PreDestroy @@ -106,6 +133,22 @@ public Executor produce() { return executor; } + /** + * Log current executor pool statistics for diagnostics. + * Useful for debugging pool exhaustion or sizing issues. + */ + public void logPoolStats() { + if (executor instanceof ThreadPoolExecutor tpe) { + LOGGER.info("Executor pool stats: active={}/{}, queued={}/{}, completed={}, total={}", + tpe.getActiveCount(), + tpe.getPoolSize(), + tpe.getQueue().size(), + queueCapacity, + tpe.getCompletedTaskCount(), + tpe.getTaskCount()); + } + } + private static class A2AThreadFactory implements ThreadFactory { private final AtomicInteger threadNumber = new AtomicInteger(1); private final String namePrefix = "a2a-agent-executor-"; diff --git a/server-common/src/main/java/io/a2a/server/util/async/EventConsumerExecutorProducer.java b/server-common/src/main/java/io/a2a/server/util/async/EventConsumerExecutorProducer.java new file mode 100644 index 000000000..24ff7f5d1 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/util/async/EventConsumerExecutorProducer.java @@ -0,0 +1,93 @@ +package io.a2a.server.util.async; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Produces; +import jakarta.inject.Qualifier; + +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static java.lang.annotation.ElementType.*; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +/** + * Produces a dedicated executor for EventConsumer polling threads. + *

    + * CRITICAL: EventConsumer polling must use a separate executor from AgentExecutor because: + *

      + *
    • EventConsumer threads are I/O-bound (blocking on queue.poll()), not CPU-bound
    • + *
    • One EventConsumer thread needed per active queue (can be 100+ concurrent)
    • + *
    • Threads are mostly idle, waiting for events
    • + *
    • Using the same bounded pool as AgentExecutor causes deadlock when pool exhausted
    • + *
    + *

    + * Uses a cached thread pool (unbounded) with automatic thread reclamation: + *

      + *
    • Creates threads on demand as EventConsumers start
    • + *
    • Idle threads automatically terminated after 10 seconds
    • + *
    • No queue saturation since threads are created as needed
    • + *
    + */ +@ApplicationScoped +public class EventConsumerExecutorProducer { + private static final Logger LOGGER = LoggerFactory.getLogger(EventConsumerExecutorProducer.class); + + /** + * Qualifier annotation for EventConsumer executor injection. + */ + @Retention(RUNTIME) + @Target({METHOD, FIELD, PARAMETER, TYPE}) + @Qualifier + public @interface EventConsumerExecutor { + } + + /** + * Thread factory for EventConsumer threads. + */ + private static class EventConsumerThreadFactory implements ThreadFactory { + private final AtomicInteger threadNumber = new AtomicInteger(1); + + @Override + public Thread newThread(Runnable r) { + Thread thread = new Thread(r, "a2a-event-consumer-" + threadNumber.getAndIncrement()); + thread.setDaemon(true); + return thread; + } + } + + private @Nullable ExecutorService executor; + + @Produces + @EventConsumerExecutor + @ApplicationScoped + public Executor eventConsumerExecutor() { + // Cached thread pool with 10s idle timeout (reduced from default 60s): + // - Creates threads on demand as EventConsumers start + // - Reclaims idle threads after 10s to prevent accumulation during fast test execution + // - Perfect for I/O-bound EventConsumer polling which blocks on queue.poll() + // - 10s timeout balances thread reuse (production) vs cleanup (testing) + executor = new ThreadPoolExecutor( + 0, // corePoolSize - no core threads + Integer.MAX_VALUE, // maxPoolSize - unbounded + 10, TimeUnit.SECONDS, // keepAliveTime - 10s idle timeout + new SynchronousQueue<>(), // queue - same as cached pool + new EventConsumerThreadFactory() + ); + + LOGGER.info("Initialized EventConsumer executor: cached thread pool (unbounded, 10s idle timeout)"); + + return executor; + } +} diff --git a/server-common/src/main/java/io/a2a/server/util/sse/SseFormatter.java b/server-common/src/main/java/io/a2a/server/util/sse/SseFormatter.java new file mode 100644 index 000000000..737fbac23 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/util/sse/SseFormatter.java @@ -0,0 +1,136 @@ +package io.a2a.server.util.sse; + +import io.a2a.grpc.utils.JSONRPCUtils; +import io.a2a.jsonrpc.common.wrappers.A2AErrorResponse; +import io.a2a.jsonrpc.common.wrappers.A2AResponse; +import io.a2a.jsonrpc.common.wrappers.CancelTaskResponse; +import io.a2a.jsonrpc.common.wrappers.DeleteTaskPushNotificationConfigResponse; +import io.a2a.jsonrpc.common.wrappers.GetExtendedAgentCardResponse; +import io.a2a.jsonrpc.common.wrappers.GetTaskPushNotificationConfigResponse; +import io.a2a.jsonrpc.common.wrappers.GetTaskResponse; +import io.a2a.jsonrpc.common.wrappers.ListTaskPushNotificationConfigResponse; +import io.a2a.jsonrpc.common.wrappers.ListTasksResponse; +import io.a2a.jsonrpc.common.wrappers.SendMessageResponse; +import io.a2a.jsonrpc.common.wrappers.SendStreamingMessageResponse; +import io.a2a.jsonrpc.common.wrappers.SetTaskPushNotificationConfigResponse; + +/** + * Framework-agnostic utility for formatting A2A responses as Server-Sent Events (SSE). + *

    + * Provides static methods to serialize A2A responses to JSON and format them as SSE events. + * This allows HTTP server frameworks (Vert.x, Jakarta/WildFly, etc.) to use their own + * reactive libraries for publisher mapping while sharing the serialization logic. + *

    + * Example usage (Quarkus/Vert.x with Mutiny): + *

    {@code
    + * Flow.Publisher> responses = handler.onMessageSendStream(request, context);
    + * AtomicLong eventId = new AtomicLong(0);
    + *
    + * Multi sseEvents = Multi.createFrom().publisher(responses)
    + *     .map(response -> SseFormatter.formatResponseAsSSE(response, eventId.getAndIncrement()));
    + *
    + * sseEvents.subscribe().with(sseEvent -> httpResponse.write(Buffer.buffer(sseEvent)));
    + * }
    + *

    + * Example usage (Jakarta/WildFly with custom reactive library): + *

    {@code
    + * Flow.Publisher jsonStrings = restHandler.getJsonPublisher();
    + * AtomicLong eventId = new AtomicLong(0);
    + *
    + * Flow.Publisher sseEvents = mapPublisher(jsonStrings,
    + *     json -> SseFormatter.formatJsonAsSSE(json, eventId.getAndIncrement()));
    + * }
    + */ +public class SseFormatter { + + private SseFormatter() { + // Utility class - prevent instantiation + } + + /** + * Format an A2A response as an SSE event. + *

    + * Serializes the response to JSON and formats as: + *

    +     * data: {"jsonrpc":"2.0","result":{...},"id":123}
    +     * id: 0
    +     *
    +     * 
    + * + * @param response the A2A response to format + * @param eventId the SSE event ID + * @return SSE-formatted string (ready to write to HTTP response) + */ + public static String formatResponseAsSSE(A2AResponse response, long eventId) { + String jsonData = serializeResponse(response); + return "data: " + jsonData + "\nid: " + eventId + "\n\n"; + } + + /** + * Format a pre-serialized JSON string as an SSE event. + *

    + * Wraps the JSON in SSE format as: + *

    +     * data: {"jsonrpc":"2.0","result":{...},"id":123}
    +     * id: 0
    +     *
    +     * 
    + *

    + * Use this when you already have JSON strings (e.g., from REST transport) + * and just need to add SSE formatting. + * + * @param jsonString the JSON string to wrap + * @param eventId the SSE event ID + * @return SSE-formatted string (ready to write to HTTP response) + */ + public static String formatJsonAsSSE(String jsonString, long eventId) { + return "data: " + jsonString + "\nid: " + eventId + "\n\n"; + } + + /** + * Serialize an A2AResponse to JSON string. + */ + private static String serializeResponse(A2AResponse response) { + // For error responses, use standard JSON-RPC error format + if (response instanceof A2AErrorResponse error) { + return JSONRPCUtils.toJsonRPCErrorResponse(error.getId(), error.getError()); + } + if (response.getError() != null) { + return JSONRPCUtils.toJsonRPCErrorResponse(response.getId(), response.getError()); + } + + // Convert domain response to protobuf message and serialize + com.google.protobuf.MessageOrBuilder protoMessage = convertToProto(response); + return JSONRPCUtils.toJsonRPCResultResponse(response.getId(), protoMessage); + } + + /** + * Convert A2AResponse to protobuf message for serialization. + */ + private static com.google.protobuf.MessageOrBuilder convertToProto(A2AResponse response) { + if (response instanceof GetTaskResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.task(r.getResult()); + } else if (response instanceof CancelTaskResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.task(r.getResult()); + } else if (response instanceof SendMessageResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.taskOrMessage(r.getResult()); + } else if (response instanceof ListTasksResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.listTasksResult(r.getResult()); + } else if (response instanceof SetTaskPushNotificationConfigResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.setTaskPushNotificationConfigResponse(r.getResult()); + } else if (response instanceof GetTaskPushNotificationConfigResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.getTaskPushNotificationConfigResponse(r.getResult()); + } else if (response instanceof ListTaskPushNotificationConfigResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.listTaskPushNotificationConfigResponse(r.getResult()); + } else if (response instanceof DeleteTaskPushNotificationConfigResponse) { + // DeleteTaskPushNotificationConfig has no result body, just return empty message + return com.google.protobuf.Empty.getDefaultInstance(); + } else if (response instanceof GetExtendedAgentCardResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.getExtendedCardResponse(r.getResult()); + } else if (response instanceof SendStreamingMessageResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.taskOrMessageStream(r.getResult()); + } else { + throw new IllegalArgumentException("Unknown response type: " + response.getClass().getName()); + } + } +} diff --git a/server-common/src/main/java/io/a2a/server/util/sse/package-info.java b/server-common/src/main/java/io/a2a/server/util/sse/package-info.java new file mode 100644 index 000000000..7e668b632 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/util/sse/package-info.java @@ -0,0 +1,11 @@ +/** + * Server-Sent Events (SSE) formatting utilities for A2A streaming responses. + *

    + * Provides framework-agnostic conversion of {@code Flow.Publisher>} to + * {@code Flow.Publisher} with SSE formatting, enabling easy integration with + * any HTTP server framework (Vert.x, Jakarta Servlet, etc.). + */ +@NullMarked +package io.a2a.server.util.sse; + +import org.jspecify.annotations.NullMarked; diff --git a/server-common/src/main/resources/META-INF/a2a-defaults.properties b/server-common/src/main/resources/META-INF/a2a-defaults.properties index 280fd943b..719be9e7a 100644 --- a/server-common/src/main/resources/META-INF/a2a-defaults.properties +++ b/server-common/src/main/resources/META-INF/a2a-defaults.properties @@ -19,3 +19,7 @@ a2a.executor.max-pool-size=50 # Keep-alive time for idle threads (seconds) a2a.executor.keep-alive-seconds=60 + +# Queue capacity for pending tasks (must be bounded to enable pool growth) +# When queue is full, new threads are created up to max-pool-size +a2a.executor.queue-capacity=100 diff --git a/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java b/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java index 4354f1639..146bfb10a 100644 --- a/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java +++ b/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java @@ -16,6 +16,8 @@ import java.util.concurrent.atomic.AtomicReference; import io.a2a.jsonrpc.common.json.JsonProcessingException; +import io.a2a.server.tasks.InMemoryTaskStore; +import io.a2a.server.tasks.PushNotificationSender; import io.a2a.spec.A2AError; import io.a2a.spec.A2AServerException; import io.a2a.spec.Artifact; @@ -27,14 +29,19 @@ import io.a2a.spec.TaskStatus; import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; public class EventConsumerTest { + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; + private static final String TASK_ID = "123"; // Must match MINIMAL_TASK id + private EventQueue eventQueue; private EventConsumer eventConsumer; - + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; private static final String MINIMAL_TASK = """ { @@ -54,10 +61,59 @@ public class EventConsumerTest { @BeforeEach public void init() { - eventQueue = EventQueue.builder().build(); + // Set up MainEventBus and processor for production-like test environment + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + InMemoryQueueManager queueManager = new InMemoryQueueManager(taskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + + eventQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TASK_ID) + .mainEventBus(mainEventBus) + .build().tap(); eventConsumer = new EventConsumer(eventQueue); } + @AfterEach + public void cleanup() { + if (mainEventBusProcessor != null) { + mainEventBusProcessor.setCallback(null); // Clear any test callbacks + EventQueueUtil.stop(mainEventBusProcessor); + } + } + + /** + * Helper to wait for MainEventBusProcessor to process an event. + * Replaces polling patterns with deterministic callback-based waiting. + * + * @param action the action that triggers event processing + * @throws InterruptedException if waiting is interrupted + * @throws AssertionError if processing doesn't complete within timeout + */ + private void waitForEventProcessing(Runnable action) throws InterruptedException { + CountDownLatch processingLatch = new CountDownLatch(1); + mainEventBusProcessor.setCallback(new MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, Event event) { + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String taskId) { + // Not needed for basic event processing wait + } + }); + + try { + action.run(); + assertTrue(processingLatch.await(5, TimeUnit.SECONDS), + "MainEventBusProcessor should have processed the event within timeout"); + } finally { + mainEventBusProcessor.setCallback(null); + } + } + @Test public void testConsumeOneTaskEvent() throws Exception { Task event = fromJson(MINIMAL_TASK, Task.class); @@ -92,7 +148,7 @@ public void testConsumeAllMultipleEvents() throws JsonProcessingException { List events = List.of( fromJson(MINIMAL_TASK, Task.class), TaskArtifactUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .artifact(Artifact.builder() .artifactId("11") @@ -100,7 +156,7 @@ public void testConsumeAllMultipleEvents() throws JsonProcessingException { .build()) .build(), TaskStatusUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .status(new TaskStatus(TaskState.WORKING)) .isFinal(true) @@ -128,7 +184,7 @@ public void testConsumeUntilMessage() throws Exception { List events = List.of( fromJson(MINIMAL_TASK, Task.class), TaskArtifactUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .artifact(Artifact.builder() .artifactId("11") @@ -136,7 +192,7 @@ public void testConsumeUntilMessage() throws Exception { .build()) .build(), TaskStatusUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .status(new TaskStatus(TaskState.WORKING)) .isFinal(true) @@ -185,14 +241,14 @@ public void testConsumeMessageEvents() throws Exception { @Test public void testConsumeTaskInputRequired() { Task task = Task.builder() - .id("task-id") - .contextId("task-context") + .id(TASK_ID) + .contextId("session-xyz") .status(new TaskStatus(TaskState.INPUT_REQUIRED)) .build(); List events = List.of( task, TaskArtifactUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .artifact(Artifact.builder() .artifactId("11") @@ -200,7 +256,7 @@ public void testConsumeTaskInputRequired() { .build()) .build(), TaskStatusUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .status(new TaskStatus(TaskState.WORKING)) .isFinal(true) @@ -332,7 +388,9 @@ public void onComplete() { @Test public void testConsumeAllStopsOnQueueClosed() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .mainEventBus(mainEventBus) + .build().tap(); EventConsumer consumer = new EventConsumer(queue); // Close the queue immediately @@ -378,12 +436,16 @@ public void onComplete() { @Test public void testConsumeAllHandlesQueueClosedException() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .mainEventBus(mainEventBus) + .build().tap(); EventConsumer consumer = new EventConsumer(queue); // Add a message event (which will complete the stream) Event message = fromJson(MESSAGE_PAYLOAD, Message.class); - queue.enqueueEvent(message); + + // Use callback to wait for event processing + waitForEventProcessing(() -> queue.enqueueEvent(message)); // Close the queue before consuming queue.close(); @@ -428,11 +490,13 @@ public void onComplete() { @Test public void testConsumeAllTerminatesOnQueueClosedEvent() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .mainEventBus(mainEventBus) + .build().tap(); EventConsumer consumer = new EventConsumer(queue); // Enqueue a QueueClosedEvent (poison pill) - QueueClosedEvent queueClosedEvent = new QueueClosedEvent("task-123"); + QueueClosedEvent queueClosedEvent = new QueueClosedEvent(TASK_ID); queue.enqueueEvent(queueClosedEvent); Flow.Publisher publisher = consumer.consumeAll(); @@ -477,8 +541,12 @@ public void onComplete() { } private void enqueueAndConsumeOneEvent(Event event) throws Exception { - eventQueue.enqueueEvent(event); + // Use callback to wait for event processing + waitForEventProcessing(() -> eventQueue.enqueueEvent(event)); + + // Event is now available, consume it directly Event result = eventConsumer.consumeOne(); + assertNotNull(result, "Event should be available"); assertSame(event, result); } diff --git a/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java b/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java index a3dc7d916..2499a8173 100644 --- a/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java +++ b/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java @@ -11,7 +11,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import io.a2a.server.tasks.InMemoryTaskStore; +import io.a2a.server.tasks.PushNotificationSender; import io.a2a.spec.A2AError; import io.a2a.spec.Artifact; import io.a2a.spec.Event; @@ -23,12 +27,17 @@ import io.a2a.spec.TaskStatus; import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; public class EventQueueTest { private EventQueue eventQueue; + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; + + private static final String TASK_ID = "123"; // Must match MINIMAL_TASK id private static final String MINIMAL_TASK = """ { @@ -46,38 +55,96 @@ public class EventQueueTest { } """; + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; @BeforeEach public void init() { - eventQueue = EventQueue.builder().build(); + // Set up MainEventBus and processor for production-like test environment + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + InMemoryQueueManager queueManager = new InMemoryQueueManager(taskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + + eventQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TASK_ID) + .mainEventBus(mainEventBus) + .build().tap(); + } + + @AfterEach + public void cleanup() { + if (mainEventBusProcessor != null) { + mainEventBusProcessor.setCallback(null); // Clear any test callbacks + EventQueueUtil.stop(mainEventBusProcessor); + } + } + /** + * Helper to create a queue with MainEventBus configured (for tests that need event distribution). + */ + private EventQueue createQueueWithEventBus(String taskId) { + return EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(taskId) + .build(); + } + + /** + * Helper to wait for MainEventBusProcessor to process an event. + * Replaces polling patterns with deterministic callback-based waiting. + * + * @param action the action that triggers event processing + * @throws InterruptedException if waiting is interrupted + * @throws AssertionError if processing doesn't complete within timeout + */ + private void waitForEventProcessing(Runnable action) throws InterruptedException { + CountDownLatch processingLatch = new CountDownLatch(1); + mainEventBusProcessor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, io.a2a.spec.Event event) { + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String taskId) { + // Not needed for basic event processing wait + } + }); + + try { + action.run(); + assertTrue(processingLatch.await(5, TimeUnit.SECONDS), + "MainEventBusProcessor should have processed the event within timeout"); + } finally { + mainEventBusProcessor.setCallback(null); + } } @Test public void testConstructorDefaultQueueSize() { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); assertEquals(EventQueue.DEFAULT_QUEUE_SIZE, queue.getQueueSize()); } @Test public void testConstructorCustomQueueSize() { int customSize = 500; - EventQueue queue = EventQueue.builder().queueSize(customSize).build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).queueSize(customSize).build(); assertEquals(customSize, queue.getQueueSize()); } @Test public void testConstructorInvalidQueueSize() { // Test zero queue size - assertThrows(IllegalArgumentException.class, () -> EventQueue.builder().queueSize(0).build()); + assertThrows(IllegalArgumentException.class, () -> EventQueueUtil.getEventQueueBuilder(mainEventBus).queueSize(0).build()); // Test negative queue size - assertThrows(IllegalArgumentException.class, () -> EventQueue.builder().queueSize(-10).build()); + assertThrows(IllegalArgumentException.class, () -> EventQueueUtil.getEventQueueBuilder(mainEventBus).queueSize(-10).build()); } @Test public void testTapCreatesChildQueue() { - EventQueue parentQueue = EventQueue.builder().build(); + EventQueue parentQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); EventQueue childQueue = parentQueue.tap(); assertNotNull(childQueue); @@ -87,7 +154,7 @@ public void testTapCreatesChildQueue() { @Test public void testTapOnChildQueueThrowsException() { - EventQueue parentQueue = EventQueue.builder().build(); + EventQueue parentQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); EventQueue childQueue = parentQueue.tap(); assertThrows(IllegalStateException.class, () -> childQueue.tap()); @@ -95,69 +162,74 @@ public void testTapOnChildQueueThrowsException() { @Test public void testEnqueueEventPropagagesToChildren() throws Exception { - EventQueue parentQueue = EventQueue.builder().build(); - EventQueue childQueue = parentQueue.tap(); + EventQueue mainQueue = createQueueWithEventBus(TASK_ID); + EventQueue childQueue1 = mainQueue.tap(); + EventQueue childQueue2 = mainQueue.tap(); Event event = fromJson(MINIMAL_TASK, Task.class); - parentQueue.enqueueEvent(event); + mainQueue.enqueueEvent(event); - // Event should be available in both parent and child queues - Event parentEvent = parentQueue.dequeueEventItem(-1).getEvent(); - Event childEvent = childQueue.dequeueEventItem(-1).getEvent(); + // Event should be available in all child queues + // Note: MainEventBusProcessor runs async, so we use dequeueEventItem with timeout + Event child1Event = childQueue1.dequeueEventItem(5000).getEvent(); + Event child2Event = childQueue2.dequeueEventItem(5000).getEvent(); - assertSame(event, parentEvent); - assertSame(event, childEvent); + assertSame(event, child1Event); + assertSame(event, child2Event); } @Test public void testMultipleChildQueuesReceiveEvents() throws Exception { - EventQueue parentQueue = EventQueue.builder().build(); - EventQueue childQueue1 = parentQueue.tap(); - EventQueue childQueue2 = parentQueue.tap(); + EventQueue mainQueue = createQueueWithEventBus(TASK_ID); + EventQueue childQueue1 = mainQueue.tap(); + EventQueue childQueue2 = mainQueue.tap(); + EventQueue childQueue3 = mainQueue.tap(); Event event1 = fromJson(MINIMAL_TASK, Task.class); Event event2 = fromJson(MESSAGE_PAYLOAD, Message.class); - parentQueue.enqueueEvent(event1); - parentQueue.enqueueEvent(event2); + mainQueue.enqueueEvent(event1); + mainQueue.enqueueEvent(event2); - // All queues should receive both events - assertSame(event1, parentQueue.dequeueEventItem(-1).getEvent()); - assertSame(event2, parentQueue.dequeueEventItem(-1).getEvent()); + // All child queues should receive both events + // Note: Use timeout for async processing + assertSame(event1, childQueue1.dequeueEventItem(5000).getEvent()); + assertSame(event2, childQueue1.dequeueEventItem(5000).getEvent()); - assertSame(event1, childQueue1.dequeueEventItem(-1).getEvent()); - assertSame(event2, childQueue1.dequeueEventItem(-1).getEvent()); + assertSame(event1, childQueue2.dequeueEventItem(5000).getEvent()); + assertSame(event2, childQueue2.dequeueEventItem(5000).getEvent()); - assertSame(event1, childQueue2.dequeueEventItem(-1).getEvent()); - assertSame(event2, childQueue2.dequeueEventItem(-1).getEvent()); + assertSame(event1, childQueue3.dequeueEventItem(5000).getEvent()); + assertSame(event2, childQueue3.dequeueEventItem(5000).getEvent()); } @Test public void testChildQueueDequeueIndependently() throws Exception { - EventQueue parentQueue = EventQueue.builder().build(); - EventQueue childQueue1 = parentQueue.tap(); - EventQueue childQueue2 = parentQueue.tap(); + EventQueue mainQueue = createQueueWithEventBus(TASK_ID); + EventQueue childQueue1 = mainQueue.tap(); + EventQueue childQueue2 = mainQueue.tap(); + EventQueue childQueue3 = mainQueue.tap(); Event event = fromJson(MINIMAL_TASK, Task.class); - parentQueue.enqueueEvent(event); + mainQueue.enqueueEvent(event); - // Dequeue from child1 first - Event child1Event = childQueue1.dequeueEventItem(-1).getEvent(); + // Dequeue from child1 first (use timeout for async processing) + Event child1Event = childQueue1.dequeueEventItem(5000).getEvent(); assertSame(event, child1Event); // child2 should still have the event available - Event child2Event = childQueue2.dequeueEventItem(-1).getEvent(); + Event child2Event = childQueue2.dequeueEventItem(5000).getEvent(); assertSame(event, child2Event); - // Parent should still have the event available - Event parentEvent = parentQueue.dequeueEventItem(-1).getEvent(); - assertSame(event, parentEvent); + // child3 should still have the event available + Event child3Event = childQueue3.dequeueEventItem(5000).getEvent(); + assertSame(event, child3Event); } @Test public void testCloseImmediatePropagationToChildren() throws Exception { - EventQueue parentQueue = EventQueue.builder().build(); + EventQueue parentQueue = createQueueWithEventBus(TASK_ID); EventQueue childQueue = parentQueue.tap(); // Add events to both parent and child @@ -166,7 +238,7 @@ public void testCloseImmediatePropagationToChildren() throws Exception { assertFalse(childQueue.isClosed()); try { - assertNotNull(childQueue.dequeueEventItem(-1)); // Child has the event + assertNotNull(childQueue.dequeueEventItem(5000)); // Child has the event (use timeout) } catch (EventQueueClosedException e) { // This is fine if queue closed before dequeue } @@ -187,27 +259,37 @@ public void testCloseImmediatePropagationToChildren() throws Exception { @Test public void testEnqueueEventWhenClosed() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue mainQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TASK_ID) + .build(); + EventQueue childQueue = mainQueue.tap(); Event event = fromJson(MINIMAL_TASK, Task.class); - queue.close(); // Close the queue first - assertTrue(queue.isClosed()); + childQueue.close(); // Close the child queue first (removes from children list) + assertTrue(childQueue.isClosed()); + + // Create a new child queue BEFORE enqueuing (ensures it's in children list for distribution) + EventQueue newChildQueue = mainQueue.tap(); // MainQueue accepts events even when closed (for replication support) // This ensures late-arriving replicated events can be enqueued to closed queues - queue.enqueueEvent(event); + // Note: MainEventBusProcessor runs asynchronously, so we use dequeueEventItem with timeout + mainQueue.enqueueEvent(event); - // Event should be available for dequeuing - Event dequeuedEvent = queue.dequeueEventItem(-1).getEvent(); + // New child queue should receive the event (old closed child was removed from children list) + EventQueueItem item = newChildQueue.dequeueEventItem(5000); + assertNotNull(item); + Event dequeuedEvent = item.getEvent(); assertSame(event, dequeuedEvent); - // Now queue is closed and empty, should throw exception - assertThrows(EventQueueClosedException.class, () -> queue.dequeueEventItem(-1)); + // Now new child queue is closed and empty, should throw exception + newChildQueue.close(); + assertThrows(EventQueueClosedException.class, () -> newChildQueue.dequeueEventItem(-1)); } @Test public void testDequeueEventWhenClosedAndEmpty() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build().tap(); queue.close(); assertTrue(queue.isClosed()); @@ -217,19 +299,27 @@ public void testDequeueEventWhenClosedAndEmpty() throws Exception { @Test public void testDequeueEventWhenClosedButHasEvents() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue mainQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TASK_ID) + .build(); + EventQueue childQueue = mainQueue.tap(); Event event = fromJson(MINIMAL_TASK, Task.class); - queue.enqueueEvent(event); - queue.close(); // Graceful close - events should remain - assertTrue(queue.isClosed()); + // Use callback to wait for event processing instead of polling + waitForEventProcessing(() -> mainQueue.enqueueEvent(event)); - // Should still be able to dequeue existing events - Event dequeuedEvent = queue.dequeueEventItem(-1).getEvent(); + // At this point, event has been processed and distributed to childQueue + childQueue.close(); // Graceful close - events should remain + assertTrue(childQueue.isClosed()); + + // Should still be able to dequeue existing events from closed queue + EventQueueItem item = childQueue.dequeueEventItem(5000); + assertNotNull(item); + Event dequeuedEvent = item.getEvent(); assertSame(event, dequeuedEvent); // Now queue is closed and empty, should throw exception - assertThrows(EventQueueClosedException.class, () -> queue.dequeueEventItem(-1)); + assertThrows(EventQueueClosedException.class, () -> childQueue.dequeueEventItem(-1)); } @Test @@ -244,7 +334,9 @@ public void testEnqueueAndDequeueEvent() throws Exception { public void testDequeueEventNoWait() throws Exception { Event event = fromJson(MINIMAL_TASK, Task.class); eventQueue.enqueueEvent(event); - Event dequeuedEvent = eventQueue.dequeueEventItem(-1).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event dequeuedEvent = item.getEvent(); assertSame(event, dequeuedEvent); } @@ -257,7 +349,7 @@ public void testDequeueEventEmptyQueueNoWait() throws Exception { @Test public void testDequeueEventWait() throws Exception { Event event = TaskStatusUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .status(new TaskStatus(TaskState.WORKING)) .isFinal(true) @@ -271,7 +363,7 @@ public void testDequeueEventWait() throws Exception { @Test public void testTaskDone() throws Exception { Event event = TaskArtifactUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .artifact(Artifact.builder() .artifactId("11") @@ -347,7 +439,7 @@ public void testCloseIdempotent() throws Exception { assertTrue(eventQueue.isClosed()); // Test with immediate close as well - EventQueue eventQueue2 = EventQueue.builder().build(); + EventQueue eventQueue2 = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); eventQueue2.close(true); assertTrue(eventQueue2.isClosed()); @@ -361,19 +453,20 @@ public void testCloseIdempotent() throws Exception { */ @Test public void testCloseChildQueues() throws Exception { - EventQueue childQueue = eventQueue.tap(); + EventQueue mainQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); + EventQueue childQueue = mainQueue.tap(); assertTrue(childQueue != null); // Graceful close - parent closes but children remain open - eventQueue.close(); - assertTrue(eventQueue.isClosed()); + mainQueue.close(); + assertTrue(mainQueue.isClosed()); assertFalse(childQueue.isClosed()); // Child NOT closed on graceful parent close // Immediate close - parent force-closes all children - EventQueue parentQueue2 = EventQueue.builder().build(); - EventQueue childQueue2 = parentQueue2.tap(); - parentQueue2.close(true); // immediate=true - assertTrue(parentQueue2.isClosed()); + EventQueue mainQueue2 = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); + EventQueue childQueue2 = mainQueue2.tap(); + mainQueue2.close(true); // immediate=true + assertTrue(mainQueue2.isClosed()); assertTrue(childQueue2.isClosed()); // Child IS closed on immediate parent close } @@ -383,7 +476,7 @@ public void testCloseChildQueues() throws Exception { */ @Test public void testMainQueueReferenceCountingStaysOpenWithActiveChildren() throws Exception { - EventQueue mainQueue = EventQueue.builder().build(); + EventQueue mainQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); EventQueue child1 = mainQueue.tap(); EventQueue child2 = mainQueue.tap(); diff --git a/server-common/src/test/java/io/a2a/server/events/EventQueueUtil.java b/server-common/src/test/java/io/a2a/server/events/EventQueueUtil.java index 39201c1f6..6c9ed4a17 100644 --- a/server-common/src/test/java/io/a2a/server/events/EventQueueUtil.java +++ b/server-common/src/test/java/io/a2a/server/events/EventQueueUtil.java @@ -1,8 +1,39 @@ package io.a2a.server.events; +import java.util.concurrent.atomic.AtomicInteger; + public class EventQueueUtil { - // Since EventQueue.builder() is package protected, add a method to expose it - public static EventQueue.EventQueueBuilder getEventQueueBuilder() { - return EventQueue.builder(); + // Counter for generating unique test taskIds + private static final AtomicInteger TASK_ID_COUNTER = new AtomicInteger(0); + + /** + * Get an EventQueue builder pre-configured with the shared test MainEventBus and a unique taskId. + *

    + * Note: Returns MainQueue - tests should call .tap() if they need to consume events. + *

    + * + * @return builder with TEST_EVENT_BUS and unique taskId already set + */ + public static EventQueue.EventQueueBuilder getEventQueueBuilder(MainEventBus eventBus) { + return EventQueue.builder(eventBus) + .taskId("test-task-" + TASK_ID_COUNTER.incrementAndGet()); + } + + /** + * Start a MainEventBusProcessor instance. + * + * @param processor the processor to start + */ + public static void start(MainEventBusProcessor processor) { + processor.start(); + } + + /** + * Stop a MainEventBusProcessor instance. + * + * @param processor the processor to stop + */ + public static void stop(MainEventBusProcessor processor) { + processor.stop(); } } diff --git a/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java b/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java index 1eca1b739..3e09ff2af 100644 --- a/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java +++ b/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java @@ -14,7 +14,10 @@ import java.util.concurrent.ExecutionException; import java.util.stream.IntStream; +import io.a2a.server.tasks.InMemoryTaskStore; import io.a2a.server.tasks.MockTaskStateProvider; +import io.a2a.server.tasks.PushNotificationSender; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -22,17 +25,30 @@ public class InMemoryQueueManagerTest { private InMemoryQueueManager queueManager; private MockTaskStateProvider taskStateProvider; + private InMemoryTaskStore taskStore; + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; @BeforeEach public void setUp() { taskStateProvider = new MockTaskStateProvider(); - queueManager = new InMemoryQueueManager(taskStateProvider); + taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + queueManager = new InMemoryQueueManager(taskStateProvider, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + } + + @AfterEach + public void tearDown() { + EventQueueUtil.stop(mainEventBusProcessor); } @Test public void testAddNewQueue() { String taskId = "test_task_id"; - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); @@ -43,8 +59,8 @@ public void testAddNewQueue() { @Test public void testAddExistingQueueThrowsException() { String taskId = "test_task_id"; - EventQueue queue1 = EventQueue.builder().build(); - EventQueue queue2 = EventQueue.builder().build(); + EventQueue queue1 = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); + EventQueue queue2 = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue1); @@ -56,7 +72,7 @@ public void testAddExistingQueueThrowsException() { @Test public void testGetExistingQueue() { String taskId = "test_task_id"; - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); EventQueue result = queueManager.get(taskId); @@ -73,7 +89,7 @@ public void testGetNonexistentQueue() { @Test public void testTapExistingQueue() { String taskId = "test_task_id"; - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); EventQueue tappedQueue = queueManager.tap(taskId); @@ -94,7 +110,7 @@ public void testTapNonexistentQueue() { @Test public void testCloseExistingQueue() { String taskId = "test_task_id"; - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); queueManager.close(taskId); @@ -129,7 +145,7 @@ public void testCreateOrTapNewQueue() { @Test public void testCreateOrTapExistingQueue() { String taskId = "test_task_id"; - EventQueue originalQueue = EventQueue.builder().build(); + EventQueue originalQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, originalQueue); EventQueue result = queueManager.createOrTap(taskId); @@ -151,7 +167,7 @@ public void testConcurrentOperations() throws InterruptedException, ExecutionExc // Add tasks concurrently List> addFutures = taskIds.stream() .map(taskId -> CompletableFuture.supplyAsync(() -> { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); return taskId; })) diff --git a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java index ea5bbe797..9c64f03f9 100644 --- a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java +++ b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java @@ -26,7 +26,10 @@ import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.events.EventQueue; import io.a2a.server.events.EventQueueItem; +import io.a2a.server.events.EventQueueUtil; import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.events.MainEventBus; +import io.a2a.server.events.MainEventBusProcessor; import io.a2a.server.tasks.BasePushNotificationSender; import io.a2a.server.tasks.InMemoryPushNotificationConfigStore; import io.a2a.server.tasks.InMemoryTaskStore; @@ -66,6 +69,8 @@ public class AbstractA2ARequestHandlerTest { private static final String PREFERRED_TRANSPORT = "preferred-transport"; private static final String A2A_REQUESTHANDLER_TEST_PROPERTIES = "/a2a-requesthandler-test.properties"; + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; + protected AgentExecutor executor; protected TaskStore taskStore; protected RequestHandler requestHandler; @@ -73,6 +78,8 @@ public class AbstractA2ARequestHandlerTest { protected AgentExecutorMethod agentExecutorCancel; protected InMemoryQueueManager queueManager; protected TestHttpClient httpClient; + protected MainEventBus mainEventBus; + protected MainEventBusProcessor mainEventBusProcessor; protected final Executor internalExecutor = Executors.newCachedThreadPool(); @@ -96,19 +103,31 @@ public void cancel(RequestContext context, EventQueue eventQueue) throws A2AErro InMemoryTaskStore inMemoryTaskStore = new InMemoryTaskStore(); taskStore = inMemoryTaskStore; - queueManager = new InMemoryQueueManager(inMemoryTaskStore); + + // Create push notification components BEFORE MainEventBusProcessor httpClient = new TestHttpClient(); PushNotificationConfigStore pushConfigStore = new InMemoryPushNotificationConfigStore(); PushNotificationSender pushSender = new BasePushNotificationSender(pushConfigStore, httpClient); + // Create MainEventBus and MainEventBusProcessor (production code path) + mainEventBus = new MainEventBus(); + queueManager = new InMemoryQueueManager(inMemoryTaskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, pushSender, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, pushConfigStore, pushSender, internalExecutor); + executor, taskStore, queueManager, pushConfigStore, mainEventBusProcessor, internalExecutor, internalExecutor); } @AfterEach public void cleanup() { agentExecutorExecute = null; agentExecutorCancel = null; + + // Stop MainEventBusProcessor background thread + if (mainEventBusProcessor != null) { + EventQueueUtil.stop(mainEventBusProcessor); + } } protected static AgentCard createAgentCard(boolean streaming, boolean pushNotifications, boolean stateTransitionHistory) { diff --git a/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java b/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java index 293babe4e..19df17aac 100644 --- a/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java +++ b/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java @@ -3,11 +3,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -17,11 +19,16 @@ import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.auth.UnauthenticatedUser; import io.a2a.server.events.EventQueue; +import io.a2a.server.events.EventQueueUtil; import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.events.MainEventBus; +import io.a2a.server.events.MainEventBusProcessor; import io.a2a.server.tasks.InMemoryPushNotificationConfigStore; import io.a2a.server.tasks.InMemoryTaskStore; +import io.a2a.server.tasks.PushNotificationSender; import io.a2a.server.tasks.TaskUpdater; import io.a2a.spec.A2AError; +import io.a2a.spec.Event; import io.a2a.spec.ListTaskPushNotificationConfigParams; import io.a2a.spec.ListTaskPushNotificationConfigResult; import io.a2a.spec.Message; @@ -31,7 +38,9 @@ import io.a2a.spec.Task; import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; @@ -50,26 +59,74 @@ public class DefaultRequestHandlerTest { private InMemoryQueueManager queueManager; private TestAgentExecutor agentExecutor; private ServerCallContext serverCallContext; + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; + + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; @BeforeEach void setUp() { taskStore = new InMemoryTaskStore(); + + // Create MainEventBus and MainEventBusProcessor (production code path) + mainEventBus = new MainEventBus(); // Pass taskStore as TaskStateProvider to queueManager for task-aware queue management - queueManager = new InMemoryQueueManager(taskStore); + queueManager = new InMemoryQueueManager(taskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + agentExecutor = new TestAgentExecutor(); + ExecutorService executor = Executors.newCachedThreadPool(); requestHandler = DefaultRequestHandler.create( agentExecutor, taskStore, queueManager, null, // pushConfigStore - null, // pushSender - Executors.newCachedThreadPool() + mainEventBusProcessor, + executor, + executor ); serverCallContext = new ServerCallContext(UnauthenticatedUser.INSTANCE, Map.of(), Set.of()); } + @AfterEach + void tearDown() { + // Stop MainEventBusProcessor background thread + // Note: Don't clear callback here - DefaultRequestHandler has a permanent callback + if (mainEventBusProcessor != null) { + EventQueueUtil.stop(mainEventBusProcessor); + } + } + + /** + * Helper to wait for task finalization in background (for non-blocking tests). + *

    + * Note: Does NOT set callbacks - DefaultRequestHandler has a permanent callback. + * Simply polls TaskStore until task reaches final state. + *

    + * + * @param action the action that triggers task finalization (e.g., allowing agent to complete) + * @param taskId the task ID to wait for + * @throws InterruptedException if waiting is interrupted + * @throws AssertionError if finalization doesn't complete within timeout + */ + private void waitForTaskFinalization(Runnable action, String taskId) throws InterruptedException { + action.run(); + + // Poll TaskStore for final state (non-blocking tests complete in background) + for (int i = 0; i < 50; i++) { + Task task = taskStore.get(taskId); + if (task != null && task.status() != null && task.status().state() != null + && task.status().state().isFinal()) { + return; // Success! + } + Thread.sleep(100); + } + fail("Task " + taskId + " should have been finalized within timeout"); + } + /** * Test that multiple blocking messages to the same task work correctly * when agent doesn't emit final events (fire-and-forget pattern). @@ -576,32 +633,15 @@ void testNonBlockingMessagePersistsAllEventsInBackground() throws Exception { // At this point, the non-blocking call has returned, but the agent is still running - // Allow the agent to emit the final COMPLETED event - allowCompletion.countDown(); - - // Assertion 2: Poll for the final task state to be persisted in background - // Use polling loop instead of fixed sleep for faster and more reliable test - long timeoutMs = 5000; - long startTime = System.currentTimeMillis(); - Task persistedTask = null; - boolean completedStateFound = false; - - while (System.currentTimeMillis() - startTime < timeoutMs) { - persistedTask = taskStore.get(taskId); - if (persistedTask != null && persistedTask.status().state() == TaskState.COMPLETED) { - completedStateFound = true; - break; - } - Thread.sleep(100); // Poll every 100ms - } + // Assertion 2: Wait for the final task to be processed and finalized in background + // Poll TaskStore for finalization (background consumption) + waitForTaskFinalization(() -> allowCompletion.countDown(), taskId); - assertTrue(persistedTask != null, "Task should be persisted to store"); - assertTrue( - completedStateFound, - "Final task state should be COMPLETED (background consumption should have processed it), got: " + - (persistedTask != null ? persistedTask.status().state() : "null") + - " after " + (System.currentTimeMillis() - startTime) + "ms" - ); + // Verify the task was persisted with COMPLETED state + Task persistedTask = taskStore.get(taskId); + assertNotNull(persistedTask, "Task should be persisted to store"); + assertEquals(TaskState.COMPLETED, persistedTask.status().state(), + "Final task state should be COMPLETED (background consumption should have processed it)"); } /** @@ -779,13 +819,16 @@ void testBlockingCallReturnsCompleteTaskWithArtifacts() throws Exception { }); // Call blocking onMessageSend - should wait for ALL events + // DefaultRequestHandler now waits internally for task finalization before returning Object result = requestHandler.onMessageSend(params, serverCallContext); // The returned result should be a Task with ALL artifacts assertTrue(result instanceof Task, "Result should be a Task"); Task returnedTask = (Task) result; - // Verify task is completed + // Fetch final state from TaskStore (guaranteed to be persisted after blocking call) + returnedTask = taskStore.get(taskId); + assertEquals(TaskState.COMPLETED, returnedTask.status().state(), "Returned task should be COMPLETED"); @@ -817,13 +860,15 @@ void testBlockingMessageStoresPushNotificationConfigForNewTask() throws Exceptio InMemoryPushNotificationConfigStore pushConfigStore = new InMemoryPushNotificationConfigStore(); // Re-create request handler with pushConfigStore + ExecutorService pushTestExecutor = Executors.newCachedThreadPool(); requestHandler = DefaultRequestHandler.create( agentExecutor, taskStore, queueManager, pushConfigStore, // Add push config store - null, // pushSender - Executors.newCachedThreadPool() + mainEventBusProcessor, + pushTestExecutor, + pushTestExecutor ); // Create push notification config @@ -888,13 +933,15 @@ void testBlockingMessageStoresPushNotificationConfigForExistingTask() throws Exc InMemoryPushNotificationConfigStore pushConfigStore = new InMemoryPushNotificationConfigStore(); // Re-create request handler with pushConfigStore + ExecutorService pushTestExecutor = Executors.newCachedThreadPool(); requestHandler = DefaultRequestHandler.create( agentExecutor, taskStore, queueManager, pushConfigStore, // Add push config store - null, // pushSender - Executors.newCachedThreadPool() + mainEventBusProcessor, + pushTestExecutor, + pushTestExecutor ); // Create EXISTING task in store @@ -951,6 +998,187 @@ void testBlockingMessageStoresPushNotificationConfigForExistingTask() throws Exc assertEquals("https://example.com/existing-webhook", storedConfig.url()); } + /** + * Test that sending a message WITHOUT taskId works correctly (like TCK). + * Agent emits Task with the same ID it receives from context.getTaskId(). + */ + @Test + @Timeout(10) + void testMessageSendWithoutTaskIdCreatesTask() throws Exception { + String contextId = "temp-id-ctx"; + + // Agent does same as TCK: emit SUBMITTED task with received taskId, then WORKING (fire-and-forget) + agentExecutor.setExecuteCallback((context, queue) -> { + Task task = context.getTask(); + if (task == null) { + // First message: create SUBMITTED task using context's taskId + // (RequestContext generates a real UUID when message.taskId is null) + task = Task.builder() + .id(context.getTaskId()) + .contextId(context.getContextId()) + .status(new TaskStatus(TaskState.SUBMITTED)) + .history(List.of(context.getMessage())) + .build(); + queue.enqueueEvent(task); + } + // Set to WORKING (fire-and-forget like TCK) + TaskUpdater updater = new TaskUpdater(context, queue); + updater.startWork(); + // Don't complete - just return (fire-and-forget) + }); + + // Send message WITHOUT taskId (null) in blocking mode + Message message = Message.builder() + .messageId("msg-no-taskid") + .role(Message.Role.USER) + .parts(new TextPart("message without taskId")) + .taskId(null) // No taskId! + .contextId(contextId) + .build(); + + MessageSendConfiguration config = MessageSendConfiguration.builder() + .blocking(true) + .build(); + + MessageSendParams params = new MessageSendParams(message, config, null, ""); + + // Call blocking onMessageSend + Object result = requestHandler.onMessageSend(params, serverCallContext); + + // Verify result is a Task + assertTrue(result instanceof Task, "Result should be a Task"); + Task resultTask = (Task) result; + + // Task should have an ID (auto-generated) + assertNotNull(resultTask.id(), "Task should have an ID"); + + // ID should NOT start with "temp-" (that's just the queue key, not the task.id) + assertTrue(!resultTask.id().startsWith("temp-"), + "Task ID should not start with 'temp-', got: " + resultTask.id()); + + assertEquals(contextId, resultTask.contextId()); + assertEquals(TaskState.WORKING, resultTask.status().state()); + + // Verify task is persisted in TaskStore + Task storedTask = taskStore.get(resultTask.id()); + assertNotNull(storedTask, "Task should be stored"); + assertEquals(resultTask.id(), storedTask.id()); + } + + /** + * Test message send without taskId using streaming. + * This tests the same scenario as testMessageSendWithoutTaskIdCreatesTask but with streaming. + * TCK streaming tests were failing with similar temp-to-real ID switching issues. + */ + @Test + @Timeout(10) + void testMessageSendStreamWithoutTaskIdCreatesTask() throws Exception { + String contextId = "temp-id-ctx-stream"; + + // Agent does same as TCK: emit SUBMITTED task with received taskId, then WORKING (fire-and-forget) + agentExecutor.setExecuteCallback((context, queue) -> { + Task task = context.getTask(); + if (task == null) { + // First message: create SUBMITTED task using context's taskId + // (RequestContext generates a real UUID when message.taskId is null) + task = Task.builder() + .id(context.getTaskId()) + .contextId(context.getContextId()) + .status(new TaskStatus(TaskState.SUBMITTED)) + .history(List.of(context.getMessage())) + .build(); + queue.enqueueEvent(task); + } + // Set to WORKING (fire-and-forget like TCK) + TaskUpdater updater = new TaskUpdater(context, queue); + updater.startWork(); + // Don't complete - just return (fire-and-forget) + }); + + // Send message WITHOUT taskId (null) in streaming mode + Message message = Message.builder() + .messageId("msg-no-taskid-stream") + .role(Message.Role.USER) + .parts(new TextPart("message without taskId streaming")) + .taskId(null) // No taskId! + .contextId(contextId) + .build(); + + MessageSendParams params = new MessageSendParams(message, null, null, ""); + + // Call streaming onMessageSendStream + var publisher = requestHandler.onMessageSendStream(params, serverCallContext); + + // Collect events from stream + List events = new java.util.ArrayList<>(); + CountDownLatch completionLatch = new CountDownLatch(1); + AtomicBoolean hasError = new AtomicBoolean(false); + final Throwable[] error = new Throwable[1]; + + publisher.subscribe(new java.util.concurrent.Flow.Subscriber() { + private java.util.concurrent.Flow.Subscription subscription; + + @Override + public void onSubscribe(java.util.concurrent.Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(Long.MAX_VALUE); // Request all events + } + + @Override + public void onNext(Event event) { + events.add(event); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + hasError.set(true); + error[0] = throwable; + throwable.printStackTrace(); + completionLatch.countDown(); + } + + @Override + public void onComplete() { + completionLatch.countDown(); + } + }); + + // Wait for stream to complete + assertTrue(completionLatch.await(5, TimeUnit.SECONDS), "Stream should complete"); + if (hasError.get()) { + fail("Stream had error: " + error[0], error[0]); + } + + // Should have received at least 2 events: Task (SUBMITTED) and TaskStatusUpdateEvent (WORKING) + assertTrue(events.size() >= 2, "Should have at least 2 events, got: " + events.size()); + + // First event should be Task with SUBMITTED state + Event firstEvent = events.get(0); + assertTrue(firstEvent instanceof Task, "First event should be Task"); + Task firstTask = (Task) firstEvent; + + assertNotNull(firstTask.id(), "Task should have an ID"); + assertTrue(!firstTask.id().startsWith("temp-"), + "Task ID should not start with 'temp-', got: " + firstTask.id()); + assertEquals(contextId, firstTask.contextId()); + assertEquals(TaskState.SUBMITTED, firstTask.status().state()); + + // Second event should be TaskStatusUpdateEvent with WORKING state + Event secondEvent = events.get(1); + assertTrue(secondEvent instanceof TaskStatusUpdateEvent, + "Second event should be TaskStatusUpdateEvent, got: " + secondEvent.getClass().getSimpleName()); + TaskStatusUpdateEvent statusUpdate = (TaskStatusUpdateEvent) secondEvent; + assertEquals(firstTask.id(), statusUpdate.taskId(), "Status update should have same task ID"); + assertEquals(TaskState.WORKING, statusUpdate.status().state()); + + // Verify task is persisted in TaskStore with WORKING state + Task storedTask = taskStore.get(firstTask.id()); + assertNotNull(storedTask, "Task should be stored"); + assertEquals(firstTask.id(), storedTask.id()); + assertEquals(TaskState.WORKING, storedTask.status().state(), "Stored task should have WORKING state"); + } + /** * Simple test agent executor that allows controlling execution timing */ diff --git a/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java b/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java index d64729077..0e25e9aad 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java @@ -11,18 +11,25 @@ import static org.mockito.Mockito.when; import java.util.Collections; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import io.a2a.server.events.EventConsumer; import io.a2a.server.events.EventQueue; +import io.a2a.server.events.EventQueueUtil; import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.events.MainEventBus; +import io.a2a.server.events.MainEventBusProcessor; +import io.a2a.spec.Event; import io.a2a.spec.EventKind; import io.a2a.spec.Message; import io.a2a.spec.Task; import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatus; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; @@ -49,7 +56,7 @@ public class ResultAggregatorTest { @BeforeEach void setUp() { MockitoAnnotations.openMocks(this); - aggregator = new ResultAggregator(mockTaskManager, null, testExecutor); + aggregator = new ResultAggregator(mockTaskManager, null, testExecutor, testExecutor); } // Helper methods for creating sample data @@ -69,13 +76,45 @@ private Task createSampleTask(String taskId, TaskState statusState, String conte .build(); } + /** + * Helper to wait for MainEventBusProcessor to process an event. + * Replaces polling patterns with deterministic callback-based waiting. + * + * @param processor the processor to set callback on + * @param action the action that triggers event processing + * @throws InterruptedException if waiting is interrupted + * @throws AssertionError if processing doesn't complete within timeout + */ + private void waitForEventProcessing(MainEventBusProcessor processor, Runnable action) throws InterruptedException { + CountDownLatch processingLatch = new CountDownLatch(1); + processor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, Event event) { + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String taskId) { + // Not needed for basic event processing wait + } + }); + + try { + action.run(); + assertTrue(processingLatch.await(5, TimeUnit.SECONDS), + "MainEventBusProcessor should have processed the event within timeout"); + } finally { + processor.setCallback(null); + } + } + // Basic functionality tests @Test void testConstructorWithMessage() { Message initialMessage = createSampleMessage("initial", "msg1", Message.Role.USER); - ResultAggregator aggregatorWithMessage = new ResultAggregator(mockTaskManager, initialMessage, testExecutor); + ResultAggregator aggregatorWithMessage = new ResultAggregator(mockTaskManager, initialMessage, testExecutor, testExecutor); // Test that the message is properly stored by checking getCurrentResult assertEquals(initialMessage, aggregatorWithMessage.getCurrentResult()); @@ -86,7 +125,7 @@ void testConstructorWithMessage() { @Test void testGetCurrentResultWithMessageSet() { Message sampleMessage = createSampleMessage("hola", "msg1", Message.Role.USER); - ResultAggregator aggregatorWithMessage = new ResultAggregator(mockTaskManager, sampleMessage, testExecutor); + ResultAggregator aggregatorWithMessage = new ResultAggregator(mockTaskManager, sampleMessage, testExecutor, testExecutor); EventKind result = aggregatorWithMessage.getCurrentResult(); @@ -121,7 +160,7 @@ void testConstructorStoresTaskManagerCorrectly() { @Test void testConstructorWithNullMessage() { - ResultAggregator aggregatorWithNullMessage = new ResultAggregator(mockTaskManager, null, testExecutor); + ResultAggregator aggregatorWithNullMessage = new ResultAggregator(mockTaskManager, null, testExecutor, testExecutor); Task expectedTask = createSampleTask("null_msg_task", TaskState.WORKING, "ctx1"); when(mockTaskManager.getTask()).thenReturn(expectedTask); @@ -181,7 +220,7 @@ void testMultipleGetCurrentResultCalls() { void testGetCurrentResultWithMessageTakesPrecedence() { // Test that when both message and task are available, message takes precedence Message message = createSampleMessage("priority message", "pri1", Message.Role.USER); - ResultAggregator messageAggregator = new ResultAggregator(mockTaskManager, message, testExecutor); + ResultAggregator messageAggregator = new ResultAggregator(mockTaskManager, message, testExecutor, testExecutor); // Even if we set up the task manager to return something, message should take precedence Task task = createSampleTask("should_not_be_returned", TaskState.WORKING, "ctx1"); @@ -197,17 +236,24 @@ void testGetCurrentResultWithMessageTakesPrecedence() { @Test void testConsumeAndBreakNonBlocking() throws Exception { // Test that with blocking=false, the method returns after the first event - Task firstEvent = createSampleTask("non_blocking_task", TaskState.WORKING, "ctx1"); + String taskId = "test-task"; + Task firstEvent = createSampleTask(taskId, TaskState.WORKING, "ctx1"); // After processing firstEvent, the current result will be that task when(mockTaskManager.getTask()).thenReturn(firstEvent); // Create an event queue using QueueManager (which has access to builder) + MainEventBus mainEventBus = new MainEventBus(); + InMemoryTaskStore taskStore = new InMemoryTaskStore(); InMemoryQueueManager queueManager = - new InMemoryQueueManager(new MockTaskStateProvider()); + new InMemoryQueueManager(new MockTaskStateProvider(), mainEventBus); + MainEventBusProcessor processor = new MainEventBusProcessor(mainEventBus, taskStore, task -> {}, queueManager); + EventQueueUtil.start(processor); + + EventQueue queue = queueManager.getEventQueueBuilder(taskId).build().tap(); - EventQueue queue = queueManager.getEventQueueBuilder("test-task").build(); - queue.enqueueEvent(firstEvent); + // Use callback to wait for event processing (replaces polling) + waitForEventProcessing(processor, () -> queue.enqueueEvent(firstEvent)); // Create real EventConsumer with the queue EventConsumer eventConsumer = @@ -221,11 +267,16 @@ void testConsumeAndBreakNonBlocking() throws Exception { assertEquals(firstEvent, result.eventType()); assertTrue(result.interrupted()); - verify(mockTaskManager).process(firstEvent); - // getTask() is called at least once for the return value (line 255) - // May be called once more if debug logging executes in time (line 209) - // The async consumer may or may not execute before verification, so we accept 1-2 calls - verify(mockTaskManager, atLeast(1)).getTask(); - verify(mockTaskManager, atMost(2)).getTask(); + // NOTE: ResultAggregator no longer calls taskManager.process() + // That responsibility has moved to MainEventBusProcessor for centralized persistence + // + // NOTE: Since firstEvent is a Task, ResultAggregator captures it directly from the queue + // (capturedTask.get() at line 283 in ResultAggregator). Therefore, taskManager.getTask() + // is only called for debug logging in taskIdForLogging() (line 305), which may or may not + // execute depending on timing and log level. We expect 0-1 calls, not 1-2. + verify(mockTaskManager, atMost(1)).getTask(); + + // Cleanup: stop the processor + EventQueueUtil.stop(processor); } } diff --git a/server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java b/server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java index f14ebc0fe..abeeed859 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java @@ -42,7 +42,7 @@ public class TaskManagerTest { public void init() throws Exception { minimalTask = fromJson(TASK_JSON, Task.class); taskStore = new InMemoryTaskStore(); - taskManager = new TaskManager(minimalTask.id(), minimalTask.contextId(), taskStore, null); + taskManager = new TaskManager(minimalTask.id(), minimalTask.contextId(), taskStore, null, false); } @Test @@ -136,7 +136,7 @@ public void testEnsureTaskExisting() { @Test public void testEnsureTaskNonExistentForStatusUpdate() throws A2AServerException { // Tests that an update event instantiates a new task and that - TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null); + TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null, false); TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() .taskId("new-task") .contextId("some-context") @@ -157,7 +157,7 @@ public void testEnsureTaskNonExistentForStatusUpdate() throws A2AServerException @Test public void testSaveTaskEventNewTaskNoTaskId() throws A2AServerException { - TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null); + TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null, false); Task task = Task.builder() .id("new-task-id") .contextId("some-context") @@ -175,7 +175,7 @@ public void testSaveTaskEventNewTaskNoTaskId() throws A2AServerException { @Test public void testGetTaskNoTaskId() { - TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null); + TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null, false); Task retrieved = taskManagerWithoutId.getTask(); assertNull(retrieved); } @@ -321,7 +321,7 @@ public void testTaskArtifactUpdateEventAppendNullWithExistingArtifact() throws A @Test public void testAddingTaskWithDifferentIdFails() { // Test that adding a task with a different id from the taskmanager's taskId fails - TaskManager taskManagerWithId = new TaskManager("task-abc", "session-xyz", taskStore, null); + TaskManager taskManagerWithId = new TaskManager("task-abc", "session-xyz", taskStore, null, false); Task differentTask = Task.builder() .id("different-task-id") @@ -337,7 +337,7 @@ public void testAddingTaskWithDifferentIdFails() { @Test public void testAddingTaskWithDifferentIdViaStatusUpdateFails() { // Test that adding a status update with different taskId fails - TaskManager taskManagerWithId = new TaskManager("task-abc", "session-xyz", taskStore, null); + TaskManager taskManagerWithId = new TaskManager("task-abc", "session-xyz", taskStore, null, false); TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() .taskId("different-task-id") @@ -354,7 +354,7 @@ public void testAddingTaskWithDifferentIdViaStatusUpdateFails() { @Test public void testAddingTaskWithDifferentIdViaArtifactUpdateFails() { // Test that adding an artifact update with different taskId fails - TaskManager taskManagerWithId = new TaskManager("task-abc", "session-xyz", taskStore, null); + TaskManager taskManagerWithId = new TaskManager("task-abc", "session-xyz", taskStore, null, false); Artifact artifact = Artifact.builder() .artifactId("artifact-id") @@ -382,7 +382,7 @@ public void testTaskWithNoMessageUsesInitialMessage() throws A2AServerException .messageId("initial-msg-id") .build(); - TaskManager taskManagerWithInitialMessage = new TaskManager(null, null, taskStore, initialMessage); + TaskManager taskManagerWithInitialMessage = new TaskManager(null, null, taskStore, initialMessage, false); // Use a status update event instead of a Task to trigger createTask TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() @@ -413,7 +413,7 @@ public void testTaskWithMessageDoesNotUseInitialMessage() throws A2AServerExcept .messageId("initial-msg-id") .build(); - TaskManager taskManagerWithInitialMessage = new TaskManager(null, null, taskStore, initialMessage); + TaskManager taskManagerWithInitialMessage = new TaskManager(null, null, taskStore, initialMessage, false); Message taskMessage = Message.builder() .role(Message.Role.AGENT) @@ -533,11 +533,11 @@ public void testMultipleArtifactsWithDifferentArtifactIds() throws A2AServerExce @Test public void testInvalidTaskIdValidation() { // Test that creating TaskManager with null taskId is allowed (Python allows None) - TaskManager taskManagerWithNullId = new TaskManager(null, "context", taskStore, null); + TaskManager taskManagerWithNullId = new TaskManager(null, "context", taskStore, null, false); assertNull(taskManagerWithNullId.getTaskId()); // Test that empty string task ID is handled (Java doesn't have explicit validation like Python) - TaskManager taskManagerWithEmptyId = new TaskManager("", "context", taskStore, null); + TaskManager taskManagerWithEmptyId = new TaskManager("", "context", taskStore, null, false); assertEquals("", taskManagerWithEmptyId.getTaskId()); } @@ -625,7 +625,7 @@ public void testCreateTaskWithInitialMessage() throws A2AServerException { .messageId("initial-msg-id") .build(); - TaskManager taskManagerWithMessage = new TaskManager(null, null, taskStore, initialMessage); + TaskManager taskManagerWithMessage = new TaskManager(null, null, taskStore, initialMessage, false); TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() .taskId("new-task-id") @@ -653,7 +653,7 @@ public void testCreateTaskWithInitialMessage() throws A2AServerException { @Test public void testCreateTaskWithoutInitialMessage() throws A2AServerException { // Test task creation without initial message - TaskManager taskManagerWithoutMessage = new TaskManager(null, null, taskStore, null); + TaskManager taskManagerWithoutMessage = new TaskManager(null, null, taskStore, null, false); TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() .taskId("new-task-id") @@ -677,7 +677,7 @@ public void testCreateTaskWithoutInitialMessage() throws A2AServerException { @Test public void testSaveTaskInternal() throws A2AServerException { // Test equivalent of _save_task functionality through saveTaskEvent - TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null); + TaskManager taskManagerWithoutId = new TaskManager(null, null, taskStore, null, false); Task newTask = Task.builder() .id("test-task-id") @@ -701,7 +701,7 @@ public void testUpdateWithMessage() throws A2AServerException { .messageId("initial-msg-id") .build(); - TaskManager taskManagerWithInitialMessage = new TaskManager(null, null, taskStore, initialMessage); + TaskManager taskManagerWithInitialMessage = new TaskManager(null, null, taskStore, initialMessage, false); Message taskMessage = Message.builder() .role(Message.Role.AGENT) diff --git a/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java b/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java index 40f763569..73da17824 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java @@ -14,7 +14,11 @@ import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.events.EventQueue; +import io.a2a.server.events.EventQueueItem; import io.a2a.server.events.EventQueueUtil; +import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.events.MainEventBus; +import io.a2a.server.events.MainEventBusProcessor; import io.a2a.spec.Event; import io.a2a.spec.Message; import io.a2a.spec.Part; @@ -22,6 +26,7 @@ import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -38,14 +43,28 @@ public class TaskUpdaterTest { private static final List> SAMPLE_PARTS = List.of(new TextPart("Test message")); + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; + EventQueue eventQueue; + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; private TaskUpdater taskUpdater; @BeforeEach public void init() { - eventQueue = EventQueueUtil.getEventQueueBuilder().build(); + // Set up MainEventBus and processor for production-like test environment + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + InMemoryQueueManager queueManager = new InMemoryQueueManager(taskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + + eventQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TEST_TASK_ID) + .mainEventBus(mainEventBus) + .build().tap(); RequestContext context = new RequestContext.Builder() .setTaskId(TEST_TASK_ID) .setContextId(TEST_TASK_CONTEXT_ID) @@ -53,10 +72,19 @@ public void init() { taskUpdater = new TaskUpdater(context, eventQueue); } + @AfterEach + public void cleanup() { + if (mainEventBusProcessor != null) { + EventQueueUtil.stop(mainEventBusProcessor); + } + } + @Test public void testAddArtifactWithCustomIdAndName() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, "custom-artifact-id", "Custom Artifact", null); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -239,7 +267,9 @@ public void testNewAgentMessageWithMetadata() throws Exception { @Test public void testAddArtifactWithAppendTrue() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, "artifact-id", "Test Artifact", null, true, null); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -258,7 +288,9 @@ public void testAddArtifactWithAppendTrue() throws Exception { @Test public void testAddArtifactWithLastChunkTrue() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, "artifact-id", "Test Artifact", null, null, true); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -273,7 +305,9 @@ public void testAddArtifactWithLastChunkTrue() throws Exception { @Test public void testAddArtifactWithAppendAndLastChunk() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, "artifact-id", "Test Artifact", null, true, false); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -287,7 +321,9 @@ public void testAddArtifactWithAppendAndLastChunk() throws Exception { @Test public void testAddArtifactGeneratesIdWhenNull() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, null, "Test Artifact", null); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -383,7 +419,9 @@ public void testConcurrentCompletionAttempts() throws Exception { thread2.join(); // Exactly one event should have been queued - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskStatusUpdateEvent.class, event); @@ -396,7 +434,10 @@ public void testConcurrentCompletionAttempts() throws Exception { } private TaskStatusUpdateEvent checkTaskStatusUpdateEventOnQueue(boolean isFinal, TaskState state, Message statusMessage) throws Exception { - Event event = eventQueue.dequeueEventItem(0).getEvent(); + // Wait up to 5 seconds for event (async MainEventBusProcessor needs time to distribute) + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskStatusUpdateEvent.class, event); @@ -408,6 +449,7 @@ private TaskStatusUpdateEvent checkTaskStatusUpdateEventOnQueue(boolean isFinal, assertEquals(state, tsue.status().state()); assertEquals(statusMessage, tsue.status().message()); + // Check no additional events (still use 0 timeout for this check) assertNull(eventQueue.dequeueEventItem(0)); return tsue; diff --git a/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java b/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java index 408205aa2..439d97497 100644 --- a/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java +++ b/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java @@ -242,7 +242,7 @@ public void sendStreamingMessage(io.a2a.grpc.SendMessageRequest request, A2AExtensions.validateRequiredExtensions(getAgentCardInternal(), context); MessageSendParams params = FromProto.messageSendParams(request); Flow.Publisher publisher = getRequestHandler().onMessageSendStream(params, context); - convertToStreamResponse(publisher, responseObserver); + convertToStreamResponse(publisher, responseObserver, context); } catch (A2AError e) { handleError(responseObserver, e); } catch (SecurityException e) { @@ -264,7 +264,7 @@ public void subscribeToTask(io.a2a.grpc.SubscribeToTaskRequest request, ServerCallContext context = createCallContext(responseObserver); TaskIdParams params = FromProto.taskIdParams(request); Flow.Publisher publisher = getRequestHandler().onResubscribeToTask(params, context); - convertToStreamResponse(publisher, responseObserver); + convertToStreamResponse(publisher, responseObserver, context); } catch (A2AError e) { handleError(responseObserver, e); } catch (SecurityException e) { @@ -275,7 +275,8 @@ public void subscribeToTask(io.a2a.grpc.SubscribeToTaskRequest request, } private void convertToStreamResponse(Flow.Publisher publisher, - StreamObserver responseObserver) { + StreamObserver responseObserver, + ServerCallContext context) { CompletableFuture.runAsync(() -> { publisher.subscribe(new Flow.Subscriber() { private Flow.Subscription subscription; @@ -285,6 +286,18 @@ public void onSubscribe(Flow.Subscription subscription) { this.subscription = subscription; subscription.request(1); + // Detect gRPC client disconnect and call EventConsumer.cancel() directly + // This stops the polling loop without relying on subscription cancellation propagation + Context grpcContext = Context.current(); + grpcContext.addListener(new Context.CancellationListener() { + @Override + public void cancelled(Context ctx) { + LOGGER.fine(() -> "gRPC call cancelled by client, calling EventConsumer.cancel() to stop polling loop"); + context.invokeEventConsumerCancelCallback(); + subscription.cancel(); + } + }, getExecutor()); + // Notify tests that we are subscribed Runnable runnable = streamingSubscribedRunnable; if (runnable != null) { @@ -305,6 +318,8 @@ public void onNext(StreamingEventKind event) { @Override public void onError(Throwable throwable) { + // Cancel upstream to stop EventConsumer when error occurs + subscription.cancel(); if (throwable instanceof A2AError jsonrpcError) { handleError(responseObserver, jsonrpcError); } else { @@ -329,6 +344,9 @@ public void getExtendedAgentCard(io.a2a.grpc.GetExtendedAgentCardRequest request if (extendedAgentCard != null) { responseObserver.onNext(ToProto.agentCard(extendedAgentCard)); responseObserver.onCompleted(); + } else { + // Extended agent card not configured - return error instead of hanging + handleError(responseObserver, new ExtendedAgentCardNotConfiguredError(null, "Extended agent card not configured", null)); } } catch (Throwable t) { handleInternalError(responseObserver, t); diff --git a/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java b/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java index 690d69a87..f7d711382 100644 --- a/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java +++ b/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java @@ -281,8 +281,7 @@ public void testPushNotificationsNotSupportedError() throws Exception { @Test public void testOnGetPushNotificationNoPushNotifierConfig() throws Exception { // Create request handler without a push notifier - DefaultRequestHandler requestHandler = - new DefaultRequestHandler(executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = AbstractA2ARequestHandlerTest.createAgentCard(false, true, false); GrpcHandler handler = new TestGrpcHandler(card, requestHandler, internalExecutor); String NAME = "tasks/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id() + "/pushNotificationConfigs/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(); @@ -293,8 +292,7 @@ public void testOnGetPushNotificationNoPushNotifierConfig() throws Exception { @Test public void testOnSetPushNotificationNoPushNotifierConfig() throws Exception { // Create request handler without a push notifier - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = AbstractA2ARequestHandlerTest.createAgentCard(false, true, false); GrpcHandler handler = new TestGrpcHandler(card, requestHandler, internalExecutor); String NAME = "tasks/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id() + "/pushNotificationConfigs/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(); @@ -424,9 +422,14 @@ public void testOnMessageStreamNewMessageExistingTaskSuccessMocks() throws Excep @Test public void testOnMessageStreamNewMessageSendPushNotificationSuccess() throws Exception { - GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - List events = List.of( - AbstractA2ARequestHandlerTest.MINIMAL_TASK, + // Use synchronous executor for push notifications to ensure deterministic ordering + // Without this, async push notifications can execute out of order, causing test flakiness + mainEventBusProcessor.setPushNotificationExecutor(Runnable::run); + + try { + GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); + List events = List.of( + AbstractA2ARequestHandlerTest.MINIMAL_TASK, TaskArtifactUpdateEvent.builder() .taskId(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id()) .contextId(AbstractA2ARequestHandlerTest.MINIMAL_TASK.contextId()) @@ -493,13 +496,16 @@ public void onCompleted() { Assertions.assertEquals(1, curr.artifacts().get(0).parts().size()); Assertions.assertEquals("text", ((TextPart)curr.artifacts().get(0).parts().get(0)).text()); - curr = httpClient.tasks.get(2); - Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(), curr.id()); - Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.contextId(), curr.contextId()); - Assertions.assertEquals(io.a2a.spec.TaskState.COMPLETED, curr.status().state()); - Assertions.assertEquals(1, curr.artifacts().size()); - Assertions.assertEquals(1, curr.artifacts().get(0).parts().size()); - Assertions.assertEquals("text", ((TextPart)curr.artifacts().get(0).parts().get(0)).text()); + curr = httpClient.tasks.get(2); + Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(), curr.id()); + Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.contextId(), curr.contextId()); + Assertions.assertEquals(io.a2a.spec.TaskState.COMPLETED, curr.status().state()); + Assertions.assertEquals(1, curr.artifacts().size()); + Assertions.assertEquals(1, curr.artifacts().get(0).parts().size()); + Assertions.assertEquals("text", ((TextPart)curr.artifacts().get(0).parts().get(0)).text()); + } finally { + mainEventBusProcessor.setPushNotificationExecutor(null); + } } @Test @@ -668,8 +674,7 @@ public void testListPushNotificationConfigNotSupported() throws Exception { @Test public void testListPushNotificationConfigNoPushConfigStore() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); agentExecutorExecute = (context, eventQueue) -> { @@ -741,8 +746,7 @@ public void testDeletePushNotificationConfigNotSupported() throws Exception { @Test public void testDeletePushNotificationConfigNoPushConfigStore() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); String NAME = "tasks/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id() + "/pushNotificationConfigs/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(); DeleteTaskPushNotificationConfigRequest request = DeleteTaskPushNotificationConfigRequest.newBuilder() diff --git a/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java b/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java index b43c28029..f09e9a40f 100644 --- a/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java +++ b/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java @@ -3,6 +3,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.fail; import java.util.ArrayList; import java.util.Collections; @@ -30,6 +31,8 @@ import io.a2a.jsonrpc.common.wrappers.GetTaskResponse; import io.a2a.jsonrpc.common.wrappers.ListTaskPushNotificationConfigRequest; import io.a2a.jsonrpc.common.wrappers.ListTaskPushNotificationConfigResponse; +import io.a2a.jsonrpc.common.wrappers.ListTasksRequest; +import io.a2a.jsonrpc.common.wrappers.ListTasksResponse; import io.a2a.jsonrpc.common.wrappers.ListTasksResult; import io.a2a.jsonrpc.common.wrappers.SendMessageRequest; import io.a2a.jsonrpc.common.wrappers.SendMessageResponse; @@ -37,12 +40,11 @@ import io.a2a.jsonrpc.common.wrappers.SendStreamingMessageResponse; import io.a2a.jsonrpc.common.wrappers.SetTaskPushNotificationConfigRequest; import io.a2a.jsonrpc.common.wrappers.SetTaskPushNotificationConfigResponse; -import io.a2a.jsonrpc.common.wrappers.ListTasksRequest; -import io.a2a.jsonrpc.common.wrappers.ListTasksResponse; import io.a2a.jsonrpc.common.wrappers.SubscribeToTaskRequest; import io.a2a.server.ServerCallContext; import io.a2a.server.auth.UnauthenticatedUser; import io.a2a.server.events.EventConsumer; +import io.a2a.server.events.MainEventBusProcessorCallback; import io.a2a.server.requesthandlers.AbstractA2ARequestHandlerTest; import io.a2a.server.requesthandlers.DefaultRequestHandler; import io.a2a.server.tasks.ResultAggregator; @@ -52,16 +54,15 @@ import io.a2a.spec.AgentExtension; import io.a2a.spec.AgentInterface; import io.a2a.spec.Artifact; -import io.a2a.spec.ExtendedAgentCardNotConfiguredError; -import io.a2a.spec.ExtensionSupportRequiredError; -import io.a2a.spec.VersionNotSupportedError; import io.a2a.spec.DeleteTaskPushNotificationConfigParams; import io.a2a.spec.Event; +import io.a2a.spec.ExtendedAgentCardNotConfiguredError; +import io.a2a.spec.ExtensionSupportRequiredError; import io.a2a.spec.GetTaskPushNotificationConfigParams; import io.a2a.spec.InternalError; import io.a2a.spec.InvalidRequestError; -import io.a2a.spec.ListTasksParams; import io.a2a.spec.ListTaskPushNotificationConfigParams; +import io.a2a.spec.ListTasksParams; import io.a2a.spec.Message; import io.a2a.spec.MessageSendParams; import io.a2a.spec.PushNotificationConfig; @@ -78,6 +79,7 @@ import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; import io.a2a.spec.UnsupportedOperationError; +import io.a2a.spec.VersionNotSupportedError; import mutiny.zero.ZeroPublisher; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Disabled; @@ -174,38 +176,9 @@ public void testOnMessageNewMessageSuccess() { SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); SendMessageResponse response = handler.onMessageSend(request, callContext); assertNull(response.getError()); - // The Python implementation returns a Task here, but then again they are using hardcoded mocks and - // bypassing the whole EventQueue. - // If we were to send a Task in agentExecutorExecute EventConsumer.consumeAll() would not exit due to - // the Task not having a 'final' state - // - // See testOnMessageNewMessageSuccessMocks() for a test more similar to the Python implementation Assertions.assertSame(message, response.getResult()); } - @Test - public void testOnMessageNewMessageSuccessMocks() { - JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - - Message message = Message.builder(MESSAGE) - .taskId(MINIMAL_TASK.id()) - .contextId(MINIMAL_TASK.contextId()) - .build(); - - SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); - SendMessageResponse response; - try (MockedConstruction mocked = Mockito.mockConstruction( - EventConsumer.class, - (mock, context) -> { - Mockito.doReturn(ZeroPublisher.fromItems(wrapEvent(MINIMAL_TASK))).when(mock).consumeAll(); - Mockito.doCallRealMethod().when(mock).createAgentRunnableDoneCallback(); - })) { - response = handler.onMessageSend(request, callContext); - } - assertNull(response.getError()); - Assertions.assertSame(MINIMAL_TASK, response.getResult()); - } - @Test public void testOnMessageNewMessageWithExistingTaskSuccess() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); @@ -220,38 +193,9 @@ public void testOnMessageNewMessageWithExistingTaskSuccess() { SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); SendMessageResponse response = handler.onMessageSend(request, callContext); assertNull(response.getError()); - // The Python implementation returns a Task here, but then again they are using hardcoded mocks and - // bypassing the whole EventQueue. - // If we were to send a Task in agentExecutorExecute EventConsumer.consumeAll() would not exit due to - // the Task not having a 'final' state - // - // See testOnMessageNewMessageWithExistingTaskSuccessMocks() for a test more similar to the Python implementation Assertions.assertSame(message, response.getResult()); } - @Test - public void testOnMessageNewMessageWithExistingTaskSuccessMocks() { - JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); - - Message message = Message.builder(MESSAGE) - .taskId(MINIMAL_TASK.id()) - .contextId(MINIMAL_TASK.contextId()) - .build(); - SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); - SendMessageResponse response; - try (MockedConstruction mocked = Mockito.mockConstruction( - EventConsumer.class, - (mock, context) -> { - Mockito.doReturn(ZeroPublisher.fromItems(wrapEvent(MINIMAL_TASK))).when(mock).consumeAll(); - })) { - response = handler.onMessageSend(request, callContext); - } - assertNull(response.getError()); - Assertions.assertSame(MINIMAL_TASK, response.getResult()); - - } - @Test public void testOnMessageError() { // See testMessageOnErrorMocks() for a test more similar to the Python implementation, using mocks for @@ -352,9 +296,11 @@ public void onComplete() { @Test public void testOnMessageStreamNewMessageMultipleEventsSuccess() throws InterruptedException { + // Note: Do NOT set callback - DefaultRequestHandler has a permanent callback + // We'll verify persistence by checking TaskStore after streaming completes JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - // Create multiple events to be sent during streaming + // Create multiple events to be sent during streaming Task taskEvent = Task.builder(MINIMAL_TASK) .status(new TaskStatus(TaskState.WORKING)) .build(); @@ -429,8 +375,8 @@ public void onComplete() { } }); - // Wait for all events to be received - Assertions.assertTrue(latch.await(2, TimeUnit.SECONDS), + // Wait for all events to be received (increased timeout for async processing) + Assertions.assertTrue(latch.await(10, TimeUnit.SECONDS), "Expected to receive 3 events within timeout"); // Assert no error occurred during streaming @@ -456,6 +402,17 @@ public void onComplete() { "Third event should be a TaskStatusUpdateEvent"); assertEquals(MINIMAL_TASK.id(), receivedStatus.taskId()); assertEquals(TaskState.COMPLETED, receivedStatus.status().state()); + + // Verify events were persisted to TaskStore (poll for final state) + for (int i = 0; i < 50; i++) { + Task storedTask = taskStore.get(MINIMAL_TASK.id()); + if (storedTask != null && storedTask.status() != null + && TaskState.COMPLETED.equals(storedTask.status().state())) { + return; // Success - task finalized in TaskStore + } + Thread.sleep(100); + } + fail("Task should have been finalized in TaskStore within timeout"); } @Test @@ -729,106 +686,118 @@ public void testGetPushNotificationConfigSuccess() { @Test public void testOnMessageStreamNewMessageSendPushNotificationSuccess() throws Exception { - JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); - - List events = List.of( - MINIMAL_TASK, - TaskArtifactUpdateEvent.builder() - .taskId(MINIMAL_TASK.id()) - .contextId(MINIMAL_TASK.contextId()) - .artifact(Artifact.builder() - .artifactId("11") - .parts(new TextPart("text")) - .build()) - .build(), - TaskStatusUpdateEvent.builder() - .taskId(MINIMAL_TASK.id()) - .contextId(MINIMAL_TASK.contextId()) - .status(new TaskStatus(TaskState.COMPLETED)) - .build()); - - agentExecutorExecute = (context, eventQueue) -> { - // Hardcode the events to send here - for (Event event : events) { - eventQueue.enqueueEvent(event); - } - }; - - TaskPushNotificationConfig config = new TaskPushNotificationConfig( - MINIMAL_TASK.id(), - PushNotificationConfig.builder().id("c295ea44-7543-4f78-b524-7a38915ad6e4").url("http://example.com").build(), "tenant"); - - SetTaskPushNotificationConfigRequest stpnRequest = new SetTaskPushNotificationConfigRequest("1", config); - SetTaskPushNotificationConfigResponse stpnResponse = handler.setPushNotificationConfig(stpnRequest, callContext); - assertNull(stpnResponse.getError()); - - Message msg = Message.builder(MESSAGE) - .taskId(MINIMAL_TASK.id()) - .build(); - SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(msg, null, null)); - Flow.Publisher response = handler.onMessageSendStream(request, callContext); - - final List results = Collections.synchronizedList(new ArrayList<>()); - final AtomicReference subscriptionRef = new AtomicReference<>(); - final CountDownLatch latch = new CountDownLatch(6); - httpClient.latch = latch; - - Executors.newSingleThreadExecutor().execute(() -> { - response.subscribe(new Flow.Subscriber<>() { - @Override - public void onSubscribe(Flow.Subscription subscription) { - subscriptionRef.set(subscription); - subscription.request(1); - } - - @Override - public void onNext(SendStreamingMessageResponse item) { - System.out.println("-> " + item.getResult()); - results.add(item.getResult()); - System.out.println(results); - subscriptionRef.get().request(1); - latch.countDown(); - } - - @Override - public void onError(Throwable throwable) { - subscriptionRef.get().cancel(); - } - - @Override - public void onComplete() { - subscriptionRef.get().cancel(); + // Note: Do NOT set callback - DefaultRequestHandler has a permanent callback + + // Use synchronous executor for push notifications to ensure deterministic ordering + // Without this, async push notifications can execute out of order, causing test flakiness + mainEventBusProcessor.setPushNotificationExecutor(Runnable::run); + + try { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); + taskStore.save(MINIMAL_TASK); + + List events = List.of( + MINIMAL_TASK, + TaskArtifactUpdateEvent.builder() + .taskId(MINIMAL_TASK.id()) + .contextId(MINIMAL_TASK.contextId()) + .artifact(Artifact.builder() + .artifactId("11") + .parts(new TextPart("text")) + .build()) + .build(), + TaskStatusUpdateEvent.builder() + .taskId(MINIMAL_TASK.id()) + .contextId(MINIMAL_TASK.contextId()) + .status(new TaskStatus(TaskState.COMPLETED)) + .build()); + + + agentExecutorExecute = (context, eventQueue) -> { + // Hardcode the events to send here + for (Event event : events) { + eventQueue.enqueueEvent(event); } + }; + + TaskPushNotificationConfig config = new TaskPushNotificationConfig( + MINIMAL_TASK.id(), + PushNotificationConfig.builder().id("c295ea44-7543-4f78-b524-7a38915ad6e4").url("http://example.com").build(), "tenant"); + + SetTaskPushNotificationConfigRequest stpnRequest = new SetTaskPushNotificationConfigRequest("1", config); + SetTaskPushNotificationConfigResponse stpnResponse = handler.setPushNotificationConfig(stpnRequest, callContext); + assertNull(stpnResponse.getError()); + + Message msg = Message.builder(MESSAGE) + .taskId(MINIMAL_TASK.id()) + .build(); + SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(msg, null, null)); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); + + final List results = Collections.synchronizedList(new ArrayList<>()); + final AtomicReference subscriptionRef = new AtomicReference<>(); + final CountDownLatch latch = new CountDownLatch(6); + httpClient.latch = latch; + + Executors.newSingleThreadExecutor().execute(() -> { + response.subscribe(new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + subscriptionRef.set(subscription); + subscription.request(1); + } + + @Override + public void onNext(SendStreamingMessageResponse item) { + System.out.println("-> " + item.getResult()); + results.add(item.getResult()); + System.out.println(results); + subscriptionRef.get().request(1); + latch.countDown(); + } + + @Override + public void onError(Throwable throwable) { + subscriptionRef.get().cancel(); + } + + @Override + public void onComplete() { + subscriptionRef.get().cancel(); + } + }); }); - }); - Assertions.assertTrue(latch.await(5, TimeUnit.SECONDS)); - subscriptionRef.get().cancel(); - assertEquals(3, results.size()); - assertEquals(3, httpClient.tasks.size()); - - Task curr = httpClient.tasks.get(0); - assertEquals(MINIMAL_TASK.id(), curr.id()); - assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); - assertEquals(MINIMAL_TASK.status().state(), curr.status().state()); - assertEquals(0, curr.artifacts() == null ? 0 : curr.artifacts().size()); - - curr = httpClient.tasks.get(1); - assertEquals(MINIMAL_TASK.id(), curr.id()); - assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); - assertEquals(MINIMAL_TASK.status().state(), curr.status().state()); - assertEquals(1, curr.artifacts().size()); - assertEquals(1, curr.artifacts().get(0).parts().size()); - assertEquals("text", ((TextPart) curr.artifacts().get(0).parts().get(0)).text()); - - curr = httpClient.tasks.get(2); - assertEquals(MINIMAL_TASK.id(), curr.id()); - assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); - assertEquals(TaskState.COMPLETED, curr.status().state()); - assertEquals(1, curr.artifacts().size()); - assertEquals(1, curr.artifacts().get(0).parts().size()); - assertEquals("text", ((TextPart) curr.artifacts().get(0).parts().get(0)).text()); + Assertions.assertTrue(latch.await(5, TimeUnit.SECONDS)); + + subscriptionRef.get().cancel(); + assertEquals(3, results.size()); + assertEquals(3, httpClient.tasks.size()); + + Task curr = httpClient.tasks.get(0); + assertEquals(MINIMAL_TASK.id(), curr.id()); + assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); + assertEquals(MINIMAL_TASK.status().state(), curr.status().state()); + assertEquals(0, curr.artifacts() == null ? 0 : curr.artifacts().size()); + + curr = httpClient.tasks.get(1); + assertEquals(MINIMAL_TASK.id(), curr.id()); + assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); + assertEquals(MINIMAL_TASK.status().state(), curr.status().state()); + assertEquals(1, curr.artifacts().size()); + assertEquals(1, curr.artifacts().get(0).parts().size()); + assertEquals("text", ((TextPart) curr.artifacts().get(0).parts().get(0)).text()); + + curr = httpClient.tasks.get(2); + assertEquals(MINIMAL_TASK.id(), curr.id()); + assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); + assertEquals(TaskState.COMPLETED, curr.status().state()); + assertEquals(1, curr.artifacts().size()); + assertEquals(1, curr.artifacts().get(0).parts().size()); + assertEquals("text", ((TextPart) curr.artifacts().get(0).parts().get(0)).text()); + } finally { + mainEventBusProcessor.setPushNotificationExecutor(null); + } } @Test @@ -1060,7 +1029,7 @@ public void onComplete() { if (results.get(0).getError() != null && results.get(0).getError() instanceof InvalidRequestError ire) { assertEquals("Streaming is not supported by the agent", ire.getMessage()); } else { - Assertions.fail("Expected a response containing an error"); + fail("Expected a response containing an error"); } } @@ -1107,7 +1076,7 @@ public void onComplete() { if (results.get(0).getError() != null && results.get(0).getError() instanceof InvalidRequestError ire) { assertEquals("Streaming is not supported by the agent", ire.getMessage()); } else { - Assertions.fail("Expected a response containing an error"); + fail("Expected a response containing an error"); } } @@ -1135,8 +1104,7 @@ public void testPushNotificationsNotSupportedError() { @Test public void testOnGetPushNotificationNoPushNotifierConfig() { // Create request handler without a push notifier - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = createAgentCard(false, true, false); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler, internalExecutor); @@ -1154,8 +1122,7 @@ public void testOnGetPushNotificationNoPushNotifierConfig() { @Test public void testOnSetPushNotificationNoPushNotifierConfig() { // Create request handler without a push notifier - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = createAgentCard(false, true, false); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler, internalExecutor); @@ -1246,8 +1213,7 @@ public void testDefaultRequestHandlerWithCustomComponents() { @Test public void testOnMessageSendErrorHandling() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = createAgentCard(false, true, false); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler, internalExecutor); @@ -1293,16 +1259,17 @@ public void testOnMessageSendTaskIdMismatch() { } @Test - public void testOnMessageStreamTaskIdMismatch() { + public void testOnMessageStreamTaskIdMismatch() throws InterruptedException { + // Note: Do NOT set callback - DefaultRequestHandler has a permanent callback JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK); - agentExecutorExecute = ((context, eventQueue) -> { - eventQueue.enqueueEvent(MINIMAL_TASK); - }); + agentExecutorExecute = ((context, eventQueue) -> { + eventQueue.enqueueEvent(MINIMAL_TASK); + }); - SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); - Flow.Publisher response = handler.onMessageSendStream(request, callContext); + SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); CompletableFuture future = new CompletableFuture<>(); List results = new ArrayList<>(); @@ -1404,8 +1371,7 @@ public void testListPushNotificationConfigNotSupported() { @Test public void testListPushNotificationConfigNoPushConfigStore() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); taskStore.save(MINIMAL_TASK); agentExecutorExecute = (context, eventQueue) -> { @@ -1496,8 +1462,8 @@ public void testDeletePushNotificationConfigNotSupported() { @Test public void testDeletePushNotificationConfigNoPushConfigStore() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = + DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); taskStore.save(MINIMAL_TASK); agentExecutorExecute = (context, eventQueue) -> { diff --git a/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java b/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java index 3ffb56c5f..3273d6119 100644 --- a/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java +++ b/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java @@ -399,32 +399,46 @@ private Flow.Publisher convertToSendStreamingMessageResponse( Flow.Publisher publisher) { // We can't use the normal convertingProcessor since that propagates any errors as an error handled // via Subscriber.onError() rather than as part of the SendStreamingResponse payload + log.log(Level.FINE, "REST: convertToSendStreamingMessageResponse called, creating ZeroPublisher"); return ZeroPublisher.create(createTubeConfig(), tube -> { + log.log(Level.FINE, "REST: ZeroPublisher tube created, starting CompletableFuture.runAsync"); CompletableFuture.runAsync(() -> { + log.log(Level.FINE, "REST: Inside CompletableFuture, subscribing to EventKind publisher"); publisher.subscribe(new Flow.Subscriber() { Flow.@Nullable Subscription subscription; @Override public void onSubscribe(Flow.Subscription subscription) { + log.log(Level.FINE, "REST: onSubscribe called, storing subscription and requesting first event"); this.subscription = subscription; subscription.request(1); } @Override public void onNext(StreamingEventKind item) { + log.log(Level.FINE, "REST: onNext called with event: {0}", item.getClass().getSimpleName()); try { String payload = JsonFormat.printer().omittingInsignificantWhitespace().print(ProtoUtils.ToProto.taskOrMessageStream(item)); + log.log(Level.FINE, "REST: Converted to JSON, sending via tube: {0}", payload.substring(0, Math.min(100, payload.length()))); tube.send(payload); + log.log(Level.FINE, "REST: tube.send() completed, requesting next event from EventConsumer"); + // Request next event from EventConsumer (Chain 1: EventConsumer → RestHandler) + // This is safe because ZeroPublisher buffers items + // Chain 2 (ZeroPublisher → MultiSseSupport) controls actual delivery via request(1) in onWriteDone() if (subscription != null) { subscription.request(1); + } else { + log.log(Level.WARNING, "REST: subscription is null in onNext!"); } } catch (InvalidProtocolBufferException ex) { + log.log(Level.SEVERE, "REST: JSON conversion failed", ex); onError(ex); } } @Override public void onError(Throwable throwable) { + log.log(Level.SEVERE, "REST: onError called", throwable); if (throwable instanceof A2AError jsonrpcError) { tube.send(new HTTPRestErrorResponse(jsonrpcError).toJson()); } else { @@ -435,6 +449,7 @@ public void onError(Throwable throwable) { @Override public void onComplete() { + log.log(Level.FINE, "REST: onComplete called, calling tube.complete()"); tube.complete(); } });