From e48f897850b3944b4c35c91d730bf117d1701ed6 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 19 Feb 2026 02:28:30 -0800 Subject: [PATCH 1/9] Serialize Linux URLSession request paths to mitigate _MultiHandle race --- .../Extensions/URLSession+Extensions.swift | 147 +++++++++++++++--- 1 file changed, 125 insertions(+), 22 deletions(-) diff --git a/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift b/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift index 4c6d0cdd..f4b07da3 100644 --- a/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift +++ b/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift @@ -13,6 +13,56 @@ enum HTTP { } } +#if canImport(FoundationNetworking) + /// Serializes Linux URLSession operations to mitigate a FoundationNetworking race. + /// + /// AnyLanguageModel performs many concurrent HTTP requests across model implementations. + /// On Linux, `FoundationNetworking` routes `URLSession` through a shared + /// `_MultiHandle`, which has a known thread-safety bug that can crash under + /// concurrent access (`URLSession._MultiHandle.endOperation(for:)`). + /// + /// This gate intentionally allows only one in-flight request path at a time on Linux. + /// Keep this scoped to Linux-only code paths until the upstream issue is resolved. + /// + /// See: https://github.com/swiftlang/swift-corelibs-foundation/issues/4791 + private actor LinuxURLSessionRequestGate { + static let shared = LinuxURLSessionRequestGate() + + private var isLocked = false + private var waiters: [CheckedContinuation] = [] + + private func acquire() async { + if !isLocked { + isLocked = true + return + } + + await withCheckedContinuation { continuation in + waiters.append(continuation) + } + } + + private func release() { + if waiters.isEmpty { + isLocked = false + return + } + + let continuation = waiters.removeFirst() + continuation.resume() + } + + /// Executes an async operation while holding the gate lock. + func withLock( + _ operation: () async throws -> T + ) async rethrows -> T { + await acquire() + defer { release() } + return try await operation() + } + } +#endif + extension URLSession { func fetch( _ method: HTTP.Method, @@ -34,7 +84,14 @@ extension URLSession { request.addValue("application/json", forHTTPHeaderField: "Content-Type") } - let (data, response) = try await data(for: request) + #if canImport(FoundationNetworking) + let dataAndResponse = try await LinuxURLSessionRequestGate.shared.withLock { + try await data(for: request) + } + let (data, response) = dataAndResponse + #else + let (data, response) = try await data(for: request) + #endif guard let httpResponse = response as? HTTPURLResponse else { throw URLSessionError.invalidResponse @@ -83,7 +140,14 @@ extension URLSession { request.addValue("application/json", forHTTPHeaderField: "Content-Type") } - let (data, response) = try await self.data(for: request) + #if canImport(FoundationNetworking) + let dataAndResponse = try await LinuxURLSessionRequestGate.shared.withLock { + try await self.data(for: request) + } + let (data, response) = dataAndResponse + #else + let (data, response) = try await self.data(for: request) + #endif guard let httpResponse = response as? HTTPURLResponse else { throw URLSessionError.invalidResponse @@ -143,34 +207,73 @@ extension URLSession { } #if canImport(FoundationNetworking) - let (asyncBytes, response) = try await self.linuxBytes(for: request) + try await LinuxURLSessionRequestGate.shared.withLock { + let asyncBytesAndResponse = try await self.linuxBytes(for: request) + let (asyncBytes, response) = asyncBytesAndResponse + + guard let httpResponse = response as? HTTPURLResponse else { + throw URLSessionError.invalidResponse + } + + guard (200 ..< 300).contains(httpResponse.statusCode) else { + var errorData = Data() + for try await byte in asyncBytes { + errorData.append(byte) + } + if let errorString = String(data: errorData, encoding: .utf8) { + throw URLSessionError.httpError( + statusCode: httpResponse.statusCode, + detail: errorString + ) + } + throw URLSessionError.httpError( + statusCode: httpResponse.statusCode, + detail: "Invalid response" + ) + } + + let decoder = JSONDecoder() + + for try await event in asyncBytes.events { + guard let data = event.data.data(using: .utf8) else { continue } + if let decoded = try? decoder.decode(T.self, from: data) { + continuation.yield(decoded) + } + } + } #else let (asyncBytes, response) = try await self.bytes(for: request) - #endif - - guard let httpResponse = response as? HTTPURLResponse else { - throw URLSessionError.invalidResponse - } - guard (200 ..< 300).contains(httpResponse.statusCode) else { - var errorData = Data() - for try await byte in asyncBytes { - errorData.append(byte) + guard let httpResponse = response as? HTTPURLResponse else { + throw URLSessionError.invalidResponse } - if let errorString = String(data: errorData, encoding: .utf8) { - throw URLSessionError.httpError(statusCode: httpResponse.statusCode, detail: errorString) + + guard (200 ..< 300).contains(httpResponse.statusCode) else { + var errorData = Data() + for try await byte in asyncBytes { + errorData.append(byte) + } + if let errorString = String(data: errorData, encoding: .utf8) { + throw URLSessionError.httpError( + statusCode: httpResponse.statusCode, + detail: errorString + ) + } + throw URLSessionError.httpError( + statusCode: httpResponse.statusCode, + detail: "Invalid response" + ) } - throw URLSessionError.httpError(statusCode: httpResponse.statusCode, detail: "Invalid response") - } - let decoder = JSONDecoder() + let decoder = JSONDecoder() - for try await event in asyncBytes.events { - guard let data = event.data.data(using: .utf8) else { continue } - if let decoded = try? decoder.decode(T.self, from: data) { - continuation.yield(decoded) + for try await event in asyncBytes.events { + guard let data = event.data.data(using: .utf8) else { continue } + if let decoded = try? decoder.decode(T.self, from: data) { + continuation.yield(decoded) + } } - } + #endif continuation.finish() } catch { From 2027d9ea633b4728d62319e72f09913ad9306f56 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 19 Feb 2026 02:44:33 -0800 Subject: [PATCH 2/9] Incorporate feedback from review --- .../Extensions/URLSession+Extensions.swift | 108 +++++++----------- .../URLSessionExtensionsTests.swift | 87 ++++++++++++++ 2 files changed, 130 insertions(+), 65 deletions(-) diff --git a/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift b/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift index f4b07da3..aa17292b 100644 --- a/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift +++ b/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift @@ -22,10 +22,12 @@ enum HTTP { /// concurrent access (`URLSession._MultiHandle.endOperation(for:)`). /// /// This gate intentionally allows only one in-flight request path at a time on Linux. + /// This fully serializes HTTP request setup paths on Linux and reduces request-level + /// parallelism, which can lower throughput for heavily concurrent workloads. /// Keep this scoped to Linux-only code paths until the upstream issue is resolved. /// /// See: https://github.com/swiftlang/swift-corelibs-foundation/issues/4791 - private actor LinuxURLSessionRequestGate { + actor LinuxURLSessionRequestGate { static let shared = LinuxURLSessionRequestGate() private var isLocked = false @@ -207,74 +209,17 @@ extension URLSession { } #if canImport(FoundationNetworking) - try await LinuxURLSessionRequestGate.shared.withLock { - let asyncBytesAndResponse = try await self.linuxBytes(for: request) - let (asyncBytes, response) = asyncBytesAndResponse - - guard let httpResponse = response as? HTTPURLResponse else { - throw URLSessionError.invalidResponse - } - - guard (200 ..< 300).contains(httpResponse.statusCode) else { - var errorData = Data() - for try await byte in asyncBytes { - errorData.append(byte) - } - if let errorString = String(data: errorData, encoding: .utf8) { - throw URLSessionError.httpError( - statusCode: httpResponse.statusCode, - detail: errorString - ) - } - throw URLSessionError.httpError( - statusCode: httpResponse.statusCode, - detail: "Invalid response" - ) - } - - let decoder = JSONDecoder() - - for try await event in asyncBytes.events { - guard let data = event.data.data(using: .utf8) else { continue } - if let decoded = try? decoder.decode(T.self, from: data) { - continuation.yield(decoded) - } - } + let asyncBytes = try await LinuxURLSessionRequestGate.shared.withLock { + let (bytes, response) = try await self.linuxBytes(for: request) + try await self.validateEventStreamResponse(response, asyncBytes: bytes) + return bytes } + try await decodeAndYieldEventStream(asyncBytes, to: continuation) #else let (asyncBytes, response) = try await self.bytes(for: request) - - guard let httpResponse = response as? HTTPURLResponse else { - throw URLSessionError.invalidResponse - } - - guard (200 ..< 300).contains(httpResponse.statusCode) else { - var errorData = Data() - for try await byte in asyncBytes { - errorData.append(byte) - } - if let errorString = String(data: errorData, encoding: .utf8) { - throw URLSessionError.httpError( - statusCode: httpResponse.statusCode, - detail: errorString - ) - } - throw URLSessionError.httpError( - statusCode: httpResponse.statusCode, - detail: "Invalid response" - ) - } - - let decoder = JSONDecoder() - - for try await event in asyncBytes.events { - guard let data = event.data.data(using: .utf8) else { continue } - if let decoded = try? decoder.decode(T.self, from: data) { - continuation.yield(decoded) - } - } + try await validateEventStreamResponse(response, asyncBytes: asyncBytes) + try await decodeAndYieldEventStream(asyncBytes, to: continuation) #endif - continuation.finish() } catch { continuation.finish(throwing: error) @@ -286,6 +231,39 @@ extension URLSession { } } } + + private func validateEventStreamResponse( + _ response: URLResponse, + asyncBytes: Bytes + ) async throws where Bytes: AsyncSequence, Bytes.Element == UInt8 { + guard let httpResponse = response as? HTTPURLResponse else { + throw URLSessionError.invalidResponse + } + + guard (200 ..< 300).contains(httpResponse.statusCode) else { + var errorData = Data() + for try await byte in asyncBytes { + errorData.append(byte) + } + if let errorString = String(data: errorData, encoding: .utf8) { + throw URLSessionError.httpError(statusCode: httpResponse.statusCode, detail: errorString) + } + throw URLSessionError.httpError(statusCode: httpResponse.statusCode, detail: "Invalid response") + } + } + + private func decodeAndYieldEventStream( + _ asyncBytes: Bytes, + to continuation: AsyncThrowingStream.Continuation + ) async throws where Bytes: AsyncSequence, Bytes.Element == UInt8 { + let decoder = JSONDecoder() + for try await event in asyncBytes.events { + guard let data = event.data.data(using: .utf8) else { continue } + if let decoded = try? decoder.decode(T.self, from: data) { + continuation.yield(decoded) + } + } + } } #if canImport(FoundationNetworking) diff --git a/Tests/AnyLanguageModelTests/URLSessionExtensionsTests.swift b/Tests/AnyLanguageModelTests/URLSessionExtensionsTests.swift index b672f928..d87a40a9 100644 --- a/Tests/AnyLanguageModelTests/URLSessionExtensionsTests.swift +++ b/Tests/AnyLanguageModelTests/URLSessionExtensionsTests.swift @@ -19,3 +19,90 @@ struct URLSessionExtensionsTests { #expect(error.description == "Decoding error: keyNotFound") } } + +#if canImport(FoundationNetworking) + private actor GateCounter { + private(set) var current = 0 + private(set) var maxConcurrent = 0 + + func enter() { + current += 1 + maxConcurrent = max(maxConcurrent, current) + } + + func leave() { + current -= 1 + } + } + + private enum GateTestError: Error { + case expected + } + + extension URLSessionExtensionsTests { + @Test func linuxGateSerializesConcurrentOperations() async throws { + let gate = LinuxURLSessionRequestGate() + let counter = GateCounter() + + try await withThrowingTaskGroup(of: Void.self) { group in + for _ in 0 ..< 8 { + group.addTask { + try await gate.withLock { + await counter.enter() + do { + try await Task.sleep(for: .milliseconds(20)) + await counter.leave() + } catch { + await counter.leave() + throw error + } + } + } + } + try await group.waitForAll() + } + + #expect(await counter.maxConcurrent == 1) + } + + @Test func linuxGateReleasesAfterError() async throws { + let gate = LinuxURLSessionRequestGate() + + do { + _ = try await gate.withLock { + throw GateTestError.expected + } + Issue.record("Expected error was not thrown") + } catch GateTestError.expected { + // expected + } + + var ranSecondOperation = false + _ = try await gate.withLock { + ranSecondOperation = true + } + #expect(ranSecondOperation) + } + + @Test func linuxGateReleasesAfterCancellation() async throws { + let gate = LinuxURLSessionRequestGate() + + let longTask = Task { + try await gate.withLock { + try await Task.sleep(for: .seconds(10)) + } + } + + try await Task.sleep(for: .milliseconds(30)) + longTask.cancel() + _ = await longTask.result + + var acquiredAfterCancellation = false + _ = try await gate.withLock { + acquiredAfterCancellation = true + } + + #expect(acquiredAfterCancellation) + } + } +#endif From 260abcf1ec1196b2039a04d798b0a7f457b73f54 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 19 Feb 2026 02:49:03 -0800 Subject: [PATCH 3/9] Replace withLock instance method with top-level withLinuxRequestLock helper --- .../Extensions/URLSession+Extensions.swift | 31 ++++++++++++------- .../URLSessionExtensionsTests.swift | 15 +++------ 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift b/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift index aa17292b..70cd714c 100644 --- a/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift +++ b/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift @@ -33,7 +33,7 @@ enum HTTP { private var isLocked = false private var waiters: [CheckedContinuation] = [] - private func acquire() async { + func acquire() async { if !isLocked { isLocked = true return @@ -44,7 +44,7 @@ enum HTTP { } } - private func release() { + func release() { if waiters.isEmpty { isLocked = false return @@ -54,13 +54,20 @@ enum HTTP { continuation.resume() } - /// Executes an async operation while holding the gate lock. - func withLock( - _ operation: () async throws -> T - ) async rethrows -> T { - await acquire() - defer { release() } - return try await operation() + } + + func withLinuxRequestLock( + _ operation: () async throws -> T + ) async rethrows -> T { + let gate = LinuxURLSessionRequestGate.shared + await gate.acquire() + do { + let result = try await operation() + await gate.release() + return result + } catch { + await gate.release() + throw error } } #endif @@ -87,7 +94,7 @@ extension URLSession { } #if canImport(FoundationNetworking) - let dataAndResponse = try await LinuxURLSessionRequestGate.shared.withLock { + let dataAndResponse = try await withLinuxRequestLock { try await data(for: request) } let (data, response) = dataAndResponse @@ -143,7 +150,7 @@ extension URLSession { } #if canImport(FoundationNetworking) - let dataAndResponse = try await LinuxURLSessionRequestGate.shared.withLock { + let dataAndResponse = try await withLinuxRequestLock { try await self.data(for: request) } let (data, response) = dataAndResponse @@ -209,7 +216,7 @@ extension URLSession { } #if canImport(FoundationNetworking) - let asyncBytes = try await LinuxURLSessionRequestGate.shared.withLock { + let asyncBytes = try await withLinuxRequestLock { let (bytes, response) = try await self.linuxBytes(for: request) try await self.validateEventStreamResponse(response, asyncBytes: bytes) return bytes diff --git a/Tests/AnyLanguageModelTests/URLSessionExtensionsTests.swift b/Tests/AnyLanguageModelTests/URLSessionExtensionsTests.swift index d87a40a9..d8fb92b2 100644 --- a/Tests/AnyLanguageModelTests/URLSessionExtensionsTests.swift +++ b/Tests/AnyLanguageModelTests/URLSessionExtensionsTests.swift @@ -41,13 +41,12 @@ struct URLSessionExtensionsTests { extension URLSessionExtensionsTests { @Test func linuxGateSerializesConcurrentOperations() async throws { - let gate = LinuxURLSessionRequestGate() let counter = GateCounter() try await withThrowingTaskGroup(of: Void.self) { group in for _ in 0 ..< 8 { group.addTask { - try await gate.withLock { + try await withLinuxRequestLock { await counter.enter() do { try await Task.sleep(for: .milliseconds(20)) @@ -66,10 +65,8 @@ struct URLSessionExtensionsTests { } @Test func linuxGateReleasesAfterError() async throws { - let gate = LinuxURLSessionRequestGate() - do { - _ = try await gate.withLock { + try await withLinuxRequestLock { throw GateTestError.expected } Issue.record("Expected error was not thrown") @@ -78,17 +75,15 @@ struct URLSessionExtensionsTests { } var ranSecondOperation = false - _ = try await gate.withLock { + try await withLinuxRequestLock { ranSecondOperation = true } #expect(ranSecondOperation) } @Test func linuxGateReleasesAfterCancellation() async throws { - let gate = LinuxURLSessionRequestGate() - let longTask = Task { - try await gate.withLock { + try await withLinuxRequestLock { try await Task.sleep(for: .seconds(10)) } } @@ -98,7 +93,7 @@ struct URLSessionExtensionsTests { _ = await longTask.result var acquiredAfterCancellation = false - _ = try await gate.withLock { + try await withLinuxRequestLock { acquiredAfterCancellation = true } From f94d23375fcff0fd686c68d50fecc19533e4091f Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 19 Feb 2026 02:54:52 -0800 Subject: [PATCH 4/9] Fix Linux compiler bug around generic returning lock helper --- .../Extensions/URLSession+Extensions.swift | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift b/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift index 70cd714c..ede16cca 100644 --- a/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift +++ b/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift @@ -56,15 +56,14 @@ enum HTTP { } - func withLinuxRequestLock( - _ operation: () async throws -> T - ) async rethrows -> T { + func withLinuxRequestLock( + _ operation: () async throws -> Void + ) async throws { let gate = LinuxURLSessionRequestGate.shared await gate.acquire() do { - let result = try await operation() + try await operation() await gate.release() - return result } catch { await gate.release() throw error @@ -94,8 +93,12 @@ extension URLSession { } #if canImport(FoundationNetworking) - let dataAndResponse = try await withLinuxRequestLock { - try await data(for: request) + var dataAndResponse: (Data, URLResponse)? + try await withLinuxRequestLock { + dataAndResponse = try await data(for: request) + } + guard let dataAndResponse else { + throw URLSessionError.invalidResponse } let (data, response) = dataAndResponse #else @@ -150,8 +153,12 @@ extension URLSession { } #if canImport(FoundationNetworking) - let dataAndResponse = try await withLinuxRequestLock { - try await self.data(for: request) + var dataAndResponse: (Data, URLResponse)? + try await withLinuxRequestLock { + dataAndResponse = try await self.data(for: request) + } + guard let dataAndResponse else { + throw URLSessionError.invalidResponse } let (data, response) = dataAndResponse #else @@ -216,10 +223,14 @@ extension URLSession { } #if canImport(FoundationNetworking) - let asyncBytes = try await withLinuxRequestLock { + var asyncBytes: AsyncThrowingStream? + try await withLinuxRequestLock { let (bytes, response) = try await self.linuxBytes(for: request) try await self.validateEventStreamResponse(response, asyncBytes: bytes) - return bytes + asyncBytes = bytes + } + guard let asyncBytes else { + throw URLSessionError.invalidResponse } try await decodeAndYieldEventStream(asyncBytes, to: continuation) #else From 6edc374bc6aa7e3f290c38e5b5b3d3b40fa9d515 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 19 Feb 2026 02:59:25 -0800 Subject: [PATCH 5/9] More workarounds for Linux compiler bugs --- .../Extensions/URLSession+Extensions.swift | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift b/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift index ede16cca..87efc636 100644 --- a/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift +++ b/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift @@ -93,14 +93,16 @@ extension URLSession { } #if canImport(FoundationNetworking) - var dataAndResponse: (Data, URLResponse)? + var lockedData: Data? + var lockedResponse: URLResponse? try await withLinuxRequestLock { - dataAndResponse = try await data(for: request) + let (data, response) = try await data(for: request) + lockedData = data + lockedResponse = response } - guard let dataAndResponse else { + guard let data = lockedData, let response = lockedResponse else { throw URLSessionError.invalidResponse } - let (data, response) = dataAndResponse #else let (data, response) = try await data(for: request) #endif @@ -153,14 +155,16 @@ extension URLSession { } #if canImport(FoundationNetworking) - var dataAndResponse: (Data, URLResponse)? + var lockedData: Data? + var lockedResponse: URLResponse? try await withLinuxRequestLock { - dataAndResponse = try await self.data(for: request) + let (data, response) = try await self.data(for: request) + lockedData = data + lockedResponse = response } - guard let dataAndResponse else { + guard let data = lockedData, let response = lockedResponse else { throw URLSessionError.invalidResponse } - let (data, response) = dataAndResponse #else let (data, response) = try await self.data(for: request) #endif @@ -223,13 +227,13 @@ extension URLSession { } #if canImport(FoundationNetworking) - var asyncBytes: AsyncThrowingStream? + var lockedAsyncBytes: AsyncThrowingStream? try await withLinuxRequestLock { let (bytes, response) = try await self.linuxBytes(for: request) try await self.validateEventStreamResponse(response, asyncBytes: bytes) - asyncBytes = bytes + lockedAsyncBytes = bytes } - guard let asyncBytes else { + guard let asyncBytes = lockedAsyncBytes else { throw URLSessionError.invalidResponse } try await decodeAndYieldEventStream(asyncBytes, to: continuation) From f13a1366129fbfad940213ef433428a88807981c Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 19 Feb 2026 03:49:25 -0800 Subject: [PATCH 6/9] Incorporate feedback from review --- .../Extensions/URLSession+Extensions.swift | 57 +++++++++++++++---- .../URLSessionExtensionsTests.swift | 37 ++++++++++++ 2 files changed, 82 insertions(+), 12 deletions(-) diff --git a/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift b/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift index 87efc636..8a1e111a 100644 --- a/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift +++ b/Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift @@ -21,26 +21,48 @@ enum HTTP { /// `_MultiHandle`, which has a known thread-safety bug that can crash under /// concurrent access (`URLSession._MultiHandle.endOperation(for:)`). /// - /// This gate intentionally allows only one in-flight request path at a time on Linux. - /// This fully serializes HTTP request setup paths on Linux and reduces request-level - /// parallelism, which can lower throughput for heavily concurrent workloads. + /// This gate intentionally allows only one in-flight request setup path at a time on Linux. + /// For non-streaming requests, callers typically hold this lock for the entire + /// request/response cycle, effectively serializing those operations and reducing + /// request-level parallelism (which can lower throughput for heavily concurrent + /// workloads). + /// + /// For streaming requests, callers usually acquire the gate only during initial + /// request setup and then release it once the stream has been established; stream + /// consumption itself is not serialized by this gate. /// Keep this scoped to Linux-only code paths until the upstream issue is resolved. /// /// See: https://github.com/swiftlang/swift-corelibs-foundation/issues/4791 actor LinuxURLSessionRequestGate { + private struct Waiter { + let id: UUID + let continuation: CheckedContinuation + } + static let shared = LinuxURLSessionRequestGate() private var isLocked = false - private var waiters: [CheckedContinuation] = [] + private var waiters: [Waiter] = [] + + func acquire() async throws { + if Task.isCancelled { + throw CancellationError() + } - func acquire() async { if !isLocked { isLocked = true return } - await withCheckedContinuation { continuation in - waiters.append(continuation) + let waiterID = UUID() + try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { continuation in + waiters.append(Waiter(id: waiterID, continuation: continuation)) + } + } onCancel: { + Task { + await self.cancelWaiter(id: waiterID) + } } } @@ -50,8 +72,17 @@ enum HTTP { return } - let continuation = waiters.removeFirst() - continuation.resume() + let waiter = waiters.removeFirst() + waiter.continuation.resume() + } + + private func cancelWaiter(id: UUID) { + guard let index = waiters.firstIndex(where: { $0.id == id }) else { + return + } + + let waiter = waiters.remove(at: index) + waiter.continuation.resume(throwing: CancellationError()) } } @@ -60,7 +91,7 @@ enum HTTP { _ operation: () async throws -> Void ) async throws { let gate = LinuxURLSessionRequestGate.shared - await gate.acquire() + try await gate.acquire() do { try await operation() await gate.release() @@ -228,14 +259,16 @@ extension URLSession { #if canImport(FoundationNetworking) var lockedAsyncBytes: AsyncThrowingStream? + var lockedResponse: URLResponse? try await withLinuxRequestLock { let (bytes, response) = try await self.linuxBytes(for: request) - try await self.validateEventStreamResponse(response, asyncBytes: bytes) lockedAsyncBytes = bytes + lockedResponse = response } - guard let asyncBytes = lockedAsyncBytes else { + guard let asyncBytes = lockedAsyncBytes, let response = lockedResponse else { throw URLSessionError.invalidResponse } + try await self.validateEventStreamResponse(response, asyncBytes: asyncBytes) try await decodeAndYieldEventStream(asyncBytes, to: continuation) #else let (asyncBytes, response) = try await self.bytes(for: request) diff --git a/Tests/AnyLanguageModelTests/URLSessionExtensionsTests.swift b/Tests/AnyLanguageModelTests/URLSessionExtensionsTests.swift index d8fb92b2..c1ac2669 100644 --- a/Tests/AnyLanguageModelTests/URLSessionExtensionsTests.swift +++ b/Tests/AnyLanguageModelTests/URLSessionExtensionsTests.swift @@ -39,6 +39,14 @@ struct URLSessionExtensionsTests { case expected } + private actor GateFlag { + private(set) var value = false + + func setTrue() { + value = true + } + } + extension URLSessionExtensionsTests { @Test func linuxGateSerializesConcurrentOperations() async throws { let counter = GateCounter() @@ -99,5 +107,34 @@ struct URLSessionExtensionsTests { #expect(acquiredAfterCancellation) } + + @Test func linuxGateCancelledWaiterDoesNotExecute() async throws { + let ranCancelledOperation = GateFlag() + + let holder = Task { + try await withLinuxRequestLock { + try await Task.sleep(for: .milliseconds(200)) + } + } + + try await Task.sleep(for: .milliseconds(20)) + + let waiter = Task { + do { + try await withLinuxRequestLock { + await ranCancelledOperation.setTrue() + } + } catch { + // Cancellation is expected. + } + } + + waiter.cancel() + _ = await waiter.result + try await holder.value + try await Task.sleep(for: .milliseconds(20)) + + #expect(await ranCancelledOperation.value == false) + } } #endif From 056dccdecea995f3ccc0917d8f8dc2c25e4619c2 Mon Sep 17 00:00:00 2001 From: Jonas Stoehr Date: Mon, 23 Mar 2026 12:51:48 +0100 Subject: [PATCH 7/9] AsyncHTTPClient support for linux (#143) * deps: add trait-based import of AsyncHTTPClient * feat: implement transparent AsyncHTTPClient wrapper * deps: conditionally include trait in EventSource * Increase HTTP request timeout from 60 to 180 seconds --------- Co-authored-by: Leonhard Solbach <49833472+KotlinFactory@users.noreply.github.com> --- Package.resolved | 193 +++++++++++++--- Package.swift | 21 +- .../Extensions/HTTPClient+Extensions.swift | 213 ++++++++++++++++++ .../Models/AnthropicLanguageModel.swift | 4 +- .../Models/GeminiLanguageModel.swift | 6 +- .../Models/OllamaLanguageModel.swift | 4 +- .../Models/OpenAILanguageModel.swift | 4 +- .../Models/OpenResponsesLanguageModel.swift | 4 +- Sources/AnyLanguageModel/Transport.swift | 20 ++ 9 files changed, 427 insertions(+), 42 deletions(-) create mode 100644 Sources/AnyLanguageModel/Extensions/HTTPClient+Extensions.swift create mode 100644 Sources/AnyLanguageModel/Transport.swift diff --git a/Package.resolved b/Package.resolved index 837d7768..7ed371b3 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,13 +1,22 @@ { - "originHash" : "f7b86b800200fa069a2b288e06bafe53bc937a1851b6effeebba326a62be227e", + "originHash" : "f2f0ba1d1b9625bd5147b2fbd7b82236dac35ee1baa399fcf5c76b22fd428bb8", "pins" : [ + { + "identity" : "async-http-client", + "kind" : "remoteSourceControl", + "location" : "https://github.com/swift-server/async-http-client.git", + "state" : { + "revision" : "2fc4652fb4689eb24af10e55cabaa61d8ba774fd", + "version" : "1.32.0" + } + }, { "identity" : "eventsource", "kind" : "remoteSourceControl", - "location" : "https://github.com/mattt/EventSource.git", + "location" : "https://github.com/mattt/EventSource", "state" : { - "revision" : "ca2a9d90cbe49e09b92f4b6ebd922c03ebea51d0", - "version" : "1.3.0" + "revision" : "bd64824505da71a1a403adb221f6e25413c0bc7f", + "version" : "1.4.0" } }, { @@ -20,39 +29,57 @@ } }, { - "identity" : "llama.swift", + "identity" : "partialjsondecoder", "kind" : "remoteSourceControl", - "location" : "https://github.com/mattt/llama.swift", + "location" : "https://github.com/mattt/PartialJSONDecoder.git", "state" : { - "revision" : "4d57cff84ba85914baa39850157e7c27684db9c8", - "version" : "2.7966.0" + "revision" : "e4d389e6bcc6771bb988d1a8a17695d8bfa97172", + "version" : "1.0.0" } }, { - "identity" : "mlx-swift", + "identity" : "swift-algorithms", "kind" : "remoteSourceControl", - "location" : "https://github.com/ml-explore/mlx-swift", + "location" : "https://github.com/apple/swift-algorithms.git", "state" : { - "revision" : "072b684acaae80b6a463abab3a103732f33774bf", - "version" : "0.29.1" + "revision" : "87e50f483c54e6efd60e885f7f5aa946cee68023", + "version" : "1.2.1" } }, { - "identity" : "mlx-swift-lm", + "identity" : "swift-asn1", "kind" : "remoteSourceControl", - "location" : "https://github.com/ml-explore/mlx-swift-lm", + "location" : "https://github.com/apple/swift-asn1.git", "state" : { - "revision" : "5064b8c5d8ed3b0bbb71385c4124f0fc102e74a2", - "version" : "2.29.3" + "revision" : "810496cf121e525d660cd0ea89a758740476b85f", + "version" : "1.5.1" } }, { - "identity" : "partialjsondecoder", + "identity" : "swift-async-algorithms", "kind" : "remoteSourceControl", - "location" : "https://github.com/mattt/PartialJSONDecoder.git", + "location" : "https://github.com/apple/swift-async-algorithms.git", "state" : { - "revision" : "e4d389e6bcc6771bb988d1a8a17695d8bfa97172", - "version" : "1.0.0" + "revision" : "9d349bcc328ac3c31ce40e746b5882742a0d1272", + "version" : "1.1.3" + } + }, + { + "identity" : "swift-atomics", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-atomics.git", + "state" : { + "revision" : "b601256eab081c0f92f059e12818ac1d4f178ff7", + "version" : "1.3.0" + } + }, + { + "identity" : "swift-certificates", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-certificates.git", + "state" : { + "revision" : "24ccdeeeed4dfaae7955fcac9dbf5489ed4f1a25", + "version" : "1.18.0" } }, { @@ -65,23 +92,131 @@ } }, { - "identity" : "swift-jinja", + "identity" : "swift-configuration", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-configuration.git", + "state" : { + "revision" : "be76c4ad929eb6c4bcaf3351799f2adf9e6848a9", + "version" : "1.2.0" + } + }, + { + "identity" : "swift-crypto", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-crypto.git", + "state" : { + "revision" : "6f70fa9eab24c1fd982af18c281c4525d05e3095", + "version" : "4.2.0" + } + }, + { + "identity" : "swift-distributed-tracing", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-distributed-tracing.git", + "state" : { + "revision" : "e109d8b5308d0e05201d9a1dd1c475446a946a11", + "version" : "1.4.0" + } + }, + { + "identity" : "swift-http-structured-headers", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-http-structured-headers.git", + "state" : { + "revision" : "76d7627bd88b47bf5a0f8497dd244885960dde0b", + "version" : "1.6.0" + } + }, + { + "identity" : "swift-http-types", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-http-types.git", + "state" : { + "revision" : "45eb0224913ea070ec4fba17291b9e7ecf4749ca", + "version" : "1.5.1" + } + }, + { + "identity" : "swift-log", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-log.git", + "state" : { + "revision" : "bbd81b6725ae874c69e9b8c8804d462356b55523", + "version" : "1.10.1" + } + }, + { + "identity" : "swift-nio", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio.git", + "state" : { + "revision" : "e932d3c4d8f77433c8f7093b5ebcbf91463948a0", + "version" : "2.95.0" + } + }, + { + "identity" : "swift-nio-extras", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio-extras.git", + "state" : { + "revision" : "3df009d563dc9f21a5c85b33d8c2e34d2e4f8c3b", + "version" : "1.32.1" + } + }, + { + "identity" : "swift-nio-http2", "kind" : "remoteSourceControl", - "location" : "https://github.com/huggingface/swift-jinja.git", + "location" : "https://github.com/apple/swift-nio-http2.git", "state" : { - "revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0", - "version" : "2.3.1" + "revision" : "b6571f3db40799df5a7fc0e92c399aa71c883edd", + "version" : "1.40.0" + } + }, + { + "identity" : "swift-nio-ssl", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio-ssl.git", + "state" : { + "revision" : "173cc69a058623525a58ae6710e2f5727c663793", + "version" : "2.36.0" + } + }, + { + "identity" : "swift-nio-transport-services", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio-transport-services.git", + "state" : { + "revision" : "60c3e187154421171721c1a38e800b390680fb5d", + "version" : "1.26.0" } }, { "identity" : "swift-numerics", "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-numerics", + "location" : "https://github.com/apple/swift-numerics.git", "state" : { "revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2", "version" : "1.1.1" } }, + { + "identity" : "swift-service-context", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-service-context.git", + "state" : { + "revision" : "d0997351b0c7779017f88e7a93bc30a1878d7f29", + "version" : "1.3.0" + } + }, + { + "identity" : "swift-service-lifecycle", + "kind" : "remoteSourceControl", + "location" : "https://github.com/swift-server/swift-service-lifecycle", + "state" : { + "revision" : "89888196dd79c61c50bca9a103d8114f32e1e598", + "version" : "2.10.1" + } + }, { "identity" : "swift-syntax", "kind" : "remoteSourceControl", @@ -92,12 +227,12 @@ } }, { - "identity" : "swift-transformers", + "identity" : "swift-system", "kind" : "remoteSourceControl", - "location" : "https://github.com/huggingface/swift-transformers", + "location" : "https://github.com/apple/swift-system", "state" : { - "revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0", - "version" : "1.1.6" + "revision" : "7c6ad0fc39d0763e0b699210e4124afd5041c5df", + "version" : "1.6.4" } } ], diff --git a/Package.swift b/Package.swift index 3916bf01..b8d69621 100644 --- a/Package.swift +++ b/Package.swift @@ -25,17 +25,22 @@ let package = Package( .trait(name: "CoreML"), .trait(name: "MLX"), .trait(name: "Llama"), + .trait(name: "AsyncHTTPClient"), .default(enabledTraits: []), ], dependencies: [ .package(url: "https://github.com/huggingface/swift-transformers", from: "1.0.0"), - .package(url: "https://github.com/mattt/EventSource", from: "1.3.0"), + .package(url: "https://github.com/mattt/EventSource", from: "1.3.0", traits: [ + .defaults, + .trait(name: "AsyncHTTPClient", condition: .when(traits: ["AsyncHTTPClient"])) + ]), .package(url: "https://github.com/mattt/JSONSchema", from: "1.3.0"), .package(url: "https://github.com/mattt/llama.swift", .upToNextMajor(from: "2.7484.0")), .package(url: "https://github.com/mattt/PartialJSONDecoder", from: "1.0.0"), // mlx-swift-lm must be >= 2.25.5 for ToolSpec/tool calls and UserInput(chat:processing:tools:). .package(url: "https://github.com/ml-explore/mlx-swift-lm", from: "2.25.5"), .package(url: "https://github.com/swiftlang/swift-syntax", from: "600.0.0"), + .package(url: "https://github.com/swift-server/async-http-client.git", from: "1.24.0"), ], targets: [ .target( @@ -70,6 +75,11 @@ let package = Package( package: "llama.swift", condition: .when(traits: ["Llama"]) ), + .product( + name: "AsyncHTTPClient", + package: "async-http-client", + condition: .when(traits: ["AsyncHTTPClient"]) + ), ] ), .macro( @@ -83,7 +93,14 @@ let package = Package( ), .testTarget( name: "AnyLanguageModelTests", - dependencies: ["AnyLanguageModel"] + dependencies: [ + "AnyLanguageModel", + .product( + name: "AsyncHTTPClient", + package: "async-http-client", + condition: .when(traits: ["AsyncHTTPClient"]) + ), + ], ), ] ) diff --git a/Sources/AnyLanguageModel/Extensions/HTTPClient+Extensions.swift b/Sources/AnyLanguageModel/Extensions/HTTPClient+Extensions.swift new file mode 100644 index 00000000..10896089 --- /dev/null +++ b/Sources/AnyLanguageModel/Extensions/HTTPClient+Extensions.swift @@ -0,0 +1,213 @@ +#if canImport(AsyncHTTPClient) +// AsyncHTTPClient.HTTPHandler introduces a Task type that clashes +typealias SwiftTask = Task + +import AsyncHTTPClient +import EventSource +import Foundation +#if canImport(FoundationNetworking) +import FoundationNetworking +#endif +import JSONSchema +import NIOCore +import NIOHTTP1 +import NIOFoundationCompat + +extension HTTPClient { + func fetch( + _ method: HTTP.Method, + url: URL, + headers: [String: String] = [:], + body: Data? = nil, + dateDecodingStrategy: JSONDecoder.DateDecodingStrategy = .deferredToDate + ) async throws -> T { + var request = HTTPClientRequest(url: url.absoluteString) + request.method = HTTPMethod(rawValue: method.rawValue) + request.headers.add(name: "Accept", value: "application/json") + + for (key, value) in headers { + request.headers.add(name: key, value: value) + } + + if let body { + request.body = .bytes(ByteBuffer(data: body)) + request.headers.add(name: "Content-Type", value: "application/json") + } + + let response = try await self.execute(request, timeout: .seconds(180)) + + guard (200 ..< 300).contains(response.status.code) else { + let bodyData = try await Data(buffer: response.body.collect(upTo: 1024 * 1024)) + if let errorString = String(data: bodyData, encoding: .utf8) { + throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: errorString) + } + throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: "Invalid response") + } + + let bodyData = try await Data(buffer: response.body.collect(upTo: 1024 * 1024)) + + let decoder = JSONDecoder() + decoder.dateDecodingStrategy = dateDecodingStrategy + + do { + return try decoder.decode(T.self, from: bodyData) + } catch { + throw HTTPClientError.decodingError(detail: error.localizedDescription) + } + } + + func fetchStream( + _ method: HTTP.Method, + url: URL, + headers: [String: String] = [:], + body: Data? = nil, + dateDecodingStrategy: JSONDecoder.DateDecodingStrategy = .deferredToDate + ) -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + let task = SwiftTask { @Sendable in + let decoder = JSONDecoder() + decoder.dateDecodingStrategy = dateDecodingStrategy + + do { + var request = HTTPClientRequest(url: url.absoluteString) + request.method = HTTPMethod(rawValue: method.rawValue) + request.headers.add(name: "Accept", value: "application/json") + + for (key, value) in headers { + request.headers.add(name: key, value: value) + } + + if let body { + request.body = .bytes(ByteBuffer(data: body)) + request.headers.add(name: "Content-Type", value: "application/json") + } + + let response = try await self.execute(request, timeout: .seconds(60)) + + guard (200 ..< 300).contains(response.status.code) else { + let bodyData = try await Data(buffer: response.body.collect(upTo: 1024 * 1024)) + if let errorString = String(data: bodyData, encoding: .utf8) { + throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: errorString) + } + throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: "Invalid response") + } + + var buffer = Data() + + for try await chunk in response.body { + buffer.append(contentsOf: chunk.readableBytesView) + + while let newlineIndex = buffer.firstIndex(of: UInt8(ascii: "\n")) { + let line = buffer[..( + _ method: HTTP.Method, + url: URL, + headers: [String: String] = [:], + body: Data? = nil + ) -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + let task = SwiftTask { @Sendable in + do { + var request = HTTPClientRequest(url: url.absoluteString) + request.method = HTTPMethod(rawValue: method.rawValue) + request.headers.add(name: "Accept", value: "text/event-stream") + + for (key, value) in headers { + request.headers.add(name: key, value: value) + } + + if let body { + request.body = .bytes(ByteBuffer(data: body)) + request.headers.add(name: "Content-Type", value: "application/json") + } + + let response = try await self.execute(request, timeout: .seconds(60)) + + guard (200 ..< 300).contains(response.status.code) else { + let bodyData = try await Data(buffer: response.body.collect(upTo: 1024 * 1024)) + if let errorString = String(data: bodyData, encoding: .utf8) { + throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: errorString) + } + throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: "Invalid response") + } + + let asyncBytes = AsyncStream { byteContinuation in + SwiftTask { + do { + for try await buffer in response.body { + for byte in buffer.readableBytesView { + byteContinuation.yield(byte) + } + } + byteContinuation.finish() + } catch { + byteContinuation.finish() + } + } + } + + try await self.decodeAndYieldEventStream(asyncBytes, to: continuation) + continuation.finish() + } catch { + continuation.finish(throwing: error) + } + } + + continuation.onTermination = { _ in + task.cancel() + } + } + } + + private func decodeAndYieldEventStream( + _ asyncBytes: Bytes, + to continuation: AsyncThrowingStream.Continuation + ) async throws where Bytes: AsyncSequence, Bytes.Element == UInt8 { + let decoder = JSONDecoder() + for try await event in asyncBytes.events { + guard let data = event.data.data(using: .utf8) else { continue } + if let decoded = try? decoder.decode(T.self, from: data) { + continuation.yield(decoded) + } + } + } +} + +enum HTTPClientError: Error, CustomStringConvertible { + case invalidResponse + case httpError(statusCode: Int, detail: String) + case decodingError(detail: String) + + var description: String { + switch self { + case .invalidResponse: + return "Invalid response" + case .httpError(let statusCode, let detail): + return "HTTP error (Status \(statusCode)): \(detail)" + case .decodingError(let detail): + return "Decoding error: \(detail)" + } + } +} +#endif diff --git a/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift b/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift index bd21a1f2..851c5559 100644 --- a/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift @@ -278,7 +278,7 @@ public struct AnthropicLanguageModel: LanguageModel { /// The model identifier to use for generation. public let model: String - private let urlSession: URLSession + private let urlSession: SessionType /// Creates an Anthropic language model. /// @@ -295,7 +295,7 @@ public struct AnthropicLanguageModel: LanguageModel { apiVersion: String = defaultAPIVersion, betas: [String]? = nil, model: String, - session: URLSession = URLSession(configuration: .default) + session: SessionType = makeDefaultSession(), ) { var baseURL = baseURL if !baseURL.path.hasSuffix("/") { diff --git a/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift b/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift index 2da15a48..e2230ab2 100644 --- a/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift @@ -186,7 +186,7 @@ public struct GeminiLanguageModel: LanguageModel { /// Internal storage for the deprecated serverTools property. internal var _serverTools: [CustomGenerationOptions.ServerTool] - private let urlSession: URLSession + private let urlSession: SessionType /// Creates a new Gemini language model. /// @@ -201,7 +201,7 @@ public struct GeminiLanguageModel: LanguageModel { apiKey tokenProvider: @escaping @autoclosure @Sendable () -> String, apiVersion: String = defaultAPIVersion, model: String, - session: URLSession = URLSession(configuration: .default) + session: SessionType = makeDefaultSession(), ) { var baseURL = baseURL if !baseURL.path.hasSuffix("/") { @@ -243,7 +243,7 @@ public struct GeminiLanguageModel: LanguageModel { model: String, thinking: CustomGenerationOptions.Thinking = .disabled, serverTools: [CustomGenerationOptions.ServerTool] = [], - session: URLSession = URLSession(configuration: .default) + session: SessionType = makeDefaultSession(), ) { var baseURL = baseURL if !baseURL.path.hasSuffix("/") { diff --git a/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift index 6be5d02d..1e498566 100644 --- a/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift @@ -46,7 +46,7 @@ public struct OllamaLanguageModel: LanguageModel { /// The model identifier to use for generation. public let model: String - private let urlSession: URLSession + private let urlSession: SessionType /// Creates an Ollama language model. /// @@ -57,7 +57,7 @@ public struct OllamaLanguageModel: LanguageModel { public init( baseURL: URL = defaultBaseURL, model: String, - session: URLSession = URLSession(configuration: .default) + session: SessionType = makeDefaultSession(), ) { var baseURL = baseURL if !baseURL.path.hasSuffix("/") { diff --git a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift index db4eab33..a1b6b0ee 100644 --- a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift @@ -393,7 +393,7 @@ public struct OpenAILanguageModel: LanguageModel { /// The API variant to use. public let apiVariant: APIVariant - private let urlSession: URLSession + private let urlSession: SessionType /// Creates an OpenAI language model. /// @@ -408,7 +408,7 @@ public struct OpenAILanguageModel: LanguageModel { apiKey tokenProvider: @escaping @autoclosure @Sendable () -> String, model: String, apiVariant: APIVariant = .chatCompletions, - session: URLSession = URLSession(configuration: .default) + session: SessionType = makeDefaultSession(), ) { var baseURL = baseURL if !baseURL.path.hasSuffix("/") { diff --git a/Sources/AnyLanguageModel/Models/OpenResponsesLanguageModel.swift b/Sources/AnyLanguageModel/Models/OpenResponsesLanguageModel.swift index c4ba51ef..c123e800 100644 --- a/Sources/AnyLanguageModel/Models/OpenResponsesLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OpenResponsesLanguageModel.swift @@ -365,7 +365,7 @@ public struct OpenResponsesLanguageModel: LanguageModel { /// Model identifier to use for generation. public let model: String - private let urlSession: URLSession + private let urlSession: SessionType /// Creates an Open Responses language model. /// @@ -378,7 +378,7 @@ public struct OpenResponsesLanguageModel: LanguageModel { baseURL: URL, apiKey tokenProvider: @escaping @autoclosure @Sendable () -> String, model: String, - session: URLSession = URLSession(configuration: .default) + session: SessionType = makeDefaultSession(), ) { var baseURL = baseURL if !baseURL.path.hasSuffix("/") { diff --git a/Sources/AnyLanguageModel/Transport.swift b/Sources/AnyLanguageModel/Transport.swift new file mode 100644 index 00000000..a8d81861 --- /dev/null +++ b/Sources/AnyLanguageModel/Transport.swift @@ -0,0 +1,20 @@ +#if canImport(AsyncHTTPClient) + import AsyncHTTPClient + + public typealias SessionType = HTTPClient + + public func makeDefaultSession() -> SessionType { + return HTTPClient.shared + } +#else + import Foundation + #if canImport(FoundationNetworking) + import FoundationNetworking + #endif + + public typealias SessionType = URLSession + + public func makeDefaultSession() -> SessionType { + return URLSession(configuration: .default) + } +#endif From b518a560605f63d944ced165f3caa43ef76f2510 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 23 Mar 2026 04:53:01 -0700 Subject: [PATCH 8/9] Incorporate feedback from review --- .../Extensions/HTTPClient+Extensions.swift | 352 ++++++++++-------- .../Models/AnthropicLanguageModel.swift | 12 +- .../Models/GeminiLanguageModel.swift | 18 +- .../Models/OllamaLanguageModel.swift | 12 +- .../Models/OpenAILanguageModel.swift | 16 +- .../Models/OpenResponsesLanguageModel.swift | 12 +- .../{ => Shared}/Transport.swift | 8 +- 7 files changed, 226 insertions(+), 204 deletions(-) rename Sources/AnyLanguageModel/{ => Shared}/Transport.swift (58%) diff --git a/Sources/AnyLanguageModel/Extensions/HTTPClient+Extensions.swift b/Sources/AnyLanguageModel/Extensions/HTTPClient+Extensions.swift index 10896089..1d723cdd 100644 --- a/Sources/AnyLanguageModel/Extensions/HTTPClient+Extensions.swift +++ b/Sources/AnyLanguageModel/Extensions/HTTPClient+Extensions.swift @@ -1,213 +1,235 @@ #if canImport(AsyncHTTPClient) -// AsyncHTTPClient.HTTPHandler introduces a Task type that clashes -typealias SwiftTask = Task - -import AsyncHTTPClient -import EventSource -import Foundation -#if canImport(FoundationNetworking) -import FoundationNetworking -#endif -import JSONSchema -import NIOCore -import NIOHTTP1 -import NIOFoundationCompat - -extension HTTPClient { - func fetch( - _ method: HTTP.Method, - url: URL, - headers: [String: String] = [:], - body: Data? = nil, - dateDecodingStrategy: JSONDecoder.DateDecodingStrategy = .deferredToDate - ) async throws -> T { - var request = HTTPClientRequest(url: url.absoluteString) - request.method = HTTPMethod(rawValue: method.rawValue) - request.headers.add(name: "Accept", value: "application/json") - - for (key, value) in headers { - request.headers.add(name: key, value: value) - } + // AsyncHTTPClient.HTTPHandler introduces a Task type that clashes with Swift's Task. + // Bind Swift's structured-concurrency Task before importing AsyncHTTPClient. + typealias SwiftTask = Task - if let body { - request.body = .bytes(ByteBuffer(data: body)) - request.headers.add(name: "Content-Type", value: "application/json") - } + /// Holds the body-read task for `fetchEventStream` so outer stream cancellation can stop NIO body iteration. + private final class HTTPClientBodyReaderTaskBox: @unchecked Sendable { + var task: SwiftTask? + } - let response = try await self.execute(request, timeout: .seconds(180)) + import AsyncHTTPClient + import EventSource + import Foundation + import NIOCore + import NIOHTTP1 + import NIOFoundationCompat + + extension HTTPClient { + func fetch( + _ method: HTTP.Method, + url: URL, + headers: [String: String] = [:], + body: Data? = nil, + dateDecodingStrategy: JSONDecoder.DateDecodingStrategy = .deferredToDate + ) async throws -> T { + var request = HTTPClientRequest(url: url.absoluteString) + request.method = HTTPMethod(rawValue: method.rawValue) + request.headers.add(name: "Accept", value: "application/json") + + for (key, value) in headers { + request.headers.add(name: key, value: value) + } - guard (200 ..< 300).contains(response.status.code) else { - let bodyData = try await Data(buffer: response.body.collect(upTo: 1024 * 1024)) - if let errorString = String(data: bodyData, encoding: .utf8) { - throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: errorString) + if let body { + request.body = .bytes(ByteBuffer(data: body)) + request.headers.add(name: "Content-Type", value: "application/json") + } + + let response = try await self.execute(request, timeout: .seconds(180)) + + guard (200 ..< 300).contains(response.status.code) else { + let bodyData = try await Data(buffer: response.body.collect(upTo: 1024 * 1024)) + if let errorString = String(data: bodyData, encoding: .utf8) { + throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: errorString) + } + throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: "Invalid response") } - throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: "Invalid response") - } - let bodyData = try await Data(buffer: response.body.collect(upTo: 1024 * 1024)) + let bodyData = try await Data(buffer: response.body.collect(upTo: 1024 * 1024)) - let decoder = JSONDecoder() - decoder.dateDecodingStrategy = dateDecodingStrategy + let decoder = JSONDecoder() + decoder.dateDecodingStrategy = dateDecodingStrategy - do { - return try decoder.decode(T.self, from: bodyData) - } catch { - throw HTTPClientError.decodingError(detail: error.localizedDescription) + do { + return try decoder.decode(T.self, from: bodyData) + } catch { + throw HTTPClientError.decodingError(detail: error.localizedDescription) + } } - } - func fetchStream( - _ method: HTTP.Method, - url: URL, - headers: [String: String] = [:], - body: Data? = nil, - dateDecodingStrategy: JSONDecoder.DateDecodingStrategy = .deferredToDate - ) -> AsyncThrowingStream { - AsyncThrowingStream { continuation in - let task = SwiftTask { @Sendable in - let decoder = JSONDecoder() - decoder.dateDecodingStrategy = dateDecodingStrategy - - do { - var request = HTTPClientRequest(url: url.absoluteString) - request.method = HTTPMethod(rawValue: method.rawValue) - request.headers.add(name: "Accept", value: "application/json") - - for (key, value) in headers { - request.headers.add(name: key, value: value) - } + func fetchStream( + _ method: HTTP.Method, + url: URL, + headers: [String: String] = [:], + body: Data? = nil, + dateDecodingStrategy: JSONDecoder.DateDecodingStrategy = .deferredToDate + ) -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + let task = SwiftTask { @Sendable in + let decoder = JSONDecoder() + decoder.dateDecodingStrategy = dateDecodingStrategy + + do { + var request = HTTPClientRequest(url: url.absoluteString) + request.method = HTTPMethod(rawValue: method.rawValue) + request.headers.add(name: "Accept", value: "application/json") + + for (key, value) in headers { + request.headers.add(name: key, value: value) + } - if let body { - request.body = .bytes(ByteBuffer(data: body)) - request.headers.add(name: "Content-Type", value: "application/json") - } + if let body { + request.body = .bytes(ByteBuffer(data: body)) + request.headers.add(name: "Content-Type", value: "application/json") + } - let response = try await self.execute(request, timeout: .seconds(60)) + let response = try await self.execute(request, timeout: .seconds(60)) - guard (200 ..< 300).contains(response.status.code) else { - let bodyData = try await Data(buffer: response.body.collect(upTo: 1024 * 1024)) - if let errorString = String(data: bodyData, encoding: .utf8) { - throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: errorString) + guard (200 ..< 300).contains(response.status.code) else { + let bodyData = try await Data(buffer: response.body.collect(upTo: 1024 * 1024)) + if let errorString = String(data: bodyData, encoding: .utf8) { + throw HTTPClientError.httpError( + statusCode: Int(response.status.code), + detail: errorString + ) + } + throw HTTPClientError.httpError( + statusCode: Int(response.status.code), + detail: "Invalid response" + ) } - throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: "Invalid response") - } - var buffer = Data() + var buffer = Data() - for try await chunk in response.body { - buffer.append(contentsOf: chunk.readableBytesView) + for try await chunk in response.body { + try SwiftTask.checkCancellation() + buffer.append(contentsOf: chunk.readableBytesView) - while let newlineIndex = buffer.firstIndex(of: UInt8(ascii: "\n")) { - let line = buffer[..( - _ method: HTTP.Method, - url: URL, - headers: [String: String] = [:], - body: Data? = nil - ) -> AsyncThrowingStream { - AsyncThrowingStream { continuation in - let task = SwiftTask { @Sendable in - do { - var request = HTTPClientRequest(url: url.absoluteString) - request.method = HTTPMethod(rawValue: method.rawValue) - request.headers.add(name: "Accept", value: "text/event-stream") - - for (key, value) in headers { - request.headers.add(name: key, value: value) - } + func fetchEventStream( + _ method: HTTP.Method, + url: URL, + headers: [String: String] = [:], + body: Data? = nil + ) -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + let bodyReaderBox = HTTPClientBodyReaderTaskBox() + + let task = SwiftTask { @Sendable in + do { + var request = HTTPClientRequest(url: url.absoluteString) + request.method = HTTPMethod(rawValue: method.rawValue) + request.headers.add(name: "Accept", value: "text/event-stream") + + for (key, value) in headers { + request.headers.add(name: key, value: value) + } - if let body { - request.body = .bytes(ByteBuffer(data: body)) - request.headers.add(name: "Content-Type", value: "application/json") - } + if let body { + request.body = .bytes(ByteBuffer(data: body)) + request.headers.add(name: "Content-Type", value: "application/json") + } - let response = try await self.execute(request, timeout: .seconds(60)) + let response = try await self.execute(request, timeout: .seconds(60)) - guard (200 ..< 300).contains(response.status.code) else { - let bodyData = try await Data(buffer: response.body.collect(upTo: 1024 * 1024)) - if let errorString = String(data: bodyData, encoding: .utf8) { - throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: errorString) + guard (200 ..< 300).contains(response.status.code) else { + let bodyData = try await Data(buffer: response.body.collect(upTo: 1024 * 1024)) + if let errorString = String(data: bodyData, encoding: .utf8) { + throw HTTPClientError.httpError( + statusCode: Int(response.status.code), + detail: errorString + ) + } + throw HTTPClientError.httpError( + statusCode: Int(response.status.code), + detail: "Invalid response" + ) } - throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: "Invalid response") - } - let asyncBytes = AsyncStream { byteContinuation in - SwiftTask { - do { - for try await buffer in response.body { - for byte in buffer.readableBytesView { - byteContinuation.yield(byte) + let asyncBytes = AsyncStream { byteContinuation in + bodyReaderBox.task = SwiftTask { + do { + for try await buffer in response.body { + try SwiftTask.checkCancellation() + for byte in buffer.readableBytesView { + byteContinuation.yield(byte) + } } + byteContinuation.finish() + } catch { + byteContinuation.finish() } - byteContinuation.finish() - } catch { - byteContinuation.finish() + } + byteContinuation.onTermination = { _ in + bodyReaderBox.task?.cancel() } } + + try await self.decodeAndYieldEventStream(asyncBytes, to: continuation) + continuation.finish() + } catch { + continuation.finish(throwing: error) } + } - try await self.decodeAndYieldEventStream(asyncBytes, to: continuation) - continuation.finish() - } catch { - continuation.finish(throwing: error) + continuation.onTermination = { _ in + bodyReaderBox.task?.cancel() + task.cancel() } } + } - continuation.onTermination = { _ in - task.cancel() + private func decodeAndYieldEventStream( + _ asyncBytes: Bytes, + to continuation: AsyncThrowingStream.Continuation + ) async throws where Bytes: AsyncSequence, Bytes.Element == UInt8 { + let decoder = JSONDecoder() + for try await event in asyncBytes.events { + guard let data = event.data.data(using: .utf8) else { continue } + if let decoded = try? decoder.decode(T.self, from: data) { + continuation.yield(decoded) + } } } } - private func decodeAndYieldEventStream( - _ asyncBytes: Bytes, - to continuation: AsyncThrowingStream.Continuation - ) async throws where Bytes: AsyncSequence, Bytes.Element == UInt8 { - let decoder = JSONDecoder() - for try await event in asyncBytes.events { - guard let data = event.data.data(using: .utf8) else { continue } - if let decoded = try? decoder.decode(T.self, from: data) { - continuation.yield(decoded) + enum HTTPClientError: Error, CustomStringConvertible { + case invalidResponse + case httpError(statusCode: Int, detail: String) + case decodingError(detail: String) + + var description: String { + switch self { + case .invalidResponse: + return "Invalid response" + case .httpError(let statusCode, let detail): + return "HTTP error (Status \(statusCode)): \(detail)" + case .decodingError(let detail): + return "Decoding error: \(detail)" } } } -} - -enum HTTPClientError: Error, CustomStringConvertible { - case invalidResponse - case httpError(statusCode: Int, detail: String) - case decodingError(detail: String) - - var description: String { - switch self { - case .invalidResponse: - return "Invalid response" - case .httpError(let statusCode, let detail): - return "HTTP error (Status \(statusCode)): \(detail)" - case .decodingError(let detail): - return "Decoding error: \(detail)" - } - } -} #endif diff --git a/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift b/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift index 851c5559..f78cd078 100644 --- a/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift @@ -278,7 +278,7 @@ public struct AnthropicLanguageModel: LanguageModel { /// The model identifier to use for generation. public let model: String - private let urlSession: SessionType + private let httpSession: HTTPSession /// Creates an Anthropic language model. /// @@ -288,14 +288,14 @@ public struct AnthropicLanguageModel: LanguageModel { /// - apiVersion: The API version to use for requests. Defaults to `2023-06-01`. /// - betas: Optional beta version(s) of the API to use. /// - model: The model identifier (for example, "claude-3-5-sonnet-20241022"). - /// - session: The URL session to use for network requests. + /// - session: The HTTP session or client used for network requests. public init( baseURL: URL = defaultBaseURL, apiKey tokenProvider: @escaping @autoclosure @Sendable () -> String, apiVersion: String = defaultAPIVersion, betas: [String]? = nil, model: String, - session: SessionType = makeDefaultSession(), + session: HTTPSession = makeDefaultSession(), ) { var baseURL = baseURL if !baseURL.path.hasSuffix("/") { @@ -307,7 +307,7 @@ public struct AnthropicLanguageModel: LanguageModel { self.apiVersion = apiVersion self.betas = betas self.model = model - self.urlSession = session + self.httpSession = session } public func respond( @@ -337,7 +337,7 @@ public struct AnthropicLanguageModel: LanguageModel { let body = try JSONEncoder().encode(params) - let message: AnthropicMessageResponse = try await urlSession.fetch( + let message: AnthropicMessageResponse = try await httpSession.fetch( .post, url: url, headers: headers, @@ -435,7 +435,7 @@ public struct AnthropicLanguageModel: LanguageModel { // Stream server-sent events from Anthropic API let events: AsyncThrowingStream = - urlSession + httpSession .fetchEventStream( .post, url: url, diff --git a/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift b/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift index e2230ab2..b074ddca 100644 --- a/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift @@ -186,7 +186,7 @@ public struct GeminiLanguageModel: LanguageModel { /// Internal storage for the deprecated serverTools property. internal var _serverTools: [CustomGenerationOptions.ServerTool] - private let urlSession: SessionType + private let httpSession: HTTPSession /// Creates a new Gemini language model. /// @@ -195,13 +195,13 @@ public struct GeminiLanguageModel: LanguageModel { /// - tokenProvider: A closure that provides the API key. /// - apiVersion: The API version to use. /// - model: The model identifier. - /// - session: The URL session for network requests. + /// - session: The HTTP session or client used for network requests. public init( baseURL: URL = defaultBaseURL, apiKey tokenProvider: @escaping @autoclosure @Sendable () -> String, apiVersion: String = defaultAPIVersion, model: String, - session: SessionType = makeDefaultSession(), + session: HTTPSession = makeDefaultSession(), ) { var baseURL = baseURL if !baseURL.path.hasSuffix("/") { @@ -214,7 +214,7 @@ public struct GeminiLanguageModel: LanguageModel { self.model = model self._thinking = .disabled self._serverTools = [] - self.urlSession = session + self.httpSession = session } /// Creates a new Gemini language model with thinking and server tools configuration. @@ -226,7 +226,7 @@ public struct GeminiLanguageModel: LanguageModel { /// - model: The model identifier. /// - thinking: The thinking mode configuration. /// - serverTools: Server-side tools to enable. - /// - session: The URL session for network requests. + /// - session: The HTTP session or client used for network requests. /// /// - Important: This initializer is deprecated. Use the initializer without /// `thinking` and `serverTools` parameters, and pass these options through @@ -243,7 +243,7 @@ public struct GeminiLanguageModel: LanguageModel { model: String, thinking: CustomGenerationOptions.Thinking = .disabled, serverTools: [CustomGenerationOptions.ServerTool] = [], - session: SessionType = makeDefaultSession(), + session: HTTPSession = makeDefaultSession(), ) { var baseURL = baseURL if !baseURL.path.hasSuffix("/") { @@ -256,7 +256,7 @@ public struct GeminiLanguageModel: LanguageModel { self.model = model self._thinking = thinking self._serverTools = serverTools - self.urlSession = session + self.httpSession = session } public func respond( @@ -295,7 +295,7 @@ public struct GeminiLanguageModel: LanguageModel { let body = try JSONEncoder().encode(params) - let response: GeminiGenerateContentResponse = try await urlSession.fetch( + let response: GeminiGenerateContentResponse = try await httpSession.fetch( .post, url: url, headers: headers, @@ -407,7 +407,7 @@ public struct GeminiLanguageModel: LanguageModel { let body = try JSONEncoder().encode(params) let stream: AsyncThrowingStream = - urlSession + httpSession .fetchEventStream( .post, url: url, diff --git a/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift index 1e498566..82be5cc7 100644 --- a/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift @@ -46,18 +46,18 @@ public struct OllamaLanguageModel: LanguageModel { /// The model identifier to use for generation. public let model: String - private let urlSession: SessionType + private let httpSession: HTTPSession /// Creates an Ollama language model. /// /// - Parameters: /// - baseURL: The base URL for the Ollama server. Defaults to `http://localhost:11434`. /// - model: The model identifier (for example, "qwen2.5" or "llama3.3"). - /// - session: The URL session to use for network requests. + /// - session: The HTTP session or client used for network requests. public init( baseURL: URL = defaultBaseURL, model: String, - session: SessionType = makeDefaultSession(), + session: HTTPSession = makeDefaultSession(), ) { var baseURL = baseURL if !baseURL.path.hasSuffix("/") { @@ -66,7 +66,7 @@ public struct OllamaLanguageModel: LanguageModel { self.baseURL = baseURL self.model = model - self.urlSession = session + self.httpSession = session } public func respond( @@ -105,7 +105,7 @@ public struct OllamaLanguageModel: LanguageModel { let url = baseURL.appendingPathComponent("api/chat") let body = try JSONEncoder().encode(params) - let chatResponse: ChatResponse = try await urlSession.fetch( + let chatResponse: ChatResponse = try await httpSession.fetch( .post, url: url, body: body, @@ -199,7 +199,7 @@ public struct OllamaLanguageModel: LanguageModel { // Reuse ChatResponse as each streamed line shares the same shape do { let chunks = - urlSession.fetchStream( + httpSession.fetchStream( .post, url: url, body: body, diff --git a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift index a1b6b0ee..20bca5eb 100644 --- a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift @@ -393,7 +393,7 @@ public struct OpenAILanguageModel: LanguageModel { /// The API variant to use. public let apiVariant: APIVariant - private let urlSession: SessionType + private let httpSession: HTTPSession /// Creates an OpenAI language model. /// @@ -402,13 +402,13 @@ public struct OpenAILanguageModel: LanguageModel { /// - apiKey: Your OpenAI API key or a closure that returns it. /// - model: The model identifier (for example, "gpt-4" or "gpt-3.5-turbo"). /// - apiVariant: The API variant to use. Defaults to `.chatCompletions`. - /// - session: The URL session to use for network requests. + /// - session: The HTTP session or client used for network requests. public init( baseURL: URL = defaultBaseURL, apiKey tokenProvider: @escaping @autoclosure @Sendable () -> String, model: String, apiVariant: APIVariant = .chatCompletions, - session: SessionType = makeDefaultSession(), + session: HTTPSession = makeDefaultSession(), ) { var baseURL = baseURL if !baseURL.path.hasSuffix("/") { @@ -419,7 +419,7 @@ public struct OpenAILanguageModel: LanguageModel { self.tokenProvider = tokenProvider self.model = model self.apiVariant = apiVariant - self.urlSession = session + self.httpSession = session } public func respond( @@ -485,7 +485,7 @@ public struct OpenAILanguageModel: LanguageModel { let url = baseURL.appendingPathComponent("chat/completions") let body = try JSONEncoder().encode(params) - let resp: ChatCompletions.Response = try await urlSession.fetch( + let resp: ChatCompletions.Response = try await httpSession.fetch( .post, url: url, headers: [ @@ -593,7 +593,7 @@ public struct OpenAILanguageModel: LanguageModel { let encoder = JSONEncoder() let body = try encoder.encode(params) - let resp: Responses.Response = try await urlSession.fetch( + let resp: Responses.Response = try await httpSession.fetch( .post, url: url, headers: [ @@ -704,7 +704,7 @@ public struct OpenAILanguageModel: LanguageModel { let body = try JSONEncoder().encode(params) let events: AsyncThrowingStream = - urlSession.fetchEventStream( + httpSession.fetchEventStream( .post, url: url, headers: [ @@ -788,7 +788,7 @@ public struct OpenAILanguageModel: LanguageModel { let body = try JSONEncoder().encode(params) let events: AsyncThrowingStream = - urlSession.fetchEventStream( + httpSession.fetchEventStream( .post, url: url, headers: [ diff --git a/Sources/AnyLanguageModel/Models/OpenResponsesLanguageModel.swift b/Sources/AnyLanguageModel/Models/OpenResponsesLanguageModel.swift index c123e800..8370e14f 100644 --- a/Sources/AnyLanguageModel/Models/OpenResponsesLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OpenResponsesLanguageModel.swift @@ -365,7 +365,7 @@ public struct OpenResponsesLanguageModel: LanguageModel { /// Model identifier to use for generation. public let model: String - private let urlSession: SessionType + private let httpSession: HTTPSession /// Creates an Open Responses language model. /// @@ -373,12 +373,12 @@ public struct OpenResponsesLanguageModel: LanguageModel { /// - baseURL: Base URL for the API (e.g. `https://api.openai.com/v1/` or `https://openrouter.ai/api/v1/`). Must end with `/`. /// - apiKey: API key or closure that returns it. /// - model: Model identifier (e.g. `gpt-4o-mini` or provider-specific id). - /// - session: URL session for network requests. + /// - session: The HTTP session or client used for network requests. public init( baseURL: URL, apiKey tokenProvider: @escaping @autoclosure @Sendable () -> String, model: String, - session: SessionType = makeDefaultSession(), + session: HTTPSession = makeDefaultSession(), ) { var baseURL = baseURL if !baseURL.path.hasSuffix("/") { @@ -387,7 +387,7 @@ public struct OpenResponsesLanguageModel: LanguageModel { self.baseURL = baseURL self.tokenProvider = tokenProvider self.model = model - self.urlSession = session + self.httpSession = session } public func respond( @@ -433,7 +433,7 @@ public struct OpenResponsesLanguageModel: LanguageModel { do { let body = try JSONEncoder().encode(params) let events: AsyncThrowingStream = - urlSession.fetchEventStream( + httpSession.fetchEventStream( .post, url: url, headers: ["Authorization": "Bearer \(tokenProvider())"], @@ -505,7 +505,7 @@ public struct OpenResponsesLanguageModel: LanguageModel { stream: false ) let body = try JSONEncoder().encode(params) - let resp: OpenResponsesAPI.Response = try await urlSession.fetch( + let resp: OpenResponsesAPI.Response = try await httpSession.fetch( .post, url: url, headers: ["Authorization": "Bearer \(tokenProvider())"], diff --git a/Sources/AnyLanguageModel/Transport.swift b/Sources/AnyLanguageModel/Shared/Transport.swift similarity index 58% rename from Sources/AnyLanguageModel/Transport.swift rename to Sources/AnyLanguageModel/Shared/Transport.swift index a8d81861..5e418a09 100644 --- a/Sources/AnyLanguageModel/Transport.swift +++ b/Sources/AnyLanguageModel/Shared/Transport.swift @@ -1,9 +1,9 @@ #if canImport(AsyncHTTPClient) import AsyncHTTPClient - public typealias SessionType = HTTPClient + public typealias HTTPSession = HTTPClient - public func makeDefaultSession() -> SessionType { + public func makeDefaultSession() -> HTTPSession { return HTTPClient.shared } #else @@ -12,9 +12,9 @@ import FoundationNetworking #endif - public typealias SessionType = URLSession + public typealias HTTPSession = URLSession - public func makeDefaultSession() -> SessionType { + public func makeDefaultSession() -> HTTPSession { return URLSession(configuration: .default) } #endif From 4d52094174cbb65f69f33560d797074e9b0e38cf Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 23 Mar 2026 04:53:11 -0700 Subject: [PATCH 9/9] swift format -i -r . --- Package.swift | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/Package.swift b/Package.swift index b8d69621..933f837b 100644 --- a/Package.swift +++ b/Package.swift @@ -30,10 +30,14 @@ let package = Package( ], dependencies: [ .package(url: "https://github.com/huggingface/swift-transformers", from: "1.0.0"), - .package(url: "https://github.com/mattt/EventSource", from: "1.3.0", traits: [ - .defaults, - .trait(name: "AsyncHTTPClient", condition: .when(traits: ["AsyncHTTPClient"])) - ]), + .package( + url: "https://github.com/mattt/EventSource", + from: "1.3.0", + traits: [ + .defaults, + .trait(name: "AsyncHTTPClient", condition: .when(traits: ["AsyncHTTPClient"])), + ] + ), .package(url: "https://github.com/mattt/JSONSchema", from: "1.3.0"), .package(url: "https://github.com/mattt/llama.swift", .upToNextMajor(from: "2.7484.0")), .package(url: "https://github.com/mattt/PartialJSONDecoder", from: "1.0.0"),