Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 164 additions & 28 deletions Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,95 @@ enum HTTP {
}
}

#if canImport(FoundationNetworking)
/// Serializes Linux URLSession operations to mitigate a FoundationNetworking race.
///
/// AnyLanguageModel performs many concurrent HTTP requests across model implementations.
/// On Linux, `FoundationNetworking` routes `URLSession` through a shared
/// `_MultiHandle`, which has a known thread-safety bug that can crash under
/// concurrent access (`URLSession._MultiHandle.endOperation(for:)`).
///
/// This gate intentionally allows only one in-flight request setup path at a time on Linux.
/// For non-streaming requests, callers typically hold this lock for the entire
/// request/response cycle, effectively serializing those operations and reducing
/// request-level parallelism (which can lower throughput for heavily concurrent
/// workloads).
///
/// For streaming requests, callers usually acquire the gate only during initial
/// request setup and then release it once the stream has been established; stream
/// consumption itself is not serialized by this gate.
/// Keep this scoped to Linux-only code paths until the upstream issue is resolved.
///
/// See: https://github.com/swiftlang/swift-corelibs-foundation/issues/4791
actor LinuxURLSessionRequestGate {
private struct Waiter {
let id: UUID
let continuation: CheckedContinuation<Void, Error>
}

static let shared = LinuxURLSessionRequestGate()

private var isLocked = false
private var waiters: [Waiter] = []

func acquire() async throws {
if Task.isCancelled {
throw CancellationError()
}

if !isLocked {
isLocked = true
return
}

let waiterID = UUID()
try await withTaskCancellationHandler {
try await withCheckedThrowingContinuation { continuation in
waiters.append(Waiter(id: waiterID, continuation: continuation))
}
} onCancel: {
Task {
await self.cancelWaiter(id: waiterID)
}
}
}

func release() {
if waiters.isEmpty {
isLocked = false
return
}

let waiter = waiters.removeFirst()
waiter.continuation.resume()
}

private func cancelWaiter(id: UUID) {
guard let index = waiters.firstIndex(where: { $0.id == id }) else {
return
}

let waiter = waiters.remove(at: index)
waiter.continuation.resume(throwing: CancellationError())
}

}

func withLinuxRequestLock(
_ operation: () async throws -> Void
) async throws {
let gate = LinuxURLSessionRequestGate.shared
try await gate.acquire()
do {
try await operation()
await gate.release()
} catch {
await gate.release()
throw error
}
}
#endif

