From 2f61aa11e9e6e4b1a6a233fd610c6f13d5f6cc4a Mon Sep 17 00:00:00 2001 From: Simon Su Date: Wed, 17 Dec 2025 17:01:11 +1100 Subject: [PATCH] Add confirmed function call to next LLM call --- ...equestConfirmationLlmRequestProcessor.java | 15 +++++- ...stConfirmationLlmRequestProcessorTest.java | 47 ++++++++++++++++++- 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java b/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java index b06abeacb..32e3e0457 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java @@ -42,6 +42,7 @@ import io.reactivex.rxjava3.core.Single; import java.util.Collection; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; @@ -109,8 +110,8 @@ public Single processRequest( continue; } - Map toolsToResumeWithConfirmation = new HashMap<>(); - Map toolsToResumeWithArgs = new HashMap<>(); + final Map toolsToResumeWithConfirmation = new HashMap<>(); + final Map toolsToResumeWithArgs = new HashMap<>(); event.functionCalls().stream() .filter( @@ -163,6 +164,16 @@ public Single processRequest( // Create an updated LlmRequest including the new event's content ImmutableList.Builder updatedContentsBuilder = ImmutableList.builder().addAll(llmRequest.contents()); + + final List functionCalls = + toolsToResumeWithArgs.values().stream() + .map(functionCall -> Part.builder().functionCall(functionCall).build()) + .collect(toImmutableList()); + Content functionCallContent = + Content.builder().role("model").parts(functionCalls).build(); + // add function call + updatedContentsBuilder.add(functionCallContent); + // add function response assembledEvent.content().ifPresent(updatedContentsBuilder::add); LlmRequest updatedLlmRequest = diff --git a/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java b/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java index 971930907..4f2ca3388 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java @@ -102,11 +102,48 @@ public class RequestConfirmationLlmRequestProcessorTest { .build())) .build(); + private static final Event FUNCTION_CALL_EVENT = + Event.builder() + .author("model") + .content( + Content.builder() + .role("model") + .parts( + Part.builder() + .functionCall( + FunctionCall.builder() + .id(ORIGINAL_FUNCTION_CALL_ID) + .name(ECHO_TOOL_NAME) + .args(ImmutableMap.of("say", "hello")) + .build()) + .build()) + .build()) + .build(); + + private static final Event FUNCTION_RESPONSE_EVENT = + Event.builder() + .author("user") + .content( + Content.builder() + .role("user") + .parts( + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id(ORIGINAL_FUNCTION_CALL_ID) + .name(ECHO_TOOL_NAME) + .response( + ImmutableMap.of("result", ImmutableMap.of("say", "hello"))) + .build()) + .build()) + .build()) + .build(); + private static final RequestConfirmationLlmRequestProcessor processor = new RequestConfirmationLlmRequestProcessor(); @Test - public void runAsync_withConfirmation_callsOriginalFunction() { + public void runAsync_withConfirmation_callsOriginalFunctionAndAppendsToUpdatedRequest() { LlmAgent agent = createAgentWithEchoTool(); Session session = Session.builder("session_id") @@ -126,6 +163,14 @@ public void runAsync_withConfirmation_callsOriginalFunction() { assertThat(fr.id()).hasValue(ORIGINAL_FUNCTION_CALL_ID); assertThat(fr.name()).hasValue(ECHO_TOOL_NAME); assertThat(fr.response()).hasValue(ImmutableMap.of("result", ImmutableMap.of("say", "hello"))); + assertThat(result.updatedRequest()) + .isEqualTo( + LlmRequest.builder() + .contents( + ImmutableList.of( + FUNCTION_CALL_EVENT.content().get(), + FUNCTION_RESPONSE_EVENT.content().get())) + .build()); } @Test