Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 123 additions & 115 deletions core/src/main/java/com/google/adk/flows/llmflows/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package com.google.adk.flows.llmflows;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;

import com.google.adk.Telemetry;
Expand Down Expand Up @@ -62,9 +63,13 @@

/** Utility class for handling function calls. */
public final class Functions {
/** The function call name for the request confirmation function. */
public static final String REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = "adk_request_confirmation";

/** Session state key for storing the security policy outcomes for tool calls. */
public static final String TOOL_CALL_SECURITY_STATES = "adk_tool_call_security_states";

private static final String AF_FUNCTION_CALL_ID_PREFIX = "adk-";
static final String REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = "adk_request_confirmation";
private static final Logger logger = LoggerFactory.getLogger(Functions.class);

/** Generates a unique ID for a function call. */
Expand Down Expand Up @@ -145,59 +150,7 @@ public static Maybe<Event> handleFunctionCalls(
}

Function<FunctionCall, Maybe<Event>> functionCallMapper =
functionCall -> {
BaseTool tool = tools.get(functionCall.name().get());
ToolContext toolContext =
ToolContext.builder(invocationContext)
.functionCallId(functionCall.id().orElse(""))
.toolConfirmation(toolConfirmations.get(functionCall.id().orElse(null)))
.build();

Map<String, Object> functionArgs = functionCall.args().orElse(ImmutableMap.of());

Maybe<Map<String, Object>> maybeFunctionResult =
maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext)
.switchIfEmpty(Maybe.defer(() -> callTool(tool, functionArgs, toolContext)));

return maybeFunctionResult
.map(Optional::of)
.defaultIfEmpty(Optional.empty())
.onErrorResumeNext(
t ->
invocationContext
.pluginManager()
.runOnToolErrorCallback(tool, functionArgs, toolContext, t)
.map(Optional::of)
.switchIfEmpty(Single.error(t)))
.flatMapMaybe(
optionalInitialResult -> {
Map<String, Object> initialFunctionResult = optionalInitialResult.orElse(null);

Maybe<Map<String, Object>> afterToolResultMaybe =
maybeInvokeAfterToolCall(
invocationContext,
tool,
functionArgs,
toolContext,
initialFunctionResult);

return afterToolResultMaybe
.map(Optional::of)
.defaultIfEmpty(Optional.ofNullable(initialFunctionResult))
.flatMapMaybe(
finalOptionalResult -> {
Map<String, Object> finalFunctionResult =
finalOptionalResult.orElse(null);
if (tool.longRunning() && finalFunctionResult == null) {
return Maybe.empty();
}
Event functionResponseEvent =
buildResponseEvent(
tool, finalFunctionResult, toolContext, invocationContext);
return Maybe.just(functionResponseEvent);
});
});
};
getFunctionCallMapper(invocationContext, tools, toolConfirmations, false);

Flowable<Event> functionResponseEventsFlowable;
if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) {
Expand Down Expand Up @@ -241,6 +194,18 @@ public static Maybe<Event> handleFunctionCalls(
*/
public static Maybe<Event> handleFunctionCallsLive(
InvocationContext invocationContext, Event functionCallEvent, Map<String, BaseTool> tools) {
return handleFunctionCallsLive(invocationContext, functionCallEvent, tools, ImmutableMap.of());
}

/**
* Handles function calls in a live/streaming context with tool confirmations, supporting
* background execution and stream termination.
*/
public static Maybe<Event> handleFunctionCallsLive(
InvocationContext invocationContext,
Event functionCallEvent,
Map<String, BaseTool> tools,
Map<String, ToolConfirmation> toolConfirmations) {
ImmutableList<FunctionCall> functionCalls = functionCallEvent.functionCalls();

for (FunctionCall functionCall : functionCalls) {
Expand All @@ -250,72 +215,12 @@ public static Maybe<Event> handleFunctionCallsLive(
}

Function<FunctionCall, Maybe<Event>> functionCallMapper =
functionCall -> {
BaseTool tool = tools.get(functionCall.name().get());
ToolContext toolContext =
ToolContext.builder(invocationContext)
.functionCallId(functionCall.id().orElse(""))
.build();
Map<String, Object> functionArgs = functionCall.args().orElse(new HashMap<>());

Maybe<Map<String, Object>> maybeFunctionResult =
maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext)
.switchIfEmpty(
Maybe.defer(
() ->
processFunctionLive(
invocationContext,
tool,
toolContext,
functionCall,
functionArgs)));

return maybeFunctionResult
.map(Optional::of)
.defaultIfEmpty(Optional.empty())
.onErrorResumeNext(
t ->
invocationContext
.pluginManager()
.runOnToolErrorCallback(tool, functionArgs, toolContext, t)
.map(Optional::ofNullable)
.switchIfEmpty(Single.error(t)))
.flatMapMaybe(
optionalInitialResult -> {
Map<String, Object> initialFunctionResult = optionalInitialResult.orElse(null);

Maybe<Map<String, Object>> afterToolResultMaybe =
maybeInvokeAfterToolCall(
invocationContext,
tool,
functionArgs,
toolContext,
initialFunctionResult);

return afterToolResultMaybe
.map(Optional::of)
.defaultIfEmpty(Optional.ofNullable(initialFunctionResult))
.flatMapMaybe(
finalOptionalResult -> {
Map<String, Object> finalFunctionResult =
finalOptionalResult.orElse(null);
if (tool.longRunning() && finalFunctionResult == null) {
return Maybe.empty();
}
Event functionResponseEvent =
buildResponseEvent(
tool, finalFunctionResult, toolContext, invocationContext);
return Maybe.just(functionResponseEvent);
});
});
};
getFunctionCallMapper(invocationContext, tools, toolConfirmations, true);

Flowable<Event> responseEventsFlowable;

if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) {
responseEventsFlowable =
Flowable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper);

} else {
responseEventsFlowable =
Flowable.fromIterable(functionCalls).flatMapMaybe(functionCallMapper);
Expand All @@ -332,6 +237,47 @@ public static Maybe<Event> handleFunctionCallsLive(
});
}

