Skip to content

Commit c4160aa

Browse files
Tongzhou-Jiangcopybara-github
authored andcommitted
feat: Alow VertexAiSession for streaming_agent_run_with_events
PiperOrigin-RevId: 824600367
1 parent 35ac4c5 commit c4160aa

File tree

2 files changed

+61
-39
lines changed
  • vertexai
    • agent_engines/templates
    • preview/reasoning_engines/templates

2 files changed

+61
-39
lines changed

vertexai/agent_engines/templates/adk.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,6 @@ async def _init_session(
567567
):
568568
"""Initializes the session, and returns the session id."""
569569
from google.adk.events.event import Event
570-
import random
571570

572571
session_state = None
573572
if request.authorizations:
@@ -576,14 +575,9 @@ async def _init_session(
576575
auth = _Authorization(**auth)
577576
session_state[f"temp:{auth_id}"] = auth.access_token
578577

579-
if request.session_id:
580-
session_id = request.session_id
581-
else:
582-
session_id = f"temp_session_{random.randbytes(8).hex()}"
583578
session = await session_service.create_session(
584579
app_name=self._tmpl_attrs.get("app_name"),
585580
user_id=request.user_id,
586-
session_id=session_id,
587581
state=session_state,
588582
)
589583
if not session:
@@ -601,7 +595,7 @@ async def _init_session(
601595
saved_version = await artifact_service.save_artifact(
602596
app_name=self._tmpl_attrs.get("app_name"),
603597
user_id=request.user_id,
604-
session_id=session_id,
598+
session_id=session.id,
605599
filename=artifact.file_name,
606600
artifact=version_data.data,
607601
)
@@ -997,43 +991,61 @@ async def streaming_agent_run_with_events(self, request_json: str):
997991

998992
import json
999993
from google.genai import types
994+
from google.genai.errors import ClientError
1000995

1001996
request = _StreamRunRequest(**json.loads(request_json))
1002997
if not self._tmpl_attrs.get("in_memory_runner"):
1003998
self.set_up()
999+
if not self._tmpl_attrs.get("runner"):
1000+
self.set_up()
10041001
# Prepare the in-memory session.
10051002
if not self._tmpl_attrs.get("in_memory_artifact_service"):
10061003
self.set_up()
1004+
if not self._tmpl_attrs.get("artifact_service"):
1005+
self.set_up()
10071006
if not self._tmpl_attrs.get("in_memory_session_service"):
10081007
self.set_up()
1009-
session_service = self._tmpl_attrs.get("in_memory_session_service")
1010-
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
1008+
if not self._tmpl_attrs.get("session_service"):
1009+
self.set_up()
10111010
app = self._tmpl_attrs.get("app")
1011+
10121012
# Try to get the session, if it doesn't exist, create a new one.
1013-
session = None
10141013
if request.session_id:
1014+
session_service = self._tmpl_attrs.get("session_service")
1015+
artifact_service = self._tmpl_attrs.get("artifact_service")
1016+
runner = self._tmpl_attrs.get("runner")
10151017
try:
10161018
session = await session_service.get_session(
10171019
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
10181020
user_id=request.user_id,
10191021
session_id=request.session_id,
10201022
)
1021-
except RuntimeError:
1022-
pass
1023-
if not session:
1024-
# Fall back to create session if the session is not found.
1025-
session = await self._init_session(
1026-
session_service=session_service,
1027-
artifact_service=artifact_service,
1028-
request=request,
1023+
except ClientError:
1024+
# Fall back to create session if the session is not found.
1025+
# Specifying session_id on creation is not supported,
1026+
# so session id will be regenerated.
1027+
session = await self._init_session(
1028+
session_service=session_service,
1029+
artifact_service=artifact_service,
1030+
request=request,
1031+
)
1032+
else:
1033+
# Not providing a session ID will create a new in-memory session.
1034+
session_service = self._tmpl_attrs.get("in_memory_session_service")
1035+
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
1036+
runner = self._tmpl_attrs.get("in_memory_runner")
1037+
session = await session_service.create_session(
1038+
app_name=self._tmpl_attrs.get("app_name"),
1039+
user_id=request.user_id,
1040+
session_id=request.session_id,
10291041
)
10301042
if not session:
10311043
raise RuntimeError("Session initialization failed.")
10321044

10331045
# Run the agent
10341046
message_for_agent = types.Content(**request.message)
10351047
try:
1036-
async for event in self._tmpl_attrs.get("in_memory_runner").run_async(
1048+
async for event in runner.run_async(
10371049
user_id=request.user_id,
10381050
session_id=session.id,
10391051
new_message=message_for_agent,

vertexai/preview/reasoning_engines/templates/adk.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,6 @@ async def _init_session(
502502
):
503503
"""Initializes the session, and returns the session id."""
504504
from google.adk.events.event import Event
505-
import random
506505

507506
session_state = None
508507
if request.authorizations:
@@ -511,14 +510,9 @@ async def _init_session(
511510
auth = _Authorization(**auth)
512511
session_state[f"temp:{auth_id}"] = auth.access_token
513512

514-
if request.session_id:
515-
session_id = request.session_id
516-
else:
517-
session_id = f"temp_session_{random.randbytes(8).hex()}"
518513
session = await session_service.create_session(
519514
app_name=self._tmpl_attrs.get("app_name"),
520515
user_id=request.user_id,
521-
session_id=session_id,
522516
state=session_state,
523517
)
524518
if not session:
@@ -536,7 +530,7 @@ async def _init_session(
536530
saved_version = await artifact_service.save_artifact(
537531
app_name=self._tmpl_attrs.get("app_name"),
538532
user_id=request.user_id,
539-
session_id=session_id,
533+
session_id=session.id,
540534
filename=artifact.file_name,
541535
artifact=version_data.data,
542536
)
@@ -881,44 +875,60 @@ async def async_stream_query(
881875
def streaming_agent_run_with_events(self, request_json: str):
882876
import json
883877
from google.genai import types
878+
from google.genai.errors import ClientError
884879

885880
event_queue = queue.Queue(maxsize=1)
886881

887882
async def _invoke_agent_async():
888883
request = _StreamRunRequest(**json.loads(request_json))
889884
if not self._tmpl_attrs.get("in_memory_runner"):
890885
self.set_up()
886+
if not self._tmpl_attrs.get("runner"):
887+
self.set_up()
891888
# Prepare the in-memory session.
892889
if not self._tmpl_attrs.get("in_memory_artifact_service"):
893890
self.set_up()
891+
if not self._tmpl_attrs.get("artifact_service"):
892+
self.set_up()
894893
if not self._tmpl_attrs.get("in_memory_session_service"):
895894
self.set_up()
896-
session_service = self._tmpl_attrs.get("in_memory_session_service")
897-
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
898-
# Try to get the session, if it doesn't exist, create a new one.
899-
session = None
895+
if not self._tmpl_attrs.get("session_service"):
896+
self.set_up()
900897
if request.session_id:
898+
session_service = self._tmpl_attrs.get("session_service")
899+
artifact_service = self._tmpl_attrs.get("artifact_service")
900+
runner = self._tmpl_attrs.get("runner")
901901
try:
902902
session = await session_service.get_session(
903903
app_name=self._tmpl_attrs.get("app_name"),
904904
user_id=request.user_id,
905905
session_id=request.session_id,
906906
)
907-
except RuntimeError:
908-
pass
909-
if not session:
910-
# Fall back to create session if the session is not found.
911-
session = await self._init_session(
912-
session_service=session_service,
913-
artifact_service=artifact_service,
914-
request=request,
907+
except ClientError:
908+
# Fall back to create session if the session is not found.
909+
# Specifying session_id on creation is not supported,
910+
# so session id will be regenerated.
911+
session = await self._init_session(
912+
session_service=session_service,
913+
artifact_service=artifact_service,
914+
request=request,
915+
)
916+
else:
917+
# Not providing a session ID will create a new in-memory session.
918+
session_service = self._tmpl_attrs.get("in_memory_session_service")
919+
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
920+
runner = self._tmpl_attrs.get("in_memory_runner")
921+
session = await session_service.create_session(
922+
app_name=self._tmpl_attrs.get("app_name"),
923+
user_id=request.user_id,
924+
session_id=request.session_id,
915925
)
916926
if not session:
917927
raise RuntimeError("Session initialization failed.")
918928
# Run the agent.
919929
message_for_agent = types.Content(**request.message)
920930
try:
921-
for event in self._tmpl_attrs.get("in_memory_runner").run(
931+
for event in runner.run(
922932
user_id=request.user_id,
923933
session_id=session.id,
924934
new_message=message_for_agent,

0 commit comments

Comments
 (0)