Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/google/adk/code_executors/built_in_code_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ..agents.invocation_context import InvocationContext
from ..models import LlmRequest
from ..utils.model_name_utils import is_gemini_2_or_above
from ..utils.model_name_utils import is_gemini_model_id_check_disabled
from .base_code_executor import BaseCodeExecutor
from .code_execution_utils import CodeExecutionInput
from .code_execution_utils import CodeExecutionResult
Expand All @@ -42,7 +43,8 @@ def execute_code(

def process_llm_request(self, llm_request: LlmRequest) -> None:
"""Pre-process the LLM request for Gemini 2.0+ models to use the code execution tool."""
if is_gemini_2_or_above(llm_request.model):
model_check_disabled = is_gemini_model_id_check_disabled()
if is_gemini_2_or_above(llm_request.model) or model_check_disabled:
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tools = llm_request.config.tools or []
llm_request.config.tools.append(
Expand Down
35 changes: 35 additions & 0 deletions src/google/adk/events/event_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,38 @@ class EventCompaction(BaseModel):
"""The compacted content of the events."""


class RewindAuditReceipt(BaseModel): # type: ignore[misc]
"""Audit receipt metadata emitted for rewind operations."""

model_config = ConfigDict(
extra='forbid',
alias_generator=alias_generators.to_camel,
populate_by_name=True,
)
"""The pydantic model config."""

rewind_before_invocation_id: str
"""The invocation ID that the rewind operation targeted."""

boundary_after_invocation_id: Optional[str] = None
"""The last invocation ID retained before the rewind boundary, if any."""

events_before_rewind: int
"""The number of events present before appending the rewind event."""

events_after_rewind: int
"""The number of pre-existing events retained after rewind filtering."""

history_before_hash: str
"""Canonical hash of the full pre-rewind event history."""

history_after_hash: str
"""Canonical hash of the retained pre-rewind event history."""

receipt_hash: str
"""Tamper-evident hash over the rewind receipt summary."""


class EventActions(BaseModel):
"""Represents the actions attached to an event."""

Expand Down Expand Up @@ -108,3 +140,6 @@ class EventActions(BaseModel):

rewind_before_invocation_id: Optional[str] = None
"""The invocation id to rewind to. This is only set for rewind event."""

rewind_audit_receipt: Optional[RewindAuditReceipt] = None
"""Structured receipt proving rewind boundaries and history digests."""
99 changes: 89 additions & 10 deletions src/google/adk/memory/vertex_ai_memory_bank_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@
'wait_for_completion',
})

_ENABLE_CONSOLIDATION_KEY = 'enable_consolidation'
# Vertex docs for GenerateMemoriesRequest.DirectMemoriesSource allow
# at most 5 direct_memories per request.
_MAX_DIRECT_MEMORIES_PER_GENERATE_CALL = 5