private static Function<FunctionCall, Maybe<Event>> getFunctionCallMapper(
InvocationContext invocationContext,
Map<String, BaseTool> tools,
Map<String, ToolConfirmation> toolConfirmations,
boolean isLive) {
return functionCall -> {
BaseTool tool = tools.get(functionCall.name().get());
ToolContext toolContext =
ToolContext.builder(invocationContext)
.functionCallId(functionCall.id().orElse(""))
.toolConfirmation(functionCall.id().map(toolConfirmations::get).orElse(null))
.build();

ToolConfirmation toolConfirmation = toolContext.toolConfirmation().orElse(null);
if (toolConfirmation != null && !toolConfirmation.confirmed()) {
return Maybe.just(
buildResponseEvent(
tool,
ImmutableMap.of("error", "User declined tool execution for " + tool.name()),
toolContext,
invocationContext));
}

Map<String, Object> functionArgs =
functionCall.args().orElse(isLive ? new HashMap<>() : ImmutableMap.of());

Maybe<Map<String, Object>> maybeFunctionResult =
maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext)
.switchIfEmpty(
Maybe.defer(
() ->
isLive
? processFunctionLive(
invocationContext, tool, toolContext, functionCall, functionArgs)
: callTool(tool, functionArgs, toolContext)));

return postProcessFunctionResult(
maybeFunctionResult, invocationContext, tool, functionArgs, toolContext, isLive);
};
}

/**
* Processes a single function call in a live context. Manages starting, stopping, and running
* tools.
Expand Down Expand Up @@ -427,6 +373,48 @@ public static Set<String> getLongRunningFunctionCalls(
return longRunningFunctionCalls;
}

private static Maybe<Event> postProcessFunctionResult(
Maybe<Map<String, Object>> maybeFunctionResult,
InvocationContext invocationContext,
BaseTool tool,
Map<String, Object> functionArgs,
ToolContext toolContext,
boolean isLive) {
return maybeFunctionResult
.map(Optional::of)
.defaultIfEmpty(Optional.empty())
.onErrorResumeNext(
t ->
invocationContext
.pluginManager()
.runOnToolErrorCallback(tool, functionArgs, toolContext, t)
.map(isLive ? Optional::ofNullable : Optional::of)
.switchIfEmpty(Single.error(t)))
.flatMapMaybe(
optionalInitialResult -> {
Map<String, Object> initialFunctionResult = optionalInitialResult.orElse(null);

Maybe<Map<String, Object>> afterToolResultMaybe =
maybeInvokeAfterToolCall(
invocationContext, tool, functionArgs, toolContext, initialFunctionResult);

return afterToolResultMaybe
.map(Optional::of)
.defaultIfEmpty(Optional.ofNullable(initialFunctionResult))
.flatMapMaybe(
finalOptionalResult -> {
Map<String, Object> finalFunctionResult = finalOptionalResult.orElse(null);
if (tool.longRunning() && finalFunctionResult == null) {
return Maybe.empty();
}
Event functionResponseEvent =
buildResponseEvent(
tool, finalFunctionResult, toolContext, invocationContext);
return Maybe.just(functionResponseEvent);
});
});
}

private static Optional<Event> mergeParallelFunctionResponseEvents(
List<Event> functionResponseEvents) {
if (functionResponseEvents.isEmpty()) {
Expand Down Expand Up @@ -664,5 +652,25 @@ public static Optional<Event> generateRequestConfirmationEvent(
.build());
}

/**
* Gets the ask user confirmation function calls from the event.
*
* @param event The event to extract function calls from.
* @return A list of function calls for asking user confirmation.
*/
public static ImmutableList<FunctionCall> getAskUserConfirmationFunctionCalls(Event event) {
return event.content().flatMap(Content::parts).orElse(ImmutableList.of()).stream()
.flatMap(part -> part.functionCall().stream())
.filter(Functions::isRequestConfirmationFunctionCall)
.collect(toImmutableList());
}

private static boolean isRequestConfirmationFunctionCall(FunctionCall functionCall) {
return functionCall
.name()
.map(name -> name.equals(REQUEST_CONFIRMATION_FUNCTION_CALL_NAME))
.orElse(false);
}

private Functions() {}
}
Loading