diff --git a/contrib/database-session-service/README.md b/contrib/database-session-service/README.md new file mode 100644 index 000000000..90b5d5415 --- /dev/null +++ b/contrib/database-session-service/README.md @@ -0,0 +1,87 @@ +# Database Session Service + +JDBC-based session service implementation for ADK Java. + +## Features + +- **No ORM Dependencies**: Uses JDBC with HikariCP for connection pooling +- **Multi-Database Support**: PostgreSQL, MySQL, H2 (SQLite not supported) +- **Automatic Schema Management**: Flyway migrations handle table creation/updates +- **3-Tier State Storage**: Separate tables for app-level, user-level, and session-level state +- **Reactive API**: RxJava 3 Single/Maybe/Completable return types + +## Dependencies + +- **HikariCP**: High-performance JDBC connection pool +- **Flyway**: Database schema versioning and migration +- **Jackson**: JSON serialization for events and state +- **RxJava 3**: Reactive programming support + +## Database Schema + +The service creates and manages these tables: + +- `app_states`: Application-level state (shared across all users) +- `user_states`: User-level state (shared across user's sessions) +- `sessions`: Individual session data +- `events`: Event history for each session + +## Usage + +```java +// Create service with database URL +String dbUrl = "jdbc:postgresql://localhost:5432/adk?user=postgres&password=secret"; +try (DatabaseSessionService sessionService = new DatabaseSessionService(dbUrl)) { + + // Create a session + Session session = sessionService.createSession( + "myApp", + "user123", + new ConcurrentHashMap<>(), + null + ).blockingGet(); + + // Append an event + Event event = Event.builder() + .id(UUID.randomUUID().toString()) + .invocationId("inv-1") + .timestamp(System.currentTimeMillis()) + .build(); + + Event appendedEvent = sessionService.appendEvent(session, event).blockingGet(); +} +``` + +## Supported Databases + +- **PostgreSQL**: Full support with JSONB + - URL: `jdbc:postgresql://host:port/database?user=...&password=...` +- **MySQL**: Full support with JSON + - URL: `jdbc:mysql://host:port/database?user=...&password=...` +- **H2**: For testing and development + - URL: `jdbc:h2:mem:testdb` or `jdbc:h2:file:./data/mydb` +- **Cloud Spanner**: Full support + - URL: `jdbc:cloudspanner:/projects/PROJECT_ID/instances/INSTANCE_ID/databases/DATABASE_ID` +- **SQLite**: NOT supported (no UPSERT support) + +## State Management + +State is stored across three tables with merge priority: + +1. **App State** (lowest priority): `app:key` prefix +2. **User State** (medium priority): `user:key` prefix +3. **Session State** (highest priority): No prefix + +When retrieving a session, states are merged: app → user → session (higher priority overwrites). + +## Configuration + +Optional properties can be passed to the constructor: + +```java +Map props = new HashMap<>(); +props.put("connectionTimeout", 30000); +props.put("maximumPoolSize", 10); + +DatabaseSessionService service = new DatabaseSessionService(dbUrl, props); +``` diff --git a/contrib/database-session-service/pom.xml b/contrib/database-session-service/pom.xml new file mode 100644 index 000000000..bfab7cd22 --- /dev/null +++ b/contrib/database-session-service/pom.xml @@ -0,0 +1,186 @@ + + + + 4.0.0 + + + com.google.adk + google-adk-parent + 0.5.1-SNAPSHOT + ../../pom.xml + + + google-adk-database-session-service + Agent Development Kit - Database Session Service + Database integration with Agent Development Kit for User Session Management + + + + + com.google.adk + google-adk + ${project.version} + + + + + + + com.zaxxer + HikariCP + + + + + com.fasterxml.jackson.core + jackson-databind + + + + + io.reactivex.rxjava3 + rxjava + + + + + org.slf4j + slf4j-api + + + + + + + org.flywaydb + flyway-core + + + + + org.flywaydb + flyway-database-postgresql + runtime + + + + + org.flywaydb + flyway-mysql + runtime + + + + + org.flywaydb + flyway-gcp-spanner + runtime + + + + + + + org.postgresql + postgresql + true + + + + + com.mysql + mysql-connector-j + true + + + + + com.google.cloud + google-cloud-spanner-jdbc + true + + + + + + + com.h2database + h2 + test + + + + org.junit.jupiter + junit-jupiter-api + test + + + + org.junit.jupiter + junit-jupiter-params + test + + + + org.junit.jupiter + junit-jupiter-engine + test + + + + org.slf4j + slf4j-simple + test + + + + com.google.truth + truth + test + + + + org.mockito + mockito-core + test + + + + + + + src/main/resources + true + + + + + maven-compiler-plugin + + + org.jacoco + jacoco-maven-plugin + + + org.apache.maven.plugins + maven-surefire-plugin + + + + diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/DatabaseSessionService.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/DatabaseSessionService.java new file mode 100644 index 000000000..e99ca91df --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/DatabaseSessionService.java @@ -0,0 +1,1186 @@ +package com.google.adk.sessions; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.events.Event; +import com.google.adk.sessions.dao.EventDao; +import com.google.adk.sessions.dao.SessionDao; +import com.google.adk.sessions.dao.StateDao; +import com.google.adk.sessions.dialect.DialectDetector; +import com.google.adk.sessions.dialect.SqlDialect; +import com.google.adk.sessions.model.AppStateRow; +import com.google.adk.sessions.model.EventRow; +import com.google.adk.sessions.model.SessionRow; +import com.google.adk.sessions.model.UserStateRow; +import com.google.adk.sessions.util.JdbcTemplate; +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.schedulers.Schedulers; +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; +import org.flywaydb.core.Flyway; +import org.flywaydb.core.api.FlywayException; +import org.flywaydb.core.api.output.MigrateResult; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * JDBC-based implementation of {@link BaseSessionService}. + * + *

This service provides persistent session management using JDBC and HikariCP connection + * pooling. It supports multiple databases with automatic dialect detection and schema management + * via Flyway migrations. + * + *

Features: + * + *

+ * + *

Supported Databases: + * + *

+ * + *

State Management: + * + *

This service implements a 3-tier state storage model: + * + *

+ * + *

State is merged with priority: App → User → Session (higher priority overwrites lower). + * + *

Thread Safety: + * + *

This class is thread-safe. All database operations use connection pooling and database + * transactions. Concurrent operations on the same session are serialized via pessimistic locking + * (SELECT ... FOR UPDATE). + * + *

Resource Management: + * + *

