Skip to content

Commit f9ce76f

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
chore: force flush OTel logs at the end of each request
Similar to what is already done for spans. Logs and spans are flushed in concurrently. PiperOrigin-RevId: 826620266
1 parent 9a46e67 commit f9ce76f

File tree

4 files changed

+242
-24
lines changed

4 files changed

+242
-24
lines changed

tests/unit/vertex_adk/test_agent_engine_templates_adk.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,34 @@ def trace_provider_mock():
170170
yield tracer_provider_mock
171171

172172

173+
@pytest.fixture
174+
def trace_provider_force_flush_mock():
175+
import opentelemetry.trace
176+
import opentelemetry.sdk.trace
177+
178+
with mock.patch.object(
179+
opentelemetry.trace, "get_tracer_provider"
180+
) as get_tracer_provider_mock:
181+
get_tracer_provider_mock.return_value = mock.Mock(
182+
spec=opentelemetry.sdk.trace.TracerProvider()
183+
)
184+
yield get_tracer_provider_mock.return_value.force_flush
185+
186+
187+
@pytest.fixture
188+
def logger_provider_force_flush_mock():
189+
import opentelemetry._logs
190+
import opentelemetry.sdk._logs
191+
192+
with mock.patch.object(
193+
opentelemetry._logs, "get_logger_provider"
194+
) as get_logger_provider_mock:
195+
get_logger_provider_mock.return_value = mock.Mock(
196+
spec=opentelemetry.sdk._logs.LoggerProvider()
197+
)
198+
yield get_logger_provider_mock.return_value.force_flush
199+
200+
173201
@pytest.fixture
174202
def default_instrumentor_builder_mock():
175203
with mock.patch(
@@ -351,6 +379,29 @@ async def test_async_stream_query(self):
351379
events.append(event)
352380
assert len(events) == 1
353381

382+
@pytest.mark.asyncio
383+
@mock.patch.dict(
384+
os.environ,
385+
{GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"},
386+
)
387+
async def test_async_stream_query_force_flush_otel(
388+
self,
389+
trace_provider_force_flush_mock: mock.Mock,
390+
logger_provider_force_flush_mock: mock.Mock,
391+
):
392+
app = agent_engines.AdkApp(agent=_TEST_AGENT)
393+
assert app._tmpl_attrs.get("runner") is None
394+
app.set_up()
395+
app._tmpl_attrs["runner"] = _MockRunner()
396+
async for _ in app.async_stream_query(
397+
user_id=_TEST_USER_ID,
398+
message="test message",
399+
):
400+
pass
401+
402+
trace_provider_force_flush_mock.assert_called_once()
403+
logger_provider_force_flush_mock.assert_called_once()
404+
354405
@pytest.mark.asyncio
355406
async def test_async_stream_query_with_content(self):
356407
app = agent_engines.AdkApp(agent=_TEST_AGENT)
@@ -403,6 +454,46 @@ async def test_streaming_agent_run_with_events(self):
403454
events.append(event)
404455
assert len(events) == 1
405456

457+
@pytest.mark.asyncio
458+
@mock.patch.dict(
459+
os.environ,
460+
{GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"},
461+
)
462+
async def test_streaming_agent_run_with_events_force_flush_otel(
463+
self,
464+
trace_provider_force_flush_mock: mock.Mock,
465+
logger_provider_force_flush_mock: mock.Mock,
466+
):
467+
app = agent_engines.AdkApp(agent=_TEST_AGENT)
468+
app.set_up()
469+
app._tmpl_attrs["in_memory_runner"] = _MockRunner()
470+
request_json = json.dumps(
471+
{
472+
"artifacts": [
473+
{
474+
"file_name": "test_file_name",
475+
"versions": [{"version": "v1", "data": "v1data"}],
476+
}
477+
],
478+
"authorizations": {
479+
"test_user_id1": {"access_token": "test_access_token"},
480+
"test_user_id2": {"accessToken": "test-access-token"},
481+
},
482+
"user_id": _TEST_USER_ID,
483+
"message": {
484+
"parts": [{"text": "What is the exchange rate from USD to SEK?"}],
485+
"role": "user",
486+
},
487+
}
488+
)
489+
async for _ in app.streaming_agent_run_with_events(
490+
request_json=request_json,
491+
):
492+
pass
493+
494+
trace_provider_force_flush_mock.assert_called_once()
495+
logger_provider_force_flush_mock.assert_called_once()
496+
406497
@pytest.mark.asyncio
407498
async def test_async_create_session(self):
408499
app = agent_engines.AdkApp(agent=_TEST_AGENT)

tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,34 @@ def trace_provider_mock():
128128
yield tracer_provider_mock
129129

130130

131+
@pytest.fixture
132+
def trace_provider_force_flush_mock():
133+
import opentelemetry.trace
134+
import opentelemetry.sdk.trace
135+
136+
with mock.patch.object(
137+
opentelemetry.trace, "get_tracer_provider"
138+
) as get_tracer_provider_mock:
139+
get_tracer_provider_mock.return_value = mock.Mock(
140+
spec=opentelemetry.sdk.trace.TracerProvider()
141+
)
142+
yield get_tracer_provider_mock.return_value.force_flush
143+
144+
145+
@pytest.fixture
146+
def logger_provider_force_flush_mock():
147+
import opentelemetry._logs
148+
import opentelemetry.sdk._logs
149+
150+
with mock.patch.object(
151+
opentelemetry._logs, "get_logger_provider"
152+
) as get_logger_provider_mock:
153+
get_logger_provider_mock.return_value = mock.Mock(
154+
spec=opentelemetry.sdk._logs.LoggerProvider()
155+
)
156+
yield get_logger_provider_mock.return_value.force_flush
157+
158+
131159
@pytest.fixture
132160
def default_instrumentor_builder_mock():
133161
with mock.patch(
@@ -353,6 +381,31 @@ async def test_async_stream_query(self):
353381
events.append(event)
354382
assert len(events) == 1
355383

384+
@pytest.mark.asyncio
385+
@mock.patch.dict(
386+
os.environ,
387+
{"GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY": "true"},
388+
)
389+
async def test_async_stream_query_force_flush_otel(
390+
self,
391+
trace_provider_force_flush_mock: mock.Mock,
392+
logger_provider_force_flush_mock: mock.Mock,
393+
):
394+
app = reasoning_engines.AdkApp(
395+
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL), enable_tracing=True
396+
)
397+
assert app._tmpl_attrs.get("runner") is None
398+
app.set_up()
399+
app._tmpl_attrs["runner"] = _MockRunner()
400+
async for _ in app.async_stream_query(
401+
user_id=_TEST_USER_ID,
402+
message="test message",
403+
):
404+
pass
405+
406+
trace_provider_force_flush_mock.assert_called_once()
407+
logger_provider_force_flush_mock.assert_called_once()
408+
356409
@pytest.mark.asyncio
357410
async def test_async_stream_query_with_content(self):
358411
app = reasoning_engines.AdkApp(
@@ -404,6 +457,46 @@ def test_streaming_agent_run_with_events(self):
404457
events = list(app.streaming_agent_run_with_events(request_json=request_json))
405458
assert len(events) == 1
406459

460+
@pytest.mark.asyncio
461+
@mock.patch.dict(
462+
os.environ,
463+
{"GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY": "true"},
464+
)
465+
async def test_streaming_agent_run_with_events_force_flush_otel(
466+
self,
467+
trace_provider_force_flush_mock: mock.Mock,
468+
logger_provider_force_flush_mock: mock.Mock,
469+
):
470+
app = reasoning_engines.AdkApp(
471+
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL),
472+
enable_tracing=True,
473+
)
474+
app.set_up()
475+
app._tmpl_attrs["in_memory_runner"] = _MockRunner()
476+
request_json = json.dumps(
477+
{
478+
"artifacts": [
479+
{
480+
"file_name": "test_file_name",
481+
"versions": [{"version": "v1", "data": "v1data"}],
482+
}
483+
],
484+
"authorizations": {
485+
"test_user_id1": {"access_token": "test_access_token"},
486+
"test_user_id2": {"accessToken": "test-access-token"},
487+
},
488+
"user_id": _TEST_USER_ID,
489+
"message": {
490+
"parts": [{"text": "What is the exchange rate from USD to SEK?"}],
491+
"role": "user",
492+
},
493+
}
494+
)
495+
list(app.streaming_agent_run_with_events(request_json=request_json))
496+
497+
trace_provider_force_flush_mock.assert_called_once()
498+
logger_provider_force_flush_mock.assert_called_once()
499+
407500
@pytest.mark.asyncio
408501
async def test_async_bidi_stream_query(self):
409502
app = reasoning_engines.AdkApp(

vertexai/agent_engines/templates/adk.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626

2727
import asyncio
28+
from collections.abc import Awaitable
2829
import queue
2930
import threading
3031
import warnings
@@ -231,26 +232,38 @@ def _warn(msg: str):
231232
_warn._LOGGER.warning(msg) # pyright: ignore[reportFunctionMemberAccess]
232233

233234

234-
def _force_flush_traces():
235+
async def _force_flush_otel(tracing_enabled: bool, logging_enabled: bool):
235236
try:
236237
import opentelemetry.trace
238+
import opentelemetry._logs
237239
except (ImportError, AttributeError):
238240
_warn(
239-
"Could not force flush traces. opentelemetry-api is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
241+
"Could not force flush telemetry data. opentelemetry-api is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
240242
)
241243
return None
242244

243245
try:
244246
import opentelemetry.sdk.trace
247+
import opentelemetry.sdk._logs
245248
except (ImportError, AttributeError):
246249
_warn(
247-
"Could not force flush traces. opentelemetry-sdk is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
250+
"Could not force flush telemetry data. opentelemetry-sdk is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
248251
)
249252
return None
250253

251-
provider = opentelemetry.trace.get_tracer_provider()
252-
if isinstance(provider, opentelemetry.sdk.trace.TracerProvider):
253-
_ = provider.force_flush()
254+
coros: List[Awaitable[bool]] = []
255+
256+
if tracing_enabled:
257+
tracer_provider = opentelemetry.trace.get_tracer_provider()
258+
if isinstance(tracer_provider, opentelemetry.sdk.trace.TracerProvider):
259+
coros.append(asyncio.to_thread(tracer_provider.force_flush))
260+
261+
if logging_enabled:
262+
logger_provider = opentelemetry._logs.get_logger_provider()
263+
if isinstance(logger_provider, opentelemetry.sdk._logs.LoggerProvider):
264+
coros.append(asyncio.to_thread(logger_provider.force_flush))
265+
266+
await asyncio.gather(*coros, return_exceptions=True)
254267

255268

256269
def _default_instrumentor_builder(
@@ -894,9 +907,11 @@ async def async_stream_query(
894907
# Yield the event data as a dictionary
895908
yield _utils.dump_event_for_json(event)
896909
finally:
897-
# Avoid trace data loss having to do with CPU throttling on instance turndown
898-
if self._tracing_enabled():
899-
_ = await asyncio.to_thread(_force_flush_traces)
910+
# Avoid telemetry data loss having to do with CPU throttling on instance turndown
911+
_ = await _force_flush_otel(
912+
tracing_enabled=self._tracing_enabled(),
913+
logging_enabled=bool(self._telemetry_enabled()),
914+
)
900915

901916
def stream_query(
902917
self,
@@ -1066,9 +1081,11 @@ async def streaming_agent_run_with_events(self, request_json: str):
10661081
user_id=request.user_id,
10671082
session_id=session.id,
10681083
)
1069-
# Avoid trace data loss having to do with CPU throttling on instance turndown
1070-
if self._tracing_enabled():
1071-
_ = await asyncio.to_thread(_force_flush_traces)
1084+
# Avoid telemetry data loss having to do with CPU throttling on instance turndown
1085+
_ = await _force_flush_otel(
1086+
tracing_enabled=self._tracing_enabled(),
1087+
logging_enabled=bool(self._telemetry_enabled()),
1088+
)
10721089

10731090
async def async_get_session(
10741091
self,

vertexai/preview/reasoning_engines/templates/adk.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626

2727
import asyncio
28+
from collections.abc import Awaitable
2829
import queue
2930
import threading
3031

@@ -233,26 +234,38 @@ def _warn(msg: str):
233234
_warn._LOGGER.warning(msg) # pyright: ignore[reportFunctionMemberAccess]
234235

235236

236-
def _force_flush_traces():
237+
async def _force_flush_otel(tracing_enabled: bool, logging_enabled: bool):
237238
try:
238239
import opentelemetry.trace
240+
import opentelemetry._logs
239241
except (ImportError, AttributeError):
240242
_warn(
241-
"Could not force flush traces. opentelemetry-api is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
243+
"Could not force flush telemetry data. opentelemetry-api is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
242244
)
243245
return None
244246

245247
try:
246248
import opentelemetry.sdk.trace
249+
import opentelemetry.sdk._logs
247250
except (ImportError, AttributeError):
248251
_warn(
249-
"Could not force flush traces. opentelemetry-sdk is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
252+
"Could not force flush telemetry data. opentelemetry-sdk is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
250253
)
251254
return None
252255

253-
provider = opentelemetry.trace.get_tracer_provider()
254-
if isinstance(provider, opentelemetry.sdk.trace.TracerProvider):
255-
_ = provider.force_flush()
256+
coros: List[Awaitable[bool]] = []
257+
258+
if tracing_enabled:
259+
tracer_provider = opentelemetry.trace.get_tracer_provider()
260+
if isinstance(tracer_provider, opentelemetry.sdk.trace.TracerProvider):
261+
coros.append(asyncio.to_thread(tracer_provider.force_flush))
262+
263+
if logging_enabled:
264+
logger_provider = opentelemetry._logs.get_logger_provider()
265+
if isinstance(logger_provider, opentelemetry.sdk._logs.LoggerProvider):
266+
coros.append(asyncio.to_thread(logger_provider.force_flush))
267+
268+
await asyncio.gather(*coros, return_exceptions=True)
256269

257270

258271
def _default_instrumentor_builder(
@@ -891,9 +904,11 @@ async def async_stream_query(
891904
# Yield the event data as a dictionary
892905
yield _utils.dump_event_for_json(event)
893906
finally:
894-
# Avoid trace data loss having to do with CPU throttling on instance turndown
895-
if self._tracing_enabled():
896-
_ = await asyncio.to_thread(_force_flush_traces)
907+
# Avoid telemetry data loss having to do with CPU throttling on instance turndown
908+
_ = await _force_flush_otel(
909+
tracing_enabled=self._tracing_enabled(),
910+
logging_enabled=bool(self._telemetry_enabled()),
911+
)
897912

898913
def streaming_agent_run_with_events(self, request_json: str):
899914
import json
@@ -970,9 +985,11 @@ async def _invoke_agent_async():
970985
user_id=request.user_id,
971986
session_id=session.id,
972987
)
973-
# Avoid trace data loss having to do with CPU throttling on instance turndown
974-
if self._tracing_enabled():
975-
_ = await asyncio.to_thread(_force_flush_traces)
988+
# Avoid telemetry data loss having to do with CPU throttling on instance turndown
989+
_ = await _force_flush_otel(
990+
tracing_enabled=self._tracing_enabled(),
991+
logging_enabled=bool(self._telemetry_enabled()),
992+
)
976993

977994
def _asyncio_thread_main():
978995
try:

0 commit comments

Comments
 (0)