def _supports_generate_memories_metadata() -> bool:
"""Returns whether installed Vertex SDK supports config.metadata."""
Expand Down Expand Up @@ -160,6 +165,11 @@ def __init__(
not use Google AI Studio API key for this field. For more details, visit
https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview
"""
if not agent_engine_id:
raise ValueError(
'agent_engine_id is required for VertexAiMemoryBankService.'
)

self._project = project
self._location = location
self._agent_engine_id = agent_engine_id
Expand Down Expand Up @@ -219,7 +229,22 @@ async def add_memory(
memories: Sequence[MemoryEntry],
custom_metadata: Mapping[str, object] | None = None,
) -> None:
"""Adds explicit memory items via Vertex memories.create."""
"""Adds explicit memory items using Vertex Memory Bank.

By default, this writes directly via `memories.create`.
If `custom_metadata["enable_consolidation"]` is set to True, this uses
`memories.generate` with `direct_memories_source` so provided memories are
consolidated server-side.
"""
if _is_consolidation_enabled(custom_metadata):
await self._add_memories_via_generate_direct_memories_source(
app_name=app_name,
user_id=user_id,
memories=memories,
custom_metadata=custom_metadata,
)
return

await self._add_memories_via_create(
app_name=app_name,
user_id=user_id,
Expand All @@ -235,9 +260,6 @@ async def _add_events_to_memory_from_events(
events_to_process: Sequence[Event],
custom_metadata: Mapping[str, object] | None = None,
) -> None:
if not self._agent_engine_id:
raise ValueError('Agent Engine ID is required for Memory Bank.')

direct_events = []
for event in events_to_process:
if _should_filter_out_event(event.content):
Expand Down Expand Up @@ -272,9 +294,6 @@ async def _add_memories_via_create(
custom_metadata: Mapping[str, object] | None = None,
) -> None:
"""Adds direct memory items without server-side extraction."""
if not self._agent_engine_id:
raise ValueError('Agent Engine ID is required for Memory Bank.')

normalized_memories = _normalize_memories_for_create(memories)
api_client = self._get_api_client()
for index, memory in enumerate(normalized_memories):
Expand All @@ -300,11 +319,41 @@ async def _add_memories_via_create(
logger.info('Create memory response received.')
logger.debug('Create memory response: %s', operation)

async def _add_memories_via_generate_direct_memories_source(
self,
*,
app_name: str,
user_id: str,
memories: Sequence[MemoryEntry],
custom_metadata: Mapping[str, object] | None = None,
) -> None:
"""Adds memories via generate API with direct_memories_source."""
normalized_memories = _normalize_memories_for_create(memories)
memory_texts = [
_memory_entry_to_fact(m, index=i)
for i, m in enumerate(normalized_memories)
]
api_client = self._get_api_client()
config = _build_generate_memories_config(custom_metadata)
for memory_batch in _iter_memory_batches(memory_texts):
operation = await api_client.agent_engines.memories.generate(
name='reasoningEngines/' + self._agent_engine_id,
direct_memories_source={
'direct_memories': [
{'fact': memory_text} for memory_text in memory_batch
]
},
scope={
'app_name': app_name,
'user_id': user_id,
},
config=config,
)
logger.info('Generate direct memory response received.')
logger.debug('Generate direct memory response: %s', operation)

@override
async def search_memory(self, *, app_name: str, user_id: str, query: str):
if not self._agent_engine_id:
raise ValueError('Agent Engine ID is required for Memory Bank.')

api_client = self._get_api_client()
retrieved_memories_iterator = (
await api_client.agent_engines.memories.retrieve(
Expand Down Expand Up @@ -379,6 +428,8 @@ def _build_generate_memories_config(

metadata_by_key: dict[str, object] = {}
for key, value in custom_metadata.items():
if key == _ENABLE_CONSOLIDATION_KEY:
continue
if key == 'ttl':
if value is None:
continue
Expand Down Expand Up @@ -456,6 +507,8 @@ def _build_create_memory_config(
metadata_by_key: dict[str, object] = {}
custom_revision_labels: dict[str, str] = {}
for key, value in (custom_metadata or {}).items():
if key == _ENABLE_CONSOLIDATION_KEY:
continue
if key == 'metadata':
if value is None:
continue
Expand Down Expand Up @@ -641,6 +694,32 @@ def _extract_revision_labels(
return revision_labels


def _is_consolidation_enabled(
custom_metadata: Mapping[str, object] | None,
) -> bool:
"""Returns whether direct memories should be consolidated via generate API."""
if not custom_metadata:
return False
enable_consolidation = custom_metadata.get(_ENABLE_CONSOLIDATION_KEY)
if enable_consolidation is None:
return False
if not isinstance(enable_consolidation, bool):
raise TypeError(
f'custom_metadata["{_ENABLE_CONSOLIDATION_KEY}"] must be a bool.'
)
return enable_consolidation


def _iter_memory_batches(memories: Sequence[str]) -> Sequence[Sequence[str]]:
"""Returns memory slices that comply with direct_memories limits."""
memory_batches: list[Sequence[str]] = []
for index in range(0, len(memories), _MAX_DIRECT_MEMORIES_PER_GENERATE_CALL):
memory_batches.append(
memories[index : index + _MAX_DIRECT_MEMORIES_PER_GENERATE_CALL]
)
return memory_batches


def _build_vertex_metadata(
metadata_by_key: Mapping[str, object],
) -> dict[str, object]:
Expand Down
70 changes: 70 additions & 0 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from __future__ import annotations

import asyncio
import hashlib
import inspect
import json
import logging
from pathlib import Path
import queue
Expand Down Expand Up @@ -47,6 +49,7 @@
from .code_executors.built_in_code_executor import BuiltInCodeExecutor
from .events.event import Event
from .events.event import EventActions
from .events.event_actions import RewindAuditReceipt
from .flows.llm_flows import contents
from .flows.llm_flows.functions import find_matching_function_call
from .memory.base_memory_service import BaseMemoryService
Expand Down Expand Up @@ -591,6 +594,11 @@ async def rewind_async(
artifact_delta = await self._compute_artifact_delta_for_rewind(
session, rewind_event_index
)
rewind_audit_receipt = self._build_rewind_audit_receipt(
session=session,
rewind_event_index=rewind_event_index,
rewind_before_invocation_id=rewind_before_invocation_id,
)

# Create rewind event
rewind_event = Event(
Expand All @@ -600,13 +608,75 @@ async def rewind_async(
rewind_before_invocation_id=rewind_before_invocation_id,
state_delta=state_delta,
artifact_delta=artifact_delta,
rewind_audit_receipt=rewind_audit_receipt,
),
)

logger.info('Rewinding session to invocation: %s', rewind_event)

await self.session_service.append_event(session=session, event=rewind_event)

def _build_rewind_audit_receipt(
self,
*,
session: Session,
rewind_event_index: int,
rewind_before_invocation_id: str,
) -> RewindAuditReceipt:
"""Builds a deterministic audit receipt for a rewind operation."""
events_before = session.events
events_after = session.events[:rewind_event_index]
boundary_after_invocation_id = None
if rewind_event_index > 0:
boundary_after_invocation_id = session.events[
rewind_event_index - 1
].invocation_id

history_before_hash = self._hash_rewind_events(events_before)
history_after_hash = self._hash_rewind_events(events_after)

receipt_payload = {
'rewind_before_invocation_id': rewind_before_invocation_id,
'boundary_after_invocation_id': boundary_after_invocation_id,
'events_before_rewind': len(events_before),
'events_after_rewind': len(events_after),
'history_before_hash': history_before_hash,
'history_after_hash': history_after_hash,
}
receipt_hash = self._hash_rewind_payload(receipt_payload)

return RewindAuditReceipt(
**receipt_payload,
receipt_hash=receipt_hash,
)
Comment on lines 648 to 651
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The instantiation of RewindAuditReceipt repeats all the fields that were just defined in receipt_payload. You can simplify this and avoid duplication by unpacking the receipt_payload dictionary when creating the RewindAuditReceipt instance. This makes the code more concise and easier to maintain if the fields change in the future.

    return RewindAuditReceipt(**receipt_payload, receipt_hash=receipt_hash)


def _hash_rewind_events(self, events: List[Event]) -> str:
"""Hashes event summaries for deterministic rewind audit receipts."""
summarized_events = [
{
'event_id': event.id,
'invocation_id': event.invocation_id,
'author': event.author,
'state_delta': event.actions.state_delta,
'artifact_delta': event.actions.artifact_delta,
'rewind_before_invocation_id': (
event.actions.rewind_before_invocation_id
),
}
for event in events
]
return self._hash_rewind_payload({'events': summarized_events})

def _hash_rewind_payload(self, payload: dict[str, Any]) -> str:
"""Returns a canonical SHA-256 digest for rewind audit payloads."""
canonical_json = json.dumps(
payload,
sort_keys=True,
separators=(',', ':'),
ensure_ascii=True,
)
return hashlib.sha256(canonical_json.encode('utf-8')).hexdigest()

async def _compute_state_delta_for_rewind(
self, session: Session, rewind_event_index: int
) -> dict[str, Any]:
Expand Down
9 changes: 6 additions & 3 deletions src/google/adk/tools/enterprise_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from ..utils.model_name_utils import is_gemini_1_model
from ..utils.model_name_utils import is_gemini_model
from ..utils.model_name_utils import is_gemini_model_id_check_disabled
from .base_tool import BaseTool
from .tool_context import ToolContext

Expand Down Expand Up @@ -54,14 +55,16 @@ async def process_llm_request(
tool_context: ToolContext,
llm_request: LlmRequest,
) -> None:
if is_gemini_model(llm_request.model):
model_check_disabled = is_gemini_model_id_check_disabled()
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tools = llm_request.config.tools or []

if is_gemini_model(llm_request.model) or model_check_disabled:
if is_gemini_1_model(llm_request.model) and llm_request.config.tools:
raise ValueError(
'Enterprise Web Search tool cannot be used with other tools in'
' Gemini 1.x.'
)
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tools = llm_request.config.tools or []
llm_request.config.tools.append(
types.Tool(enterprise_web_search=types.EnterpriseWebSearch())
)
Expand Down
4 changes: 3 additions & 1 deletion src/google/adk/tools/google_maps_grounding_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from ..utils.model_name_utils import is_gemini_1_model
from ..utils.model_name_utils import is_gemini_model
from ..utils.model_name_utils import is_gemini_model_id_check_disabled
from .base_tool import BaseTool
from .tool_context import ToolContext

Expand Down Expand Up @@ -49,13 +50,14 @@ async def process_llm_request(
tool_context: ToolContext,
llm_request: LlmRequest,
) -> None:
model_check_disabled = is_gemini_model_id_check_disabled()
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tools = llm_request.config.tools or []
if is_gemini_1_model(llm_request.model):
raise ValueError(
'Google Maps grounding tool cannot be used with Gemini 1.x models.'
)
elif is_gemini_model(llm_request.model):
elif is_gemini_model(llm_request.model) or model_check_disabled:
llm_request.config.tools.append(
types.Tool(google_maps=types.GoogleMaps())
)
Expand Down
Loading