From c8e0960836fcd435c866dfe4103bd98228913303 Mon Sep 17 00:00:00 2001 From: pkarmarkar Date: Sat, 13 Dec 2025 23:39:29 -0800 Subject: [PATCH] feat: Add DatabaseSessionService with Hibernate and Flyway support Summary Implements a production-ready database-backed session service that provides persistent storage for ADK sessions, events, and state using JPA/Hibernate. Key Features Hibernate 6.6 + HikariCP connection pooling for optimal performance Flyway migrations for schema versioning and zero-downtime deployments Multi-database support: PostgreSQL, MySQL, H2, and other RDBMS Thread-safe operations with pessimistic locking for concurrent updates Comprehensive test coverage with H2 in-memory database Dialect-aware JSON storage (JSONB for PostgreSQL, CLOB for others) Event filtering and pagination for efficient data retrieval Architecture Database dependencies are non-optional in the core module for ease of use, consistent with Python ADK's approach where `DatabaseSessionService` is in the main package. Note: Database drivers (PostgreSQL, MySQL) are marked as `true`, so users must add their specific driver dependency to their project. --- core/pom.xml | 75 ++ .../adk/sessions/DatabaseSessionService.java | 846 ++++++++++++++++++ .../db/converter/EventActionsUserType.java | 167 ++++ .../sessions/db/converter/JsonUserType.java | 174 ++++ .../adk/sessions/db/entity/EventId.java | 85 ++ .../adk/sessions/db/entity/SessionId.java | 74 ++ .../sessions/db/entity/StorageAppState.java | 73 ++ .../adk/sessions/db/entity/StorageEvent.java | 506 +++++++++++ .../sessions/db/entity/StorageSession.java | 132 +++ .../sessions/db/entity/StorageUserState.java | 86 ++ .../adk/sessions/db/entity/UserStateId.java | 62 ++ .../db/util/DatabaseDialectDetector.java | 81 ++ .../db/util/EntityManagerFactoryProvider.java | 83 ++ .../main/resources/META-INF/persistence.xml | 29 + .../db/migration/h2/V1__Initial_schema.sql | 75 ++ .../db/migration/mysql/V1__Initial_schema.sql | 77 ++ .../postgresql/V1__Initial_schema.sql | 75 ++ .../ConcurrentSessionOperationsTest.java | 429 +++++++++ .../sessions/DatabaseDialectDetectorTest.java | 179 ++++ .../sessions/DatabaseSessionServiceTest.java | 416 +++++++++ .../adk/sessions/EventFilteringTest.java | 416 +++++++++ .../adk/sessions/FlywayMigrationTest.java | 195 ++++ .../adk/sessions/PessimisticLockingTest.java | 325 +++++++ .../google/adk/sessions/StateDeltaTest.java | 382 ++++++++ .../adk/sessions/StateManagementTest.java | 308 +++++++ .../db/converter/JsonUserTypeTest.java | 150 ++++ .../adk/testing/TestDatabaseConfig.java | 95 ++ pom.xml | 28 + 28 files changed, 5623 insertions(+) create mode 100644 core/src/main/java/com/google/adk/sessions/DatabaseSessionService.java create mode 100644 core/src/main/java/com/google/adk/sessions/db/converter/EventActionsUserType.java create mode 100644 core/src/main/java/com/google/adk/sessions/db/converter/JsonUserType.java create mode 100644 core/src/main/java/com/google/adk/sessions/db/entity/EventId.java create mode 100644 core/src/main/java/com/google/adk/sessions/db/entity/SessionId.java create mode 100644 core/src/main/java/com/google/adk/sessions/db/entity/StorageAppState.java create mode 100644 core/src/main/java/com/google/adk/sessions/db/entity/StorageEvent.java create mode 100644 core/src/main/java/com/google/adk/sessions/db/entity/StorageSession.java create mode 100644 core/src/main/java/com/google/adk/sessions/db/entity/StorageUserState.java create mode 100644 core/src/main/java/com/google/adk/sessions/db/entity/UserStateId.java create mode 100644 core/src/main/java/com/google/adk/sessions/db/util/DatabaseDialectDetector.java create mode 100644 core/src/main/java/com/google/adk/sessions/db/util/EntityManagerFactoryProvider.java create mode 100644 core/src/main/resources/META-INF/persistence.xml create mode 100644 core/src/main/resources/db/migration/h2/V1__Initial_schema.sql create mode 100644 core/src/main/resources/db/migration/mysql/V1__Initial_schema.sql create mode 100644 core/src/main/resources/db/migration/postgresql/V1__Initial_schema.sql create mode 100644 core/src/test/java/com/google/adk/sessions/ConcurrentSessionOperationsTest.java create mode 100644 core/src/test/java/com/google/adk/sessions/DatabaseDialectDetectorTest.java create mode 100644 core/src/test/java/com/google/adk/sessions/DatabaseSessionServiceTest.java create mode 100644 core/src/test/java/com/google/adk/sessions/EventFilteringTest.java create mode 100644 core/src/test/java/com/google/adk/sessions/FlywayMigrationTest.java create mode 100644 core/src/test/java/com/google/adk/sessions/PessimisticLockingTest.java create mode 100644 core/src/test/java/com/google/adk/sessions/StateDeltaTest.java create mode 100644 core/src/test/java/com/google/adk/sessions/StateManagementTest.java create mode 100644 core/src/test/java/com/google/adk/sessions/db/converter/JsonUserTypeTest.java create mode 100644 core/src/test/java/com/google/adk/testing/TestDatabaseConfig.java diff --git a/core/pom.xml b/core/pom.xml index fe65715f3..00ad9c689 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -104,6 +104,17 @@ com.fasterxml.jackson.dataformat jackson-dataformat-yaml + + + com.fasterxml.jackson.module + jackson-module-jaxb-annotations + + + + jakarta.xml.bind + jakarta.xml.bind-api + 4.0.0 + com.google.protobuf protobuf-java @@ -189,6 +200,70 @@ opentelemetry-sdk-testing test + + + + + + org.hibernate.orm + hibernate-core + + + + + jakarta.persistence + jakarta.persistence-api + + + + + com.zaxxer + HikariCP + + + + + org.hibernate.orm + hibernate-hikaricp + + + + + + + com.h2database + h2 + test + + + + + org.postgresql + postgresql + true + + + + + com.mysql + mysql-connector-j + true + + + + + + + org.flywaydb + flyway-core + + + + + org.flywaydb + flyway-database-postgresql + runtime + diff --git a/core/src/main/java/com/google/adk/sessions/DatabaseSessionService.java b/core/src/main/java/com/google/adk/sessions/DatabaseSessionService.java new file mode 100644 index 000000000..009e77bb7 --- /dev/null +++ b/core/src/main/java/com/google/adk/sessions/DatabaseSessionService.java @@ -0,0 +1,846 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import com.google.adk.events.Event; +import com.google.adk.sessions.db.entity.SessionId; +import com.google.adk.sessions.db.entity.StorageAppState; +import com.google.adk.sessions.db.entity.StorageEvent; +import com.google.adk.sessions.db.entity.StorageSession; +import com.google.adk.sessions.db.entity.StorageUserState; +import com.google.adk.sessions.db.entity.UserStateId; +import com.google.adk.sessions.db.util.DatabaseDialectDetector; +import com.google.adk.sessions.db.util.EntityManagerFactoryProvider; +import com.google.common.annotations.VisibleForTesting; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.schedulers.Schedulers; +import jakarta.persistence.EntityManager; +import jakarta.persistence.EntityManagerFactory; +import jakarta.persistence.LockModeType; +import jakarta.persistence.PersistenceException; +import jakarta.persistence.TypedQuery; +import jakarta.persistence.criteria.CriteriaBuilder; +import jakarta.persistence.criteria.CriteriaQuery; +import jakarta.persistence.criteria.Root; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.stream.Collectors; +import org.flywaydb.core.Flyway; +import org.flywaydb.core.api.FlywayException; +import org.flywaydb.core.api.output.MigrateResult; +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A {@link BaseSessionService} implementation that stores sessions in a relational database. + * Supports automatic table creation on initialization using flyway. + */ +public class DatabaseSessionService implements BaseSessionService { + + private static final Logger logger = LoggerFactory.getLogger(DatabaseSessionService.class); + + // Default database options if not specified + private static final int DEFAULT_FETCH_LIMIT = 1000; + + // The Entity Manager Factory for database access + private final EntityManagerFactory emf; + + // The database dialect being used + private final String dialect; + + /** + * Creates a new DatabaseSessionService with the specified database URL. + * + * @param dbUrl The database URL to connect to + */ + public DatabaseSessionService(String dbUrl) { + this(dbUrl, Collections.emptyMap()); + } + + /** + * Creates a new DatabaseSessionService with the specified database URL and configuration. + * + *

