diff --git a/core/src/main/java/com/google/adk/agents/BaseAgentState.java b/core/src/main/java/com/google/adk/agents/BaseAgentState.java new file mode 100644 index 000000000..dedcb93ab --- /dev/null +++ b/core/src/main/java/com/google/adk/agents/BaseAgentState.java @@ -0,0 +1,39 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.agents; + +import com.google.adk.JsonBaseModel; + +/** Base class for all agent states. */ +public class BaseAgentState extends JsonBaseModel { + + protected BaseAgentState() {} + + /** Returns a new {@link Builder} for creating {@link BaseAgentState} instances. */ + public static Builder builder() { + return new Builder(); + } + + /** Builder for {@link BaseAgentState}. */ + public static class Builder { + private Builder() {} + + public BaseAgentState build() { + return new BaseAgentState(); + } + } +} diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index 6457a8ca4..afee5065a 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -18,7 +18,9 @@ import static com.google.common.base.Strings.isNullOrEmpty; +import com.google.adk.apps.ResumabilityConfig; import com.google.adk.artifacts.BaseArtifactService; +import com.google.adk.events.Event; import com.google.adk.memory.BaseMemoryService; import com.google.adk.models.LlmCallsLimitExceededException; import com.google.adk.plugins.Plugin; @@ -26,9 +28,12 @@ import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; import com.google.adk.summarizer.EventsCompactionConfig; +import com.google.common.collect.ImmutableSet; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.InlineMe; import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; @@ -49,6 +54,9 @@ public class InvocationContext { private final Session session; private final Optional userContent; private final RunConfig runConfig; + private final Map agentStates; + private final Map endOfAgents; + private final ResumabilityConfig resumabilityConfig; @Nullable private final EventsCompactionConfig eventsCompactionConfig; @Nullable private final ContextCacheConfig contextCacheConfig; private final InvocationCostManager invocationCostManager; @@ -72,10 +80,13 @@ protected InvocationContext(Builder builder) { this.userContent = builder.userContent; this.runConfig = builder.runConfig; this.endInvocation = builder.endInvocation; + this.agentStates = builder.agentStates; + this.endOfAgents = builder.endOfAgents; + this.resumabilityConfig = builder.resumabilityConfig; this.eventsCompactionConfig = builder.eventsCompactionConfig; this.contextCacheConfig = builder.contextCacheConfig; this.invocationCostManager = builder.invocationCostManager; - this.callbackContextData = new ConcurrentHashMap<>(builder.callbackContextData); + this.callbackContextData = builder.callbackContextData; } /** @@ -256,7 +267,10 @@ public String invocationId() { /** * Sets the [branch] ID for the current invocation. A branch represents a fork in the conversation * history. + * + * @deprecated Use {@link #toBuilder()} and {@link Builder#branch(String)} instead. */ + @Deprecated(forRemoval = true) public void branch(@Nullable String branch) { this.branch = Optional.ofNullable(branch); } @@ -307,6 +321,16 @@ public Map callbackContextData() { return callbackContextData; } + /** Returns agent-specific state saved within this invocation. */ + public Map agentStates() { + return agentStates; + } + + /** Returns map of agents that ended during this invocation. */ + public Map endOfAgents() { + return endOfAgents; + } + /** * Returns whether this invocation should be ended, e.g., due to reaching a terminal state or * error. @@ -345,6 +369,36 @@ public void incrementLlmCallsCount() throws LlmCallsLimitExceededException { this.invocationCostManager.incrementAndEnforceLlmCallsLimit(this.runConfig); } + /** Returns whether the current invocation is resumable. */ + public boolean isResumable() { + return resumabilityConfig.isResumable(); + } + + /** Returns ResumabilityConfig for this invocation. */ + public ResumabilityConfig resumabilityConfig() { + return resumabilityConfig; + } + + /** + * Populates agentStates and endOfAgents maps by reading session events for this invocation id. + */ + public void populateAgentStates(List events) { + events.stream() + .filter(event -> invocationId().equals(event.invocationId())) + .forEach( + event -> { + if (event.actions() != null) { + if (event.actions().agentState() != null + && !event.actions().agentState().isEmpty()) { + agentStates.putAll(event.actions().agentState()); + } + if (event.actions().endOfAgent()) { + endOfAgents.put(event.author(), true); + } + } + }); + } + /** Returns the events compaction configuration for the current agent run. */ public Optional eventsCompactionConfig() { return Optional.ofNullable(eventsCompactionConfig); @@ -355,6 +409,23 @@ public Optional contextCacheConfig() { return Optional.ofNullable(contextCacheConfig); } + /** Returns whether to pause the invocation right after this [event]. */ + public boolean shouldPauseInvocation(Event event) { + if (!isResumable()) { + return false; + } + + var longRunningToolIds = event.longRunningToolIds().orElse(ImmutableSet.of()); + if (longRunningToolIds.isEmpty()) { + return false; + } + + return event.functionCalls().stream() + .map(FunctionCall::id) + .flatMap(Optional::stream) + .anyMatch(functionCallId -> longRunningToolIds.contains(functionCallId)); + } + private static class InvocationCostManager { private int numberOfLlmCalls = 0; @@ -406,10 +477,13 @@ private Builder(InvocationContext context) { this.userContent = context.userContent; this.runConfig = context.runConfig; this.endInvocation = context.endInvocation; + this.agentStates = new ConcurrentHashMap<>(context.agentStates); + this.endOfAgents = new ConcurrentHashMap<>(context.endOfAgents); + this.resumabilityConfig = context.resumabilityConfig; this.eventsCompactionConfig = context.eventsCompactionConfig; this.contextCacheConfig = context.contextCacheConfig; this.invocationCostManager = context.invocationCostManager; - this.callbackContextData = new ConcurrentHashMap<>(context.callbackContextData); + this.callbackContextData = context.callbackContextData; } private BaseSessionService sessionService; @@ -425,6 +499,9 @@ private Builder(InvocationContext context) { private Optional userContent = Optional.empty(); private RunConfig runConfig = RunConfig.builder().build(); private boolean endInvocation = false; + private Map agentStates = new ConcurrentHashMap<>(); + private Map endOfAgents = new ConcurrentHashMap<>(); + private ResumabilityConfig resumabilityConfig = new ResumabilityConfig(); @Nullable private EventsCompactionConfig eventsCompactionConfig; @Nullable private ContextCacheConfig contextCacheConfig; private InvocationCostManager invocationCostManager = new InvocationCostManager(); @@ -616,6 +693,42 @@ public Builder endInvocation(boolean endInvocation) { return this; } + /** + * Sets agent-specific state saved within this invocation. + * + * @param agentStates agent-specific state saved within this invocation. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder agentStates(Map agentStates) { + this.agentStates = agentStates; + return this; + } + + /** + * Sets agent end-of-invocation status. + * + * @param endOfAgents agent end-of-invocation status. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder endOfAgents(Map endOfAgents) { + this.endOfAgents = endOfAgents; + return this; + } + + /** + * Sets the resumability configuration for the current agent run. + * + * @param resumabilityConfig the resumability configuration. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder resumabilityConfig(ResumabilityConfig resumabilityConfig) { + this.resumabilityConfig = resumabilityConfig; + return this; + } + /** * Sets the events compaction configuration for the current agent run. * @@ -705,6 +818,9 @@ public boolean equals(Object o) { && Objects.equals(session, that.session) && Objects.equals(userContent, that.userContent) && Objects.equals(runConfig, that.runConfig) + && Objects.equals(agentStates, that.agentStates) + && Objects.equals(endOfAgents, that.endOfAgents) + && Objects.equals(resumabilityConfig, that.resumabilityConfig) && Objects.equals(eventsCompactionConfig, that.eventsCompactionConfig) && Objects.equals(contextCacheConfig, that.contextCacheConfig) && Objects.equals(invocationCostManager, that.invocationCostManager) @@ -727,6 +843,9 @@ public int hashCode() { userContent, runConfig, endInvocation, + agentStates, + endOfAgents, + resumabilityConfig, eventsCompactionConfig, contextCacheConfig, invocationCostManager, diff --git a/core/src/main/java/com/google/adk/apps/App.java b/core/src/main/java/com/google/adk/apps/App.java index 18e8753c7..d6635d2e7 100644 --- a/core/src/main/java/com/google/adk/apps/App.java +++ b/core/src/main/java/com/google/adk/apps/App.java @@ -41,6 +41,7 @@ public class App { private final BaseAgent rootAgent; private final ImmutableList plugins; @Nullable private final EventsCompactionConfig eventsCompactionConfig; + @Nullable private final ResumabilityConfig resumabilityConfig; @Nullable private final ContextCacheConfig contextCacheConfig; private App( @@ -48,11 +49,13 @@ private App( BaseAgent rootAgent, List plugins, @Nullable EventsCompactionConfig eventsCompactionConfig, + @Nullable ResumabilityConfig resumabilityConfig, @Nullable ContextCacheConfig contextCacheConfig) { this.name = name; this.rootAgent = rootAgent; this.plugins = ImmutableList.copyOf(plugins); this.eventsCompactionConfig = eventsCompactionConfig; + this.resumabilityConfig = resumabilityConfig; this.contextCacheConfig = contextCacheConfig; } @@ -73,6 +76,11 @@ public EventsCompactionConfig eventsCompactionConfig() { return eventsCompactionConfig; } + @Nullable + public ResumabilityConfig resumabilityConfig() { + return resumabilityConfig; + } + @Nullable public ContextCacheConfig contextCacheConfig() { return contextCacheConfig; @@ -84,6 +92,7 @@ public static class Builder { private BaseAgent rootAgent; private List plugins = ImmutableList.of(); @Nullable private EventsCompactionConfig eventsCompactionConfig; + @Nullable private ResumabilityConfig resumabilityConfig; @Nullable private ContextCacheConfig contextCacheConfig; @CanIgnoreReturnValue @@ -110,6 +119,12 @@ public Builder eventsCompactionConfig(EventsCompactionConfig eventsCompactionCon return this; } + @CanIgnoreReturnValue + public Builder resumabilityConfig(ResumabilityConfig resumabilityConfig) { + this.resumabilityConfig = resumabilityConfig; + return this; + } + @CanIgnoreReturnValue public Builder contextCacheConfig(ContextCacheConfig contextCacheConfig) { this.contextCacheConfig = contextCacheConfig; @@ -124,7 +139,8 @@ public App build() { throw new IllegalStateException("Root agent must be provided."); } validateAppName(name); - return new App(name, rootAgent, plugins, eventsCompactionConfig, contextCacheConfig); + return new App( + name, rootAgent, plugins, eventsCompactionConfig, resumabilityConfig, contextCacheConfig); } } diff --git a/core/src/main/java/com/google/adk/apps/ResumabilityConfig.java b/core/src/main/java/com/google/adk/apps/ResumabilityConfig.java new file mode 100644 index 000000000..b80ce709c --- /dev/null +++ b/core/src/main/java/com/google/adk/apps/ResumabilityConfig.java @@ -0,0 +1,28 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS language governing permissions and + * limitations under the License. + */ +package com.google.adk.apps; + +/** + * An app contains Resumability configuration for the agents. + * + * @param isResumable Whether the app is resumable. + */ +public record ResumabilityConfig(boolean isResumable) { + + /** Creates a new {@code ResumabilityConfig} with resumability disabled. */ + public ResumabilityConfig() { + this(false); + } +} diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 6d8c698dd..07488a171 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -19,6 +19,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.google.adk.JsonBaseModel; +import com.google.adk.agents.BaseAgentState; import com.google.adk.sessions.State; import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.util.HashSet; @@ -43,7 +44,9 @@ public class EventActions extends JsonBaseModel { private ConcurrentMap> requestedAuthConfigs; private ConcurrentMap requestedToolConfirmations; private boolean endOfAgent; + private ConcurrentMap agentState; private Optional compaction; + private Optional rewindBeforeInvocationId; /** Default constructor for Jackson. */ public EventActions() { @@ -57,6 +60,8 @@ public EventActions() { this.requestedToolConfirmations = new ConcurrentHashMap<>(); this.endOfAgent = false; this.compaction = Optional.empty(); + this.agentState = new ConcurrentHashMap<>(); + this.rewindBeforeInvocationId = Optional.empty(); } private EventActions(Builder builder) { @@ -70,6 +75,8 @@ private EventActions(Builder builder) { this.requestedToolConfirmations = builder.requestedToolConfirmations; this.endOfAgent = builder.endOfAgent; this.compaction = builder.compaction; + this.agentState = builder.agentState; + this.rewindBeforeInvocationId = builder.rewindBeforeInvocationId; } @JsonProperty("skipSummarization") @@ -216,6 +223,25 @@ public void setCompaction(Optional compaction) { this.compaction = compaction; } + @JsonProperty("agentState") + @JsonInclude(JsonInclude.Include.NON_EMPTY) + public ConcurrentMap agentState() { + return agentState; + } + + public void setAgentState(ConcurrentMap agentState) { + this.agentState = agentState; + } + + @JsonProperty("rewindBeforeInvocationId") + public Optional rewindBeforeInvocationId() { + return rewindBeforeInvocationId; + } + + public void setRewindBeforeInvocationId(@Nullable String rewindBeforeInvocationId) { + this.rewindBeforeInvocationId = Optional.ofNullable(rewindBeforeInvocationId); + } + public static Builder builder() { return new Builder(); } @@ -241,7 +267,9 @@ public boolean equals(Object o) { && Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs) && Objects.equals(requestedToolConfirmations, that.requestedToolConfirmations) && (endOfAgent == that.endOfAgent) - && Objects.equals(compaction, that.compaction); + && Objects.equals(compaction, that.compaction) + && Objects.equals(agentState, that.agentState) + && Objects.equals(rewindBeforeInvocationId, that.rewindBeforeInvocationId); } @Override @@ -256,7 +284,9 @@ public int hashCode() { requestedAuthConfigs, requestedToolConfirmations, endOfAgent, - compaction); + compaction, + agentState, + rewindBeforeInvocationId); } /** Builder for {@link EventActions}. */ @@ -271,6 +301,8 @@ public static class Builder { private ConcurrentMap requestedToolConfirmations; private boolean endOfAgent = false; private Optional compaction; + private ConcurrentMap agentState; + private Optional rewindBeforeInvocationId; public Builder() { this.skipSummarization = Optional.empty(); @@ -282,6 +314,8 @@ public Builder() { this.requestedAuthConfigs = new ConcurrentHashMap<>(); this.requestedToolConfirmations = new ConcurrentHashMap<>(); this.compaction = Optional.empty(); + this.agentState = new ConcurrentHashMap<>(); + this.rewindBeforeInvocationId = Optional.empty(); } private Builder(EventActions eventActions) { @@ -296,6 +330,8 @@ private Builder(EventActions eventActions) { new ConcurrentHashMap<>(eventActions.requestedToolConfirmations()); this.endOfAgent = eventActions.endOfAgent(); this.compaction = eventActions.compaction(); + this.agentState = new ConcurrentHashMap<>(eventActions.agentState()); + this.rewindBeforeInvocationId = eventActions.rewindBeforeInvocationId(); } @CanIgnoreReturnValue @@ -380,6 +416,20 @@ public Builder compaction(EventCompaction value) { return this; } + @CanIgnoreReturnValue + @JsonProperty("agentState") + public Builder agentState(ConcurrentMap agentState) { + this.agentState = agentState; + return this; + } + + @CanIgnoreReturnValue + @JsonProperty("rewindBeforeInvocationId") + public Builder rewindBeforeInvocationId(String rewindBeforeInvocationId) { + this.rewindBeforeInvocationId = Optional.ofNullable(rewindBeforeInvocationId); + return this; + } + @CanIgnoreReturnValue public Builder merge(EventActions other) { other.skipSummarization().ifPresent(this::skipSummarization); @@ -392,6 +442,8 @@ public Builder merge(EventActions other) { this.requestedToolConfirmations.putAll(other.requestedToolConfirmations()); this.endOfAgent = other.endOfAgent(); other.compaction().ifPresent(this::compaction); + this.agentState.putAll(other.agentState()); + other.rewindBeforeInvocationId().ifPresent(this::rewindBeforeInvocationId); return this; } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 3ff778011..070ecec23 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -24,6 +24,7 @@ import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; import com.google.adk.apps.App; +import com.google.adk.apps.ResumabilityConfig; import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.artifacts.InMemoryArtifactService; import com.google.adk.events.Event; @@ -73,6 +74,7 @@ public class Runner { private final BaseSessionService sessionService; @Nullable private final BaseMemoryService memoryService; private final PluginManager pluginManager; + private final ResumabilityConfig resumabilityConfig; @Nullable private final EventsCompactionConfig eventsCompactionConfig; @Nullable private final ContextCacheConfig contextCacheConfig; @@ -136,6 +138,7 @@ public Runner build() { BaseAgent buildAgent; String buildAppName; List buildPlugins; + ResumabilityConfig buildResumabilityConfig; EventsCompactionConfig buildEventsCompactionConfig; ContextCacheConfig buildContextCacheConfig; @@ -149,12 +152,17 @@ public Runner build() { buildAgent = this.app.rootAgent(); buildPlugins = this.app.plugins(); buildAppName = this.appName == null ? this.app.name() : this.appName; + buildResumabilityConfig = + this.app.resumabilityConfig() != null + ? this.app.resumabilityConfig() + : new ResumabilityConfig(); buildEventsCompactionConfig = this.app.eventsCompactionConfig(); buildContextCacheConfig = this.app.contextCacheConfig(); } else { buildAgent = this.agent; buildAppName = this.appName; buildPlugins = this.plugins; + buildResumabilityConfig = new ResumabilityConfig(); buildEventsCompactionConfig = null; buildContextCacheConfig = null; } @@ -178,6 +186,7 @@ public Runner build() { sessionService, memoryService, buildPlugins, + buildResumabilityConfig, buildEventsCompactionConfig, buildContextCacheConfig); } @@ -199,7 +208,14 @@ public Runner( BaseArtifactService artifactService, BaseSessionService sessionService, @Nullable BaseMemoryService memoryService) { - this(agent, appName, artifactService, sessionService, memoryService, ImmutableList.of()); + this( + agent, + appName, + artifactService, + sessionService, + memoryService, + ImmutableList.of(), + new ResumabilityConfig()); } /** @@ -215,7 +231,40 @@ public Runner( BaseSessionService sessionService, @Nullable BaseMemoryService memoryService, List plugins) { - this(agent, appName, artifactService, sessionService, memoryService, plugins, null, null); + this( + agent, + appName, + artifactService, + sessionService, + memoryService, + plugins, + new ResumabilityConfig()); + } + + /** + * Creates a new {@code Runner} with a list of plugins and resumability config. + * + * @deprecated Use {@link Runner.Builder} instead. + */ + @Deprecated + public Runner( + BaseAgent agent, + String appName, + BaseArtifactService artifactService, + BaseSessionService sessionService, + @Nullable BaseMemoryService memoryService, + List plugins, + ResumabilityConfig resumabilityConfig) { + this( + agent, + appName, + artifactService, + sessionService, + memoryService, + plugins, + resumabilityConfig, + null, + null); } /** @@ -231,6 +280,7 @@ protected Runner( BaseSessionService sessionService, @Nullable BaseMemoryService memoryService, List plugins, + ResumabilityConfig resumabilityConfig, @Nullable EventsCompactionConfig eventsCompactionConfig, @Nullable ContextCacheConfig contextCacheConfig) { this.agent = agent; @@ -239,6 +289,7 @@ protected Runner( this.sessionService = sessionService; this.memoryService = memoryService; this.pluginManager = new PluginManager(plugins); + this.resumabilityConfig = resumabilityConfig; this.eventsCompactionConfig = createEventsCompactionConfig(agent, eventsCompactionConfig); this.contextCacheConfig = contextCacheConfig; } @@ -617,6 +668,7 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) { .pluginManager(this.pluginManager) .agent(rootAgent) .session(session) + .resumabilityConfig(this.resumabilityConfig) .eventsCompactionConfig(this.eventsCompactionConfig) .contextCacheConfig(this.contextCacheConfig) .agent(this.findAgentToRun(session, rootAgent)); diff --git a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java index 0237261c5..bbfbb74bd 100644 --- a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java +++ b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java @@ -20,15 +20,22 @@ import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; +import com.google.adk.apps.ResumabilityConfig; import com.google.adk.artifacts.BaseArtifactService; +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; import com.google.adk.memory.BaseMemoryService; import com.google.adk.models.LlmCallsLimitExceededException; import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; import com.google.adk.summarizer.EventsCompactionConfig; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.Part; import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -179,12 +186,12 @@ public void testToBuilder() { assertThat(copiedContext.activeStreamingTools()) .isEqualTo(originalContext.activeStreamingTools()); assertThat(copiedContext.callbackContextData()) - .isEqualTo(originalContext.callbackContextData()); + .isSameInstanceAs(originalContext.callbackContextData()); } @Test public void testBuildWithCallbackContextData() { - ConcurrentHashMap data = new ConcurrentHashMap<>(); + Map data = new ConcurrentHashMap<>(); data.put("key", "value"); InvocationContext context = InvocationContext.builder() @@ -196,6 +203,7 @@ public void testBuildWithCallbackContextData() { .build(); assertThat(context.callbackContextData()).isEqualTo(data); + assertThat(context.callbackContextData()).isSameInstanceAs(data); } @Test @@ -435,7 +443,7 @@ public void testEquals_differentValues() { .userContent(userContent) .runConfig(runConfig) .endInvocation(false) - .callbackContextData(new ConcurrentHashMap<>(ImmutableMap.of("key", "value"))) + .callbackContextData(ImmutableMap.of("key", "value")) .build(); assertThat(context.equals(contextWithDiffCallbackContextData)).isFalse(); } @@ -500,11 +508,161 @@ public void testHashCode_differentValues() { .userContent(userContent) .runConfig(runConfig) .endInvocation(false) - .callbackContextData(new ConcurrentHashMap<>(ImmutableMap.of("key", "value"))) + .callbackContextData(ImmutableMap.of("key", "value")) .build(); assertThat(context.hashCode()).isNotEqualTo(contextWithDiffCallbackContextData.hashCode()); } + @Test + public void isResumable_whenResumabilityConfigIsNotResumable_isFalse() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .session(session) + .resumabilityConfig(new ResumabilityConfig(false)) + .build(); + assertThat(context.isResumable()).isFalse(); + } + + @Test + public void isResumable_whenResumabilityConfigIsResumable_isTrue() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .session(session) + .resumabilityConfig(new ResumabilityConfig(true)) + .build(); + assertThat(context.isResumable()).isTrue(); + } + + @Test + public void shouldPauseInvocation_whenNotResumable_isFalse() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .session(session) + .resumabilityConfig(new ResumabilityConfig(false)) + .build(); + Event event = + Event.builder() + .longRunningToolIds(Optional.of(ImmutableSet.of("fc1"))) + .content( + Content.builder() + .parts( + ImmutableList.of( + Part.builder() + .functionCall( + FunctionCall.builder().name("tool1").id("fc1").build()) + .build())) + .build()) + .build(); + assertThat(context.shouldPauseInvocation(event)).isFalse(); + } + + @Test + public void shouldPauseInvocation_whenResumableAndNoLongRunningToolIds_isFalse() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .session(session) + .resumabilityConfig(new ResumabilityConfig(true)) + .build(); + Event event = + Event.builder() + .content( + Content.builder() + .parts( + ImmutableList.of( + Part.builder() + .functionCall( + FunctionCall.builder().name("tool1").id("fc1").build()) + .build())) + .build()) + .build(); + assertThat(context.shouldPauseInvocation(event)).isFalse(); + } + + @Test + public void shouldPauseInvocation_whenResumableAndNoFunctionCalls_isFalse() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .session(session) + .resumabilityConfig(new ResumabilityConfig(true)) + .build(); + Event event = Event.builder().longRunningToolIds(Optional.of(ImmutableSet.of("fc1"))).build(); + assertThat(context.shouldPauseInvocation(event)).isFalse(); + } + + @Test + public void shouldPauseInvocation_whenResumableAndNoMatchingFunctionCallId_isFalse() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .session(session) + .resumabilityConfig(new ResumabilityConfig(true)) + .build(); + Event event = + Event.builder() + .longRunningToolIds(Optional.of(ImmutableSet.of("fc2"))) + .content( + Content.builder() + .parts( + ImmutableList.of( + Part.builder() + .functionCall( + FunctionCall.builder().name("tool1").id("fc1").build()) + .build())) + .build()) + .build(); + assertThat(context.shouldPauseInvocation(event)).isFalse(); + } + + @Test + public void shouldPauseInvocation_whenResumableAndMatchingFunctionCallId_isTrue() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .session(session) + .resumabilityConfig(new ResumabilityConfig(true)) + .build(); + Event event = + Event.builder() + .longRunningToolIds(Optional.of(ImmutableSet.of("fc1"))) + .content( + Content.builder() + .parts( + ImmutableList.of( + Part.builder() + .functionCall( + FunctionCall.builder().name("tool1").id("fc1").build()) + .build())) + .build()) + .build(); + assertThat(context.shouldPauseInvocation(event)).isTrue(); + } + @Test public void incrementLlmCallsCount_whenLimitNotExceeded_doesNotThrow() throws Exception { InvocationContext context = @@ -571,6 +729,26 @@ public void testSessionGetters() { assertThat(context.userId()).isEqualTo("test-user"); } + @Test + public void testAgentStatesAndEndOfAgents() { + BaseAgentState mockState = mock(BaseAgentState.class); + ImmutableMap states = ImmutableMap.of("agent1", mockState); + ImmutableMap endOfAgents = ImmutableMap.of("agent1", true); + + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .agentStates(states) + .endOfAgents(endOfAgents) + .build(); + + assertThat(context.agentStates()).isEqualTo(states); + assertThat(context.endOfAgents()).isEqualTo(endOfAgents); + } + @Test public void testSetEndInvocation() { InvocationContext context = @@ -758,6 +936,58 @@ public void testDeprecatedConstructor_11params() { assertThat(context.endInvocation()).isTrue(); } + @Test + public void populateAgentStates_populatesAgentStatesAndEndOfAgents() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .invocationId(testInvocationId) + .build(); + + BaseAgentState agent1State = mock(BaseAgentState.class); + ConcurrentHashMap agent1StateMap = new ConcurrentHashMap<>(); + agent1StateMap.put("agent1", agent1State); + Event event1 = + Event.builder() + .invocationId(testInvocationId) + .author("agent1") + .actions(EventActions.builder().agentState(agent1StateMap).endOfAgent(true).build()) + .build(); + Event event2 = + Event.builder() + .invocationId("other-invocation-id") + .author("agent2") + .actions(EventActions.builder().endOfAgent(true).build()) + .build(); + Event event3 = + Event.builder() + .invocationId(testInvocationId) + .author("agent3") + .actions(EventActions.builder().endOfAgent(false).build()) + .build(); + BaseAgentState agent4State = mock(BaseAgentState.class); + ConcurrentHashMap agent4StateMap = new ConcurrentHashMap<>(); + agent4StateMap.put("agent4", agent4State); + Event event4 = + Event.builder() + .invocationId(testInvocationId) + .author("agent4") + .actions(EventActions.builder().agentState(agent4StateMap).endOfAgent(false).build()) + .build(); + Event event5 = Event.builder().invocationId(testInvocationId).author("agent5").build(); + + context.populateAgentStates(ImmutableList.of(event1, event2, event3, event4, event5)); + + assertThat(context.agentStates()).hasSize(2); + assertThat(context.agentStates()).containsEntry("agent1", agent1State); + assertThat(context.agentStates()).containsEntry("agent4", agent4State); + assertThat(context.endOfAgents()).hasSize(1); + assertThat(context.endOfAgents()).containsEntry("agent1", true); + } + @Test public void build_missingInvocationId_null_throwsException() { InvocationContext.Builder builder = diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 421b79abb..86b0a81ec 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -36,6 +36,7 @@ import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; import com.google.adk.apps.App; +import com.google.adk.apps.ResumabilityConfig; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.Functions; import com.google.adk.models.LlmResponse; @@ -928,6 +929,48 @@ public void runLive_createsInvocationSpan() { assertThat(invocationSpan.get().hasEnded()).isTrue(); } + @Test + public void resumabilityConfig_isResumable_isTrueInInvocationContext() { + ArgumentCaptor contextCaptor = + ArgumentCaptor.forClass(InvocationContext.class); + when(plugin.beforeRunCallback(contextCaptor.capture())).thenReturn(Maybe.empty()); + Runner runner = + Runner.builder() + .app( + App.builder() + .name("test") + .rootAgent(agent) + .plugins(ImmutableList.of(plugin)) + .resumabilityConfig(new ResumabilityConfig(true)) + .build()) + .build(); + Session session = runner.sessionService().createSession("test", "user").blockingGet(); + var unused = + runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); + assertThat(contextCaptor.getValue().isResumable()).isTrue(); + } + + @Test + public void resumabilityConfig_isNotResumable_isFalseInInvocationContext() { + ArgumentCaptor contextCaptor = + ArgumentCaptor.forClass(InvocationContext.class); + when(plugin.beforeRunCallback(contextCaptor.capture())).thenReturn(Maybe.empty()); + Runner runner = + Runner.builder() + .app( + App.builder() + .name("test") + .rootAgent(agent) + .plugins(ImmutableList.of(plugin)) + .resumabilityConfig(new ResumabilityConfig(false)) + .build()) + .build(); + Session session = runner.sessionService().createSession("test", "user").blockingGet(); + var unused = + runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); + assertThat(contextCaptor.getValue().isResumable()).isFalse(); + } + @Test public void runAsync_withoutSessionAndAutoCreateSessionTrue_createsSession() { RunConfig runConfig = RunConfig.builder().setAutoCreateSession(true).build();