This class implements {@link AutoCloseable}. Always use try-with-resources or explicitly call + * {@link #close()} to release database connections: + * + *

{@code
+ * try (DatabaseSessionService service = new DatabaseSessionService(jdbcUrl)) {
+ *     // Use the service
+ * } // Connections automatically closed
+ * }
+ * + *

Example Usage: + * + *

{@code
+ * // Create service with PostgreSQL
+ * String jdbcUrl = "jdbc:postgresql://localhost:5432/adk?user=postgres&password=secret";
+ * try (DatabaseSessionService service = new DatabaseSessionService(jdbcUrl)) {
+ *
+ *     // Create a session with initial state
+ *     ConcurrentMap state = new ConcurrentHashMap<>();
+ *     state.put("app:version", "1.0");      // App-level state
+ *     state.put("user:theme", "dark");       // User-level state
+ *     state.put("currentStep", 1);           // Session-level state
+ *
+ *     Session session = service.createSession("myApp", "user123", state, null).blockingGet();
+ *
+ *     // Append an event with state delta
+ *     Event event = Event.builder()
+ *         .id(UUID.randomUUID().toString())
+ *         .invocationId("inv-1")
+ *         .timestamp(System.currentTimeMillis())
+ *         .actions(EventActions.builder()
+ *             .stateDelta(Map.of("currentStep", 2))
+ *             .build())
+ *         .build();
+ *
+ *     Event appendedEvent = service.appendEvent(session, event).blockingGet();
+ *
+ *     // Retrieve updated session
+ *     Session updated = service.getSession("myApp", "user123", session.id(), Optional.empty())
+ *         .blockingGet();
+ * }
+ * }
+ * + * @see BaseSessionService + * @see Session + * @see Event + * @see State + */ +public class DatabaseSessionService implements BaseSessionService, AutoCloseable { + + private static final Logger logger = LoggerFactory.getLogger(DatabaseSessionService.class); + + private final HikariDataSource dataSource; + private final JdbcTemplate jdbcTemplate; + private final SqlDialect dialect; + private final ObjectMapper objectMapper; + private final SessionDao sessionDao; + private final EventDao eventDao; + private final StateDao stateDao; + private final AtomicBoolean closed = new AtomicBoolean(false); + + /** + * Creates a new DatabaseSessionService with default configuration. + * + *

This constructor uses default HikariCP connection pool settings: max pool size = 10, min + * idle = 2, connection timeout = 30s. + * + *

The database dialect is automatically detected from the JDBC URL, and Flyway migrations are + * run automatically to create/update the schema. + * + * @param jdbcUrl the JDBC connection URL (e.g., {@code + * jdbc:postgresql://localhost:5432/adk?user=postgres&password=secret}) + * @throws NullPointerException if jdbcUrl is null + * @throws IllegalArgumentException if the database dialect cannot be detected from the JDBC URL + * @throws SessionException if database migration fails + * @see #DatabaseSessionService(String, Map) + */ + public DatabaseSessionService(String jdbcUrl) { + this(jdbcUrl, Collections.emptyMap()); + } + + /** + * Creates a new DatabaseSessionService with custom configuration. + * + *

This constructor allows customization of both HikariCP connection pool settings and + * database-specific properties. + * + *

Example with custom connection pool settings: + * + *

{@code
+   * Map props = new HashMap<>();
+   * props.put("hikari.connectionTimeout", 60000);  // 60 seconds
+   * props.put("hikari.maximumPoolSize", 20);       // 20 connections
+   * DatabaseSessionService service = new DatabaseSessionService(jdbcUrl, props);
+   * }
+ * + *

Supported HikariCP properties (prefix with "hikari."): + * + *

+ * + *

All properties without the "hikari." prefix are passed to the underlying DataSource. + * + * @param jdbcUrl the JDBC connection URL + * @param properties configuration properties for HikariCP and the DataSource + * @throws NullPointerException if jdbcUrl is null + * @throws IllegalArgumentException if the database dialect cannot be detected from the JDBC URL + * @throws SessionException if database migration fails + */ + public DatabaseSessionService(String jdbcUrl, Map properties) { + Objects.requireNonNull(jdbcUrl, "JDBC URL cannot be null"); + + this.dialect = DialectDetector.detectFromJdbcUrl(jdbcUrl); + logger.info("Detected SQL dialect: {}", dialect.dialectName()); + + runMigrations(jdbcUrl); + + this.dataSource = createDataSource(jdbcUrl, properties); + + this.jdbcTemplate = new JdbcTemplate(dataSource); + this.objectMapper = com.google.adk.JsonBaseModel.getMapper(); + this.sessionDao = new SessionDao(dialect); + this.eventDao = new EventDao(dialect); + this.stateDao = new StateDao(dialect); + + logger.info( + "DatabaseSessionService initialized with {} (JDBC implementation)", dialect.dialectName()); + } + + private void runMigrations(String jdbcUrl) { + try { + String dialectFolder = extractDialectFolder(dialect.dialectName()); + String flywayLocation = "classpath:db/migration/" + dialectFolder; + + logger.info("Starting Flyway database migration"); + logger.info("Dialect: {}", dialect.dialectName()); + logger.info("Migration location: {}", flywayLocation); + logger.info("JDBC URL: {}", jdbcUrl.replaceAll("password=[^&;]*", "password=***")); + + String baselineOnMigrateStr = + System.getProperty( + "FLYWAY_BASELINE_ON_MIGRATE", + System.getenv().getOrDefault("FLYWAY_BASELINE_ON_MIGRATE", "false")); + boolean baselineOnMigrate = Boolean.parseBoolean(baselineOnMigrateStr); + + String lockRetryCountStr = + System.getProperty( + "FLYWAY_LOCK_RETRY_COUNT", + System.getenv().getOrDefault("FLYWAY_LOCK_RETRY_COUNT", "120")); + int lockRetryCount = Integer.parseInt(lockRetryCountStr); + + logger.info( + "Flyway configuration: baselineOnMigrate={}, lockRetryCount={}", + baselineOnMigrate, + lockRetryCount); + + Flyway flyway = + Flyway.configure() + .dataSource(jdbcUrl, null, null) + .locations(flywayLocation) + .cleanDisabled(true) + .lockRetryCount(lockRetryCount) + .baselineOnMigrate(baselineOnMigrate) + .load(); + + MigrateResult result = flyway.migrate(); + + if (result.migrationsExecuted > 0) { + logger.info( + "Flyway migration completed: {} migration(s) applied successfully", + result.migrationsExecuted); + } else { + logger.info("Database schema is up to date (no migrations applied)"); + } + logger.info("Flyway migration complete"); + } catch (FlywayException e) { + throw new SessionException("Failed to run database migrations", e); + } + } + + private String extractDialectFolder(String dialectName) { + String lower = dialectName.toLowerCase(); + if (lower.contains("postgres")) return "postgresql"; + if (lower.contains("mysql")) return "mysql"; + if (lower.contains("h2")) return "h2"; + if (lower.contains("spanner")) return "spanner"; + throw new IllegalArgumentException("Unsupported dialect: " + dialectName); + } + + private HikariDataSource createDataSource(String jdbcUrl, Map properties) { + HikariConfig config = new HikariConfig(); + config.setJdbcUrl(jdbcUrl); + + int maxPoolSize = getIntProperty(properties, "hikari.maximumPoolSize", 10); + int minIdle = getIntProperty(properties, "hikari.minimumIdle", 2); + long connTimeout = getLongProperty(properties, "hikari.connectionTimeout", 30000L); + long idleTimeout = getLongProperty(properties, "hikari.idleTimeout", 600000L); + long maxLifetime = getLongProperty(properties, "hikari.maxLifetime", 1800000L); + + config.setMaximumPoolSize(maxPoolSize); + config.setMinimumIdle(minIdle); + config.setConnectionTimeout(connTimeout); + config.setIdleTimeout(idleTimeout); + config.setMaxLifetime(maxLifetime); + + properties.entrySet().stream() + .filter(e -> !e.getKey().startsWith("hikari.")) + .forEach(e -> config.addDataSourceProperty(e.getKey(), e.getValue())); + + logger.debug("Initializing HikariCP connection pool"); + logger.debug( + "Pool configuration: maxPoolSize={}, minIdle={}, connectionTimeout={}ms, idleTimeout={}ms, maxLifetime={}ms", + maxPoolSize, + minIdle, + connTimeout, + idleTimeout, + maxLifetime); + + HikariDataSource dataSource = new HikariDataSource(config); + logger.debug("HikariCP connection pool created successfully"); + return dataSource; + } + + /** + * Creates a new session with the specified parameters. + * + *

This method creates a new session and initializes the 3-tier state storage (app, user, + * session). If the provided state map contains keys with prefixes {@code app:} or {@code user:}, + * those entries are stored in the corresponding state tables. + * + *

State Handling: + * + *

+ * + *

If app or user state already exists, it is updated with the new values. + * + * @param appName the application name (must not be null) + * @param userId the user identifier (must not be null) + * @param state the initial state map, can be null or empty + * @param sessionId optional session ID; if null or empty, a UUID is generated + * @return a Single that emits the created Session + * @throws NullPointerException if appName or userId is null (checked by BaseSessionService) + * @see State#APP_PREFIX + * @see State#USER_PREFIX + * @see State#TEMP_PREFIX + */ + @Override + public Single createSession( + String appName, String userId, ConcurrentMap state, String sessionId) { + checkNotClosed(); + Objects.requireNonNull(appName, "appName cannot be null"); + Objects.requireNonNull(userId, "userId cannot be null"); + + return Single.fromCallable(() -> createSessionInTransaction(appName, userId, state, sessionId)) + .subscribeOn(Schedulers.io()); + } + + /** + * Retrieves a session by its identifiers. + * + *

This method fetches the session from the database and merges the 3-tier state (app → user → + * session) before returning. The returned session includes all events up to the specified limit. + * + *

Event Filtering: + * + *

+ * + * @param appName the application name (must not be null) + * @param userId the user identifier (must not be null) + * @param sessionId the session identifier (must not be null) + * @param config optional configuration for event filtering + * @return a Maybe that emits the Session if found, or completes empty if not found + * @throws NullPointerException if appName, userId, or sessionId is null + * @see GetSessionConfig + */ + @Override + public Maybe getSession( + String appName, String userId, String sessionId, Optional config) { + checkNotClosed(); + Objects.requireNonNull(appName, "appName cannot be null"); + Objects.requireNonNull(userId, "userId cannot be null"); + Objects.requireNonNull(sessionId, "sessionId cannot be null"); + Objects.requireNonNull(config, "config cannot be null"); + + return Maybe.fromCallable( + () -> + jdbcTemplate.inTransaction( + ops -> { + Optional sessionOpt = + sessionDao.findSession(ops, appName, userId, sessionId); + + if (!sessionOpt.isPresent()) { + return null; + } + + return buildSessionFromRow(ops, sessionOpt.get(), config); + })) + .subscribeOn(Schedulers.io()); + } + + /** + * Lists all sessions for a specific application and user. + * + *

The sessions are returned without events and without merged app/user state (state maps will + * be empty). Use {@link #getSession} to retrieve full session details. + * + *

Sessions are ordered by update_time descending (most recently updated first), with a limit + * of 1000 sessions. + * + * @param appName the application name (must not be null) + * @param userId the user identifier (must not be null) + * @return a Single that emits a ListSessionsResponse containing the sessions + * @throws NullPointerException if appName or userId is null + */ + @Override + public Single listSessions(String appName, String userId) { + checkNotClosed(); + Objects.requireNonNull(appName, "appName cannot be null"); + Objects.requireNonNull(userId, "userId cannot be null"); + + return Single.fromCallable( + () -> { + return jdbcTemplate.inTransaction( + ops -> { + List sessionRows = sessionDao.listSessions(ops, appName, userId); + + List sessions = + sessionRows.stream() + .map( + row -> + toDomainSession( + row, + Collections.emptyList(), + new ConcurrentHashMap<>(), + new ConcurrentHashMap<>())) + .collect(Collectors.toList()); + + return ListSessionsResponse.builder().sessions(sessions).build(); + }); + }) + .subscribeOn(Schedulers.io()); + } + + /** + * Deletes a session and all its associated events. + * + *

This operation cascades to delete all events associated with the session. App-level and + * user-level state are NOT deleted (they may be shared with other sessions). + * + * @param appName the application name (must not be null) + * @param userId the user identifier (must not be null) + * @param sessionId the session identifier (must not be null) + * @return a Completable that completes when the session is deleted + * @throws SessionNotFoundException if the session does not exist + * @throws NullPointerException if any parameter is null + */ + @Override + public Completable deleteSession(String appName, String userId, String sessionId) { + checkNotClosed(); + Objects.requireNonNull(appName, "appName cannot be null"); + Objects.requireNonNull(userId, "userId cannot be null"); + Objects.requireNonNull(sessionId, "sessionId cannot be null"); + + return Completable.fromAction( + () -> { + jdbcTemplate.inTransaction( + ops -> { + Optional sessionOpt = + sessionDao.findSession(ops, appName, userId, sessionId); + + if (!sessionOpt.isPresent()) { + throw new SessionNotFoundException( + "Session not found: " + appName + "/" + userId + "/" + sessionId); + } + + sessionDao.deleteSession(ops, appName, userId, sessionId); + return null; + }); + }) + .subscribeOn(Schedulers.io()); + } + + /** + * Lists all events for a session. + * + *

This method fetches ALL events for the session in chronological order (oldest first). + * + * @param appName the application name (must not be null) + * @param userId the user identifier (must not be null) + * @param sessionId the session identifier (must not be null) + * @return a Single that emits a ListEventsResponse containing all events + * @throws SessionNotFoundException if the session does not exist + * @throws IllegalStateException if the service has been closed + */ + @Override + public Single listEvents(String appName, String userId, String sessionId) { + checkNotClosed(); + Objects.requireNonNull(appName, "appName cannot be null"); + Objects.requireNonNull(userId, "userId cannot be null"); + Objects.requireNonNull(sessionId, "sessionId cannot be null"); + + return Single.fromCallable( + () -> { + return jdbcTemplate.inTransaction( + ops -> { + Optional sessionOpt = + sessionDao.findSession(ops, appName, userId, sessionId); + + if (!sessionOpt.isPresent()) { + throw new SessionNotFoundException( + "Session not found: " + appName + "/" + userId + "/" + sessionId); + } + + List eventRows = eventDao.listEvents(ops, appName, userId, sessionId); + + List events = + eventRows.stream().map(this::toEvent).collect(Collectors.toList()); + + return ListEventsResponse.builder().events(events).build(); + }); + }) + .subscribeOn(Schedulers.io()); + } + + /** + * Appends an event to a session and persists it to the database. + * + *

This method processes the event's state delta (if present) and applies it to the appropriate + * state tier (app, user, or session). The session's update_time is refreshed even if there is no + * state delta. + * + *

State Delta Processing: + * + *

    + *
  • Keys starting with {@code app:} update app-level state + *
  • Keys starting with {@code user:} update user-level state + *
  • Keys starting with {@code temp:} are ignored (not persisted) + *
  • All other keys update session-level state + *
  • Use {@link State#REMOVED} as a value to delete a state key + *
+ * + *

This operation uses pessimistic locking (SELECT ... FOR UPDATE) to prevent concurrent + * modifications to the same session. + * + * @param session the session to append the event to (must not be null) + * @param event the event to append (must not be null) + * @return a Single that emits the updated Event after processing + * @throws NullPointerException if session or event is null + * @throws SessionNotFoundException if the session does not exist + * @throws IllegalStateException if the service has been closed + * @see State#REMOVED + */ + @Override + public Single appendEvent(Session session, Event event) { + checkNotClosed(); + Objects.requireNonNull(session, "session cannot be null"); + Objects.requireNonNull(event, "event cannot be null"); + Objects.requireNonNull(session.appName(), "session.appName cannot be null"); + Objects.requireNonNull(session.userId(), "session.userId cannot be null"); + Objects.requireNonNull(session.id(), "session.id cannot be null"); + + // DB first, then memory + // If DB fails, transaction rolls back and memory is never updated + return persistEventToDatabase(session.appName(), session.userId(), session.id(), event) + .andThen(BaseSessionService.super.appendEvent(session, event)) + .doOnError( + throwable -> { + logger.error( + "Failed to append event to session {}/{}/{}: {}", + session.appName(), + session.userId(), + session.id(), + throwable.getMessage(), + throwable); + }); + } + + /** + * Persists an event to the database. + * + *

This method handles event persistence and state delta updates. It acquires row-level locks + * on the session, app state, and user state during the transaction. + * + * @param appName the application name (must not be null) + * @param userId the user identifier (must not be null) + * @param sessionId the session identifier (must not be null) + * @param event the event to append (must not be null) + * @return a Completable that completes when the event is persisted + * @throws SessionNotFoundException if the session does not exist + */ + private Completable persistEventToDatabase( + String appName, String userId, String sessionId, Event event) { + return Completable.fromAction( + () -> { + jdbcTemplate.inTransaction( + ops -> { + Optional sessionOpt = + sessionDao.findSessionForUpdate(ops, appName, userId, sessionId); + + if (!sessionOpt.isPresent()) { + throw new SessionNotFoundException( + "Session not found: " + appName + "/" + userId + "/" + sessionId); + } + + SessionRow sessionRow = sessionOpt.get(); + + Optional appStateOpt = stateDao.getAppStateForUpdate(ops, appName); + Map appState = + appStateOpt + .map(s -> fromJson(s.getState())) + .orElse(new ConcurrentHashMap<>()); + + Optional userStateOpt = + stateDao.getUserStateForUpdate(ops, appName, userId); + Map userState = + userStateOpt + .map(s -> fromJson(s.getState())) + .orElse(new ConcurrentHashMap<>()); + + if (event.actions() != null && event.actions().stateDelta() != null) { + ConcurrentMap stateDelta = event.actions().stateDelta(); + + Map appStateDelta = new ConcurrentHashMap<>(); + Map userStateDelta = new ConcurrentHashMap<>(); + Map sessionStateDelta = new ConcurrentHashMap<>(); + + for (Map.Entry entry : stateDelta.entrySet()) { + String key = entry.getKey(); + if (key.startsWith(State.APP_PREFIX)) { + String unprefixedKey = key.substring(State.APP_PREFIX.length()); + appStateDelta.put(unprefixedKey, entry.getValue()); + } else if (key.startsWith(State.USER_PREFIX)) { + String unprefixedKey = key.substring(State.USER_PREFIX.length()); + userStateDelta.put(unprefixedKey, entry.getValue()); + } else if (!key.startsWith(State.TEMP_PREFIX)) { + sessionStateDelta.put(key, entry.getValue()); + } + } + + if (!appStateDelta.isEmpty()) { + for (Map.Entry entry : appStateDelta.entrySet()) { + if (entry.getValue() == State.REMOVED) { + appState.remove(entry.getKey()); + } else { + appState.put(entry.getKey(), entry.getValue()); + } + } + AppStateRow updatedAppState = new AppStateRow(); + updatedAppState.setAppName(appName); + updatedAppState.setState(toJson(appState)); + updatedAppState.setUpdateTime(Instant.now()); + stateDao.upsertAppState(ops, updatedAppState); + } + + if (!userStateDelta.isEmpty()) { + for (Map.Entry entry : userStateDelta.entrySet()) { + if (entry.getValue() == State.REMOVED) { + userState.remove(entry.getKey()); + } else { + userState.put(entry.getKey(), entry.getValue()); + } + } + UserStateRow updatedUserState = new UserStateRow(); + updatedUserState.setAppName(appName); + updatedUserState.setUserId(userId); + updatedUserState.setState(toJson(userState)); + updatedUserState.setUpdateTime(Instant.now()); + stateDao.upsertUserState(ops, updatedUserState); + } + + if (!sessionStateDelta.isEmpty()) { + Map sessionState = fromJson(sessionRow.getState()); + for (Map.Entry entry : sessionStateDelta.entrySet()) { + if (entry.getValue() == State.REMOVED) { + sessionState.remove(entry.getKey()); + } else { + sessionState.put(entry.getKey(), entry.getValue()); + } + } + sessionRow.setState(toJson(sessionState)); + } + } + + sessionRow.setUpdateTime(Instant.now()); + sessionDao.updateSession(ops, sessionRow); + + EventRow eventRow = fromDomainEvent(event, appName, userId, sessionId); + eventDao.insertEvent(ops, eventRow); + + return null; + }); + }) + .subscribeOn(Schedulers.io()); + } + + /** + * Retrieves the app-level state for an application. + * + *

App-level state is shared across all users and sessions for the specified application. This + * is typically used for application-wide configuration or data. + * + * @param appName the application name (must not be null) + * @return a Single that emits the app state map, or null if no app state exists + */ + public Single> getAppState(String appName) { + return Single.fromCallable( + () -> { + return jdbcTemplate.inTransaction( + ops -> { + Optional appStateOpt = stateDao.getAppState(ops, appName); + return appStateOpt + .map(s -> (Map) fromJson(s.getState())) + .orElse(null); + }); + }) + .subscribeOn(Schedulers.io()); + } + + /** + * Sets or replaces the app-level state for an application. + * + *

This operation completely replaces the existing app state. If you need to update specific + * keys, retrieve the current state first, modify it, and then set it back. + * + *

Warning: This affects all users and sessions for the application. + * + * @param appName the application name (must not be null) + * @param state the new app state map (must not be null) + * @return a Completable that completes when the state is updated + */ + public Completable setAppState(String appName, Map state) { + return Completable.fromAction( + () -> { + jdbcTemplate.inTransaction( + ops -> { + AppStateRow row = new AppStateRow(); + row.setAppName(appName); + row.setState(toJson(state)); + row.setUpdateTime(Instant.now()); + + stateDao.upsertAppState(ops, row); + return null; + }); + }) + .subscribeOn(Schedulers.io()); + } + + /** + * Retrieves the user-level state for a specific user in an application. + * + *

User-level state is shared across all sessions for the specified user. This is typically + * used for user preferences or data that should persist across sessions. + * + * @param appName the application name (must not be null) + * @param userId the user identifier (must not be null) + * @return a Single that emits the user state map, or null if no user state exists + */ + public Single> getUserState(String appName, String userId) { + return Single.fromCallable( + () -> { + return jdbcTemplate.inTransaction( + ops -> { + Optional userStateOpt = + stateDao.getUserState(ops, appName, userId); + return userStateOpt + .map(s -> (Map) fromJson(s.getState())) + .orElse(null); + }); + }) + .subscribeOn(Schedulers.io()); + } + + /** + * Sets or replaces the user-level state for a specific user in an application. + * + *

This operation completely replaces the existing user state. If you need to update specific + * keys, retrieve the current state first, modify it, and then set it back. + * + *

Warning: This affects all sessions for the specified user. + * + * @param appName the application name (must not be null) + * @param userId the user identifier (must not be null) + * @param state the new user state map (must not be null) + * @return a Completable that completes when the state is updated + */ + public Completable setUserState(String appName, String userId, Map state) { + return Completable.fromAction( + () -> { + jdbcTemplate.inTransaction( + ops -> { + UserStateRow row = new UserStateRow(); + row.setAppName(appName); + row.setUserId(userId); + row.setState(toJson(state)); + row.setUpdateTime(Instant.now()); + + stateDao.upsertUserState(ops, row); + return null; + }); + }) + .subscribeOn(Schedulers.io()); + } + + private Session toDomainSession( + SessionRow row, + List eventRows, + Map appState, + Map userState) { + ConcurrentMap mergedState = new ConcurrentHashMap<>(); + + if (appState != null) { + for (Map.Entry entry : appState.entrySet()) { + mergedState.put(State.APP_PREFIX + entry.getKey(), entry.getValue()); + } + } + + if (userState != null) { + for (Map.Entry entry : userState.entrySet()) { + mergedState.put(State.USER_PREFIX + entry.getKey(), entry.getValue()); + } + } + + Map sessionStateMap = fromJson(row.getState()); + if (sessionStateMap != null) { + mergedState.putAll(sessionStateMap); + } + + List events = eventRows.stream().map(this::toEvent).collect(Collectors.toList()); + + return Session.builder(row.getId()) + .appName(row.getAppName()) + .userId(row.getUserId()) + .state(mergedState) + .events(events) + .lastUpdateTime(row.getUpdateTime()) + .build(); + } + + private Event toEvent(EventRow row) { + try { + Event event = objectMapper.readValue(row.getEventData(), Event.class); + + event.setId(row.getId()); + event.setInvocationId(row.getInvocationId()); + event.setTimestamp(row.getTimestamp().toEpochMilli()); + + return event; + } catch (Exception e) { + logger.error("Failed to deserialize event {}: {}", row.getId(), e.getMessage(), e); + throw new SessionException("Failed to convert EventRow to Event", e); + } + } + + private EventRow fromDomainEvent(Event event, String appName, String userId, String sessionId) { + EventRow row = new EventRow(); + row.setId(event.id()); + row.setAppName(appName); + row.setUserId(userId); + row.setSessionId(sessionId); + row.setInvocationId(event.invocationId()); + row.setTimestamp(Instant.ofEpochMilli(event.timestamp())); + + try { + Map eventDataMap = + objectMapper.convertValue(event, new TypeReference>() {}); + + eventDataMap.remove("id"); + eventDataMap.remove("invocationId"); + eventDataMap.remove("timestamp"); + + String eventDataJson = objectMapper.writeValueAsString(eventDataMap); + row.setEventData(eventDataJson); + } catch (Exception e) { + logger.error("Failed to serialize event {}: {}", event.id(), e.getMessage(), e); + throw new SessionException("Failed to convert Event to EventRow", e); + } + + return row; + } + + private String toJson(Map map) { + try { + return objectMapper.writeValueAsString(map); + } catch (Exception e) { + throw new SessionException("Failed to serialize to JSON", e); + } + } + + private Map fromJson(String json) { + if (json == null || json.isEmpty()) { + return new ConcurrentHashMap<>(); + } + try { + return objectMapper.readValue(json, new TypeReference>() {}); + } catch (Exception e) { + throw new SessionException("Failed to deserialize from JSON", e); + } + } + + private void checkNotClosed() { + if (closed.get()) { + throw new IllegalStateException( + "DatabaseSessionService is closed. Create a new instance or ensure close() is not called prematurely."); + } + } + + private static int getIntProperty(Map properties, String key, int defaultValue) { + Object value = properties.get(key); + if (value == null) { + return defaultValue; + } + if (value instanceof Integer) { + return (Integer) value; + } + if (value instanceof Number) { + return ((Number) value).intValue(); + } + if (value instanceof String) { + try { + return Integer.parseInt((String) value); + } catch (NumberFormatException e) { + logger.warn( + "Invalid integer value for property {}: {}. Using default: {}", + key, + value, + defaultValue); + return defaultValue; + } + } + logger.warn( + "Unsupported type for property {}: {}. Using default: {}", + key, + value.getClass().getName(), + defaultValue); + return defaultValue; + } + + private static long getLongProperty( + Map properties, String key, long defaultValue) { + Object value = properties.get(key); + if (value == null) { + return defaultValue; + } + if (value instanceof Long) { + return (Long) value; + } + if (value instanceof Number) { + return ((Number) value).longValue(); + } + if (value instanceof String) { + try { + return Long.parseLong((String) value); + } catch (NumberFormatException e) { + logger.warn( + "Invalid long value for property {}: {}. Using default: {}", key, value, defaultValue); + return defaultValue; + } + } + logger.warn( + "Unsupported type for property {}: {}. Using default: {}", + key, + value.getClass().getName(), + defaultValue); + return defaultValue; + } + + /** + * Helper method to create a session within a transaction. + * + *

Extracted from {@link #createSession} to improve testability and reduce lambda nesting. + * + * @param appName the application name + * @param userId the user identifier + * @param state initial state map (may be null) + * @param sessionId session ID (generates UUID if null/empty) + * @return the created Session + */ + private Session createSessionInTransaction( + String appName, String userId, ConcurrentMap state, String sessionId) + throws java.sql.SQLException { + String id = + (sessionId != null && !sessionId.isEmpty()) ? sessionId : UUID.randomUUID().toString(); + + Instant now = Instant.now(); + + return jdbcTemplate.inTransaction( + ops -> { + Map appStateMap = new ConcurrentHashMap<>(); + Map userStateMap = new ConcurrentHashMap<>(); + Map sessionStateMap = new ConcurrentHashMap<>(); + + if (state != null) { + for (Map.Entry entry : state.entrySet()) { + String key = entry.getKey(); + if (key.startsWith(State.TEMP_PREFIX)) { + continue; + } + + if (key.startsWith(State.APP_PREFIX)) { + String unprefixedKey = key.substring(State.APP_PREFIX.length()); + appStateMap.put(unprefixedKey, entry.getValue()); + } else if (key.startsWith(State.USER_PREFIX)) { + String unprefixedKey = key.substring(State.USER_PREFIX.length()); + userStateMap.put(unprefixedKey, entry.getValue()); + } else { + sessionStateMap.put(key, entry.getValue()); + } + } + } + + Map appState = upsertAppStateIfNeeded(ops, appName, appStateMap, now); + Map userState = + upsertUserStateIfNeeded(ops, appName, userId, userStateMap, now); + + SessionRow row = new SessionRow(); + row.setAppName(appName); + row.setUserId(userId); + row.setId(id); + row.setState(toJson(sessionStateMap)); + row.setCreateTime(now); + row.setUpdateTime(now); + + sessionDao.insertSession(ops, row); + + return toDomainSession(row, Collections.emptyList(), appState, userState); + }); + } + + /** + * Helper method to build a Session from a SessionRow within a transaction. + * + *

Extracted from {@link #getSession} to improve testability and reduce lambda nesting. + * + * @param ops transaction operations + * @param sessionRow the session row from database + * @param config optional configuration for event filtering + * @return the built Session + */ + private Session buildSessionFromRow( + JdbcTemplate.JdbcOperations ops, SessionRow sessionRow, Optional config) + throws java.sql.SQLException { + String appName = sessionRow.getAppName(); + String userId = sessionRow.getUserId(); + String sessionId = sessionRow.getId(); + + Optional limit = config.flatMap(GetSessionConfig::numRecentEvents); + + List eventRows; + if (config.isPresent() && config.get().afterTimestamp().isPresent()) { + Instant afterTimestamp = config.get().afterTimestamp().get(); + eventRows = + eventDao.listEventsAfterTimestamp( + ops, appName, userId, sessionId, afterTimestamp, limit, 0); + } else { + eventRows = eventDao.listEvents(ops, appName, userId, sessionId, limit); + } + + Optional appStateOpt = stateDao.getAppState(ops, appName); + Map appState = + appStateOpt.map(s -> fromJson(s.getState())).orElse(new ConcurrentHashMap<>()); + + Optional userStateOpt = stateDao.getUserState(ops, appName, userId); + Map userState = + userStateOpt.map(s -> fromJson(s.getState())).orElse(new ConcurrentHashMap<>()); + + return toDomainSession(sessionRow, eventRows, appState, userState); + } + + /** + * Helper method to upsert app state if needed. + * + * @param ops transaction operations + * @param appName application name + * @param appStateMap state map to upsert (may be empty) + * @param now current timestamp + * @return the merged app state map + */ + private Map upsertAppStateIfNeeded( + JdbcTemplate.JdbcOperations ops, String appName, Map appStateMap, Instant now) + throws java.sql.SQLException { + Optional appStateOpt = stateDao.getAppStateForUpdate(ops, appName); + Map appState; + + if (appStateOpt.isPresent()) { + appState = fromJson(appStateOpt.get().getState()); + if (!appStateMap.isEmpty()) { + appState.putAll(appStateMap); + AppStateRow updatedAppState = new AppStateRow(); + updatedAppState.setAppName(appName); + updatedAppState.setState(toJson(appState)); + updatedAppState.setUpdateTime(now); + stateDao.upsertAppState(ops, updatedAppState); + } + } else if (!appStateMap.isEmpty()) { + appState = new ConcurrentHashMap<>(appStateMap); + AppStateRow newAppState = new AppStateRow(); + newAppState.setAppName(appName); + newAppState.setState(toJson(appState)); + newAppState.setUpdateTime(now); + stateDao.upsertAppState(ops, newAppState); + } else { + appState = new ConcurrentHashMap<>(); + } + + return appState; + } + + /** + * Helper method to upsert user state if needed. + * + * @param ops transaction operations + * @param appName application name + * @param userId user identifier + * @param userStateMap state map to upsert (may be empty) + * @param now current timestamp + * @return the merged user state map + */ + private Map upsertUserStateIfNeeded( + JdbcTemplate.JdbcOperations ops, + String appName, + String userId, + Map userStateMap, + Instant now) + throws java.sql.SQLException { + Optional userStateOpt = stateDao.getUserStateForUpdate(ops, appName, userId); + Map userState; + + if (userStateOpt.isPresent()) { + userState = fromJson(userStateOpt.get().getState()); + if (!userStateMap.isEmpty()) { + userState.putAll(userStateMap); + UserStateRow updatedUserState = new UserStateRow(); + updatedUserState.setAppName(appName); + updatedUserState.setUserId(userId); + updatedUserState.setState(toJson(userState)); + updatedUserState.setUpdateTime(now); + stateDao.upsertUserState(ops, updatedUserState); + } + } else if (!userStateMap.isEmpty()) { + userState = new ConcurrentHashMap<>(userStateMap); + UserStateRow newUserState = new UserStateRow(); + newUserState.setAppName(appName); + newUserState.setUserId(userId); + newUserState.setState(toJson(userState)); + newUserState.setUpdateTime(now); + stateDao.upsertUserState(ops, newUserState); + } else { + userState = new ConcurrentHashMap<>(); + } + + return userState; + } + + /** + * Closes this service and releases all database connections. + * + *

This method shuts down the HikariCP connection pool and releases all associated resources. + * After calling this method, the service cannot be used again - create a new instance if needed. + * + *

This method is idempotent - calling it multiple times has no additional effect. + * + *

Thread Safety: This method uses atomic compare-and-set to ensure the connection pool + * is closed exactly once, even if called concurrently from multiple threads. + * + *

Best Practice: Use try-with-resources to ensure automatic cleanup: + * + *

{@code
+   * try (DatabaseSessionService service = new DatabaseSessionService(jdbcUrl)) {
+   *     // Use the service
+   * } // Automatically closed
+   * }
+ * + * @throws IllegalStateException if any operations are attempted after closing + */ + @Override + public void close() { + if (closed.compareAndSet(false, true)) { + logger.info("Closing DatabaseSessionService"); + if (dataSource != null && !dataSource.isClosed()) { + dataSource.close(); + logger.info("HikariCP connection pool closed"); + } + } + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/EventDao.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/EventDao.java new file mode 100644 index 000000000..79f49f5c4 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/EventDao.java @@ -0,0 +1,141 @@ +package com.google.adk.sessions.dao; + +import com.google.adk.sessions.dialect.SqlDialect; +import com.google.adk.sessions.model.EventRow; +import com.google.adk.sessions.util.JdbcTemplate.JdbcOperations; +import com.google.adk.sessions.util.RowMapper; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.util.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class EventDao { + + private static final Logger logger = LoggerFactory.getLogger(EventDao.class); + private final SqlDialect dialect; + + public EventDao(SqlDialect dialect) { + this.dialect = dialect; + logger.debug("EventDao initialized with {} dialect", dialect.dialectName()); + } + + private static final RowMapper ROW_MAPPER = + rs -> { + EventRow row = new EventRow(); + row.setId(rs.getString("id")); + row.setAppName(rs.getString("app_name")); + row.setUserId(rs.getString("user_id")); + row.setSessionId(rs.getString("session_id")); + row.setInvocationId(rs.getString("invocation_id")); + row.setEventData(rs.getString("event_data")); + + Timestamp ts = rs.getTimestamp("timestamp"); + row.setTimestamp(ts != null ? ts.toInstant() : null); + + return row; + }; + + public List listEvents( + JdbcOperations ops, String appName, String userId, String sessionId) throws SQLException { + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + params.put("sessionId", sessionId); + + String sql = + "SELECT * FROM events " + + "WHERE app_name = :appName AND user_id = :userId AND session_id = :sessionId " + + "ORDER BY timestamp ASC"; + + return ops.query(sql, params, ROW_MAPPER); + } + + public List listEvents( + JdbcOperations ops, String appName, String userId, String sessionId, Optional limit) + throws SQLException { + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + params.put("sessionId", sessionId); + + String sql = + "SELECT * FROM events " + + "WHERE app_name = :appName AND user_id = :userId AND session_id = :sessionId " + + "ORDER BY timestamp DESC"; + + if (limit.isPresent()) { + sql += " LIMIT :limit"; + params.put("limit", limit.get()); + } + + List events = ops.query(sql, params, ROW_MAPPER); + Collections.reverse(events); + return events; + } + + public List listEventsAfterTimestamp( + JdbcOperations ops, + String appName, + String userId, + String sessionId, + java.time.Instant afterTimestamp, + Optional limit, + int offset) + throws SQLException { + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + params.put("sessionId", sessionId); + params.put("afterTimestamp", java.sql.Timestamp.from(afterTimestamp)); + + String sql = + "SELECT * FROM events " + + "WHERE app_name = :appName AND user_id = :userId AND session_id = :sessionId " + + "AND timestamp > :afterTimestamp " + + "ORDER BY timestamp ASC"; + + if (limit.isPresent()) { + sql += " LIMIT :limit OFFSET :offset"; + params.put("limit", limit.get()); + params.put("offset", offset); + } + + return ops.query(sql, params, ROW_MAPPER); + } + + public void insertEvent(JdbcOperations ops, EventRow event) throws SQLException { + String sql = + "INSERT INTO events (id, app_name, user_id, session_id, invocation_id, timestamp, event_data) " + + "VALUES (:id, :appName, :userId, :sessionId, :invocationId, :timestamp, " + + dialect.jsonValue(":eventData") + + ")"; + + Map params = new HashMap<>(); + params.put("id", event.getId()); + params.put("appName", event.getAppName()); + params.put("userId", event.getUserId()); + params.put("sessionId", event.getSessionId()); + params.put("invocationId", event.getInvocationId()); + params.put("timestamp", Timestamp.from(event.getTimestamp())); + params.put("eventData", event.getEventData()); + + logger.debug("Appending event: eventId={}, sessionId={}", event.getId(), event.getSessionId()); + + ops.update(sql, params); + } + + public long countEvents(JdbcOperations ops, String appName, String userId, String sessionId) + throws SQLException { + String sql = + "SELECT COUNT(*) as count FROM events " + + "WHERE app_name = :appName AND user_id = :userId AND session_id = :sessionId"; + + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + params.put("sessionId", sessionId); + + return ops.queryForObject(sql, params, rs -> rs.getLong("count")).orElse(0L); + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/SessionDao.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/SessionDao.java new file mode 100644 index 000000000..71e664390 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/SessionDao.java @@ -0,0 +1,131 @@ +package com.google.adk.sessions.dao; + +import com.google.adk.sessions.dialect.SqlDialect; +import com.google.adk.sessions.model.SessionRow; +import com.google.adk.sessions.util.JdbcTemplate.JdbcOperations; +import com.google.adk.sessions.util.RowMapper; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SessionDao { + + private static final Logger logger = LoggerFactory.getLogger(SessionDao.class); + private final SqlDialect dialect; + + public SessionDao(SqlDialect dialect) { + this.dialect = dialect; + logger.debug("SessionDao initialized with {} dialect", dialect.dialectName()); + } + + private static final RowMapper ROW_MAPPER = + rs -> { + SessionRow row = new SessionRow(); + row.setAppName(rs.getString("app_name")); + row.setUserId(rs.getString("user_id")); + row.setId(rs.getString("id")); + row.setState(rs.getString("state")); + + Timestamp createTs = rs.getTimestamp("create_time"); + row.setCreateTime(createTs != null ? createTs.toInstant() : null); + + Timestamp updateTs = rs.getTimestamp("update_time"); + row.setUpdateTime(updateTs != null ? updateTs.toInstant() : null); + + return row; + }; + + public Optional findSession( + JdbcOperations ops, String appName, String userId, String id) throws SQLException { + String sql = + "SELECT * FROM sessions " + "WHERE app_name = :appName AND user_id = :userId AND id = :id"; + + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + params.put("id", id); + + return ops.queryForObject(sql, params, ROW_MAPPER); + } + + public Optional findSessionForUpdate( + JdbcOperations ops, String appName, String userId, String id) throws SQLException { + String sql = + "SELECT * FROM sessions " + + "WHERE app_name = :appName AND user_id = :userId AND id = :id " + + dialect.forUpdateSyntax(); + + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + params.put("id", id); + + return ops.queryForObject(sql, params, ROW_MAPPER); + } + + public List listSessions(JdbcOperations ops, String appName, String userId) + throws SQLException { + String sql = + "SELECT * FROM sessions " + + "WHERE app_name = :appName AND user_id = :userId " + + "ORDER BY update_time DESC"; + + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + + return ops.query(sql, params, ROW_MAPPER); + } + + public void insertSession(JdbcOperations ops, SessionRow session) throws SQLException { + String sql = + "INSERT INTO sessions (app_name, user_id, id, state, create_time, update_time) " + + "VALUES (:appName, :userId, :id, " + + dialect.jsonValue(":state") + + ", :createTime, :updateTime)"; + + Map params = new HashMap<>(); + params.put("appName", session.getAppName()); + params.put("userId", session.getUserId()); + params.put("id", session.getId()); + params.put("state", session.getState()); + params.put("createTime", Timestamp.from(session.getCreateTime())); + params.put("updateTime", Timestamp.from(session.getUpdateTime())); + + logger.debug( + "Inserting session: app={}, user={}, sessionId={}", + session.getAppName(), + session.getUserId(), + session.getId()); + ops.update(sql, params); + logger.debug("Session created successfully: {}", session.getId()); + } + + public void updateSession(JdbcOperations ops, SessionRow session) throws SQLException { + String sql = + "UPDATE sessions " + + "SET state = " + + dialect.jsonValue(":state") + + ", update_time = :updateTime " + + "WHERE app_name = :appName AND user_id = :userId AND id = :id"; + + Map params = new HashMap<>(); + params.put("state", session.getState()); + params.put("updateTime", Timestamp.from(session.getUpdateTime())); + params.put("appName", session.getAppName()); + params.put("userId", session.getUserId()); + params.put("id", session.getId()); + + ops.update(sql, params); + } + + public void deleteSession(JdbcOperations ops, String appName, String userId, String id) + throws SQLException { + dialect.deleteSession(ops, appName, userId, id); + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/StateDao.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/StateDao.java new file mode 100644 index 000000000..1a8b672f9 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dao/StateDao.java @@ -0,0 +1,110 @@ +package com.google.adk.sessions.dao; + +import com.google.adk.sessions.dialect.SqlDialect; +import com.google.adk.sessions.model.AppStateRow; +import com.google.adk.sessions.model.UserStateRow; +import com.google.adk.sessions.util.JdbcTemplate.JdbcOperations; +import com.google.adk.sessions.util.RowMapper; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class StateDao { + + private static final Logger logger = LoggerFactory.getLogger(StateDao.class); + private final SqlDialect dialect; + + public StateDao(SqlDialect dialect) { + this.dialect = dialect; + } + + private static final RowMapper APP_STATE_MAPPER = + rs -> { + AppStateRow row = new AppStateRow(); + row.setAppName(rs.getString("app_name")); + row.setState(rs.getString("state")); + + Timestamp updateTs = rs.getTimestamp("update_time"); + row.setUpdateTime(updateTs != null ? updateTs.toInstant() : null); + + return row; + }; + + private static final RowMapper USER_STATE_MAPPER = + rs -> { + UserStateRow row = new UserStateRow(); + row.setAppName(rs.getString("app_name")); + row.setUserId(rs.getString("user_id")); + row.setState(rs.getString("state")); + + Timestamp updateTs = rs.getTimestamp("update_time"); + row.setUpdateTime(updateTs != null ? updateTs.toInstant() : null); + + return row; + }; + + public Optional getAppState(JdbcOperations ops, String appName) throws SQLException { + String sql = "SELECT * FROM app_states WHERE app_name = :appName"; + + Map params = new HashMap<>(); + params.put("appName", appName); + + return ops.queryForObject(sql, params, APP_STATE_MAPPER); + } + + public Optional getAppStateForUpdate(JdbcOperations ops, String appName) + throws SQLException { + String sql = "SELECT * FROM app_states WHERE app_name = :appName " + dialect.forUpdateSyntax(); + + Map params = new HashMap<>(); + params.put("appName", appName); + + return ops.queryForObject(sql, params, APP_STATE_MAPPER); + } + + public void upsertAppState(JdbcOperations ops, AppStateRow appState) throws SQLException { + logger.debug("Upserting app state for app: {}", appState.getAppName()); + dialect.upsertAppState(ops, appState); + logger.debug("App state upserted successfully for app: {}", appState.getAppName()); + } + + public Optional getUserState(JdbcOperations ops, String appName, String userId) + throws SQLException { + String sql = "SELECT * FROM user_states WHERE app_name = :appName AND user_id = :userId"; + + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + + return ops.queryForObject(sql, params, USER_STATE_MAPPER); + } + + public Optional getUserStateForUpdate( + JdbcOperations ops, String appName, String userId) throws SQLException { + String sql = + "SELECT * FROM user_states WHERE app_name = :appName AND user_id = :userId " + + dialect.forUpdateSyntax(); + + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + + return ops.queryForObject(sql, params, USER_STATE_MAPPER); + } + + public void upsertUserState(JdbcOperations ops, UserStateRow userState) throws SQLException { + logger.debug( + "Upserting user state for app: {}, user: {}", + userState.getAppName(), + userState.getUserId()); + dialect.upsertUserState(ops, userState); + logger.debug( + "User state upserted successfully for app: {}, user: {}", + userState.getAppName(), + userState.getUserId()); + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/DialectDetector.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/DialectDetector.java new file mode 100644 index 000000000..0aa7afe77 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/DialectDetector.java @@ -0,0 +1,45 @@ +package com.google.adk.sessions.dialect; + +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.SQLException; + +public class DialectDetector { + + public static SqlDialect detect(Connection connection) throws SQLException { + DatabaseMetaData metaData = connection.getMetaData(); + String productName = metaData.getDatabaseProductName().toLowerCase(); + + if (productName.contains("postgresql")) { + return new PostgresDialect(); + } else if (productName.contains("mysql")) { + return new MySqlDialect(); + } else if (productName.contains("h2")) { + return new H2Dialect(); + } else if (productName.contains("spanner")) { + return new SpannerDialect(); + } else { + throw new IllegalArgumentException( + "Unsupported database: " + + productName + + ". " + + "Supported databases: PostgreSQL, MySQL, H2, Cloud Spanner"); + } + } + + public static SqlDialect detectFromJdbcUrl(String jdbcUrl) { + String url = jdbcUrl.toLowerCase(); + + if (url.startsWith("jdbc:postgresql:")) { + return new PostgresDialect(); + } else if (url.startsWith("jdbc:mysql:")) { + return new MySqlDialect(); + } else if (url.startsWith("jdbc:h2:")) { + return new H2Dialect(); + } else if (url.startsWith("jdbc:cloudspanner:")) { + return new SpannerDialect(); + } else { + throw new IllegalArgumentException("Cannot detect dialect from JDBC URL: " + jdbcUrl); + } + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/H2Dialect.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/H2Dialect.java new file mode 100644 index 000000000..d75a66dc7 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/H2Dialect.java @@ -0,0 +1,33 @@ +package com.google.adk.sessions.dialect; + +public class H2Dialect implements SqlDialect { + + @Override + public String dialectName() { + return "H2"; + } + + @Override + public String jsonCastSyntax() { + return ""; + } + + @Override + public String forUpdateSyntax() { + return "FOR UPDATE"; + } + + @Override + public String upsertAppStateSql() { + return "MERGE INTO app_states (app_name, state, update_time) " + + "KEY (app_name) " + + "VALUES (:appName, :state, :updateTime)"; + } + + @Override + public String upsertUserStateSql() { + return "MERGE INTO user_states (app_name, user_id, state, update_time) " + + "KEY (app_name, user_id) " + + "VALUES (:appName, :userId, :state, :updateTime)"; + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/MySqlDialect.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/MySqlDialect.java new file mode 100644 index 000000000..b30e37d78 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/MySqlDialect.java @@ -0,0 +1,33 @@ +package com.google.adk.sessions.dialect; + +public class MySqlDialect implements SqlDialect { + + @Override + public String dialectName() { + return "MySQL"; + } + + @Override + public String jsonCastSyntax() { + return ""; + } + + @Override + public String forUpdateSyntax() { + return "FOR UPDATE"; + } + + @Override + public String upsertAppStateSql() { + return "INSERT INTO app_states (app_name, state, update_time) " + + "VALUES (:appName, :state, :updateTime) " + + "ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = VALUES(update_time)"; + } + + @Override + public String upsertUserStateSql() { + return "INSERT INTO user_states (app_name, user_id, state, update_time) " + + "VALUES (:appName, :userId, :state, :updateTime) " + + "ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = VALUES(update_time)"; + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/PostgresDialect.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/PostgresDialect.java new file mode 100644 index 000000000..c69c8f1b1 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/PostgresDialect.java @@ -0,0 +1,35 @@ +package com.google.adk.sessions.dialect; + +public class PostgresDialect implements SqlDialect { + + @Override + public String dialectName() { + return "PostgreSQL"; + } + + @Override + public String jsonCastSyntax() { + return "::jsonb"; + } + + @Override + public String forUpdateSyntax() { + return "FOR UPDATE"; + } + + @Override + public String upsertAppStateSql() { + return "INSERT INTO app_states (app_name, state, update_time) " + + "VALUES (:appName, :state::jsonb, :updateTime) " + + "ON CONFLICT (app_name) " + + "DO UPDATE SET state = EXCLUDED.state, update_time = EXCLUDED.update_time"; + } + + @Override + public String upsertUserStateSql() { + return "INSERT INTO user_states (app_name, user_id, state, update_time) " + + "VALUES (:appName, :userId, :state::jsonb, :updateTime) " + + "ON CONFLICT (app_name, user_id) " + + "DO UPDATE SET state = EXCLUDED.state, update_time = EXCLUDED.update_time"; + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/SpannerDialect.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/SpannerDialect.java new file mode 100644 index 000000000..b163a0013 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/SpannerDialect.java @@ -0,0 +1,65 @@ +package com.google.adk.sessions.dialect; + +import com.google.adk.sessions.util.JdbcTemplate.JdbcOperations; +import java.sql.SQLException; +import java.util.HashMap; +import java.util.Map; + +public class SpannerDialect implements SqlDialect { + + @Override + public String dialectName() { + return "Cloud Spanner"; + } + + @Override + public String jsonCastSyntax() { + return ""; + } + + @Override + public String jsonValue(String paramName) { + return "PARSE_JSON(" + paramName + ")"; + } + + @Override + public String forUpdateSyntax() { + return ""; + } + + @Override + public String upsertAppStateSql() { + return "INSERT OR UPDATE app_states (app_name, state, update_time) " + + "VALUES (:appName, " + + jsonValue(":state") + + ", :updateTime)"; + } + + @Override + public String upsertUserStateSql() { + return "INSERT OR UPDATE user_states (app_name, user_id, state, update_time) " + + "VALUES (:appName, :userId, " + + jsonValue(":state") + + ", :updateTime)"; + } + + @Override + public void deleteSession(JdbcOperations ops, String appName, String userId, String sessionId) + throws SQLException { + String deleteEventsSql = + "DELETE FROM events WHERE app_name = :appName AND user_id = :userId AND session_id = :sessionId"; + Map eventsParams = new HashMap<>(); + eventsParams.put("appName", appName); + eventsParams.put("userId", userId); + eventsParams.put("sessionId", sessionId); + ops.update(deleteEventsSql, eventsParams); + + String deleteSessionSql = + "DELETE FROM sessions WHERE app_name = :appName AND user_id = :userId AND id = :id"; + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + params.put("id", sessionId); + ops.update(deleteSessionSql, params); + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/SqlDialect.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/SqlDialect.java new file mode 100644 index 000000000..b9089033e --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/dialect/SqlDialect.java @@ -0,0 +1,56 @@ +package com.google.adk.sessions.dialect; + +import com.google.adk.sessions.model.AppStateRow; +import com.google.adk.sessions.model.UserStateRow; +import com.google.adk.sessions.util.JdbcTemplate.JdbcOperations; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.util.HashMap; +import java.util.Map; + +public interface SqlDialect { + + String dialectName(); + + String jsonCastSyntax(); + + String forUpdateSyntax(); + + String upsertAppStateSql(); + + String upsertUserStateSql(); + + default String jsonValue(String paramName) { + return paramName + jsonCastSyntax(); + } + + default void upsertAppState(JdbcOperations ops, AppStateRow appState) throws SQLException { + String sql = upsertAppStateSql(); + Map params = new HashMap<>(); + params.put("appName", appState.getAppName()); + params.put("state", appState.getState()); + params.put("updateTime", Timestamp.from(appState.getUpdateTime())); + ops.update(sql, params); + } + + default void upsertUserState(JdbcOperations ops, UserStateRow userState) throws SQLException { + String sql = upsertUserStateSql(); + Map params = new HashMap<>(); + params.put("appName", userState.getAppName()); + params.put("userId", userState.getUserId()); + params.put("state", userState.getState()); + params.put("updateTime", Timestamp.from(userState.getUpdateTime())); + ops.update(sql, params); + } + + default void deleteSession(JdbcOperations ops, String appName, String userId, String sessionId) + throws SQLException { + String sql = + "DELETE FROM sessions WHERE app_name = :appName AND user_id = :userId AND id = :id"; + Map params = new HashMap<>(); + params.put("appName", appName); + params.put("userId", userId); + params.put("id", sessionId); + ops.update(sql, params); + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/AppStateRow.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/AppStateRow.java new file mode 100644 index 000000000..8060a0416 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/AppStateRow.java @@ -0,0 +1,33 @@ +package com.google.adk.sessions.model; + +import java.time.Instant; + +public class AppStateRow { + private String appName; + private String state; + private Instant updateTime; + + public String getAppName() { + return appName; + } + + public void setAppName(String appName) { + this.appName = appName; + } + + public String getState() { + return state; + } + + public void setState(String state) { + this.state = state; + } + + public Instant getUpdateTime() { + return updateTime; + } + + public void setUpdateTime(Instant updateTime) { + this.updateTime = updateTime; + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/EventRow.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/EventRow.java new file mode 100644 index 000000000..607c5df9d --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/EventRow.java @@ -0,0 +1,69 @@ +package com.google.adk.sessions.model; + +import java.time.Instant; + +public class EventRow { + private String id; + private String appName; + private String userId; + private String sessionId; + private String invocationId; + private Instant timestamp; + private String eventData; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getAppName() { + return appName; + } + + public void setAppName(String appName) { + this.appName = appName; + } + + public String getUserId() { + return userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + public String getSessionId() { + return sessionId; + } + + public void setSessionId(String sessionId) { + this.sessionId = sessionId; + } + + public String getInvocationId() { + return invocationId; + } + + public void setInvocationId(String invocationId) { + this.invocationId = invocationId; + } + + public Instant getTimestamp() { + return timestamp; + } + + public void setTimestamp(Instant timestamp) { + this.timestamp = timestamp; + } + + public String getEventData() { + return eventData; + } + + public void setEventData(String eventData) { + this.eventData = eventData; + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/SessionRow.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/SessionRow.java new file mode 100644 index 000000000..0dcf1a97e --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/SessionRow.java @@ -0,0 +1,60 @@ +package com.google.adk.sessions.model; + +import java.time.Instant; + +public class SessionRow { + private String appName; + private String userId; + private String id; + private String state; + private Instant createTime; + private Instant updateTime; + + public String getAppName() { + return appName; + } + + public void setAppName(String appName) { + this.appName = appName; + } + + public String getUserId() { + return userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getState() { + return state; + } + + public void setState(String state) { + this.state = state; + } + + public Instant getCreateTime() { + return createTime; + } + + public void setCreateTime(Instant createTime) { + this.createTime = createTime; + } + + public Instant getUpdateTime() { + return updateTime; + } + + public void setUpdateTime(Instant updateTime) { + this.updateTime = updateTime; + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/UserStateRow.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/UserStateRow.java new file mode 100644 index 000000000..bd02f0ddb --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/model/UserStateRow.java @@ -0,0 +1,42 @@ +package com.google.adk.sessions.model; + +import java.time.Instant; + +public class UserStateRow { + private String appName; + private String userId; + private String state; + private Instant updateTime; + + public String getAppName() { + return appName; + } + + public void setAppName(String appName) { + this.appName = appName; + } + + public String getUserId() { + return userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + public String getState() { + return state; + } + + public void setState(String state) { + this.state = state; + } + + public Instant getUpdateTime() { + return updateTime; + } + + public void setUpdateTime(Instant updateTime) { + this.updateTime = updateTime; + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/util/JdbcTemplate.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/util/JdbcTemplate.java new file mode 100644 index 000000000..b0aa3ee66 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/util/JdbcTemplate.java @@ -0,0 +1,101 @@ +package com.google.adk.sessions.util; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import javax.sql.DataSource; + +public class JdbcTemplate { + + private final DataSource dataSource; + + public JdbcTemplate(DataSource dataSource) { + this.dataSource = dataSource; + } + + public T inTransaction(TransactionCallback callback) throws SQLException { + try (Connection conn = dataSource.getConnection()) { + boolean originalAutoCommit = conn.getAutoCommit(); + try { + conn.setAutoCommit(false); + T result = callback.doInTransaction(new JdbcOperations(conn)); + conn.commit(); + return result; + } catch (Exception e) { + conn.rollback(); + throw e; + } finally { + conn.setAutoCommit(originalAutoCommit); + } + } + } + + @FunctionalInterface + public interface TransactionCallback { + T doInTransaction(JdbcOperations ops) throws SQLException; + } + + public static class JdbcOperations { + private final Connection connection; + + JdbcOperations(Connection connection) { + this.connection = connection; + } + + public Connection getConnection() { + return connection; + } + + public Optional queryForObject( + String sql, Map params, RowMapper mapper) throws SQLException { + NamedParameterSupport nps = NamedParameterSupport.parse(sql); + + try (PreparedStatement ps = connection.prepareStatement(nps.getParsedSql())) { + nps.setParameters(ps, params); + + try (ResultSet rs = ps.executeQuery()) { + if (rs.next()) { + return Optional.of(mapper.mapRow(rs)); + } + return Optional.empty(); + } + } + } + + public List query(String sql, Map params, RowMapper mapper) + throws SQLException { + NamedParameterSupport nps = NamedParameterSupport.parse(sql); + List results = new ArrayList<>(); + + try (PreparedStatement ps = connection.prepareStatement(nps.getParsedSql())) { + nps.setParameters(ps, params); + + try (ResultSet rs = ps.executeQuery()) { + while (rs.next()) { + results.add(mapper.mapRow(rs)); + } + } + } + + return results; + } + + public int update(String sql, Map params) throws SQLException { + NamedParameterSupport nps = NamedParameterSupport.parse(sql); + + try (PreparedStatement ps = connection.prepareStatement(nps.getParsedSql())) { + nps.setParameters(ps, params); + return ps.executeUpdate(); + } + } + + public int execute(String sql, Map params) throws SQLException { + return update(sql, params); + } + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/util/NamedParameterSupport.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/util/NamedParameterSupport.java new file mode 100644 index 000000000..4c4ab705d --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/util/NamedParameterSupport.java @@ -0,0 +1,58 @@ +package com.google.adk.sessions.util; + +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class NamedParameterSupport { + + private static final Pattern NAMED_PARAM_PATTERN = Pattern.compile("(? parameterNames; + + private NamedParameterSupport(String parsedSql, List parameterNames) { + this.parsedSql = parsedSql; + this.parameterNames = parameterNames; + } + + public static NamedParameterSupport parse(String namedSql) { + List parameterNames = new ArrayList<>(); + Matcher matcher = NAMED_PARAM_PATTERN.matcher(namedSql); + StringBuffer parsedSql = new StringBuffer(); + + while (matcher.find()) { + String paramName = matcher.group(1); + parameterNames.add(paramName); + matcher.appendReplacement(parsedSql, "?"); + } + matcher.appendTail(parsedSql); + + return new NamedParameterSupport(parsedSql.toString(), parameterNames); + } + + public String getParsedSql() { + return parsedSql; + } + + public void setParameters(PreparedStatement ps, Map params) throws SQLException { + for (int i = 0; i < parameterNames.size(); i++) { + String paramName = parameterNames.get(i); + + if (!params.containsKey(paramName)) { + throw new IllegalArgumentException("Missing parameter: " + paramName); + } + + Object value = params.get(paramName); + ps.setObject(i + 1, value); + } + } + + public List getParameterNames() { + return new ArrayList<>(parameterNames); + } +} diff --git a/contrib/database-session-service/src/main/java/com/google/adk/sessions/util/RowMapper.java b/contrib/database-session-service/src/main/java/com/google/adk/sessions/util/RowMapper.java new file mode 100644 index 000000000..312c67d60 --- /dev/null +++ b/contrib/database-session-service/src/main/java/com/google/adk/sessions/util/RowMapper.java @@ -0,0 +1,9 @@ +package com.google.adk.sessions.util; + +import java.sql.ResultSet; +import java.sql.SQLException; + +@FunctionalInterface +public interface RowMapper { + T mapRow(ResultSet rs) throws SQLException; +} diff --git a/contrib/database-session-service/src/main/resources/db/migration/h2/V1__Initial_schema.sql b/contrib/database-session-service/src/main/resources/db/migration/h2/V1__Initial_schema.sql new file mode 100644 index 000000000..b15292da3 --- /dev/null +++ b/contrib/database-session-service/src/main/resources/db/migration/h2/V1__Initial_schema.sql @@ -0,0 +1,66 @@ +-- V1__Initial_schema.sql for H2 Database +-- Initial database schema for ADK DatabaseSessionService (v1 format) +-- This schema matches Python ADK v1 with simplified event storage using JSON + +-- Create metadata table for schema versioning +CREATE TABLE IF NOT EXISTS adk_internal_metadata ( + "KEY" VARCHAR(128) PRIMARY KEY, + "VALUE" VARCHAR(256) +); + +-- Insert schema version (1 = v1 JSON schema format, compatible with Python ADK) +MERGE INTO adk_internal_metadata ("KEY", "VALUE") VALUES ('schema_version', '1'); + +-- Create sessions table +CREATE TABLE IF NOT EXISTS sessions ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + id VARCHAR(128) NOT NULL, + state CLOB, + create_time TIMESTAMP(6), + update_time TIMESTAMP(6), + PRIMARY KEY (app_name, user_id, id) +); + +-- Create events table (v1 format with event_data JSON column) +CREATE TABLE IF NOT EXISTS events ( + id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256), + timestamp TIMESTAMP(6), + event_data CLOB, + PRIMARY KEY (id, app_name, user_id, session_id), + FOREIGN KEY (app_name, user_id, session_id) + REFERENCES sessions(app_name, user_id, id) + ON DELETE CASCADE +); + +-- Create app states table +CREATE TABLE IF NOT EXISTS app_states ( + app_name VARCHAR(128) NOT NULL, + state CLOB, + update_time TIMESTAMP(6), + PRIMARY KEY (app_name) +); + +-- Create user states table +CREATE TABLE IF NOT EXISTS user_states ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state CLOB, + update_time TIMESTAMP(6), + PRIMARY KEY (app_name, user_id) +); + +-- Add indexes to improve query performance + +-- Index for looking up sessions by app_name and user_id +CREATE INDEX IF NOT EXISTS idx_sessions_app_user ON sessions(app_name, user_id); + +-- Index for looking up events by session +CREATE INDEX IF NOT EXISTS idx_events_session ON events(app_name, user_id, session_id); + +-- Index for sorting events by timestamp +CREATE INDEX IF NOT EXISTS idx_events_timestamp ON events(timestamp); diff --git a/contrib/database-session-service/src/main/resources/db/migration/mysql/V1__Initial_schema.sql b/contrib/database-session-service/src/main/resources/db/migration/mysql/V1__Initial_schema.sql new file mode 100644 index 000000000..6d8184799 --- /dev/null +++ b/contrib/database-session-service/src/main/resources/db/migration/mysql/V1__Initial_schema.sql @@ -0,0 +1,66 @@ +-- V1__Initial_schema.sql for MySQL +-- Initial database schema for ADK DatabaseSessionService (v1 format) +-- This schema matches Python ADK v1 with simplified event storage using JSON + +-- Create metadata table for schema versioning +CREATE TABLE IF NOT EXISTS adk_internal_metadata ( + `key` VARCHAR(128) PRIMARY KEY, + value VARCHAR(256) +); + +-- Insert schema version (1 = v1 JSON schema format, compatible with Python ADK) +INSERT INTO adk_internal_metadata (`key`, value) VALUES ('schema_version', '1'); + +-- Create sessions table +CREATE TABLE IF NOT EXISTS sessions ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + id VARCHAR(128) NOT NULL, + state LONGTEXT, + create_time TIMESTAMP(6), + update_time TIMESTAMP(6), + PRIMARY KEY (app_name, user_id, id) +); + +-- Create events table (v1 format with event_data JSON column) +CREATE TABLE IF NOT EXISTS events ( + id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256), + timestamp TIMESTAMP(6), + event_data LONGTEXT, + PRIMARY KEY (id, app_name, user_id, session_id), + FOREIGN KEY (app_name, user_id, session_id) + REFERENCES sessions(app_name, user_id, id) + ON DELETE CASCADE +); + +-- Create app states table +CREATE TABLE IF NOT EXISTS app_states ( + app_name VARCHAR(128) NOT NULL, + state LONGTEXT, + update_time TIMESTAMP(6), + PRIMARY KEY (app_name) +); + +-- Create user states table +CREATE TABLE IF NOT EXISTS user_states ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state LONGTEXT, + update_time TIMESTAMP(6), + PRIMARY KEY (app_name, user_id) +); + +-- Add indexes to improve query performance + +-- Index for looking up sessions by app_name and user_id +CREATE INDEX idx_sessions_app_user ON sessions(app_name, user_id); + +-- Index for looking up events by session +CREATE INDEX idx_events_session ON events(app_name, user_id, session_id); + +-- Index for sorting events by timestamp +CREATE INDEX idx_events_timestamp ON events(timestamp); diff --git a/contrib/database-session-service/src/main/resources/db/migration/postgresql/V1__Initial_schema.sql b/contrib/database-session-service/src/main/resources/db/migration/postgresql/V1__Initial_schema.sql new file mode 100644 index 000000000..810a2eb87 --- /dev/null +++ b/contrib/database-session-service/src/main/resources/db/migration/postgresql/V1__Initial_schema.sql @@ -0,0 +1,66 @@ +-- V1__Initial_schema.sql for PostgreSQL +-- Initial database schema for ADK DatabaseSessionService (v1 format) +-- This schema matches Python ADK v1 with simplified event storage using JSON + +-- Create metadata table for schema versioning +CREATE TABLE IF NOT EXISTS adk_internal_metadata ( + key VARCHAR(128) PRIMARY KEY, + value VARCHAR(256) +); + +-- Insert schema version (1 = v1 JSON schema format, compatible with Python ADK) +INSERT INTO adk_internal_metadata (key, value) VALUES ('schema_version', '1'); + +-- Create sessions table +CREATE TABLE IF NOT EXISTS sessions ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + id VARCHAR(128) NOT NULL, + state JSONB, + create_time TIMESTAMP, + update_time TIMESTAMP, + PRIMARY KEY (app_name, user_id, id) +); + +-- Create events table (v1 format with event_data JSON column) +CREATE TABLE IF NOT EXISTS events ( + id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256), + timestamp TIMESTAMP, + event_data JSONB, + PRIMARY KEY (id, app_name, user_id, session_id), + FOREIGN KEY (app_name, user_id, session_id) + REFERENCES sessions(app_name, user_id, id) + ON DELETE CASCADE +); + +-- Create app states table +CREATE TABLE IF NOT EXISTS app_states ( + app_name VARCHAR(128) NOT NULL, + state JSONB, + update_time TIMESTAMP, + PRIMARY KEY (app_name) +); + +-- Create user states table +CREATE TABLE IF NOT EXISTS user_states ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSONB, + update_time TIMESTAMP, + PRIMARY KEY (app_name, user_id) +); + +-- Add indexes to improve query performance + +-- Index for looking up sessions by app_name and user_id +CREATE INDEX IF NOT EXISTS idx_sessions_app_user ON sessions(app_name, user_id); + +-- Index for looking up events by session +CREATE INDEX IF NOT EXISTS idx_events_session ON events(app_name, user_id, session_id); + +-- Index for sorting events by timestamp +CREATE INDEX IF NOT EXISTS idx_events_timestamp ON events(timestamp); diff --git a/contrib/database-session-service/src/main/resources/db/migration/spanner/V1__Initial_schema.sql b/contrib/database-session-service/src/main/resources/db/migration/spanner/V1__Initial_schema.sql new file mode 100644 index 000000000..f169d0008 --- /dev/null +++ b/contrib/database-session-service/src/main/resources/db/migration/spanner/V1__Initial_schema.sql @@ -0,0 +1,59 @@ +-- V1__Initial_schema.sql for Cloud Spanner +-- Initial database schema for ADK DatabaseSessionService (v1 format) +-- This schema matches Python ADK v1 with simplified event storage using JSON + +-- Create metadata table for schema versioning +CREATE TABLE adk_internal_metadata ( + key STRING(128) NOT NULL, + value STRING(256) +) PRIMARY KEY (key); + +-- Insert schema version (1 = v1 JSON schema format, compatible with Python ADK) +INSERT INTO adk_internal_metadata (key, value) VALUES ('schema_version', '1'); + +-- Create sessions table +CREATE TABLE sessions ( + app_name STRING(128) NOT NULL, + user_id STRING(128) NOT NULL, + id STRING(128) NOT NULL, + state JSON, + create_time TIMESTAMP, + update_time TIMESTAMP +) PRIMARY KEY (app_name, user_id, id); + +-- Create events table (v1 format with event_data JSON column) +-- Note: Spanner does not support traditional FOREIGN KEY constraints with ON DELETE CASCADE +-- We avoid INTERLEAVE IN PARENT to keep the schema simpler and compatible with the DAO layer +-- Applications must handle cascade deletes manually if needed +CREATE TABLE events ( + id STRING(128) NOT NULL, + app_name STRING(128) NOT NULL, + user_id STRING(128) NOT NULL, + session_id STRING(128) NOT NULL, + invocation_id STRING(256), + timestamp TIMESTAMP, + event_data JSON +) PRIMARY KEY (id, app_name, user_id, session_id); + +-- Create app states table +CREATE TABLE app_states ( + app_name STRING(128) NOT NULL, + state JSON, + update_time TIMESTAMP +) PRIMARY KEY (app_name); + +-- Create user states table +CREATE TABLE user_states ( + app_name STRING(128) NOT NULL, + user_id STRING(128) NOT NULL, + state JSON, + update_time TIMESTAMP +) PRIMARY KEY (app_name, user_id); + +-- Add indexes to improve query performance + +-- Index for looking up sessions by app_name and user_id +CREATE INDEX idx_sessions_app_user ON sessions(app_name, user_id); + +-- Index for sorting events by timestamp +CREATE INDEX idx_events_timestamp ON events(timestamp); diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/AppUserStateLockingTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/AppUserStateLockingTest.java new file mode 100644 index 000000000..61d6505ee --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/AppUserStateLockingTest.java @@ -0,0 +1,298 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import org.flywaydb.core.Flyway; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Tests for pessimistic locking on app_states and user_states tables. + * + *

This test verifies that concurrent updates to app-level and user-level state from multiple + * sessions do not result in lost updates. Without pessimistic locking, concurrent read-modify-write + * operations can overwrite each other's changes. + */ +public class AppUserStateLockingTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:app_user_locking_test;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + private static final String TEST_APP_NAME = "app-user-lock-test"; + private static final String TEST_USER_ID = "test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + flyway.clean(); + flyway.migrate(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + /** + * Tests that concurrent updates to app state from multiple threads on the SAME session preserve + * all changes. + * + *

Scenario: - 10 threads concurrently append events to the SAME session - Each event sets a + * unique key in app state - Expected: All 10 keys present - Without locking on app_states: some + * keys would be lost + * + *

Note: This tests the real-world pattern where events carry state deltas, not + * read-modify-write. + */ + @Test + public void testAppStateConcurrentUpdates_noLostUpdates() throws InterruptedException { + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + // Create initial session + String sharedSessionId = "shared-session"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sharedSessionId) + .blockingGet(); + + // Each thread appends event with a unique app state key + for (int i = 0; i < threadCount; i++) { + final int threadNum = i; + executor.submit( + () -> { + try { + // Each thread sets its own unique key in app state + ConcurrentMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("app:thread_" + threadNum, threadNum); + + Event event = + Event.builder() + .id("event-" + threadNum) + .author("thread-" + threadNum) + .content(Content.fromParts(Part.fromText("Increment app counter"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sharedSessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + + } catch (Exception e) { + throw new RuntimeException("Thread " + threadNum + " failed", e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(60, TimeUnit.SECONDS), "Threads did not complete in time"); + executor.shutdown(); + + // Verify final counter value + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sharedSessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(finalSession); + + // Check that all thread keys are present + for (int i = 0; i < threadCount; i++) { + String key = "app:thread_" + i; + assertTrue(finalSession.state().containsKey(key), "app:thread_" + i + " should exist"); + assertEquals(i, finalSession.state().get(key), "app:thread_" + i + " should equal " + i); + } + } + + /** + * Tests that concurrent updates to user state from multiple sessions preserve all changes. + * + *

Scenario: - Same user has 10 different sessions (e.g., phone, laptop, tablet) - Each session + * concurrently increments user:notification_count - Expected final value: 10 (all updates + * preserved) - Without locking: final value would be < 10 (lost updates) + */ + @Test + public void testUserStateConcurrentUpdates_noLostUpdates() throws InterruptedException { + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + // Create initial session + String sharedSessionId = "shared-session"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sharedSessionId) + .blockingGet(); + + // Each thread appends event with a unique user state key + for (int i = 0; i < threadCount; i++) { + final int threadNum = i; + executor.submit( + () -> { + try { + // Each thread sets its own unique key in user state + ConcurrentMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("user:thread_" + threadNum, threadNum); + + Event event = + Event.builder() + .id("notif-" + threadNum) + .author("device-" + threadNum) + .content(Content.fromParts(Part.fromText("New notification"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sharedSessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + + } catch (Exception e) { + throw new RuntimeException("Thread " + threadNum + " failed", e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(60, TimeUnit.SECONDS), "Threads did not complete in time"); + executor.shutdown(); + + // Verify final notification count + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sharedSessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(finalSession); + + // Check that all thread keys are present + for (int i = 0; i < threadCount; i++) { + String key = "user:thread_" + i; + assertTrue(finalSession.state().containsKey(key), "user:thread_" + i + " should exist"); + assertEquals(i, finalSession.state().get(key), "user:thread_" + i + " should equal " + i); + } + } + + /** + * Tests that concurrent updates to both app and user state work correctly. + * + *

Scenario: - 5 sessions concurrently update both app:total_requests and user:request_count - + * Tests that locks on app_states and user_states don't deadlock + */ + @Test + public void testConcurrentAppAndUserStateUpdates() throws InterruptedException { + int threadCount = 5; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + // Create initial session + String sharedSessionId = "shared-session"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sharedSessionId) + .blockingGet(); + + // Each thread appends event with both app and user state updates + for (int i = 0; i < threadCount; i++) { + final int threadNum = i; + executor.submit( + () -> { + try { + // Each thread sets unique keys in both app and user state + ConcurrentMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("app:req_" + threadNum, threadNum); + stateDelta.put("user:req_" + threadNum, threadNum); + + Event event = + Event.builder() + .id("req-" + threadNum) + .author("thread-" + threadNum) + .content(Content.fromParts(Part.fromText("API request"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sharedSessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + + } catch (Exception e) { + throw new RuntimeException("Thread " + threadNum + " failed", e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(60, TimeUnit.SECONDS), "Threads did not complete in time"); + executor.shutdown(); + + // Verify both counters + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sharedSessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(finalSession); + + // Check that all app and user keys are present + for (int i = 0; i < threadCount; i++) { + String appKey = "app:req_" + i; + String userKey = "user:req_" + i; + assertTrue(finalSession.state().containsKey(appKey), appKey + " should exist"); + assertTrue(finalSession.state().containsKey(userKey), userKey + " should exist"); + assertEquals(i, finalSession.state().get(appKey), appKey + " should equal " + i); + assertEquals(i, finalSession.state().get(userKey), userKey + " should equal " + i); + } + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/AppendEventRaceConditionTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/AppendEventRaceConditionTest.java new file mode 100644 index 000000000..aa0d676a4 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/AppendEventRaceConditionTest.java @@ -0,0 +1,300 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; +import java.time.Instant; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Test that demonstrates the race condition between event emission and persistence. + * + *

This test proves that using {@code doOnNext()} to append events creates a race condition where + * events flow downstream before being persisted to the database, while using {@code flatMap()} + * correctly waits for persistence to complete. + */ +@RunWith(JUnit4.class) +public class AppendEventRaceConditionTest { + + private DatabaseSessionService sessionService; + private static final String APP_NAME = "race-test-app"; + private static final String USER_ID = "race-test-user"; + + @Before + public void setUp() throws Exception { + sessionService = + new DatabaseSessionService("jdbc:h2:mem:race_test_db;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="); + } + + @After + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + /** + * This test demonstrates the RACE CONDITION with doOnNext(). + * + *

Timeline: T=0ms: Event emitted T=1ms: doOnNext() fires appendEvent() (doesn't wait!) T=2ms: + * Event flows downstream immediately T=5ms: We query listEvents() T=6ms: Query reads database + * Result: Event might NOT be in database yet! ← RACE CONDITION T=100ms: Database write finally + * completes + */ + @Test + public void testDoOnNext_hasRaceCondition() throws Exception { + Session testSession = + sessionService + .createSession(APP_NAME, USER_ID, new ConcurrentHashMap<>(), "session-doOnNext") + .blockingGet(); + + AtomicInteger eventsSeenInQuery = new AtomicInteger(0); + AtomicBoolean appendStarted = new AtomicBoolean(false); + CountDownLatch queryLatch = new CountDownLatch(1); + + Event testEvent = + Event.builder() + .id("race-event-1") + .invocationId("inv-1") + .author("test-agent") + .content(Content.builder().parts(Part.builder().text("Test").build()).build()) + .actions(EventActions.builder().build()) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Flowable.just(testEvent) + .doOnNext( + event -> { + appendStarted.set(true); + sessionService.appendEvent(testSession, event); + }) + .doOnNext( + event -> { + try { + Thread.sleep(50); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + List events = + sessionService + .listEvents(APP_NAME, USER_ID, testSession.id()) + .blockingGet() + .events(); + eventsSeenInQuery.set(events.size()); + queryLatch.countDown(); + }) + .blockingSubscribe(); + + queryLatch.await(5, TimeUnit.SECONDS); + + System.out.println( + "doOnNext() test - Events seen in query: " + + eventsSeenInQuery.get() + + " (expected 0 or 1 due to race)"); + } + + /** + * This test demonstrates the CORRECT BEHAVIOR with flatMap(). + * + *

Timeline: T=0ms: Event emitted T=1ms: flatMap() calls appendEvent() T=2ms: Waits for + * appendEvent() Single to complete T=100ms: Database write completes T=101ms: Event flows + * downstream T=102ms: We query listEvents() T=103ms: Query reads database Result: Event IS in + * database! ← CORRECT + */ + @Test + public void testFlatMap_waitsForPersistence() throws Exception { + Session testSession = + sessionService + .createSession(APP_NAME, USER_ID, new ConcurrentHashMap<>(), "session-flatMap") + .blockingGet(); + + AtomicInteger eventsSeenInQuery = new AtomicInteger(0); + CountDownLatch queryLatch = new CountDownLatch(1); + + Event testEvent = + Event.builder() + .id("race-event-2") + .invocationId("inv-2") + .author("test-agent") + .content(Content.builder().parts(Part.builder().text("Test").build()).build()) + .actions(EventActions.builder().build()) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Flowable.just(testEvent) + .flatMap( + event -> + sessionService + .appendEvent(testSession, event) + .toFlowable() + .onErrorResumeNext( + error -> { + System.err.println("Failed to append event: " + error.getMessage()); + return Flowable.just(event); + })) + .doOnNext( + event -> { + List events = + sessionService + .listEvents(APP_NAME, USER_ID, testSession.id()) + .blockingGet() + .events(); + eventsSeenInQuery.set(events.size()); + queryLatch.countDown(); + }) + .blockingSubscribe(); + + queryLatch.await(5, TimeUnit.SECONDS); + + System.out.println( + "flatMap() test - Events seen in query: " + + eventsSeenInQuery.get() + + " (expected 1 - always present)"); + + assertThat(eventsSeenInQuery.get()).isEqualTo(1); + } + + /** + * This test runs multiple iterations to increase the chance of catching the race condition. + * + *

With doOnNext(), we expect to see the race condition manifest as inconsistent query results. + * With flatMap(), we expect 100% consistency. + */ + @Test + public void testRaceCondition_multipleIterations() throws Exception { + int iterations = 10; + int doOnNextMisses = 0; + int flatMapMisses = 0; + + for (int i = 0; i < iterations; i++) { + final int iteration = i; + Session session = + sessionService + .createSession( + APP_NAME, "user-" + iteration, new ConcurrentHashMap<>(), "session-" + iteration) + .blockingGet(); + + Event event = + Event.builder() + .id("event-" + iteration) + .invocationId("inv-" + iteration) + .author("test") + .content( + Content.builder().parts(Part.builder().text("Test " + iteration).build()).build()) + .actions(EventActions.builder().build()) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + AtomicInteger doOnNextCount = new AtomicInteger(0); + CountDownLatch doOnNextLatch = new CountDownLatch(1); + + Flowable.just(event) + .doOnNext(e -> sessionService.appendEvent(session, e)) + .delay(10, TimeUnit.MILLISECONDS) + .doOnNext( + e -> { + int count = + sessionService + .listEvents(APP_NAME, "user-" + iteration, "session-" + iteration) + .blockingGet() + .events() + .size(); + doOnNextCount.set(count); + doOnNextLatch.countDown(); + }) + .blockingSubscribe(); + + doOnNextLatch.await(2, TimeUnit.SECONDS); + if (doOnNextCount.get() == 0) { + doOnNextMisses++; + } + + Session session2 = + sessionService + .createSession( + APP_NAME, + "user2-" + iteration, + new ConcurrentHashMap<>(), + "session2-" + iteration) + .blockingGet(); + + Event event2 = + Event.builder() + .id("event2-" + iteration) + .invocationId("inv2-" + iteration) + .author("test") + .content( + Content.builder() + .parts(Part.builder().text("Test2 " + iteration).build()) + .build()) + .actions(EventActions.builder().build()) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + AtomicInteger flatMapCount = new AtomicInteger(0); + CountDownLatch flatMapLatch = new CountDownLatch(1); + + Flowable.just(event2) + .flatMap( + e -> + sessionService + .appendEvent(session2, e) + .toFlowable() + .onErrorResumeNext(err -> Flowable.just(e))) + .doOnNext( + e -> { + int count = + sessionService + .listEvents(APP_NAME, "user2-" + iteration, "session2-" + iteration) + .blockingGet() + .events() + .size(); + flatMapCount.set(count); + flatMapLatch.countDown(); + }) + .blockingSubscribe(); + + flatMapLatch.await(2, TimeUnit.SECONDS); + if (flatMapCount.get() == 0) { + flatMapMisses++; + } + } + + System.out.println("Race condition test results over " + iterations + " iterations:"); + System.out.println( + " doOnNext() misses: " + doOnNextMisses + " (race condition manifestations)"); + System.out.println(" flatMap() misses: " + flatMapMisses + " (should be 0)"); + + assertThat(flatMapMisses).isEqualTo(0); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/ConcurrentSessionOperationsTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/ConcurrentSessionOperationsTest.java new file mode 100644 index 000000000..7ffb162ab --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/ConcurrentSessionOperationsTest.java @@ -0,0 +1,498 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import com.google.adk.events.Event; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ConcurrentSessionOperationsTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:concurrency_test;DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + private static final String TEST_APP_NAME = "concurrency-test-app"; + private static final String TEST_USER_ID = "concurrency-test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testConcurrentEventAppends() throws InterruptedException { + String sessionId = "concurrent-append-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + int threadCount = 10; + int eventsPerThread = 5; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int t = 0; t < threadCount; t++) { + final int threadId = t; + executor.submit( + () -> { + try { + for (int i = 0; i < eventsPerThread; i++) { + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("thread-" + threadId) + .content(Content.fromParts(Part.fromText("Event from thread " + threadId))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + TimeUnit.MILLISECONDS.sleep(10); + } + } catch (Exception e) { + e.printStackTrace(); + } finally { + latch.countDown(); + } + }); + } + + latch.await(30, TimeUnit.SECONDS); + executor.shutdown(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(session); + assertEquals(threadCount * eventsPerThread, session.events().size()); + } + + @Test + public void testConcurrentSessionCreations() throws InterruptedException { + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + List sessionIds = new ArrayList<>(); + + for (int t = 0; t < threadCount; t++) { + final int threadId = t; + executor.submit( + () -> { + try { + String sessionId = "session-" + threadId; + sessionIds.add(sessionId); + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("thread", threadId); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId) + .blockingGet(); + } finally { + latch.countDown(); + } + }); + } + + latch.await(30, TimeUnit.SECONDS); + executor.shutdown(); + + for (String sessionId : sessionIds) { + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(session); + } + } + + @Test + public void testConcurrentReadsAndWrites() throws InterruptedException { + String sessionId = "read-write-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + int readerCount = 5; + int writerCount = 5; + ExecutorService executor = Executors.newFixedThreadPool(readerCount + writerCount); + CountDownLatch latch = new CountDownLatch(readerCount + writerCount); + + for (int i = 0; i < writerCount; i++) { + final int writerId = i; + executor.submit( + () -> { + try { + for (int j = 0; j < 3; j++) { + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("writer-" + writerId) + .content(Content.fromParts(Part.fromText("Event"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + TimeUnit.MILLISECONDS.sleep(20); + } + } catch (Exception e) { + e.printStackTrace(); + } finally { + latch.countDown(); + } + }); + } + + for (int i = 0; i < readerCount; i++) { + executor.submit( + () -> { + try { + for (int j = 0; j < 5; j++) { + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + TimeUnit.MILLISECONDS.sleep(10); + } + } catch (Exception e) { + e.printStackTrace(); + } finally { + latch.countDown(); + } + }); + } + + latch.await(60, TimeUnit.SECONDS); + executor.shutdown(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(session); + assertEquals(writerCount * 3, session.events().size()); + } + + @Test + public void testConcurrentAppStateUpdates() throws InterruptedException { + int threadCount = 5; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int t = 0; t < threadCount; t++) { + final int threadId = t; + executor.submit( + () -> { + try { + String sessionId = "app-state-" + threadId; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("_app_counter", threadId); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId) + .blockingGet(); + } finally { + latch.countDown(); + } + }); + } + + latch.await(30, TimeUnit.SECONDS); + executor.shutdown(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, "app-state-0", Optional.empty()) + .blockingGet(); + + assertNotNull(session); + assertNotNull(session.state().get("_app_counter")); + } + + @Test + public void testConcurrentCreateSessionsWithSameAppName_noStateCorruption() + throws InterruptedException { + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int t = 0; t < threadCount; t++) { + final int threadId = t; + TimeUnit.MILLISECONDS.sleep(10); + executor.submit( + () -> { + try { + String sessionId = "concurrent-app-state-" + threadId; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put(State.APP_PREFIX + "key_" + threadId, "value_" + threadId); + + sessionService + .createSession(TEST_APP_NAME, "user-" + threadId, state, sessionId) + .blockingGet(); + } catch (Exception e) { + e.printStackTrace(); + } finally { + latch.countDown(); + } + }); + } + + latch.await(30, TimeUnit.SECONDS); + executor.shutdown(); + + Session verifySession = + sessionService + .getSession(TEST_APP_NAME, "user-0", "concurrent-app-state-0", Optional.empty()) + .blockingGet(); + + assertNotNull(verifySession); + for (int i = 0; i < threadCount; i++) { + String key = State.APP_PREFIX + "key_" + i; + assertEquals( + "value_" + i, + verifySession.state().get(key), + "App state should contain all keys from concurrent creates without corruption"); + } + } + + @Test + public void testConcurrentCreateSessionsWithSameUser_noUserStateCorruption() + throws InterruptedException { + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int t = 0; t < threadCount; t++) { + final int threadId = t; + TimeUnit.MILLISECONDS.sleep(10); + executor.submit( + () -> { + try { + String sessionId = "concurrent-user-state-" + threadId; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put(State.USER_PREFIX + "pref_" + threadId, threadId * 100); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId) + .blockingGet(); + } catch (Exception e) { + e.printStackTrace(); + } finally { + latch.countDown(); + } + }); + } + + latch.await(30, TimeUnit.SECONDS); + executor.shutdown(); + + Session verifySession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, "concurrent-user-state-0", Optional.empty()) + .blockingGet(); + + assertNotNull(verifySession); + for (int i = 0; i < threadCount; i++) { + String key = State.USER_PREFIX + "pref_" + i; + assertEquals( + i * 100, + verifySession.state().get(key), + "User state should contain all keys from concurrent creates without corruption"); + } + } + + @Test + public void testConcurrentCreateSessionsWithMixedStateUpdates() throws InterruptedException { + int threadCount = 8; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int t = 0; t < threadCount; t++) { + final int threadId = t; + executor.submit( + () -> { + try { + String sessionId = "mixed-state-" + threadId; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put(State.APP_PREFIX + "shared_app_key", "app_value_" + threadId); + state.put(State.USER_PREFIX + "user_pref_" + threadId, threadId); + state.put("session_local", "local_" + threadId); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId) + .blockingGet(); + TimeUnit.MILLISECONDS.sleep(10); + } catch (Exception e) { + e.printStackTrace(); + } finally { + latch.countDown(); + } + }); + } + + latch.await(30, TimeUnit.SECONDS); + executor.shutdown(); + + Session session0 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, "mixed-state-0", Optional.empty()) + .blockingGet(); + + assertNotNull(session0); + assertNotNull(session0.state().get(State.APP_PREFIX + "shared_app_key")); + + for (int i = 0; i < threadCount; i++) { + assertEquals(i, session0.state().get(State.USER_PREFIX + "user_pref_" + i), "User pref " + i); + } + + for (int i = 0; i < threadCount; i++) { + Session sessionI = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, "mixed-state-" + i, Optional.empty()) + .blockingGet(); + assertEquals("local_" + i, sessionI.state().get("session_local")); + } + } + + @Test + public void testConcurrentCreateSessionsWithPreExistingAppState_noLag() + throws InterruptedException { + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.APP_PREFIX + "key_initial", "initial_value"); + sessionService + .createSession(TEST_APP_NAME, "initial-user", initialState, "initial-session") + .blockingGet(); + + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int t = 0; t < threadCount; t++) { + final int threadId = t; + executor.submit( + () -> { + try { + String sessionId = "pre-existing-app-state-" + threadId; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put(State.APP_PREFIX + "key_" + threadId, "value_" + threadId); + + sessionService + .createSession(TEST_APP_NAME, "user-" + threadId, state, sessionId) + .blockingGet(); + } catch (Exception e) { + e.printStackTrace(); + } finally { + latch.countDown(); + } + }); + } + + latch.await(30, TimeUnit.SECONDS); + executor.shutdown(); + + Session verifySession = + sessionService + .getSession(TEST_APP_NAME, "user-0", "pre-existing-app-state-0", Optional.empty()) + .blockingGet(); + + assertNotNull(verifySession); + assertEquals( + "initial_value", + verifySession.state().get(State.APP_PREFIX + "key_initial"), + "Initial app state key should be preserved"); + for (int i = 0; i < threadCount; i++) { + String key = State.APP_PREFIX + "key_" + i; + assertEquals( + "value_" + i, + verifySession.state().get(key), + "App state should contain all keys when row pre-exists (SELECT FOR UPDATE works)"); + } + } + + @Test + public void testConcurrentCreateSessionsWithPreExistingUserState_noLag() + throws InterruptedException { + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.USER_PREFIX + "pref_initial", -1); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, "initial-session") + .blockingGet(); + + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int t = 0; t < threadCount; t++) { + final int threadId = t; + executor.submit( + () -> { + try { + String sessionId = "pre-existing-user-state-" + threadId; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put(State.USER_PREFIX + "pref_" + threadId, threadId * 100); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId) + .blockingGet(); + } catch (Exception e) { + e.printStackTrace(); + } finally { + latch.countDown(); + } + }); + } + + latch.await(30, TimeUnit.SECONDS); + executor.shutdown(); + + Session verifySession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, "pre-existing-user-state-0", Optional.empty()) + .blockingGet(); + + assertNotNull(verifySession); + assertEquals( + -1, + verifySession.state().get(State.USER_PREFIX + "pref_initial"), + "Initial user state key should be preserved"); + for (int i = 0; i < threadCount; i++) { + String key = State.USER_PREFIX + "pref_" + i; + assertEquals( + i * 100, + verifySession.state().get(key), + "User state should contain all keys when row pre-exists (SELECT FOR UPDATE works)"); + } + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/ContentSerializationTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/ContentSerializationTest.java new file mode 100644 index 000000000..d09fdf8b7 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/ContentSerializationTest.java @@ -0,0 +1,522 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.adk.events.Event; +import com.google.genai.types.Content; +import com.google.genai.types.ExecutableCode; +import com.google.genai.types.FileData; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.flywaydb.core.Flyway; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ContentSerializationTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:testdb_content;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + private static final String TEST_APP_NAME = "content-test-app"; + private static final String TEST_USER_ID = "content-test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + flyway.clean(); + flyway.migrate(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + sessionService.close(); + } + + @Test + public void testTextPartRoundTrip() { + String sessionId = "text-part-test"; + String testText = "Hello, world! This is a test message."; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Event originalEvent = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText(testText))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, originalEvent).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrievedSession); + assertEquals(1, retrievedSession.events().size()); + + Event retrievedEvent = retrievedSession.events().get(0); + assertNotNull(retrievedEvent.content()); + assertTrue(retrievedEvent.content().isPresent()); + + Content content = retrievedEvent.content().get(); + assertNotNull(content.parts()); + assertTrue(content.parts().isPresent()); + + List parts = content.parts().get(); + assertEquals(1, parts.size()); + + Part part = parts.get(0); + assertTrue(part.text().isPresent()); + assertEquals(testText, part.text().get()); + } + + @Test + public void testFunctionCallPartRoundTrip() { + String sessionId = "function-call-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + FunctionCall functionCall = + FunctionCall.builder() + .name("get_weather") + .args(Map.of("location", "San Francisco", "unit", "celsius")) + .id("call-123") + .build(); + + Event originalEvent = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("model") + .content(Content.fromParts(Part.builder().functionCall(functionCall).build())) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, originalEvent).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrievedSession); + assertEquals(1, retrievedSession.events().size()); + + Event retrievedEvent = retrievedSession.events().get(0); + Content content = retrievedEvent.content().get(); + Part part = content.parts().get().get(0); + + assertTrue(part.functionCall().isPresent()); + FunctionCall retrievedCall = part.functionCall().get(); + + assertEquals("get_weather", retrievedCall.name().get()); + assertEquals("call-123", retrievedCall.id().get()); + + Map retrievedArgs = retrievedCall.args().get(); + assertEquals("San Francisco", retrievedArgs.get("location")); + assertEquals("celsius", retrievedArgs.get("unit")); + } + + @Test + public void testFileDataPartRoundTrip() { + String sessionId = "file-data-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + FileData fileData = + FileData.builder() + .fileUri("gs://bucket/path/to/file.pdf") + .mimeType("application/pdf") + .build(); + + Event originalEvent = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("user") + .content(Content.fromParts(Part.builder().fileData(fileData).build())) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, originalEvent).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + Event retrievedEvent = retrievedSession.events().get(0); + Part part = retrievedEvent.content().get().parts().get().get(0); + + assertTrue(part.fileData().isPresent()); + FileData retrievedFileData = part.fileData().get(); + + assertEquals("gs://bucket/path/to/file.pdf", retrievedFileData.fileUri().get()); + assertEquals("application/pdf", retrievedFileData.mimeType().get()); + } + + @Test + public void testFunctionResponsePartRoundTrip() { + String sessionId = "function-response-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + FunctionResponse functionResponse = + FunctionResponse.builder() + .name("get_weather") + .response(Map.of("temperature", 72, "conditions", "sunny")) + .id("call-123") + .build(); + + Event originalEvent = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("tool") + .content(Content.fromParts(Part.builder().functionResponse(functionResponse).build())) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, originalEvent).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + Event retrievedEvent = retrievedSession.events().get(0); + Part part = retrievedEvent.content().get().parts().get().get(0); + + assertTrue(part.functionResponse().isPresent()); + FunctionResponse retrievedResponse = part.functionResponse().get(); + + assertEquals("get_weather", retrievedResponse.name().get()); + assertEquals("call-123", retrievedResponse.id().get()); + + Map responseData = retrievedResponse.response().get(); + assertEquals(72, responseData.get("temperature")); + assertEquals("sunny", responseData.get("conditions")); + } + + @Test + public void testExecutableCodePartRoundTrip() { + String sessionId = "executable-code-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ExecutableCode executableCode = ExecutableCode.builder().code("print('Hello, World!')").build(); + + Event originalEvent = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("model") + .content(Content.fromParts(Part.builder().executableCode(executableCode).build())) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, originalEvent).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + Event retrievedEvent = retrievedSession.events().get(0); + Part part = retrievedEvent.content().get().parts().get().get(0); + + assertTrue(part.executableCode().isPresent()); + ExecutableCode retrievedCode = part.executableCode().get(); + + assertEquals("print('Hello, World!')", retrievedCode.code().get()); + } + + @Test + public void testMixedPartsInSingleEvent() { + String sessionId = "mixed-parts-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Part textPart = Part.fromText("Let me call a function:"); + Part functionCallPart = + Part.builder() + .functionCall( + FunctionCall.builder() + .name("calculate") + .args(Map.of("expression", "2+2")) + .id("calc-1") + .build()) + .build(); + Part fileDataPart = + Part.builder() + .fileData( + FileData.builder().fileUri("gs://bucket/data.csv").mimeType("text/csv").build()) + .build(); + + Event originalEvent = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("model") + .content(Content.fromParts(textPart, functionCallPart, fileDataPart)) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, originalEvent).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + Event retrievedEvent = retrievedSession.events().get(0); + List parts = retrievedEvent.content().get().parts().get(); + + assertEquals(3, parts.size()); + + assertTrue(parts.get(0).text().isPresent()); + assertEquals("Let me call a function:", parts.get(0).text().get()); + + assertTrue(parts.get(1).functionCall().isPresent()); + assertEquals("calculate", parts.get(1).functionCall().get().name().get()); + + assertTrue(parts.get(2).fileData().isPresent()); + assertEquals("gs://bucket/data.csv", parts.get(2).fileData().get().fileUri().get()); + } + + /** + * Tests that a multi-turn conversation with function calls is correctly serialized and + * deserialized. This verifies the complete workflow: user message -> model function call -> tool + * response -> model final response. + * + *

IMPORTANT: Events are created with incrementing timestamps (100ms apart) to simulate + * realistic timing. In production, events naturally have different timestamps due to processing + * delays. Without timestamp separation, events with identical timestamps would have undefined + * ordering since the database only sorts by timestamp. This test previously failed intermittently + * because it created all events with Instant.now() within the same millisecond, causing + * non-deterministic ordering. + */ + @Test + public void testMultiTurnConversationWithTools() { + String sessionId = "multi-turn-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + long baseTimestamp = Instant.now().toEpochMilli(); + + Event userMessage = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("user") + .content(Content.fromParts(Part.fromText("What's the weather in Tokyo?"))) + .timestamp(baseTimestamp) + .build(); + + Event modelFunctionCall = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("model") + .content( + Content.fromParts( + Part.builder() + .functionCall( + FunctionCall.builder() + .name("get_weather") + .args(Map.of("city", "Tokyo")) + .id("weather-1") + .build()) + .build())) + .timestamp(baseTimestamp + 100) + .build(); + + Event toolResponse = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("tool") + .content( + Content.fromParts( + Part.builder() + .functionResponse( + FunctionResponse.builder() + .name("get_weather") + .response(Map.of("temp", 18, "condition", "cloudy")) + .id("weather-1") + .build()) + .build())) + .timestamp(baseTimestamp + 200) + .build(); + + Event modelFinalResponse = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("model") + .content(Content.fromParts(Part.fromText("The weather in Tokyo is 18°C and cloudy."))) + .timestamp(baseTimestamp + 300) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, userMessage).blockingGet(); + + session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, modelFunctionCall).blockingGet(); + + session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, toolResponse).blockingGet(); + + session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, modelFinalResponse).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertEquals(4, retrievedSession.events().size()); + + assertEquals( + "What's the weather in Tokyo?", + retrievedSession.events().get(0).content().get().parts().get().get(0).text().get()); + + FunctionCall retrievedCall = + retrievedSession.events().get(1).content().get().parts().get().get(0).functionCall().get(); + assertEquals("get_weather", retrievedCall.name().get()); + assertEquals("Tokyo", retrievedCall.args().get().get("city")); + + FunctionResponse retrievedResponse = + retrievedSession + .events() + .get(2) + .content() + .get() + .parts() + .get() + .get(0) + .functionResponse() + .get(); + assertEquals("get_weather", retrievedResponse.name().get()); + assertEquals(18, retrievedResponse.response().get().get("temp")); + + assertEquals( + "The weather in Tokyo is 18°C and cloudy.", + retrievedSession.events().get(3).content().get().parts().get().get(0).text().get()); + } + + @Test + public void testEmptyAndNullContent() { + String sessionId = "empty-content-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Event emptyContentEvent = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("system") + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, emptyContentEvent).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + Event retrievedEvent = retrievedSession.events().get(0); + assertTrue( + retrievedEvent.content().isEmpty() || retrievedEvent.content().get().parts().isEmpty()); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/DatabaseSessionServiceTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/DatabaseSessionServiceTest.java new file mode 100644 index 000000000..01d6507b5 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/DatabaseSessionServiceTest.java @@ -0,0 +1,693 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class DatabaseSessionServiceTest { + + private static final String TEST_DB_URL = "jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + private static final String TEST_APP_NAME = "test-app"; + private static final String TEST_USER_ID = "test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + try (Connection conn = DriverManager.getConnection(TEST_DB_URL); + Statement stmt = conn.createStatement()) { + stmt.execute("DELETE FROM events"); + stmt.execute("DELETE FROM sessions"); + stmt.execute("DELETE FROM app_states"); + stmt.execute("DELETE FROM user_states"); + } catch (Exception e) { + } + sessionService.close(); + } + } + + private long countEventsInDatabase(String sessionId) throws Exception { + try (Connection conn = DriverManager.getConnection(TEST_DB_URL); + Statement stmt = conn.createStatement(); + ResultSet rs = + stmt.executeQuery( + "SELECT COUNT(*) FROM events WHERE session_id = '" + sessionId + "'")) { + if (rs.next()) { + return rs.getLong(1); + } + return 0; + } + } + + private long countSessionsInDatabase(String sessionId) throws Exception { + try (Connection conn = DriverManager.getConnection(TEST_DB_URL); + Statement stmt = conn.createStatement(); + ResultSet rs = + stmt.executeQuery("SELECT COUNT(*) FROM sessions WHERE id = '" + sessionId + "'")) { + if (rs.next()) { + return rs.getLong(1); + } + return 0; + } + } + + @Test + public void testCreateSession() { + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("key1", "value1"); + state.put("key2", 42); + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, null).blockingGet(); + + assertNotNull(session); + assertNotNull(session.id()); + assertEquals(TEST_APP_NAME, session.appName()); + assertEquals(TEST_USER_ID, session.userId()); + assertEquals("value1", session.state().get("key1")); + assertEquals(42, session.state().get("key2")); + assertTrue(session.events().isEmpty()); + } + + @Test + public void testCreateSessionWithId() { + String sessionId = "custom-session-id"; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + assertNotNull(session); + assertEquals(sessionId, session.id()); + assertEquals(TEST_APP_NAME, session.appName()); + assertEquals(TEST_USER_ID, session.userId()); + } + + @Test + public void testGetSession() { + String sessionId = "get-session-test"; + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("key", "value"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrievedSession); + assertEquals(sessionId, retrievedSession.id()); + assertEquals(TEST_APP_NAME, retrievedSession.appName()); + assertEquals(TEST_USER_ID, retrievedSession.userId()); + assertEquals("value", retrievedSession.state().get("key")); + } + + @Test + public void testGetSessionNotFound() { + String nonExistentId = "non-existent"; + + assertNull( + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, nonExistentId, Optional.empty()) + .blockingGet()); + } + + @Test + public void testListSessionsEmpty() { + ListSessionsResponse response = + sessionService.listSessions(TEST_APP_NAME, TEST_USER_ID).blockingGet(); + + assertNotNull(response); + assertEquals(0, response.sessions().size()); + } + + @Test + public void testListSessions() { + String sessionId1 = "list-test-1"; + String sessionId2 = "list-test-2"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId1) + .blockingGet(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId2) + .blockingGet(); + + ListSessionsResponse response = + sessionService.listSessions(TEST_APP_NAME, TEST_USER_ID).blockingGet(); + + assertNotNull(response); + List sessions = response.sessions(); + assertEquals(2, sessions.size()); + assertTrue(sessions.stream().anyMatch(s -> s.id().equals(sessionId1))); + assertTrue(sessions.stream().anyMatch(s -> s.id().equals(sessionId2))); + } + + @Test + public void testAppendEvent() { + String sessionId = "event-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Hello, world!"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + Session updatedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updatedSession); + assertEquals(1, updatedSession.events().size()); + Event retrievedEvent = updatedSession.events().get(0); + assertEquals(event.id(), retrievedEvent.id()); + assertEquals(event.author(), retrievedEvent.author()); + } + + @Test + public void testAppendEventToNonExistentSession() { + String nonExistentId = "non-existent"; + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Hello, world!"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session nonExistentSession = + Session.builder(nonExistentId) + .appName(TEST_APP_NAME) + .userId(TEST_USER_ID) + .state(new ConcurrentHashMap<>()) + .events(new ArrayList<>()) + .build(); + assertThrows( + SessionNotFoundException.class, + () -> sessionService.appendEvent(nonExistentSession, event).blockingGet()); + } + + @Test + public void testDeleteSession() { + String sessionId = "delete-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + assertNotNull( + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet()); + + sessionService.deleteSession(TEST_APP_NAME, TEST_USER_ID, sessionId).blockingAwait(); + + assertNull( + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet()); + } + + @Test + public void testListEvents() { + String sessionId = "list-events-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("index: " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + try { + TimeUnit.MILLISECONDS.sleep(10); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + ListEventsResponse response = + sessionService.listEvents(TEST_APP_NAME, TEST_USER_ID, sessionId).blockingGet(); + + assertNotNull(response); + assertEquals(5, response.events().size()); + } + + @Test + public void testGetSessionWithNumRecentEvents() { + String sessionId = "filter-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Instant startTime = Instant.now(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("index: " + i))) + .timestamp(startTime.plusSeconds(i).toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + } + + GetSessionConfig config = GetSessionConfig.builder().numRecentEvents(2).build(); + Session sessionWithRecentEvents = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(sessionWithRecentEvents); + assertEquals(2, sessionWithRecentEvents.events().size()); + } + + @Test + public void testAppendEventUpdatesSessionState() { + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), "session1") + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("sessionKey", "sessionValue"); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test event"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertNotNull(retrievedSession); + assertEquals("sessionValue", retrievedSession.state().get("sessionKey")); + } + + @Test + public void testAppendEventUpdatesAppState() { + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), "session2") + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("_app_appKey", "appValue"); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test event"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertNotNull(retrievedSession); + assertEquals("appValue", retrievedSession.state().get("_app_appKey")); + } + + @Test + public void testAppendEventUpdatesUserState() { + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), "session3") + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("_user_userKey", "userValue"); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test event"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session retrievedSession = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertNotNull(retrievedSession); + assertEquals("userValue", retrievedSession.state().get("_user_userKey")); + } + + @Test + public void testDeleteSessionRemovesAllRelatedData() throws Exception { + String sessionId = "delete-cascade-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + } + + long eventsBefore = countEventsInDatabase(sessionId); + assertEquals(5, eventsBefore, "Should have 5 events before deletion"); + + long sessionsBefore = countSessionsInDatabase(sessionId); + assertEquals(1, sessionsBefore, "Should have 1 session before deletion"); + + sessionService.deleteSession(TEST_APP_NAME, TEST_USER_ID, sessionId).blockingAwait(); + + long eventsAfter = countEventsInDatabase(sessionId); + assertEquals(0, eventsAfter, "All events should be deleted from database"); + + long sessionsAfter = countSessionsInDatabase(sessionId); + assertEquals(0, sessionsAfter, "Session should be deleted from database"); + } + + @Test + public void testEventsPersistAfterMultipleReads() throws Exception { + String sessionId = "persist-after-reads-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 3; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + } + + long eventsBeforeReads = countEventsInDatabase(sessionId); + assertEquals(3, eventsBeforeReads); + + for (int i = 0; i < 5; i++) { + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + } + + long eventsAfterReads = countEventsInDatabase(sessionId); + assertEquals(3, eventsAfterReads, "Events should persist in database after multiple reads"); + } + + @Test + public void testAppendEventWithNullContent() { + String sessionId = "null-content-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Event eventWithNullContent = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Optional.empty()) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, eventWithNullContent).blockingGet(); + Session updatedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updatedSession); + assertEquals(1, updatedSession.events().size()); + } + + @Test + public void testEmptyStateDelta() { + String sessionId = "empty-delta-test"; + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Event eventWithEmptyDelta = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(new ConcurrentHashMap<>()).build()) + .build(); + + sessionService.appendEvent(session, eventWithEmptyDelta).blockingGet(); + Session updatedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updatedSession); + assertEquals(1, updatedSession.events().size()); + } + + @Test + public void testNullStateDeltaHandling() { + String sessionId = "null-delta-test"; + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Event eventWithNullActions = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(null) + .build(); + + sessionService.appendEvent(session, eventWithNullActions).blockingGet(); + Session updatedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updatedSession); + assertEquals(1, updatedSession.events().size()); + } + + @Test + public void testAppendEventWithRemovedDeletesKeys() throws Exception { + String sessionId = UUID.randomUUID().toString(); + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.APP_PREFIX + "app_key", "app_value"); + initialState.put(State.USER_PREFIX + "user_key", "user_value"); + initialState.put("session_key", "session_value"); + + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + assertNotNull(session); + assertEquals("app_value", session.state().get(State.APP_PREFIX + "app_key")); + assertEquals("user_value", session.state().get(State.USER_PREFIX + "user_key")); + assertEquals("session_value", session.state().get("session_key")); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "app_key", State.REMOVED); + delta.put(State.USER_PREFIX + "user_key", State.REMOVED); + delta.put("session_key", State.REMOVED); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Remove keys"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + assertFalse(updated.state().containsKey(State.APP_PREFIX + "app_key")); + assertFalse(updated.state().containsKey(State.USER_PREFIX + "user_key")); + assertFalse(updated.state().containsKey("session_key")); + + try (Connection conn = DriverManager.getConnection(TEST_DB_URL); + Statement stmt = conn.createStatement()) { + + ResultSet rs = + stmt.executeQuery( + "SELECT state FROM app_states WHERE app_name = '" + TEST_APP_NAME + "'"); + if (rs.next()) { + String appStateJson = rs.getString("state"); + assertFalse(appStateJson.contains("app_key"), "app_key should be removed from database"); + } + + rs = + stmt.executeQuery( + "SELECT state FROM user_states WHERE app_name = '" + + TEST_APP_NAME + + "' AND user_id = '" + + TEST_USER_ID + + "'"); + if (rs.next()) { + String userStateJson = rs.getString("state"); + assertFalse(userStateJson.contains("user_key"), "user_key should be removed from database"); + } + + rs = stmt.executeQuery("SELECT state FROM sessions WHERE id = '" + sessionId + "'"); + if (rs.next()) { + String sessionStateJson = rs.getString("state"); + assertFalse( + sessionStateJson.contains("session_key"), + "session_key should be removed from database"); + } + } + } + + @Test + public void testRemovedOnlyAffectsSpecifiedTier() throws Exception { + String sessionId = UUID.randomUUID().toString(); + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.APP_PREFIX + "app_keep", "app_value"); + initialState.put(State.APP_PREFIX + "app_remove", "remove_this"); + initialState.put(State.USER_PREFIX + "user_keep", "user_value"); + initialState.put(State.USER_PREFIX + "user_remove", "remove_this"); + initialState.put("session_keep", "session_value"); + initialState.put("session_remove", "remove_this"); + + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "app_remove", State.REMOVED); + delta.put(State.USER_PREFIX + "user_remove", State.REMOVED); + delta.put("session_remove", State.REMOVED); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Selective removal"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + + assertEquals("app_value", updated.state().get(State.APP_PREFIX + "app_keep")); + assertFalse(updated.state().containsKey(State.APP_PREFIX + "app_remove")); + + assertEquals("user_value", updated.state().get(State.USER_PREFIX + "user_keep")); + assertFalse(updated.state().containsKey(State.USER_PREFIX + "user_remove")); + + assertEquals("session_value", updated.state().get("session_keep")); + assertFalse(updated.state().containsKey("session_remove")); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/DiagnosticTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/DiagnosticTest.java new file mode 100644 index 000000000..82a6a8591 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/DiagnosticTest.java @@ -0,0 +1,112 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import com.google.adk.events.Event; +import com.google.adk.testing.TestDatabaseConfig; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +@Tag("integration") +public class DiagnosticTest { + private static final String TEST_DB_URL = TestDatabaseConfig.MYSQL_JDBC_URL; + private static final String TEST_APP_NAME = "diagnostic-test"; + private static final String TEST_USER_ID = "diagnostic-user"; + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + assumeTrue( + TestDatabaseConfig.isMySQLAvailable(), + TestDatabaseConfig.getDatabaseNotAvailableMessage("MySQL")); + sessionService = new DatabaseSessionService(TEST_DB_URL, new java.util.HashMap<>()); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void diagnosticAfterTimestampFiltering() { + String sessionId = "diag-" + System.currentTimeMillis(); + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Instant baseTime = Instant.now(); + System.out.println("Base time: " + baseTime); + + // Create 10 events + for (int i = 1; i <= 10; i++) { + Instant eventTime = baseTime.plusSeconds(i); + Event event = + Event.builder() + .id("event-" + i) + .author("test") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(eventTime.toEpochMilli()) + .build(); + session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + System.out.println( + "Created event-" + + i + + " with timestamp: " + + eventTime + + " (" + + eventTime.toEpochMilli() + + ")"); + try { + TimeUnit.MILLISECONDS.sleep(5); + } catch (InterruptedException e) { + } + } + + // Get all events + Session allEvents = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + System.out.println("\n=== ALL EVENTS ==="); + for (Event e : allEvents.events()) { + System.out.println( + e.id() + ": " + e.timestamp() + " (" + Instant.ofEpochMilli(e.timestamp()) + ")"); + } + + // Filter after 5 seconds + Instant threshold = baseTime.plusSeconds(5); + System.out.println( + "\n=== FILTERING AFTER: " + threshold + " (" + threshold.toEpochMilli() + ") ==="); + GetSessionConfig config = GetSessionConfig.builder().afterTimestamp(threshold).build(); + Session filtered = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + System.out.println("Expected ~5 events (event-6 through event-10)"); + System.out.println("Actually got: " + filtered.events().size() + " events"); + for (Event e : filtered.events()) { + System.out.println( + " " + e.id() + ": " + e.timestamp() + " (" + Instant.ofEpochMilli(e.timestamp()) + ")"); + } + + assertEquals(5, filtered.events().size(), "Should get exactly 5 events after threshold"); + assertEquals("event-6", filtered.events().get(0).id()); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/DialectDetectorTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/DialectDetectorTest.java new file mode 100644 index 000000000..a2dbbb24a --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/DialectDetectorTest.java @@ -0,0 +1,119 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.google.adk.sessions.dialect.DialectDetector; +import com.google.adk.sessions.dialect.H2Dialect; +import com.google.adk.sessions.dialect.MySqlDialect; +import com.google.adk.sessions.dialect.PostgresDialect; +import com.google.adk.sessions.dialect.SpannerDialect; +import com.google.adk.sessions.dialect.SqlDialect; +import org.junit.jupiter.api.Test; + +public class DialectDetectorTest { + + @Test + public void testDetectPostgreSQLDialectFromUrl() { + String url = "jdbc:postgresql://localhost:5432/testdb"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(PostgresDialect.class, dialect.getClass()); + assertEquals("PostgreSQL", dialect.dialectName()); + } + + @Test + public void testDetectMySQLDialectFromUrl() { + String url = "jdbc:mysql://localhost:3306/testdb"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(MySqlDialect.class, dialect.getClass()); + assertEquals("MySQL", dialect.dialectName()); + } + + @Test + public void testDetectH2DialectFromUrl() { + String url = "jdbc:h2:mem:testdb"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(H2Dialect.class, dialect.getClass()); + assertEquals("H2", dialect.dialectName()); + } + + @Test + public void testDetectSpannerDialectFromUrl() { + String url = "jdbc:cloudspanner:/projects/test/instances/test/databases/test"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(SpannerDialect.class, dialect.getClass()); + assertEquals("Cloud Spanner", dialect.dialectName()); + } + + @Test + public void testDetectDialectWithParametersInUrl() { + String url = "jdbc:postgresql://localhost:5432/testdb?user=admin&password=secret"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(PostgresDialect.class, dialect.getClass()); + } + + @Test + public void testDetectDialectUnsupportedDatabase() { + String url = "jdbc:oracle:thin:@localhost:1521:testdb"; + assertThrows(IllegalArgumentException.class, () -> DialectDetector.detectFromJdbcUrl(url)); + } + + @Test + public void testDetectH2InMemoryDatabase() { + String url = "jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(H2Dialect.class, dialect.getClass()); + } + + @Test + public void testDetectH2FileDatabase() { + String url = "jdbc:h2:file:/data/testdb"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(H2Dialect.class, dialect.getClass()); + } + + @Test + public void testDetectPostgreSQLWithSSL() { + String url = + "jdbc:postgresql://localhost:5432/testdb?ssl=true&sslfactory=org.postgresql.ssl.NonValidatingFactory"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(PostgresDialect.class, dialect.getClass()); + } + + @Test + public void testDetectMySQLWithUTF8() { + String url = "jdbc:mysql://localhost:3306/testdb?useUnicode=true&characterEncoding=UTF-8"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(MySqlDialect.class, dialect.getClass()); + } + + @Test + public void testDetectDialectCaseInsensitive() { + String urlUpper = "JDBC:POSTGRESQL://localhost:5432/testdb"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(urlUpper); + assertEquals(PostgresDialect.class, dialect.getClass()); + } + + @Test + public void testDetectSpannerWithComplexUrl() { + String url = + "jdbc:cloudspanner:/projects/my-project/instances/my-instance/databases/my-database?credentials=/path/to/credentials.json"; + SqlDialect dialect = DialectDetector.detectFromJdbcUrl(url); + assertEquals(SpannerDialect.class, dialect.getClass()); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/EventFilteringTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/EventFilteringTest.java new file mode 100644 index 000000000..10f1b2856 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/EventFilteringTest.java @@ -0,0 +1,297 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.adk.events.Event; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class EventFilteringTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:filter_test;DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + private static final String TEST_APP_NAME = "filter-test-app"; + private static final String TEST_USER_ID = "filter-test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testFilterByNumRecentEvents() throws InterruptedException { + String sessionId = "recent-events-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 10; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + TimeUnit.MILLISECONDS.sleep(10); + } + + GetSessionConfig config = GetSessionConfig.builder().numRecentEvents(3).build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(session); + assertEquals(3, session.events().size()); + assertEquals("event-8", session.events().get(0).id()); + assertEquals("event-9", session.events().get(1).id()); + assertEquals("event-10", session.events().get(2).id()); + } + + @Test + public void testFilterByAfterTimestamp() throws InterruptedException { + String sessionId = "timestamp-filter-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Instant startTime = Instant.now(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(startTime.plusSeconds(i).toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + GetSessionConfig config = + GetSessionConfig.builder().afterTimestamp(startTime.plusSeconds(3)).build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(session); + assertEquals(2, session.events().size()); + assertEquals("event-4", session.events().get(0).id()); + assertEquals("event-5", session.events().get(1).id()); + } + + @Test + public void testFilterByNumRecentEventsZero() { + String sessionId = "zero-events-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + + GetSessionConfig config = GetSessionConfig.builder().numRecentEvents(0).build(); + Session filteredSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(filteredSession); + assertEquals(0, filteredSession.events().size()); + } + + @Test + public void testNoFilterReturnsAllEvents() throws InterruptedException { + String sessionId = "all-events-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + TimeUnit.MILLISECONDS.sleep(10); + } + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(session); + assertEquals(5, session.events().size()); + } + + @Test + public void testCombinedFilters() throws InterruptedException { + String sessionId = "combined-filter-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Instant startTime = Instant.now(); + + for (int i = 1; i <= 10; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(startTime.plusSeconds(i).toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + GetSessionConfig config = + GetSessionConfig.builder() + .afterTimestamp(startTime.plusSeconds(3)) + .numRecentEvents(3) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(session); + assertTrue(session.events().size() <= 3); + } + + @Test + public void testNoFilterReturnsAllEventsLargeDataset() throws InterruptedException { + String sessionId = "large-dataset-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 50; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + TimeUnit.MILLISECONDS.sleep(10); + } + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(session); + assertEquals(50, session.events().size()); + assertEquals("event-1", session.events().get(0).id()); + assertEquals("event-25", session.events().get(24).id()); + assertEquals("event-50", session.events().get(49).id()); + } + + @Test + public void testLimitedEventsFromLargeDatasetReturnsCorrectOrder() throws InterruptedException { + String sessionId = "limited-large-dataset-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 50; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + TimeUnit.MILLISECONDS.sleep(10); + } + + GetSessionConfig config = GetSessionConfig.builder().numRecentEvents(20).build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + + assertNotNull(session); + assertEquals(20, session.events().size()); + assertEquals("event-31", session.events().get(0).id()); + assertEquals("event-40", session.events().get(9).id()); + assertEquals("event-50", session.events().get(19).id()); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/FlywayMigrationTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/FlywayMigrationTest.java new file mode 100644 index 000000000..6f00a1749 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/FlywayMigrationTest.java @@ -0,0 +1,70 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import org.junit.jupiter.api.Test; + +public class FlywayMigrationTest { + + @Test + public void testFlywayMigrationsApplied() { + String dbUrl = "jdbc:h2:mem:flyway_test;DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + + assertDoesNotThrow( + () -> { + try (DatabaseSessionService service = new DatabaseSessionService(dbUrl)) { + assertNotNull(service); + } + }); + + try (Connection conn = DriverManager.getConnection(dbUrl); + Statement stmt = conn.createStatement()) { + + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM \"flyway_schema_history\""); + rs.next(); + int migrationCount = rs.getInt(1); + assert migrationCount > 0 : "Flyway migrations should be applied"; + + rs = + stmt.executeQuery( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'PUBLIC'"); + boolean hasAppStates = false; + boolean hasUserStates = false; + boolean hasSessions = false; + boolean hasEvents = false; + + while (rs.next()) { + String tableName = rs.getString("table_name"); + if (tableName.equalsIgnoreCase("APP_STATES")) hasAppStates = true; + if (tableName.equalsIgnoreCase("USER_STATES")) hasUserStates = true; + if (tableName.equalsIgnoreCase("SESSIONS")) hasSessions = true; + if (tableName.equalsIgnoreCase("EVENTS")) hasEvents = true; + } + + assert hasAppStates : "app_states table should exist"; + assert hasUserStates : "user_states table should exist"; + assert hasSessions : "sessions table should exist"; + assert hasEvents : "events table should exist"; + + } catch (Exception e) { + throw new RuntimeException("Failed to verify Flyway migrations", e); + } + } + + @Test + public void testMultipleServiceInstancesShareSchema() { + String dbUrl = "jdbc:h2:mem:shared_schema_test;DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + + try (DatabaseSessionService service1 = new DatabaseSessionService(dbUrl); + DatabaseSessionService service2 = new DatabaseSessionService(dbUrl)) { + + assertNotNull(service1); + assertNotNull(service2); + } + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/HikariConfigTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/HikariConfigTest.java new file mode 100644 index 000000000..a5111a5f7 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/HikariConfigTest.java @@ -0,0 +1,216 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import com.zaxxer.hikari.HikariDataSource; +import java.lang.reflect.Field; +import java.util.HashMap; +import java.util.Map; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +public class HikariConfigTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:hikari_test;DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + private DatabaseSessionService sessionService; + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testDefaultHikariConfig() throws Exception { + sessionService = new DatabaseSessionService(TEST_DB_URL); + + HikariDataSource dataSource = getDataSource(sessionService); + assertNotNull(dataSource); + assertEquals(10, dataSource.getMaximumPoolSize()); + assertEquals(2, dataSource.getMinimumIdle()); + assertEquals(30000, dataSource.getConnectionTimeout()); + assertEquals(600000, dataSource.getIdleTimeout()); + assertEquals(1800000, dataSource.getMaxLifetime()); + } + + @Test + public void testCustomMaximumPoolSize() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.maximumPoolSize", 20); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(20, dataSource.getMaximumPoolSize()); + assertEquals(2, dataSource.getMinimumIdle()); + } + + @Test + public void testCustomMinimumIdle() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.minimumIdle", 5); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(10, dataSource.getMaximumPoolSize()); + assertEquals(5, dataSource.getMinimumIdle()); + } + + @Test + public void testCustomConnectionTimeout() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.connectionTimeout", 60000L); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(60000, dataSource.getConnectionTimeout()); + } + + @Test + public void testCustomIdleTimeout() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.idleTimeout", 300000L); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(300000, dataSource.getIdleTimeout()); + } + + @Test + public void testCustomMaxLifetime() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.maxLifetime", 900000L); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(900000, dataSource.getMaxLifetime()); + } + + @Test + public void testAllCustomHikariProperties() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.maximumPoolSize", 25); + properties.put("hikari.minimumIdle", 10); + properties.put("hikari.connectionTimeout", 45000L); + properties.put("hikari.idleTimeout", 400000L); + properties.put("hikari.maxLifetime", 1200000L); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(25, dataSource.getMaximumPoolSize()); + assertEquals(10, dataSource.getMinimumIdle()); + assertEquals(45000, dataSource.getConnectionTimeout()); + assertEquals(400000, dataSource.getIdleTimeout()); + assertEquals(1200000, dataSource.getMaxLifetime()); + } + + @Test + public void testNonHikariPropertiesArePassedToDataSource() throws Exception { + Map properties = new HashMap<>(); + properties.put("cachePrepStmts", "true"); + properties.put("prepStmtCacheSize", 250); + properties.put("hikari.maximumPoolSize", 15); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(15, dataSource.getMaximumPoolSize()); + } + + @Test + public void testInvalidIntegerPropertyUsesDefault() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.maximumPoolSize", "invalid"); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(10, dataSource.getMaximumPoolSize()); + } + + @Test + public void testInvalidLongPropertyUsesDefault() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.connectionTimeout", "invalid"); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(30000, dataSource.getConnectionTimeout()); + } + + @Test + public void testIntegerAsNumberType() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.maximumPoolSize", Integer.valueOf(30)); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(30, dataSource.getMaximumPoolSize()); + } + + @Test + public void testLongAsNumberType() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.connectionTimeout", Long.valueOf(50000L)); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(50000, dataSource.getConnectionTimeout()); + } + + @Test + public void testStringNumberConversion() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.maximumPoolSize", "35"); + properties.put("hikari.connectionTimeout", "40000"); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(35, dataSource.getMaximumPoolSize()); + assertEquals(40000, dataSource.getConnectionTimeout()); + } + + @Test + public void testNullPropertyUsesDefault() throws Exception { + Map properties = new HashMap<>(); + properties.put("hikari.maximumPoolSize", null); + + sessionService = new DatabaseSessionService(TEST_DB_URL, properties); + + HikariDataSource dataSource = getDataSource(sessionService); + assertEquals(10, dataSource.getMaximumPoolSize()); + } + + private HikariDataSource getDataSource(DatabaseSessionService service) throws Exception { + Field dataSourceField = DatabaseSessionService.class.getDeclaredField("dataSource"); + dataSourceField.setAccessible(true); + return (HikariDataSource) dataSourceField.get(service); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/ListSessionsEventsTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/ListSessionsEventsTest.java new file mode 100644 index 000000000..9973d6d31 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/ListSessionsEventsTest.java @@ -0,0 +1,89 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import com.google.adk.events.Event; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ListSessionsEventsTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:list_test;DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + private static final String TEST_APP_NAME = "list-test-app"; + private static final String TEST_USER_ID = "list-test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testListSessionsReturnsAllSessions() { + String userId = "list-sessions-user"; + for (int i = 1; i <= 50; i++) { + sessionService + .createSession(TEST_APP_NAME, userId, new ConcurrentHashMap<>(), "session-" + i) + .blockingGet(); + } + + ListSessionsResponse response = + sessionService.listSessions(TEST_APP_NAME, userId).blockingGet(); + + assertNotNull(response); + assertEquals(50, response.sessions().size()); + } + + @Test + public void testListEventsReturnsAllEvents() throws InterruptedException { + String userId = "list-events-user"; + String sessionId = "all-events-test"; + sessionService + .createSession(TEST_APP_NAME, userId, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 50; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, userId, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + TimeUnit.MILLISECONDS.sleep(10); + } + + ListEventsResponse response = + sessionService.listEvents(TEST_APP_NAME, userId, sessionId).blockingGet(); + + assertNotNull(response); + assertEquals(50, response.events().size()); + assertEquals("event-1", response.events().get(0).id()); + assertEquals("event-25", response.events().get(24).id()); + assertEquals("event-50", response.events().get(49).id()); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/MySQLAgentIntegrationTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/MySQLAgentIntegrationTest.java new file mode 100644 index 000000000..f8afa272b --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/MySQLAgentIntegrationTest.java @@ -0,0 +1,355 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.SequentialAgent; +import com.google.adk.events.Event; +import com.google.adk.models.LlmResponse; +import com.google.adk.testing.TestDatabaseConfig; +import com.google.adk.testing.TestLlm; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +/** + * Integration tests for Agents using DatabaseSessionService with real MySQL 8.0 database. + * + *

This test suite verifies that agents work correctly with MySQL-backed session persistence, + * including: - Sequential agent execution with database persistence - State propagation between + * agents via outputKey - Event storage and retrieval - App/user/session state management with + * database backend + * + *

Prerequisites: Start MySQL test database with: + * + *

{@code
+ * docker-compose -f scripts/docker-compose.test.yml up -d mysql-test
+ * }
+ * + *

Configuration: - Host: localhost:3307 - Database: adk_test - User: adk_user - Password: + * adk_password + */ +@Tag("integration") +public class MySQLAgentIntegrationTest { + + private static final String TEST_DB_URL = TestDatabaseConfig.MYSQL_JDBC_URL; + private static final String TEST_APP_NAME = "mysql-agent-integration-test"; + private static final String TEST_USER_ID = "agent-test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + assumeTrue( + TestDatabaseConfig.isMySQLAvailable(), + TestDatabaseConfig.getDatabaseNotAvailableMessage("MySQL")); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testSequentialAgentWithDatabasePersistence() { + Content agentAResponse = Content.fromParts(Part.fromText("The topic is: AI")); + TestLlm llmA = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(agentAResponse) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agentA = + LlmAgent.builder() + .name("AgentA") + .model(llmA) + .instruction("Extract topic") + .outputKey("topic") + .build(); + + Content agentBResponse = Content.fromParts(Part.fromText("Summary: AI is important")); + TestLlm llmB = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(agentBResponse) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agentB = + LlmAgent.builder() + .name("AgentB") + .model(llmB) + .instruction("Summarize topic: ${topic}") + .outputKey("summary") + .build(); + + SequentialAgent sequential = + SequentialAgent.builder() + .name("SequentialAgent") + .subAgents(ImmutableList.of(agentA, agentB)) + .build(); + + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + InvocationContext ctx = + InvocationContext.builder() + .sessionService(sessionService) + .session(session) + .agent(sequential) + .build(); + + List events = + sequential + .runAsync(ctx) + .flatMap(event -> ctx.sessionService().appendEvent(ctx.session(), event).toFlowable()) + .toList() + .blockingGet(); + + assertNotNull(events); + assertTrue(events.size() >= 2, "Expected at least 2 events from sequential agents"); + + assertEquals("The topic is: AI", ctx.session().state().get("topic")); + assertEquals("Summary: AI is important", ctx.session().state().get("summary")); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, session.id(), Optional.empty()) + .blockingGet(); + + assertNotNull(retrievedSession); + assertEquals(session.id(), retrievedSession.id()); + assertEquals("The topic is: AI", retrievedSession.state().get("topic")); + assertEquals("Summary: AI is important", retrievedSession.state().get("summary")); + assertTrue( + retrievedSession.events().size() >= 2, "Expected at least 2 events persisted in database"); + } + + @Test + public void testAgentWithAppAndUserStatePersistence() { + Content configResponse = Content.fromParts(Part.fromText("{\"version\": \"1.0\"}")); + TestLlm llmA = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(configResponse) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agentA = + LlmAgent.builder() + .name("ConfigAgent") + .model(llmA) + .instruction("Return config") + .outputKey("app:config") + .build(); + + Content prefResponse = Content.fromParts(Part.fromText("dark")); + TestLlm llmB = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(prefResponse) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agentB = + LlmAgent.builder() + .name("PreferenceAgent") + .model(llmB) + .instruction("Use config: ${app:config}") + .outputKey("user:theme") + .build(); + + SequentialAgent sequential = + SequentialAgent.builder() + .name("Sequential") + .subAgents(ImmutableList.of(agentA, agentB)) + .build(); + + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + InvocationContext ctx = + InvocationContext.builder() + .sessionService(sessionService) + .session(session) + .agent(sequential) + .build(); + + sequential + .runAsync(ctx) + .flatMap(event -> ctx.sessionService().appendEvent(ctx.session(), event).toFlowable()) + .toList() + .blockingGet(); + + assertEquals("{\"version\": \"1.0\"}", ctx.session().state().get("app:config")); + assertEquals("dark", ctx.session().state().get("user:theme")); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, session.id(), Optional.empty()) + .blockingGet(); + + assertNotNull(retrievedSession); + assertEquals("{\"version\": \"1.0\"}", retrievedSession.state().get("app:config")); + assertEquals("dark", retrievedSession.state().get("user:theme")); + } + + @Test + public void testAgentStatePersistedAcrossSessions() { + Content response1 = Content.fromParts(Part.fromText("User preference stored")); + TestLlm llm1 = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(response1) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agent1 = + LlmAgent.builder() + .name("PreferenceAgent") + .model(llm1) + .instruction("Store preference") + .outputKey("user:language") + .build(); + + Session session1 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + InvocationContext ctx1 = + InvocationContext.builder() + .sessionService(sessionService) + .session(session1) + .agent(agent1) + .build(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("user:language", "English"); + + Event event = + Event.builder() + .id(java.util.UUID.randomUUID().toString()) + .author(agent1.name()) + .content(response1) + .actions(com.google.adk.events.EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session1, event).blockingGet(); + + Session session2 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + assertNotNull(session2); + assertEquals("English", session2.state().get("user:language")); + } + + @Test + public void testRegularStateIsolatedBetweenSessions() { + Session session1 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("session_data", "session1_value"); + + Event event = + Event.builder() + .id(java.util.UUID.randomUUID().toString()) + .author("test-agent") + .content(Content.fromParts(Part.fromText("Session 1 data"))) + .actions(com.google.adk.events.EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session1, event).blockingGet(); + + assertEquals("session1_value", session1.state().get("session_data")); + + Session session2 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + assertNotNull(session2); + assertNull( + session2.state().get("session_data"), "Regular state should not persist across sessions"); + } + + @Test + public void testAppStatePersistedAcrossSessions() { + Session session1 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("app:api_key", "key-12345"); + + Event event = + Event.builder() + .id(java.util.UUID.randomUUID().toString()) + .author("test-agent") + .content(Content.fromParts(Part.fromText("App config stored"))) + .actions(com.google.adk.events.EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session1, event).blockingGet(); + + Session session2 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + assertNotNull(session2); + assertEquals("key-12345", session2.state().get("app:api_key")); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/MySQLIntegrationTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/MySQLIntegrationTest.java new file mode 100644 index 000000000..fa13b9ee9 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/MySQLIntegrationTest.java @@ -0,0 +1,140 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.adk.testing.TestDatabaseConfig; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +@Tag("integration") +public class MySQLIntegrationTest { + + private static final String TEST_DB_URL = TestDatabaseConfig.MYSQL_JDBC_URL; + private String TEST_APP_NAME; + private String TEST_USER_ID; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + assumeTrue( + TestDatabaseConfig.isMySQLAvailable(), + TestDatabaseConfig.getDatabaseNotAvailableMessage("MySQL")); + + TEST_APP_NAME = "jdbc-mysql-test-app-" + System.currentTimeMillis(); + TEST_USER_ID = "jdbc-mysql-test-user-" + System.currentTimeMillis(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testBasicSessionOperations() { + String sessionId = "mysql-basic-test-" + System.currentTimeMillis(); + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("key", "value"); + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + assertNotNull(session); + assertEquals(sessionId, session.id()); + assertEquals("value", session.state().get("key")); + } + + @Test + public void testEventActionsWithStateDelta() { + String sessionId = "mysql-actions-test-" + System.currentTimeMillis(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("count", 1); + stateDelta.put("app:shared", "global"); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test event"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals(1, retrieved.state().get("count")); + assertEquals("global", retrieved.state().get("app:shared")); + } + + @Test + public void testJSONStorageAndRetrieval() { + String sessionId = "mysql-json-test-" + System.currentTimeMillis(); + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("nested", java.util.Map.of("inner", "value")); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertNotNull(retrieved.state().get("nested")); + } + + @Test + public void testUpsertAppState() { + String sessionId1 = "mysql-upsert-1-" + System.currentTimeMillis(); + String sessionId2 = "mysql-upsert-2-" + System.currentTimeMillis(); + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("app:config", "value1"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("app:config", "value2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state2, sessionId2).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId1, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("value2", retrieved.state().get("app:config")); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/NegativeTestCases.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/NegativeTestCases.java new file mode 100644 index 000000000..425625092 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/NegativeTestCases.java @@ -0,0 +1,400 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.adk.events.Event; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.flywaydb.core.Flyway; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class NegativeTestCases { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:negative_test;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + private static final String TEST_APP_NAME = "negative-test-app"; + private static final String TEST_USER_ID = "negative-test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + flyway.clean(); + flyway.migrate(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testCreateSessionWithNullAppName() { + assertThrows( + NullPointerException.class, + () -> + sessionService + .createSession(null, TEST_USER_ID, new ConcurrentHashMap<>(), "session-1") + .blockingGet()); + } + + @Test + public void testCreateSessionWithNullUserId() { + assertThrows( + NullPointerException.class, + () -> + sessionService + .createSession(TEST_APP_NAME, null, new ConcurrentHashMap<>(), "session-1") + .blockingGet()); + } + + @Test + public void testCreateSessionWithNullState() { + assertThrows( + NullPointerException.class, + () -> + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, null, "session-1") + .blockingGet()); + } + + @Test + public void testGetSessionWithNullAppName() { + assertThrows( + NullPointerException.class, + () -> + sessionService + .getSession(null, TEST_USER_ID, "session-1", Optional.empty()) + .blockingGet()); + } + + @Test + public void testGetSessionWithNullUserId() { + assertThrows( + NullPointerException.class, + () -> + sessionService + .getSession(TEST_APP_NAME, null, "session-1", Optional.empty()) + .blockingGet()); + } + + @Test + public void testGetSessionWithNullSessionId() { + assertThrows( + NullPointerException.class, + () -> + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, null, Optional.empty()) + .blockingGet()); + } + + @Test + public void testAppendEventToDeletedSession() { + String sessionId = "deleted-session"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + sessionService.deleteSession(TEST_APP_NAME, TEST_USER_ID, sessionId).blockingAwait(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + assertThrows( + SessionNotFoundException.class, + () -> + sessionService + .appendEvent( + Session.builder(sessionId) + .appName(TEST_APP_NAME) + .userId(TEST_USER_ID) + .state(new ConcurrentHashMap<>()) + .events(new ArrayList<>()) + .build(), + event) + .blockingGet()); + } + + @Test + public void testDeleteNonExistentSession() { + sessionService.deleteSession(TEST_APP_NAME, TEST_USER_ID, "non-existent").blockingAwait(); + } + + @Test + public void testGetNonExistentSession() { + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, "non-existent", Optional.empty()) + .blockingGet(); + assertNull(session); + } + + @Test + public void testCreateSessionWithEmptySessionId() { + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), "") + .blockingGet(); + assertNotNull(session); + } + + @Test + public void testCreateSessionWithVeryLongSessionId() { + String longId = "a".repeat(200); + + Exception exception = + assertThrows( + Exception.class, + () -> + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), longId) + .blockingGet()); + + assertTrue( + exception.getMessage().contains("too long") || exception.getCause() != null, + "Should fail with constraint violation for long session ID"); + } + + @Test + public void testCreateDuplicateSession() { + String sessionId = "duplicate-session"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + assertThrows( + Exception.class, + () -> + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet()); + } + + @Test + public void testAppendEventWithDuplicateId() { + String sessionId = "duplicate-event-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + String eventId = "duplicate-event-id"; + Event event1 = + Event.builder() + .id(eventId) + .author("author-1") + .content(Content.fromParts(Part.fromText("Event 1"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event1).blockingGet(); + + Event event2 = + Event.builder() + .id(eventId) + .author("author-2") + .content(Content.fromParts(Part.fromText("Event 2"))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + assertThrows( + Exception.class, + () -> + sessionService + .appendEvent( + Session.builder(sessionId) + .appName(TEST_APP_NAME) + .userId(TEST_USER_ID) + .state(new ConcurrentHashMap<>()) + .events(new ArrayList<>()) + .build(), + event2) + .blockingGet()); + } + + @Test + public void testStateDeltaWithComplexNestedStructures() { + String sessionId = "complex-state-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ConcurrentHashMap complexDelta = new ConcurrentHashMap<>(); + complexDelta.put("level1", Map.of("level2", Map.of("level3", "deep-value"))); + complexDelta.put("array", java.util.List.of(1, 2, 3, 4, 5)); + complexDelta.put("mixed", Map.of("num", 42, "str", "text", "bool", true)); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Complex state"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(com.google.adk.events.EventActions.builder().stateDelta(complexDelta).build()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session updatedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(updatedSession); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(retrieved); + assertTrue(retrieved.state().containsKey("level1")); + assertTrue(retrieved.state().containsKey("array")); + assertTrue(retrieved.state().containsKey("mixed")); + } + + @Test + public void testGetSessionWithInvalidConfig() { + String sessionId = "invalid-config-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + GetSessionConfig negativeNumEvents = GetSessionConfig.builder().numRecentEvents(-1).build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(negativeNumEvents)) + .blockingGet(); + assertNotNull(session); + } + + @Test + public void testConcurrentDeleteAndRead() throws InterruptedException { + String sessionId = "concurrent-delete-read-test"; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + Thread deleter = + new Thread( + () -> { + try { + Thread.sleep(50); + sessionService + .deleteSession(TEST_APP_NAME, TEST_USER_ID, sessionId) + .blockingAwait(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + + Thread reader = + new Thread( + () -> { + for (int i = 0; i < 10; i++) { + try { + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + Thread.sleep(20); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + break; + } catch (Exception e) { + } + } + }); + + deleter.start(); + reader.start(); + + deleter.join(); + reader.join(); + + Session finalCheck = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNull(finalCheck, "Session should be deleted"); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/PessimisticLockingTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/PessimisticLockingTest.java new file mode 100644 index 000000000..f19461d2e --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/PessimisticLockingTest.java @@ -0,0 +1,340 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.adk.events.Event; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.flywaydb.core.Flyway; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class PessimisticLockingTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:locking_test;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + private static final String TEST_APP_NAME = "locking-test-app"; + private static final String TEST_USER_ID = "locking-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + flyway.clean(); + flyway.migrate(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testSerializedEventAppends() throws InterruptedException { + String sessionId = "serialized-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + int threadCount = 20; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + AtomicInteger successCount = new AtomicInteger(0); + + for (int i = 0; i < threadCount; i++) { + final int eventNum = i; + executor.submit( + () -> { + try { + Event event = + Event.builder() + .id("event-" + eventNum) + .author("thread-" + eventNum) + .content(Content.fromParts(Part.fromText("Message " + eventNum))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + successCount.incrementAndGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(60, TimeUnit.SECONDS)); + executor.shutdown(); + + assertEquals(threadCount, successCount.get()); + + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(finalSession); + assertEquals(threadCount, finalSession.events().size()); + } + + @Test + public void testNoLostUpdates() throws InterruptedException { + String sessionId = "no-lost-updates"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put("counter", 0); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + int threadCount = 10; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + + for (int i = 0; i < threadCount; i++) { + final int eventNum = i; + executor.submit( + () -> { + try { + Event event = + Event.builder() + .id("event-" + eventNum) + .author("thread-" + eventNum) + .content(Content.fromParts(Part.fromText("Message " + eventNum))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(60, TimeUnit.SECONDS)); + executor.shutdown(); + + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(finalSession); + assertEquals(threadCount, finalSession.events().size()); + } + + @Test + public void testConcurrentAppendDifferentSessions() throws InterruptedException { + int sessionCount = 5; + int eventsPerSession = 10; + + for (int i = 0; i < sessionCount; i++) { + String sessionId = "session-" + i; + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + } + + ExecutorService executor = Executors.newFixedThreadPool(sessionCount * eventsPerSession); + CountDownLatch latch = new CountDownLatch(sessionCount * eventsPerSession); + + for (int i = 0; i < sessionCount; i++) { + final String sessionId = "session-" + i; + for (int j = 0; j < eventsPerSession; j++) { + final int eventNum = j; + executor.submit( + () -> { + try { + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Event " + eventNum))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + } + + assertTrue(latch.await(60, TimeUnit.SECONDS)); + executor.shutdown(); + + for (int i = 0; i < sessionCount; i++) { + String sessionId = "session-" + i; + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(session); + assertEquals(eventsPerSession, session.events().size()); + } + } + + @Test + public void testAppendEventUnderLoad() throws InterruptedException { + String sessionId = "load-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + int threadCount = 50; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + AtomicInteger successCount = new AtomicInteger(0); + AtomicInteger failureCount = new AtomicInteger(0); + + for (int i = 0; i < threadCount; i++) { + final int eventNum = i; + executor.submit( + () -> { + try { + Event event = + Event.builder() + .id("event-" + eventNum) + .author("thread-" + eventNum) + .content(Content.fromParts(Part.fromText("Load test " + eventNum))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + successCount.incrementAndGet(); + } catch (Exception e) { + failureCount.incrementAndGet(); + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(120, TimeUnit.SECONDS)); + executor.shutdown(); + + assertEquals(threadCount, successCount.get()); + assertEquals(0, failureCount.get()); + + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(finalSession); + assertEquals(threadCount, finalSession.events().size()); + } + + @Test + public void testEventOrderingConsistency() throws InterruptedException { + String sessionId = "ordering-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + int eventCount = 100; + ExecutorService executor = Executors.newFixedThreadPool(10); + CountDownLatch latch = new CountDownLatch(eventCount); + + for (int i = 0; i < eventCount; i++) { + final int eventNum = i; + executor.submit( + () -> { + try { + Event event = + Event.builder() + .id("event-" + String.format("%03d", eventNum)) + .author("test") + .content(Content.fromParts(Part.fromText("Message " + eventNum))) + .timestamp(Instant.now().toEpochMilli() + eventNum) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }); + } + + assertTrue(latch.await(120, TimeUnit.SECONDS)); + executor.shutdown(); + + Session finalSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(finalSession); + assertEquals(eventCount, finalSession.events().size()); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/PostgreSQLAgentIntegrationTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/PostgreSQLAgentIntegrationTest.java new file mode 100644 index 000000000..7db4e3745 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/PostgreSQLAgentIntegrationTest.java @@ -0,0 +1,364 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.SequentialAgent; +import com.google.adk.events.Event; +import com.google.adk.models.LlmResponse; +import com.google.adk.testing.TestDatabaseConfig; +import com.google.adk.testing.TestLlm; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +/** + * Integration tests for Agents using DatabaseSessionService with real PostgreSQL 16 database. + * + *

This test suite verifies that agents work correctly with PostgreSQL-backed session + * persistence, including: - Sequential agent execution with database persistence - State + * propagation between agents via outputKey - Event storage and retrieval - App/user/session state + * management with database backend - JSONB storage for complex state data + * + *

Prerequisites: Start PostgreSQL test database with: + * + *

{@code
+ * docker-compose -f scripts/docker-compose.test.yml up -d postgres-test
+ * }
+ * + *

Configuration: - Host: localhost:5433 - Database: adk_test - User: adk_user - Password: + * adk_password + */ +@Tag("integration") +public class PostgreSQLAgentIntegrationTest { + + private static final String TEST_DB_URL = TestDatabaseConfig.POSTGRES_JDBC_URL; + private static final String TEST_APP_NAME = "postgres-agent-integration-test"; + private static final String TEST_USER_ID = "agent-test-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + assumeTrue( + TestDatabaseConfig.isPostgreSQLAvailable(), + TestDatabaseConfig.getDatabaseNotAvailableMessage("PostgreSQL")); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testSequentialAgentWithDatabasePersistence() { + Content agentAResponse = Content.fromParts(Part.fromText("The topic is: Machine Learning")); + TestLlm llmA = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(agentAResponse) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agentA = + LlmAgent.builder() + .name("AgentA") + .model(llmA) + .instruction("Extract topic") + .outputKey("topic") + .build(); + + Content agentBResponse = + Content.fromParts(Part.fromText("Summary: Machine Learning is transformative")); + TestLlm llmB = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(agentBResponse) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agentB = + LlmAgent.builder() + .name("AgentB") + .model(llmB) + .instruction("Summarize topic: ${topic}") + .outputKey("summary") + .build(); + + SequentialAgent sequential = + SequentialAgent.builder() + .name("SequentialAgent") + .subAgents(ImmutableList.of(agentA, agentB)) + .build(); + + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + InvocationContext ctx = + InvocationContext.builder() + .sessionService(sessionService) + .session(session) + .agent(sequential) + .build(); + + List events = + sequential + .runAsync(ctx) + .flatMap(event -> ctx.sessionService().appendEvent(ctx.session(), event).toFlowable()) + .toList() + .blockingGet(); + + assertNotNull(events); + assertTrue(events.size() >= 2, "Expected at least 2 events from sequential agents"); + + assertEquals("The topic is: Machine Learning", ctx.session().state().get("topic")); + assertEquals( + "Summary: Machine Learning is transformative", ctx.session().state().get("summary")); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, session.id(), Optional.empty()) + .blockingGet(); + + assertNotNull(retrievedSession); + assertEquals(session.id(), retrievedSession.id()); + assertEquals("The topic is: Machine Learning", retrievedSession.state().get("topic")); + assertEquals( + "Summary: Machine Learning is transformative", retrievedSession.state().get("summary")); + assertTrue( + retrievedSession.events().size() >= 2, "Expected at least 2 events persisted in database"); + } + + @Test + public void testAgentWithAppAndUserStatePersistence() { + Content configResponse = + Content.fromParts( + Part.fromText("{\"version\": \"2.0\", \"feature_flags\": {\"new_ui\": true}}")); + TestLlm llmA = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(configResponse) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agentA = + LlmAgent.builder() + .name("ConfigAgent") + .model(llmA) + .instruction("Return config") + .outputKey("app:config") + .build(); + + Content prefResponse = Content.fromParts(Part.fromText("light")); + TestLlm llmB = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(prefResponse) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agentB = + LlmAgent.builder() + .name("PreferenceAgent") + .model(llmB) + .instruction("Use config: ${app:config}") + .outputKey("user:theme") + .build(); + + SequentialAgent sequential = + SequentialAgent.builder() + .name("Sequential") + .subAgents(ImmutableList.of(agentA, agentB)) + .build(); + + Session session = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + InvocationContext ctx = + InvocationContext.builder() + .sessionService(sessionService) + .session(session) + .agent(sequential) + .build(); + + sequential + .runAsync(ctx) + .flatMap(event -> ctx.sessionService().appendEvent(ctx.session(), event).toFlowable()) + .toList() + .blockingGet(); + + assertEquals( + "{\"version\": \"2.0\", \"feature_flags\": {\"new_ui\": true}}", + ctx.session().state().get("app:config")); + assertEquals("light", ctx.session().state().get("user:theme")); + + Session retrievedSession = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, session.id(), Optional.empty()) + .blockingGet(); + + assertNotNull(retrievedSession); + assertEquals( + "{\"version\": \"2.0\", \"feature_flags\": {\"new_ui\": true}}", + retrievedSession.state().get("app:config")); + assertEquals("light", retrievedSession.state().get("user:theme")); + } + + @Test + public void testAgentStatePersistedAcrossSessions() { + Content response1 = Content.fromParts(Part.fromText("User preference stored")); + TestLlm llm1 = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content(response1) + .partial(false) + .turnComplete(true) + .build())); + + LlmAgent agent1 = + LlmAgent.builder() + .name("PreferenceAgent") + .model(llm1) + .instruction("Store preference") + .outputKey("user:language") + .build(); + + Session session1 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + InvocationContext ctx1 = + InvocationContext.builder() + .sessionService(sessionService) + .session(session1) + .agent(agent1) + .build(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("user:language", "French"); + + Event event = + Event.builder() + .id(java.util.UUID.randomUUID().toString()) + .author(agent1.name()) + .content(response1) + .actions(com.google.adk.events.EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session1, event).blockingGet(); + + Session session2 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + assertNotNull(session2); + assertEquals("French", session2.state().get("user:language")); + } + + @Test + public void testRegularStateIsolatedBetweenSessions() { + Session session1 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("session_data", "session1_value"); + + Event event = + Event.builder() + .id(java.util.UUID.randomUUID().toString()) + .author("test-agent") + .content(Content.fromParts(Part.fromText("Session 1 data"))) + .actions(com.google.adk.events.EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session1, event).blockingGet(); + + assertEquals("session1_value", session1.state().get("session_data")); + + Session session2 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + assertNotNull(session2); + assertNull( + session2.state().get("session_data"), "Regular state should not persist across sessions"); + } + + @Test + public void testAppStatePersistedAcrossSessions() { + Session session1 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("app:api_key", "key-67890"); + + Event event = + Event.builder() + .id(java.util.UUID.randomUUID().toString()) + .author("test-agent") + .content(Content.fromParts(Part.fromText("App config stored"))) + .actions(com.google.adk.events.EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + sessionService.appendEvent(session1, event).blockingGet(); + + Session session2 = + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), null) + .blockingGet(); + + assertNotNull(session2); + assertEquals("key-67890", session2.state().get("app:api_key")); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/PostgreSQLIntegrationTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/PostgreSQLIntegrationTest.java new file mode 100644 index 000000000..cd13578f0 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/PostgreSQLIntegrationTest.java @@ -0,0 +1,140 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.adk.testing.TestDatabaseConfig; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +@Tag("integration") +public class PostgreSQLIntegrationTest { + + private static final String TEST_DB_URL = TestDatabaseConfig.POSTGRES_JDBC_URL; + private String TEST_APP_NAME; + private String TEST_USER_ID; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + assumeTrue( + TestDatabaseConfig.isPostgreSQLAvailable(), + TestDatabaseConfig.getDatabaseNotAvailableMessage("PostgreSQL")); + + TEST_APP_NAME = "jdbc-postgres-test-app-" + System.currentTimeMillis(); + TEST_USER_ID = "jdbc-postgres-test-user-" + System.currentTimeMillis(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testBasicSessionOperations() { + String sessionId = "postgres-basic-test-" + System.currentTimeMillis(); + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("key", "value"); + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + assertNotNull(session); + assertEquals(sessionId, session.id()); + assertEquals("value", session.state().get("key")); + } + + @Test + public void testEventActionsWithStateDelta() { + String sessionId = "postgres-actions-test-" + System.currentTimeMillis(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("count", 1); + stateDelta.put("app:shared", "global"); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test event"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals(1, retrieved.state().get("count")); + assertEquals("global", retrieved.state().get("app:shared")); + } + + @Test + public void testJSONBStorageAndRetrieval() { + String sessionId = "postgres-jsonb-test-" + System.currentTimeMillis(); + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("nested", java.util.Map.of("inner", "value")); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertNotNull(retrieved.state().get("nested")); + } + + @Test + public void testUpsertAppState() { + String sessionId1 = "postgres-upsert-1-" + System.currentTimeMillis(); + String sessionId2 = "postgres-upsert-2-" + System.currentTimeMillis(); + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("app:config", "value1"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("app:config", "value2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state2, sessionId2).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId1, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("value2", retrieved.state().get("app:config")); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/ReadTwiceNonDestructiveTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/ReadTwiceNonDestructiveTest.java new file mode 100644 index 000000000..224c6c2b9 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/ReadTwiceNonDestructiveTest.java @@ -0,0 +1,267 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import com.google.adk.events.Event; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.flywaydb.core.Flyway; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ReadTwiceNonDestructiveTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:read_twice_test;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + private static final String TEST_APP_NAME = "read-twice-test-app"; + private static final String TEST_USER_ID = "read-twice-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + flyway.clean(); + flyway.migrate(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testGetSessionTwiceDoesNotDeleteEvents() { + String sessionId = "non-destructive-read-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + Session firstRead = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(firstRead); + assertEquals(5, firstRead.events().size(), "First read should return 5 events"); + + Session secondRead = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(secondRead); + assertEquals(5, secondRead.events().size(), "Second read should still return 5 events"); + + Session thirdRead = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertNotNull(thirdRead); + assertEquals(5, thirdRead.events().size(), "Third read should still return 5 events"); + } + + @Test + public void testGetSessionMultipleTimesWithFiltering() { + String sessionId = "filter-read-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Instant startTime = Instant.now(); + + for (int i = 1; i <= 10; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(startTime.plusSeconds(i).toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + Session allEvents = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertEquals(10, allEvents.events().size()); + + GetSessionConfig recentConfig = GetSessionConfig.builder().numRecentEvents(3).build(); + Session recentEvents = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(recentConfig)) + .blockingGet(); + assertEquals(3, recentEvents.events().size()); + + Session allEventsAgain = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertEquals(10, allEventsAgain.events().size(), "All events should still exist in DB"); + + GetSessionConfig timestampConfig = + GetSessionConfig.builder().afterTimestamp(startTime.plusSeconds(5)).build(); + Session filteredByTime = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(timestampConfig)) + .blockingGet(); + assertEquals(5, filteredByTime.events().size()); + + Session finalRead = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertEquals(10, finalRead.events().size(), "Events should persist after all filtered reads"); + } + + @Test + public void testConcurrentReadsDoNotAffectData() throws InterruptedException { + String sessionId = "concurrent-read-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + for (int i = 1; i <= 5; i++) { + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(Instant.now().toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + Thread[] readers = new Thread[10]; + for (int i = 0; i < 10; i++) { + readers[i] = + new Thread( + () -> { + for (int j = 0; j < 5; j++) { + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertEquals( + 5, session.events().size(), "Each concurrent read should return 5 events"); + } + }); + readers[i].start(); + } + + for (Thread reader : readers) { + reader.join(); + } + + Session finalCheck = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertEquals( + 5, finalCheck.events().size(), "Events should remain intact after concurrent reads"); + } + + @Test + public void testReadWithDifferentConfigs() { + String sessionId = "config-variation-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + Instant startTime = Instant.now(); + + for (int i = 1; i <= 10; i++) { + Event event = + Event.builder() + .id("event-" + i) + .author("test-author") + .content(Content.fromParts(Part.fromText("Event " + i))) + .timestamp(startTime.plusSeconds(i).toEpochMilli()) + .build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + } + + for (int recentCount = 1; recentCount <= 10; recentCount++) { + GetSessionConfig config = GetSessionConfig.builder().numRecentEvents(recentCount).build(); + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.of(config)) + .blockingGet(); + assertEquals( + recentCount, session.events().size(), "Should get " + recentCount + " recent events"); + } + + Session fullRead = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + assertEquals( + 10, + fullRead.events().size(), + "All 10 events should still exist after multiple config reads"); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/SpannerIntegrationTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/SpannerIntegrationTest.java new file mode 100644 index 000000000..710b2f75a --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/SpannerIntegrationTest.java @@ -0,0 +1,140 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.adk.testing.TestDatabaseConfig; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +@Tag("integration") +public class SpannerIntegrationTest { + + private static final String TEST_DB_URL = TestDatabaseConfig.SPANNER_JDBC_URL; + private String TEST_APP_NAME; + private String TEST_USER_ID; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + assumeTrue( + TestDatabaseConfig.isSpannerAvailable(), + TestDatabaseConfig.getDatabaseNotAvailableMessage("Spanner")); + + TEST_APP_NAME = "jdbc-spanner-test-app-" + System.currentTimeMillis(); + TEST_USER_ID = "jdbc-spanner-test-user-" + System.currentTimeMillis(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testBasicSessionOperations() { + String sessionId = "spanner-basic-test-" + System.currentTimeMillis(); + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("key", "value"); + + Session session = + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + assertNotNull(session); + assertEquals(sessionId, session.id()); + assertEquals("value", session.state().get("key")); + } + + @Test + public void testEventActionsWithStateDelta() { + String sessionId = "spanner-actions-test-" + System.currentTimeMillis(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ConcurrentHashMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("count", 1); + stateDelta.put("_app_shared", "global"); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test-author") + .content(Content.fromParts(Part.fromText("Test event"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + sessionService.appendEvent(session, event).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals(1, retrieved.state().get("count")); + assertEquals("global", retrieved.state().get("_app_shared")); + } + + @Test + public void testJSONStorageAndRetrieval() { + String sessionId = "spanner-json-test-" + System.currentTimeMillis(); + ConcurrentHashMap state = new ConcurrentHashMap<>(); + state.put("nested", java.util.Map.of("inner", "value")); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state, sessionId).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertNotNull(retrieved.state().get("nested")); + } + + @Test + public void testUpsertAppState() { + String sessionId1 = "spanner-upsert-1-" + System.currentTimeMillis(); + String sessionId2 = "spanner-upsert-2-" + System.currentTimeMillis(); + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("app:config", "value1"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("app:config", "value2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID, state2, sessionId2).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId1, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("value2", retrieved.state().get("app:config")); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/StateDeltaTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/StateDeltaTest.java new file mode 100644 index 000000000..dc4467b4b --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/StateDeltaTest.java @@ -0,0 +1,400 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.flywaydb.core.Flyway; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class StateDeltaTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:jdbc_delta_test;DB_CLOSE_DELAY=-1;USER=sa;PASSWORD="; + private static final String TEST_APP_NAME = "delta-test-app"; + private static final String TEST_USER_ID = "delta-user"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + Flyway flyway = + Flyway.configure() + .dataSource(TEST_DB_URL, null, null) + .locations("classpath:db/migration/h2") + .cleanDisabled(false) + .load(); + flyway.clean(); + flyway.migrate(); + + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testStateDeltaInEvent() { + String sessionId = "delta-event-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put("counter", 0); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put("counter", 1); + delta.put("new_field", "added"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + assertEquals(1, updated.state().get("counter")); + assertEquals("added", updated.state().get("new_field")); + } + + @Test + public void testAppStateDeltaInEvent() { + String sessionId = "app-delta-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.APP_PREFIX + "app_counter", 0); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "app_counter", 10); + delta.put(State.APP_PREFIX + "app_field", "app_value"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + assertEquals(10, updated.state().get(State.APP_PREFIX + "app_counter")); + assertEquals("app_value", updated.state().get(State.APP_PREFIX + "app_field")); + } + + @Test + public void testUserStateDeltaInEvent() { + String sessionId = "user-delta-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.USER_PREFIX + "user_counter", 0); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.USER_PREFIX + "user_counter", 5); + delta.put(State.USER_PREFIX + "user_field", "user_value"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + assertEquals(5, updated.state().get(State.USER_PREFIX + "user_counter")); + assertEquals("user_value", updated.state().get(State.USER_PREFIX + "user_field")); + } + + @Test + public void testMixedStateDeltaInEvent() { + String sessionId = "mixed-delta-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.APP_PREFIX + "app_value", "initial_app"); + initialState.put(State.USER_PREFIX + "user_value", "initial_user"); + initialState.put("session_value", "initial_session"); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "app_value", "updated_app"); + delta.put(State.USER_PREFIX + "user_value", "updated_user"); + delta.put("session_value", "updated_session"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + assertEquals("updated_app", updated.state().get(State.APP_PREFIX + "app_value")); + assertEquals("updated_user", updated.state().get(State.USER_PREFIX + "user_value")); + assertEquals("updated_session", updated.state().get("session_value")); + } + + @Test + public void testStateRemovalViaRemoved() { + String sessionId = "session-removal-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put("key_to_remove", "value"); + initialState.put("key_to_keep", "keep_this"); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put("key_to_remove", State.REMOVED); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + assertFalse(updated.state().containsKey("key_to_remove")); + assertEquals("keep_this", updated.state().get("key_to_keep")); + } + + @Test + public void testAppStateRemovalViaRemoved() { + String sessionId = "app-removal-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.APP_PREFIX + "deprecated", "old_value"); + initialState.put(State.APP_PREFIX + "current", "keep_this"); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "deprecated", State.REMOVED); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + assertFalse(updated.state().containsKey(State.APP_PREFIX + "deprecated")); + assertEquals("keep_this", updated.state().get(State.APP_PREFIX + "current")); + } + + @Test + public void testUserStateRemovalViaRemoved() { + String sessionId = "user-removal-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.USER_PREFIX + "old_pref", "remove_me"); + initialState.put(State.USER_PREFIX + "new_pref", "keep_this"); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.USER_PREFIX + "old_pref", State.REMOVED); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + assertFalse(updated.state().containsKey(State.USER_PREFIX + "old_pref")); + assertEquals("keep_this", updated.state().get(State.USER_PREFIX + "new_pref")); + } + + @Test + public void testMixedStateRemovalViaRemoved() { + String sessionId = "mixed-removal-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put(State.APP_PREFIX + "app_deprecated", "remove"); + initialState.put(State.APP_PREFIX + "app_current", "keep"); + initialState.put(State.USER_PREFIX + "user_old", "remove"); + initialState.put(State.USER_PREFIX + "user_new", "keep"); + initialState.put("session_temp", "remove"); + initialState.put("session_data", "keep"); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID, initialState, sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "app_deprecated", State.REMOVED); + delta.put(State.USER_PREFIX + "user_old", State.REMOVED); + delta.put("session_temp", State.REMOVED); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + Session updated = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(updated); + assertFalse(updated.state().containsKey(State.APP_PREFIX + "app_deprecated")); + assertFalse(updated.state().containsKey(State.USER_PREFIX + "user_old")); + assertFalse(updated.state().containsKey("session_temp")); + assertEquals("keep", updated.state().get(State.APP_PREFIX + "app_current")); + assertEquals("keep", updated.state().get(State.USER_PREFIX + "user_new")); + assertEquals("keep", updated.state().get("session_data")); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/StateManagementTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/StateManagementTest.java new file mode 100644 index 000000000..66301903c --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/StateManagementTest.java @@ -0,0 +1,392 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class StateManagementTest { + + private static final String TEST_DB_URL = + "jdbc:h2:mem:state_test;DB_CLOSE_DELAY=-1;MODE=PostgreSQL"; + private static final String TEST_APP_NAME = "state-test-app"; + private static final String TEST_USER_ID_1 = "user-1"; + private static final String TEST_USER_ID_2 = "user-2"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + sessionService.close(); + } + } + + @Test + public void testAppStateSharing() { + String sessionId1 = "session-1"; + String sessionId2 = "session-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put(State.APP_PREFIX + "global_setting", "shared_value"); + state1.put("local", "private_value_1"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put(State.APP_PREFIX + "global_setting", "updated_value"); + state2.put("local", "private_value_2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_2, state2, sessionId2).blockingGet(); + + Session retrieved1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId1, Optional.empty()) + .blockingGet(); + Session retrieved2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_2, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved1); + assertNotNull(retrieved2); + + assertEquals("updated_value", retrieved1.state().get(State.APP_PREFIX + "global_setting")); + assertEquals("updated_value", retrieved2.state().get(State.APP_PREFIX + "global_setting")); + + assertEquals("private_value_1", retrieved1.state().get("local")); + assertEquals("private_value_2", retrieved2.state().get("local")); + } + + @Test + public void testUserStateSharing() { + String sessionId1 = "user-session-1"; + String sessionId2 = "user-session-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put(State.USER_PREFIX + "preference", "dark_mode"); + state1.put("data", "session_specific_1"); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put(State.USER_PREFIX + "preference", "light_mode"); + state2.put("data", "session_specific_2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session session1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId1, Optional.empty()) + .blockingGet(); + Session session2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(session1); + assertNotNull(session2); + + assertEquals("light_mode", session1.state().get(State.USER_PREFIX + "preference")); + assertEquals("light_mode", session2.state().get(State.USER_PREFIX + "preference")); + + assertEquals("session_specific_1", session1.state().get("data")); + assertEquals("session_specific_2", session2.state().get("data")); + } + + @Test + public void testSessionStateIsolation() { + String sessionId1 = "isolated-1"; + String sessionId2 = "isolated-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("private_key", "value_1"); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("private_key", "value_2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session session1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId1, Optional.empty()) + .blockingGet(); + Session session2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(session1); + assertNotNull(session2); + + assertEquals("value_1", session1.state().get("private_key")); + assertEquals("value_2", session2.state().get("private_key")); + } + + @Test + public void testStatePriorityMerging() { + String sessionId = "priority-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put("_app_key", "app_value"); + initialState.put("_user_key", "user_value"); + initialState.put("key", "session_value"); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, initialState, sessionId) + .blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("app_value", retrieved.state().get("_app_key")); + assertEquals("user_value", retrieved.state().get("_user_key")); + assertEquals("session_value", retrieved.state().get("key")); + } + + @Test + public void testTempStateIsIgnored() { + String sessionId = "temp-test"; + + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put("temp:ignored", "should_not_persist"); + initialState.put("persisted", "should_persist"); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, initialState, sessionId) + .blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("should_persist", retrieved.state().get("persisted")); + assertEquals(null, retrieved.state().get("temp:ignored")); + } + + @Test + public void testStateMerge_putAllDoesNotLoseData() { + String sessionId1 = "merge-test-1"; + String sessionId2 = "merge-test-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put(State.APP_PREFIX + "key1", "value1"); + state1.put(State.APP_PREFIX + "key2", "value2"); + state1.put(State.APP_PREFIX + "key3", "value3"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put(State.APP_PREFIX + "key4", "value4"); + state2.put(State.APP_PREFIX + "key5", "value5"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("value1", retrieved.state().get(State.APP_PREFIX + "key1")); + assertEquals("value2", retrieved.state().get(State.APP_PREFIX + "key2")); + assertEquals("value3", retrieved.state().get(State.APP_PREFIX + "key3")); + assertEquals("value4", retrieved.state().get(State.APP_PREFIX + "key4")); + assertEquals("value5", retrieved.state().get(State.APP_PREFIX + "key5")); + } + + @Test + public void testStateMerge_nestedObjectsPreserved() { + String sessionId1 = "nested-merge-1"; + String sessionId2 = "nested-merge-2"; + + ConcurrentHashMap nestedMap1 = new ConcurrentHashMap<>(); + nestedMap1.put("nested_key_1", "nested_value_1"); + nestedMap1.put("nested_key_2", 42); + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put(State.APP_PREFIX + "config", nestedMap1); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + + ConcurrentHashMap nestedMap2 = new ConcurrentHashMap<>(); + nestedMap2.put("another_nested_key", "another_value"); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put(State.APP_PREFIX + "other_config", nestedMap2); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertNotNull(retrieved.state().get(State.APP_PREFIX + "config")); + assertNotNull(retrieved.state().get(State.APP_PREFIX + "other_config")); + } + + @Test + public void testStateMerge_overwriteExistingKeys() { + String sessionId1 = "overwrite-1"; + String sessionId2 = "overwrite-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put(State.APP_PREFIX + "shared_key", "original_value"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put(State.APP_PREFIX + "shared_key", "updated_value"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("updated_value", retrieved.state().get(State.APP_PREFIX + "shared_key")); + } + + @Test + public void testStateMerge_userStateDoesNotLoseData() { + String sessionId1 = "user-merge-1"; + String sessionId2 = "user-merge-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put(State.USER_PREFIX + "pref1", "value1"); + state1.put(State.USER_PREFIX + "pref2", "value2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put(State.USER_PREFIX + "pref3", "value3"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("value1", retrieved.state().get(State.USER_PREFIX + "pref1")); + assertEquals("value2", retrieved.state().get(State.USER_PREFIX + "pref2")); + assertEquals("value3", retrieved.state().get(State.USER_PREFIX + "pref3")); + } + + @Test + public void testStateMerge_sessionStateRemainsIsolated() { + String sessionId1 = "session-isolated-1"; + String sessionId2 = "session-isolated-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + state1.put("session_key1", "session_value1"); + state1.put(State.APP_PREFIX + "app_key", "shared"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + state2.put("session_key2", "session_value2"); + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session retrieved1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId1, Optional.empty()) + .blockingGet(); + Session retrieved2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved1); + assertNotNull(retrieved2); + + assertEquals("session_value1", retrieved1.state().get("session_key1")); + assertEquals(null, retrieved1.state().get("session_key2")); + + assertEquals(null, retrieved2.state().get("session_key1")); + assertEquals("session_value2", retrieved2.state().get("session_key2")); + + assertEquals("shared", retrieved1.state().get(State.APP_PREFIX + "app_key")); + assertEquals("shared", retrieved2.state().get(State.APP_PREFIX + "app_key")); + } + + @Test + public void testStateMerge_largeStateDoesNotLoseData() { + String sessionId1 = "large-state-1"; + String sessionId2 = "large-state-2"; + + ConcurrentHashMap state1 = new ConcurrentHashMap<>(); + for (int i = 0; i < 50; i++) { + state1.put(State.APP_PREFIX + "key_" + i, "value_" + i); + } + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state1, sessionId1).blockingGet(); + + ConcurrentHashMap state2 = new ConcurrentHashMap<>(); + for (int i = 50; i < 100; i++) { + state2.put(State.APP_PREFIX + "key_" + i, "value_" + i); + } + + sessionService.createSession(TEST_APP_NAME, TEST_USER_ID_1, state2, sessionId2).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + + for (int i = 0; i < 100; i++) { + assertEquals( + "value_" + i, + retrieved.state().get(State.APP_PREFIX + "key_" + i), + "Key " + i + " should not be lost during merge"); + } + } + + @Test + public void testStateMerge_roundTripSerialization() { + String sessionId = "roundtrip-test"; + + ConcurrentHashMap originalState = new ConcurrentHashMap<>(); + originalState.put(State.APP_PREFIX + "string_key", "string_value"); + originalState.put(State.APP_PREFIX + "int_key", 42); + originalState.put(State.APP_PREFIX + "double_key", 3.14); + originalState.put(State.APP_PREFIX + "boolean_key", true); + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, originalState, sessionId) + .blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals("string_value", retrieved.state().get(State.APP_PREFIX + "string_key")); + assertEquals(42, retrieved.state().get(State.APP_PREFIX + "int_key")); + assertEquals(3.14, retrieved.state().get(State.APP_PREFIX + "double_key")); + assertEquals(true, retrieved.state().get(State.APP_PREFIX + "boolean_key")); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/sessions/StatePrefixHandlingTest.java b/contrib/database-session-service/src/test/java/com/google/adk/sessions/StatePrefixHandlingTest.java new file mode 100644 index 000000000..5dcb1a2ff --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/sessions/StatePrefixHandlingTest.java @@ -0,0 +1,334 @@ +package com.google.adk.sessions; + +import static org.junit.jupiter.api.Assertions.*; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Integration test to verify that app/user state is stored WITHOUT prefixes in the database but + * retrieved WITH prefixes in the session state. + * + *

This ensures compatibility with: + * + *

    + *
  • Python DatabaseSessionService implementation + *
  • Java InMemorySessionService implementation + *
  • Proper state isolation and namespace handling + *
+ */ +public class StatePrefixHandlingTest { + + private static final String TEST_DB_URL = "jdbc:h2:mem:testdb_prefix;DB_CLOSE_DELAY=-1"; + private static final String TEST_APP_NAME = "test-app"; + private static final String TEST_USER_ID_1 = "user1"; + private static final String TEST_USER_ID_2 = "user2"; + + private DatabaseSessionService sessionService; + + @BeforeEach + public void setUp() { + sessionService = new DatabaseSessionService(TEST_DB_URL); + } + + @AfterEach + public void tearDown() { + if (sessionService != null) { + try (Connection conn = DriverManager.getConnection(TEST_DB_URL); + Statement stmt = conn.createStatement()) { + stmt.execute("DELETE FROM events"); + stmt.execute("DELETE FROM sessions"); + stmt.execute("DELETE FROM app_states"); + stmt.execute("DELETE FROM user_states"); + } catch (Exception e) { + } + sessionService.close(); + } + } + + @Test + public void testAppStatePrefixStrippedInDatabase() throws Exception { + String sessionId = "app-prefix-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "counter", 42); + delta.put(State.APP_PREFIX + "theme", "dark"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + + try (Connection conn = DriverManager.getConnection(TEST_DB_URL); + Statement stmt = conn.createStatement(); + ResultSet rs = + stmt.executeQuery( + "SELECT state FROM app_states WHERE app_name = '" + TEST_APP_NAME + "'")) { + + assertTrue(rs.next(), "App state should exist in database"); + String stateJson = rs.getString("state"); + assertNotNull(stateJson); + assertFalse( + stateJson.contains("\"app:counter\""), "Database should NOT contain 'app:' prefix"); + assertFalse(stateJson.contains("\"app:theme\""), "Database should NOT contain 'app:' prefix"); + assertTrue(stateJson.contains("\"counter\""), "Database should contain unprefixed 'counter'"); + assertTrue(stateJson.contains("\"theme\""), "Database should contain unprefixed 'theme'"); + } + } + + @Test + public void testUserStatePrefixStrippedInDatabase() throws Exception { + String sessionId = "user-prefix-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.USER_PREFIX + "preference", "enabled"); + delta.put(State.USER_PREFIX + "language", "en"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + + try (Connection conn = DriverManager.getConnection(TEST_DB_URL); + Statement stmt = conn.createStatement(); + ResultSet rs = + stmt.executeQuery( + "SELECT state FROM user_states WHERE app_name = '" + + TEST_APP_NAME + + "' AND user_id = '" + + TEST_USER_ID_1 + + "'")) { + + assertTrue(rs.next(), "User state should exist in database"); + String stateJson = rs.getString("state"); + assertNotNull(stateJson); + assertFalse( + stateJson.contains("\"user:preference\""), "Database should NOT contain 'user:' prefix"); + assertFalse( + stateJson.contains("\"user:language\""), "Database should NOT contain 'user:' prefix"); + assertTrue( + stateJson.contains("\"preference\""), "Database should contain unprefixed 'preference'"); + assertTrue( + stateJson.contains("\"language\""), "Database should contain unprefixed 'language'"); + } + } + + @Test + public void testSessionStatePrefixAddedDuringRetrieval() { + String sessionId = "retrieval-test"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, new ConcurrentHashMap<>(), sessionId) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "global_value", "shared"); + delta.put(State.USER_PREFIX + "user_value", "personal"); + delta.put("session_value", "private"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session, event).blockingGet(); + + Session retrieved = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId, Optional.empty()) + .blockingGet(); + + assertNotNull(retrieved); + assertEquals( + "shared", + retrieved.state().get(State.APP_PREFIX + "global_value"), + "App state should have 'app:' prefix"); + assertEquals( + "personal", + retrieved.state().get(State.USER_PREFIX + "user_value"), + "User state should have 'user:' prefix"); + assertEquals( + "private", retrieved.state().get("session_value"), "Session state should NOT have prefix"); + } + + @Test + public void testAppStateSharedAcrossUsers() { + String sessionId1 = "session1"; + String sessionId2 = "session2"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, new ConcurrentHashMap<>(), sessionId1) + .blockingGet(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_2, new ConcurrentHashMap<>(), sessionId2) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.APP_PREFIX + "feature_flag", true); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId1, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session1, event).blockingGet(); + + Session session2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_2, sessionId2, Optional.empty()) + .blockingGet(); + + assertTrue( + (Boolean) session2.state().get(State.APP_PREFIX + "feature_flag"), + "User 2 should see app state set by User 1"); + } + + @Test + public void testUserStateIsolatedBetweenUsers() { + String sessionId1 = "session1"; + String sessionId2 = "session2"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, new ConcurrentHashMap<>(), sessionId1) + .blockingGet(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_2, new ConcurrentHashMap<>(), sessionId2) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.USER_PREFIX + "timezone", "UTC"); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId1, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session1, event).blockingGet(); + + Session session2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_2, sessionId2, Optional.empty()) + .blockingGet(); + + assertNull( + session2.state().get(State.USER_PREFIX + "timezone"), + "User 2 should NOT see user state set by User 1"); + } + + @Test + public void testUserStateSharedAcrossSessionsForSameUser() { + String sessionId1 = "session1"; + String sessionId2 = "session2"; + + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, new ConcurrentHashMap<>(), sessionId1) + .blockingGet(); + sessionService + .createSession(TEST_APP_NAME, TEST_USER_ID_1, new ConcurrentHashMap<>(), sessionId2) + .blockingGet(); + + ConcurrentHashMap delta = new ConcurrentHashMap<>(); + delta.put(State.USER_PREFIX + "notification_enabled", true); + + EventActions actions = EventActions.builder().stateDelta(delta).build(); + + Event event = + Event.builder() + .id(UUID.randomUUID().toString()) + .author("test") + .content(Content.fromParts(Part.fromText("Test"))) + .timestamp(Instant.now().toEpochMilli()) + .actions(actions) + .build(); + + Session session1 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId1, Optional.empty()) + .blockingGet(); + sessionService.appendEvent(session1, event).blockingGet(); + + Session session2 = + sessionService + .getSession(TEST_APP_NAME, TEST_USER_ID_1, sessionId2, Optional.empty()) + .blockingGet(); + + assertTrue( + (Boolean) session2.state().get(State.USER_PREFIX + "notification_enabled"), + "Same user should see user state across different sessions"); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/testing/TestDatabaseConfig.java b/contrib/database-session-service/src/test/java/com/google/adk/testing/TestDatabaseConfig.java new file mode 100644 index 000000000..b9b113ade --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/testing/TestDatabaseConfig.java @@ -0,0 +1,125 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.testing; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; + +/** + * Centralized configuration for integration test databases. + * + *

These connection strings assume test databases are running via + * scripts/docker-compose.test.yml: + * + *

{@code
+ * docker-compose -f scripts/docker-compose.test.yml up -d
+ * }
+ */ +public final class TestDatabaseConfig { + + private TestDatabaseConfig() {} + + // MySQL Test Database Configuration + public static final String MYSQL_HOST = "localhost"; + public static final int MYSQL_PORT = 3306; + public static final String MYSQL_DATABASE = "adk_test"; + public static final String MYSQL_USER = "adk_user"; + public static final String MYSQL_PASSWORD = "adk_password"; + public static final String MYSQL_JDBC_URL = + String.format( + "jdbc:mysql://%s:%d/%s?user=%s&password=%s&useSSL=false&allowPublicKeyRetrieval=true", + MYSQL_HOST, MYSQL_PORT, MYSQL_DATABASE, MYSQL_USER, MYSQL_PASSWORD); + + // PostgreSQL Test Database Configuration + public static final String POSTGRES_HOST = "localhost"; + public static final int POSTGRES_PORT = 5432; + public static final String POSTGRES_DATABASE = "adk_test"; + public static final String POSTGRES_USER = "adk_user"; + public static final String POSTGRES_PASSWORD = "adk_password"; + public static final String POSTGRES_JDBC_URL = + String.format( + "jdbc:postgresql://%s:%d/%s?user=%s&password=%s", + POSTGRES_HOST, POSTGRES_PORT, POSTGRES_DATABASE, POSTGRES_USER, POSTGRES_PASSWORD); + + // Cloud Spanner Emulator Configuration + public static final String SPANNER_HOST = "localhost"; + public static final int SPANNER_PORT = 9010; + public static final String SPANNER_PROJECT = "test-project"; + public static final String SPANNER_INSTANCE = "test-instance"; + public static final String SPANNER_DATABASE = "test-db"; + public static final String SPANNER_JDBC_URL = + String.format( + "jdbc:cloudspanner://%s:%d/projects/%s/instances/%s/databases/%s?autoConfigEmulator=true", + SPANNER_HOST, SPANNER_PORT, SPANNER_PROJECT, SPANNER_INSTANCE, SPANNER_DATABASE); + + /** + * Checks if MySQL test database is available. + * + * @return true if connection succeeds, false otherwise + */ + public static boolean isMySQLAvailable() { + try (Connection conn = DriverManager.getConnection(MYSQL_JDBC_URL)) { + return conn.isValid(2); + } catch (SQLException e) { + return false; + } + } + + /** + * Checks if PostgreSQL test database is available. + * + * @return true if connection succeeds, false otherwise + */ + public static boolean isPostgreSQLAvailable() { + try (Connection conn = DriverManager.getConnection(POSTGRES_JDBC_URL)) { + return conn.isValid(2); + } catch (SQLException e) { + return false; + } + } + + /** + * Checks if Cloud Spanner emulator is available. + * + * @return true if connection succeeds, false otherwise + */ + public static boolean isSpannerAvailable() { + try (Connection conn = DriverManager.getConnection(SPANNER_JDBC_URL)) { + return conn.isValid(2); + } catch (SQLException e) { + return false; + } + } + + /** + * Returns a helpful message for skipped tests when database is not available. + * + * @param databaseName The name of the database (MySQL, PostgreSQL, or Spanner) + * @return A message explaining how to start the database + */ + public static String getDatabaseNotAvailableMessage(String databaseName) { + if ("Spanner".equalsIgnoreCase(databaseName)) { + return "Cloud Spanner emulator not available. Start it with: " + + "docker run -d -p 9010:9010 -p 9020:9020 gcr.io/cloud-spanner-emulator/emulator && " + + "export SPANNER_EMULATOR_HOST=localhost:9010"; + } + return String.format( + "%s test database not available. Start it with: " + + "docker-compose -f docker-compose.test.yml up -d %s-test", + databaseName, databaseName.toLowerCase()); + } +} diff --git a/contrib/database-session-service/src/test/java/com/google/adk/testing/TestLlm.java b/contrib/database-session-service/src/test/java/com/google/adk/testing/TestLlm.java new file mode 100644 index 000000000..aaacf00a0 --- /dev/null +++ b/contrib/database-session-service/src/test/java/com/google/adk/testing/TestLlm.java @@ -0,0 +1,310 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.testing; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.google.adk.agents.LiveRequest; +import com.google.adk.models.BaseLlm; +import com.google.adk.models.BaseLlmConnection; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Flowable; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Predicate; +import java.util.function.Supplier; +import javax.annotation.Nullable; + +/** + * A test implementation of {@link BaseLlm}. + * + *

Supports providing responses via a sequence of {@link LlmResponse} objects or a {@link + * Supplier} of {@code Flowable}. It also captures all standard and live requests for + * assertion in tests. + */ +public final class TestLlm extends BaseLlm { + private final List llmRequests = Collections.synchronizedList(new ArrayList<>()); + private final List liveRequestHistory = + Collections.synchronizedList(new ArrayList<>()); + + private final List responseSequence; + private final AtomicInteger responseIndex = new AtomicInteger(0); + + private final Supplier> responsesSupplier; + private final Optional error; + + private TestLlm( + @Nullable List responses, + @Nullable Supplier> responsesSupplier, + @Nullable Throwable error) { + super("test-llm"); + this.responseSequence = responses; + this.responsesSupplier = responsesSupplier; + this.error = Optional.ofNullable(error); + } + + /** + * Constructs a TestLlm that serves responses sequentially from the provided list. + * + * @param responses A list of LlmResponse objects to be served in order. Can be null or empty. + */ + public TestLlm(@Nullable List responses) { + this(responses == null ? ImmutableList.of() : ImmutableList.copyOf(responses), null, null); + } + + /** + * Constructs a TestLlm that uses the provided supplier to get responses. + * + * @param responsesSupplier A supplier that provides a Flowable of LlmResponse. + */ + public TestLlm(Supplier> responsesSupplier) { + this(null, responsesSupplier, null); + } + + @CanIgnoreReturnValue + public static TestLlm create(@Nullable List responses, @Nullable Throwable error) { + if (error != null) { + return new TestLlm(ImmutableList.of(), null, error); + } + if (responses == null || responses.isEmpty()) { + return new TestLlm(ImmutableList.of(), null, null); + } + + List llmResponses = new ArrayList<>(); + Object first = responses.get(0); + if (first instanceof LlmResponse) { + // responses is List + for (Object response : responses) { + if (response instanceof LlmResponse llmResponse) { + llmResponses.add(llmResponse); + } else { + throw new IllegalArgumentException("Mixed response types in List"); + } + } + } else if (first instanceof String) { + // responses is List + for (Object item : responses) { + if (item instanceof String string) { + llmResponses.add( + LlmResponse.builder() + .content(Content.builder().parts(ImmutableList.of(Part.fromText(string))).build()) + .build()); + } else { + throw new IllegalArgumentException("Mixed response types in List"); + } + } + } else if (first instanceof Part) { + // responses is List + for (Object item : responses) { + if (item instanceof Part part) { + llmResponses.add( + LlmResponse.builder() + .content(Content.builder().parts(ImmutableList.of(part)).build()) + .build()); + } else { + throw new IllegalArgumentException("Mixed response types in List"); + } + } + } else if (first instanceof List) { + // responses is List> + for (Object item : responses) { + if (item instanceof List) { + List partList = (List) item; + if (!partList.isEmpty() && partList.get(0) instanceof Part) { + llmResponses.add( + LlmResponse.builder() + .content( + Content.builder() + .parts(partList.stream().map(p -> (Part) p).collect(toImmutableList())) + .build()) + .build()); + } else { + throw new IllegalArgumentException("Inner list elements are not Part instances."); + } + } else { + throw new IllegalArgumentException("Mixed response types in List"); + } + } + } else { + throw new IllegalArgumentException("Unsupported response type in List" + first.getClass()); + } + return new TestLlm(llmResponses, null, null); + } + + @CanIgnoreReturnValue + public static TestLlm create(@Nullable List responses) { + return create(responses, null); + } + + @CanIgnoreReturnValue + public static TestLlm create(String... responses) { + return create(Arrays.asList(responses), null); + } + + @CanIgnoreReturnValue + public static TestLlm create(LlmResponse... responses) { + return create(Arrays.asList(responses), null); + } + + @CanIgnoreReturnValue + public static TestLlm create(Part... responses) { + return create(Arrays.asList(responses), null); + } + + @Override + public Flowable generateContent(LlmRequest llmRequest, boolean stream) { + llmRequests.add(llmRequest); + + if (error.isPresent()) { + return Flowable.error(error.get()); + } + + if (this.responseSequence != null) { + // Sequential discrete response mode + int currentIndex = responseIndex.getAndIncrement(); + if (currentIndex < responseSequence.size()) { + LlmResponse nextResponse = responseSequence.get(currentIndex); + return Flowable.just(nextResponse); + } else { + return Flowable.error( + new NoSuchElementException( + "TestLlm (List mode) out of responses. Requested response for LLM call " + + llmRequests.size() + + " (index " + + currentIndex + + ") but only " + + responseSequence.size() + + " were configured.")); + } + } else if (this.responsesSupplier != null) { + // Legacy/streaming supplier mode + return responsesSupplier.get(); + } else { + // Should not happen if constructors are used properly + return Flowable.error(new IllegalStateException("TestLlm not initialized with responses.")); + } + } + + @Override + public BaseLlmConnection connect(LlmRequest llmRequest) { + llmRequests.add(llmRequest); + return new TestLlmConnection(); + } + + public ImmutableList getRequests() { + return ImmutableList.copyOf(llmRequests); + } + + public LlmRequest getLastRequest() { + return Iterables.getLast(llmRequests); + } + + /** Returns an immutable list of all {@link LiveRequest}s sent to the live connection. */ + public ImmutableList getLiveRequestHistory() { + return ImmutableList.copyOf(liveRequestHistory); + } + + public boolean waitForStreamingToolResults(String toolName, int expectedCount, Duration timeout) { + Instant deadline = Instant.now().plus(timeout); + String prefix = "Function " + toolName + " returned:"; + + Predicate isStreamingToolResult = + req -> + req.content() + .filter( + content -> + content.role().orElse("").equals("user") + && content.text() != null + && content.text().startsWith(prefix)) + .isPresent(); + + long currentCount = 0; + while (Instant.now().isBefore(deadline)) { + currentCount = getLiveRequestHistory().stream().filter(isStreamingToolResult).count(); + if (currentCount >= expectedCount) { + return true; + } + try { + Thread.sleep(200); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return false; + } + } + return false; + } + + /** A test implementation of {@link BaseLlmConnection} for {@link TestLlm}. */ + private final class TestLlmConnection implements BaseLlmConnection { + + @Override + public Completable sendHistory(List history) { + return Completable.complete(); + } + + @Override + public Completable sendContent(Content content) { + liveRequestHistory.add(LiveRequest.builder().content(content).build()); + return Completable.complete(); + } + + @Override + public Completable sendRealtime(Blob blob) { + liveRequestHistory.add(LiveRequest.builder().blob(blob).build()); + return Completable.complete(); + } + + @Override + public Flowable receive() { + if (error.isPresent()) { + return Flowable.error(error.get()); + } + if (responseSequence != null) { + return Flowable.fromIterable(responseSequence); + } else if (responsesSupplier != null) { + return responsesSupplier.get(); + } else { + return Flowable.error(new IllegalStateException("TestLlm not initialized with responses.")); + } + } + + @Override + public void close() { + liveRequestHistory.add(LiveRequest.builder().close(true).build()); + } + + @Override + public void close(Throwable throwable) { + close(); + } + } +} diff --git a/pom.xml b/pom.xml index 6a1aa5af5..da06ccd8e 100644 --- a/pom.xml +++ b/pom.xml @@ -34,6 +34,7 @@ contrib/spring-ai contrib/samples contrib/firestore-session-service + contrib/database-session-service tutorials/city-time-weather tutorials/live-audio-single-agent a2a @@ -72,6 +73,8 @@ 1.4.0 3.9.0 5.4.3 + 6.2.1 + 11.17.0 @@ -244,6 +247,31 @@ graphviz-java ${graphviz.version} + + com.zaxxer + HikariCP + ${hikaricp.version} + + + org.flywaydb + flyway-core + ${flyway.version} + + + org.flywaydb + flyway-database-postgresql + ${flyway.version} + + + org.flywaydb + flyway-mysql + ${flyway.version} + + + org.flywaydb + flyway-gcp-spanner + ${flyway.version} + org.eclipse.jdt ecj @@ -558,4 +586,4 @@ https://central.sonatype.com/repository/maven-snapshots/ - \ No newline at end of file +