diff --git a/src/google/adk/events/event_actions.py b/src/google/adk/events/event_actions.py index fe8556088f..b3fe665455 100644 --- a/src/google/adk/events/event_actions.py +++ b/src/google/adk/events/event_actions.py @@ -47,6 +47,38 @@ class EventCompaction(BaseModel): """The compacted content of the events.""" +class RewindAuditReceipt(BaseModel): # type: ignore[misc] + """Audit receipt metadata emitted for rewind operations.""" + + model_config = ConfigDict( + extra='forbid', + alias_generator=alias_generators.to_camel, + populate_by_name=True, + ) + """The pydantic model config.""" + + rewind_before_invocation_id: str + """The invocation ID that the rewind operation targeted.""" + + boundary_after_invocation_id: Optional[str] = None + """The last invocation ID retained before the rewind boundary, if any.""" + + events_before_rewind: int + """The number of events present before appending the rewind event.""" + + events_after_rewind: int + """The number of pre-existing events retained after rewind filtering.""" + + history_before_hash: str + """Canonical hash of the full pre-rewind event history.""" + + history_after_hash: str + """Canonical hash of the retained pre-rewind event history.""" + + receipt_hash: str + """Tamper-evident hash over the rewind receipt summary.""" + + class EventActions(BaseModel): """Represents the actions attached to an event.""" @@ -108,3 +140,6 @@ class EventActions(BaseModel): rewind_before_invocation_id: Optional[str] = None """The invocation id to rewind to. This is only set for rewind event.""" + + rewind_audit_receipt: Optional[RewindAuditReceipt] = None + """Structured receipt proving rewind boundaries and history digests.""" diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index cdb878cf24..cf29ecc3ce 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -15,7 +15,9 @@ from __future__ import annotations import asyncio +import hashlib import inspect +import json import logging from pathlib import Path import queue @@ -47,6 +49,7 @@ from .code_executors.built_in_code_executor import BuiltInCodeExecutor from .events.event import Event from .events.event import EventActions +from .events.event_actions import RewindAuditReceipt from .flows.llm_flows import contents from .flows.llm_flows.functions import find_matching_function_call from .memory.base_memory_service import BaseMemoryService @@ -594,6 +597,11 @@ async def rewind_async( artifact_delta = await self._compute_artifact_delta_for_rewind( session, rewind_event_index ) + rewind_audit_receipt = self._build_rewind_audit_receipt( + session=session, + rewind_event_index=rewind_event_index, + rewind_before_invocation_id=rewind_before_invocation_id, + ) # Create rewind event rewind_event = Event( @@ -603,6 +611,7 @@ async def rewind_async( rewind_before_invocation_id=rewind_before_invocation_id, state_delta=state_delta, artifact_delta=artifact_delta, + rewind_audit_receipt=rewind_audit_receipt, ), ) @@ -610,6 +619,67 @@ async def rewind_async( await self.session_service.append_event(session=session, event=rewind_event) + def _build_rewind_audit_receipt( + self, + *, + session: Session, + rewind_event_index: int, + rewind_before_invocation_id: str, + ) -> RewindAuditReceipt: + """Builds a deterministic audit receipt for a rewind operation.""" + events_before = session.events + events_after = session.events[:rewind_event_index] + boundary_after_invocation_id = None + if rewind_event_index > 0: + boundary_after_invocation_id = session.events[ + rewind_event_index - 1 + ].invocation_id + + history_before_hash = self._hash_rewind_events(events_before) + history_after_hash = self._hash_rewind_events(events_after) + + receipt_payload = { + 'rewind_before_invocation_id': rewind_before_invocation_id, + 'boundary_after_invocation_id': boundary_after_invocation_id, + 'events_before_rewind': len(events_before), + 'events_after_rewind': len(events_after), + 'history_before_hash': history_before_hash, + 'history_after_hash': history_after_hash, + } + receipt_hash = self._hash_rewind_payload(receipt_payload) + + return RewindAuditReceipt( + **receipt_payload, + receipt_hash=receipt_hash, + ) + + def _hash_rewind_events(self, events: List[Event]) -> str: + """Hashes event summaries for deterministic rewind audit receipts.""" + summarized_events = [ + { + 'event_id': event.id, + 'invocation_id': event.invocation_id, + 'author': event.author, + 'state_delta': event.actions.state_delta, + 'artifact_delta': event.actions.artifact_delta, + 'rewind_before_invocation_id': ( + event.actions.rewind_before_invocation_id + ), + } + for event in events + ] + return self._hash_rewind_payload({'events': summarized_events}) + + def _hash_rewind_payload(self, payload: dict[str, Any]) -> str: + """Returns a canonical SHA-256 digest for rewind audit payloads.""" + canonical_json = json.dumps( + payload, + sort_keys=True, + separators=(',', ':'), + ensure_ascii=True, + ) + return hashlib.sha256(canonical_json.encode('utf-8')).hexdigest() + async def _compute_state_delta_for_rewind( self, session: Session, rewind_event_index: int ) -> dict[str, Any]: diff --git a/tests/unittests/runners/test_runner_rewind.py b/tests/unittests/runners/test_runner_rewind.py index 035d28437b..562b53bcb0 100644 --- a/tests/unittests/runners/test_runner_rewind.py +++ b/tests/unittests/runners/test_runner_rewind.py @@ -154,6 +154,15 @@ async def test_rewind_async_with_state_and_artifacts(self): ) is None ) + rewind_receipt = session.events[-1].actions.rewind_audit_receipt + assert rewind_receipt is not None + assert rewind_receipt.rewind_before_invocation_id == "invocation2" + assert rewind_receipt.boundary_after_invocation_id == "invocation1" + assert rewind_receipt.events_before_rewind == 3 + assert rewind_receipt.events_after_rewind == 1 + assert rewind_receipt.history_before_hash + assert rewind_receipt.history_after_hash + assert rewind_receipt.receipt_hash @pytest.mark.asyncio async def test_rewind_async_not_first_invocation(self): @@ -246,3 +255,40 @@ async def test_rewind_async_not_first_invocation(self): session_id=session_id, filename="f2", ) == types.Part.from_text(text="f2v0") + + @pytest.mark.asyncio + async def test_rewind_receipt_hash_is_deterministic(self): + """Tests that rewind receipt hashes are stable for the same history.""" + runner = self.runner + user_id = "test_user" + session_id = "test_session" + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id=user_id, session_id=session_id + ) + + for invocation_id in ("invocation1", "invocation2", "invocation3"): + await runner.session_service.append_event( + session=session, + event=Event( + invocation_id=invocation_id, + author="agent", + actions=EventActions(state_delta={invocation_id: invocation_id}), + ), + ) + + first_receipt = runner._build_rewind_audit_receipt( + session=session, + rewind_event_index=1, + rewind_before_invocation_id="invocation2", + ) + second_receipt = runner._build_rewind_audit_receipt( + session=session, + rewind_event_index=1, + rewind_before_invocation_id="invocation2", + ) + + assert ( + first_receipt.history_before_hash == second_receipt.history_before_hash + ) + assert first_receipt.history_after_hash == second_receipt.history_after_hash + assert first_receipt.receipt_hash == second_receipt.receipt_hash