extension URLSession {
func fetch<T: Decodable>(
_ method: HTTP.Method,
Expand All @@ -34,7 +123,20 @@ extension URLSession {
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
}

let (data, response) = try await data(for: request)
#if canImport(FoundationNetworking)
var lockedData: Data?
var lockedResponse: URLResponse?
try await withLinuxRequestLock {
let (data, response) = try await data(for: request)
lockedData = data
lockedResponse = response
}
guard let data = lockedData, let response = lockedResponse else {
throw URLSessionError.invalidResponse
}
#else
let (data, response) = try await data(for: request)
#endif

guard let httpResponse = response as? HTTPURLResponse else {
throw URLSessionError.invalidResponse
Expand Down Expand Up @@ -83,7 +185,20 @@ extension URLSession {
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
}

let (data, response) = try await self.data(for: request)
#if canImport(FoundationNetworking)
var lockedData: Data?
var lockedResponse: URLResponse?
try await withLinuxRequestLock {
let (data, response) = try await self.data(for: request)
lockedData = data
lockedResponse = response
}
guard let data = lockedData, let response = lockedResponse else {
throw URLSessionError.invalidResponse
}
#else
let (data, response) = try await self.data(for: request)
#endif

guard let httpResponse = response as? HTTPURLResponse else {
throw URLSessionError.invalidResponse
Expand Down Expand Up @@ -143,35 +258,23 @@ extension URLSession {
}

#if canImport(FoundationNetworking)
let (asyncBytes, response) = try await self.linuxBytes(for: request)
var lockedAsyncBytes: AsyncThrowingStream<UInt8, Error>?
var lockedResponse: URLResponse?
try await withLinuxRequestLock {
let (bytes, response) = try await self.linuxBytes(for: request)
lockedAsyncBytes = bytes
lockedResponse = response
}
guard let asyncBytes = lockedAsyncBytes, let response = lockedResponse else {
throw URLSessionError.invalidResponse
}
try await self.validateEventStreamResponse(response, asyncBytes: asyncBytes)
try await decodeAndYieldEventStream(asyncBytes, to: continuation)
#else
let (asyncBytes, response) = try await self.bytes(for: request)
try await validateEventStreamResponse(response, asyncBytes: asyncBytes)
try await decodeAndYieldEventStream(asyncBytes, to: continuation)
#endif

guard let httpResponse = response as? HTTPURLResponse else {
throw URLSessionError.invalidResponse
}

guard (200 ..< 300).contains(httpResponse.statusCode) else {
var errorData = Data()
for try await byte in asyncBytes {
errorData.append(byte)
}
if let errorString = String(data: errorData, encoding: .utf8) {
throw URLSessionError.httpError(statusCode: httpResponse.statusCode, detail: errorString)
}
throw URLSessionError.httpError(statusCode: httpResponse.statusCode, detail: "Invalid response")
}

let decoder = JSONDecoder()

for try await event in asyncBytes.events {
guard let data = event.data.data(using: .utf8) else { continue }
if let decoded = try? decoder.decode(T.self, from: data) {
continuation.yield(decoded)
}
}

continuation.finish()
} catch {
continuation.finish(throwing: error)
Expand All @@ -183,6 +286,39 @@ extension URLSession {
}
}
}

private func validateEventStreamResponse<Bytes>(
_ response: URLResponse,
asyncBytes: Bytes
) async throws where Bytes: AsyncSequence, Bytes.Element == UInt8 {
guard let httpResponse = response as? HTTPURLResponse else {
throw URLSessionError.invalidResponse
}

guard (200 ..< 300).contains(httpResponse.statusCode) else {
var errorData = Data()
for try await byte in asyncBytes {
errorData.append(byte)
}
if let errorString = String(data: errorData, encoding: .utf8) {
throw URLSessionError.httpError(statusCode: httpResponse.statusCode, detail: errorString)
}
throw URLSessionError.httpError(statusCode: httpResponse.statusCode, detail: "Invalid response")
}
}

private func decodeAndYieldEventStream<T: Decodable & Sendable, Bytes>(
_ asyncBytes: Bytes,
to continuation: AsyncThrowingStream<T, any Error>.Continuation
) async throws where Bytes: AsyncSequence, Bytes.Element == UInt8 {
let decoder = JSONDecoder()
for try await event in asyncBytes.events {
guard let data = event.data.data(using: .utf8) else { continue }
if let decoded = try? decoder.decode(T.self, from: data) {
continuation.yield(decoded)
}
}
}
}

#if canImport(FoundationNetworking)
Expand Down
119 changes: 119 additions & 0 deletions Tests/AnyLanguageModelTests/URLSessionExtensionsTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,122 @@ struct URLSessionExtensionsTests {
#expect(error.description == "Decoding error: keyNotFound")
}
}

#if canImport(FoundationNetworking)
private actor GateCounter {
private(set) var current = 0
private(set) var maxConcurrent = 0

func enter() {
current += 1
maxConcurrent = max(maxConcurrent, current)
}

func leave() {
current -= 1
}
}

private enum GateTestError: Error {
case expected
}

private actor GateFlag {
private(set) var value = false

func setTrue() {
value = true
}
}

extension URLSessionExtensionsTests {
@Test func linuxGateSerializesConcurrentOperations() async throws {
let counter = GateCounter()

try await withThrowingTaskGroup(of: Void.self) { group in
for _ in 0 ..< 8 {
group.addTask {
try await withLinuxRequestLock {
await counter.enter()
do {
try await Task.sleep(for: .milliseconds(20))
await counter.leave()
} catch {
await counter.leave()
throw error
}
}
}
}
try await group.waitForAll()
}

#expect(await counter.maxConcurrent == 1)
}

@Test func linuxGateReleasesAfterError() async throws {
do {
try await withLinuxRequestLock {
throw GateTestError.expected
}
Issue.record("Expected error was not thrown")
} catch GateTestError.expected {
// expected
}

var ranSecondOperation = false
try await withLinuxRequestLock {
ranSecondOperation = true
}
#expect(ranSecondOperation)
}

@Test func linuxGateReleasesAfterCancellation() async throws {
let longTask = Task {
try await withLinuxRequestLock {
try await Task.sleep(for: .seconds(10))
}
}

try await Task.sleep(for: .milliseconds(30))
longTask.cancel()
_ = await longTask.result

var acquiredAfterCancellation = false
try await withLinuxRequestLock {
acquiredAfterCancellation = true
}

#expect(acquiredAfterCancellation)
}

@Test func linuxGateCancelledWaiterDoesNotExecute() async throws {
let ranCancelledOperation = GateFlag()

let holder = Task {
try await withLinuxRequestLock {
try await Task.sleep(for: .milliseconds(200))
}
}

try await Task.sleep(for: .milliseconds(20))

let waiter = Task {
do {
try await withLinuxRequestLock {
await ranCancelledOperation.setTrue()
}
} catch {
// Cancellation is expected.
}
}

waiter.cancel()
_ = await waiter.result
try await holder.value
try await Task.sleep(for: .milliseconds(20))

#expect(await ranCancelledOperation.value == false)
}
}
#endif