Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 59 additions & 6 deletions core/src/main/java/com/google/adk/tools/AgentTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,23 @@
import com.google.adk.SchemaUtils;
import com.google.adk.agents.BaseAgent;
import com.google.adk.agents.LlmAgent;
import com.google.adk.artifacts.BaseArtifactService;
import com.google.adk.artifacts.InMemoryArtifactService;
import com.google.adk.events.Event;
import com.google.adk.runner.InMemoryRunner;
import com.google.adk.memory.BaseMemoryService;
import com.google.adk.memory.InMemoryMemoryService;
import com.google.adk.plugins.BasePlugin;
import com.google.adk.runner.Runner;
import com.google.adk.sessions.BaseSessionService;
import com.google.adk.sessions.InMemorySessionService;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.genai.types.Content;
import com.google.genai.types.FunctionDeclaration;
import com.google.genai.types.Part;
import com.google.genai.types.Schema;
import io.reactivex.rxjava3.core.Single;
import java.util.List;
import java.util.Map;
import java.util.Optional;

Expand All @@ -39,19 +46,42 @@ public class AgentTool extends BaseTool {

private final BaseAgent agent;
private final boolean skipSummarization;
private final List<BasePlugin> plugins;
private final BaseSessionService sessionService;
private final BaseArtifactService artifactService;
private final BaseMemoryService memoryService;

public static AgentTool create(
BaseAgent agent,
BaseSessionService sessionService,
BaseArtifactService artifactService,
BaseMemoryService memoryService,
List<BasePlugin> plugins) {
return new AgentTool(agent, false, sessionService, artifactService, memoryService, plugins);
}

public static AgentTool create(BaseAgent agent, boolean skipSummarization) {
return new AgentTool(agent, skipSummarization);
return new AgentTool(agent, skipSummarization, null, null, null, ImmutableList.of());
}

public static AgentTool create(BaseAgent agent) {
return new AgentTool(agent, false);
return new AgentTool(agent, false, null, null, null, ImmutableList.of());
}

protected AgentTool(BaseAgent agent, boolean skipSummarization) {
protected AgentTool(
BaseAgent agent,
boolean skipSummarization,
BaseSessionService sessionService,
BaseArtifactService artifactService,
BaseMemoryService memoryService,
List<BasePlugin> plugins) {
super(agent.name(), agent.description());
this.agent = agent;
this.skipSummarization = skipSummarization;
this.sessionService = sessionService;
this.artifactService = artifactService;
this.memoryService = memoryService;
this.plugins = plugins != null ? plugins : ImmutableList.of();
}

@Override
Expand Down Expand Up @@ -104,12 +134,34 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
content = Content.fromParts(Part.fromText(input.toString()));
}

Runner runner = new InMemoryRunner(this.agent, toolContext.agentName());
// Determine effective services: use injected singletons if present, otherwise create fresh
// instances per run (default behavior)
BaseSessionService effectiveSessionService =
this.sessionService != null ? this.sessionService : new InMemorySessionService();
BaseArtifactService effectiveArtifactService =
this.artifactService != null ? this.artifactService : new InMemoryArtifactService();
BaseMemoryService effectiveMemoryService =
this.memoryService != null ? this.memoryService : new InMemoryMemoryService();

Runner runner =
new Runner(
this.agent,
toolContext.agentName(),
effectiveArtifactService,
effectiveSessionService,
effectiveMemoryService,
this.plugins);

String userId = "tmp-user";
if (toolContext.userId() != null) {
userId = toolContext.userId();
}

// Session state is final, can't update to toolContext state
// session.toBuilder().setState(toolContext.getState());
return runner
.sessionService()
.createSession(toolContext.agentName(), "tmp-user", toolContext.state(), null)
.createSession(toolContext.agentName(), userId, toolContext.state(), null)
.flatMapPublisher(session -> runner.runAsync(session.userId(), session.id(), content))
.lastElement()
.map(Optional::of)
Expand Down Expand Up @@ -140,3 +192,4 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
});
}
}

79 changes: 79 additions & 0 deletions core/src/test/java/com/google/adk/tools/AgentToolTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@

import com.google.adk.agents.InvocationContext;
import com.google.adk.agents.LlmAgent;
import com.google.adk.artifacts.InMemoryArtifactService;
import com.google.adk.memory.InMemoryMemoryService;
import com.google.adk.models.LlmResponse;
import com.google.adk.plugins.BasePlugin;
import com.google.adk.sessions.InMemorySessionService;
import com.google.adk.sessions.ListSessionsResponse;
import com.google.adk.sessions.Session;
import com.google.adk.testing.TestLlm;
import com.google.common.collect.ImmutableList;
Expand All @@ -33,8 +38,10 @@
import com.google.genai.types.Part;
import com.google.genai.types.Schema;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.Maybe;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand Down Expand Up @@ -344,6 +351,78 @@ public void call_withoutInputSchema_requestIsSentToAgent() throws Exception {
.containsExactly(Content.fromParts(Part.fromText("magic")));
}

@Test
public void create_withServicesAndPlugins_initializesCorrectly() {
LlmAgent testAgent =
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
.name("agent name")
.description("agent description")
.build();

AgentTool agentTool =
AgentTool.create(
testAgent,
new InMemorySessionService(),
new InMemoryArtifactService(),
new InMemoryMemoryService(),
ImmutableList.of());

assertThat(agentTool).isNotNull();
assertThat(agentTool.declaration()).isPresent();
}

@Test
public void runAsync_withServicesAndPlugins_usesThem() {
LlmAgent testAgent =
createTestAgentBuilder(
createTestLlm(
LlmResponse.builder()
.content(Content.fromParts(Part.fromText("Sub-agent executed")))
.build()))
.name("sub-agent")
.description("sub-agent description")
.build();

InMemorySessionService sessionService = new InMemorySessionService();
TestPlugin testPlugin = new TestPlugin();

AgentTool agentTool =
AgentTool.create(
testAgent,
sessionService,
new InMemoryArtifactService(),
new InMemoryMemoryService(),
ImmutableList.of(testPlugin));

ToolContext toolContext = createToolContext(testAgent);

Map<String, Object> result =
agentTool.runAsync(ImmutableMap.of("request", "start"), toolContext).blockingGet();

assertThat(result).containsEntry("result", "Sub-agent executed");

assertThat(testPlugin.wasCalled.get()).isTrue();

ListSessionsResponse sessionsResponse =
sessionService.listSessions("sub-agent", "tmp-user").blockingGet();
assertThat(sessionsResponse.sessions()).isNotEmpty();
}

private static class TestPlugin extends BasePlugin {
final AtomicBoolean wasCalled = new AtomicBoolean(false);

TestPlugin() {
super("test-plugin");
}

@Override
public Maybe<Content> onUserMessageCallback(
InvocationContext invocationContext, Content userMessage) {
wasCalled.set(true);
return Maybe.empty();
}
}

private static ToolContext createToolContext(LlmAgent agent) {
return ToolContext.builder(
new InvocationContext(
Expand Down