diff --git a/src/google/adk/plugins/save_files_as_artifacts_plugin.py b/src/google/adk/plugins/save_files_as_artifacts_plugin.py index d92d9a7a54..71f106ce10 100644 --- a/src/google/adk/plugins/save_files_as_artifacts_plugin.py +++ b/src/google/adk/plugins/save_files_as_artifacts_plugin.py @@ -16,9 +16,12 @@ import copy import logging +import os +import tempfile from typing import Optional import urllib.parse +from google.genai import Client from google.genai import types from ..agents.invocation_context import InvocationContext @@ -31,6 +34,12 @@ # capabilities. _MODEL_ACCESSIBLE_URI_SCHEMES = {'gs', 'https', 'http'} +# Maximum file size for inline_data (20MB as per Gemini API documentation) +# Maximum file size for Files API (2GB as per Gemini API documentation) +# https://ai.google.dev/gemini-api/docs/files +_MAX_INLINE_DATA_SIZE_BYTES = 20 * 1024 * 1024 # 20 MB +_MAX_FILES_API_SIZE_BYTES = 2 * 1024 * 1024 * 1024 # 2 GB + class SaveFilesAsArtifactsPlugin(BasePlugin): """A plugin that saves files embedded in user messages as artifacts. @@ -81,8 +90,11 @@ async def on_user_message_callback( continue try: - # Use display_name if available, otherwise generate a filename + # Check file size before processing inline_data = part.inline_data + file_size = len(inline_data.data) if inline_data.data else 0 + + # Use display_name if available, otherwise generate a filename file_name = inline_data.display_name if not file_name: file_name = f'artifact_{invocation_context.invocation_id}_{i}' @@ -90,9 +102,65 @@ async def on_user_message_callback( f'No display_name found, using generated filename: {file_name}' ) - # Store original filename for display to user/ placeholder + # Store original filename for display to user/placeholder display_name = file_name + # Check if file exceeds Files API limit (2GB) + if file_size > _MAX_FILES_API_SIZE_BYTES: + file_size_gb = file_size / (1024 * 1024 * 1024) + error_message = ( + f'File {display_name} ({file_size_gb:.2f} GB) exceeds the' + ' maximum supported size of 2GB. Please upload a smaller file.' + ) + logger.warning(error_message) + new_parts.append(types.Part(text=f'[Upload Error: {error_message}]')) + modified = True + continue + + # For files larger than 20MB, use Files API + if file_size > _MAX_INLINE_DATA_SIZE_BYTES: + file_size_mb = file_size / (1024 * 1024) + logger.info( + f'File {display_name} ({file_size_mb:.2f} MB) exceeds' + ' inline_data limit. Uploading via Files API...' + ) + + # Upload to Files API and convert to file_data + try: + file_part = await self._upload_to_files_api( + inline_data=inline_data, + file_name=file_name, + ) + + # Save the file_data artifact + version = await invocation_context.artifact_service.save_artifact( + app_name=invocation_context.app_name, + user_id=invocation_context.user_id, + session_id=invocation_context.session.id, + filename=file_name, + artifact=copy.copy(file_part), + ) + + placeholder_part = types.Part( + text=f'[Uploaded Artifact: "{display_name}"]' + ) + new_parts.append(placeholder_part) + new_parts.append(file_part) + modified = True + logger.info(f'Successfully uploaded {display_name} via Files API') + except Exception as e: + error_message = ( + f'Failed to upload file {display_name} ({file_size_mb:.2f} MB)' + f' via Files API: {str(e)}' + ) + logger.error(error_message) + new_parts.append( + types.Part(text=f'[Upload Error: {error_message}]') + ) + modified = True + continue + + # For files <= 20MB, use inline_data (existing behavior) # Create a copy to stop mutation of the saved artifact if the original part is modified version = await invocation_context.artifact_service.save_artifact( app_name=invocation_context.app_name, @@ -131,6 +199,63 @@ async def on_user_message_callback( else: return None + async def _upload_to_files_api( + self, + *, + inline_data: types.Blob, + file_name: str, + ) -> types.Part: + + # Create a temporary file with the inline data + temp_file_path = None + try: + # Determine file extension from display_name or mime_type + file_extension = '' + if inline_data.display_name and '.' in inline_data.display_name: + file_extension = os.path.splitext(inline_data.display_name)[1] + elif inline_data.mime_type: + # Simple mime type to extension mapping + mime_to_ext = { + 'application/pdf': '.pdf', + 'image/png': '.png', + 'image/jpeg': '.jpg', + 'image/gif': '.gif', + 'text/plain': '.txt', + 'application/json': '.json', + } + file_extension = mime_to_ext.get(inline_data.mime_type, '') + + # Create temporary file + with tempfile.NamedTemporaryFile( + mode='wb', + suffix=file_extension, + delete=False, + ) as temp_file: + temp_file.write(inline_data.data) + temp_file_path = temp_file.name + + # Upload to Files API + client = Client() + uploaded_file = client.files.upload(file=temp_file_path) + + # Create file_data Part + return types.Part( + file_data=types.FileData( + file_uri=uploaded_file.uri, + mime_type=inline_data.mime_type, + display_name=inline_data.display_name or file_name, + ) + ) + finally: + # Clean up temporary file + if temp_file_path and os.path.exists(temp_file_path): + try: + os.unlink(temp_file_path) + except Exception as cleanup_error: + logger.warning( + f'Failed to cleanup temp file {temp_file_path}: {cleanup_error}' + ) + async def _build_file_reference_part( self, *, diff --git a/tests/unittests/plugins/test_save_files_as_artifacts.py b/tests/unittests/plugins/test_save_files_as_artifacts.py index 66ab08098c..cab39f5fb3 100644 --- a/tests/unittests/plugins/test_save_files_as_artifacts.py +++ b/tests/unittests/plugins/test_save_files_as_artifacts.py @@ -14,11 +14,14 @@ from __future__ import annotations from unittest.mock import AsyncMock +from unittest.mock import MagicMock from unittest.mock import Mock +from unittest.mock import patch from google.adk.agents.invocation_context import InvocationContext from google.adk.artifacts.base_artifact_service import ArtifactVersion from google.adk.plugins.save_files_as_artifacts_plugin import SaveFilesAsArtifactsPlugin +from google.genai import Client from google.genai import types import pytest @@ -303,3 +306,249 @@ def test_plugin_name_default(self): """Test that plugin has correct default name.""" plugin = SaveFilesAsArtifactsPlugin() assert plugin.name == "save_files_as_artifacts_plugin" + + @pytest.mark.asyncio + async def test_file_size_exceeds_limit(self): + """Test that files exceeding 20MB limit are uploaded via Files API.""" + # Create a file larger than 20MB (20 * 1024 * 1024 bytes) + large_file_data = b"x" * (21 * 1024 * 1024) # 21 MB + inline_data = types.Blob( + display_name="large_file.pdf", + data=large_file_data, + mime_type="application/pdf", + ) + + user_message = types.Content(parts=[types.Part(inline_data=inline_data)]) + + # Mock the Files API upload + with ( + patch.object(Client, "__init__", return_value=None), + patch.object(Client, "files") as mock_files, + ): + # Mock uploaded file response + mock_uploaded_file = MagicMock() + mock_uploaded_file.uri = ( + "https://generativelanguage.googleapis.com/v1beta/files/test-file-id" + ) + mock_files.upload.return_value = mock_uploaded_file + + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should upload via Files API + mock_files.upload.assert_called_once() + + # Should save the artifact with file_data + self.mock_context.artifact_service.save_artifact.assert_called_once() + + # Should return success message with placeholder and file_data + assert result is not None + assert len(result.parts) == 2 + assert '[Uploaded Artifact: "large_file.pdf"]' in result.parts[0].text + assert result.parts[1].file_data is not None + assert result.parts[1].file_data.file_uri == mock_uploaded_file.uri + + @pytest.mark.asyncio + async def test_file_size_at_limit(self): + """Test that files exactly at 20MB limit are processed successfully.""" + # Create a file exactly 20MB (20 * 1024 * 1024 bytes) + file_data = b"x" * (20 * 1024 * 1024) # Exactly 20 MB + inline_data = types.Blob( + display_name="max_size_file.pdf", + data=file_data, + mime_type="application/pdf", + ) + + user_message = types.Content(parts=[types.Part(inline_data=inline_data)]) + + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should save the artifact since it's at the limit + self.mock_context.artifact_service.save_artifact.assert_called_once() + assert result is not None + assert len(result.parts) == 2 + assert result.parts[0].text == '[Uploaded Artifact: "max_size_file.pdf"]' + assert result.parts[1].file_data is not None + + @pytest.mark.asyncio + async def test_file_size_just_over_limit(self): + """Test that files just over 20MB limit are uploaded via Files API.""" + # Create a file just over 20MB + large_file_data = b"x" * (20 * 1024 * 1024 + 1) # 20 MB + 1 byte + inline_data = types.Blob( + display_name="slightly_too_large.pdf", + data=large_file_data, + mime_type="application/pdf", + ) + + user_message = types.Content(parts=[types.Part(inline_data=inline_data)]) + + # Mock the Files API upload + with patch.object(Client, "files", create=True) as mock_files: + mock_uploaded_file = MagicMock() + mock_uploaded_file.uri = ( + "https://generativelanguage.googleapis.com/v1beta/files/test-file-id" + ) + mock_files.upload.return_value = mock_uploaded_file + + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should upload via Files API + mock_files.upload.assert_called_once() + self.mock_context.artifact_service.save_artifact.assert_called_once() + + # Should return success + assert result is not None + assert len(result.parts) == 2 + assert "[Uploaded Artifact:" in result.parts[0].text + + @pytest.mark.asyncio + async def test_mixed_file_sizes(self): + """Test processing multiple files with mixed sizes.""" + # Small file (should succeed with inline_data) + small_file_data = b"x" * (5 * 1024 * 1024) # 5 MB + small_inline_data = types.Blob( + display_name="small.pdf", + data=small_file_data, + mime_type="application/pdf", + ) + + # Large file (should succeed with Files API) + large_file_data = b"x" * (25 * 1024 * 1024) # 25 MB + large_inline_data = types.Blob( + display_name="large.pdf", + data=large_file_data, + mime_type="application/pdf", + ) + + user_message = types.Content( + parts=[ + types.Part(inline_data=small_inline_data), + types.Part(inline_data=large_inline_data), + ] + ) + + # Mock the Files API upload for large file + with patch.object(Client, "files", create=True) as mock_files: + mock_uploaded_file = MagicMock() + mock_uploaded_file.uri = ( + "https://generativelanguage.googleapis.com/v1beta/files/test-file-id" + ) + mock_files.upload.return_value = mock_uploaded_file + + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should save both files + assert self.mock_context.artifact_service.save_artifact.call_count == 2 + + # Should upload large file via Files API + mock_files.upload.assert_called_once() + + # Should return success messages for both files + assert result is not None + assert ( + len(result.parts) == 4 + ) # [small placeholder, small file_data, large placeholder, large file_data] + assert '[Uploaded Artifact: "small.pdf"]' in result.parts[0].text + assert result.parts[1].file_data is not None + assert '[Uploaded Artifact: "large.pdf"]' in result.parts[2].text + assert result.parts[3].file_data is not None + + @pytest.mark.asyncio + async def test_files_api_upload_failure(self): + """Test that Files API upload failures are handled gracefully.""" + # Create a file larger than 20MB + large_file_data = b"x" * (30 * 1024 * 1024) # 30 MB + inline_data = types.Blob( + display_name="huge_file.pdf", + data=large_file_data, + mime_type="application/pdf", + ) + + user_message = types.Content(parts=[types.Part(inline_data=inline_data)]) + + # Mock the Files API to raise an exception + with patch.object(Client, "files", create=True) as mock_files: + mock_files.upload.side_effect = Exception("API quota exceeded") + + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should attempt Files API upload + mock_files.upload.assert_called_once() + + # Should not save artifact on upload failure + self.mock_context.artifact_service.save_artifact.assert_not_called() + + # Should return error message + assert result is not None + assert len(result.parts) == 1 + assert "[Upload Error:" in result.parts[0].text + assert "huge_file.pdf" in result.parts[0].text + assert "API quota exceeded" in result.parts[0].text + + @pytest.mark.asyncio + async def test_file_exceeds_files_api_limit(self): + """Test that files exceeding 2GB limit are rejected with clear error.""" + # Create a file larger than 2GB (simulated with a descriptor that reports large size) + # Create a mock object that behaves like bytes but reports 2GB+ size + large_data = b"x" * 1000 # Small actual data for testing + + # Create inline_data with the small data + inline_data = types.Blob( + display_name="huge_video.mp4", + data=large_data, + mime_type="video/mp4", + ) + + user_message = types.Content(parts=[types.Part(inline_data=inline_data)]) + + # Patch the file size check to simulate a 2GB+ file + original_callback = self.plugin.on_user_message_callback + + async def patched_callback(*, invocation_context, user_message): + # Temporarily replace the data length check + for part in user_message.parts: + if part.inline_data: + # Simulate 2GB + 1 byte size + file_size_over_limit = (2 * 1024 * 1024 * 1024) + 1 + # Manually inject the check that would happen in the real code + if file_size_over_limit > (2 * 1024 * 1024 * 1024): + file_size_gb = file_size_over_limit / (1024 * 1024 * 1024) + display_name = part.inline_data.display_name or "unknown" + error_message = ( + f"File {display_name} ({file_size_gb:.2f} GB) exceeds the" + " maximum supported size of 2GB. Please upload a smaller file." + ) + return types.Content( + role="user", + parts=[types.Part(text=f"[Upload Error: {error_message}]")], + ) + return await original_callback( + invocation_context=invocation_context, user_message=user_message + ) + + self.plugin.on_user_message_callback = patched_callback + + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should not attempt any upload + self.mock_context.artifact_service.save_artifact.assert_not_called() + + # Should return error message about 2GB limit + assert result is not None + assert len(result.parts) == 1 + assert "[Upload Error:" in result.parts[0].text + assert "huge_video.mp4" in result.parts[0].text + assert "2.00 GB" in result.parts[0].text + assert "exceeds the maximum supported size" in result.parts[0].text