diff --git a/src/google/adk/code_executors/built_in_code_executor.py b/src/google/adk/code_executors/built_in_code_executor.py index 50a0b9f4f6..a4e3203461 100644 --- a/src/google/adk/code_executors/built_in_code_executor.py +++ b/src/google/adk/code_executors/built_in_code_executor.py @@ -20,6 +20,7 @@ from ..agents.invocation_context import InvocationContext from ..models import LlmRequest from ..utils.model_name_utils import is_gemini_2_or_above +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_code_executor import BaseCodeExecutor from .code_execution_utils import CodeExecutionInput from .code_execution_utils import CodeExecutionResult @@ -42,7 +43,8 @@ def execute_code( def process_llm_request(self, llm_request: LlmRequest) -> None: """Pre-process the LLM request for Gemini 2.0+ models to use the code execution tool.""" - if is_gemini_2_or_above(llm_request.model): + model_check_disabled = is_gemini_model_id_check_disabled() + if is_gemini_2_or_above(llm_request.model) or model_check_disabled: llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] llm_request.config.tools.append( diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py index 7bb18efae3..2218c8742b 100644 --- a/src/google/adk/memory/vertex_ai_memory_bank_service.py +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -65,6 +65,11 @@ 'wait_for_completion', }) +_ENABLE_CONSOLIDATION_KEY = 'enable_consolidation' +# Vertex docs for GenerateMemoriesRequest.DirectMemoriesSource allow +# at most 5 direct_memories per request. +_MAX_DIRECT_MEMORIES_PER_GENERATE_CALL = 5 + def _supports_generate_memories_metadata() -> bool: """Returns whether installed Vertex SDK supports config.metadata.""" @@ -160,6 +165,11 @@ def __init__( not use Google AI Studio API key for this field. For more details, visit https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview """ + if not agent_engine_id: + raise ValueError( + 'agent_engine_id is required for VertexAiMemoryBankService.' + ) + self._project = project self._location = location self._agent_engine_id = agent_engine_id @@ -219,7 +229,22 @@ async def add_memory( memories: Sequence[MemoryEntry], custom_metadata: Mapping[str, object] | None = None, ) -> None: - """Adds explicit memory items via Vertex memories.create.""" + """Adds explicit memory items using Vertex Memory Bank. + + By default, this writes directly via `memories.create`. + If `custom_metadata["enable_consolidation"]` is set to True, this uses + `memories.generate` with `direct_memories_source` so provided memories are + consolidated server-side. + """ + if _is_consolidation_enabled(custom_metadata): + await self._add_memories_via_generate_direct_memories_source( + app_name=app_name, + user_id=user_id, + memories=memories, + custom_metadata=custom_metadata, + ) + return + await self._add_memories_via_create( app_name=app_name, user_id=user_id, @@ -235,9 +260,6 @@ async def _add_events_to_memory_from_events( events_to_process: Sequence[Event], custom_metadata: Mapping[str, object] | None = None, ) -> None: - if not self._agent_engine_id: - raise ValueError('Agent Engine ID is required for Memory Bank.') - direct_events = [] for event in events_to_process: if _should_filter_out_event(event.content): @@ -272,9 +294,6 @@ async def _add_memories_via_create( custom_metadata: Mapping[str, object] | None = None, ) -> None: """Adds direct memory items without server-side extraction.""" - if not self._agent_engine_id: - raise ValueError('Agent Engine ID is required for Memory Bank.') - normalized_memories = _normalize_memories_for_create(memories) api_client = self._get_api_client() for index, memory in enumerate(normalized_memories): @@ -300,11 +319,41 @@ async def _add_memories_via_create( logger.info('Create memory response received.') logger.debug('Create memory response: %s', operation) + async def _add_memories_via_generate_direct_memories_source( + self, + *, + app_name: str, + user_id: str, + memories: Sequence[MemoryEntry], + custom_metadata: Mapping[str, object] | None = None, + ) -> None: + """Adds memories via generate API with direct_memories_source.""" + normalized_memories = _normalize_memories_for_create(memories) + memory_texts = [ + _memory_entry_to_fact(m, index=i) + for i, m in enumerate(normalized_memories) + ] + api_client = self._get_api_client() + config = _build_generate_memories_config(custom_metadata) + for memory_batch in _iter_memory_batches(memory_texts): + operation = await api_client.agent_engines.memories.generate( + name='reasoningEngines/' + self._agent_engine_id, + direct_memories_source={ + 'direct_memories': [ + {'fact': memory_text} for memory_text in memory_batch + ] + }, + scope={ + 'app_name': app_name, + 'user_id': user_id, + }, + config=config, + ) + logger.info('Generate direct memory response received.') + logger.debug('Generate direct memory response: %s', operation) + @override async def search_memory(self, *, app_name: str, user_id: str, query: str): - if not self._agent_engine_id: - raise ValueError('Agent Engine ID is required for Memory Bank.') - api_client = self._get_api_client() retrieved_memories_iterator = ( await api_client.agent_engines.memories.retrieve( @@ -379,6 +428,8 @@ def _build_generate_memories_config( metadata_by_key: dict[str, object] = {} for key, value in custom_metadata.items(): + if key == _ENABLE_CONSOLIDATION_KEY: + continue if key == 'ttl': if value is None: continue @@ -456,6 +507,8 @@ def _build_create_memory_config( metadata_by_key: dict[str, object] = {} custom_revision_labels: dict[str, str] = {} for key, value in (custom_metadata or {}).items(): + if key == _ENABLE_CONSOLIDATION_KEY: + continue if key == 'metadata': if value is None: continue @@ -641,6 +694,32 @@ def _extract_revision_labels( return revision_labels +def _is_consolidation_enabled( + custom_metadata: Mapping[str, object] | None, +) -> bool: + """Returns whether direct memories should be consolidated via generate API.""" + if not custom_metadata: + return False + enable_consolidation = custom_metadata.get(_ENABLE_CONSOLIDATION_KEY) + if enable_consolidation is None: + return False + if not isinstance(enable_consolidation, bool): + raise TypeError( + f'custom_metadata["{_ENABLE_CONSOLIDATION_KEY}"] must be a bool.' + ) + return enable_consolidation + + +def _iter_memory_batches(memories: Sequence[str]) -> Sequence[Sequence[str]]: + """Returns memory slices that comply with direct_memories limits.""" + memory_batches: list[Sequence[str]] = [] + for index in range(0, len(memories), _MAX_DIRECT_MEMORIES_PER_GENERATE_CALL): + memory_batches.append( + memories[index : index + _MAX_DIRECT_MEMORIES_PER_GENERATE_CALL] + ) + return memory_batches + + def _build_vertex_metadata( metadata_by_key: Mapping[str, object], ) -> dict[str, object]: diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index bc0251a81e..bc7b27764c 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -896,6 +896,19 @@ async def _append_new_message_to_session( new_message.parts[i] = types.Part( text=f'Uploaded file: {file_name}. It is saved into artifacts' ) + + if self._has_duplicate_user_event_for_invocation( + session=session, + invocation_id=invocation_context.invocation_id, + new_message=new_message, + state_delta=state_delta, + ): + logger.info( + 'Skipping duplicate user event append for invocation_id=%s', + invocation_context.invocation_id, + ) + return + # Appends only. We do not yield the event because it's not from the model. if state_delta: event = Event( @@ -918,6 +931,25 @@ async def _append_new_message_to_session( await self.session_service.append_event(session=session, event=event) + def _has_duplicate_user_event_for_invocation( + self, + *, + session: Session, + invocation_id: str, + new_message: types.Content, + state_delta: Optional[dict[str, Any]], + ) -> bool: + expected_state_delta = state_delta or {} + for event in session.events: + if event.invocation_id != invocation_id or event.author != 'user': + continue + if ( + event.content == new_message + and event.actions.state_delta == expected_state_delta + ): + return True + return False + async def run_live( self, *, diff --git a/src/google/adk/tools/enterprise_search_tool.py b/src/google/adk/tools/enterprise_search_tool.py index 4f7a0d7f35..c114fdb46d 100644 --- a/src/google/adk/tools/enterprise_search_tool.py +++ b/src/google/adk/tools/enterprise_search_tool.py @@ -21,6 +21,7 @@ from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_model +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -54,14 +55,16 @@ async def process_llm_request( tool_context: ToolContext, llm_request: LlmRequest, ) -> None: - if is_gemini_model(llm_request.model): + model_check_disabled = is_gemini_model_id_check_disabled() + llm_request.config = llm_request.config or types.GenerateContentConfig() + llm_request.config.tools = llm_request.config.tools or [] + + if is_gemini_model(llm_request.model) or model_check_disabled: if is_gemini_1_model(llm_request.model) and llm_request.config.tools: raise ValueError( 'Enterprise Web Search tool cannot be used with other tools in' ' Gemini 1.x.' ) - llm_request.config = llm_request.config or types.GenerateContentConfig() - llm_request.config.tools = llm_request.config.tools or [] llm_request.config.tools.append( types.Tool(enterprise_web_search=types.EnterpriseWebSearch()) ) diff --git a/src/google/adk/tools/google_maps_grounding_tool.py b/src/google/adk/tools/google_maps_grounding_tool.py index bade0a3385..d4b105ec1e 100644 --- a/src/google/adk/tools/google_maps_grounding_tool.py +++ b/src/google/adk/tools/google_maps_grounding_tool.py @@ -21,6 +21,7 @@ from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_model +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -49,13 +50,14 @@ async def process_llm_request( tool_context: ToolContext, llm_request: LlmRequest, ) -> None: + model_check_disabled = is_gemini_model_id_check_disabled() llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] if is_gemini_1_model(llm_request.model): raise ValueError( 'Google Maps grounding tool cannot be used with Gemini 1.x models.' ) - elif is_gemini_model(llm_request.model): + elif is_gemini_model(llm_request.model) or model_check_disabled: llm_request.config.tools.append( types.Tool(google_maps=types.GoogleMaps()) ) diff --git a/src/google/adk/tools/google_search_tool.py b/src/google/adk/tools/google_search_tool.py index 406ad2189e..1c11e091de 100644 --- a/src/google/adk/tools/google_search_tool.py +++ b/src/google/adk/tools/google_search_tool.py @@ -21,6 +21,7 @@ from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_model +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -67,6 +68,7 @@ async def process_llm_request( if self.model is not None: llm_request.model = self.model + model_check_disabled = is_gemini_model_id_check_disabled() llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] if is_gemini_1_model(llm_request.model): @@ -77,7 +79,7 @@ async def process_llm_request( llm_request.config.tools.append( types.Tool(google_search_retrieval=types.GoogleSearchRetrieval()) ) - elif is_gemini_model(llm_request.model): + elif is_gemini_model(llm_request.model) or model_check_disabled: llm_request.config.tools.append( types.Tool(google_search=types.GoogleSearch()) ) diff --git a/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py b/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py index 206819a9be..4d564ca164 100644 --- a/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +++ b/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py @@ -24,6 +24,7 @@ from typing_extensions import override from ...utils.model_name_utils import is_gemini_2_or_above +from ...utils.model_name_utils import is_gemini_model_id_check_disabled from ..tool_context import ToolContext from .base_retrieval_tool import BaseRetrievalTool @@ -63,7 +64,8 @@ async def process_llm_request( llm_request: LlmRequest, ) -> None: # Use Gemini built-in Vertex AI RAG tool for Gemini 2 models. - if is_gemini_2_or_above(llm_request.model): + model_check_disabled = is_gemini_model_id_check_disabled() + if is_gemini_2_or_above(llm_request.model) or model_check_disabled: llm_request.config = ( types.GenerateContentConfig() if not llm_request.config diff --git a/src/google/adk/tools/url_context_tool.py b/src/google/adk/tools/url_context_tool.py index fcdf76dab5..5e923e7447 100644 --- a/src/google/adk/tools/url_context_tool.py +++ b/src/google/adk/tools/url_context_tool.py @@ -21,6 +21,7 @@ from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_2_or_above +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -46,11 +47,12 @@ async def process_llm_request( tool_context: ToolContext, llm_request: LlmRequest, ) -> None: + model_check_disabled = is_gemini_model_id_check_disabled() llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] if is_gemini_1_model(llm_request.model): raise ValueError('Url context tool cannot be used in Gemini 1.x.') - elif is_gemini_2_or_above(llm_request.model): + elif is_gemini_2_or_above(llm_request.model) or model_check_disabled: llm_request.config.tools.append( types.Tool(url_context=types.UrlContext()) ) diff --git a/src/google/adk/tools/vertex_ai_search_tool.py b/src/google/adk/tools/vertex_ai_search_tool.py index 91fe60e553..46104c5ed4 100644 --- a/src/google/adk/tools/vertex_ai_search_tool.py +++ b/src/google/adk/tools/vertex_ai_search_tool.py @@ -24,6 +24,7 @@ from ..agents.readonly_context import ReadonlyContext from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_model +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -141,14 +142,16 @@ async def process_llm_request( tool_context: ToolContext, llm_request: LlmRequest, ) -> None: - if is_gemini_model(llm_request.model): + model_check_disabled = is_gemini_model_id_check_disabled() + llm_request.config = llm_request.config or types.GenerateContentConfig() + llm_request.config.tools = llm_request.config.tools or [] + + if is_gemini_model(llm_request.model) or model_check_disabled: if is_gemini_1_model(llm_request.model) and llm_request.config.tools: raise ValueError( 'Vertex AI search tool cannot be used with other tools in Gemini' ' 1.x.' ) - llm_request.config = llm_request.config or types.GenerateContentConfig() - llm_request.config.tools = llm_request.config.tools or [] # Build the search config (can be overridden by subclasses) vertex_ai_search_config = self._build_vertex_ai_search_config( diff --git a/src/google/adk/utils/model_name_utils.py b/src/google/adk/utils/model_name_utils.py index 4960b0b78f..57103fb2c7 100644 --- a/src/google/adk/utils/model_name_utils.py +++ b/src/google/adk/utils/model_name_utils.py @@ -22,6 +22,19 @@ from packaging.version import InvalidVersion from packaging.version import Version +from .env_utils import is_env_enabled + +_DISABLE_GEMINI_MODEL_ID_CHECK_ENV_VAR = 'ADK_DISABLE_GEMINI_MODEL_ID_CHECK' + + +def is_gemini_model_id_check_disabled() -> bool: + """Returns True when Gemini model-id validation should be bypassed. + + This opt-in environment variable is intended for internal usage where model + ids may not follow the public ``gemini-*`` naming convention. + """ + return is_env_enabled(_DISABLE_GEMINI_MODEL_ID_CHECK_ENV_VAR) + def extract_model_name(model_string: str) -> str: """Extract the actual model name from either simple or path-based format. diff --git a/tests/unittests/code_executors/test_built_in_code_executor.py b/tests/unittests/code_executors/test_built_in_code_executor.py index 58f54c7cef..cbf128fba9 100644 --- a/tests/unittests/code_executors/test_built_in_code_executor.py +++ b/tests/unittests/code_executors/test_built_in_code_executor.py @@ -97,6 +97,22 @@ def test_process_llm_request_non_gemini_2_model( ) +def test_process_llm_request_non_gemini_2_model_with_disabled_check( + built_in_executor: BuiltInCodeExecutor, + monkeypatch, +): + """Tests non-Gemini models pass when model-id check is disabled.""" + monkeypatch.setenv("ADK_DISABLE_GEMINI_MODEL_ID_CHECK", "true") + llm_request = LlmRequest(model="internal-model-v1") + + built_in_executor.process_llm_request(llm_request) + + assert llm_request.config is not None + assert llm_request.config.tools == [ + types.Tool(code_execution=types.ToolCodeExecution()) + ] + + def test_process_llm_request_no_model_name( built_in_executor: BuiltInCodeExecutor, ): diff --git a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py index 6f342a08b1..c498b8335b 100644 --- a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py +++ b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py @@ -230,6 +230,14 @@ async def test_initialize_with_project_location_and_api_key_error(): ) +def test_initialize_without_agent_engine_id_error(): + with pytest.raises( + ValueError, + match='agent_engine_id is required for VertexAiMemoryBankService', + ): + mock_vertex_ai_memory_bank_service(agent_engine_id=None) + + @pytest.mark.asyncio async def test_add_session_to_memory(mock_vertexai_client): memory_service = mock_vertex_ai_memory_bank_service() @@ -481,6 +489,7 @@ async def test_add_memory_calls_create( ), ], custom_metadata={ + 'enable_consolidation': False, 'ttl': '6000s', 'source': 'agent', }, @@ -518,6 +527,139 @@ async def test_add_memory_calls_create( vertex_common_types.AgentEngineMemoryConfig(**create_config) +@pytest.mark.asyncio +async def test_add_memory_enable_consolidation_calls_generate_direct_source( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + memories=[ + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact one')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact two')]) + ), + ], + custom_metadata={ + 'enable_consolidation': True, + 'source': 'agent', + }, + ) + + expected_config = {'wait_for_completion': False} + if _supports_generate_memories_metadata(): + expected_config['metadata'] = {'source': {'string_value': 'agent'}} + + mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with( + name='reasoningEngines/123', + direct_memories_source={ + 'direct_memories': [ + {'fact': 'fact one'}, + {'fact': 'fact two'}, + ] + }, + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config=expected_config, + ) + mock_vertexai_client.agent_engines.memories.create.assert_not_called() + + generate_config = ( + mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[ + 'config' + ] + ) + vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config) + + +@pytest.mark.asyncio +async def test_add_memory_enable_consolidation_batches_generate_calls( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + memories=[ + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact one')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact two')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact three')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact four')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact five')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact six')]) + ), + ], + custom_metadata={ + 'enable_consolidation': True, + }, + ) + + mock_vertexai_client.agent_engines.memories.generate.assert_has_awaits([ + mock.call( + name='reasoningEngines/123', + direct_memories_source={ + 'direct_memories': [ + {'fact': 'fact one'}, + {'fact': 'fact two'}, + {'fact': 'fact three'}, + {'fact': 'fact four'}, + {'fact': 'fact five'}, + ] + }, + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config={'wait_for_completion': False}, + ), + mock.call( + name='reasoningEngines/123', + direct_memories_source={ + 'direct_memories': [ + {'fact': 'fact six'}, + ] + }, + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config={'wait_for_completion': False}, + ), + ]) + assert mock_vertexai_client.agent_engines.memories.generate.await_count == 2 + mock_vertexai_client.agent_engines.memories.create.assert_not_called() + + +@pytest.mark.asyncio +async def test_add_memory_invalid_enable_consolidation_type_raises( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + with pytest.raises( + TypeError, + match=r'custom_metadata\["enable_consolidation"\] must be a bool', + ): + await memory_service.add_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + memories=[ + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact one')]) + ) + ], + custom_metadata={'enable_consolidation': 'yes'}, + ) + mock_vertexai_client.agent_engines.memories.generate.assert_not_called() + mock_vertexai_client.agent_engines.memories.create.assert_not_called() + + @pytest.mark.asyncio async def test_add_memory_calls_create_with_memory_entry_metadata( mock_vertexai_client, diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index ca7eb37533..63c14b9880 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -289,6 +289,153 @@ def _infer_agent_origin( assert event.content.parts[0].text == "Test LLM response" +@pytest.mark.asyncio +async def test_append_new_message_to_session_skips_duplicate_retry_message(): + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", + agent=MockLlmAgent("root_agent"), + session_service=session_service, + artifact_service=InMemoryArtifactService(), + ) + session = await session_service.create_session( + app_name="test_app", + user_id="test_user", + ) + user_message = types.Content( + role="user", + parts=[types.Part(text="retry message")], + ) + invocation_context = runner._new_invocation_context( + session, + invocation_id="inv-retry", + new_message=user_message, + run_config=RunConfig(), + ) + + await runner._append_new_message_to_session( + session=session, + new_message=user_message, + invocation_context=invocation_context, + ) + await runner._append_new_message_to_session( + session=session, + new_message=user_message, + invocation_context=invocation_context, + ) + + matched_events = [ + event + for event in session.events + if event.author == "user" + and event.invocation_id == "inv-retry" + and event.content == user_message + ] + assert len(matched_events) == 1 + + +@pytest.mark.asyncio +async def test_append_new_message_to_session_keeps_non_duplicate_messages(): + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", + agent=MockLlmAgent("root_agent"), + session_service=session_service, + artifact_service=InMemoryArtifactService(), + ) + session = await session_service.create_session( + app_name="test_app", + user_id="test_user", + ) + invocation_context = runner._new_invocation_context( + session, + invocation_id="inv-retry", + new_message=types.Content(role="user", parts=[types.Part(text="first")]), + run_config=RunConfig(), + ) + first_message = types.Content(role="user", parts=[types.Part(text="first")]) + second_message = types.Content(role="user", parts=[types.Part(text="second")]) + + await runner._append_new_message_to_session( + session=session, + new_message=first_message, + invocation_context=invocation_context, + ) + await runner._append_new_message_to_session( + session=session, + new_message=second_message, + invocation_context=invocation_context, + ) + + matched_events = [ + event + for event in session.events + if event.author == "user" and event.invocation_id == "inv-retry" + ] + assert len(matched_events) == 2 + + +@pytest.mark.asyncio +async def test_append_new_message_to_session_state_delta_deduping(): + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", + agent=MockLlmAgent("root_agent"), + session_service=session_service, + artifact_service=InMemoryArtifactService(), + ) + session = await session_service.create_session( + app_name="test_app", + user_id="test_user", + ) + user_message = types.Content( + role="user", parts=[types.Part(text="same message")] + ) + invocation_context = runner._new_invocation_context( + session, + invocation_id="inv-state-delta", + new_message=user_message, + run_config=RunConfig(), + ) + + await runner._append_new_message_to_session( + session=session, + new_message=user_message, + invocation_context=invocation_context, + state_delta={"attempt": 1}, + ) + await runner._append_new_message_to_session( + session=session, + new_message=user_message, + invocation_context=invocation_context, + state_delta={"attempt": 1}, + ) + await runner._append_new_message_to_session( + session=session, + new_message=user_message, + invocation_context=invocation_context, + state_delta={"attempt": 2}, + ) + await runner._append_new_message_to_session( + session=session, + new_message=user_message, + invocation_context=invocation_context, + state_delta=None, + ) + + matched_events = [ + event + for event in session.events + if event.author == "user" + and event.invocation_id == "inv-state-delta" + and event.content == user_message + ] + assert len(matched_events) == 3 + assert matched_events[0].actions.state_delta == {"attempt": 1} + assert matched_events[1].actions.state_delta == {"attempt": 2} + assert matched_events[2].actions.state_delta == {} + + @pytest.mark.asyncio async def test_rewind_auto_create_session_on_missing_session(): """When auto_create_session=True, rewind should create session if missing. diff --git a/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py b/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py index 3b5aa26f8a..0a86d07c63 100644 --- a/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py +++ b/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py @@ -145,3 +145,43 @@ def test_vertex_rag_retrieval_for_gemini_2_x(): ) ] assert 'rag_retrieval' not in mockModel.requests[0].tools_dict + + +def test_vertex_rag_retrieval_for_non_gemini_with_disabled_check(monkeypatch): + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + responses = [ + 'response1', + ] + mockModel = testing_utils.MockModel.create(responses=responses) + mockModel.model = 'internal-model-v1' + + agent = Agent( + name='root_agent', + model=mockModel, + tools=[ + VertexAiRagRetrieval( + name='rag_retrieval', + description='rag_retrieval', + rag_corpora=[ + 'projects/123456789/locations/us-central1/ragCorpora/1234567890' + ], + ) + ], + ) + runner = testing_utils.InMemoryRunner(agent) + runner.run('test1') + + assert len(mockModel.requests) == 1 + assert len(mockModel.requests[0].config.tools) == 1 + assert mockModel.requests[0].config.tools == [ + types.Tool( + retrieval=types.Retrieval( + vertex_rag_store=types.VertexRagStore( + rag_corpora=[ + 'projects/123456789/locations/us-central1/ragCorpora/1234567890' + ] + ) + ) + ) + ] + assert 'rag_retrieval' not in mockModel.requests[0].tools_dict diff --git a/tests/unittests/tools/test_enterprise_web_search_tool.py b/tests/unittests/tools/test_enterprise_web_search_tool.py index ed4715963e..7b28d858fd 100644 --- a/tests/unittests/tools/test_enterprise_web_search_tool.py +++ b/tests/unittests/tools/test_enterprise_web_search_tool.py @@ -76,6 +76,25 @@ async def test_process_llm_request_failure_with_non_gemini_models(): assert 'is not supported for model' in str(exc_info.value) +@pytest.mark.asyncio +async def test_process_llm_request_non_gemini_with_disabled_check(monkeypatch): + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = EnterpriseWebSearchTool() + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + tool_context = await _create_tool_context() + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert ( + llm_request.config.tools[0].enterprise_web_search + == types.EnterpriseWebSearch() + ) + + @pytest.mark.asyncio async def test_process_llm_request_failure_with_multiple_tools_gemini_1_models(): tool = EnterpriseWebSearchTool() diff --git a/tests/unittests/tools/test_google_maps_grounding_tool.py b/tests/unittests/tools/test_google_maps_grounding_tool.py new file mode 100644 index 0000000000..0cd2c4fa6c --- /dev/null +++ b/tests/unittests/tools/test_google_maps_grounding_tool.py @@ -0,0 +1,92 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.models.llm_request import LlmRequest +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.tools.google_maps_grounding_tool import GoogleMapsGroundingTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +import pytest + + +async def _create_tool_context() -> ToolContext: + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + agent = SequentialAgent(name='test_agent') + invocation_context = InvocationContext( + invocation_id='invocation_id', + agent=agent, + session=session, + session_service=session_service, + ) + return ToolContext(invocation_context=invocation_context) + + +class TestGoogleMapsGroundingTool: + """Tests for GoogleMapsGroundingTool.""" + + @pytest.mark.asyncio + async def test_process_llm_request_with_gemini_2_model(self): + tool = GoogleMapsGroundingTool() + tool_context = await _create_tool_context() + llm_request = LlmRequest( + model='gemini-2.5-pro', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + assert llm_request.config.tools[0].google_maps is not None + + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_model_raises_error(self): + tool = GoogleMapsGroundingTool() + tool_context = await _create_tool_context() + llm_request = LlmRequest( + model='claude-3-sonnet', config=types.GenerateContentConfig() + ) + + with pytest.raises( + ValueError, + match='Google maps tool is not supported for model claude-3-sonnet', + ): + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_and_disabled_check( + self, monkeypatch + ): + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = GoogleMapsGroundingTool() + tool_context = await _create_tool_context() + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + assert llm_request.config.tools[0].google_maps is not None diff --git a/tests/unittests/tools/test_google_search_tool.py b/tests/unittests/tools/test_google_search_tool.py index ad5d46b59e..d71061b883 100644 --- a/tests/unittests/tools/test_google_search_tool.py +++ b/tests/unittests/tools/test_google_search_tool.py @@ -268,6 +268,27 @@ async def test_process_llm_request_with_non_gemini_model_raises_error(self): tool_context=tool_context, llm_request=llm_request ) + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_model_and_disabled_check( + self, monkeypatch + ): + """Test non-Gemini model can pass when model-id check is disabled.""" + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = GoogleSearchTool() + tool_context = await _create_tool_context() + + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + assert llm_request.config.tools[0].google_search is not None + @pytest.mark.asyncio async def test_process_llm_request_with_path_based_non_gemini_model_raises_error( self, diff --git a/tests/unittests/tools/test_url_context_tool.py b/tests/unittests/tools/test_url_context_tool.py index 53ee7e6277..8fd44b59cb 100644 --- a/tests/unittests/tools/test_url_context_tool.py +++ b/tests/unittests/tools/test_url_context_tool.py @@ -190,6 +190,27 @@ async def test_process_llm_request_with_non_gemini_model_raises_error(self): tool_context=tool_context, llm_request=llm_request ) + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_model_and_disabled_check( + self, monkeypatch + ): + """Test non-Gemini model can pass when model-id check is disabled.""" + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = UrlContextTool() + tool_context = await _create_tool_context() + + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + assert llm_request.config.tools[0].url_context is not None + @pytest.mark.asyncio async def test_process_llm_request_with_path_based_non_gemini_model_raises_error( self, diff --git a/tests/unittests/tools/test_vertex_ai_search_tool.py b/tests/unittests/tools/test_vertex_ai_search_tool.py index 3ade634da6..b15d3a1f64 100644 --- a/tests/unittests/tools/test_vertex_ai_search_tool.py +++ b/tests/unittests/tools/test_vertex_ai_search_tool.py @@ -376,6 +376,29 @@ async def test_process_llm_request_with_non_gemini_model_raises_error(self): tool_context=tool_context, llm_request=llm_request ) + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_model_and_disabled_check( + self, monkeypatch + ): + """Test non-Gemini model can pass when model-id check is disabled.""" + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = VertexAiSearchTool(data_store_id='test_data_store') + tool_context = await _create_tool_context() + + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + retrieval_tool = llm_request.config.tools[0] + assert retrieval_tool.retrieval is not None + assert retrieval_tool.retrieval.vertex_ai_search is not None + @pytest.mark.asyncio async def test_process_llm_request_with_path_based_non_gemini_model_raises_error( self, diff --git a/tests/unittests/utils/test_model_name_utils.py b/tests/unittests/utils/test_model_name_utils.py index cbac37e3f7..2af1584b05 100644 --- a/tests/unittests/utils/test_model_name_utils.py +++ b/tests/unittests/utils/test_model_name_utils.py @@ -18,6 +18,7 @@ from google.adk.utils.model_name_utils import is_gemini_1_model from google.adk.utils.model_name_utils import is_gemini_2_or_above from google.adk.utils.model_name_utils import is_gemini_model +from google.adk.utils.model_name_utils import is_gemini_model_id_check_disabled class TestExtractModelName: @@ -318,3 +319,15 @@ def test_path_vs_simple_model_consistency(self): f'Inconsistent Gemini 2.0+ classification for {simple_model} vs' f' {path_model}' ) + + +class TestGeminiModelIdCheckFlag: + """Tests for Gemini model-id check override flag.""" + + def test_default_is_disabled(self, monkeypatch): + monkeypatch.delenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', raising=False) + assert is_gemini_model_id_check_disabled() is False + + def test_true_enables_check_bypass(self, monkeypatch): + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + assert is_gemini_model_id_check_disabled() is True