diff --git a/Sources/SQLiteNIOExtras/SQLiteClient.swift b/Sources/SQLiteNIOExtras/SQLiteClient.swift index a57dab1..9d2a0c6 100644 --- a/Sources/SQLiteNIOExtras/SQLiteClient.swift +++ b/Sources/SQLiteNIOExtras/SQLiteClient.swift @@ -6,6 +6,7 @@ // import Logging +import NIOCore import SQLiteNIO /// A SQLite client backed by a connection pool. @@ -49,9 +50,9 @@ public final class SQLiteClient: Sendable { /// Logger used for pool operations. public let logger: Logger /// Journal mode applied to each pooled connection. - public let journalMode: JournalMode + public let journalMode: JournalMode? /// Busy timeout, in milliseconds, applied to each pooled connection. - public let busyTimeoutMilliseconds: Int + public let busyTimeoutMilliseconds: Int? /// Foreign key enforcement mode applied to each pooled connection. public let foreignKeysMode: ForeignKeysMode @@ -68,15 +69,17 @@ public final class SQLiteClient: Sendable { storage: SQLiteConnection.Storage, logger: Logger, minimumConnections: Int = 1, - maximumConnections: Int = 8, - journalMode: JournalMode = .wal, + maximumConnections: Int = System.coreCount, + journalMode: JournalMode? = nil, foreignKeysMode: ForeignKeysMode = .on, - busyTimeoutMilliseconds: Int = 1000 + busyTimeoutMilliseconds: Int? = nil ) { precondition(minimumConnections >= 0) precondition(maximumConnections >= 1) precondition(minimumConnections <= maximumConnections) - precondition(busyTimeoutMilliseconds >= 0) + if let busyTimeoutMilliseconds { + precondition(busyTimeoutMilliseconds >= 0) + } self.storage = storage self.minimumConnections = minimumConnections self.maximumConnections = maximumConnections diff --git a/Sources/SQLiteNIOExtras/SQLiteConnectionPool.swift b/Sources/SQLiteNIOExtras/SQLiteConnectionPool.swift index 308c28e..6fe6cba 100644 --- a/Sources/SQLiteNIOExtras/SQLiteConnectionPool.swift +++ b/Sources/SQLiteNIOExtras/SQLiteConnectionPool.swift @@ -191,15 +191,21 @@ actor SQLiteConnectionPool { logger: configuration.logger ) do { + let foreignKeys = configuration.foreignKeysMode + _ = try await connection.query( - "PRAGMA journal_mode = \(configuration.journalMode.rawValue);" - ) - _ = try await connection.query( - "PRAGMA busy_timeout = \(configuration.busyTimeoutMilliseconds);" - ) - _ = try await connection.query( - "PRAGMA foreign_keys = \(configuration.foreignKeysMode.rawValue);" + "PRAGMA foreign_keys = \(foreignKeys.rawValue);" ) + if let busyTimeout = configuration.busyTimeoutMilliseconds { + _ = try await connection.query( + "PRAGMA busy_timeout = \(busyTimeout);" + ) + } + if let journalMode = configuration.journalMode { + _ = try await connection.query( + "PRAGMA journal_mode = \(journalMode.rawValue);" + ) + } } catch { await closeConnection(connection) diff --git a/Tests/SQLiteNIOExtrasTests/SQLiteNIOExtrasTestSuite.swift b/Tests/SQLiteNIOExtrasTests/SQLiteNIOExtrasTestSuite.swift index 009769c..9aeef6d 100644 --- a/Tests/SQLiteNIOExtrasTests/SQLiteNIOExtrasTestSuite.swift +++ b/Tests/SQLiteNIOExtrasTests/SQLiteNIOExtrasTestSuite.swift @@ -229,4 +229,84 @@ struct SQLiteNIOExtrasTestSuite { #expect(result[0].column("is_valid")?.bool == true) } } + + // MARK: - lock + + private actor LockBarrier { + private var ready = false + private var waiters: [CheckedContinuation] = [] + + func waitUntilLocked() async { + if ready { return } + await withCheckedContinuation { continuation in + waiters.append(continuation) + } + } + + func markLocked() { + guard !ready else { return } + ready = true + let pending = waiters + waiters.removeAll(keepingCapacity: false) + for continuation in pending { + continuation.resume() + } + } + } + + @Test + func warmupWaitsForTransientExclusiveLock() async throws { + let dbPath = + "/tmp/feather-lock-\(UInt64.random(in: 0...UInt64.max)).sqlite" + + var logger = Logger(label: "test.sqlite.lock.warmup") + logger.logLevel = .info + + let config = SQLiteClient.Configuration( + storage: .file(path: dbPath), + logger: logger, + minimumConnections: 1, + maximumConnections: 1, + journalMode: .delete, + busyTimeoutMilliseconds: 5_000 + ) + + let clientA = SQLiteClient(configuration: config) + let clientB = SQLiteClient(configuration: config) + + try await clientA.run() + defer { + Task { + await clientB.shutdown() + await clientA.shutdown() + } + } + + let barrier = LockBarrier() + + let holder = Task { + try await clientA.withConnection { connection in + _ = try await connection.query("BEGIN EXCLUSIVE;") + await barrier.markLocked() + try await Task.sleep(for: .milliseconds(1200)) + _ = try await connection.query("COMMIT;") + } + } + + await barrier.waitUntilLocked() + + let clock = ContinuousClock() + let start = clock.now + + try await clientB.run() + try await clientB.withConnection { connection in + _ = try await connection.query("SELECT 1;") + } + + let elapsed = start.duration(to: clock.now) + #expect(elapsed >= .milliseconds(900)) + + _ = try await holder.value + } + }