Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/google/adk/code_executors/built_in_code_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ..agents.invocation_context import InvocationContext
from ..models import LlmRequest
from ..utils.model_name_utils import is_gemini_2_or_above
from ..utils.model_name_utils import is_gemini_model_id_check_disabled
from .base_code_executor import BaseCodeExecutor
from .code_execution_utils import CodeExecutionInput
from .code_execution_utils import CodeExecutionResult
Expand All @@ -42,7 +43,8 @@ def execute_code(

def process_llm_request(self, llm_request: LlmRequest) -> None:
"""Pre-process the LLM request for Gemini 2.0+ models to use the code execution tool."""
if is_gemini_2_or_above(llm_request.model):
model_check_disabled = is_gemini_model_id_check_disabled()
if is_gemini_2_or_above(llm_request.model) or model_check_disabled:
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tools = llm_request.config.tools or []
llm_request.config.tools.append(
Expand Down
99 changes: 89 additions & 10 deletions src/google/adk/memory/vertex_ai_memory_bank_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@
'wait_for_completion',
})

_ENABLE_CONSOLIDATION_KEY = 'enable_consolidation'
# Vertex docs for GenerateMemoriesRequest.DirectMemoriesSource allow
# at most 5 direct_memories per request.
_MAX_DIRECT_MEMORIES_PER_GENERATE_CALL = 5


