diff --git a/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java b/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java index 6df01694a..b98fba314 100644 --- a/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java +++ b/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java @@ -2,8 +2,13 @@ import com.google.adk.a2a.converters.EventConverter; import com.google.adk.a2a.converters.PartConverter; +import com.google.adk.agents.BaseAgent; import com.google.adk.agents.RunConfig; +import com.google.adk.apps.App; +import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; +import com.google.adk.memory.BaseMemoryService; +import com.google.adk.plugins.Plugin; import com.google.adk.runner.Runner; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; @@ -21,6 +26,7 @@ import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.disposables.CompositeDisposable; import io.reactivex.rxjava3.disposables.Disposable; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.UUID; @@ -41,29 +47,96 @@ public class AgentExecutor implements io.a2a.server.agentexecution.AgentExecutor private static final RunConfig DEFAULT_RUN_CONFIG = RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.NONE).setMaxLlmCalls(20).build(); - private final Runner runner; private final Map activeTasks = new ConcurrentHashMap<>(); + private final Runner.Builder runnerBuilder; + private final RunConfig runConfig; - private AgentExecutor(Runner runner) { - this.runner = runner; + private AgentExecutor( + App app, + BaseAgent agent, + String appName, + BaseArtifactService artifactService, + BaseSessionService sessionService, + BaseMemoryService memoryService, + List plugins, + RunConfig runConfig) { + this.runnerBuilder = + Runner.builder() + .app(app) + .agent(agent) + .appName(appName) + .artifactService(artifactService) + .sessionService(sessionService) + .memoryService(memoryService) + .plugins(plugins); + // Check that the runner is configured correctly and can be built. + var unused = runnerBuilder.build(); + this.runConfig = runConfig == null ? DEFAULT_RUN_CONFIG : runConfig; } /** Builder for {@link AgentExecutor}. */ public static class Builder { - private Runner runner; + private App app; + private BaseAgent agent; + private String appName; + private BaseArtifactService artifactService; + private BaseSessionService sessionService; + private BaseMemoryService memoryService; + private List plugins; + private RunConfig runConfig; @CanIgnoreReturnValue - public Builder runner(Runner runner) { - this.runner = runner; + public Builder app(App app) { + this.app = app; + return this; + } + + @CanIgnoreReturnValue + public Builder agent(BaseAgent agent) { + this.agent = agent; + return this; + } + + @CanIgnoreReturnValue + public Builder appName(String appName) { + this.appName = appName; + return this; + } + + @CanIgnoreReturnValue + public Builder artifactService(BaseArtifactService artifactService) { + this.artifactService = artifactService; + return this; + } + + @CanIgnoreReturnValue + public Builder sessionService(BaseSessionService sessionService) { + this.sessionService = sessionService; + return this; + } + + @CanIgnoreReturnValue + public Builder memoryService(BaseMemoryService memoryService) { + this.memoryService = memoryService; + return this; + } + + @CanIgnoreReturnValue + public Builder plugins(List plugins) { + this.plugins = plugins; + return this; + } + + @CanIgnoreReturnValue + public Builder runConfig(RunConfig runConfig) { + this.runConfig = runConfig; return this; } @CanIgnoreReturnValue public AgentExecutor build() { - if (runner == null) { - throw new IllegalStateException("Runner must be provided."); - } - return new AgentExecutor(runner); + return new AgentExecutor( + app, agent, appName, artifactService, sessionService, memoryService, plugins, runConfig); } } @@ -96,13 +169,14 @@ public void execute(RequestContext ctx, EventQueue eventQueue) { EventProcessor p = new EventProcessor(); Content content = PartConverter.messageToContent(message); + Runner runner = runnerBuilder.build(); taskDisposables.add( - prepareSession(ctx, runner.sessionService()) + prepareSession(ctx, runner.appName(), runner.sessionService()) .flatMapPublisher( session -> { updater.startWork(); - return runner.runAsync(getUserId(ctx), session.id(), content, DEFAULT_RUN_CONFIG); + return runner.runAsync(getUserId(ctx), session.id(), content, runConfig); }) .subscribe( event -> { @@ -130,13 +204,14 @@ private String getUserId(RequestContext ctx) { return USER_ID_PREFIX + ctx.getContextId(); } - private Maybe prepareSession(RequestContext ctx, BaseSessionService service) { + private Maybe prepareSession( + RequestContext ctx, String appName, BaseSessionService service) { return service - .getSession(runner.appName(), getUserId(ctx), ctx.getContextId(), Optional.empty()) + .getSession(appName, getUserId(ctx), ctx.getContextId(), Optional.empty()) .switchIfEmpty( Maybe.defer( () -> { - return service.createSession(runner.appName(), getUserId(ctx)).toMaybe(); + return service.createSession(appName, getUserId(ctx)).toMaybe(); })); }