From 9611f8967e528c6242e17ad3ad5419e0b25fb3fb Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 19 Dec 2025 06:29:32 -0800 Subject: [PATCH] feat(HITL): Let ADK resume after HITL approval is present feat(HITL): Declining a proposal now correctly intercepts the run fix: Events for HITL are now emitted correctly fix: HITL endless loop when asking for approvals PiperOrigin-RevId: 846709135 --- .../google/adk/flows/llmflows/Functions.java | 238 +++++++++--------- ...equestConfirmationLlmRequestProcessor.java | 222 +++++++++++----- .../google/adk/flows/llmflows/SingleFlow.java | 1 + .../com/google/adk/tools/ToolContext.java | 14 ++ .../adk/flows/llmflows/FunctionsTest.java | 103 +++++++- ...stConfirmationLlmRequestProcessorTest.java | 93 ++++++- 6 files changed, 490 insertions(+), 181 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index a952d602d..2f1126abc 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -17,6 +17,7 @@ package com.google.adk.flows.llmflows; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import com.google.adk.Telemetry; @@ -62,9 +63,13 @@ /** Utility class for handling function calls. */ public final class Functions { + /** The function call name for the request confirmation function. */ + public static final String REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = "adk_request_confirmation"; + + /** Session state key for storing the security policy outcomes for tool calls. */ + public static final String TOOL_CALL_SECURITY_STATES = "adk_tool_call_security_states"; private static final String AF_FUNCTION_CALL_ID_PREFIX = "adk-"; - static final String REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = "adk_request_confirmation"; private static final Logger logger = LoggerFactory.getLogger(Functions.class); /** Generates a unique ID for a function call. */ @@ -145,59 +150,7 @@ public static Maybe handleFunctionCalls( } Function> functionCallMapper = - functionCall -> { - BaseTool tool = tools.get(functionCall.name().get()); - ToolContext toolContext = - ToolContext.builder(invocationContext) - .functionCallId(functionCall.id().orElse("")) - .toolConfirmation(toolConfirmations.get(functionCall.id().orElse(null))) - .build(); - - Map functionArgs = functionCall.args().orElse(ImmutableMap.of()); - - Maybe> maybeFunctionResult = - maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) - .switchIfEmpty(Maybe.defer(() -> callTool(tool, functionArgs, toolContext))); - - return maybeFunctionResult - .map(Optional::of) - .defaultIfEmpty(Optional.empty()) - .onErrorResumeNext( - t -> - invocationContext - .pluginManager() - .runOnToolErrorCallback(tool, functionArgs, toolContext, t) - .map(Optional::of) - .switchIfEmpty(Single.error(t))) - .flatMapMaybe( - optionalInitialResult -> { - Map initialFunctionResult = optionalInitialResult.orElse(null); - - Maybe> afterToolResultMaybe = - maybeInvokeAfterToolCall( - invocationContext, - tool, - functionArgs, - toolContext, - initialFunctionResult); - - return afterToolResultMaybe - .map(Optional::of) - .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) - .flatMapMaybe( - finalOptionalResult -> { - Map finalFunctionResult = - finalOptionalResult.orElse(null); - if (tool.longRunning() && finalFunctionResult == null) { - return Maybe.empty(); - } - Event functionResponseEvent = - buildResponseEvent( - tool, finalFunctionResult, toolContext, invocationContext); - return Maybe.just(functionResponseEvent); - }); - }); - }; + getFunctionCallMapper(invocationContext, tools, toolConfirmations, false); Flowable functionResponseEventsFlowable; if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) { @@ -241,6 +194,18 @@ public static Maybe handleFunctionCalls( */ public static Maybe handleFunctionCallsLive( InvocationContext invocationContext, Event functionCallEvent, Map tools) { + return handleFunctionCallsLive(invocationContext, functionCallEvent, tools, ImmutableMap.of()); + } + + /** + * Handles function calls in a live/streaming context with tool confirmations, supporting + * background execution and stream termination. + */ + public static Maybe handleFunctionCallsLive( + InvocationContext invocationContext, + Event functionCallEvent, + Map tools, + Map toolConfirmations) { ImmutableList functionCalls = functionCallEvent.functionCalls(); for (FunctionCall functionCall : functionCalls) { @@ -250,72 +215,12 @@ public static Maybe handleFunctionCallsLive( } Function> functionCallMapper = - functionCall -> { - BaseTool tool = tools.get(functionCall.name().get()); - ToolContext toolContext = - ToolContext.builder(invocationContext) - .functionCallId(functionCall.id().orElse("")) - .build(); - Map functionArgs = functionCall.args().orElse(new HashMap<>()); - - Maybe> maybeFunctionResult = - maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) - .switchIfEmpty( - Maybe.defer( - () -> - processFunctionLive( - invocationContext, - tool, - toolContext, - functionCall, - functionArgs))); - - return maybeFunctionResult - .map(Optional::of) - .defaultIfEmpty(Optional.empty()) - .onErrorResumeNext( - t -> - invocationContext - .pluginManager() - .runOnToolErrorCallback(tool, functionArgs, toolContext, t) - .map(Optional::ofNullable) - .switchIfEmpty(Single.error(t))) - .flatMapMaybe( - optionalInitialResult -> { - Map initialFunctionResult = optionalInitialResult.orElse(null); - - Maybe> afterToolResultMaybe = - maybeInvokeAfterToolCall( - invocationContext, - tool, - functionArgs, - toolContext, - initialFunctionResult); - - return afterToolResultMaybe - .map(Optional::of) - .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) - .flatMapMaybe( - finalOptionalResult -> { - Map finalFunctionResult = - finalOptionalResult.orElse(null); - if (tool.longRunning() && finalFunctionResult == null) { - return Maybe.empty(); - } - Event functionResponseEvent = - buildResponseEvent( - tool, finalFunctionResult, toolContext, invocationContext); - return Maybe.just(functionResponseEvent); - }); - }); - }; + getFunctionCallMapper(invocationContext, tools, toolConfirmations, true); Flowable responseEventsFlowable; - if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) { responseEventsFlowable = Flowable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper); - } else { responseEventsFlowable = Flowable.fromIterable(functionCalls).flatMapMaybe(functionCallMapper); @@ -332,6 +237,47 @@ public static Maybe handleFunctionCallsLive( }); } + private static Function> getFunctionCallMapper( + InvocationContext invocationContext, + Map tools, + Map toolConfirmations, + boolean isLive) { + return functionCall -> { + BaseTool tool = tools.get(functionCall.name().get()); + ToolContext toolContext = + ToolContext.builder(invocationContext) + .functionCallId(functionCall.id().orElse("")) + .toolConfirmation(functionCall.id().map(toolConfirmations::get).orElse(null)) + .build(); + + ToolConfirmation toolConfirmation = toolContext.toolConfirmation().orElse(null); + if (toolConfirmation != null && !toolConfirmation.confirmed()) { + return Maybe.just( + buildResponseEvent( + tool, + ImmutableMap.of("error", "User declined tool execution for " + tool.name()), + toolContext, + invocationContext)); + } + + Map functionArgs = + functionCall.args().orElse(isLive ? new HashMap<>() : ImmutableMap.of()); + + Maybe> maybeFunctionResult = + maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) + .switchIfEmpty( + Maybe.defer( + () -> + isLive + ? processFunctionLive( + invocationContext, tool, toolContext, functionCall, functionArgs) + : callTool(tool, functionArgs, toolContext))); + + return postProcessFunctionResult( + maybeFunctionResult, invocationContext, tool, functionArgs, toolContext, isLive); + }; + } + /** * Processes a single function call in a live context. Manages starting, stopping, and running * tools. @@ -427,6 +373,48 @@ public static Set getLongRunningFunctionCalls( return longRunningFunctionCalls; } + private static Maybe postProcessFunctionResult( + Maybe> maybeFunctionResult, + InvocationContext invocationContext, + BaseTool tool, + Map functionArgs, + ToolContext toolContext, + boolean isLive) { + return maybeFunctionResult + .map(Optional::of) + .defaultIfEmpty(Optional.empty()) + .onErrorResumeNext( + t -> + invocationContext + .pluginManager() + .runOnToolErrorCallback(tool, functionArgs, toolContext, t) + .map(isLive ? Optional::ofNullable : Optional::of) + .switchIfEmpty(Single.error(t))) + .flatMapMaybe( + optionalInitialResult -> { + Map initialFunctionResult = optionalInitialResult.orElse(null); + + Maybe> afterToolResultMaybe = + maybeInvokeAfterToolCall( + invocationContext, tool, functionArgs, toolContext, initialFunctionResult); + + return afterToolResultMaybe + .map(Optional::of) + .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) + .flatMapMaybe( + finalOptionalResult -> { + Map finalFunctionResult = finalOptionalResult.orElse(null); + if (tool.longRunning() && finalFunctionResult == null) { + return Maybe.empty(); + } + Event functionResponseEvent = + buildResponseEvent( + tool, finalFunctionResult, toolContext, invocationContext); + return Maybe.just(functionResponseEvent); + }); + }); + } + private static Optional mergeParallelFunctionResponseEvents( List functionResponseEvents) { if (functionResponseEvents.isEmpty()) { @@ -664,5 +652,25 @@ public static Optional generateRequestConfirmationEvent( .build()); } + /** + * Gets the ask user confirmation function calls from the event. + * + * @param event The event to extract function calls from. + * @return A list of function calls for asking user confirmation. + */ + public static ImmutableList getAskUserConfirmationFunctionCalls(Event event) { + return event.content().flatMap(Content::parts).orElse(ImmutableList.of()).stream() + .flatMap(part -> part.functionCall().stream()) + .filter(Functions::isRequestConfirmationFunctionCall) + .collect(toImmutableList()); + } + + private static boolean isRequestConfirmationFunctionCall(FunctionCall functionCall) { + return functionCall + .name() + .map(name -> name.equals(REQUEST_CONFIRMATION_FUNCTION_CALL_NAME)) + .orElse(false); + } + private Functions() {} } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java b/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java index cd08a2c78..b06abeacb 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java @@ -17,12 +17,14 @@ package com.google.adk.flows.llmflows; import static com.google.adk.flows.llmflows.Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME; +import static com.google.adk.flows.llmflows.Functions.TOOL_CALL_SECURITY_STATES; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; +import com.google.adk.JsonBaseModel; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; import com.google.adk.events.Event; @@ -31,6 +33,7 @@ import com.google.adk.tools.ToolConfirmation; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; @@ -38,7 +41,7 @@ import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.Collection; -import java.util.List; +import java.util.HashMap; import java.util.Map; import java.util.Objects; import java.util.Optional; @@ -49,59 +52,154 @@ public class RequestConfirmationLlmRequestProcessor implements RequestProcessor { private static final Logger logger = LoggerFactory.getLogger(RequestConfirmationLlmRequestProcessor.class); - private final ObjectMapper objectMapper; - - public RequestConfirmationLlmRequestProcessor() { - objectMapper = new ObjectMapper().registerModule(new Jdk8Module()); - } + private static final ObjectMapper OBJECT_MAPPER = JsonBaseModel.getMapper(); + private static final String ORIGINAL_FUNCTION_CALL = "originalFunctionCall"; @Override public Single processRequest( InvocationContext invocationContext, LlmRequest llmRequest) { - List events = invocationContext.session().events(); + ImmutableList events = ImmutableList.copyOf(invocationContext.session().events()); if (events.isEmpty()) { logger.info( "No events are present in the session. Skipping request confirmation processing."); return Single.just(RequestProcessingResult.create(llmRequest, ImmutableList.of())); } - ImmutableMap requestConfirmationFunctionResponses = - filterRequestConfirmationFunctionResponses(events); - if (requestConfirmationFunctionResponses.isEmpty()) { + int confirmationEventIndex = -1; + ImmutableMap responses = ImmutableMap.of(); + // Search backwards for the most recent user event that contains request confirmation + // function responses. + for (int i = events.size() - 1; i >= 0; i--) { + Event event = events.get(i); + if (!Objects.equals(event.author(), "user") || event.functionResponses().isEmpty()) { + continue; + } + + ImmutableMap confirmationsInEvent = + event.functionResponses().stream() + .filter(functionResponse -> functionResponse.id().isPresent()) + .filter( + functionResponse -> + Objects.equals( + functionResponse.name().orElse(null), + REQUEST_CONFIRMATION_FUNCTION_CALL_NAME)) + .map(this::maybeCreateToolConfirmationEntry) + .flatMap(Optional::stream) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + if (!confirmationsInEvent.isEmpty()) { + responses = confirmationsInEvent; + confirmationEventIndex = i; + break; + } + } + if (responses.isEmpty()) { logger.info("No request confirmation function responses found."); return Single.just(RequestProcessingResult.create(llmRequest, ImmutableList.of())); } - for (ImmutableList functionCalls : - events.stream() - .map(Event::functionCalls) - .filter(fc -> !fc.isEmpty()) - .collect(toImmutableList())) { + // Make them final to enable access from lambda expressions. + final int finalConfirmationEventIndex = confirmationEventIndex; + final ImmutableMap requestConfirmationFunctionResponses = responses; - ImmutableMap toolsToResumeWithArgs = - filterToolsToResumeWithArgs(functionCalls, requestConfirmationFunctionResponses); - ImmutableMap toolsToResumeWithConfirmation = - toolsToResumeWithArgs.keySet().stream() - .filter( - id -> - events.stream() - .flatMap(e -> e.functionResponses().stream()) - .anyMatch(fr -> Objects.equals(fr.id().orElse(null), id))) - .collect(toImmutableMap(k -> k, requestConfirmationFunctionResponses::get)); + // Search backwards from the event before confirmation for the corresponding + // request_confirmation function calls emitted by the model. + for (int i = finalConfirmationEventIndex - 1; i >= 0; i--) { + Event event = events.get(i); + if (event.functionCalls().isEmpty()) { + continue; + } + + Map toolsToResumeWithConfirmation = new HashMap<>(); + Map toolsToResumeWithArgs = new HashMap<>(); + + event.functionCalls().stream() + .filter( + fc -> + fc.id().isPresent() + && requestConfirmationFunctionResponses.containsKey(fc.id().get())) + .forEach( + fc -> + getOriginalFunctionCall(fc) + .ifPresent( + ofc -> { + toolsToResumeWithConfirmation.put( + ofc.id().get(), + requestConfirmationFunctionResponses.get(fc.id().get())); + toolsToResumeWithArgs.put(ofc.id().get(), ofc); + })); + + if (toolsToResumeWithConfirmation.isEmpty()) { + continue; + } + + // If a tool has been confirmed, it might have been executed by a subsequent + // processor, or in a subsequent turn. We identify tools that have already been + // executed by checking for function responses with matching IDs in events that + // occurred *after* the user confirmation event. + ImmutableSet alreadyConfirmedIds = + events.subList(finalConfirmationEventIndex + 1, events.size()).stream() + .flatMap(e -> e.functionResponses().stream()) + .map(FunctionResponse::id) + .flatMap(Optional::stream) + .collect(toImmutableSet()); + toolsToResumeWithConfirmation.keySet().removeAll(alreadyConfirmedIds); + toolsToResumeWithArgs.keySet().removeAll(alreadyConfirmedIds); + + // If all confirmed tools in this event have already been processed, continue + // searching in older events. if (toolsToResumeWithConfirmation.isEmpty()) { - logger.info("No tools to resume with confirmation."); continue; } + // If we found tools that were confirmed but not yet executed, execute them now. return assembleEvent( - invocationContext, toolsToResumeWithArgs.values(), toolsToResumeWithConfirmation) - .map(event -> RequestProcessingResult.create(llmRequest, ImmutableList.of(event))) - .toSingle(); + invocationContext, + toolsToResumeWithArgs.values(), + ImmutableMap.copyOf(toolsToResumeWithConfirmation)) + .map( + assembledEvent -> { + clearToolCallSecurityStates(invocationContext, toolsToResumeWithArgs.keySet()); + + // Create an updated LlmRequest including the new event's content + ImmutableList.Builder updatedContentsBuilder = + ImmutableList.builder().addAll(llmRequest.contents()); + assembledEvent.content().ifPresent(updatedContentsBuilder::add); + + LlmRequest updatedLlmRequest = + llmRequest.toBuilder().contents(updatedContentsBuilder.build()).build(); + + return RequestProcessingResult.create( + updatedLlmRequest, ImmutableList.of(assembledEvent)); + }) + .toSingle() + .onErrorReturn( + e -> { + logger.error("Error processing request confirmation", e); + return RequestProcessingResult.create(llmRequest, ImmutableList.of()); + }); } return Single.just(RequestProcessingResult.create(llmRequest, ImmutableList.of())); } + private Optional getOriginalFunctionCall(FunctionCall functionCall) { + if (!functionCall.args().orElse(ImmutableMap.of()).containsKey(ORIGINAL_FUNCTION_CALL)) { + return Optional.empty(); + } + try { + FunctionCall originalFunctionCall = + OBJECT_MAPPER.convertValue( + functionCall.args().get().get(ORIGINAL_FUNCTION_CALL), FunctionCall.class); + if (originalFunctionCall.id().isEmpty()) { + return Optional.empty(); + } + return Optional.of(originalFunctionCall); + } catch (IllegalArgumentException e) { + logger.warn("Failed to convert originalFunctionCall argument.", e); + return Optional.empty(); + } + } + private Maybe assembleEvent( InvocationContext invocationContext, Collection functionCalls, @@ -135,21 +233,6 @@ private Maybe assembleEvent( invocationContext, functionCallEvent, toolsMap, toolConfirmations)); } - private ImmutableMap filterRequestConfirmationFunctionResponses( - List events) { - return events.stream() - .filter(event -> Objects.equals(event.author(), "user")) - .flatMap(event -> event.functionResponses().stream()) - .filter(functionResponse -> functionResponse.id().isPresent()) - .filter( - functionResponse -> - Objects.equals( - functionResponse.name().orElse(null), REQUEST_CONFIRMATION_FUNCTION_CALL_NAME)) - .map(this::maybeCreateToolConfirmationEntry) - .flatMap(Optional::stream) - .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); - } - private Optional> maybeCreateToolConfirmationEntry( FunctionResponse functionResponse) { Map responseMap = functionResponse.response().orElse(ImmutableMap.of()); @@ -157,14 +240,14 @@ private Optional> maybeCreateToolConfirmatio return Optional.of( Map.entry( functionResponse.id().get(), - objectMapper.convertValue(responseMap, ToolConfirmation.class))); + OBJECT_MAPPER.convertValue(responseMap, ToolConfirmation.class))); } try { return Optional.of( Map.entry( functionResponse.id().get(), - objectMapper.readValue( + OBJECT_MAPPER.readValue( (String) responseMap.get("response"), ToolConfirmation.class))); } catch (JsonProcessingException e) { logger.error("Failed to parse tool confirmation response", e); @@ -173,20 +256,35 @@ private Optional> maybeCreateToolConfirmatio return Optional.empty(); } - private ImmutableMap filterToolsToResumeWithArgs( - ImmutableList functionCalls, - Map requestConfirmationFunctionResponses) { - return functionCalls.stream() - .filter(fc -> fc.id().isPresent()) - .filter(fc -> requestConfirmationFunctionResponses.containsKey(fc.id().get())) - .filter( - fc -> Objects.equals(fc.name().orElse(null), REQUEST_CONFIRMATION_FUNCTION_CALL_NAME)) - .filter(fc -> fc.args().orElse(ImmutableMap.of()).containsKey("originalFunctionCall")) - .collect( - toImmutableMap( - fc -> fc.id().get(), - fc -> - objectMapper.convertValue( - fc.args().get().get("originalFunctionCall"), FunctionCall.class))); + private void clearToolCallSecurityStates( + InvocationContext invocationContext, Collection processedFunctionCallIds) { + var state = invocationContext.session().state(); + Object statesObj = state.get(TOOL_CALL_SECURITY_STATES); + + if (statesObj == null) { + return; + } + if (!(statesObj instanceof Map)) { + logger.warn( + "Session key {} does not contain a Map, cannot clear tool states. Found: {}", + TOOL_CALL_SECURITY_STATES, + statesObj.getClass().getName()); + return; + } + + try { + @SuppressWarnings("unchecked") // safe after instanceof check + Map updatedToolCallStates = new HashMap<>((Map) statesObj); + + // Remove the entries for the function calls that just got processed + processedFunctionCallIds.forEach(updatedToolCallStates::remove); + + state.put(TOOL_CALL_SECURITY_STATES, updatedToolCallStates); + } catch (ClassCastException e) { + logger.warn( + "Session key {} has unexpected map types, cannot clear tool states.", + TOOL_CALL_SECURITY_STATES, + e); + } } } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java index 2aeaf6225..e96099a72 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java @@ -31,6 +31,7 @@ public class SingleFlow extends BaseLlmFlow { new Identity(), new Contents(), new Examples(), + new RequestConfirmationLlmRequestProcessor(), CodeExecution.requestProcessor); protected static final ImmutableList RESPONSE_PROCESSORS = diff --git a/core/src/main/java/com/google/adk/tools/ToolContext.java b/core/src/main/java/com/google/adk/tools/ToolContext.java index 76fe47a74..50a27bd85 100644 --- a/core/src/main/java/com/google/adk/tools/ToolContext.java +++ b/core/src/main/java/com/google/adk/tools/ToolContext.java @@ -127,6 +127,20 @@ public Builder toBuilder() { .toolConfirmation(toolConfirmation.orElse(null)); } + @Override + public String toString() { + return "ToolContext{" + + "invocationContext=" + + invocationContext + + ", eventActions=" + + eventActions + + ", functionCallId=" + + functionCallId + + ", toolConfirmation=" + + toolConfirmation + + '}'; + } + /** Builder for {@link ToolContext}. */ public static final class Builder { private final InvocationContext invocationContext; diff --git a/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java index d880d7d85..775fab5fc 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java @@ -31,6 +31,7 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; +import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -38,6 +39,39 @@ /** Unit tests for {@link Functions}. */ @RunWith(JUnit4.class) public final class FunctionsTest { + + private static final Event EVENT_WITH_NO_CONTENT = + Event.builder() + .id("event1") + .invocationId("invocation1") + .author("agent") + .content(Optional.empty()) + .build(); + + private static final Event EVENT_WITH_NO_PARTS = + Event.builder() + .id("event1") + .invocationId("invocation1") + .author("agent") + .content(Content.builder().role("model").parts(ImmutableList.of()).build()) + .build(); + + private static final Event EVENT_WITH_NO_FUNCTION_CALLS = + Event.builder() + .id("event1") + .invocationId("invocation1") + .author("agent") + .content(Content.fromParts(Part.fromText("hello"))) + .build(); + + private static final Event EVENT_WITH_NON_CONFIRMATION_FUNCTION_CALL = + Event.builder() + .id("event1") + .invocationId("invocation1") + .author("agent") + .content(Content.fromParts(Part.fromFunctionCall("other_function", ImmutableMap.of()))) + .build(); + @Test public void handleFunctionCalls_noFunctionCalls() { InvocationContext invocationContext = createInvocationContext(createRootAgent()); @@ -215,6 +249,7 @@ public void populateClientFunctionCallId_withEmptyId_populatesId() { @Test public void populateClientFunctionCallId_withExistingId_noChange() { + String id = "some_id"; Event event = createEvent("event").toBuilder() .content( @@ -223,13 +258,77 @@ public void populateClientFunctionCallId_withExistingId_noChange() { .functionCall( FunctionCall.builder() .name("echo_tool") - .id("some_id") + .id(id) .args(ImmutableMap.of("key", "value")) .build()) .build())) .build(); Functions.populateClientFunctionCallId(event); - assertThat(event).isEqualTo(event); + + assertThat(event.content().get().parts().get().get(0).functionCall().get().id()).hasValue(id); + } + + @Test + public void getAskUserConfirmationFunctionCalls_eventWithNoContent_returnsEmptyList() { + assertThat(Functions.getAskUserConfirmationFunctionCalls(EVENT_WITH_NO_CONTENT)).isEmpty(); + } + + @Test + public void getAskUserConfirmationFunctionCalls_eventWithNoParts_returnsEmptyList() { + assertThat(Functions.getAskUserConfirmationFunctionCalls(EVENT_WITH_NO_PARTS)).isEmpty(); + } + + @Test + public void getAskUserConfirmationFunctionCalls_eventWithNoFunctionCalls_returnsEmptyList() { + assertThat(Functions.getAskUserConfirmationFunctionCalls(EVENT_WITH_NO_FUNCTION_CALLS)) + .isEmpty(); + } + + @Test + public void + getAskUserConfirmationFunctionCalls_eventWithNonConfirmationFunctionCall_returnsEmptyList() { + assertThat( + Functions.getAskUserConfirmationFunctionCalls( + EVENT_WITH_NON_CONFIRMATION_FUNCTION_CALL)) + .isEmpty(); + } + + @Test + public void getAskUserConfirmationFunctionCalls_eventWithConfirmationFunctionCall_returnsCall() { + FunctionCall confirmationCall = + FunctionCall.builder().name(Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME).build(); + Event event = + Event.builder() + .id("event1") + .invocationId("invocation1") + .author("agent") + .content(Content.fromParts(Part.builder().functionCall(confirmationCall).build())) + .build(); + ImmutableList result = Functions.getAskUserConfirmationFunctionCalls(event); + assertThat(result).containsExactly(confirmationCall); + } + + @Test + public void + getAskUserConfirmationFunctionCalls_eventWithMixedParts_returnsOnlyConfirmationCalls() { + FunctionCall confirmationCall1 = + FunctionCall.builder().name(Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME).build(); + FunctionCall confirmationCall2 = + FunctionCall.builder().name(Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME).build(); + Event event = + Event.builder() + .id("event1") + .invocationId("invocation1") + .author("agent") + .content( + Content.fromParts( + Part.fromText("hello"), + Part.builder().functionCall(confirmationCall1).build(), + Part.fromFunctionCall("other_function", ImmutableMap.of()), + Part.builder().functionCall(confirmationCall2).build())) + .build(); + ImmutableList result = Functions.getAskUserConfirmationFunctionCalls(event); + assertThat(result).containsExactly(confirmationCall1, confirmationCall2); } } diff --git a/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java b/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java index 9c3de70b0..971930907 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java @@ -44,6 +44,7 @@ @RunWith(JUnit4.class) public class RequestConfirmationLlmRequestProcessorTest { + private static final String ORIGINAL_FUNCTION_CALL_ID = "fc0"; private static final String REQUEST_CONFIRMATION_FUNCTION_CALL_ID = "fc1"; private static final String ECHO_TOOL_NAME = "echo_tool"; @@ -62,7 +63,7 @@ public class RequestConfirmationLlmRequestProcessorTest { "originalFunctionCall", ImmutableMap.of( "id", - "fc0", + ORIGINAL_FUNCTION_CALL_ID, "name", ECHO_TOOL_NAME, "args", @@ -86,6 +87,21 @@ public class RequestConfirmationLlmRequestProcessorTest { .build())) .build(); + private static final Event USER_DECLINE_EVENT = + Event.builder() + .author("user") + .content( + Content.fromParts( + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id(REQUEST_CONFIRMATION_FUNCTION_CALL_ID) + .name(REQUEST_CONFIRMATION_FUNCTION_CALL_NAME) + .response(ImmutableMap.of("confirmed", false)) + .build()) + .build())) + .build(); + private static final RequestConfirmationLlmRequestProcessor processor = new RequestConfirmationLlmRequestProcessor(); @@ -107,11 +123,69 @@ public void runAsync_withConfirmation_callsOriginalFunction() { Event event = result.events().iterator().next(); assertThat(event.functionResponses()).hasSize(1); FunctionResponse fr = event.functionResponses().get(0); - assertThat(fr.id()).hasValue("fc0"); + assertThat(fr.id()).hasValue(ORIGINAL_FUNCTION_CALL_ID); assertThat(fr.name()).hasValue(ECHO_TOOL_NAME); assertThat(fr.response()).hasValue(ImmutableMap.of("result", ImmutableMap.of("say", "hello"))); } + @Test + public void runAsync_withDecline_returnsErrorFunctionResponse() { + LlmAgent agent = createAgentWithEchoTool(); + Session session = + Session.builder("session_id") + .events(ImmutableList.of(REQUEST_CONFIRMATION_EVENT, USER_DECLINE_EVENT)) + .build(); + + InvocationContext context = createInvocationContext(agent, session); + + RequestProcessor.RequestProcessingResult result = + processor.processRequest(context, LlmRequest.builder().build()).blockingGet(); + + assertThat(result).isNotNull(); + assertThat(result.events()).hasSize(1); + Event event = result.events().iterator().next(); + assertThat(event.functionResponses()).hasSize(1); + FunctionResponse fr = event.functionResponses().get(0); + assertThat(fr.id()).hasValue(ORIGINAL_FUNCTION_CALL_ID); + assertThat(fr.name()).hasValue(ECHO_TOOL_NAME); + assertThat(fr.response()) + .hasValue(ImmutableMap.of("error", "User declined tool execution for echo_tool")); + } + + @Test + public void runAsync_withConfirmationAndToolAlreadyCalled_doesNotCallOriginalFunction() { + LlmAgent agent = createAgentWithEchoTool(); + Event toolResponseEvent = + Event.builder() + .author("model") + .content( + Content.fromParts( + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id(ORIGINAL_FUNCTION_CALL_ID) + .name(ECHO_TOOL_NAME) + .response( + ImmutableMap.of("result", ImmutableMap.of("say", "hello"))) + .build()) + .build())) + .build(); + Session session = + Session.builder("session_id") + .events( + ImmutableList.of( + REQUEST_CONFIRMATION_EVENT, USER_CONFIRMATION_EVENT, toolResponseEvent)) + .build(); + + InvocationContext context = createInvocationContext(agent, session); + + RequestProcessor.RequestProcessingResult result = + processor.processRequest(context, LlmRequest.builder().build()).blockingGet(); + + assertThat(result).isNotNull(); + assertThat(result.events()).isEmpty(); + } + @Test public void runAsync_noEvents_empty() { LlmAgent agent = createAgentWithEchoTool(); @@ -126,6 +200,21 @@ public void runAsync_noEvents_empty() { .isEmpty(); } + @Test + public void runAsync_noUserConfirmationEvent_empty() { + LlmAgent agent = createAgentWithEchoTool(); + Session session = + Session.builder("session_id").events(ImmutableList.of(REQUEST_CONFIRMATION_EVENT)).build(); + + assertThat( + processor + .processRequest( + createInvocationContext(agent, session), LlmRequest.builder().build()) + .blockingGet() + .events()) + .isEmpty(); + } + private static InvocationContext createInvocationContext(LlmAgent agent, Session session) { return new InvocationContext( /* sessionService= */ null,