From 4b15e1073a290227001a30b0686077a07acdca18 Mon Sep 17 00:00:00 2001 From: Emmanuel Hugonnet Date: Fri, 23 Jan 2026 09:53:52 +0100 Subject: [PATCH] test: Add input-required workflow coverage Add test for input-required task state transitions: create task, receive input-required status, send input, and complete task. Signed-off-by: Emmanuel Hugonnet --- .../apps/common/AbstractA2AServerTest.java | 62 +++++++++++++++++++ .../apps/common/AgentExecutorProducer.java | 42 +++++++++++++ 2 files changed, 104 insertions(+) diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java index 5fab9d8d7..3c964759d 100644 --- a/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java @@ -1418,6 +1418,68 @@ public void testNonBlockingWithMultipleMessages() throws Exception { } } + @Test + @Timeout(value = 1, unit = TimeUnit.MINUTES) + public void testInputRequiredWorkflow() throws Exception { + String inputRequiredTaskId = "input-required-test-" + java.util.UUID.randomUUID(); + try { + // 1. Send initial message - AgentExecutor will transition task to INPUT_REQUIRED + Message initialMessage = Message.builder(MESSAGE) + .taskId(inputRequiredTaskId) + .contextId("test-context") + .parts(new TextPart("Initial request")) + .build(); + + CountDownLatch initialLatch = new CountDownLatch(1); + AtomicReference initialState = new AtomicReference<>(); + AtomicBoolean initialUnexpectedEvent = new AtomicBoolean(false); + + BiConsumer initialConsumer = (event, agentCard) -> { + if (event instanceof TaskEvent te) { + initialState.set(te.getTask().status().state()); + initialLatch.countDown(); + } else { + initialUnexpectedEvent.set(true); + } + }; + + // Send initial message - task will go to INPUT_REQUIRED state + getNonStreamingClient().sendMessage(initialMessage, List.of(initialConsumer), null); + assertTrue(initialLatch.await(10, TimeUnit.SECONDS)); + assertFalse(initialUnexpectedEvent.get()); + assertEquals(TaskState.INPUT_REQUIRED, initialState.get()); + + // 2. Send input message - AgentExecutor will complete the task + Message inputMessage = Message.builder(MESSAGE) + .taskId(inputRequiredTaskId) + .contextId("test-context") + .parts(new TextPart("User input")) + .build(); + + CountDownLatch completionLatch = new CountDownLatch(1); + AtomicReference completedState = new AtomicReference<>(); + AtomicBoolean completionUnexpectedEvent = new AtomicBoolean(false); + + BiConsumer completionConsumer = (event, agentCard) -> { + if (event instanceof TaskEvent te) { + completedState.set(te.getTask().status().state()); + completionLatch.countDown(); + } else { + completionUnexpectedEvent.set(true); + } + }; + + // Send input - task will be completed + getNonStreamingClient().sendMessage(inputMessage, List.of(completionConsumer), null); + assertTrue(completionLatch.await(10, TimeUnit.SECONDS)); + assertFalse(completionUnexpectedEvent.get()); + assertEquals(TaskState.COMPLETED, completedState.get()); + + } finally { + deleteTaskInTaskStore(inputRequiredTaskId); + } + } + @Test public void testMalformedJSONRPCRequest() { // skip this test for non-JSONRPC transports diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentExecutorProducer.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentExecutorProducer.java index 8b03a68f4..364d2275f 100644 --- a/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentExecutorProducer.java +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentExecutorProducer.java @@ -10,6 +10,9 @@ import io.a2a.server.events.EventQueue; import io.a2a.server.tasks.TaskUpdater; import io.a2a.spec.A2AError; +import io.a2a.spec.InvalidParamsError; +import io.a2a.spec.Message; +import io.a2a.spec.Part; import io.a2a.spec.TextPart; import io.a2a.spec.UnsupportedOperationError; import io.quarkus.arc.profile.IfBuildProfile; @@ -43,6 +46,28 @@ public void execute(RequestContext context, EventQueue eventQueue) throws A2AErr } } + // Special handling for input-required test + if (taskId != null && taskId.startsWith("input-required-test")) { + // First call: context.getTask() == null (new task) + if (context.getTask() == null) { + updater.startWork(); + updater.requiresInput(updater.newAgentMessage( + List.of(new TextPart("Please provide additional information")), + context.getMessage().metadata())); + // Return immediately - queue stays open because task is in INPUT_REQUIRED state + return; + } else { + String input = extractTextFromMessage(context.getMessage()); + if(! "User input".equals(input)) { + throw new InvalidParamsError("We didn't get the expected input"); + } + // Second call: context.getTask() != null (input provided) + updater.startWork(); + updater.complete(); + return; + } + } + if (context.getTaskId().equals("task-not-supported-123")) { eventQueue.enqueueEvent(new UnsupportedOperationError()); } @@ -60,4 +85,21 @@ public void cancel(RequestContext context, EventQueue eventQueue) throws A2AErro } }; } + + /** + * Extract the content of TextPart in a message to create a single String. + * @param message the message containing the TextPart. + * @return a String aggreagating all the TextPart contents of the message. + */ + private String extractTextFromMessage(final Message message) { + final StringBuilder textBuilder = new StringBuilder(); + if (message.parts() != null) { + for (final Part part : message.parts()) { + if (part instanceof TextPart textPart) { + textBuilder.append(textPart.text()); + } + } + } + return textBuilder.toString(); + } }