diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java index 7878daf22..a109f32be 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java @@ -119,15 +119,17 @@ public Single listSessions(String appName, String userId) .map( listSessionsResponseMap -> parseListSessionsResponse(listSessionsResponseMap, appName, userId)) - .defaultIfEmpty(ListSessionsResponse.builder().build()); + .defaultIfEmpty(ListSessionsResponse.builder().sessions(new ArrayList<>()).build()); } private ListSessionsResponse parseListSessionsResponse( JsonNode listSessionsResponseMap, String appName, String userId) { + JsonNode sessionsNode = listSessionsResponseMap.get("sessions"); + if (sessionsNode == null || sessionsNode.isNull() || sessionsNode.isEmpty()) { + return ListSessionsResponse.builder().sessions(new ArrayList<>()).build(); + } List> apiSessions = - objectMapper.convertValue( - listSessionsResponseMap.get("sessions"), - new TypeReference>>() {}); + objectMapper.convertValue(sessionsNode, new TypeReference>>() {}); List sessions = new ArrayList<>(); for (Map apiSession : apiSessions) { diff --git a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java index 775b465ff..b8af31133 100644 --- a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java @@ -25,6 +25,8 @@ import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import okhttp3.MediaType; +import okhttp3.ResponseBody; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -37,6 +39,20 @@ public class VertexAiSessionServiceTest { private static final ObjectMapper mapper = JsonBaseModel.getMapper(); + private static final MediaType JSON_MEDIA_TYPE = + MediaType.parse("application/json; charset=utf-8"); + + private static ApiResponse apiResponseJson(String json) { + return new ApiResponse() { + @Override + public ResponseBody getResponseBody() { + return ResponseBody.create(JSON_MEDIA_TYPE, json); + } + + @Override + public void close() {} + }; + } private static final String MOCK_SESSION_STRING_1 = """ @@ -322,6 +338,24 @@ public void listSessions_empty() { .isEmpty(); } + @Test + public void listSessions_missingSessionsField_returnsEmpty() { + when(mockApiClient.request("GET", "reasoningEngines/123/sessions?filter=user_id=userX", "")) + .thenReturn(apiResponseJson("{}")); + + assertThat(vertexAiSessionService.listSessions("123", "userX").blockingGet().sessions()) + .isEmpty(); + } + + @Test + public void listSessions_nullSessionsField_returnsEmpty() { + when(mockApiClient.request("GET", "reasoningEngines/123/sessions?filter=user_id=userY", "")) + .thenReturn(apiResponseJson("{\"sessions\": null}")); + + assertThat(vertexAiSessionService.listSessions("123", "userY").blockingGet().sessions()) + .isEmpty(); + } + @Test public void listEvents_empty() { assertThat(vertexAiSessionService.listEvents("789", "user1", "3").blockingGet().events())