Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .envrc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export GIT_DIR=$PWD/.jj/repo/store/git
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ DerivedData/
.swiftpm/configuration/registries.json
.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata
.netrc
.index-build/
.index-build/
.claude/settings.local.json
.swift-version
120 changes: 109 additions & 11 deletions Sources/MCP/Base/Transports/NetworkTransport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,14 @@ import Logging
content: Heartbeat().data,
contentContext: .defaultMessage,
isComplete: true,
completion: .contentProcessed { [weak self] error in
completion: .contentProcessed { [weak self, continuation] error in
guard let self = self else {
continuation.resume(
throwing: MCPError.internalError(
"Transport deallocated during heartbeat"))
return
}

if let error = error {
continuation.resume(throwing: error)
} else {
Expand Down Expand Up @@ -511,9 +518,6 @@ import Logging
var messageWithNewline = message
messageWithNewline.append(UInt8(ascii: "\n"))

// Use a local actor-isolated variable to track continuation state
var sendContinuationResumed = false

try await withCheckedThrowingContinuation {
[weak self] (continuation: CheckedContinuation<Void, Swift.Error>) in
guard let self = self else {
Expand All @@ -528,6 +532,26 @@ import Logging
completion: .contentProcessed { [weak self] error in
guard let self = self else { return }

if let error = error {
self.logger.error("Send error: \(error)")

// Schedule reconnection check on a separate task
Task { [weak self] in
guard let self = self else { return }
let isStopping = await self.isStopping
if !isStopping && self.reconnectionConfig.enabled {
let isConnected = await self.isConnected
if isConnected && error.isConnectionLost {
self.logger.warning(
"Connection appears broken, will attempt to reconnect..."
)
await self.setIsConnected(false)
try? await Task.sleep(for: .milliseconds(500))

let currentIsStopping = await self.isStopping
if !currentIsStopping {
self.connection.cancel()
try? await self.connect()
Task { @MainActor in
if !sendContinuationResumed {
sendContinuationResumed = true
Expand Down Expand Up @@ -560,15 +584,63 @@ import Logging
}
}
}
if let error = error {
self.logger.error("Send error: \(error)")

// Schedule reconnection check on a separate task
Task { [weak self] in
guard let self = self else { return }
let isStopping = await self.isStopping
if !isStopping && self.reconnectionConfig.enabled {
let isConnected = await self.isConnected
if isConnected && error.isConnectionLost {
self.logger.warning(
"Connection appears broken, will attempt to reconnect..."
)
await self.setIsConnected(false)
try? await Task.sleep(for: .milliseconds(500))

let currentIsStopping = await self.isStopping
if !currentIsStopping {
self.connection.cancel()
try? await self.connect()
completion: .contentProcessed { [weak self, continuation] error in
guard let self = self else {
continuation.resume(
throwing: MCPError.internalError(
"Transport deallocated during send"))
return
}

if let error = error {
self.logger.error("Send error: \(error)")

// Schedule reconnection check on a separate task
Task { [weak self] in
guard let self = self else { return }
let isStopping = await self.isStopping
if !isStopping && self.reconnectionConfig.enabled {
let isConnected = await self.isConnected
if isConnected && error.isConnectionLost {
self.logger.warning(
"Connection appears broken, will attempt to reconnect..."
)
await self.setIsConnected(false)
try? await Task.sleep(for: .milliseconds(500))

let currentIsStopping = await self.isStopping
if !currentIsStopping {
self.connection.cancel()
try? await self.connect()
}
}

continuation.resume(
throwing: MCPError.internalError("Send error: \(error)"))
} else {
continuation.resume()
}
}

continuation.resume(
throwing: MCPError.internalError("Send error: \(error)"))
} else {
continuation.resume()
}
})
}
Expand Down Expand Up @@ -747,8 +819,6 @@ import Logging
/// - Returns: The received data chunk
/// - Throws: Network errors or transport failures
private func receiveData() async throws -> Data {
var receiveContinuationResumed = false

return try await withCheckedThrowingContinuation {
[weak self] (continuation: CheckedContinuation<Data, Swift.Error>) in
guard let self = self else {
Expand All @@ -759,6 +829,15 @@ import Logging
let maxLength = bufferConfig.maxReceiveBufferSize ?? Int.max
connection.receive(minimumIncompleteLength: 1, maximumLength: maxLength) {
content, _, isComplete, error in
if let error = error {
continuation.resume(throwing: MCPError.transportError(error))
} else if let content = content {
continuation.resume(returning: content)
} else if isComplete {
self.logger.trace("Connection completed by peer")
continuation.resume(throwing: MCPError.connectionClosed)
} else {
continuation.resume(returning: Data())
Task { @MainActor in
if !receiveContinuationResumed {
receiveContinuationResumed = true
Expand All @@ -774,6 +853,25 @@ import Logging
continuation.resume(returning: Data())
}
}
if let error = error {
continuation.resume(throwing: MCPError.transportError(error))
} else if let content = content {
continuation.resume(returning: content)
} else if isComplete {
self.logger.trace("Connection completed by peer")
continuation.resume(throwing: MCPError.connectionClosed)
} else {
continuation.resume(returning: Data())
[weak self, continuation] content, _, isComplete, error in
if let error = error {
continuation.resume(throwing: MCPError.transportError(error))
} else if let content = content {
continuation.resume(returning: content)
} else if isComplete {
self?.logger.trace("Connection completed by peer")
continuation.resume(throwing: MCPError.connectionClosed)
} else {
continuation.resume(returning: Data())
}
}
}
Expand Down
12 changes: 8 additions & 4 deletions Sources/MCP/Client/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -592,24 +592,27 @@ public actor Client {
/// Use this object to add requests to the batch.
/// - Throws: `MCPError.internalError` if the client is not connected.
/// Can also rethrow errors from the `body` closure or from sending the batch request.
public func withBatch(body: @escaping @Sendable (Batch) async throws -> Void) async throws {
@discardableResult
public func withBatch<T: Sendable>(
body: @escaping @Sendable (Batch) async throws -> T
) async throws -> T {
guard let connection = connection else {
throw MCPError.internalError("Client connection not initialized")
}

// Create Batch actor, passing self (Client)
let batch = Batch(client: self)

// Populate the batch actor by calling the user's closure.
try await body(batch)
// Populate the batch actor by calling the user's closure and capture result.
let result = try await body(batch)

// Get the collected requests from the batch actor
let requests = await batch.requests

// Check if there are any requests to send
guard !requests.isEmpty else {
await logger?.debug("Batch requested but no requests were added.")
return // Nothing to send
return result // Return result even if no requests
}

await logger?.debug(
Expand All @@ -620,6 +623,7 @@ public actor Client {
try await connection.send(data)

// Responses will be handled asynchronously by the message loop and handleBatchResponse/handleResponse.
return result
}

// MARK: - Lifecycle
Expand Down
72 changes: 54 additions & 18 deletions Tests/MCPTests/ClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,10 @@ struct ClientTests {

let request1 = Ping.request()
let request2 = Ping.request()
nonisolated(unsafe) var resultTask1: Task<Ping.Result, Swift.Error>?
nonisolated(unsafe) var resultTask2: Task<Ping.Result, Swift.Error>?

try await client.withBatch { batch in
resultTask1 = try await batch.addRequest(request1)
resultTask2 = try await batch.addRequest(request2)
let (resultTask1, resultTask2) = try await client.withBatch { batch in
let task1 = try await batch.addRequest(request1)
let task2 = try await batch.addRequest(request2)
return (task1, task2)
}

// Check if batch message was sent (after initialize and initialized notification)
Expand Down Expand Up @@ -381,13 +379,8 @@ struct ClientTests {
try await transport.queue(batch: [anyResponse1, anyResponse2])

// Wait for results and verify
guard let task1 = resultTask1, let task2 = resultTask2 else {
#expect(Bool(false), "Result tasks not created")
return
}

_ = try await task1.value // Should succeed
_ = try await task2.value // Should succeed
_ = try await resultTask1.value // Should succeed
_ = try await resultTask2.value // Should succeed

#expect(Bool(true)) // Reaching here means success

Expand Down Expand Up @@ -426,11 +419,11 @@ struct ClientTests {
let request1 = Ping.request() // Success
let request2 = Ping.request() // Error

nonisolated(unsafe) var resultTasks: [Task<Ping.Result, Swift.Error>] = []

try await client.withBatch { batch in
resultTasks.append(try await batch.addRequest(request1))
resultTasks.append(try await batch.addRequest(request2))
let resultTasks = try await client.withBatch { batch in
[
try await batch.addRequest(request1),
try await batch.addRequest(request2),
]
}

// Check if batch message was sent (after initialize and initialized notification)
Expand Down Expand Up @@ -514,6 +507,49 @@ struct ClientTests {
await client.disconnect()
}

@Test("Batch request - empty with non-Void return")
func testBatchRequestEmptyNonVoid() async throws {
let transport = MockTransport()
let client = Client(name: "TestClient", version: "1.0")

// Set up a task to handle the initialize response
let initTask = Task {
try await Task.sleep(for: .milliseconds(10))
if let lastMessage = await transport.sentMessages.last,
let data = lastMessage.data(using: .utf8),
let request = try? JSONDecoder().decode(Request<Initialize>.self, from: data)
{
let response = Initialize.response(
id: request.id,
result: .init(
protocolVersion: Version.latest,
capabilities: .init(),
serverInfo: .init(name: "TestServer", version: "1.0"),
instructions: nil
)
)
try await transport.queue(response: response)
}
}

try await client.connect(transport: transport)
try await Task.sleep(for: .milliseconds(10))
initTask.cancel()

// Call withBatch with non-Void return but don't add any requests
let result: Int = try await client.withBatch { _ in
42
}

// Verify the closure's return value is passed through
#expect(result == 42)

// Check that only initialize message and initialized notification were sent
#expect(await transport.sentMessages.count == 2)

await client.disconnect()
}

@Test("Notify method sends notifications")
func testClientNotify() async throws {
let transport = MockTransport()
Expand Down