Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 90 additions & 15 deletions a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@

import com.google.adk.a2a.converters.EventConverter;
import com.google.adk.a2a.converters.PartConverter;
import com.google.adk.agents.BaseAgent;
import com.google.adk.agents.RunConfig;
import com.google.adk.apps.App;
import com.google.adk.artifacts.BaseArtifactService;
import com.google.adk.events.Event;
import com.google.adk.memory.BaseMemoryService;
import com.google.adk.plugins.Plugin;
import com.google.adk.runner.Runner;
import com.google.adk.sessions.BaseSessionService;
import com.google.adk.sessions.Session;
Expand All @@ -21,6 +26,7 @@
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.disposables.CompositeDisposable;
import io.reactivex.rxjava3.disposables.Disposable;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
Expand All @@ -41,29 +47,96 @@ public class AgentExecutor implements io.a2a.server.agentexecution.AgentExecutor
private static final RunConfig DEFAULT_RUN_CONFIG =
RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.NONE).setMaxLlmCalls(20).build();

private final Runner runner;
private final Map<String, Disposable> activeTasks = new ConcurrentHashMap<>();
private final Runner.Builder runnerBuilder;
private final RunConfig runConfig;

private AgentExecutor(Runner runner) {
this.runner = runner;
private AgentExecutor(
App app,
BaseAgent agent,
String appName,
BaseArtifactService artifactService,
BaseSessionService sessionService,
BaseMemoryService memoryService,
List<? extends Plugin> plugins,
RunConfig runConfig) {
this.runnerBuilder =
Runner.builder()
.app(app)
.agent(agent)
.appName(appName)
.artifactService(artifactService)
.sessionService(sessionService)
.memoryService(memoryService)
.plugins(plugins);
// Check that the runner is configured correctly and can be built.
var unused = runnerBuilder.build();
this.runConfig = runConfig == null ? DEFAULT_RUN_CONFIG : runConfig;
}

/** Builder for {@link AgentExecutor}. */
public static class Builder {
private Runner runner;
private App app;
private BaseAgent agent;
private String appName;
private BaseArtifactService artifactService;
private BaseSessionService sessionService;
private BaseMemoryService memoryService;
private List<? extends Plugin> plugins;
private RunConfig runConfig;

@CanIgnoreReturnValue
public Builder runner(Runner runner) {
this.runner = runner;
public Builder app(App app) {
this.app = app;
return this;
}

@CanIgnoreReturnValue
public Builder agent(BaseAgent agent) {
this.agent = agent;
return this;
}

@CanIgnoreReturnValue
public Builder appName(String appName) {
this.appName = appName;
return this;
}

@CanIgnoreReturnValue
public Builder artifactService(BaseArtifactService artifactService) {
this.artifactService = artifactService;
return this;
}

@CanIgnoreReturnValue
public Builder sessionService(BaseSessionService sessionService) {
this.sessionService = sessionService;
return this;
}

@CanIgnoreReturnValue
public Builder memoryService(BaseMemoryService memoryService) {
this.memoryService = memoryService;
return this;
}

@CanIgnoreReturnValue
public Builder plugins(List<? extends Plugin> plugins) {
this.plugins = plugins;
return this;
}

@CanIgnoreReturnValue
public Builder runConfig(RunConfig runConfig) {
this.runConfig = runConfig;
return this;
}

@CanIgnoreReturnValue
public AgentExecutor build() {
if (runner == null) {
throw new IllegalStateException("Runner must be provided.");
}
return new AgentExecutor(runner);
return new AgentExecutor(
app, agent, appName, artifactService, sessionService, memoryService, plugins, runConfig);
}
}

Expand Down Expand Up @@ -96,13 +169,14 @@ public void execute(RequestContext ctx, EventQueue eventQueue) {

EventProcessor p = new EventProcessor();
Content content = PartConverter.messageToContent(message);
Runner runner = runnerBuilder.build();

taskDisposables.add(
prepareSession(ctx, runner.sessionService())
prepareSession(ctx, runner.appName(), runner.sessionService())
.flatMapPublisher(
session -> {
updater.startWork();
return runner.runAsync(getUserId(ctx), session.id(), content, DEFAULT_RUN_CONFIG);
return runner.runAsync(getUserId(ctx), session.id(), content, runConfig);
})
.subscribe(
event -> {
Expand Down Expand Up @@ -130,13 +204,14 @@ private String getUserId(RequestContext ctx) {
return USER_ID_PREFIX + ctx.getContextId();
}

private Maybe<Session> prepareSession(RequestContext ctx, BaseSessionService service) {
private Maybe<Session> prepareSession(
RequestContext ctx, String appName, BaseSessionService service) {
return service
.getSession(runner.appName(), getUserId(ctx), ctx.getContextId(), Optional.empty())
.getSession(appName, getUserId(ctx), ctx.getContextId(), Optional.empty())
.switchIfEmpty(
Maybe.defer(
() -> {
return service.createSession(runner.appName(), getUserId(ctx)).toMaybe();
return service.createSession(appName, getUserId(ctx)).toMaybe();
}));
}

Expand Down