Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,68 @@ public void testNonBlockingWithMultipleMessages() throws Exception {
}
}

@Test
@Timeout(value = 1, unit = TimeUnit.MINUTES)
public void testInputRequiredWorkflow() throws Exception {
String inputRequiredTaskId = "input-required-test-" + java.util.UUID.randomUUID();
try {
// 1. Send initial message - AgentExecutor will transition task to INPUT_REQUIRED
Message initialMessage = Message.builder(MESSAGE)
.taskId(inputRequiredTaskId)
.contextId("test-context")
.parts(new TextPart("Initial request"))
.build();

CountDownLatch initialLatch = new CountDownLatch(1);
AtomicReference<TaskState> initialState = new AtomicReference<>();
AtomicBoolean initialUnexpectedEvent = new AtomicBoolean(false);

BiConsumer<ClientEvent, AgentCard> initialConsumer = (event, agentCard) -> {
if (event instanceof TaskEvent te) {
initialState.set(te.getTask().status().state());
initialLatch.countDown();
} else {
initialUnexpectedEvent.set(true);
}
};

// Send initial message - task will go to INPUT_REQUIRED state
getNonStreamingClient().sendMessage(initialMessage, List.of(initialConsumer), null);
assertTrue(initialLatch.await(10, TimeUnit.SECONDS));
assertFalse(initialUnexpectedEvent.get());
assertEquals(TaskState.INPUT_REQUIRED, initialState.get());

// 2. Send input message - AgentExecutor will complete the task
Message inputMessage = Message.builder(MESSAGE)
.taskId(inputRequiredTaskId)
.contextId("test-context")
.parts(new TextPart("User input"))
.build();

CountDownLatch completionLatch = new CountDownLatch(1);
AtomicReference<TaskState> completedState = new AtomicReference<>();
AtomicBoolean completionUnexpectedEvent = new AtomicBoolean(false);

BiConsumer<ClientEvent, AgentCard> completionConsumer = (event, agentCard) -> {
if (event instanceof TaskEvent te) {
completedState.set(te.getTask().status().state());
completionLatch.countDown();
} else {
completionUnexpectedEvent.set(true);
}
};

// Send input - task will be completed
getNonStreamingClient().sendMessage(inputMessage, List.of(completionConsumer), null);
assertTrue(completionLatch.await(10, TimeUnit.SECONDS));
assertFalse(completionUnexpectedEvent.get());
assertEquals(TaskState.COMPLETED, completedState.get());

} finally {
deleteTaskInTaskStore(inputRequiredTaskId);
}
}

@Test
public void testMalformedJSONRPCRequest() {
// skip this test for non-JSONRPC transports
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import io.a2a.server.events.EventQueue;
import io.a2a.server.tasks.TaskUpdater;
import io.a2a.spec.A2AError;
import io.a2a.spec.InvalidParamsError;
import io.a2a.spec.Message;
import io.a2a.spec.Part;
import io.a2a.spec.TextPart;
import io.a2a.spec.UnsupportedOperationError;
import io.quarkus.arc.profile.IfBuildProfile;
Expand Down Expand Up @@ -43,6 +46,28 @@ public void execute(RequestContext context, EventQueue eventQueue) throws A2AErr
}
}

// Special handling for input-required test
if (taskId != null && taskId.startsWith("input-required-test")) {
// First call: context.getTask() == null (new task)
if (context.getTask() == null) {
updater.startWork();
updater.requiresInput(updater.newAgentMessage(
List.of(new TextPart("Please provide additional information")),
context.getMessage().metadata()));
// Return immediately - queue stays open because task is in INPUT_REQUIRED state
return;
} else {
String input = extractTextFromMessage(context.getMessage());
if(! "User input".equals(input)) {
throw new InvalidParamsError("We didn't get the expected input");
}
// Second call: context.getTask() != null (input provided)
updater.startWork();
updater.complete();
return;
}
}

if (context.getTaskId().equals("task-not-supported-123")) {
eventQueue.enqueueEvent(new UnsupportedOperationError());
}
Expand All @@ -60,4 +85,21 @@ public void cancel(RequestContext context, EventQueue eventQueue) throws A2AErro
}
};
}

/**
* Extract the content of TextPart in a message to create a single String.
* @param message the message containing the TextPart.
* @return a String aggreagating all the TextPart contents of the message.
*/
private String extractTextFromMessage(final Message message) {
final StringBuilder textBuilder = new StringBuilder();
if (message.parts() != null) {
for (final Part part : message.parts()) {
if (part instanceof TextPart textPart) {
textBuilder.append(textPart.text());
}
}
}
return textBuilder.toString();
}
}