diff --git a/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift b/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift index 4c6d0cd..8a1e111 100644 --- a/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift +++ b/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift @@ -13,6 +13,95 @@ enum HTTP { } } +#if canImport(FoundationNetworking) + /// Serializes Linux URLSession operations to mitigate a FoundationNetworking race. + /// + /// AnyLanguageModel performs many concurrent HTTP requests across model implementations. + /// On Linux, `FoundationNetworking` routes `URLSession` through a shared + /// `_MultiHandle`, which has a known thread-safety bug that can crash under + /// concurrent access (`URLSession._MultiHandle.endOperation(for:)`). + /// + /// This gate intentionally allows only one in-flight request setup path at a time on Linux. + /// For non-streaming requests, callers typically hold this lock for the entire + /// request/response cycle, effectively serializing those operations and reducing + /// request-level parallelism (which can lower throughput for heavily concurrent + /// workloads). + /// + /// For streaming requests, callers usually acquire the gate only during initial + /// request setup and then release it once the stream has been established; stream + /// consumption itself is not serialized by this gate. + /// Keep this scoped to Linux-only code paths until the upstream issue is resolved. + /// + /// See: https://github.com/swiftlang/swift-corelibs-foundation/issues/4791 + actor LinuxURLSessionRequestGate { + private struct Waiter { + let id: UUID + let continuation: CheckedContinuation + } + + static let shared = LinuxURLSessionRequestGate() + + private var isLocked = false + private var waiters: [Waiter] = [] + + func acquire() async throws { + if Task.isCancelled { + throw CancellationError() + } + + if !isLocked { + isLocked = true + return + } + + let waiterID = UUID() + try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { continuation in + waiters.append(Waiter(id: waiterID, continuation: continuation)) + } + } onCancel: { + Task { + await self.cancelWaiter(id: waiterID) + } + } + } + + func release() { + if waiters.isEmpty { + isLocked = false + return + } + + let waiter = waiters.removeFirst() + waiter.continuation.resume() + } + + private func cancelWaiter(id: UUID) { + guard let index = waiters.firstIndex(where: { $0.id == id }) else { + return + } + + let waiter = waiters.remove(at: index) + waiter.continuation.resume(throwing: CancellationError()) + } + + } + + func withLinuxRequestLock( + _ operation: () async throws -> Void + ) async throws { + let gate = LinuxURLSessionRequestGate.shared + try await gate.acquire() + do { + try await operation() + await gate.release() + } catch { + await gate.release() + throw error + } + } +#endif + extension URLSession { func fetch( _ method: HTTP.Method, @@ -34,7 +123,20 @@ extension URLSession { request.addValue("application/json", forHTTPHeaderField: "Content-Type") } - let (data, response) = try await data(for: request) + #if canImport(FoundationNetworking) + var lockedData: Data? + var lockedResponse: URLResponse? + try await withLinuxRequestLock { + let (data, response) = try await data(for: request) + lockedData = data + lockedResponse = response + } + guard let data = lockedData, let response = lockedResponse else { + throw URLSessionError.invalidResponse + } + #else + let (data, response) = try await data(for: request) + #endif guard let httpResponse = response as? HTTPURLResponse else { throw URLSessionError.invalidResponse @@ -83,7 +185,20 @@ extension URLSession { request.addValue("application/json", forHTTPHeaderField: "Content-Type") } - let (data, response) = try await self.data(for: request) + #if canImport(FoundationNetworking) + var lockedData: Data? + var lockedResponse: URLResponse? + try await withLinuxRequestLock { + let (data, response) = try await self.data(for: request) + lockedData = data + lockedResponse = response + } + guard let data = lockedData, let response = lockedResponse else { + throw URLSessionError.invalidResponse + } + #else + let (data, response) = try await self.data(for: request) + #endif guard let httpResponse = response as? HTTPURLResponse else { throw URLSessionError.invalidResponse @@ -143,35 +258,23 @@ extension URLSession { } #if canImport(FoundationNetworking) - let (asyncBytes, response) = try await self.linuxBytes(for: request) + var lockedAsyncBytes: AsyncThrowingStream? + var lockedResponse: URLResponse? + try await withLinuxRequestLock { + let (bytes, response) = try await self.linuxBytes(for: request) + lockedAsyncBytes = bytes + lockedResponse = response + } + guard let asyncBytes = lockedAsyncBytes, let response = lockedResponse else { + throw URLSessionError.invalidResponse + } + try await self.validateEventStreamResponse(response, asyncBytes: asyncBytes) + try await decodeAndYieldEventStream(asyncBytes, to: continuation) #else let (asyncBytes, response) = try await self.bytes(for: request) + try await validateEventStreamResponse(response, asyncBytes: asyncBytes) + try await decodeAndYieldEventStream(asyncBytes, to: continuation) #endif - - guard let httpResponse = response as? HTTPURLResponse else { - throw URLSessionError.invalidResponse - } - - guard (200 ..< 300).contains(httpResponse.statusCode) else { - var errorData = Data() - for try await byte in asyncBytes { - errorData.append(byte) - } - if let errorString = String(data: errorData, encoding: .utf8) { - throw URLSessionError.httpError(statusCode: httpResponse.statusCode, detail: errorString) - } - throw URLSessionError.httpError(statusCode: httpResponse.statusCode, detail: "Invalid response") - } - - let decoder = JSONDecoder() - - for try await event in asyncBytes.events { - guard let data = event.data.data(using: .utf8) else { continue } - if let decoded = try? decoder.decode(T.self, from: data) { - continuation.yield(decoded) - } - } - continuation.finish() } catch { continuation.finish(throwing: error) @@ -183,6 +286,39 @@ extension URLSession { } } } + + private func validateEventStreamResponse( + _ response: URLResponse, + asyncBytes: Bytes + ) async throws where Bytes: AsyncSequence, Bytes.Element == UInt8 { + guard let httpResponse = response as? HTTPURLResponse else { + throw URLSessionError.invalidResponse + } + + guard (200 ..< 300).contains(httpResponse.statusCode) else { + var errorData = Data() + for try await byte in asyncBytes { + errorData.append(byte) + } + if let errorString = String(data: errorData, encoding: .utf8) { + throw URLSessionError.httpError(statusCode: httpResponse.statusCode, detail: errorString) + } + throw URLSessionError.httpError(statusCode: httpResponse.statusCode, detail: "Invalid response") + } + } + + private func decodeAndYieldEventStream( + _ asyncBytes: Bytes, + to continuation: AsyncThrowingStream.Continuation + ) async throws where Bytes: AsyncSequence, Bytes.Element == UInt8 { + let decoder = JSONDecoder() + for try await event in asyncBytes.events { + guard let data = event.data.data(using: .utf8) else { continue } + if let decoded = try? decoder.decode(T.self, from: data) { + continuation.yield(decoded) + } + } + } } #if canImport(FoundationNetworking) diff --git a/Tests/AnyLanguageModelTests/URLSessionExtensionsTests.swift b/Tests/AnyLanguageModelTests/URLSessionExtensionsTests.swift index b672f92..c1ac266 100644 --- a/Tests/AnyLanguageModelTests/URLSessionExtensionsTests.swift +++ b/Tests/AnyLanguageModelTests/URLSessionExtensionsTests.swift @@ -19,3 +19,122 @@ struct URLSessionExtensionsTests { #expect(error.description == "Decoding error: keyNotFound") } } + +#if canImport(FoundationNetworking) + private actor GateCounter { + private(set) var current = 0 + private(set) var maxConcurrent = 0 + + func enter() { + current += 1 + maxConcurrent = max(maxConcurrent, current) + } + + func leave() { + current -= 1 + } + } + + private enum GateTestError: Error { + case expected + } + + private actor GateFlag { + private(set) var value = false + + func setTrue() { + value = true + } + } + + extension URLSessionExtensionsTests { + @Test func linuxGateSerializesConcurrentOperations() async throws { + let counter = GateCounter() + + try await withThrowingTaskGroup(of: Void.self) { group in + for _ in 0 ..< 8 { + group.addTask { + try await withLinuxRequestLock { + await counter.enter() + do { + try await Task.sleep(for: .milliseconds(20)) + await counter.leave() + } catch { + await counter.leave() + throw error + } + } + } + } + try await group.waitForAll() + } + + #expect(await counter.maxConcurrent == 1) + } + + @Test func linuxGateReleasesAfterError() async throws { + do { + try await withLinuxRequestLock { + throw GateTestError.expected + } + Issue.record("Expected error was not thrown") + } catch GateTestError.expected { + // expected + } + + var ranSecondOperation = false + try await withLinuxRequestLock { + ranSecondOperation = true + } + #expect(ranSecondOperation) + } + + @Test func linuxGateReleasesAfterCancellation() async throws { + let longTask = Task { + try await withLinuxRequestLock { + try await Task.sleep(for: .seconds(10)) + } + } + + try await Task.sleep(for: .milliseconds(30)) + longTask.cancel() + _ = await longTask.result + + var acquiredAfterCancellation = false + try await withLinuxRequestLock { + acquiredAfterCancellation = true + } + + #expect(acquiredAfterCancellation) + } + + @Test func linuxGateCancelledWaiterDoesNotExecute() async throws { + let ranCancelledOperation = GateFlag() + + let holder = Task { + try await withLinuxRequestLock { + try await Task.sleep(for: .milliseconds(200)) + } + } + + try await Task.sleep(for: .milliseconds(20)) + + let waiter = Task { + do { + try await withLinuxRequestLock { + await ranCancelledOperation.setTrue() + } + } catch { + // Cancellation is expected. + } + } + + waiter.cancel() + _ = await waiter.result + try await holder.value + try await Task.sleep(for: .milliseconds(20)) + + #expect(await ranCancelledOperation.value == false) + } + } +#endif