def _supports_generate_memories_metadata() -> bool:
"""Returns whether installed Vertex SDK supports config.metadata."""
Expand Down Expand Up @@ -160,6 +165,11 @@ def __init__(
not use Google AI Studio API key for this field. For more details, visit
https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview
"""
if not agent_engine_id:
raise ValueError(
'agent_engine_id is required for VertexAiMemoryBankService.'
)

self._project = project
self._location = location
self._agent_engine_id = agent_engine_id
Expand Down Expand Up @@ -219,7 +229,22 @@ async def add_memory(
memories: Sequence[MemoryEntry],
custom_metadata: Mapping[str, object] | None = None,
) -> None:
"""Adds explicit memory items via Vertex memories.create."""
"""Adds explicit memory items using Vertex Memory Bank.

By default, this writes directly via `memories.create`.
If `custom_metadata["enable_consolidation"]` is set to True, this uses
`memories.generate` with `direct_memories_source` so provided memories are
consolidated server-side.
"""
if _is_consolidation_enabled(custom_metadata):
await self._add_memories_via_generate_direct_memories_source(
app_name=app_name,
user_id=user_id,
memories=memories,
custom_metadata=custom_metadata,
)
return

await self._add_memories_via_create(
app_name=app_name,
user_id=user_id,
Expand All @@ -235,9 +260,6 @@ async def _add_events_to_memory_from_events(
events_to_process: Sequence[Event],
custom_metadata: Mapping[str, object] | None = None,
) -> None:
if not self._agent_engine_id:
raise ValueError('Agent Engine ID is required for Memory Bank.')

direct_events = []
for event in events_to_process:
if _should_filter_out_event(event.content):
Expand Down Expand Up @@ -272,9 +294,6 @@ async def _add_memories_via_create(
custom_metadata: Mapping[str, object] | None = None,
) -> None:
"""Adds direct memory items without server-side extraction."""
if not self._agent_engine_id:
raise ValueError('Agent Engine ID is required for Memory Bank.')

normalized_memories = _normalize_memories_for_create(memories)
api_client = self._get_api_client()
for index, memory in enumerate(normalized_memories):
Expand All @@ -300,11 +319,41 @@ async def _add_memories_via_create(
logger.info('Create memory response received.')
logger.debug('Create memory response: %s', operation)

async def _add_memories_via_generate_direct_memories_source(
self,
*,
app_name: str,
user_id: str,
memories: Sequence[MemoryEntry],
custom_metadata: Mapping[str, object] | None = None,
) -> None:
"""Adds memories via generate API with direct_memories_source."""
normalized_memories = _normalize_memories_for_create(memories)
memory_texts = [
_memory_entry_to_fact(m, index=i)
for i, m in enumerate(normalized_memories)
]
api_client = self._get_api_client()
config = _build_generate_memories_config(custom_metadata)
for memory_batch in _iter_memory_batches(memory_texts):
operation = await api_client.agent_engines.memories.generate(
name='reasoningEngines/' + self._agent_engine_id,
direct_memories_source={
'direct_memories': [
{'fact': memory_text} for memory_text in memory_batch
]
},
scope={
'app_name': app_name,
'user_id': user_id,
},
config=config,
)
logger.info('Generate direct memory response received.')
logger.debug('Generate direct memory response: %s', operation)

@override
async def search_memory(self, *, app_name: str, user_id: str, query: str):
if not self._agent_engine_id:
raise ValueError('Agent Engine ID is required for Memory Bank.')

api_client = self._get_api_client()
retrieved_memories_iterator = (
await api_client.agent_engines.memories.retrieve(
Expand Down Expand Up @@ -379,6 +428,8 @@ def _build_generate_memories_config(

metadata_by_key: dict[str, object] = {}
for key, value in custom_metadata.items():
if key == _ENABLE_CONSOLIDATION_KEY:
continue
if key == 'ttl':
if value is None:
continue
Expand Down Expand Up @@ -456,6 +507,8 @@ def _build_create_memory_config(
metadata_by_key: dict[str, object] = {}
custom_revision_labels: dict[str, str] = {}
for key, value in (custom_metadata or {}).items():
if key == _ENABLE_CONSOLIDATION_KEY:
continue
if key == 'metadata':
if value is None:
continue
Expand Down Expand Up @@ -641,6 +694,32 @@ def _extract_revision_labels(
return revision_labels


def _is_consolidation_enabled(
custom_metadata: Mapping[str, object] | None,
) -> bool:
"""Returns whether direct memories should be consolidated via generate API."""
if not custom_metadata:
return False
enable_consolidation = custom_metadata.get(_ENABLE_CONSOLIDATION_KEY)
if enable_consolidation is None:
return False
if not isinstance(enable_consolidation, bool):
raise TypeError(
f'custom_metadata["{_ENABLE_CONSOLIDATION_KEY}"] must be a bool.'
)
return enable_consolidation


def _iter_memory_batches(memories: Sequence[str]) -> Sequence[Sequence[str]]:
"""Returns memory slices that comply with direct_memories limits."""
memory_batches: list[Sequence[str]] = []
for index in range(0, len(memories), _MAX_DIRECT_MEMORIES_PER_GENERATE_CALL):
memory_batches.append(
memories[index : index + _MAX_DIRECT_MEMORIES_PER_GENERATE_CALL]
)
return memory_batches


def _build_vertex_metadata(
metadata_by_key: Mapping[str, object],
) -> dict[str, object]:
Expand Down
32 changes: 32 additions & 0 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,19 @@ async def _append_new_message_to_session(
new_message.parts[i] = types.Part(
text=f'Uploaded file: {file_name}. It is saved into artifacts'
)

if self._has_duplicate_user_event_for_invocation(
session=session,
invocation_id=invocation_context.invocation_id,
new_message=new_message,
state_delta=state_delta,
):
logger.info(
'Skipping duplicate user event append for invocation_id=%s',
invocation_context.invocation_id,
)
return

# Appends only. We do not yield the event because it's not from the model.
if state_delta:
event = Event(
Expand All @@ -918,6 +931,25 @@ async def _append_new_message_to_session(

await self.session_service.append_event(session=session, event=event)

def _has_duplicate_user_event_for_invocation(
self,
*,
session: Session,
invocation_id: str,
new_message: types.Content,
state_delta: Optional[dict[str, Any]],
) -> bool:
expected_state_delta = state_delta or {}
for event in session.events:
if event.invocation_id != invocation_id or event.author != 'user':
continue
if (
event.content == new_message
and event.actions.state_delta == expected_state_delta
):
return True
return False
Comment on lines +943 to +951
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and conciseness, you can refactor this loop into a single statement using a generator expression with any(). This is a more Pythonic way to check for the existence of an item in a sequence that matches a condition.

    return any(
        event.content == new_message
        and event.actions.state_delta == expected_state_delta
        for event in session.events
        if event.author == "user" and event.invocation_id == invocation_id
    )


async def run_live(
self,
*,
Expand Down
9 changes: 6 additions & 3 deletions src/google/adk/tools/enterprise_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from ..utils.model_name_utils import is_gemini_1_model
from ..utils.model_name_utils import is_gemini_model
from ..utils.model_name_utils import is_gemini_model_id_check_disabled
from .base_tool import BaseTool
from .tool_context import ToolContext

Expand Down Expand Up @@ -54,14 +55,16 @@ async def process_llm_request(
tool_context: ToolContext,
llm_request: LlmRequest,
) -> None:
if is_gemini_model(llm_request.model):
model_check_disabled = is_gemini_model_id_check_disabled()
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tools = llm_request.config.tools or []

if is_gemini_model(llm_request.model) or model_check_disabled:
if is_gemini_1_model(llm_request.model) and llm_request.config.tools:
raise ValueError(
'Enterprise Web Search tool cannot be used with other tools in'
' Gemini 1.x.'
)
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tools = llm_request.config.tools or []
llm_request.config.tools.append(
types.Tool(enterprise_web_search=types.EnterpriseWebSearch())
)
Expand Down
4 changes: 3 additions & 1 deletion src/google/adk/tools/google_maps_grounding_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from ..utils.model_name_utils import is_gemini_1_model
from ..utils.model_name_utils import is_gemini_model
from ..utils.model_name_utils import is_gemini_model_id_check_disabled
from .base_tool import BaseTool
from .tool_context import ToolContext

Expand Down Expand Up @@ -49,13 +50,14 @@ async def process_llm_request(
tool_context: ToolContext,
llm_request: LlmRequest,
) -> None:
model_check_disabled = is_gemini_model_id_check_disabled()
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tools = llm_request.config.tools or []
if is_gemini_1_model(llm_request.model):
raise ValueError(
'Google Maps grounding tool cannot be used with Gemini 1.x models.'
)
elif is_gemini_model(llm_request.model):
elif is_gemini_model(llm_request.model) or model_check_disabled:
llm_request.config.tools.append(
types.Tool(google_maps=types.GoogleMaps())
)
Expand Down
4 changes: 3 additions & 1 deletion src/google/adk/tools/google_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from ..utils.model_name_utils import is_gemini_1_model
from ..utils.model_name_utils import is_gemini_model
from ..utils.model_name_utils import is_gemini_model_id_check_disabled
from .base_tool import BaseTool
from .tool_context import ToolContext

Expand Down Expand Up @@ -67,6 +68,7 @@ async def process_llm_request(
if self.model is not None:
llm_request.model = self.model

model_check_disabled = is_gemini_model_id_check_disabled()
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tools = llm_request.config.tools or []
if is_gemini_1_model(llm_request.model):
Expand All @@ -77,7 +79,7 @@ async def process_llm_request(
llm_request.config.tools.append(
types.Tool(google_search_retrieval=types.GoogleSearchRetrieval())
)
elif is_gemini_model(llm_request.model):
elif is_gemini_model(llm_request.model) or model_check_disabled:
llm_request.config.tools.append(
types.Tool(google_search=types.GoogleSearch())
)
Expand Down
4 changes: 3 additions & 1 deletion src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing_extensions import override

from ...utils.model_name_utils import is_gemini_2_or_above
from ...utils.model_name_utils import is_gemini_model_id_check_disabled
from ..tool_context import ToolContext
from .base_retrieval_tool import BaseRetrievalTool

Expand Down Expand Up @@ -63,7 +64,8 @@ async def process_llm_request(
llm_request: LlmRequest,
) -> None:
# Use Gemini built-in Vertex AI RAG tool for Gemini 2 models.
if is_gemini_2_or_above(llm_request.model):
model_check_disabled = is_gemini_model_id_check_disabled()
if is_gemini_2_or_above(llm_request.model) or model_check_disabled:
llm_request.config = (
types.GenerateContentConfig()
if not llm_request.config
Expand Down
4 changes: 3 additions & 1 deletion src/google/adk/tools/url_context_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from ..utils.model_name_utils import is_gemini_1_model
from ..utils.model_name_utils import is_gemini_2_or_above
from ..utils.model_name_utils import is_gemini_model_id_check_disabled
from .base_tool import BaseTool
from .tool_context import ToolContext

Expand All @@ -46,11 +47,12 @@ async def process_llm_request(
tool_context: ToolContext,
llm_request: LlmRequest,
) -> None:
model_check_disabled = is_gemini_model_id_check_disabled()
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tools = llm_request.config.tools or []
if is_gemini_1_model(llm_request.model):
raise ValueError('Url context tool cannot be used in Gemini 1.x.')
elif is_gemini_2_or_above(llm_request.model):
elif is_gemini_2_or_above(llm_request.model) or model_check_disabled:
llm_request.config.tools.append(
types.Tool(url_context=types.UrlContext())
)
Expand Down
9 changes: 6 additions & 3 deletions src/google/adk/tools/vertex_ai_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..agents.readonly_context import ReadonlyContext
from ..utils.model_name_utils import is_gemini_1_model
from ..utils.model_name_utils import is_gemini_model
from ..utils.model_name_utils import is_gemini_model_id_check_disabled
from .base_tool import BaseTool
from .tool_context import ToolContext

Expand Down Expand Up @@ -141,14 +142,16 @@ async def process_llm_request(
tool_context: ToolContext,
llm_request: LlmRequest,
) -> None:
if is_gemini_model(llm_request.model):
model_check_disabled = is_gemini_model_id_check_disabled()
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tools = llm_request.config.tools or []

if is_gemini_model(llm_request.model) or model_check_disabled:
if is_gemini_1_model(llm_request.model) and llm_request.config.tools:
raise ValueError(
'Vertex AI search tool cannot be used with other tools in Gemini'
' 1.x.'
)
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tools = llm_request.config.tools or []

# Build the search config (can be overridden by subclasses)
vertex_ai_search_config = self._build_vertex_ai_search_config(
Expand Down
13 changes: 13 additions & 0 deletions src/google/adk/utils/model_name_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@
from packaging.version import InvalidVersion
from packaging.version import Version

from .env_utils import is_env_enabled

_DISABLE_GEMINI_MODEL_ID_CHECK_ENV_VAR = 'ADK_DISABLE_GEMINI_MODEL_ID_CHECK'


def is_gemini_model_id_check_disabled() -> bool:
"""Returns True when Gemini model-id validation should be bypassed.

This opt-in environment variable is intended for internal usage where model
ids may not follow the public ``gemini-*`` naming convention.
"""
return is_env_enabled(_DISABLE_GEMINI_MODEL_ID_CHECK_ENV_VAR)


def extract_model_name(model_string: str) -> str:
"""Extract the actual model name from either simple or path-based format.
Expand Down
Loading