@@ -502,7 +502,6 @@ async def _init_session(
502502 ):
503503 """Initializes the session, and returns the session id."""
504504 from google .adk .events .event import Event
505- import random
506505
507506 session_state = None
508507 if request .authorizations :
@@ -511,14 +510,9 @@ async def _init_session(
511510 auth = _Authorization (** auth )
512511 session_state [f"temp:{ auth_id } " ] = auth .access_token
513512
514- if request .session_id :
515- session_id = request .session_id
516- else :
517- session_id = f"temp_session_{ random .randbytes (8 ).hex ()} "
518513 session = await session_service .create_session (
519514 app_name = self ._tmpl_attrs .get ("app_name" ),
520515 user_id = request .user_id ,
521- session_id = session_id ,
522516 state = session_state ,
523517 )
524518 if not session :
@@ -536,7 +530,7 @@ async def _init_session(
536530 saved_version = await artifact_service .save_artifact (
537531 app_name = self ._tmpl_attrs .get ("app_name" ),
538532 user_id = request .user_id ,
539- session_id = session_id ,
533+ session_id = session . id ,
540534 filename = artifact .file_name ,
541535 artifact = version_data .data ,
542536 )
@@ -881,44 +875,60 @@ async def async_stream_query(
881875 def streaming_agent_run_with_events (self , request_json : str ):
882876 import json
883877 from google .genai import types
878+ from google .genai .errors import ClientError
884879
885880 event_queue = queue .Queue (maxsize = 1 )
886881
887882 async def _invoke_agent_async ():
888883 request = _StreamRunRequest (** json .loads (request_json ))
889884 if not self ._tmpl_attrs .get ("in_memory_runner" ):
890885 self .set_up ()
886+ if not self ._tmpl_attrs .get ("runner" ):
887+ self .set_up ()
891888 # Prepare the in-memory session.
892889 if not self ._tmpl_attrs .get ("in_memory_artifact_service" ):
893890 self .set_up ()
891+ if not self ._tmpl_attrs .get ("artifact_service" ):
892+ self .set_up ()
894893 if not self ._tmpl_attrs .get ("in_memory_session_service" ):
895894 self .set_up ()
896- session_service = self ._tmpl_attrs .get ("in_memory_session_service" )
897- artifact_service = self ._tmpl_attrs .get ("in_memory_artifact_service" )
898- # Try to get the session, if it doesn't exist, create a new one.
899- session = None
895+ if not self ._tmpl_attrs .get ("session_service" ):
896+ self .set_up ()
900897 if request .session_id :
898+ session_service = self ._tmpl_attrs .get ("session_service" )
899+ artifact_service = self ._tmpl_attrs .get ("artifact_service" )
900+ runner = self ._tmpl_attrs .get ("runner" )
901901 try :
902902 session = await session_service .get_session (
903903 app_name = self ._tmpl_attrs .get ("app_name" ),
904904 user_id = request .user_id ,
905905 session_id = request .session_id ,
906906 )
907- except RuntimeError :
908- pass
909- if not session :
910- # Fall back to create session if the session is not found.
911- session = await self ._init_session (
912- session_service = session_service ,
913- artifact_service = artifact_service ,
914- request = request ,
907+ except ClientError :
908+ # Fall back to create session if the session is not found.
909+ # Specifying session_id on creation is not supported,
910+ # so session id will be regenerated.
911+ session = await self ._init_session (
912+ session_service = session_service ,
913+ artifact_service = artifact_service ,
914+ request = request ,
915+ )
916+ else :
917+ # Not providing a session ID will create a new in-memory session.
918+ session_service = self ._tmpl_attrs .get ("in_memory_session_service" )
919+ artifact_service = self ._tmpl_attrs .get ("in_memory_artifact_service" )
920+ runner = self ._tmpl_attrs .get ("in_memory_runner" )
921+ session = await session_service .create_session (
922+ app_name = self ._tmpl_attrs .get ("app_name" ),
923+ user_id = request .user_id ,
924+ session_id = request .session_id ,
915925 )
916926 if not session :
917927 raise RuntimeError ("Session initialization failed." )
918928 # Run the agent.
919929 message_for_agent = types .Content (** request .message )
920930 try :
921- for event in self . _tmpl_attrs . get ( "in_memory_runner" ) .run (
931+ for event in runner .run (
922932 user_id = request .user_id ,
923933 session_id = session .id ,
924934 new_message = message_for_agent ,
0 commit comments