The database URL should include all connection parameters, including credentials and SSL + * configuration, following standard JDBC URL format: + * + *

    + *
  • PostgreSQL: {@code + * jdbc:postgresql://host:5432/dbname?user=username&password=pass&ssl=true&sslmode=verify-full} + *
  • MySQL: {@code + * jdbc:mysql://host:3306/dbname?user=username&password=pass&useSSL=true&requireSSL=true} + *
  • H2 (in-memory): {@code jdbc:h2:mem:testdb} + *
  • SQLite: {@code jdbc:sqlite:./sessions.db} + *
+ * + *

For production use, load the database URL from environment variables to avoid hardcoding + * credentials: + * + *

{@code
+   * String dbUrl = System.getenv("DATABASE_URL");
+   * DatabaseSessionService service = new DatabaseSessionService(dbUrl);
+   * }
+ * + * @param dbUrl The database URL to connect to (including credentials) + * @param properties Additional Hibernate configuration properties (optional) + */ + public DatabaseSessionService(String dbUrl, Map properties) { + Objects.requireNonNull(dbUrl, "Database URL cannot be null"); + + // Detect dialect from URL + this.dialect = DatabaseDialectDetector.detectDialect(dbUrl); + + // Apply Flyway migrations first + try { + // Use dialect-specific migration folder to avoid duplicate version conflicts + String dialectFolder = extractDialectFolderName(this.dialect); + String flywayLocation = "classpath:db/migration/" + dialectFolder; + + // Configure Flyway with the database URL + // Flyway will extract credentials from the JDBC URL automatically + // Use system property or environment variable for FLYWAY_BASELINE_ON_MIGRATE (default: false) + String baselineOnMigrateStr = + System.getProperty( + "FLYWAY_BASELINE_ON_MIGRATE", + System.getenv().getOrDefault("FLYWAY_BASELINE_ON_MIGRATE", "false")); + boolean baselineOnMigrate = Boolean.parseBoolean(baselineOnMigrateStr); + + Flyway flyway = + Flyway.configure() + .dataSource(dbUrl, null, null) + .locations(flywayLocation) + .cleanDisabled(true) + .lockRetryCount(50) + .baselineOnMigrate(baselineOnMigrate) + .load(); + + // Run migrations with retry logic for multi-pod scenarios + MigrateResult migrateResult = flyway.migrate(); + + if (migrateResult.migrationsExecuted > 0) { + logger.info( + "Flyway migrations completed successfully. Applied {} new migrations.", + migrateResult.migrationsExecuted); + } else { + logger.info("Database schema is already up to date. No migrations applied."); + } + } catch (FlywayException e) { + logger.warn("Flyway migration issue: {}", e.getMessage()); + + if (e.getMessage() != null && e.getMessage().contains("lock")) { + logger.warn( + "Migration lock detected - another instance may be migrating. Validating schema..."); + try { + String dialectFolder = extractDialectFolderName(this.dialect); + String flywayLocation = "classpath:db/migration/" + dialectFolder; + + Flyway flyway = + Flyway.configure() + .dataSource(dbUrl, null, null) + .locations(flywayLocation) + .cleanDisabled(true) + .baselineOnMigrate(true) + .load(); + + flyway.validate(); + } catch (Exception validationEx) { + logger.error("Schema validation failed after migration lock", validationEx); + throw new RuntimeException( + "Failed to validate database schema after migration lock", validationEx); + } + } else { + logger.error("Error applying Flyway migrations", e); + throw new RuntimeException("Failed to apply database migrations", e); + } + } + + // Create entity manager factory with specified options, changing schema mode to 'none' + Map config = new HashMap<>(properties); + // Override hbm2ddl.auto to none since Flyway handles the schema creation and validation + config.put("hibernate.hbm2ddl.auto", "none"); + this.emf = EntityManagerFactoryProvider.createEntityManagerFactory(dbUrl, config); + } + + /** + * Creates a new DatabaseSessionService with a pre-configured EntityManagerFactory. This + * constructor is primarily used for testing. + * + * @param emf The EntityManagerFactory to use + */ + @VisibleForTesting + DatabaseSessionService(EntityManagerFactory emf) { + this.emf = emf; + this.dialect = "Test"; + } + + @Override + public Single createSession( + String appName, String userId, ConcurrentMap state, String sessionId) { + + Objects.requireNonNull(appName, "appName cannot be null"); + Objects.requireNonNull(userId, "userId cannot be null"); + + return Single.fromCallable( + () -> + executeInTransaction( + em -> { + // Split state by prefix + Map appStateMap = new HashMap<>(); + Map userStateMap = new HashMap<>(); + Map sessionStateMap = new HashMap<>(); + + if (state != null) { + for (Map.Entry entry : state.entrySet()) { + String key = entry.getKey(); + // Skip temp: prefixed keys + if (key.startsWith(State.TEMP_PREFIX)) { + continue; + } + + // Route keys to different tables based on prefix + if (key.startsWith(State.APP_PREFIX)) { + appStateMap.put(key, entry.getValue()); + } else if (key.startsWith(State.USER_PREFIX)) { + userStateMap.put(key, entry.getValue()); + } else { + sessionStateMap.put(key, entry.getValue()); + } + } + } + + // Get or create app state and merge new keys + StorageAppState appState = getOrCreateAppState(em, appName); + if (!appStateMap.isEmpty()) { + appState.getState().putAll(appStateMap); + appState.setUpdateTime(Instant.now()); + em.merge(appState); + } + + // Get or create user state and merge new keys + UserStateId userStateId = new UserStateId(appName, userId); + StorageUserState userState = getOrCreateUserState(em, userStateId); + if (!userStateMap.isEmpty()) { + userState.getState().putAll(userStateMap); + userState.setUpdateTime(Instant.now()); + em.merge(userState); + } + + // Create session entity with only session-specific state + StorageSession session = new StorageSession(); + session.setAppName(appName); + session.setUserId(userId); + session.setId(sessionId != null ? sessionId : UUID.randomUUID().toString()); + session.setState(sessionStateMap); + session.setCreateTime(Instant.now()); + session.setUpdateTime(Instant.now()); + + em.persist(session); + + // Convert to domain object with merged state + return createDomainSessionWithMergedState(session, appState, userState); + }, + "Error creating session")) + .subscribeOn(Schedulers.io()); + } + + /** {@inheritDoc} */ + @Override + public Maybe getSession( + String appName, String userId, String sessionId, Optional config) { + + Objects.requireNonNull(appName, "appName cannot be null"); + Objects.requireNonNull(userId, "userId cannot be null"); + Objects.requireNonNull(sessionId, "sessionId cannot be null"); + Objects.requireNonNull(config, "configOpt cannot be null"); + + return Maybe.fromCallable( + () -> + executeInTransaction( + em -> { + // Create composite key for session lookup + SessionId id = new SessionId(appName, userId, sessionId); + + // Find session + StorageSession session = em.find(StorageSession.class, id); + if (session == null) { + return null; // No session found, Maybe will be empty + } + + // Load events for the session with optional filtering + CriteriaBuilder cb = em.getCriteriaBuilder(); + CriteriaQuery eventQuery = cb.createQuery(StorageEvent.class); + Root eventRoot = eventQuery.from(StorageEvent.class); + + // Base filters for session identification + jakarta.persistence.criteria.Predicate basePredicate = + cb.and( + cb.equal(eventRoot.get("appName"), appName), + cb.equal(eventRoot.get("userId"), userId), + cb.equal(eventRoot.get("sessionId"), sessionId)); + + // Apply timestamp filter if present in config + if (config.isPresent() && config.get().afterTimestamp().isPresent()) { + Instant afterTimestamp = config.get().afterTimestamp().get(); + basePredicate = + cb.and( + basePredicate, + cb.greaterThan(eventRoot.get("timestamp"), afterTimestamp)); + } + + eventQuery.where(basePredicate); + + // Determine sort order based on whether we need to limit results + boolean needsReverse = false; + if (config.isPresent() && config.get().numRecentEvents().isPresent()) { + // Order descending to get most recent events first + eventQuery.orderBy(cb.desc(eventRoot.get("timestamp"))); + needsReverse = true; + } else { + // Order ascending for chronological order + eventQuery.orderBy(cb.asc(eventRoot.get("timestamp"))); + } + + // Execute query with optional limit + TypedQuery query = em.createQuery(eventQuery); + if (config.isPresent() && config.get().numRecentEvents().isPresent()) { + int numEvents = config.get().numRecentEvents().get(); + if (numEvents >= 0) { + query.setMaxResults(numEvents); + } + } + + List events = query.getResultList(); + + // Reverse if we fetched in descending order + if (needsReverse) { + Collections.reverse(events); + } + + // IMPORTANT: Use clear() + addAll() to maintain Hibernate's collection + // tracking. + // Replacing the collection reference (e.g., setEvents()) would break cascade + // operations + // and cause the EntityManager to lose change tracking on this relationship. + session.getEvents().clear(); + session.getEvents().addAll(events); + + // Find app state + StorageAppState appState = em.find(StorageAppState.class, appName); + + // Find user state + UserStateId userStateId = new UserStateId(appName, userId); + StorageUserState userState = em.find(StorageUserState.class, userStateId); + + // Convert to domain object with merged state + return createDomainSessionWithMergedState(session, appState, userState); + }, + "Error getting session")) + .subscribeOn(Schedulers.io()); + } + + @Override + public Single listSessions(String appName, String userId) { + return Single.fromCallable( + () -> + executeReadOnly( + em -> { + List sessions = new ArrayList<>(); + + // Create query to find sessions + CriteriaBuilder cb = em.getCriteriaBuilder(); + CriteriaQuery cq = cb.createQuery(StorageSession.class); + Root root = cq.from(StorageSession.class); + + // Apply filters + cq.where( + cb.and( + cb.equal(root.get("appName"), appName), + cb.equal(root.get("userId"), userId))); + + // Order by update time descending + cq.orderBy(cb.desc(root.get("updateTime"))); + + // Execute query + List results = + em.createQuery(cq).setMaxResults(DEFAULT_FETCH_LIMIT).getResultList(); + + // Find app and user state + StorageAppState appState = em.find(StorageAppState.class, appName); + if (appState == null) { + appState = new StorageAppState(); + appState.setAppName(appName); + appState.setState(new HashMap<>()); + appState.setUpdateTime(Instant.now()); + } + + UserStateId userStateId = new UserStateId(appName, userId); + StorageUserState userState = em.find(StorageUserState.class, userStateId); + if (userState == null) { + userState = new StorageUserState(); + userState.setAppName(appName); + userState.setUserId(userId); + userState.setState(new HashMap<>()); + userState.setUpdateTime(Instant.now()); + } + + // Convert to domain objects + for (StorageSession result : results) { + // For listing, we don't need to load all events + result.setEvents(Collections.emptyList()); + sessions.add( + createDomainSessionWithMergedState(result, appState, userState)); + } + + return ListSessionsResponse.builder().sessions(sessions).build(); + }, + "Error listing sessions")) + .subscribeOn(Schedulers.io()); + } + + /** + * Lists the events within a specific session. + * + *

This implementation delegates to the 5-parameter version with default values: + * + *

    + *
  • pageSize: {@link #DEFAULT_FETCH_LIMIT} (1000) + *
  • pageToken: null (start from the beginning) + *
+ * + * @param appName The name of the application + * @param userId The identifier of the user + * @param sessionId The unique identifier of the session whose events are to be listed + * @return A {@link ListEventsResponse} containing a list of events and an optional token for + * retrieving the next page + * @throws SessionNotFoundException if the session doesn't exist + * @throws RuntimeException for other listing errors + */ + /** {@inheritDoc} */ + @Override + public Single listEvents(String appName, String userId, String sessionId) { + return listEvents(appName, userId, sessionId, DEFAULT_FETCH_LIMIT, null); + } + + /** + * Lists the events within a specific session with pagination support. + * + * @param appName The name of the application + * @param userId The identifier of the user + * @param sessionId The unique identifier of the session whose events are to be listed + * @param pageSize The maximum number of events to return in a single page + * @param pageToken A token for pagination, representing the offset + * @return A ListEventsResponse containing a list of events and an optional token for the next + * page + */ + public Single listEvents( + String appName, String userId, String sessionId, int pageSize, @Nullable String pageToken) { + return Single.fromCallable( + () -> + executeReadOnly( + em -> { + // Parse page token once with proper error handling + int offset = 0; + if (pageToken != null) { + try { + offset = Integer.parseInt(pageToken); + } catch (NumberFormatException e) { + logger.warn("Invalid page token: {}. Defaulting to offset 0.", pageToken); + } + } + + // Create query to find events + CriteriaBuilder cb = em.getCriteriaBuilder(); + CriteriaQuery cq = cb.createQuery(StorageEvent.class); + Root root = cq.from(StorageEvent.class); + + // Apply filters + cq.where( + cb.and( + cb.equal(root.get("appName"), appName), + cb.equal(root.get("userId"), userId), + cb.equal(root.get("sessionId"), sessionId))); + + // Order by timestamp ascending + cq.orderBy(cb.asc(root.get("timestamp"))); + + // Execute query with pagination + TypedQuery query = em.createQuery(cq); + if (offset > 0) { + query.setFirstResult(offset); + } + query.setMaxResults(pageSize); + + List results = query.getResultList(); + + // Convert to domain objects + List events = + results.stream() + .map(StorageEvent::toDomainEvent) + .collect(Collectors.toList()); + + // Calculate next page token + String nextPageToken = null; + if (results.size() >= pageSize) { + nextPageToken = String.valueOf(offset + pageSize); + } + + ListEventsResponse.Builder responseBuilder = + ListEventsResponse.builder().events(events); + if (nextPageToken != null) { + responseBuilder.nextPageToken(nextPageToken); + } + return responseBuilder.build(); + }, + "Error listing events")) + .subscribeOn(Schedulers.io()); + } + + @Override + public Single appendEvent(Session session, Event event) { + Objects.requireNonNull(session, "session cannot be null"); + Objects.requireNonNull(event, "event cannot be null"); + Objects.requireNonNull(session.appName(), "session.appName cannot be null"); + Objects.requireNonNull(session.userId(), "session.userId cannot be null"); + Objects.requireNonNull(session.id(), "session.id cannot be null"); + + return BaseSessionService.super + .appendEvent(session, event) + .flatMap( + updatedEvent -> + appendEvent(session.appName(), session.userId(), session.id(), event) + .map(dbSession -> updatedEvent)); + } + + /** + * Appends an event to a session identified by app name, user ID, and session ID. + * + * @param appName The name of the application + * @param userId The identifier of the user + * @param sessionId The unique identifier of the session + * @param event The event to append + * @return The updated session + */ + public Single appendEvent(String appName, String userId, String sessionId, Event event) { + return Single.fromCallable( + () -> + executeInTransaction( + em -> { + // Find session + SessionId id = new SessionId(appName, userId, sessionId); + StorageSession session = + em.find(StorageSession.class, id, LockModeType.PESSIMISTIC_WRITE); + if (session == null) { + throw new SessionNotFoundException( + String.format( + "Session not found: appName=%s, userId=%s, sessionId=%s", + appName, userId, sessionId)); + } + + // Find app state , create new entry in app_states table if not present + StorageAppState appState = em.find(StorageAppState.class, appName); + if (appState == null) { + appState = getOrCreateAppState(em, appName); + } + + // Find user state , create new entry in user_states table if not present + UserStateId userStateId = new UserStateId(appName, userId); + StorageUserState userState = em.find(StorageUserState.class, userStateId); + if (userState == null) { + userState = getOrCreateUserState(em, userStateId); + } + + // Process state delta if present in event + if (event.actions() != null && event.actions().stateDelta() != null) { + ConcurrentMap stateDelta = event.actions().stateDelta(); + + Map appStateDelta = new HashMap<>(); + Map userStateDelta = new HashMap<>(); + Map sessionStateDelta = new HashMap<>(); + + // Split delta by prefix + for (Map.Entry entry : stateDelta.entrySet()) { + String key = entry.getKey(); + if (key.startsWith(State.APP_PREFIX)) { + appStateDelta.put(key, entry.getValue()); + } else if (key.startsWith(State.USER_PREFIX)) { + userStateDelta.put(key, entry.getValue()); + } else if (!key.startsWith(State.TEMP_PREFIX)) { + sessionStateDelta.put(key, entry.getValue()); + } + } + + // Update app_states - state , update_time column + if (!appStateDelta.isEmpty()) { + appState.getState().putAll(appStateDelta); + appState.setUpdateTime(Instant.now()); + em.merge(appState); + } + + // Update user_states - state , update_time column + if (!userStateDelta.isEmpty()) { + userState.getState().putAll(userStateDelta); + userState.setUpdateTime(Instant.now()); + em.merge(userState); + } + + // Update session state + if (!sessionStateDelta.isEmpty()) { + session.getState().putAll(sessionStateDelta); + } + } + + // Create event entity + StorageEvent storageEvent = StorageEvent.fromDomainEvent(event, session); + session.addEvent(storageEvent); + + // Update session timestamp + session.setUpdateTime(Instant.now()); + + // Save changes + em.persist(storageEvent); + em.merge(session); + + // Convert to domain object with merged state + return createDomainSessionWithMergedState(session, appState, userState); + }, + "Error appending event")) + .subscribeOn(Schedulers.io()); + } + + @Override + public Completable deleteSession(String appName, String userId, String sessionId) { + return Completable.fromCallable( + () -> { + executeInTransaction( + em -> { + // Find session + SessionId id = new SessionId(appName, userId, sessionId); + StorageSession session = em.find(StorageSession.class, id); + if (session == null) { + throw new SessionNotFoundException( + String.format( + "Session not found: appName=%s, userId=%s, sessionId=%s", + appName, userId, sessionId)); + } + + // Delete session (cascade will delete events) + em.remove(session); + + return null; + }, + "Error deleting session"); + return null; + }) + .subscribeOn(Schedulers.io()); + } + + /** + * Cleans up resources used by this service. This method should be called when the service is no + * longer needed. + */ + public void close() { + if (emf != null && emf.isOpen()) { + emf.close(); + } + } + + /** + * Executes a database operation within a transaction. + * + * @param operation The operation to execute + * @param errorMessage The error message to log on failure + * @param The return type + * @return The result of the operation + */ + private T executeInTransaction( + java.util.function.Function operation, String errorMessage) { + EntityManager em = emf.createEntityManager(); + try { + em.getTransaction().begin(); + T result = operation.apply(em); + em.getTransaction().commit(); + return result; + } catch (Exception e) { + if (em.getTransaction().isActive()) { + em.getTransaction().rollback(); + } + if (e instanceof SessionNotFoundException) { + throw e; + } + logger.error(errorMessage, e); + throw new RuntimeException(errorMessage, e); + } finally { + em.close(); + } + } + + /** + * Executes a read-only database operation without an explicit transaction. + * + * @param operation The operation to execute + * @param errorMessage The error message to log on failure + * @param The return type + * @return The result of the operation + */ + private T executeReadOnly( + java.util.function.Function operation, String errorMessage) { + EntityManager em = emf.createEntityManager(); + try { + return operation.apply(em); + } catch (Exception e) { + logger.error(errorMessage, e); + throw new RuntimeException(errorMessage, e); + } finally { + em.close(); + } + } + + /** + * Gets or creates an app state entity. + * + * @param em The EntityManager + * @param appName The application name + * @return The app state entity + */ + private StorageAppState getOrCreateAppState(EntityManager em, String appName) { + StorageAppState appState = em.find(StorageAppState.class, appName); + if (appState == null) { + appState = new StorageAppState(); + appState.setAppName(appName); + appState.setState(new HashMap<>()); + appState.setUpdateTime(Instant.now()); + try { + em.persist(appState); + em.flush(); + } catch (PersistenceException e) { + StorageAppState existingState = em.find(StorageAppState.class, appName); + if (existingState != null) { + return existingState; + } + throw e; + } + } + return appState; + } + + /** + * Gets or creates a user state entity. + * + * @param em The EntityManager + * @param userStateId The user state ID + * @return The user state entity + */ + private StorageUserState getOrCreateUserState(EntityManager em, UserStateId userStateId) { + StorageUserState userState = em.find(StorageUserState.class, userStateId); + if (userState == null) { + userState = new StorageUserState(); + userState.setAppName(userStateId.getAppName()); + userState.setUserId(userStateId.getUserId()); + userState.setState(new HashMap<>()); + userState.setUpdateTime(Instant.now()); + try { + em.persist(userState); + em.flush(); + } catch (PersistenceException e) { + StorageUserState existingState = em.find(StorageUserState.class, userStateId); + if (existingState != null) { + return existingState; + } + throw e; + } + } + return userState; + } + + private Session createDomainSessionWithMergedState( + StorageSession storage, StorageAppState appState, StorageUserState userState) { + // Merge state from all three tables: app -> user -> session + ConcurrentHashMap mergedState = new ConcurrentHashMap<>(); + + // 1. Add app state (lowest priority) + if (appState != null && appState.getState() != null) { + mergedState.putAll(appState.getState()); + } + + // 2. Add user state (medium priority, overwrites app state) + if (userState != null && userState.getState() != null) { + mergedState.putAll(userState.getState()); + } + + // 3. Add session state (highest priority, overwrites user and app state) + if (storage.getState() != null) { + mergedState.putAll(storage.getState()); + } + + // Convert storage entity to domain object with merged state + Session.Builder sessionBuilder = + Session.builder(storage.getId()) + .appName(storage.getAppName()) + .userId(storage.getUserId()) + .state(mergedState) + .lastUpdateTime(storage.getUpdateTime()); + + // Convert events if needed - use ArrayList for mutability + if (storage.getEvents() != null && !storage.getEvents().isEmpty()) { + List events = + storage.getEvents().stream() + .map(StorageEvent::toDomainEvent) + .collect(Collectors.toCollection(ArrayList::new)); + sessionBuilder.events(events); + } + + return sessionBuilder.build(); + } + + /** + * Extracts the dialect folder name from the Hibernate dialect class name. + * + * @param dialect The Hibernate dialect class name + * @return The simplified dialect name for folder lookup + */ + private String extractDialectFolderName(String dialect) { + if (dialect == null) { + return "postgresql"; // Default to PostgreSQL if unknown + } + + // Extract the database name from the dialect class + if (dialect.contains("PostgreSQL")) { + return "postgresql"; + } else if (dialect.contains("MySQL")) { + return "mysql"; + } else if (dialect.contains("H2")) { + return "h2"; + } else if (dialect.contains("SQLite")) { + return "sqlite"; + } else if (dialect.contains("Spanner")) { + return "spanner"; + } + + logger.warn("Unknown dialect '{}', defaulting to PostgreSQL", dialect); + return "postgresql"; + } +} diff --git a/core/src/main/java/com/google/adk/sessions/db/converter/EventActionsUserType.java b/core/src/main/java/com/google/adk/sessions/db/converter/EventActionsUserType.java new file mode 100644 index 000000000..7d3acf35a --- /dev/null +++ b/core/src/main/java/com/google/adk/sessions/db/converter/EventActionsUserType.java @@ -0,0 +1,167 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions.db.converter; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.events.EventActions; +import java.io.Serializable; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Types; +import org.hibernate.engine.spi.SharedSessionContractImplementor; +import org.hibernate.usertype.UserType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class EventActionsUserType implements UserType { + + private static final Logger logger = LoggerFactory.getLogger(EventActionsUserType.class); + private static final ObjectMapper MAPPER = new ObjectMapper(); + + static { + MAPPER.findAndRegisterModules(); + } + + @Override + public int getSqlType() { + return Types.OTHER; + } + + @Override + public Class returnedClass() { + return EventActions.class; + } + + @Override + public boolean equals(EventActions x, EventActions y) { + if (x == y) { + return true; + } + if (x == null || y == null) { + return false; + } + return x.equals(y); + } + + @Override + public int hashCode(EventActions x) { + return x == null ? 0 : x.hashCode(); + } + + @Override + public EventActions nullSafeGet( + ResultSet rs, int position, SharedSessionContractImplementor session, Object owner) + throws SQLException { + + Object obj = rs.getObject(position); + if (obj == null) { + return EventActions.builder().build(); + } + + try { + if (obj instanceof org.postgresql.util.PGobject) { + String json = ((org.postgresql.util.PGobject) obj).getValue(); + if (json == null || json.isEmpty()) { + return EventActions.builder().build(); + } + return MAPPER.readValue(json, EventActions.class); + } else if (obj instanceof String) { + String json = (String) obj; + if (json.isEmpty()) { + return EventActions.builder().build(); + } + return MAPPER.readValue(json, EventActions.class); + } else if (obj instanceof java.sql.Clob) { + java.sql.Clob clob = (java.sql.Clob) obj; + String json = clob.getSubString(1, (int) clob.length()); + if (json == null || json.isEmpty()) { + return EventActions.builder().build(); + } + return MAPPER.readValue(json, EventActions.class); + } else { + logger.warn("Unexpected type from database: {}", obj.getClass().getName()); + return EventActions.builder().build(); + } + } catch (Exception e) { + logger.error("Error deserializing EventActions from database", e); + return EventActions.builder().build(); + } + } + + @Override + public void nullSafeSet( + PreparedStatement st, EventActions value, int index, SharedSessionContractImplementor session) + throws SQLException { + + if (value == null) { + st.setNull(index, Types.OTHER); + return; + } + + try { + String json = MAPPER.writeValueAsString(value); + + String databaseProductName = st.getConnection().getMetaData().getDatabaseProductName(); + + if (databaseProductName.toLowerCase().contains("postgres")) { + org.postgresql.util.PGobject pgo = new org.postgresql.util.PGobject(); + pgo.setType("jsonb"); + pgo.setValue(json); + st.setObject(index, pgo, Types.OTHER); + } else { + st.setString(index, json); + } + } catch (Exception e) { + logger.error("Error serializing EventActions to database", e); + throw new SQLException("Failed to convert EventActions to JSON", e); + } + } + + @Override + public EventActions deepCopy(EventActions value) { + if (value == null) { + return null; + } + try { + String json = MAPPER.writeValueAsString(value); + return MAPPER.readValue(json, EventActions.class); + } catch (Exception e) { + logger.error("Error deep copying EventActions", e); + return value; + } + } + + @Override + public boolean isMutable() { + return true; + } + + @Override + public Serializable disassemble(EventActions value) { + return (Serializable) deepCopy(value); + } + + @Override + public EventActions assemble(Serializable cached, Object owner) { + return (EventActions) cached; + } + + @Override + public EventActions replace(EventActions detached, EventActions managed, Object owner) { + return deepCopy(detached); + } +} diff --git a/core/src/main/java/com/google/adk/sessions/db/converter/JsonUserType.java b/core/src/main/java/com/google/adk/sessions/db/converter/JsonUserType.java new file mode 100644 index 000000000..3c4bccab9 --- /dev/null +++ b/core/src/main/java/com/google/adk/sessions/db/converter/JsonUserType.java @@ -0,0 +1,174 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions.db.converter; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.Serializable; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Types; +import java.util.HashMap; +import java.util.Map; +import org.hibernate.HibernateException; +import org.hibernate.engine.spi.SharedSessionContractImplementor; +import org.hibernate.usertype.UserType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class JsonUserType implements UserType> { + + private static final Logger logger = LoggerFactory.getLogger(JsonUserType.class); + private static final ObjectMapper MAPPER = new ObjectMapper(); + + static { + MAPPER.findAndRegisterModules(); + } + + @Override + public int getSqlType() { + return Types.OTHER; + } + + @Override + public Class> returnedClass() { + return (Class) Map.class; + } + + @Override + public boolean equals(Map x, Map y) { + if (x == y) { + return true; + } + if (x == null || y == null) { + return false; + } + return x.equals(y); + } + + @Override + public int hashCode(Map x) { + return x == null ? 0 : x.hashCode(); + } + + @Override + public Map nullSafeGet( + ResultSet rs, int position, SharedSessionContractImplementor session, Object owner) + throws SQLException { + + Object obj = rs.getObject(position); + if (obj == null) { + return new HashMap<>(); + } + + try { + if (obj instanceof org.postgresql.util.PGobject) { + String json = ((org.postgresql.util.PGobject) obj).getValue(); + if (json == null || json.isEmpty()) { + return new HashMap<>(); + } + return MAPPER.readValue(json, new TypeReference>() {}); + } else if (obj instanceof String) { + String json = (String) obj; + if (json.isEmpty()) { + return new HashMap<>(); + } + return MAPPER.readValue(json, new TypeReference>() {}); + } else if (obj instanceof java.sql.Clob) { + java.sql.Clob clob = (java.sql.Clob) obj; + String json = clob.getSubString(1, (int) clob.length()); + if (json == null || json.isEmpty()) { + return new HashMap<>(); + } + return MAPPER.readValue(json, new TypeReference>() {}); + } else { + logger.warn("Unexpected type from database: {}", obj.getClass().getName()); + return new HashMap<>(); + } + } catch (Exception e) { + logger.error("Error deserializing JSON from database", e); + return new HashMap<>(); + } + } + + @Override + public void nullSafeSet( + PreparedStatement st, + Map value, + int index, + SharedSessionContractImplementor session) + throws SQLException { + + if (value == null || value.isEmpty()) { + st.setNull(index, Types.OTHER); + return; + } + + try { + String json = MAPPER.writeValueAsString(value); + + String databaseProductName = st.getConnection().getMetaData().getDatabaseProductName(); + + if (databaseProductName.toLowerCase().contains("postgres")) { + org.postgresql.util.PGobject pgo = new org.postgresql.util.PGobject(); + pgo.setType("jsonb"); + pgo.setValue(json); + st.setObject(index, pgo, Types.OTHER); + } else { + st.setString(index, json); + } + } catch (Exception e) { + logger.error("Error serializing JSON to database", e); + throw new SQLException("Failed to convert Map to JSON", e); + } + } + + @Override + public Map deepCopy(Map value) { + if (value == null) { + return null; + } + try { + String json = MAPPER.writeValueAsString(value); + return MAPPER.readValue(json, new TypeReference>() {}); + } catch (JsonProcessingException e) { + throw new HibernateException("Failed to deep copy map state", e); + } + } + + @Override + public boolean isMutable() { + return true; + } + + @Override + public Serializable disassemble(Map value) { + return (Serializable) deepCopy(value); + } + + @Override + public Map assemble(Serializable cached, Object owner) { + return (Map) cached; + } + + @Override + public Map replace( + Map detached, Map managed, Object owner) { + return deepCopy(detached); + } +} diff --git a/core/src/main/java/com/google/adk/sessions/db/entity/EventId.java b/core/src/main/java/com/google/adk/sessions/db/entity/EventId.java new file mode 100644 index 000000000..2f0f82a8a --- /dev/null +++ b/core/src/main/java/com/google/adk/sessions/db/entity/EventId.java @@ -0,0 +1,85 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions.db.entity; + +import java.io.Serializable; +import java.util.Objects; + +/** Composite key for the StorageEvent entity, consisting of id, appName, userId, and sessionId. */ +public class EventId implements Serializable { + private String id; + private String appName; + private String userId; + private String sessionId; + + // Default constructor required by JPA + public EventId() {} + + public EventId(String id, String appName, String userId, String sessionId) { + this.id = id; + this.appName = appName; + this.userId = userId; + this.sessionId = sessionId; + } + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getAppName() { + return appName; + } + + public void setAppName(String appName) { + this.appName = appName; + } + + public String getUserId() { + return userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + public String getSessionId() { + return sessionId; + } + + public void setSessionId(String sessionId) { + this.sessionId = sessionId; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + EventId eventId = (EventId) o; + return Objects.equals(id, eventId.id) + && Objects.equals(appName, eventId.appName) + && Objects.equals(userId, eventId.userId) + && Objects.equals(sessionId, eventId.sessionId); + } + + @Override + public int hashCode() { + return Objects.hash(id, appName, userId, sessionId); + } +} diff --git a/core/src/main/java/com/google/adk/sessions/db/entity/SessionId.java b/core/src/main/java/com/google/adk/sessions/db/entity/SessionId.java new file mode 100644 index 000000000..4400c0883 --- /dev/null +++ b/core/src/main/java/com/google/adk/sessions/db/entity/SessionId.java @@ -0,0 +1,74 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions.db.entity; + +import java.io.Serializable; +import java.util.Objects; + +/** Composite key for the StorageSession entity, consisting of appName, userId, and id. */ +public class SessionId implements Serializable { + private String appName; + private String userId; + private String id; + + // Default constructor required by JPA + public SessionId() {} + + public SessionId(String appName, String userId, String id) { + this.appName = appName; + this.userId = userId; + this.id = id; + } + + public String getAppName() { + return appName; + } + + public void setAppName(String appName) { + this.appName = appName; + } + + public String getUserId() { + return userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SessionId sessionId = (SessionId) o; + return Objects.equals(appName, sessionId.appName) + && Objects.equals(userId, sessionId.userId) + && Objects.equals(id, sessionId.id); + } + + @Override + public int hashCode() { + return Objects.hash(appName, userId, id); + } +} diff --git a/core/src/main/java/com/google/adk/sessions/db/entity/StorageAppState.java b/core/src/main/java/com/google/adk/sessions/db/entity/StorageAppState.java new file mode 100644 index 000000000..a04f79e8a --- /dev/null +++ b/core/src/main/java/com/google/adk/sessions/db/entity/StorageAppState.java @@ -0,0 +1,73 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions.db.entity; + +import com.google.adk.sessions.db.converter.JsonUserType; +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Table; +import java.time.Instant; +import java.util.Map; +import org.hibernate.annotations.Type; + +/** + * Entity for storing application-level state in the database. This is mapped to the "app_states" + * table. + */ +@Entity +@Table(name = "app_states") +public class StorageAppState { + + @Id + @Column(name = "app_name", length = 128) + private String appName; + + @Column(name = "state") + @Type(JsonUserType.class) + private Map state; + + @Column(name = "update_time") + private Instant updateTime; + + // Default constructor + public StorageAppState() {} + + // Getters and setters + public String getAppName() { + return appName; + } + + public void setAppName(String appName) { + this.appName = appName; + } + + public Map getState() { + return state; + } + + public void setState(Map state) { + this.state = state; + } + + public Instant getUpdateTime() { + return updateTime; + } + + public void setUpdateTime(Instant updateTime) { + this.updateTime = updateTime; + } +} diff --git a/core/src/main/java/com/google/adk/sessions/db/entity/StorageEvent.java b/core/src/main/java/com/google/adk/sessions/db/entity/StorageEvent.java new file mode 100644 index 000000000..6cc90ffe3 --- /dev/null +++ b/core/src/main/java/com/google/adk/sessions/db/entity/StorageEvent.java @@ -0,0 +1,506 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions.db.entity; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.adk.sessions.db.converter.EventActionsUserType; +import com.google.adk.sessions.db.converter.JsonUserType; +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.FetchType; +import jakarta.persistence.Id; +import jakarta.persistence.IdClass; +import jakarta.persistence.JoinColumn; +import jakarta.persistence.JoinColumns; +import jakarta.persistence.ManyToOne; +import jakarta.persistence.Table; +import java.time.Instant; +import java.util.Map; +import java.util.Optional; +import org.hibernate.annotations.Type; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Entity for storing event data in the database. This is mapped to the "events" table. */ +@Entity +@Table(name = "events") +@IdClass(EventId.class) +public class StorageEvent { + + private static final Logger logger = LoggerFactory.getLogger(StorageEvent.class); + + @Id + @Column(name = "id", length = 128) + private String id; + + @Id + @Column(name = "app_name", length = 128) + private String appName; + + @Id + @Column(name = "user_id", length = 128) + private String userId; + + @Id + @Column(name = "session_id", length = 128) + private String sessionId; + + @Column(name = "invocation_id", length = 256) + private String invocationId; + + @Column(name = "author", length = 256) + private String author; + + @Column(name = "actions") + @Type(EventActionsUserType.class) + private EventActions actions; + + @Column(name = "timestamp") + private Instant timestamp; + + @Column(name = "content") + @Type(JsonUserType.class) + private Map content; + + @Column(name = "grounding_metadata") + @Type(JsonUserType.class) + private Map groundingMetadata; + + @Column(name = "custom_metadata") + @Type(JsonUserType.class) + private Map customMetadata; + + @Column(name = "usage_metadata") + @Type(JsonUserType.class) + private Map usageMetadata; + + @Column(name = "citation_metadata") + @Type(JsonUserType.class) + private Map citationMetadata; + + @Column(name = "partial") + private Boolean partial; + + @Column(name = "turn_complete") + private Boolean turnComplete; + + @Column(name = "error_code", length = 256) + private String errorCode; + + @Column(name = "error_message", columnDefinition = "TEXT") + private String errorMessage; + + @Column(name = "interrupted") + private Boolean interrupted; + + @Column(name = "branch", length = 256) + private String branch; + + @Column(name = "long_running_tool_ids_json", columnDefinition = "TEXT") + private String longRunningToolIdsJson; + + @Column(name = "input_transcription") + @Type(JsonUserType.class) + private Map inputTranscription; + + @Column(name = "output_transcription") + @Type(JsonUserType.class) + private Map outputTranscription; + + @ManyToOne(fetch = FetchType.LAZY) + @JoinColumns({ + @JoinColumn( + name = "app_name", + referencedColumnName = "app_name", + insertable = false, + updatable = false), + @JoinColumn( + name = "user_id", + referencedColumnName = "user_id", + insertable = false, + updatable = false), + @JoinColumn( + name = "session_id", + referencedColumnName = "id", + insertable = false, + updatable = false) + }) + private StorageSession session; + + // Default constructor + public StorageEvent() {} + + // Getters and setters + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getAppName() { + return appName; + } + + public void setAppName(String appName) { + this.appName = appName; + } + + public String getUserId() { + return userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + public String getSessionId() { + return sessionId; + } + + public void setSessionId(String sessionId) { + this.sessionId = sessionId; + } + + public String getInvocationId() { + return invocationId; + } + + public void setInvocationId(String invocationId) { + this.invocationId = invocationId; + } + + public String getAuthor() { + return author; + } + + public void setAuthor(String author) { + this.author = author; + } + + public EventActions getActions() { + return actions; + } + + public void setActions(EventActions actions) { + this.actions = actions; + } + + public Instant getTimestamp() { + return timestamp; + } + + public void setTimestamp(Instant timestamp) { + this.timestamp = timestamp; + } + + public Map getContent() { + return content; + } + + public void setContent(Map content) { + this.content = content; + } + + public Map getGroundingMetadata() { + return groundingMetadata; + } + + public void setGroundingMetadata(Map groundingMetadata) { + this.groundingMetadata = groundingMetadata; + } + + public Map getCustomMetadata() { + return customMetadata; + } + + public void setCustomMetadata(Map customMetadata) { + this.customMetadata = customMetadata; + } + + public Map getUsageMetadata() { + return usageMetadata; + } + + public void setUsageMetadata(Map usageMetadata) { + this.usageMetadata = usageMetadata; + } + + public Map getCitationMetadata() { + return citationMetadata; + } + + public void setCitationMetadata(Map citationMetadata) { + this.citationMetadata = citationMetadata; + } + + public Boolean getPartial() { + return partial; + } + + public void setPartial(Boolean partial) { + this.partial = partial; + } + + public Boolean getTurnComplete() { + return turnComplete; + } + + public void setTurnComplete(Boolean turnComplete) { + this.turnComplete = turnComplete; + } + + public String getErrorCode() { + return errorCode; + } + + public void setErrorCode(String errorCode) { + this.errorCode = errorCode; + } + + public String getErrorMessage() { + return errorMessage; + } + + public void setErrorMessage(String errorMessage) { + this.errorMessage = errorMessage; + } + + public Boolean getInterrupted() { + return interrupted; + } + + public void setInterrupted(Boolean interrupted) { + this.interrupted = interrupted; + } + + public String getBranch() { + return branch; + } + + public void setBranch(String branch) { + this.branch = branch; + } + + public String getLongRunningToolIdsJson() { + return longRunningToolIdsJson; + } + + public void setLongRunningToolIdsJson(String longRunningToolIdsJson) { + this.longRunningToolIdsJson = longRunningToolIdsJson; + } + + public Map getInputTranscription() { + return inputTranscription; + } + + public void setInputTranscription(Map inputTranscription) { + this.inputTranscription = inputTranscription; + } + + public Map getOutputTranscription() { + return outputTranscription; + } + + public void setOutputTranscription(Map outputTranscription) { + this.outputTranscription = outputTranscription; + } + + public StorageSession getSession() { + return session; + } + + public void setSession(StorageSession session) { + this.session = session; + } + + /** + * Converts a storage entity to a domain model Event. + * + * @return A domain Event object created from this storage entity + */ + public Event toDomainEvent() { + Event.Builder builder = + Event.builder().id(this.id).invocationId(this.invocationId).author(this.author); + + if (this.actions != null) { + builder.actions(this.actions); + } + + // Handle Optional fields + builder.partial(Optional.ofNullable(this.partial)); + builder.turnComplete(Optional.ofNullable(this.turnComplete)); + builder.interrupted(Optional.ofNullable(this.interrupted)); + + // Set timestamp (long expected) + if (this.timestamp != null) { + builder.timestamp(this.timestamp.toEpochMilli()); + } + + // Handle content map to Content object + if (this.content != null) { + try { + com.google.genai.types.Content contentObj = deserializeContent(this.content); + builder.content(Optional.of(contentObj)); + } catch (Exception e) { + logger.warn("Failed to deserialize content for event {}: {}", this.id, e.getMessage(), e); + builder.content(Optional.empty()); + } + } + + if (this.errorCode != null && !this.errorCode.isEmpty()) { + builder.errorCode(Optional.of(new com.google.genai.types.FinishReason(this.errorCode))); + } + + if (this.errorMessage != null) { + // Use explicit Optional.of() to avoid ambiguity + builder.errorMessage(Optional.of(this.errorMessage)); + } + + if (this.groundingMetadata != null && !this.groundingMetadata.isEmpty()) { + try { + String json = + com.google.adk.JsonBaseModel.getMapper().writeValueAsString(this.groundingMetadata); + com.google.genai.types.GroundingMetadata metadata = + com.google.adk.JsonBaseModel.getMapper() + .readValue(json, com.google.genai.types.GroundingMetadata.class); + builder.groundingMetadata(Optional.of(metadata)); + } catch (Exception e) { + logger.warn( + "Failed to deserialize grounding metadata for event {}: {}", + this.id, + e.getMessage(), + e); + builder.groundingMetadata(Optional.empty()); + } + } + + // Set branch from storage + builder.branch(Optional.ofNullable(this.branch)); + + return builder.build(); + } + + /** + * Creates a StorageEvent entity from a domain Event model. + * + * @param event The domain Event to convert + * @param session The parent StorageSession + * @return A StorageEvent entity + */ + public static StorageEvent fromDomainEvent(Event event, StorageSession session) { + StorageEvent storageEvent = new StorageEvent(); + storageEvent.setId(event.id()); + storageEvent.setAppName(session.getAppName()); + storageEvent.setUserId(session.getUserId()); + storageEvent.setSessionId(session.getId()); + storageEvent.setSession(session); + storageEvent.setInvocationId(event.invocationId()); + storageEvent.setAuthor(event.author()); + + storageEvent.setActions(event.actions()); + + // Convert long timestamp to Instant + storageEvent.setTimestamp(Instant.ofEpochMilli(event.timestamp())); + + // Handle content - Convert Content to Map if present + event + .content() + .ifPresent( + content -> { + try { + String json = com.google.adk.JsonBaseModel.getMapper().writeValueAsString(content); + @SuppressWarnings("unchecked") + Map contentMap = + com.google.adk.JsonBaseModel.getMapper() + .readValue( + json, + new com.fasterxml.jackson.core.type.TypeReference< + Map>() {}); + storageEvent.setContent(contentMap); + } catch (Exception e) { + logger.warn( + "Failed to serialize content for event {}: {}", event.id(), e.getMessage(), e); + } + }); + + event + .groundingMetadata() + .ifPresent( + metadata -> { + try { + String json = com.google.adk.JsonBaseModel.getMapper().writeValueAsString(metadata); + @SuppressWarnings("unchecked") + Map metadataMap = + com.google.adk.JsonBaseModel.getMapper() + .readValue( + json, + new com.fasterxml.jackson.core.type.TypeReference< + Map>() {}); + storageEvent.setGroundingMetadata(metadataMap); + } catch (Exception e) { + logger.warn( + "Failed to serialize grounding metadata for event {}: {}", + event.id(), + e.getMessage(), + e); + } + }); + + // Handle Boolean fields that are now Optional + storageEvent.setPartial(event.partial().orElse(null)); + storageEvent.setTurnComplete(event.turnComplete().orElse(null)); + storageEvent.setInterrupted(event.interrupted().orElse(null)); + + // Handle error code - FinishReason enum to String + event + .errorCode() + .ifPresent( + code -> { + // We can't directly use name(), so store the toString() result + storageEvent.setErrorCode(code.toString()); + }); + + // Handle error message + event.errorMessage().ifPresent(storageEvent::setErrorMessage); + + // Set the parent session + storageEvent.setSession(session); + return storageEvent; + } + + /** + * Deserializes a Content object from a Map. + * + * @param contentMap The map containing content data + * @return A Content object + */ + private static com.google.genai.types.Content deserializeContent(Map contentMap) { + if (contentMap == null) { + return null; + } + + try { + return com.google.adk.JsonBaseModel.getMapper() + .convertValue(contentMap, com.google.genai.types.Content.class); + } catch (IllegalArgumentException e) { + logger.warn("Failed to deserialize content from map: {}", e.getMessage(), e); + return null; + } + } +} diff --git a/core/src/main/java/com/google/adk/sessions/db/entity/StorageSession.java b/core/src/main/java/com/google/adk/sessions/db/entity/StorageSession.java new file mode 100644 index 000000000..16ccbd2c8 --- /dev/null +++ b/core/src/main/java/com/google/adk/sessions/db/entity/StorageSession.java @@ -0,0 +1,132 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions.db.entity; + +import com.google.adk.sessions.db.converter.JsonUserType; +import jakarta.persistence.CascadeType; +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.IdClass; +import jakarta.persistence.OneToMany; +import jakarta.persistence.Table; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.hibernate.annotations.Type; + +/** Entity for storing session data in the database. This is mapped to the "sessions" table. */ +@Entity +@Table(name = "sessions") +@IdClass(SessionId.class) +public class StorageSession { + + @Id + @Column(name = "app_name", length = 128) + private String appName; + + @Id + @Column(name = "user_id", length = 128) + private String userId; + + @Id + @Column(name = "id", length = 128) + private String id; + + @Column(name = "state") + @Type(JsonUserType.class) + private Map state; + + @Column(name = "create_time") + private Instant createTime; + + @Column(name = "update_time") + private Instant updateTime; + + @OneToMany(mappedBy = "session", cascade = CascadeType.ALL, orphanRemoval = true) + private List events = new ArrayList<>(); + + // Default constructor + public StorageSession() {} + + // Getter and setter methods + public String getAppName() { + return appName; + } + + public void setAppName(String appName) { + this.appName = appName; + } + + public String getUserId() { + return userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public Map getState() { + return state; + } + + public void setState(Map state) { + this.state = state; + } + + public Instant getCreateTime() { + return createTime; + } + + public void setCreateTime(Instant createTime) { + this.createTime = createTime; + } + + public Instant getUpdateTime() { + return updateTime; + } + + public void setUpdateTime(Instant updateTime) { + this.updateTime = updateTime; + } + + public List getEvents() { + return events; + } + + public void setEvents(List events) { + this.events = events; + } + + /** + * Adds an event to this session. + * + * @param event The event to add + */ + public void addEvent(StorageEvent event) { + events.add(event); + event.setSession(this); + } +} diff --git a/core/src/main/java/com/google/adk/sessions/db/entity/StorageUserState.java b/core/src/main/java/com/google/adk/sessions/db/entity/StorageUserState.java new file mode 100644 index 000000000..588de48e9 --- /dev/null +++ b/core/src/main/java/com/google/adk/sessions/db/entity/StorageUserState.java @@ -0,0 +1,86 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions.db.entity; + +import com.google.adk.sessions.db.converter.JsonUserType; +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.IdClass; +import jakarta.persistence.Table; +import java.time.Instant; +import java.util.Map; +import org.hibernate.annotations.Type; + +/** + * Entity for storing user-level state in the database. This is mapped to the "user_states" table. + */ +@Entity +@Table(name = "user_states") +@IdClass(UserStateId.class) +public class StorageUserState { + + @Id + @Column(name = "app_name", length = 128) + private String appName; + + @Id + @Column(name = "user_id", length = 128) + private String userId; + + @Column(name = "state") + @Type(JsonUserType.class) + private Map state; + + @Column(name = "update_time") + private Instant updateTime; + + // Default constructor + public StorageUserState() {} + + // Getters and setters + public String getAppName() { + return appName; + } + + public void setAppName(String appName) { + this.appName = appName; + } + + public String getUserId() { + return userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + public Map getState() { + return state; + } + + public void setState(Map state) { + this.state = state; + } + + public Instant getUpdateTime() { + return updateTime; + } + + public void setUpdateTime(Instant updateTime) { + this.updateTime = updateTime; + } +} diff --git a/core/src/main/java/com/google/adk/sessions/db/entity/UserStateId.java b/core/src/main/java/com/google/adk/sessions/db/entity/UserStateId.java new file mode 100644 index 000000000..fe1e66b93 --- /dev/null +++ b/core/src/main/java/com/google/adk/sessions/db/entity/UserStateId.java @@ -0,0 +1,62 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions.db.entity; + +import java.io.Serializable; +import java.util.Objects; + +/** Composite key for the StorageUserState entity, consisting of appName and userId. */ +public class UserStateId implements Serializable { + private String appName; + private String userId; + + // Default constructor required by JPA + public UserStateId() {} + + public UserStateId(String appName, String userId) { + this.appName = appName; + this.userId = userId; + } + + public String getAppName() { + return appName; + } + + public void setAppName(String appName) { + this.appName = appName; + } + + public String getUserId() { + return userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + UserStateId that = (UserStateId) o; + return Objects.equals(appName, that.appName) && Objects.equals(userId, that.userId); + } + + @Override + public int hashCode() { + return Objects.hash(appName, userId); + } +} diff --git a/core/src/main/java/com/google/adk/sessions/db/util/DatabaseDialectDetector.java b/core/src/main/java/com/google/adk/sessions/db/util/DatabaseDialectDetector.java new file mode 100644 index 000000000..9e67a80bf --- /dev/null +++ b/core/src/main/java/com/google/adk/sessions/db/util/DatabaseDialectDetector.java @@ -0,0 +1,81 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions.db.util; + +/** Utility class for detecting the appropriate Hibernate dialect based on the database URL. */ +public class DatabaseDialectDetector { + + private DatabaseDialectDetector() { + // Utility class, no instantiation + } + + /** + * Detects the appropriate Hibernate dialect based on the database URL. + * + * @param dbUrl The database URL + * @return The Hibernate dialect class name + * @throws IllegalArgumentException If the database type is not supported + */ + public static String detectDialect(String dbUrl) { + if (dbUrl == null) { + throw new IllegalArgumentException("Database URL cannot be null"); + } + + // PostgreSQL + if (dbUrl.startsWith("jdbc:postgresql:")) { + return "org.hibernate.dialect.PostgreSQLDialect"; + } + + // MySQL + if (dbUrl.startsWith("jdbc:mysql:")) { + return "org.hibernate.dialect.MySQLDialect"; + } + + // SQLite + if (dbUrl.startsWith("jdbc:sqlite:")) { + return "org.hibernate.dialect.SQLiteDialect"; + } + + // H2 Database + if (dbUrl.startsWith("jdbc:h2:")) { + return "org.hibernate.dialect.H2Dialect"; + } + + // Cloud Spanner + if (dbUrl.startsWith("jdbc:cloudspanner:")) { + return "com.google.cloud.spanner.hibernate.SpannerDialect"; + } + + throw new IllegalArgumentException("Unsupported database URL: " + maskConnectionUrl(dbUrl)); + } + + /** + * Masks sensitive information in the database connection URL for logging purposes. + * + * @param url The database URL to mask + * @return A masked version of the URL + */ + public static String maskConnectionUrl(String url) { + if (url == null) { + return null; + } + + String result = url; + result = result.replaceAll(":([^:/@]+)@", "password:****@"); + result = result.replaceAll("password=([^&]*)", "password=****"); + return result; + } +} diff --git a/core/src/main/java/com/google/adk/sessions/db/util/EntityManagerFactoryProvider.java b/core/src/main/java/com/google/adk/sessions/db/util/EntityManagerFactoryProvider.java new file mode 100644 index 000000000..bf562dfa1 --- /dev/null +++ b/core/src/main/java/com/google/adk/sessions/db/util/EntityManagerFactoryProvider.java @@ -0,0 +1,83 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions.db.util; + +import jakarta.persistence.EntityManagerFactory; +import jakarta.persistence.Persistence; +import java.util.HashMap; +import java.util.Map; + +/** + * Provider for creating and managing an EntityManagerFactory for the DatabaseSessionService. Uses + * the singleton pattern to ensure only one EntityManagerFactory exists per persistence unit. + */ +public class EntityManagerFactoryProvider { + + private static final String PERSISTENCE_UNIT_NAME = "adk-sessions"; + + /** + * Creates an EntityManagerFactory with the specified database URL and configuration. + * + * @param dbUrl The database URL to connect to + * @param properties Additional properties for the EntityManagerFactory + * @return A configured EntityManagerFactory + */ + public static EntityManagerFactory createEntityManagerFactory( + String dbUrl, Map properties) { + + Map config = new HashMap<>(properties); + + // Set required properties if not already provided + if (!config.containsKey("jakarta.persistence.jdbc.url")) { + config.put("jakarta.persistence.jdbc.url", dbUrl); + } + + if (!config.containsKey("hibernate.dialect")) { + config.put("hibernate.dialect", DatabaseDialectDetector.detectDialect(dbUrl)); + } + + // Set default credentials for H2 if not provided + String dialect = (String) config.get("hibernate.dialect"); + if ("org.hibernate.dialect.H2Dialect".equals(dialect)) { + if (!config.containsKey("jakarta.persistence.jdbc.user")) { + config.put("jakarta.persistence.jdbc.user", "sa"); + } + if (!config.containsKey("jakarta.persistence.jdbc.password")) { + config.put("jakarta.persistence.jdbc.password", ""); + } + } + + // Set default schema generation mode if not specified + // Use validate mode by default to work with Flyway-managed schema + if (!config.containsKey("hibernate.hbm2ddl.auto")) { + config.put("hibernate.hbm2ddl.auto", "validate"); + } + + // Set default connection pooling properties + if (!config.containsKey("hibernate.connection.provider_class")) { + config.put( + "hibernate.connection.provider_class", + "org.hibernate.hikaricp.internal.HikariCPConnectionProvider"); + // Hibernate expects String values in configuration map; provide numbers as strings + config.put("hibernate.hikari.minimumIdle", "5"); + config.put("hibernate.hikari.maximumPoolSize", "20"); + config.put("hibernate.hikari.idleTimeout", "30000"); + } + + // Create the EntityManagerFactory + return Persistence.createEntityManagerFactory(PERSISTENCE_UNIT_NAME, config); + } +} diff --git a/core/src/main/resources/META-INF/persistence.xml b/core/src/main/resources/META-INF/persistence.xml new file mode 100644 index 000000000..a7c08d663 --- /dev/null +++ b/core/src/main/resources/META-INF/persistence.xml @@ -0,0 +1,29 @@ + + + + Persistence unit for ADK Session Storage + + com.google.adk.sessions.db.entity.StorageSession + com.google.adk.sessions.db.entity.StorageEvent + com.google.adk.sessions.db.entity.StorageAppState + com.google.adk.sessions.db.entity.StorageUserState + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/core/src/main/resources/db/migration/h2/V1__Initial_schema.sql b/core/src/main/resources/db/migration/h2/V1__Initial_schema.sql new file mode 100644 index 000000000..6553bb9fd --- /dev/null +++ b/core/src/main/resources/db/migration/h2/V1__Initial_schema.sql @@ -0,0 +1,75 @@ +-- V1__Initial_schema.sql for H2 Database +-- Initial database schema for ADK DatabaseSessionService +-- This represents the baseline schema for all database session operations + +-- Create sessions table +CREATE TABLE IF NOT EXISTS sessions ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + id VARCHAR(128) NOT NULL, + state CLOB, + create_time TIMESTAMP(6), + update_time TIMESTAMP(6), + PRIMARY KEY (app_name, user_id, id) +); + +-- Create events table +CREATE TABLE IF NOT EXISTS events ( + id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256), + author VARCHAR(256), + actions CLOB, + timestamp TIMESTAMP(6), + content CLOB, + grounding_metadata CLOB, + custom_metadata CLOB, + usage_metadata CLOB, + citation_metadata CLOB, + partial BOOLEAN, + turn_complete BOOLEAN, + error_code VARCHAR(256), + error_message CLOB, + interrupted BOOLEAN, + branch VARCHAR(256), + long_running_tool_ids_json CLOB, + input_transcription CLOB, + output_transcription CLOB, + finish_reason VARCHAR(256), + avg_logprobs DOUBLE, + model_version VARCHAR(256), + PRIMARY KEY (id, app_name, user_id, session_id), + FOREIGN KEY (app_name, user_id, session_id) + REFERENCES sessions(app_name, user_id, id) + ON DELETE CASCADE +); + +-- Create app states table +CREATE TABLE IF NOT EXISTS app_states ( + app_name VARCHAR(128) NOT NULL, + state CLOB, + update_time TIMESTAMP(6), + PRIMARY KEY (app_name) +); + +-- Create user states table +CREATE TABLE IF NOT EXISTS user_states ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state CLOB, + update_time TIMESTAMP(6), + PRIMARY KEY (app_name, user_id) +); + +-- Add indexes to improve query performance + +-- Index for looking up sessions by app_name and user_id +CREATE INDEX IF NOT EXISTS idx_sessions_app_user ON sessions(app_name, user_id); + +-- Index for looking up events by session +CREATE INDEX IF NOT EXISTS idx_events_session ON events(app_name, user_id, session_id); + +-- Index for sorting events by timestamp +CREATE INDEX IF NOT EXISTS idx_events_timestamp ON events(timestamp); \ No newline at end of file diff --git a/core/src/main/resources/db/migration/mysql/V1__Initial_schema.sql b/core/src/main/resources/db/migration/mysql/V1__Initial_schema.sql new file mode 100644 index 000000000..63982c365 --- /dev/null +++ b/core/src/main/resources/db/migration/mysql/V1__Initial_schema.sql @@ -0,0 +1,77 @@ +-- V1__Initial_schema.sql for MySQL +-- Initial database schema for ADK DatabaseSessionService +-- This represents the baseline schema for all database session operations + +-- Create sessions table +CREATE TABLE IF NOT EXISTS sessions ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + id VARCHAR(128) NOT NULL, + state LONGTEXT, + create_time TIMESTAMP(6), + update_time TIMESTAMP(6), + PRIMARY KEY (app_name, user_id, id) +); + +-- Create events table +CREATE TABLE IF NOT EXISTS events ( + id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256), + author VARCHAR(256), + actions LONGTEXT, + timestamp TIMESTAMP(6), + content LONGTEXT, + grounding_metadata LONGTEXT, + custom_metadata LONGTEXT, + usage_metadata LONGTEXT, + citation_metadata LONGTEXT, + partial BOOLEAN, + turn_complete BOOLEAN, + error_code VARCHAR(256), + error_message TEXT, + interrupted BOOLEAN, + branch VARCHAR(256), + long_running_tool_ids_json TEXT, + input_transcription LONGTEXT, + output_transcription LONGTEXT, + finish_reason VARCHAR(256), + avg_logprobs DOUBLE, + model_version VARCHAR(256), + PRIMARY KEY (id, app_name, user_id, session_id), + FOREIGN KEY (app_name, user_id, session_id) + REFERENCES sessions(app_name, user_id, id) + ON DELETE CASCADE +); + +-- Create app states table +CREATE TABLE IF NOT EXISTS app_states ( + app_name VARCHAR(128) NOT NULL, + state LONGTEXT, + update_time TIMESTAMP(6), + PRIMARY KEY (app_name) +); + +-- Create user states table +CREATE TABLE IF NOT EXISTS user_states ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state LONGTEXT, + update_time TIMESTAMP(6), + PRIMARY KEY (app_name, user_id) +); + +-- Add indexes to improve query performance +-- Note: CREATE INDEX IF NOT EXISTS requires MySQL 8.0.32+ +-- Using simple CREATE INDEX for broader compatibility + +-- Index for looking up sessions by app_name and user_id +CREATE INDEX idx_sessions_app_user ON sessions(app_name, user_id); + +-- Index for looking up events by session +CREATE INDEX idx_events_session ON events(app_name, user_id, session_id); + +-- Index for sorting events by timestamp +CREATE INDEX idx_events_timestamp ON events(timestamp); \ No newline at end of file diff --git a/core/src/main/resources/db/migration/postgresql/V1__Initial_schema.sql b/core/src/main/resources/db/migration/postgresql/V1__Initial_schema.sql new file mode 100644 index 000000000..f192220fd --- /dev/null +++ b/core/src/main/resources/db/migration/postgresql/V1__Initial_schema.sql @@ -0,0 +1,75 @@ +-- V1__Initial_schema.sql for PostgreSQL +-- Initial database schema for ADK DatabaseSessionService +-- This represents the baseline schema for all database session operations + +-- Create sessions table +CREATE TABLE IF NOT EXISTS sessions ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + id VARCHAR(128) NOT NULL, + state JSONB, + create_time TIMESTAMP, + update_time TIMESTAMP, + PRIMARY KEY (app_name, user_id, id) +); + +-- Create events table +CREATE TABLE IF NOT EXISTS events ( + id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256), + author VARCHAR(256), + actions JSONB, + timestamp TIMESTAMP, + content JSONB, + grounding_metadata JSONB, + custom_metadata JSONB, + usage_metadata JSONB, + citation_metadata JSONB, + partial BOOLEAN, + turn_complete BOOLEAN, + error_code VARCHAR(256), + error_message TEXT, + interrupted BOOLEAN, + branch VARCHAR(256), + long_running_tool_ids_json TEXT, + input_transcription JSONB, + output_transcription JSONB, + finish_reason VARCHAR(256), + avg_logprobs DOUBLE PRECISION, + model_version VARCHAR(256), + PRIMARY KEY (id, app_name, user_id, session_id), + FOREIGN KEY (app_name, user_id, session_id) + REFERENCES sessions(app_name, user_id, id) + ON DELETE CASCADE +); + +-- Create app states table +CREATE TABLE IF NOT EXISTS app_states ( + app_name VARCHAR(128) NOT NULL, + state JSONB, + update_time TIMESTAMP, + PRIMARY KEY (app_name) +); + +-- Create user states table +CREATE TABLE IF NOT EXISTS user_states ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSONB, + update_time TIMESTAMP, + PRIMARY KEY (app_name, user_id) +); + +-- Add indexes to improve query performance + +-- Index for looking up sessions by app_name and user_id +CREATE INDEX IF NOT EXISTS idx_sessions_app_user ON sessions(app_name, user_id); + +-- Index for looking up events by session +CREATE INDEX IF NOT EXISTS idx_events_session ON events(app_name, user_id, session_id); + +-- Index for sorting events by timestamp +CREATE INDEX IF NOT EXISTS idx_events_timestamp ON events(timestamp); \ No newline at end of file diff --git a/core/src/test/java/com/google/adk/sessions/ConcurrentSessionOperationsTest.java b/core/src/test/java/com/google/adk/sessions/ConcurrentSessionOperationsTest.java new file mode 100644 index 000000000..b8c1bcfcb --- /dev/null +++ b/core/src/test/java/com/google/adk/sessions/ConcurrentSessionOperationsTest.java @@ -0,0 +1,429 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.flywaydb.core.Flyway; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ConcurrentSessionOperationsTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:concurrent_test;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + private static final String TEST_APP_NAME = "concurrent-test-app"; + private static final String TEST_USER_ID = "concurrent-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + flyway.clean(); + flyway.migrate(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testConcurrentEventAppends() throws InterruptedException { + String sessionId = "concurrent-append-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + AtomicInteger successCount = new AtomicInteger(0); + + for (int i = 0; i < threadCount; i++) { + final int eventNum = i; + executor.submit( + () -> { + try { + Event event = + Event.builder() + .id("event-" + eventNum) + .author("thread-" + eventNum) + .content(Content.fromParts(Part.fromText("Message " + eventNum))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService + .appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event) + .blockingGet(); + successCount.incrementAndGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(30, TimeUnit.SECONDS)); + executor.shutdown(); + + assertEquals(threadCount, successCount.get()); + + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(finalSession); + assertEquals(threadCount, finalSession.events().size()); + } + + @Test + public void testConcurrentSessionCreation() throws InterruptedException { + int sessionCount = 10; + + for (int i = 0; i < sessionCount; i++) { + String sessionId = "session-" + i; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("index", i); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + } + + ListSessionsResponse response = + sessionService.listSessions(TEST_APP_NAME, TEST_USER_ID).blockingGet(); + assertEquals(sessionCount, response.sessions().size()); + + for (int i = 0; i < sessionCount; i++) { + String sessionId = "session-" + i; + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(session); + assertEquals(i, session.state().get("index")); + } + } + + @Test + public void testConcurrentStateDeltaUpdates() throws InterruptedException { + String sessionId = "concurrent-delta-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put("counter", 0); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int i = 0; i < threadCount; i++) { + final int increment = i + 1; + executor.submit( + () -> { + try { + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put("field_" + increment, "value_" + increment); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("thread-" + increment) + .content(Content.fromParts(Part.fromText("Update " + increment))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + sessionService + .appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event) + .blockingGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(30, TimeUnit.SECONDS)); + executor.shutdown(); + + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(finalSession); + assertEquals(threadCount, finalSession.events().size()); + + for (int i = 1; i <= threadCount; i++) { + assertTrue(finalSession.state().containsKey("field_" + i)); + assertEquals("value_" + i, finalSession.state().get("field_" + i)); + } + } + + @Test + public void testConcurrentAppStateUpdates() throws InterruptedException { + int sessionCount = 5; + List sessionIds = new ArrayList<>(); + + for (int i = 0; i < sessionCount; i++) { + String sessionId = "app-state-session-" + i; + sessionIds.add(sessionId); + sessionService + .createSession(TEST_APP_NAME, "user-" + i, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + } + + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int i = 0; i < threadCount; i++) { + final int updateNum = i; + final String sessionId = sessionIds.get(i % sessionCount); + final String userId = "user-" + (i % sessionCount); + + executor.submit( + () -> { + try { + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put("app:shared_counter", updateNum); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("thread-" + updateNum) + .content(Content.fromParts(Part.fromText("Update " + updateNum))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + sessionService.appendEvent(TEST_APP_NAME, userId, sessionId, event).blockingGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(30, TimeUnit.SECONDS)); + executor.shutdown(); + + for (int i = 0; i < sessionCount; i++) { + Session session = + sessionService + .getSession(TEST_APP_NAME, "user-" + i, sessionIds.get(i), Optional.empty()) + .blockingGet(); + assertNotNull(session); + assertTrue(session.state().containsKey("app:shared_counter")); + } + } + + @Test + public void testConcurrentUserStateUpdates() throws InterruptedException { + int sessionCount = 5; + List sessionIds = new ArrayList<>(); + + for (int i = 0; i < sessionCount; i++) { + String sessionId = "user-state-session-" + i; + sessionIds.add(sessionId); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + } + + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int i = 0; i < threadCount; i++) { + final int updateNum = i; + final String sessionId = sessionIds.get(i % sessionCount); + + executor.submit( + () -> { + try { + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put("user:counter", updateNum); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("thread-" + updateNum) + .content(Content.fromParts(Part.fromText("Update " + updateNum))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + sessionService + .appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event) + .blockingGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(30, TimeUnit.SECONDS)); + executor.shutdown(); + + for (String sessionId : sessionIds) { + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(session); + assertTrue(session.state().containsKey("user:counter")); + } + } + + @Test + public void testConcurrentReadAndWrite() throws InterruptedException { + String sessionId = "read-write-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + int threadCount = 20; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + AtomicInteger readCount = new AtomicInteger(0); + AtomicInteger writeCount = new AtomicInteger(0); + + for (int i = 0; i < threadCount; i++) { + final int threadNum = i; + executor.submit( + () -> { + try { + if (threadNum % 2 == 0) { + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + if (session != null) { + readCount.incrementAndGet(); + } + } else { + Event event = + Event.builder() + .id("event-" + threadNum) + .author("thread-" + threadNum) + .content(Content.fromParts(Part.fromText("Message " + threadNum))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService + .appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event) + .blockingGet(); + writeCount.incrementAndGet(); + } + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(30, TimeUnit.SECONDS)); + executor.shutdown(); + + assertEquals(threadCount / 2, readCount.get()); + assertEquals(threadCount / 2, writeCount.get()); + } + + @Test + public void testConcurrentDeleteOperations() throws InterruptedException { + int sessionCount = 10; + List sessionIds = new ArrayList<>(); + + for (int i = 0; i < sessionCount; i++) { + String sessionId = "delete-session-" + i; + sessionIds.add(sessionId); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + } + + ExecutorService executor = Executors.newFixedThreadPool(sessionCount); + CountDownLatch latch = new CountDownLatch(sessionCount); + AtomicInteger deleteCount = new AtomicInteger(0); + + for (String sessionId : sessionIds) { + executor.submit( + () -> { + try { + sessionService.deleteSession(TEST_APP_NAME, TEST_USER_ID, sessionId).blockingAwait(); + deleteCount.incrementAndGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(30, TimeUnit.SECONDS)); + executor.shutdown(); + + assertEquals(sessionCount, deleteCount.get()); + + ListSessionsResponse response = + sessionService.listSessions(TEST_APP_NAME, TEST_USER_ID).blockingGet(); + assertEquals(0, response.sessions().size()); + } +} diff --git a/core/src/test/java/com/google/adk/sessions/DatabaseDialectDetectorTest.java b/core/src/test/java/com/google/adk/sessions/DatabaseDialectDetectorTest.java new file mode 100644 index 000000000..df5b7ed75 --- /dev/null +++ b/core/src/test/java/com/google/adk/sessions/DatabaseDialectDetectorTest.java @@ -0,0 +1,179 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.adk.sessions.db.util.DatabaseDialectDetector; +import org.junit.jupiter.api.Test; + +public class DatabaseDialectDetectorTest { + + @Test + public void testDetectPostgreSQLDialect() { + String url = "jdbc:postgresql://localhost:5432/testdb"; + String dialect = DatabaseDialectDetector.detectDialect(url); + assertEquals("org.hibernate.dialect.PostgreSQLDialect", dialect); + } + + @Test + public void testDetectMySQLDialect() { + String url = "jdbc:mysql://localhost:3306/testdb"; + String dialect = DatabaseDialectDetector.detectDialect(url); + assertEquals("org.hibernate.dialect.MySQLDialect", dialect); + } + + @Test + public void testDetectH2Dialect() { + String url = "jdbc:h2:mem:testdb"; + String dialect = DatabaseDialectDetector.detectDialect(url); + assertEquals("org.hibernate.dialect.H2Dialect", dialect); + } + + @Test + public void testDetectSQLiteDialect() { + String url = "jdbc:sqlite:/path/to/database.db"; + String dialect = DatabaseDialectDetector.detectDialect(url); + assertEquals("org.hibernate.dialect.SQLiteDialect", dialect); + } + + @Test + public void testDetectSpannerDialect() { + String url = "jdbc:cloudspanner:/projects/test/instances/test/databases/test"; + String dialect = DatabaseDialectDetector.detectDialect(url); + assertEquals("com.google.cloud.spanner.hibernate.SpannerDialect", dialect); + } + + @Test + public void testDetectDialectWithParameters() { + String url = "jdbc:postgresql://localhost:5432/testdb?user=admin&password=secret"; + String dialect = DatabaseDialectDetector.detectDialect(url); + assertEquals("org.hibernate.dialect.PostgreSQLDialect", dialect); + } + + @Test + public void testDetectDialectNullUrl() { + assertThrows(IllegalArgumentException.class, () -> DatabaseDialectDetector.detectDialect(null)); + } + + @Test + public void testDetectDialectUnsupportedDatabase() { + String url = "jdbc:oracle:thin:@localhost:1521:testdb"; + assertThrows(IllegalArgumentException.class, () -> DatabaseDialectDetector.detectDialect(url)); + } + + @Test + public void testDetectDialectInvalidUrl() { + String url = "not-a-jdbc-url"; + assertThrows(IllegalArgumentException.class, () -> DatabaseDialectDetector.detectDialect(url)); + } + + @Test + public void testMaskConnectionUrlWithPassword() { + String url = "jdbc:postgresql://localhost:5432/testdb?user=admin&password=secret123"; + String masked = DatabaseDialectDetector.maskConnectionUrl(url); + assertFalse(masked.contains("secret123")); + assertTrue(masked.contains("password=****")); + } + + @Test + public void testMaskConnectionUrlWithColonPassword() { + String url = "jdbc:mysql://admin:secret456@localhost:3306/testdb"; + String masked = DatabaseDialectDetector.maskConnectionUrl(url); + assertFalse(masked.contains("secret456")); + assertTrue(masked.contains("password:****@")); + } + + @Test + public void testMaskConnectionUrlMultiplePasswords() { + String url = + "jdbc:postgresql://localhost:5432/testdb?password=first&user=admin&password=second"; + String masked = DatabaseDialectDetector.maskConnectionUrl(url); + assertFalse(masked.contains("first")); + assertFalse(masked.contains("second")); + assertTrue(masked.contains("password=****")); + } + + @Test + public void testMaskConnectionUrlNoPassword() { + String url = "jdbc:h2:mem:testdb"; + String masked = DatabaseDialectDetector.maskConnectionUrl(url); + assertEquals(url, masked); + } + + @Test + public void testMaskConnectionUrlNull() { + String masked = DatabaseDialectDetector.maskConnectionUrl(null); + assertNull(masked); + } + + @Test + public void testMaskConnectionUrlWithUserButNoPassword() { + String url = "jdbc:postgresql://localhost:5432/testdb?user=admin"; + String masked = DatabaseDialectDetector.maskConnectionUrl(url); + assertEquals(url, masked); + } + + @Test + public void testMaskConnectionUrlComplexCase() { + String url = + "jdbc:mysql://user:mypassword123@host:3306/db?password=anotherpass&ssl=true&password=thirdpass"; + String masked = DatabaseDialectDetector.maskConnectionUrl(url); + assertFalse(masked.contains("mypassword123")); + assertFalse(masked.contains("anotherpass")); + assertFalse(masked.contains("thirdpass")); + assertTrue(masked.contains("password:****@")); + assertTrue(masked.contains("password=****")); + } + + @Test + public void testH2InMemoryDatabase() { + String url = "jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1"; + String dialect = DatabaseDialectDetector.detectDialect(url); + assertEquals("org.hibernate.dialect.H2Dialect", dialect); + } + + @Test + public void testH2FileDatabase() { + String url = "jdbc:h2:file:/data/testdb"; + String dialect = DatabaseDialectDetector.detectDialect(url); + assertEquals("org.hibernate.dialect.H2Dialect", dialect); + } + + @Test + public void testPostgreSQLWithSSL() { + String url = + "jdbc:postgresql://localhost:5432/testdb?ssl=true&sslfactory=org.postgresql.ssl.NonValidatingFactory"; + String dialect = DatabaseDialectDetector.detectDialect(url); + assertEquals("org.hibernate.dialect.PostgreSQLDialect", dialect); + } + + @Test + public void testMySQLWithUTF8() { + String url = "jdbc:mysql://localhost:3306/testdb?useUnicode=true&characterEncoding=UTF-8"; + String dialect = DatabaseDialectDetector.detectDialect(url); + assertEquals("org.hibernate.dialect.MySQLDialect", dialect); + } + + @Test + public void testEmptyStringUrl() { + assertThrows(IllegalArgumentException.class, () -> DatabaseDialectDetector.detectDialect("")); + } +} diff --git a/core/src/test/java/com/google/adk/sessions/DatabaseSessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/DatabaseSessionServiceTest.java new file mode 100644 index 000000000..f59aa5126 --- /dev/null +++ b/core/src/test/java/com/google/adk/sessions/DatabaseSessionServiceTest.java @@ -0,0 +1,416 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.adk.events.Event; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import org.flywaydb.core.Flyway; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** Unit tests for DatabaseSessionService. */ +public class DatabaseSessionServiceTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + private static final String TEST_APP_NAME = "test-app"; + private static final String TEST_USER_ID = "test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + // Initialize schema with Flyway + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + flyway.clean(); // Clean the database (safe in tests with in-memory DB) + flyway.migrate(); // Apply migrations + + // Create service with H2 in-memory database + Map properties = + Map.of( + "hibernate.show_sql", "true", + "hibernate.format_sql", "true", + "hibernate.hbm2ddl.auto", "validate"); // Set to validate mode + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + } + + @AfterEach + public void tearDown() { + // Clean up resources + sessionService.close(); + } + + @Test + public void testCreateSession() { + // Arrange + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("key1", "value1"); + state.put("key2", 42); + + // Act + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, null).blockingGet(); + + // Assert + assertNotNull(session); + assertNotNull(session.id()); + assertEquals(TEST_APP_NAME, session.appName()); + assertEquals(TEST_USER_ID, session.userId()); + assertEquals("value1", session.state().get("key1")); + assertEquals(42, session.state().get("key2")); + assertTrue(session.events().isEmpty()); + // Note: Session no longer has appState/userState accessors + // Database still stores these values but they aren't exposed in the Session model + } + + @Test + public void testCreateSessionWithId() { + // Arrange + String sessionId = "custom-session-id"; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + + // Act + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Assert + assertNotNull(session); + assertEquals(sessionId, session.id()); + assertEquals(TEST_APP_NAME, session.appName()); + assertEquals(TEST_USER_ID, session.userId()); + } + + @Test + public void testGetSession() { + // Arrange + String sessionId = "get-session-test"; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("key", "value"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + // Act + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + // Assert + assertNotNull(retrievedSession); + assertEquals(sessionId, retrievedSession.id()); + assertEquals(TEST_APP_NAME, retrievedSession.appName()); + assertEquals(TEST_USER_ID, retrievedSession.userId()); + assertEquals("value", retrievedSession.state().get("key")); + } + + @Test + public void testGetSessionNotFound() { + // Arrange + String nonExistentId = "non-existent"; + + // Act & Assert + assertNull( + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, nonExistentId, Optional.empty()) + .blockingGet()); + } + + @Test + public void testLifecycleNoSession() { + assertNull( + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, "non-existent-session", Optional.empty()) + .blockingGet()); + + ListSessionsResponse sessionsResponse = + sessionService.listSessions(TEST_APP_NAME, TEST_USER_ID).blockingGet(); + assertEquals(0, sessionsResponse.sessions().size()); + + ListEventsResponse eventsResponse = + sessionService + .listEvents(TEST_APP_NAME, TEST_USER_ID, "non-existent-session") + .blockingGet(); + assertEquals(0, eventsResponse.events().size()); + } + + @Test + public void testListSessionsEmpty() { + // Act + ListSessionsResponse response = + sessionService.listSessions(TEST_APP_NAME, TEST_USER_ID).blockingGet(); + + // Assert + assertNotNull(response); + assertEquals(0, response.sessions().size()); + } + + @Test + public void testListSessions() { + // Arrange + String sessionId1 = "list-test-1"; + String sessionId2 = "list-test-2"; + + // Create two sessions + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId1) + .blockingGet(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId2) + .blockingGet(); + + // Act + ListSessionsResponse response = + sessionService.listSessions(TEST_APP_NAME, TEST_USER_ID).blockingGet(); + + // Assert + assertNotNull(response); + List sessions = response.sessions(); + assertEquals(2, sessions.size()); + assertTrue(sessions.stream().anyMatch(s -> s.id().equals(sessionId1))); + assertTrue(sessions.stream().anyMatch(s -> s.id().equals(sessionId2))); + } + + @Test + public void testAppendEvent() { + // Arrange + String sessionId = "event-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Hello, world!"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + // Act + Session updatedSession = + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + + // Assert + assertNotNull(updatedSession); + assertEquals(1, updatedSession.events().size()); + Event retrievedEvent = updatedSession.events().get(0); + assertEquals(event.id(), retrievedEvent.id()); + assertEquals(event.author(), retrievedEvent.author()); + assertEquals( + "Hello, world!", + retrievedEvent.content().flatMap(c -> c.parts()).stream() + .flatMap(List::stream) + .flatMap(p -> p.text().stream()) + .findFirst() + .orElse("")); + } + + @Test + public void testAppendEventToNonExistentSession() { + // Arrange + String nonExistentId = "non-existent"; + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Hello, world!"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + // Act & Assert + assertThrows( + SessionNotFoundException.class, + () -> + sessionService + .appendEvent(TEST_APP_NAME, TEST_USER_ID, nonExistentId, event) + .blockingGet()); + } + + @Test + public void testDeleteSession() { + // Arrange + String sessionId = "delete-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Verify session exists + assertNotNull( + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet()); + + // Act + sessionService.deleteSession(TEST_APP_NAME, TEST_USER_ID, sessionId).blockingAwait(); + + // Assert + assertNull( + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet()); + } + + @Test + public void testListEvents() { + // Arrange + String sessionId = "list-events-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + // Create events + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("index: " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + // Add small delay to ensure different timestamps + try { + TimeUnit.MILLISECONDS.sleep(10); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + // Act + ListEventsResponse response = + sessionService.listEvents(TEST_APP_NAME, TEST_USER_ID, sessionId, 3, null).blockingGet(); + + // Assert + assertNotNull(response); + assertEquals(3, response.events().size()); + assertNotNull(response.nextPageToken()); + + // Test pagination + ListEventsResponse page2 = + sessionService + .listEvents( + TEST_APP_NAME, TEST_USER_ID, sessionId, 3, response.nextPageToken().orElse(null)) + .blockingGet(); + assertNotNull(page2); + assertEquals(2, page2.events().size()); // 5 total - 3 from first page = 2 events + assertTrue(page2.nextPageToken().isEmpty()); // No more pages + } + + @Test + public void testGetSessionWithFiltering() { + // Arrange + String sessionId = "filter-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Instant startTime = Instant.now(); + + // Create events with incrementing timestamps + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("index: " + i))) + .timestamp(startTime.plusSeconds(i).toEpochMilli()) + .build(); + + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + } + + // Act - Get only recent events + GetSessionConfig config = GetSessionConfig.builder().numRecentEvents(2).build(); + Session sessionWithRecentEvents = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, java.util.Optional.of(config)) + .blockingGet(); + + // Assert - Should have only the 2 most recent events + assertNotNull(sessionWithRecentEvents); + assertEquals(2, sessionWithRecentEvents.events().size()); + // Content structure has changed, content is now Optional instead of Map + + // Act - Get events after a specific timestamp + GetSessionConfig timestampConfig = + GetSessionConfig.builder().afterTimestamp(startTime.plusSeconds(3)).build(); + Session sessionWithTimestampFilter = + sessionService + .getSession( + TEST_APP_NAME, TEST_USER_ID, sessionId, java.util.Optional.of(timestampConfig)) + .blockingGet(); + + // Assert - Should have only events after the specified timestamp + assertNotNull(sessionWithTimestampFilter); + assertEquals(2, sessionWithTimestampFilter.events().size()); + // Content structure has changed, content is now Optional instead of Map + } + + @Test + public void testAppendEventUpdatesSessionState() { + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), "session1") + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("sessionKey", "sessionValue"); + stateDelta.put("_app_appKey", "appValue"); + stateDelta.put("_user_userKey", "userValue"); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test event"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(com.google.adk.events.EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertNotNull(retrievedSession); + assertEquals("sessionValue", retrievedSession.state().get("sessionKey")); + assertEquals("appValue", retrievedSession.state().get("_app_appKey")); + assertEquals("userValue", retrievedSession.state().get("_user_userKey")); + } +} diff --git a/core/src/test/java/com/google/adk/sessions/EventFilteringTest.java b/core/src/test/java/com/google/adk/sessions/EventFilteringTest.java new file mode 100644 index 000000000..efccbd78c --- /dev/null +++ b/core/src/test/java/com/google/adk/sessions/EventFilteringTest.java @@ -0,0 +1,416 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.adk.events.Event; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import org.flywaydb.core.Flyway; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class EventFilteringTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:filter_test;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + private static final String TEST_APP_NAME = "filter-test-app"; + private static final String TEST_USER_ID = "filter-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + flyway.clean(); + flyway.migrate(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testNumRecentEventsFilter() { + String sessionId = "recent-events-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 10; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test") + .content(Content.fromParts(Part.fromText("Message " + i))) + .timestamp(Instant.now().plusSeconds(i).toEpochMilli()) + .build(); + + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + } + + GetSessionConfig config = GetSessionConfig.builder().numRecentEvents(3).build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(session); + assertEquals(3, session.events().size()); + } + + @Test + public void testNumRecentEventsZero() { + String sessionId = "zero-events-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test") + .content(Content.fromParts(Part.fromText("Message " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + } + + GetSessionConfig config = GetSessionConfig.builder().numRecentEvents(0).build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(session); + assertEquals(0, session.events().size()); + } + + @Test + public void testNumRecentEventsExceedsTotal() { + String sessionId = "exceed-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 3; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test") + .content(Content.fromParts(Part.fromText("Message " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + } + + GetSessionConfig config = GetSessionConfig.builder().numRecentEvents(10).build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(session); + assertEquals(3, session.events().size()); + } + + @Test + public void testAfterTimestampFilter() { + String sessionId = "timestamp-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Instant baseTime = Instant.now(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test") + .content(Content.fromParts(Part.fromText("Message " + i))) + .timestamp(baseTime.plusSeconds(i).toEpochMilli()) + .build(); + + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + } + + Instant filterTime = baseTime.plusSeconds(3); + GetSessionConfig config = GetSessionConfig.builder().afterTimestamp(filterTime).build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(session); + assertEquals(2, session.events().size()); + } + + @Test + public void testAfterTimestampNoMatches() { + String sessionId = "no-matches-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Instant baseTime = Instant.now(); + + for (int i = 1; i <= 3; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test") + .content(Content.fromParts(Part.fromText("Message " + i))) + .timestamp(baseTime.plusSeconds(i).toEpochMilli()) + .build(); + + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + } + + Instant futureTime = baseTime.plusSeconds(100); + GetSessionConfig config = GetSessionConfig.builder().afterTimestamp(futureTime).build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(session); + assertEquals(0, session.events().size()); + } + + @Test + public void testCombinedFilters() { + String sessionId = "combined-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Instant baseTime = Instant.now(); + + for (int i = 1; i <= 10; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test") + .content(Content.fromParts(Part.fromText("Message " + i))) + .timestamp(baseTime.plusSeconds(i).toEpochMilli()) + .build(); + + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + } + + Instant filterTime = baseTime.plusSeconds(3); + GetSessionConfig config = + GetSessionConfig.builder().afterTimestamp(filterTime).numRecentEvents(3).build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(session); + assertEquals(3, session.events().size()); + } + + @Test + public void testNoFilterConfiguration() { + String sessionId = "no-filter-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test") + .content(Content.fromParts(Part.fromText("Message " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + } + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(session); + assertEquals(5, session.events().size()); + } + + @Test + public void testFilteringPreservesOrder() { + String sessionId = "order-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Instant baseTime = Instant.now(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test") + .content(Content.fromParts(Part.fromText("Message " + i))) + .timestamp(baseTime.plusSeconds(i).toEpochMilli()) + .build(); + + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + } + + GetSessionConfig config = GetSessionConfig.builder().numRecentEvents(3).build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(session); + assertEquals(3, session.events().size()); + + assertEquals("event-3", session.events().get(0).id()); + assertEquals("event-4", session.events().get(1).id()); + assertEquals("event-5", session.events().get(2).id()); + } + + @Test + public void testFilterOnEmptyEventList() { + String sessionId = "empty-filter-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + GetSessionConfig config = GetSessionConfig.builder().numRecentEvents(5).build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(session); + assertTrue(session.events().isEmpty()); + } + + @Test + public void testEventPagination() { + String sessionId = "pagination-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 20; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test") + .content(Content.fromParts(Part.fromText("Message " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + } + + ListEventsResponse page1 = + sessionService.listEvents(TEST_APP_NAME, TEST_USER_ID, sessionId, 5, null).blockingGet(); + + assertNotNull(page1); + assertEquals(5, page1.events().size()); + assertTrue(page1.nextPageToken().isPresent()); + + ListEventsResponse page2 = + sessionService + .listEvents(TEST_APP_NAME, TEST_USER_ID, sessionId, 5, page1.nextPageToken().get()) + .blockingGet(); + + assertNotNull(page2); + assertEquals(5, page2.events().size()); + assertTrue(page2.nextPageToken().isPresent()); + } + + @Test + public void testEventPaginationLastPage() { + String sessionId = "last-page-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 7; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test") + .content(Content.fromParts(Part.fromText("Message " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + } + + ListEventsResponse page1 = + sessionService.listEvents(TEST_APP_NAME, TEST_USER_ID, sessionId, 5, null).blockingGet(); + + assertNotNull(page1); + assertEquals(5, page1.events().size()); + + ListEventsResponse page2 = + sessionService + .listEvents(TEST_APP_NAME, TEST_USER_ID, sessionId, 5, page1.nextPageToken().get()) + .blockingGet(); + + assertNotNull(page2); + assertEquals(2, page2.events().size()); + assertTrue(page2.nextPageToken().isEmpty()); + } +} diff --git a/core/src/test/java/com/google/adk/sessions/FlywayMigrationTest.java b/core/src/test/java/com/google/adk/sessions/FlywayMigrationTest.java new file mode 100644 index 000000000..a64240c6f --- /dev/null +++ b/core/src/test/java/com/google/adk/sessions/FlywayMigrationTest.java @@ -0,0 +1,195 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.Map; +import org.flywaydb.core.Flyway; +import org.flywaydb.core.api.MigrationInfo; +import org.junit.jupiter.api.Test; + +/** Tests for Flyway migrations to ensure they're applied correctly. */ +public class FlywayMigrationTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:testdb_migrations;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + + @Test + public void testMigrations() throws Exception { + // Initialize Flyway + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + + // Clean and migrate + flyway.clean(); + flyway.migrate(); + + // Get migration info to verify + MigrationInfo[] migrations = flyway.info().applied(); + + // Check that all expected migrations were applied + assertTrue(migrations.length >= 1, "Expected at least 1 migration to be applied"); + assertEquals("1", migrations[0].getVersion().toString(), "First migration should be V1"); + assertEquals( + "Initial schema", + migrations[0].getDescription(), + "First migration should be Initial schema"); + + // Verify tables and indexes exist via JDBC + try (Connection conn = flyway.getConfiguration().getDataSource().getConnection(); + Statement stmt = conn.createStatement()) { + + // Check tables exist + String[] expectedTables = {"sessions", "events", "app_states", "user_states"}; + for (String table : expectedTables) { + ResultSet rs = + stmt.executeQuery( + "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'PUBLIC' AND table_name = '" + + table.toUpperCase() + + "'"); + rs.next(); + assertEquals(1, rs.getInt(1), "Table " + table + " should exist"); + } + } + + // Verify DatabaseSessionService can be created with validate mode + Map properties = + Map.of( + "hibernate.hbm2ddl.auto", "validate", + "hibernate.show_sql", "true"); + + DatabaseSessionService service = new DatabaseSessionService(TEST_DB_URL, properties); + assertNotNull(service, "Service should be created successfully"); + service.close(); + } + + @Test + public void testMigrationIdempotency() throws Exception { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + + flyway.clean(); + flyway.migrate(); + + MigrationInfo[] firstRun = flyway.info().applied(); + int firstRunCount = firstRun.length; + + flyway.migrate(); + + MigrationInfo[] secondRun = flyway.info().applied(); + int secondRunCount = secondRun.length; + + assertEquals( + firstRunCount, secondRunCount, "Running migrations twice should not apply more migrations"); + } + + @Test + public void testMigrationValidation() throws Exception { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + + flyway.clean(); + flyway.migrate(); + + flyway.validate(); + } + + @Test + public void testConcurrentMigrationAttempts() throws Exception { + String concurrentDbUrl = + "jdbc:h2:mem:concurrent_migrations;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + + Flyway flyway1 = + Flyway.configure() + .dataSource(concurrentDbUrl, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .lockRetryCount(5) + .load(); + + flyway1.clean(); + + Flyway flyway2 = + Flyway.configure() + .dataSource(concurrentDbUrl, null, null) + .locations("classpath:db/migration/h2") + .lockRetryCount(5) + .load(); + + Thread thread1 = + new Thread( + () -> { + try { + flyway1.migrate(); + } catch (Exception e) { + } + }); + + Thread thread2 = + new Thread( + () -> { + try { + flyway2.migrate(); + } catch (Exception e) { + } + }); + + thread1.start(); + thread2.start(); + + thread1.join(10000); + thread2.join(10000); + + flyway1.validate(); + } + + @Test + public void testSchemaVersioning() throws Exception { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + + flyway.clean(); + flyway.migrate(); + + MigrationInfo current = flyway.info().current(); + assertNotNull(current, "Current migration should not be null"); + assertTrue( + Integer.parseInt(current.getVersion().toString()) >= 1, + "Current version should be at least 1"); + } +} diff --git a/core/src/test/java/com/google/adk/sessions/PessimisticLockingTest.java b/core/src/test/java/com/google/adk/sessions/PessimisticLockingTest.java new file mode 100644 index 000000000..595bd0f5c --- /dev/null +++ b/core/src/test/java/com/google/adk/sessions/PessimisticLockingTest.java @@ -0,0 +1,325 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.adk.events.Event; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.flywaydb.core.Flyway; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class PessimisticLockingTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:locking_test;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + private static final String TEST_APP_NAME = "locking-test-app"; + private static final String TEST_USER_ID = "locking-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + flyway.clean(); + flyway.migrate(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testSerializedEventAppends() throws InterruptedException { + String sessionId = "serialized-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + int threadCount = 20; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + AtomicInteger successCount = new AtomicInteger(0); + + for (int i = 0; i < threadCount; i++) { + final int eventNum = i; + executor.submit( + () -> { + try { + Event event = + Event.builder() + .id("event-" + eventNum) + .author("thread-" + eventNum) + .content(Content.fromParts(Part.fromText("Message " + eventNum))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService + .appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event) + .blockingGet(); + successCount.incrementAndGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(60, TimeUnit.SECONDS)); + executor.shutdown(); + + assertEquals(threadCount, successCount.get()); + + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(finalSession); + assertEquals(threadCount, finalSession.events().size()); + } + + @Test + public void testNoLostUpdates() throws InterruptedException { + String sessionId = "no-lost-updates"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put("counter", 0); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int i = 0; i < threadCount; i++) { + final int eventNum = i; + executor.submit( + () -> { + try { + Event event = + Event.builder() + .id("event-" + eventNum) + .author("thread-" + eventNum) + .content(Content.fromParts(Part.fromText("Message " + eventNum))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService + .appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event) + .blockingGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(60, TimeUnit.SECONDS)); + executor.shutdown(); + + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(finalSession); + assertEquals(threadCount, finalSession.events().size()); + } + + @Test + public void testConcurrentAppendDifferentSessions() throws InterruptedException { + int sessionCount = 5; + int eventsPerSession = 10; + + for (int i = 0; i < sessionCount; i++) { + String sessionId = "session-" + i; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + } + + ExecutorService executor = Executors.newFixedThreadPool(sessionCount * eventsPerSession); + CountDownLatch latch = new CountDownLatch(sessionCount * eventsPerSession); + + for (int i = 0; i < sessionCount; i++) { + final String sessionId = "session-" + i; + for (int j = 0; j < eventsPerSession; j++) { + final int eventNum = j; + executor.submit( + () -> { + try { + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Event " + eventNum))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService + .appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event) + .blockingGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + } + + assertTrue(latch.await(60, TimeUnit.SECONDS)); + executor.shutdown(); + + for (int i = 0; i < sessionCount; i++) { + String sessionId = "session-" + i; + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(session); + assertEquals(eventsPerSession, session.events().size()); + } + } + + @Test + public void testAppendEventUnderLoad() throws InterruptedException { + String sessionId = "load-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + int threadCount = 50; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + AtomicInteger successCount = new AtomicInteger(0); + AtomicInteger failureCount = new AtomicInteger(0); + + for (int i = 0; i < threadCount; i++) { + final int eventNum = i; + executor.submit( + () -> { + try { + Event event = + Event.builder() + .id("event-" + eventNum) + .author("thread-" + eventNum) + .content(Content.fromParts(Part.fromText("Load test " + eventNum))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + sessionService + .appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event) + .blockingGet(); + successCount.incrementAndGet(); + } catch (Exception e) { + failureCount.incrementAndGet(); + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(120, TimeUnit.SECONDS)); + executor.shutdown(); + + assertEquals(threadCount, successCount.get()); + assertEquals(0, failureCount.get()); + + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(finalSession); + assertEquals(threadCount, finalSession.events().size()); + } + + @Test + public void testEventOrderingConsistency() throws InterruptedException { + String sessionId = "ordering-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + int eventCount = 100; + ExecutorService executor = Executors.newFixedThreadPool(10); + CountDownLatch latch = new CountDownLatch(eventCount); + + for (int i = 0; i < eventCount; i++) { + final int eventNum = i; + executor.submit( + () -> { + try { + Event event = + Event.builder() + .id("event-" + String.format("%03d", eventNum)) + .author("test") + .content(Content.fromParts(Part.fromText("Message " + eventNum))) + .timestamp(Instant.now().toEpochMilli() + eventNum) + .build(); + + sessionService + .appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event) + .blockingGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(120, TimeUnit.SECONDS)); + executor.shutdown(); + + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(finalSession); + assertEquals(eventCount, finalSession.events().size()); + } +} diff --git a/core/src/test/java/com/google/adk/sessions/StateDeltaTest.java b/core/src/test/java/com/google/adk/sessions/StateDeltaTest.java new file mode 100644 index 000000000..0ef757850 --- /dev/null +++ b/core/src/test/java/com/google/adk/sessions/StateDeltaTest.java @@ -0,0 +1,382 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.flywaydb.core.Flyway; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class StateDeltaTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:delta_test;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + private static final String TEST_APP_NAME = "delta-test-app"; + private static final String TEST_USER_ID = "delta-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + flyway.clean(); + flyway.migrate(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testStateDeltaInEvent() { + String sessionId = "delta-event-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put("counter", 0); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put("counter", 1); + delta.put("new_field", "added"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session updated = + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + + assertNotNull(updated); + assertEquals(1, updated.state().get("counter")); + assertEquals("added", updated.state().get("new_field")); + } + + @Test + public void testAppStateDeltaInEvent() { + String sessionId = "app-delta-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put("app:config_version", "v1"); + delta.put("app:feature_flag", true); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session updated = + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + + assertNotNull(updated); + assertEquals("v1", updated.state().get("app:config_version")); + assertEquals(true, updated.state().get("app:feature_flag")); + } + + @Test + public void testUserStateDeltaInEvent() { + String sessionId = "user-delta-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put("user:preference", "new_preference"); + delta.put("user:score", 100); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session updated = + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + + assertNotNull(updated); + assertEquals("new_preference", updated.state().get("user:preference")); + assertEquals(100, updated.state().get("user:score")); + } + + @Test + public void testMixedStateDeltaInEvent() { + String sessionId = "mixed-delta-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put("app:app_field", "app_value"); + delta.put("user:user_field", "user_value"); + delta.put("session_field", "session_value"); + delta.put("temp:temp_field", "temp_value"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session updated = + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + + assertNotNull(updated); + assertEquals("app_value", updated.state().get("app:app_field")); + assertEquals("user_value", updated.state().get("user:user_field")); + assertEquals("session_value", updated.state().get("session_field")); + assertFalse(updated.state().containsKey("temp:temp_field")); + } + + @Test + public void testMultipleDeltaUpdates() { + String sessionId = "multi-delta-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put("counter", 0); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + for (int i = 1; i <= 5; i++) { + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put("counter", i); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Update " + i))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + } + + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(finalSession); + assertEquals(5, finalSession.state().get("counter")); + assertEquals(5, finalSession.events().size()); + } + + @Test + public void testStateDeltaWithNullValues() { + String sessionId = "null-delta-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put("field1", "value1"); + initialState.put("field2", "value2"); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put("field3", "value3"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session updated = + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + + assertNotNull(updated); + assertEquals("value1", updated.state().get("field1")); + assertEquals("value2", updated.state().get("field2")); + assertEquals("value3", updated.state().get("field3")); + } + + @Test + public void testStateDeltaOverwritesExisting() { + String sessionId = "overwrite-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put("name", "old_name"); + initialState.put("version", 1); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put("name", "new_name"); + delta.put("version", 2); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session updated = + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId, event).blockingGet(); + + assertNotNull(updated); + assertEquals("new_name", updated.state().get("name")); + assertEquals(2, updated.state().get("version")); + } + + @Test + public void testAppStateDeltaPropagation() { + String sessionId1 = "app-propagate-1"; + String sessionId2 = "app-propagate-2"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId1) + .blockingGet(); + sessionService + .createSession(TEST_APP_NAME, "other-user", new ConcurrentHashMap<>(), sessionId2) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put("app:global_counter", 42); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId1, event).blockingGet(); + + Session session1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId1, Optional.empty()) + .blockingGet(); + Session session2 = + sessionService + .getSession(TEST_APP_NAME, "other-user", sessionId2, Optional.empty()) + .blockingGet(); + + assertEquals(42, session1.state().get("app:global_counter")); + assertEquals(42, session2.state().get("app:global_counter")); + } + + @Test + public void testUserStateDeltaPropagation() { + String sessionId1 = "user-propagate-1"; + String sessionId2 = "user-propagate-2"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId1) + .blockingGet(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId2) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put("user:points", 999); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + sessionService.appendEvent(TEST_APP_NAME, TEST_USER_ID, sessionId1, event).blockingGet(); + + Session session1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId1, Optional.empty()) + .blockingGet(); + Session session2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId2, Optional.empty()) + .blockingGet(); + + assertEquals(999, session1.state().get("user:points")); + assertEquals(999, session2.state().get("user:points")); + } +} diff --git a/core/src/test/java/com/google/adk/sessions/StateManagementTest.java b/core/src/test/java/com/google/adk/sessions/StateManagementTest.java new file mode 100644 index 000000000..01d77036c --- /dev/null +++ b/core/src/test/java/com/google/adk/sessions/StateManagementTest.java @@ -0,0 +1,308 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import org.flywaydb.core.Flyway; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class StateManagementTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:state_test;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + private static final String TEST_APP_NAME = "state-test-app"; + private static final String TEST_USER_ID_1 = "user-1"; + private static final String TEST_USER_ID_2 = "user-2"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + flyway.clean(); + flyway.migrate(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testAppStateSharing() { + String sessionId1 = "session-1"; + String sessionId2 = "session-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("app:global_setting", "shared_value"); + state1.put("session:local", "private_value_1"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("app:global_setting", "updated_value"); + state2.put("session:local", "private_value_2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_2, state2, sessionId2).blockingGet(); + + Session retrieved1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId1, Optional.empty()) + .blockingGet(); + Session retrieved2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_2, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved1); + assertNotNull(retrieved2); + + assertEquals("updated_value", retrieved1.state().get("app:global_setting")); + assertEquals("updated_value", retrieved2.state().get("app:global_setting")); + + assertEquals("private_value_1", retrieved1.state().get("session:local")); + assertEquals("private_value_2", retrieved2.state().get("session:local")); + } + + @Test + public void testUserStateSharing() { + String sessionId1 = "user-session-1"; + String sessionId2 = "user-session-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("user:preference", "dark_mode"); + state1.put("session:data", "session_specific_1"); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("user:preference", "light_mode"); + state2.put("session:data", "session_specific_2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session session1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId1, Optional.empty()) + .blockingGet(); + Session session2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(session1); + assertNotNull(session2); + + assertEquals("light_mode", session1.state().get("user:preference")); + assertEquals("light_mode", session2.state().get("user:preference")); + + assertEquals("session_specific_1", session1.state().get("session:data")); + assertEquals("session_specific_2", session2.state().get("session:data")); + } + + @Test + public void testSessionStateIsolation() { + String sessionId1 = "isolated-1"; + String sessionId2 = "isolated-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("private_key", "value_1"); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("private_key", "value_2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session session1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId1, Optional.empty()) + .blockingGet(); + Session session2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(session1); + assertNotNull(session2); + assertEquals("value_1", session1.state().get("private_key")); + assertEquals("value_2", session2.state().get("private_key")); + } + + @Test + public void testAllStateIsPersisted() { + String sessionId = "persist-test"; + + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("data1", "value1"); + state.put("data2", "value2"); + + Session created = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state, sessionId).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertTrue(retrieved.state().containsKey("data1")); + assertTrue(retrieved.state().containsKey("data2")); + assertEquals("value1", retrieved.state().get("data1")); + assertEquals("value2", retrieved.state().get("data2")); + } + + @Test + public void testMultipleStateKeys() { + String sessionId = "multiple-keys-test"; + + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("setting1", "value1"); + state.put("setting2", "value2"); + state.put("setting3", "value3"); + state.put("setting4", "value4"); + + Session created = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state, sessionId).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertTrue(retrieved.state().containsKey("setting1")); + assertTrue(retrieved.state().containsKey("setting2")); + assertTrue(retrieved.state().containsKey("setting3")); + assertTrue(retrieved.state().containsKey("setting4")); + } + + @Test + public void testEmptyStateCreation() { + String sessionId = "empty-state"; + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, null, sessionId).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertNotNull(retrieved.state()); + assertTrue(retrieved.state().isEmpty()); + } + + @Test + public void testStateRetrieval() { + String sessionId = "state-retrieval"; + + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("config", "config_value"); + state.put("setting", "setting_value"); + state.put("data", "data_value"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state, sessionId).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertTrue(retrieved.state().containsKey("config")); + assertTrue(retrieved.state().containsKey("setting")); + assertTrue(retrieved.state().containsKey("data")); + assertEquals("config_value", retrieved.state().get("config")); + assertEquals("setting_value", retrieved.state().get("setting")); + assertEquals("data_value", retrieved.state().get("data")); + } + + @Test + public void testThreeTierStateMergingPriority() { + String sessionId = "merge-test"; + + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("app:key", "app_value"); + state.put("user:key", "user_value"); + state.put("session:key", "session_value"); + state.put("key", "unprefixed_value"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state, sessionId).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("app_value", retrieved.state().get("app:key")); + assertEquals("user_value", retrieved.state().get("user:key")); + assertEquals("session_value", retrieved.state().get("session:key")); + assertEquals("unprefixed_value", retrieved.state().get("key")); + } + + @Test + public void testStateMergingOverridePriority() { + String sessionId1 = "priority-test-1"; + String sessionId2 = "priority-test-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("app:shared", "app_level"); + state1.put("user:shared", "user_level"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("session:shared", "session_level"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session session1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId1, Optional.empty()) + .blockingGet(); + Session session2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(session1); + assertNotNull(session2); + + assertEquals("app_level", session1.state().get("app:shared")); + assertEquals("user_level", session1.state().get("user:shared")); + + assertEquals("app_level", session2.state().get("app:shared")); + assertEquals("user_level", session2.state().get("user:shared")); + assertEquals("session_level", session2.state().get("session:shared")); + } +} diff --git a/core/src/test/java/com/google/adk/sessions/db/converter/JsonUserTypeTest.java b/core/src/test/java/com/google/adk/sessions/db/converter/JsonUserTypeTest.java new file mode 100644 index 000000000..66a0a994e --- /dev/null +++ b/core/src/test/java/com/google/adk/sessions/db/converter/JsonUserTypeTest.java @@ -0,0 +1,150 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions.db.converter; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.sql.Clob; +import java.sql.ResultSet; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.postgresql.util.PGobject; + +public class JsonUserTypeTest { + + private JsonUserType jsonUserType; + + @BeforeEach + public void setUp() { + jsonUserType = new JsonUserType(); + } + + @Test + public void testNullSafeGet_PostgreSQL_PGobject() throws Exception { + PGobject pgObject = new PGobject(); + pgObject.setType("jsonb"); + pgObject.setValue("{\"key\":\"value\"}"); + + ResultSet rs = mock(ResultSet.class); + when(rs.getObject(0)).thenReturn(pgObject); + + Map result = (Map) jsonUserType.nullSafeGet(rs, 0, null, null); + + assertNotNull(result); + assertEquals("value", result.get("key")); + } + + @Test + public void testNullSafeGet_MySQL_String() throws Exception { + String jsonString = "{\"key\":\"value\",\"number\":123}"; + + ResultSet rs = mock(ResultSet.class); + when(rs.getObject(0)).thenReturn(jsonString); + + Map result = (Map) jsonUserType.nullSafeGet(rs, 0, null, null); + + assertNotNull(result); + assertEquals("value", result.get("key")); + assertEquals(123, result.get("number")); + } + + @Test + public void testNullSafeGet_H2_Clob() throws Exception { + String jsonString = "{\"key\":\"value\"}"; + Clob clob = mock(Clob.class); + when(clob.length()).thenReturn((long) jsonString.length()); + when(clob.getSubString(1, jsonString.length())).thenReturn(jsonString); + + ResultSet rs = mock(ResultSet.class); + when(rs.getObject(0)).thenReturn(clob); + + Map result = (Map) jsonUserType.nullSafeGet(rs, 0, null, null); + + assertNotNull(result); + assertEquals("value", result.get("key")); + } + + @Test + public void testNullSafeGet_EmptyJson() throws Exception { + String emptyJson = "{}"; + + ResultSet rs = mock(ResultSet.class); + when(rs.getObject(0)).thenReturn(emptyJson); + + Map result = (Map) jsonUserType.nullSafeGet(rs, 0, null, null); + + assertNotNull(result); + assertTrue(result.isEmpty()); + } + + @Test + public void testNullSafeGet_Null() throws Exception { + ResultSet rs = mock(ResultSet.class); + when(rs.getObject(0)).thenReturn(null); + + Map result = (Map) jsonUserType.nullSafeGet(rs, 0, null, null); + + assertNotNull(result); + assertTrue(result.isEmpty()); + } + + @Test + public void testNullSafeGet_ComplexNestedJson() throws Exception { + String complexJson = "{\"level1\":{\"level2\":{\"level3\":\"deep\"}},\"array\":[1,2,3]}"; + + ResultSet rs = mock(ResultSet.class); + when(rs.getObject(0)).thenReturn(complexJson); + + Map result = (Map) jsonUserType.nullSafeGet(rs, 0, null, null); + + assertNotNull(result); + assertTrue(result.containsKey("level1")); + assertTrue(result.containsKey("array")); + } + + @Test + public void testNullSafeGet_EmptyClob() throws Exception { + Clob clob = mock(Clob.class); + when(clob.length()).thenReturn(0L); + when(clob.getSubString(1, 0)).thenReturn(""); + + ResultSet rs = mock(ResultSet.class); + when(rs.getObject(0)).thenReturn(clob); + + Map result = (Map) jsonUserType.nullSafeGet(rs, 0, null, null); + + assertNotNull(result); + assertTrue(result.isEmpty()); + } + + @Test + public void testNullSafeGet_JsonWithSpecialCharacters() throws Exception { + String jsonWithSpecial = "{\"emoji\":\"🚀\",\"quote\":\"He said \\\"hello\\\"\"}"; + + ResultSet rs = mock(ResultSet.class); + when(rs.getObject(0)).thenReturn(jsonWithSpecial); + + Map result = (Map) jsonUserType.nullSafeGet(rs, 0, null, null); + + assertNotNull(result); + assertEquals("🚀", result.get("emoji")); + } +} diff --git a/core/src/test/java/com/google/adk/testing/TestDatabaseConfig.java b/core/src/test/java/com/google/adk/testing/TestDatabaseConfig.java new file mode 100644 index 000000000..3268a5314 --- /dev/null +++ b/core/src/test/java/com/google/adk/testing/TestDatabaseConfig.java @@ -0,0 +1,95 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.testing; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; + +/** + * Centralized configuration for integration test databases. + * + *

These connection strings assume test databases are running via docker-compose.test.yml: + * + *

{@code
+ * docker-compose -f docker-compose.test.yml up -d
+ * }
+ */ +public final class TestDatabaseConfig { + + private TestDatabaseConfig() {} + + // MySQL Test Database Configuration + public static final String MYSQL_HOST = "localhost"; + public static final int MYSQL_PORT = 3307; + public static final String MYSQL_DATABASE = "adk_test"; + public static final String MYSQL_USER = "adk_user"; + public static final String MYSQL_PASSWORD = "adk_password"; + public static final String MYSQL_JDBC_URL = + String.format( + "jdbc:mysql://%s:%d/%s?user=%s&password=%s&useSSL=false&allowPublicKeyRetrieval=true", + MYSQL_HOST, MYSQL_PORT, MYSQL_DATABASE, MYSQL_USER, MYSQL_PASSWORD); + + // PostgreSQL Test Database Configuration + public static final String POSTGRES_HOST = "localhost"; + public static final int POSTGRES_PORT = 5433; + public static final String POSTGRES_DATABASE = "adk_test"; + public static final String POSTGRES_USER = "adk_user"; + public static final String POSTGRES_PASSWORD = "adk_password"; + public static final String POSTGRES_JDBC_URL = + String.format( + "jdbc:postgresql://%s:%d/%s?user=%s&password=%s", + POSTGRES_HOST, POSTGRES_PORT, POSTGRES_DATABASE, POSTGRES_USER, POSTGRES_PASSWORD); + + /** + * Checks if MySQL test database is available. + * + * @return true if connection succeeds, false otherwise + */ + public static boolean isMySQLAvailable() { + try (Connection conn = DriverManager.getConnection(MYSQL_JDBC_URL)) { + return conn.isValid(2); + } catch (SQLException e) { + return false; + } + } + + /** + * Checks if PostgreSQL test database is available. + * + * @return true if connection succeeds, false otherwise + */ + public static boolean isPostgreSQLAvailable() { + try (Connection conn = DriverManager.getConnection(POSTGRES_JDBC_URL)) { + return conn.isValid(2); + } catch (SQLException e) { + return false; + } + } + + /** + * Returns a helpful message for skipped tests when database is not available. + * + * @param databaseName The name of the database (MySQL or PostgreSQL) + * @return A message explaining how to start the database + */ + public static String getDatabaseNotAvailableMessage(String databaseName) { + return String.format( + "%s test database not available. Start it with: " + + "docker-compose -f docker-compose.test.yml up -d %s-test", + databaseName, databaseName.toLowerCase()); + } +} diff --git a/pom.xml b/pom.xml index 6009c7316..3a77b88a5 100644 --- a/pom.xml +++ b/pom.xml @@ -72,6 +72,9 @@ 1.4.0 3.9.0 5.4.3 + 6.6.4.Final + 6.2.1 + 11.17.0 @@ -244,6 +247,31 @@ graphviz-java ${graphviz.version} + + org.hibernate.orm + hibernate-core + ${hibernate.version} + + + org.hibernate.orm + hibernate-hikaricp + ${hibernate.version} + + + com.zaxxer + HikariCP + ${hikaricp.version} + + + org.flywaydb + flyway-core + ${flyway.version} + + + org.flywaydb + flyway-database-postgresql + ${flyway.version} + org.eclipse.jdt ecj