From b16388e6d9ad3b0e4c50b29509488022d9f62681 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 23 Dec 2025 10:18:23 -0800 Subject: [PATCH] feat: Add VertexAiSearchTool and AgentTools for search PiperOrigin-RevId: 848220298 --- .../java/com/google/adk/tools/AgentTool.java | 6 + .../adk/tools/GoogleSearchAgentTool.java | 50 +++++++ .../adk/tools/VertexAiSearchAgentTool.java | 51 +++++++ .../google/adk/tools/VertexAiSearchTool.java | 133 +++++++++++++++++ .../adk/tools/GoogleSearchAgentToolTest.java | 22 +++ .../tools/VertexAiSearchAgentToolTest.java | 25 ++++ .../adk/tools/VertexAiSearchToolTest.java | 138 ++++++++++++++++++ 7 files changed, 425 insertions(+) create mode 100644 core/src/main/java/com/google/adk/tools/GoogleSearchAgentTool.java create mode 100644 core/src/main/java/com/google/adk/tools/VertexAiSearchAgentTool.java create mode 100644 core/src/main/java/com/google/adk/tools/VertexAiSearchTool.java create mode 100644 core/src/test/java/com/google/adk/tools/GoogleSearchAgentToolTest.java create mode 100644 core/src/test/java/com/google/adk/tools/VertexAiSearchAgentToolTest.java create mode 100644 core/src/test/java/com/google/adk/tools/VertexAiSearchToolTest.java diff --git a/core/src/main/java/com/google/adk/tools/AgentTool.java b/core/src/main/java/com/google/adk/tools/AgentTool.java index c31715fef..43f0797aa 100644 --- a/core/src/main/java/com/google/adk/tools/AgentTool.java +++ b/core/src/main/java/com/google/adk/tools/AgentTool.java @@ -28,6 +28,7 @@ import com.google.adk.events.Event; import com.google.adk.runner.InMemoryRunner; import com.google.adk.runner.Runner; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; @@ -76,6 +77,11 @@ protected AgentTool(BaseAgent agent, boolean skipSummarization) { this.skipSummarization = skipSummarization; } + @VisibleForTesting + BaseAgent getAgent() { + return agent; + } + @Override public Optional declaration() { FunctionDeclaration.Builder builder = diff --git a/core/src/main/java/com/google/adk/tools/GoogleSearchAgentTool.java b/core/src/main/java/com/google/adk/tools/GoogleSearchAgentTool.java new file mode 100644 index 000000000..29c09028a --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/GoogleSearchAgentTool.java @@ -0,0 +1,50 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools; + +import com.google.adk.agents.LlmAgent; +import com.google.adk.models.BaseLlm; +import com.google.common.collect.ImmutableList; + +/** + * A tool that wraps a sub-agent that only uses google_search tool. + * + *

This is a workaround to support using google_search tool with other tools. TODO(b/448114567): + * Remove once the workaround is no longer needed. + */ +public class GoogleSearchAgentTool extends AgentTool { + + public static GoogleSearchAgentTool create(BaseLlm model) { + LlmAgent googleSearchAgent = + LlmAgent.builder() + .name("google_search_agent") + .model(model) + .description("An agent for performing Google search using the `google_search` tool") + .instruction( + " You are a specialized Google search agent.\n" + + "\n" + + " When given a search query, use the `google_search` tool to find the" + + " related information.") + .tools(ImmutableList.of(GoogleSearchTool.INSTANCE)) + .build(); + return new GoogleSearchAgentTool(googleSearchAgent); + } + + protected GoogleSearchAgentTool(LlmAgent agent) { + super(agent, false); + } +} diff --git a/core/src/main/java/com/google/adk/tools/VertexAiSearchAgentTool.java b/core/src/main/java/com/google/adk/tools/VertexAiSearchAgentTool.java new file mode 100644 index 000000000..c8d1c34fc --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/VertexAiSearchAgentTool.java @@ -0,0 +1,51 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools; + +import com.google.adk.agents.LlmAgent; +import com.google.adk.models.BaseLlm; +import com.google.common.collect.ImmutableList; + +/** + * A tool that wraps a sub-agent that only uses vertex_ai_search tool. + * + *

This is a workaround to support using {@link VertexAiSearchTool} tool with other tools. + */ +public class VertexAiSearchAgentTool extends AgentTool { + + public static VertexAiSearchAgentTool create( + BaseLlm model, VertexAiSearchTool vertexAiSearchTool) { + LlmAgent vertexAiSearchAgent = + LlmAgent.builder() + .name("vertex_ai_search_agent") + .model(model) + .description( + "An agent for performing Vertex AI search using the `vertex_ai_search` tool") + .instruction( + " You are a specialized Vertex AI search agent.\n" + + "\n" + + " When given a search query, use the `vertex_ai_search` tool to find" + + " the related information.") + .tools(ImmutableList.of(vertexAiSearchTool)) + .build(); + return new VertexAiSearchAgentTool(vertexAiSearchAgent); + } + + protected VertexAiSearchAgentTool(LlmAgent agent) { + super(agent, false); + } +} diff --git a/core/src/main/java/com/google/adk/tools/VertexAiSearchTool.java b/core/src/main/java/com/google/adk/tools/VertexAiSearchTool.java new file mode 100644 index 000000000..ed8d2a6f4 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/VertexAiSearchTool.java @@ -0,0 +1,133 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools; + +import com.google.adk.models.LlmRequest; +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Retrieval; +import com.google.genai.types.Tool; +import com.google.genai.types.VertexAISearch; +import com.google.genai.types.VertexAISearchDataStoreSpec; +import io.reactivex.rxjava3.core.Completable; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +/** + * A built-in tool using Vertex AI Search. + * + *

This tool can be configured with either a {@code dataStoreId} (the Vertex AI search data store + * resource ID) or a {@code searchEngineId} (the Vertex AI search engine resource ID). + */ +@AutoValue +public abstract class VertexAiSearchTool extends BaseTool { + public abstract Optional dataStoreId(); + + public abstract Optional> dataStoreSpecs(); + + public abstract Optional searchEngineId(); + + public abstract Optional filter(); + + public abstract Optional maxResults(); + + public abstract Optional project(); + + public abstract Optional location(); + + public abstract Optional dataStore(); + + public static Builder builder() { + return new AutoValue_VertexAiSearchTool.Builder(); + } + + VertexAiSearchTool() { + super("vertex_ai_search", "vertex_ai_search"); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + LlmRequest llmRequest = llmRequestBuilder.build(); + + VertexAISearch.Builder vertexAiSearchBuilder = VertexAISearch.builder(); + dataStoreId().ifPresent(vertexAiSearchBuilder::datastore); + searchEngineId().ifPresent(vertexAiSearchBuilder::engine); + filter().ifPresent(vertexAiSearchBuilder::filter); + maxResults().ifPresent(vertexAiSearchBuilder::maxResults); + dataStoreSpecs().ifPresent(vertexAiSearchBuilder::dataStoreSpecs); + + Tool retrievalTool = + Tool.builder() + .retrieval(Retrieval.builder().vertexAiSearch(vertexAiSearchBuilder.build()).build()) + .build(); + + List currentTools = new ArrayList<>(); + if (llmRequest.config().isPresent() && llmRequest.config().get().tools().isPresent()) { + currentTools.addAll(llmRequest.config().get().tools().get()); + } + currentTools.add(retrievalTool); + + GenerateContentConfig newConfig = + llmRequest + .config() + .map(GenerateContentConfig::toBuilder) + .orElse(GenerateContentConfig.builder()) + .tools(ImmutableList.copyOf(currentTools)) + .build(); + llmRequestBuilder.config(newConfig); + return Completable.complete(); + } + + /** Builder for {@link VertexAiSearchTool}. */ + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder dataStoreId(String dataStoreId); + + public abstract Builder dataStoreSpecs(List dataStoreSpecs); + + public abstract Builder searchEngineId(String searchEngineId); + + public abstract Builder filter(String filter); + + public abstract Builder maxResults(Integer maxResults); + + public abstract Builder project(String project); + + public abstract Builder location(String location); + + public abstract Builder dataStore(String dataStore); + + abstract VertexAiSearchTool autoBuild(); + + public final VertexAiSearchTool build() { + VertexAiSearchTool tool = autoBuild(); + if ((tool.dataStoreId().isEmpty() && tool.searchEngineId().isEmpty()) + || (tool.dataStoreId().isPresent() && tool.searchEngineId().isPresent())) { + throw new IllegalArgumentException( + "Either dataStoreId or searchEngineId must be specified."); + } + if (tool.dataStoreSpecs().isPresent() && tool.searchEngineId().isEmpty()) { + throw new IllegalArgumentException( + "searchEngineId must be specified if dataStoreSpecs is specified."); + } + return tool; + } + } +} diff --git a/core/src/test/java/com/google/adk/tools/GoogleSearchAgentToolTest.java b/core/src/test/java/com/google/adk/tools/GoogleSearchAgentToolTest.java new file mode 100644 index 000000000..8305274a1 --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/GoogleSearchAgentToolTest.java @@ -0,0 +1,22 @@ +package com.google.adk.tools; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.adk.agents.LlmAgent; +import com.google.adk.testing.TestLlm; +import com.google.common.collect.ImmutableList; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class GoogleSearchAgentToolTest { + + @Test + public void create_createsAgent() { + GoogleSearchAgentTool tool = GoogleSearchAgentTool.create(new TestLlm(ImmutableList.of())); + assertThat(tool.getAgent().name()).isEqualTo("google_search_agent"); + assertThat(((LlmAgent) tool.getAgent()).tools().blockingGet()) + .containsExactly(GoogleSearchTool.INSTANCE); + } +} diff --git a/core/src/test/java/com/google/adk/tools/VertexAiSearchAgentToolTest.java b/core/src/test/java/com/google/adk/tools/VertexAiSearchAgentToolTest.java new file mode 100644 index 000000000..a93030735 --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/VertexAiSearchAgentToolTest.java @@ -0,0 +1,25 @@ +package com.google.adk.tools; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.adk.agents.LlmAgent; +import com.google.adk.testing.TestLlm; +import com.google.common.collect.ImmutableList; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class VertexAiSearchAgentToolTest { + + @Test + public void create_createsAgent() { + VertexAiSearchTool vertexAiSearchTool = + VertexAiSearchTool.builder().searchEngineId("test-engine").build(); + VertexAiSearchAgentTool tool = + VertexAiSearchAgentTool.create(new TestLlm(ImmutableList.of()), vertexAiSearchTool); + assertThat(tool.getAgent().name()).isEqualTo("vertex_ai_search_agent"); + assertThat(((LlmAgent) tool.getAgent()).tools().blockingGet()) + .containsExactly(vertexAiSearchTool); + } +} diff --git a/core/src/test/java/com/google/adk/tools/VertexAiSearchToolTest.java b/core/src/test/java/com/google/adk/tools/VertexAiSearchToolTest.java new file mode 100644 index 000000000..64b0edc4f --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/VertexAiSearchToolTest.java @@ -0,0 +1,138 @@ +package com.google.adk.tools; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.when; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.Session; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Retrieval; +import com.google.genai.types.Tool; +import com.google.genai.types.VertexAISearch; +import com.google.genai.types.VertexAISearchDataStoreSpec; +import java.util.List; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public final class VertexAiSearchToolTest { + + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + @Mock InvocationContext invocationContext; + Session session = Session.builder("test-session").build(); + + @Before + public void setUp() { + when(invocationContext.session()).thenReturn(session); + } + + @Test + public void build_noDataStoreIdOrSearchEngineId_throwsException() { + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> VertexAiSearchTool.builder().build()); + assertThat(exception) + .hasMessageThat() + .isEqualTo("Either dataStoreId or searchEngineId must be specified."); + } + + @Test + public void build_bothDataStoreIdAndSearchEngineId_throwsException() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> VertexAiSearchTool.builder().dataStoreId("ds1").searchEngineId("se1").build()); + assertThat(exception) + .hasMessageThat() + .isEqualTo("Either dataStoreId or searchEngineId must be specified."); + } + + @Test + public void build_dataStoreSpecsWithoutSearchEngineId_throwsException() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + VertexAiSearchTool.builder() + .dataStoreId("ds1") + .dataStoreSpecs(ImmutableList.of()) + .build()); + assertThat(exception) + .hasMessageThat() + .isEqualTo("searchEngineId must be specified if dataStoreSpecs is specified."); + } + + @Test + public void build_withDataStoreId_succeeds() { + VertexAiSearchTool tool = VertexAiSearchTool.builder().dataStoreId("ds1").build(); + assertThat(tool.dataStoreId()).hasValue("ds1"); + } + + @Test + public void build_withSearchEngineId_succeeds() { + VertexAiSearchTool tool = VertexAiSearchTool.builder().searchEngineId("se1").build(); + assertThat(tool.searchEngineId()).hasValue("se1"); + } + + @Test + public void processLlmRequest_addsRetrievalTool() { + VertexAiSearchTool tool = + VertexAiSearchTool.builder().searchEngineId("se1").filter("filter1").maxResults(10).build(); + LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); + + tool.processLlmRequest(llmRequestBuilder, ToolContext.builder(invocationContext).build()) + .blockingAwait(); + + LlmRequest llmRequest = llmRequestBuilder.build(); + assertThat(llmRequest.config()).isPresent(); + GenerateContentConfig config = llmRequest.config().get(); + assertThat(config.tools()).isPresent(); + List tools = config.tools().get(); + assertThat(tools).hasSize(1); + Tool retrievalTool = tools.get(0); + assertThat(retrievalTool.retrieval()).isPresent(); + Retrieval retrieval = retrievalTool.retrieval().get(); + assertThat(retrieval.vertexAiSearch()).isPresent(); + VertexAISearch vertexAiSearch = retrieval.vertexAiSearch().get(); + assertThat(vertexAiSearch.engine()).hasValue("se1"); + assertThat(vertexAiSearch.filter()).hasValue("filter1"); + assertThat(vertexAiSearch.maxResults()).hasValue(10); + } + + @Test + public void processLlmRequest_withDataStoreSpecs_addsRetrievalTool() { + VertexAISearchDataStoreSpec spec = + VertexAISearchDataStoreSpec.builder().dataStore("ds1").build(); + VertexAiSearchTool tool = + VertexAiSearchTool.builder() + .searchEngineId("se1") + .dataStoreSpecs(ImmutableList.of(spec)) + .build(); + LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); + tool.processLlmRequest(llmRequestBuilder, ToolContext.builder(invocationContext).build()) + .blockingAwait(); + LlmRequest llmRequest = llmRequestBuilder.build(); + assertThat( + llmRequest + .config() + .get() + .tools() + .get() + .get(0) + .retrieval() + .get() + .vertexAiSearch() + .get() + .dataStoreSpecs() + .get()) + .containsExactly(spec); + } +}