From 5aec480ce3938d872e1cfbe9658ea3333e99fbcd Mon Sep 17 00:00:00 2001 From: lukas Date: Fri, 5 Dec 2025 13:10:52 +0100 Subject: [PATCH] feat: Enhance AgentTool to support custom services and plugins --- .../java/com/google/adk/tools/AgentTool.java | 65 +++++++++++++-- .../com/google/adk/tools/AgentToolTest.java | 79 +++++++++++++++++++ 2 files changed, 138 insertions(+), 6 deletions(-) diff --git a/core/src/main/java/com/google/adk/tools/AgentTool.java b/core/src/main/java/com/google/adk/tools/AgentTool.java index ebefb6e96..90d73891d 100644 --- a/core/src/main/java/com/google/adk/tools/AgentTool.java +++ b/core/src/main/java/com/google/adk/tools/AgentTool.java @@ -21,9 +21,15 @@ import com.google.adk.SchemaUtils; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.LlmAgent; +import com.google.adk.artifacts.BaseArtifactService; +import com.google.adk.artifacts.InMemoryArtifactService; import com.google.adk.events.Event; -import com.google.adk.runner.InMemoryRunner; +import com.google.adk.memory.BaseMemoryService; +import com.google.adk.memory.InMemoryMemoryService; +import com.google.adk.plugins.BasePlugin; import com.google.adk.runner.Runner; +import com.google.adk.sessions.BaseSessionService; +import com.google.adk.sessions.InMemorySessionService; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; @@ -31,6 +37,7 @@ import com.google.genai.types.Part; import com.google.genai.types.Schema; import io.reactivex.rxjava3.core.Single; +import java.util.List; import java.util.Map; import java.util.Optional; @@ -39,19 +46,42 @@ public class AgentTool extends BaseTool { private final BaseAgent agent; private final boolean skipSummarization; + private final List plugins; + private final BaseSessionService sessionService; + private final BaseArtifactService artifactService; + private final BaseMemoryService memoryService; + + public static AgentTool create( + BaseAgent agent, + BaseSessionService sessionService, + BaseArtifactService artifactService, + BaseMemoryService memoryService, + List plugins) { + return new AgentTool(agent, false, sessionService, artifactService, memoryService, plugins); + } public static AgentTool create(BaseAgent agent, boolean skipSummarization) { - return new AgentTool(agent, skipSummarization); + return new AgentTool(agent, skipSummarization, null, null, null, ImmutableList.of()); } public static AgentTool create(BaseAgent agent) { - return new AgentTool(agent, false); + return new AgentTool(agent, false, null, null, null, ImmutableList.of()); } - protected AgentTool(BaseAgent agent, boolean skipSummarization) { + protected AgentTool( + BaseAgent agent, + boolean skipSummarization, + BaseSessionService sessionService, + BaseArtifactService artifactService, + BaseMemoryService memoryService, + List plugins) { super(agent.name(), agent.description()); this.agent = agent; this.skipSummarization = skipSummarization; + this.sessionService = sessionService; + this.artifactService = artifactService; + this.memoryService = memoryService; + this.plugins = plugins != null ? plugins : ImmutableList.of(); } @Override @@ -104,12 +134,34 @@ public Single> runAsync(Map args, ToolContex content = Content.fromParts(Part.fromText(input.toString())); } - Runner runner = new InMemoryRunner(this.agent, toolContext.agentName()); + // Determine effective services: use injected singletons if present, otherwise create fresh + // instances per run (default behavior) + BaseSessionService effectiveSessionService = + this.sessionService != null ? this.sessionService : new InMemorySessionService(); + BaseArtifactService effectiveArtifactService = + this.artifactService != null ? this.artifactService : new InMemoryArtifactService(); + BaseMemoryService effectiveMemoryService = + this.memoryService != null ? this.memoryService : new InMemoryMemoryService(); + + Runner runner = + new Runner( + this.agent, + toolContext.agentName(), + effectiveArtifactService, + effectiveSessionService, + effectiveMemoryService, + this.plugins); + + String userId = "tmp-user"; + if (toolContext.userId() != null) { + userId = toolContext.userId(); + } + // Session state is final, can't update to toolContext state // session.toBuilder().setState(toolContext.getState()); return runner .sessionService() - .createSession(toolContext.agentName(), "tmp-user", toolContext.state(), null) + .createSession(toolContext.agentName(), userId, toolContext.state(), null) .flatMapPublisher(session -> runner.runAsync(session.userId(), session.id(), content)) .lastElement() .map(Optional::of) @@ -140,3 +192,4 @@ public Single> runAsync(Map args, ToolContex }); } } + diff --git a/core/src/test/java/com/google/adk/tools/AgentToolTest.java b/core/src/test/java/com/google/adk/tools/AgentToolTest.java index a47408d43..4c05ae2dc 100644 --- a/core/src/test/java/com/google/adk/tools/AgentToolTest.java +++ b/core/src/test/java/com/google/adk/tools/AgentToolTest.java @@ -23,7 +23,12 @@ import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; +import com.google.adk.artifacts.InMemoryArtifactService; +import com.google.adk.memory.InMemoryMemoryService; import com.google.adk.models.LlmResponse; +import com.google.adk.plugins.BasePlugin; +import com.google.adk.sessions.InMemorySessionService; +import com.google.adk.sessions.ListSessionsResponse; import com.google.adk.sessions.Session; import com.google.adk.testing.TestLlm; import com.google.common.collect.ImmutableList; @@ -33,8 +38,10 @@ import com.google.genai.types.Part; import com.google.genai.types.Schema; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; import java.util.Map; import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -344,6 +351,78 @@ public void call_withoutInputSchema_requestIsSentToAgent() throws Exception { .containsExactly(Content.fromParts(Part.fromText("magic"))); } + @Test + public void create_withServicesAndPlugins_initializesCorrectly() { + LlmAgent testAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("agent name") + .description("agent description") + .build(); + + AgentTool agentTool = + AgentTool.create( + testAgent, + new InMemorySessionService(), + new InMemoryArtifactService(), + new InMemoryMemoryService(), + ImmutableList.of()); + + assertThat(agentTool).isNotNull(); + assertThat(agentTool.declaration()).isPresent(); + } + + @Test + public void runAsync_withServicesAndPlugins_usesThem() { + LlmAgent testAgent = + createTestAgentBuilder( + createTestLlm( + LlmResponse.builder() + .content(Content.fromParts(Part.fromText("Sub-agent executed"))) + .build())) + .name("sub-agent") + .description("sub-agent description") + .build(); + + InMemorySessionService sessionService = new InMemorySessionService(); + TestPlugin testPlugin = new TestPlugin(); + + AgentTool agentTool = + AgentTool.create( + testAgent, + sessionService, + new InMemoryArtifactService(), + new InMemoryMemoryService(), + ImmutableList.of(testPlugin)); + + ToolContext toolContext = createToolContext(testAgent); + + Map result = + agentTool.runAsync(ImmutableMap.of("request", "start"), toolContext).blockingGet(); + + assertThat(result).containsEntry("result", "Sub-agent executed"); + + assertThat(testPlugin.wasCalled.get()).isTrue(); + + ListSessionsResponse sessionsResponse = + sessionService.listSessions("sub-agent", "tmp-user").blockingGet(); + assertThat(sessionsResponse.sessions()).isNotEmpty(); + } + + private static class TestPlugin extends BasePlugin { + final AtomicBoolean wasCalled = new AtomicBoolean(false); + + TestPlugin() { + super("test-plugin"); + } + + @Override + public Maybe onUserMessageCallback( + InvocationContext invocationContext, Content userMessage) { + wasCalled.set(true); + return Maybe.empty(); + } + } + private static ToolContext createToolContext(LlmAgent agent) { return ToolContext.builder( new InvocationContext(