diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f586c9cd..7db6ef0f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -43,6 +43,58 @@ jobs: - name: Run tests run: swift test -v + conformance: + timeout-minutes: 10 + runs-on: macos-latest + name: MCP Conformance Tests + + steps: + - uses: actions/checkout@v4 + + - name: Setup Swift + uses: swift-actions/setup-swift@v2 + with: + swift-version: 6.1.0 + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: '20' + + - name: Build Swift executables + run: | + swift build --product mcp-everything-client + swift build --product mcp-everything-server + + - name: Run client conformance tests + uses: modelcontextprotocol/conformance@v0.1.11 + with: + mode: client + command: '.build/debug/mcp-everything-client' + suite: 'core' + expected-failures: './conformance-baseline.yml' + + - name: Start server for testing + run: | + .build/debug/mcp-everything-server & + echo "SERVER_PID=$!" >> $GITHUB_ENV + sleep 3 + + - name: Run server conformance tests + uses: modelcontextprotocol/conformance@v0.1.11 + with: + mode: server + url: 'http://localhost:3001/mcp' + suite: 'core' + expected-failures: './conformance-baseline.yml' + + - name: Cleanup server + if: always() + run: | + if [ -n "$SERVER_PID" ]; then + kill $SERVER_PID 2>/dev/null || true + fi + static-linux-sdk-build: name: Linux Static SDK Build (${{ matrix.swift-version }} - ${{ matrix.os }}) strategy: diff --git a/Package.resolved b/Package.resolved index fb776dd5..8dc9917f 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "371f3dfcfa1201fc8d50e924ad31f9ebc4f90242924df1275958ac79df15dc12", + "originHash" : "06a30e0a3f4c69c306d3b14f13c2b4b3964674139bfeec9b920a2bc3d5b1ca20", "pins" : [ { "identity" : "eventsource", @@ -10,6 +10,33 @@ "version" : "1.1.0" } }, + { + "identity" : "swift-async-algorithms", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-async-algorithms.git", + "state" : { + "revision" : "6c050d5ef8e1aa6342528460db614e9770d7f804", + "version" : "1.1.1" + } + }, + { + "identity" : "swift-atomics", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-atomics.git", + "state" : { + "revision" : "b601256eab081c0f92f059e12818ac1d4f178ff7", + "version" : "1.3.0" + } + }, + { + "identity" : "swift-collections", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-collections.git", + "state" : { + "revision" : "7b847a3b7008b2dc2f47ca3110d8c782fb2e5c7e", + "version" : "1.3.0" + } + }, { "identity" : "swift-log", "kind" : "remoteSourceControl", @@ -19,6 +46,15 @@ "version" : "1.6.2" } }, + { + "identity" : "swift-nio", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio.git", + "state" : { + "revision" : "5e72fc102906ebe75a3487595a653e6f43725552", + "version" : "2.94.0" + } + }, { "identity" : "swift-system", "kind" : "remoteSourceControl", diff --git a/Package.swift b/Package.swift index 064e0d87..60f555fb 100644 --- a/Package.swift +++ b/Package.swift @@ -7,16 +7,22 @@ import PackageDescription var dependencies: [Package.Dependency] = [ .package(url: "https://github.com/apple/swift-system.git", from: "1.0.0"), .package(url: "https://github.com/apple/swift-log.git", from: "1.5.0"), + .package(url: "https://github.com/apple/swift-async-algorithms.git", from: "1.0.0"), .package(url: "https://github.com/mattt/eventsource.git", from: "1.1.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.65.0"), ] // Target dependencies needed on all platforms var targetDependencies: [Target.Dependency] = [ .product(name: "SystemPackage", package: "swift-system"), .product(name: "Logging", package: "swift-log"), + .product(name: "AsyncAlgorithms", package: "swift-async-algorithms"), .product( name: "EventSource", package: "eventsource", condition: .when(platforms: [.macOS, .iOS, .tvOS, .visionOS, .watchOS, .macCatalyst])), + .product(name: "NIOCore", package: "swift-nio"), + .product(name: "NIOPosix", package: "swift-nio"), + .product(name: "NIOHTTP1", package: "swift-nio"), ] let package = Package( @@ -33,7 +39,13 @@ let package = Package( // Products define the executables and libraries a package produces, making them visible to other packages. .library( name: "MCP", - targets: ["MCP"]) + targets: ["MCP"]), + .executable( + name: "mcp-everything-server", + targets: ["MCPConformanceServer"]), + .executable( + name: "mcp-everything-client", + targets: ["MCPConformanceClient"]) ], dependencies: dependencies, targets: [ @@ -49,5 +61,13 @@ let package = Package( .testTarget( name: "MCPTests", dependencies: ["MCP"] + targetDependencies), + .executableTarget( + name: "MCPConformanceServer", + dependencies: ["MCP"] + targetDependencies, + path: "Sources/MCPConformance/Server"), + .executableTarget( + name: "MCPConformanceClient", + dependencies: ["MCP"] + targetDependencies, + path: "Sources/MCPConformance/Client") ] ) diff --git a/README.md b/README.md index 7d3b3274..52d94f0b 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Official Swift SDK for the [Model Context Protocol][mcp] (MCP). The Model Context Protocol (MCP) defines a standardized way for applications to communicate with AI and ML models. This Swift SDK implements both client and server components -according to the [2025-06-18][mcp-spec-2025-06-18] (latest) version +according to the [2025-11-25][mcp-spec-2025-11-25] (latest) version of the MCP specification. ## Requirements @@ -60,6 +60,14 @@ let result = try await client.connect(transport: transport) if result.capabilities.tools != nil { // Server supports tools (implicitly including tool calling if the 'tools' capability object is present) } + +if result.capabilities.logging != nil { + // Server supports sending log messages +} + +if result.capabilities.completions != nil { + // Server supports argument autocompletion +} ``` > [!NOTE] @@ -298,6 +306,61 @@ for message in messages { } ``` +### Completions + +Completions allow servers to provide autocompletion suggestions for prompt and resource template arguments as users type: + +```swift +// Request completions for a prompt argument +let completion = try await client.complete( + promptName: "code_review", + argumentName: "language", + argumentValue: "py" +) + +// Display suggestions to the user +for value in completion.values { + print("Suggestion: \(value)") +} + +if completion.hasMore == true { + print("More suggestions available (total: \(completion.total ?? 0))") +} +``` + +You can also provide context with already-resolved arguments: + +```swift +// First, user selects a language +let languageCompletion = try await client.complete( + promptName: "code_review", + argumentName: "language", + argumentValue: "py" +) +// User selects "python" + +// Then get framework suggestions based on the selected language +let frameworkCompletion = try await client.complete( + promptName: "code_review", + argumentName: "framework", + argumentValue: "fla", + context: ["language": .string("python")] +) +// Returns: ["flask"] +``` + +Completions work for resource templates as well: + +```swift +// Get path completions for a resource URI template +let pathCompletion = try await client.complete( + resourceURI: "file:///{path}", + argumentName: "path", + argumentValue: "/usr/" +) +// Returns: ["/usr/bin", "/usr/lib", "/usr/local"] +``` + ### Sampling Sampling allows servers to request LLM completions through the client, @@ -477,6 +540,42 @@ Common use cases for elicitation: - **Configuration**: Collect preferences or settings during operation - **Missing information**: Request additional details not provided initially +### Logging + +Clients can control server logging levels and receive structured log messages: + +```swift +// Set the minimum logging level +try await client.setLoggingLevel(.warning) + +// Register a handler for log messages from the server +await client.onNotification(LogMessageNotification.self) { message in + let level = message.params.level // LogLevel (debug, info, warning, etc.) + let logger = message.params.logger // Optional logger name + let data = message.params.data // Arbitrary JSON data + + // Display log message based on level + switch level { + case .error, .critical, .alert, .emergency: + print("❌ [\(logger ?? "server")] \(data)") + case .warning: + print("⚠️ [\(logger ?? "server")] \(data)") + default: + print("ℹ️ [\(logger ?? "server")] \(data)") + } +} +``` + +Log levels follow the standard syslog severity levels (RFC 5424): +- **debug**: Detailed debugging information +- **info**: General informational messages +- **notice**: Normal but significant events +- **warning**: Warning conditions +- **error**: Error conditions +- **critical**: Critical conditions +- **alert**: Action must be taken immediately +- **emergency**: System is unusable + ### Error Handling Handle common client errors: @@ -612,6 +711,8 @@ let server = Server( name: "MyModelServer", version: "1.0.0", capabilities: .init( + completions: .init(), + logging: .init(), prompts: .init(listChanged: true), resources: .init(subscribe: true, listChanged: true), tools: .init(listChanged: true) @@ -796,6 +897,156 @@ await server.withMethodHandler(GetPrompt.self) { params in } ``` +### Completions + +Servers can provide autocompletion suggestions for prompt and resource template arguments: + +```swift +// Enable completions capability +let server = Server( + name: "MyServer", + version: "1.0.0", + capabilities: .init( + completions: .init(), + prompts: .init(listChanged: true) + ) +) + +// Register a completion handler +await server.withMethodHandler(Complete.self) { params in + // Get the argument being completed + let argumentName = params.argument.name + let currentValue = params.argument.value + + // Check which prompt or resource is being completed + switch params.ref { + case .prompt(let promptRef): + // Provide completions for a prompt argument + if promptRef.name == "code_review" && argumentName == "language" { + // Simple prefix matching + let allLanguages = ["python", "perl", "php", "javascript", "java", "swift"] + let matches = allLanguages.filter { $0.hasPrefix(currentValue.lowercased()) } + + return .init( + completion: .init( + values: Array(matches.prefix(100)), // Max 100 items + total: matches.count, + hasMore: matches.count > 100 + ) + ) + } + + case .resource(let resourceRef): + // Provide completions for a resource template argument + if resourceRef.uri == "file:///{path}" && argumentName == "path" { + // Return directory suggestions + let suggestions = try getDirectoryCompletions(for: currentValue) + return .init( + completion: .init( + values: suggestions, + total: suggestions.count, + hasMore: false + ) + ) + } + } + + // No completions available + return .init(completion: .init(values: [], total: 0, hasMore: false)) +} +``` + +You can also use context from already-resolved arguments: + +```swift +await server.withMethodHandler(Complete.self) { params in + // Access context from previous argument completions + if let context = params.context, + let language = context.arguments["language"]?.stringValue { + + // Provide framework suggestions based on selected language + if language == "python" { + let frameworks = ["flask", "django", "fastapi", "tornado"] + let matches = frameworks.filter { + $0.hasPrefix(params.argument.value.lowercased()) + } + return .init( + completion: .init(values: matches, total: matches.count, hasMore: false) + ) + } + } + + return .init(completion: .init(values: [], total: 0, hasMore: false)) +} +``` + +### Logging + +Servers can send structured log messages to clients: + +```swift +// Enable logging capability +let server = Server( + name: "MyServer", + version: "1.0.0", + capabilities: .init( + logging: .init(), + tools: .init(listChanged: true) + ) +) + +// Send log messages at different severity levels +try await server.log( + level: .info, + logger: "database", + data: Value.object([ + "message": .string("Database connected successfully"), + "host": .string("localhost"), + "port": .int(5432) + ]) +) + +try await server.log( + level: .error, + logger: "api", + data: Value.object([ + "message": .string("Request failed"), + "statusCode": .int(500), + "error": .string("Internal server error") + ]) +) + +// You can also use codable types directly +struct ErrorLog: Codable { + let message: String + let code: Int + let timestamp: String +} + +let errorLog = ErrorLog( + message: "Operation failed", + code: 500, + timestamp: ISO8601DateFormatter().string(from: Date()) +) + +try await server.log(level: .error, logger: "operations", data: errorLog) +``` + +Clients can control which log levels they receive: + +```swift +// Register a handler for client's logging level preferences +await server.withMethodHandler(SetLoggingLevel.self) { params in + let minimumLevel = params.level + + // Store the client's preference and filter log messages accordingly + // (Implementation depends on your server architecture) + storeLogLevel(minimumLevel) + + return Empty() +} +``` + ### Sampling Servers can request LLM completions from clients through sampling. This enables agentic behaviors where servers can ask for AI assistance while maintaining human oversight. @@ -1089,4 +1340,4 @@ see the [GitHub Releases page](https://github.com/modelcontextprotocol/swift-sdk This project is licensed under the MIT License. [mcp]: https://modelcontextprotocol.io -[mcp-spec-2025-06-18]: https://modelcontextprotocol.io/specification/2025-06-18 +[mcp-spec-2025-11-25]: https://modelcontextprotocol.io/specification/2025-11-25 diff --git a/Sources/MCP/Base/Transports/HTTPClientTransport.swift b/Sources/MCP/Base/Transports/HTTPClientTransport.swift index 982cd505..d5b56329 100644 --- a/Sources/MCP/Base/Transports/HTTPClientTransport.swift +++ b/Sources/MCP/Base/Transports/HTTPClientTransport.swift @@ -9,16 +9,20 @@ import Logging import FoundationNetworking #endif +// MARK: - Timeout Helpers + +/// Error thrown when an operation times out /// An implementation of the MCP Streamable HTTP transport protocol for clients. /// -/// This transport implements the [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http) -/// specification from the Model Context Protocol. +/// This transport implements the [Streamable HTTP transport](https://spec.modelcontextprotocol.io/specification/2025-11-25/basic/transports#streamable-http) +/// specification from the Model Context Protocol (version 2025-11-25). /// /// It supports: /// - Sending JSON-RPC messages via HTTP POST requests /// - Receiving responses via both direct JSON responses and SSE streams -/// - Session management using the `Mcp-Session-Id` header -/// - Automatic reconnection for dropped SSE streams +/// - Session management using the `MCP-Session-Id` header +/// - Protocol version negotiation via `MCP-Protocol-Version` header +/// - Automatic reconnection for dropped SSE streams with resumability support /// - Platform-specific optimizations for different operating systems /// /// The transport supports two modes: @@ -56,6 +60,10 @@ public actor HTTPClientTransport: Transport { /// The session ID assigned by the server, used for maintaining state across requests public private(set) var sessionID: String? + + /// The negotiated protocol version to send in MCP-Protocol-Version header + public var protocolVersion: String? + private let streaming: Bool private var streamingTask: Task? @@ -75,6 +83,16 @@ public actor HTTPClientTransport: Transport { private var initialSessionIDSignalTask: Task? private var initialSessionIDContinuation: CheckedContinuation? + /// The last event ID received from the server for SSE stream resumability + private var lastEventID: String? + + /// The retry interval (in milliseconds) from the server's SSE `retry:` field + private var retryInterval: Int = 3000 // Default 3000ms per SSE spec + + /// The underlying URLSession task for the active GET SSE stream. + /// Used to trigger reconnection when a POST SSE stream closes without delivering data. + private var activeGETSessionTask: URLSessionDataTask? + /// Creates a new HTTP transport client with the specified endpoint /// /// - Parameters: @@ -82,6 +100,7 @@ public actor HTTPClientTransport: Transport { /// - configuration: URLSession configuration to use for HTTP requests /// - streaming: Whether to enable SSE streaming mode (default: true) /// - sseInitializationTimeout: Maximum time to wait for session ID before proceeding with SSE (default: 10 seconds) + /// - protocolVersion: The MCP protocol version to use (default: "2025-11-25") /// - requestModifier: Optional closure to customize requests before they are sent (default: no modification) /// - logger: Optional logger instance for transport events public init( @@ -89,6 +108,7 @@ public actor HTTPClientTransport: Transport { configuration: URLSessionConfiguration = .default, streaming: Bool = true, sseInitializationTimeout: TimeInterval = 10, + protocolVersion: String = Version.latest, requestModifier: @escaping (URLRequest) -> URLRequest = { $0 }, logger: Logger? = nil ) { @@ -97,6 +117,7 @@ public actor HTTPClientTransport: Transport { session: URLSession(configuration: configuration), streaming: streaming, sseInitializationTimeout: sseInitializationTimeout, + protocolVersion: protocolVersion, requestModifier: requestModifier, logger: logger ) @@ -107,6 +128,7 @@ public actor HTTPClientTransport: Transport { session: URLSession, streaming: Bool = false, sseInitializationTimeout: TimeInterval = 10, + protocolVersion: String = Version.latest, requestModifier: @escaping (URLRequest) -> URLRequest = { $0 }, logger: Logger? = nil ) { @@ -114,6 +136,7 @@ public actor HTTPClientTransport: Transport { self.session = session self.streaming = streaming self.sseInitializationTimeout = sseInitializationTimeout + self.protocolVersion = protocolVersion self.requestModifier = requestModifier // Create message stream @@ -144,7 +167,9 @@ public actor HTTPClientTransport: Transport { if let continuation = self.initialSessionIDContinuation { continuation.resume() self.initialSessionIDContinuation = nil // Consume the continuation - logger.trace("Initial session ID signal triggered for SSE task.") + logger.debug("✓ Initial session ID signal triggered for SSE task") + } else { + logger.debug("✗ No continuation to trigger - signal already consumed or SSE task not waiting") } } @@ -202,6 +227,7 @@ public actor HTTPClientTransport: Transport { /// the response according to the MCP Streamable HTTP specification. It handles: /// /// - Adding appropriate Accept headers for both JSON and SSE + /// - Including the MCP-Protocol-Version header as required by the specification /// - Including the session ID in requests if one has been established /// - Processing different response types (JSON vs SSE) /// - Handling HTTP error codes according to the specification @@ -219,9 +245,14 @@ public actor HTTPClientTransport: Transport { request.addValue("application/json", forHTTPHeaderField: "Content-Type") request.httpBody = data + // Add protocol version header (required by MCP specification 2025-11-25) + if let protocolVersion = protocolVersion { + request.addValue(protocolVersion, forHTTPHeaderField: "MCP-Protocol-Version") + } + // Add session ID if available if let sessionID = sessionID { - request.addValue(sessionID, forHTTPHeaderField: "Mcp-Session-Id") + request.addValue(sessionID, forHTTPHeaderField: "MCP-Session-Id") } // Apply request modifier @@ -249,7 +280,7 @@ public actor HTTPClientTransport: Transport { let contentType = httpResponse.value(forHTTPHeaderField: "Content-Type") ?? "" // Extract session ID if present - if let newSessionID = httpResponse.value(forHTTPHeaderField: "Mcp-Session-Id") { + if let newSessionID = httpResponse.value(forHTTPHeaderField: "MCP-Session-Id") { let wasSessionIDNil = (self.sessionID == nil) self.sessionID = newSessionID if wasSessionIDNil { @@ -286,7 +317,7 @@ public actor HTTPClientTransport: Transport { let contentType = httpResponse.value(forHTTPHeaderField: "Content-Type") ?? "" // Extract session ID if present - if let newSessionID = httpResponse.value(forHTTPHeaderField: "Mcp-Session-Id") { + if let newSessionID = httpResponse.value(forHTTPHeaderField: "MCP-Session-Id") { let wasSessionIDNil = (self.sessionID == nil) self.sessionID = newSessionID if wasSessionIDNil { @@ -302,7 +333,15 @@ public actor HTTPClientTransport: Transport { if contentType.contains("text/event-stream") { // For SSE, processing happens via the stream logger.trace("Received SSE response, processing in streaming task") - try await self.processSSE(stream) + let hadData = try await self.processSSE(stream) + + // If the POST SSE stream closed without delivering a JSON-RPC response, + // trigger GET reconnection so the server can deliver it there. + // This implements standard SSE reconnection behavior per the spec. + if !hadData { + logger.debug("POST SSE stream closed without data, triggering GET reconnection") + self.activeGETSessionTask?.cancel() + } } else if contentType.contains("application/json") { // For JSON responses, collect and deliver the data var buffer = Data() @@ -404,75 +443,93 @@ public actor HTTPClientTransport: Transport { // This is the original code for platforms that support SSE guard isConnected else { return } - // Wait for the initial session ID signal, but only if sessionID isn't already set - if self.sessionID == nil, let signalTask = self.initialSessionIDSignalTask { - logger.trace("SSE streaming task waiting for initial sessionID signal...") - - // Race the signalTask against a timeout - let timeoutTask = Task { - try? await Task.sleep(for: .seconds(self.sseInitializationTimeout)) - return false - } + // Wait for session ID to be available before opening SSE stream + if self.sessionID == nil { + logger.debug("⏳ Waiting for session ID to be set (timeout: \(self.sseInitializationTimeout)s)...") - let signalCompletionTask = Task { - await signalTask.value - return true // Indicates signal received - } + let startTime = Date() + let timeout = self.sseInitializationTimeout - // Use TaskGroup to race the two tasks - var signalReceived = false - do { - signalReceived = try await withThrowingTaskGroup(of: Bool.self) { group in - group.addTask { - await signalCompletionTask.value - } - group.addTask { - await timeoutTask.value - } - - // Take the first result and cancel the other task - if let firstResult = try await group.next() { - group.cancelAll() - return firstResult - } - return false + // Poll for session ID with exponential backoff + var attempt = 0 + while self.sessionID == nil && !Task.isCancelled { + let elapsed = Date().timeIntervalSince(startTime) + if elapsed >= timeout { + logger.warning("⏱️ Timeout waiting for session ID (\(timeout)s). SSE stream will proceed anyway.") + break } - } catch { - logger.error("Error while waiting for session ID signal: \(error)") - } - // Clean up tasks - timeoutTask.cancel() + // Exponential backoff: 10ms, 20ms, 50ms, 100ms, 200ms, then 500ms + let delay = min(500, max(10, 10 * (1 << attempt))) + try? await Task.sleep(for: .milliseconds(delay)) + attempt += 1 + } - if signalReceived { - logger.trace("SSE streaming task proceeding after initial sessionID signal.") - } else { - logger.warning( - "Timeout waiting for initial sessionID signal. SSE stream will proceed (sessionID might be nil)." - ) + if self.sessionID != nil { + let elapsed = Date().timeIntervalSince(startTime) + logger.debug("✓ Session ID received after \(Int(elapsed * 1000))ms, proceeding with SSE connection") } - } else if self.sessionID != nil { - logger.trace( - "Initial sessionID already available. Proceeding with SSE streaming task immediately." - ) } else { - logger.trace( - "Proceeding with SSE connection attempt; sessionID is nil. This might be expected for stateless servers or if initialize hasn't provided one yet." - ) + logger.debug("✓ Session ID already available, proceeding with SSE connection immediately") } // Retry loop for connection drops + var isFirstAttempt = true + var attemptCount = 0 + + logger.debug("🔄 Starting SSE retry loop", metadata: [ + "isConnected": "\(isConnected)", + "isCancelled": "\(Task.isCancelled)" + ]) + while isConnected && !Task.isCancelled { + attemptCount += 1 + logger.debug("🔄 SSE retry loop iteration", metadata: [ + "attempt": "\(attemptCount)", + "isFirstAttempt": "\(isFirstAttempt)" + ]) + do { - try await connectToEventStream() + // Wait for retry interval before reconnecting (except first attempt) + if !isFirstAttempt { + let delayMs = self.retryInterval + logger.debug("⏳ Waiting before SSE reconnection", metadata: ["retryMs": "\(delayMs)"]) + try await Task.sleep(for: .milliseconds(delayMs)) + logger.debug("✓ Wait complete, reconnecting now") + } + isFirstAttempt = false + + logger.debug("📡 Calling connectToEventStream (attempt #\(attemptCount))") + + try await self.connectToEventStream() + + // If connectToEventStream() returns without error, + // it means the stream closed gracefully - reconnect with retry interval + logger.info("🔌 SSE stream closed gracefully, will reconnect", metadata: [ + "attempt": "\(attemptCount)", + "willRetryAfter": "\(self.retryInterval)ms" + ]) } catch { if !Task.isCancelled { - logger.error("SSE connection error: \(error)") - // Wait before retrying - try? await Task.sleep(for: .seconds(1)) + logger.error("❌ SSE connection error (attempt #\(attemptCount)): \(error)") + // Error case - will also use retry interval on next iteration + } else { + logger.debug("⏹️ SSE task cancelled") } } + + logger.debug("🔄 End of retry loop iteration", metadata: [ + "isConnected": "\(isConnected)", + "isCancelled": "\(Task.isCancelled)", + "willContinue": "\(isConnected && !Task.isCancelled)" + ]) } + + logger.debug("⏹️ SSE retry loop exited", metadata: [ + "isConnected": "\(isConnected)", + "isCancelled": "\(Task.isCancelled)", + "totalAttempts": "\(attemptCount)" + ]) #endif } @@ -481,19 +538,38 @@ public actor HTTPClientTransport: Transport { /// /// This initiates a GET request to the server endpoint with appropriate /// headers to establish an SSE stream according to the MCP specification. + /// Supports stream resumability via Last-Event-ID header. /// /// - Throws: MCPError for connection failures or server errors private func connectToEventStream() async throws { - guard isConnected else { return } + guard isConnected else { + logger.debug("⚠️ Skipping connectToEventStream - transport not connected") + return + } + + logger.debug("🔌 Preparing SSE connection request") var request = URLRequest(url: endpoint) request.httpMethod = "GET" request.addValue("text/event-stream", forHTTPHeaderField: "Accept") request.addValue("no-cache", forHTTPHeaderField: "Cache-Control") + // Add protocol version header (required by MCP specification 2025-11-25) + if let protocolVersion = protocolVersion { + request.addValue(protocolVersion, forHTTPHeaderField: "MCP-Protocol-Version") + } + // Add session ID if available if let sessionID = sessionID { - request.addValue(sessionID, forHTTPHeaderField: "Mcp-Session-Id") + request.addValue(sessionID, forHTTPHeaderField: "MCP-Session-Id") + } + + // Add last event ID for resumability (if available) + if let lastEventID = lastEventID { + request.addValue(lastEventID, forHTTPHeaderField: "Last-Event-ID") + logger.info("→ Resuming SSE stream with Last-Event-ID", metadata: ["lastEventID": "\(lastEventID)"]) + } else { + logger.info("→ Connecting to SSE stream (no last event ID to resume from)") } // Apply request modifier @@ -503,6 +579,7 @@ public actor HTTPClientTransport: Transport { // Create URLSession task for SSE let (stream, response) = try await session.bytes(for: request) + self.activeGETSessionTask = stream.task guard let httpResponse = response as? HTTPURLResponse else { throw MCPError.internalError("Invalid HTTP response") @@ -520,7 +597,7 @@ public actor HTTPClientTransport: Transport { } // Extract session ID if present - if let newSessionID = httpResponse.value(forHTTPHeaderField: "Mcp-Session-Id") { + if let newSessionID = httpResponse.value(forHTTPHeaderField: "MCP-Session-Id") { let wasSessionIDNil = (self.sessionID == nil) self.sessionID = newSessionID if wasSessionIDNil { @@ -531,36 +608,62 @@ public actor HTTPClientTransport: Transport { logger.debug("Session ID received", metadata: ["sessionID": "\(newSessionID)"]) } + defer { self.activeGETSessionTask = nil } try await self.processSSE(stream) } - /// Processes an SSE byte stream, extracting events and delivering them + /// Processes an SSE byte stream, extracting events and delivering them. + /// + /// This method processes Server-Sent Events according to the MCP specification, + /// including support for event IDs for resumability. /// /// - Parameter stream: The URLSession.AsyncBytes stream to process + /// - Returns: `true` if any data events were received, `false` otherwise. /// - Throws: Error for stream processing failures - private func processSSE(_ stream: URLSession.AsyncBytes) async throws { - do { - for try await event in stream.events { - // Check if task has been cancelled - if Task.isCancelled { break } - - logger.trace( - "SSE event received", - metadata: [ - "type": "\(event.event ?? "message")", - "id": "\(event.id ?? "none")", - ] - ) - - // Convert the event data to Data and yield it to the message stream - if !event.data.isEmpty, let data = event.data.data(using: .utf8) { - messageContinuation.yield(data) - } + @discardableResult + private func processSSE(_ stream: URLSession.AsyncBytes) async throws -> Bool { + logger.debug("📥 Starting SSE event processing") + var eventCount = 0 + var hadDataEvent = false + + for try await event in stream.events { + eventCount += 1 + + // Check if task has been cancelled + if Task.isCancelled { + logger.debug("⏹️ SSE processing cancelled", metadata: ["eventsProcessed": "\(eventCount)"]) + break + } + + logger.trace( + "SSE event received", + metadata: [ + "type": "\(event.event ?? "message")", + "id": "\(event.id ?? "none")", + ] + ) + + // Store event ID for resumability support + if let eventID = event.id, !eventID.isEmpty { + self.lastEventID = eventID + logger.debug("Stored event ID for resumability", metadata: ["eventID": "\(eventID)"]) + } + + // Store retry interval if provided by server + if let retry = event.retry { + self.retryInterval = retry + logger.debug("SSE retry interval updated", metadata: ["retryMs": "\(retry)"]) + } + + // Convert the event data to Data and yield it to the message stream + if !event.data.isEmpty, let data = event.data.data(using: .utf8) { + hadDataEvent = true + messageContinuation.yield(data) } - } catch { - logger.error("Error processing SSE events: \(error)") - throw error } + + logger.debug("✓ SSE event stream completed", metadata: ["eventsProcessed": "\(eventCount)", "hadData": "\(hadDataEvent)"]) + return hadDataEvent } #endif } diff --git a/Sources/MCP/Base/Transports/HTTPServer/HTTPRequestValidation.swift b/Sources/MCP/Base/Transports/HTTPServer/HTTPRequestValidation.swift new file mode 100644 index 00000000..805a195e --- /dev/null +++ b/Sources/MCP/Base/Transports/HTTPServer/HTTPRequestValidation.swift @@ -0,0 +1,371 @@ +import Foundation + +// MARK: - Validation Protocol + +/// Validates an incoming HTTP request before the transport processes it. +/// +/// Validators are composed into a pipeline and executed in order. The first validator +/// that returns a non-nil response short-circuits the pipeline and that error response +/// is returned to the client. +/// +/// Conform to this protocol to add custom validation (e.g., authentication): +/// ```swift +/// struct BearerTokenValidator: HTTPRequestValidator { +/// func validate(_ request: HTTPRequest, context: HTTPValidationContext) -> HTTPResponse? { +/// guard let auth = request.header("Authorization"), +/// auth.hasPrefix("Bearer ") else { +/// return .error(statusCode: 401, .invalidRequest("Missing bearer token")) +/// } +/// return nil +/// } +/// } +/// ``` +public protocol HTTPRequestValidator: Sendable { + /// Validates the request. Returns an error response if invalid, or `nil` if valid. + func validate(_ request: HTTPRequest, context: HTTPValidationContext) -> HTTPResponse? +} + +// MARK: - Validation Context + +/// Context provided to validators for making validation decisions. +public struct HTTPValidationContext: Sendable { + /// The HTTP method of the request (GET, POST, DELETE). + public let httpMethod: String + + /// The current session ID, if any (nil in stateless mode or before initialization). + public let sessionID: String? + + /// Whether the request body contains an `initialize` JSON-RPC request. + public let isInitializationRequest: Bool + + /// The set of protocol versions this server supports. + public let supportedProtocolVersions: Set + + public init( + httpMethod: String, + sessionID: String? = nil, + isInitializationRequest: Bool = false, + supportedProtocolVersions: Set = Version.supported + ) { + self.httpMethod = httpMethod + self.sessionID = sessionID + self.isInitializationRequest = isInitializationRequest + self.supportedProtocolVersions = supportedProtocolVersions + } +} + +// MARK: - Accept Header Validator + +/// Validates the `Accept` header based on the HTTP method and transport response mode. +/// +/// - Stateful (SSE) mode: POST requests must accept both `application/json` and `text/event-stream` +/// - Stateless (JSON) mode: POST requests only need to accept `application/json` +/// - GET requests always require `text/event-stream` +public struct AcceptHeaderValidator: HTTPRequestValidator { + /// The response mode determines which content types are required. + public enum Mode: Sendable { + /// POST requires both `application/json` and `text/event-stream`. + case sseRequired + /// POST only requires `application/json`. + case jsonOnly + } + + public let mode: Mode + + public init(mode: Mode) { + self.mode = mode + } + + public func validate(_ request: HTTPRequest, context: HTTPValidationContext) -> HTTPResponse? { + let accept = request.header(HTTPHeaderName.accept) ?? "" + let acceptTypes = accept.split(separator: ",").map { + $0.trimmingCharacters(in: .whitespaces) + } + + let hasJSON = acceptTypes.contains { $0.hasPrefix(ContentType.json) } + let hasSSE = acceptTypes.contains { $0.hasPrefix(ContentType.sse) } + + switch context.httpMethod { + case "POST": + switch mode { + case .sseRequired: + guard hasJSON, hasSSE else { + return .error( + statusCode: 406, + .invalidRequest( + "Not Acceptable: Client must accept both application/json and text/event-stream" + ), + sessionID: context.sessionID + ) + } + case .jsonOnly: + guard hasJSON else { + return .error( + statusCode: 406, + .invalidRequest( + "Not Acceptable: Client must accept application/json" + ), + sessionID: context.sessionID + ) + } + } + case "GET": + guard hasSSE else { + return .error( + statusCode: 406, + .invalidRequest( + "Not Acceptable: Client must accept text/event-stream" + ), + sessionID: context.sessionID + ) + } + default: + break + } + + return nil + } +} + +// MARK: - Content-Type Validator + +/// Validates that POST requests have `Content-Type: application/json`. +public struct ContentTypeValidator: HTTPRequestValidator { + public init() {} + + public func validate(_ request: HTTPRequest, context: HTTPValidationContext) -> HTTPResponse? { + guard context.httpMethod == "POST" else { return nil } + + let contentType = request.header(HTTPHeaderName.contentType) ?? "" + let mainType = contentType.split(separator: ";").first? + .trimmingCharacters(in: .whitespaces) ?? "" + + guard mainType == ContentType.json else { + return .error( + statusCode: 415, + .invalidRequest( + "Unsupported Media Type: Content-Type must be application/json" + ), + sessionID: context.sessionID + ) + } + + return nil + } +} + +// MARK: - Protocol Version Validator + +/// Validates the `MCP-Protocol-Version` header against supported versions. +/// +/// Per spec: +/// - If the header is absent, the server assumes the default negotiated version +/// - If the header is present but unsupported, the server returns 400 Bad Request +/// - Initialization requests are exempt (protocol version comes from the request body) +public struct ProtocolVersionValidator: HTTPRequestValidator { + public init() {} + + public func validate(_ request: HTTPRequest, context: HTTPValidationContext) -> HTTPResponse? { + // Skip for initialization requests (version is in the body, not the header) + guard !context.isInitializationRequest else { return nil } + + // Skip for non-POST methods (GET/DELETE don't carry protocol version) + // Actually, per spec, all subsequent requests should include it + guard let version = request.header(HTTPHeaderName.protocolVersion) else { + // Per spec: if not received, assume default version + return nil + } + + guard context.supportedProtocolVersions.contains(version) else { + let supported = context.supportedProtocolVersions.sorted().joined(separator: ", ") + return .error( + statusCode: 400, + .invalidRequest( + "Bad Request: Unsupported protocol version: \(version). Supported: \(supported)" + ), + sessionID: context.sessionID + ) + } + + return nil + } +} + +// MARK: - Session Validator + +/// Validates the `Mcp-Session-Id` header for stateful transports. +/// +/// - Initialization requests are exempt (no session exists yet) +/// - Non-initialization requests must include the session ID header +/// - The session ID must match the active session +public struct SessionValidator: HTTPRequestValidator { + public init() {} + + public func validate(_ request: HTTPRequest, context: HTTPValidationContext) -> HTTPResponse? { + // Skip validation for initialization requests + guard !context.isInitializationRequest else { return nil } + + // If no session exists yet, skip (server hasn't been initialized) + guard let expectedSessionID = context.sessionID else { return nil } + + let requestSessionID = request.header(HTTPHeaderName.sessionID) + + guard let requestSessionID else { + return .error( + statusCode: 400, + .invalidRequest("Bad Request: Missing \(HTTPHeaderName.sessionID) header"), + sessionID: expectedSessionID + ) + } + + guard requestSessionID == expectedSessionID else { + return .error( + statusCode: 404, + .invalidRequest("Not Found: Invalid or expired session ID"), + sessionID: expectedSessionID + ) + } + + return nil + } +} + +// MARK: - Origin Validator + +/// DNS rebinding protection: validates `Origin` and `Host` headers. +/// +/// Per spec, servers MUST validate the Origin header to prevent DNS rebinding attacks. +/// This is particularly important for servers running on localhost. +/// +/// Use `.localhost()` for local development servers. +/// Use `.disabled` to skip validation (e.g., cloud deployments). +/// Use `init(allowedHosts:allowedOrigins:)` for custom configurations. +public struct OriginValidator: HTTPRequestValidator { + public let allowedHosts: [String] + public let allowedOrigins: [String] + private let enabled: Bool + + public init(allowedHosts: [String], allowedOrigins: [String]) { + self.allowedHosts = allowedHosts + self.allowedOrigins = allowedOrigins + self.enabled = true + } + + private init(disabled: Void) { + self.allowedHosts = [] + self.allowedOrigins = [] + self.enabled = false + } + + /// Protection for localhost-bound servers. + /// Allows requests from `localhost`, `127.0.0.1`, and `[::1]` with the specified port. + public static func localhost(port: Int? = nil) -> OriginValidator { + let portPattern = port.map { String($0) } ?? "*" + return OriginValidator( + allowedHosts: [ + "127.0.0.1:\(portPattern)", + "localhost:\(portPattern)", + "[::1]:\(portPattern)", + ], + allowedOrigins: [ + "http://127.0.0.1:\(portPattern)", + "http://localhost:\(portPattern)", + "http://[::1]:\(portPattern)", + ] + ) + } + + /// Disables DNS rebinding protection. + /// Use for cloud deployments where DNS rebinding is not a threat. + public static var disabled: OriginValidator { + OriginValidator(disabled: ()) + } + + public func validate(_ request: HTTPRequest, context: HTTPValidationContext) -> HTTPResponse? { + guard enabled else { return nil } + + // Validate Host header + if let host = request.header(HTTPHeaderName.host) { + let hostAllowed = allowedHosts.contains { pattern in + matchesPattern(host, pattern: pattern) + } + if !hostAllowed { + return .error( + statusCode: 421, + .invalidRequest("Misdirected Request: Host header not allowed"), + sessionID: context.sessionID + ) + } + } + + // Validate Origin header (only if present — non-browser clients won't send it) + if let origin = request.header(HTTPHeaderName.origin) { + let originAllowed = allowedOrigins.contains { pattern in + matchesPattern(origin, pattern: pattern) + } + if !originAllowed { + return .error( + statusCode: 403, + .invalidRequest("Forbidden: Origin not allowed"), + sessionID: context.sessionID + ) + } + } + + return nil + } + + /// Matches a value against a pattern that may contain a port wildcard `:*`. + /// + /// Examples: + /// - `"localhost:*"` matches `"localhost:8080"`, `"localhost:3000"` + /// - `"http://localhost:*"` matches `"http://localhost:8080"` + /// - `"localhost:8080"` matches only `"localhost:8080"` exactly + private func matchesPattern(_ value: String, pattern: String) -> Bool { + guard pattern.hasSuffix(":*") else { + return value == pattern + } + + let prefix = String(pattern.dropLast(2)) + guard value.hasPrefix(prefix + ":") else { return false } + + let portPart = value.dropFirst(prefix.count + 1) + return !portPart.isEmpty && portPart.allSatisfy(\.isNumber) + } +} + +// MARK: - Validation Pipeline Protocol + +/// Runs a validation pipeline against an HTTP request. +/// +/// Implementations execute a sequence of validators and return the first error, +/// or `nil` if all validations pass. +public protocol HTTPRequestValidationPipeline: Sendable { + /// Validates the request using the configured pipeline. + /// Returns an error response if validation fails, or `nil` if the request is valid. + func validate(_ request: HTTPRequest, context: HTTPValidationContext) -> HTTPResponse? +} + +// MARK: - Standard Validation Pipeline + +/// Standard implementation of `HTTPRequestValidationPipeline` that runs validators in sequence. +/// +/// The first validator that returns a non-nil error response short-circuits the pipeline. +public struct StandardValidationPipeline: HTTPRequestValidationPipeline { + private let validators: [any HTTPRequestValidator] + + /// Creates a pipeline with the given validators. + /// Validators are executed in the order provided. + public init(validators: [any HTTPRequestValidator]) { + self.validators = validators + } + + public func validate(_ request: HTTPRequest, context: HTTPValidationContext) -> HTTPResponse? { + for validator in validators { + if let errorResponse = validator.validate(request, context: context) { + return errorResponse + } + } + return nil + } +} diff --git a/Sources/MCP/Base/Transports/HTTPServer/HTTPServerTypes.swift b/Sources/MCP/Base/Transports/HTTPServer/HTTPServerTypes.swift new file mode 100644 index 00000000..5ec82dc9 --- /dev/null +++ b/Sources/MCP/Base/Transports/HTTPServer/HTTPServerTypes.swift @@ -0,0 +1,259 @@ +import Foundation + +// MARK: - Session ID Generator + +/// Generates unique session identifiers for stateful HTTP server transports. +/// +/// Conform to this protocol to provide custom session ID generation logic. +/// Session IDs **MUST** contain only visible ASCII characters (0x21–0x7E) +/// per the MCP specification. +/// +/// A default implementation using UUID is provided via ``UUIDSessionIDGenerator``. +public protocol SessionIDGenerator: Sendable { + /// Generates a new unique session identifier. + func generateSessionID() -> String +} + +/// Default session ID generator that produces UUID strings. +/// +/// UUID strings consist of hexadecimal characters and hyphens, +/// which are all within the valid ASCII range (0x21–0x7E). +public struct UUIDSessionIDGenerator: SessionIDGenerator { + public init() {} + + public func generateSessionID() -> String { + UUID().uuidString + } +} + +// MARK: - HTTP Request + +/// A framework-agnostic HTTP request representation. +/// +/// This type decouples the transport from any specific HTTP framework. +/// The HTTP framework adapter converts its native request type into this before passing to the transport. +public struct HTTPRequest: Sendable { + /// The HTTP method (e.g., "GET", "POST", "DELETE"). + public let method: String + + /// HTTP headers as key-value pairs. + public let headers: [String: String] + + /// The request body data, if any. + public let body: Data? + + public init(method: String, headers: [String: String] = [:], body: Data? = nil) { + self.method = method + self.headers = headers + self.body = body + } + + /// Case-insensitive header lookup. + public func header(_ name: String) -> String? { + let lowercased = name.lowercased() + return headers.first { $0.key.lowercased() == lowercased }?.value + } +} + +// MARK: - HTTP Response + +/// A framework-agnostic HTTP response. +/// +/// The HTTP framework adapter converts this into its native response type. +/// +/// Use computed properties (`statusCode`, `headers`, `bodyData`) for generic access, +/// or switch on the enum for case-specific handling (e.g., streaming): +/// +/// ```swift +/// let response = await transport.handleRequest(request) +/// switch response { +/// case .stream(let sseStream, _): +/// // Pipe the async stream to the HTTP response body +/// default: +/// // Use response.bodyData for the body +/// } +/// ``` +public enum HTTPResponse: Sendable { + /// 202 Accepted, no body. Used for notifications and client responses. + case accepted(headers: [String: String] = [:]) + + /// 200 OK, no body. Used for DELETE confirmation. + case ok(headers: [String: String] = [:]) + + /// 200 OK with data body (typically JSON). + case data(Data, headers: [String: String] = [:]) + + /// 200 OK with SSE streaming body. + case stream(AsyncThrowingStream, headers: [String: String] = [:]) + + /// Error response with a JSON-RPC error body. + /// The status code, headers, and body are derived automatically. + case error(statusCode: Int, MCPError, sessionID: String? = nil) + + // MARK: - Computed Properties + + public var statusCode: Int { + switch self { + case .accepted: 202 + case .ok, .data, .stream: 200 + case .error(let code, _, _): code + } + } + + public var headers: [String: String] { + switch self { + case .accepted(let headers), .ok(let headers), .data(_, let headers), .stream(_, let headers): + return headers + case .error(_, _, let sessionID): + var headers: [String: String] = [HTTPHeaderName.contentType: ContentType.json] + if let sessionID { headers[HTTPHeaderName.sessionID] = sessionID } + return headers + } + } + + /// The response body as data. `nil` for `.accepted`, `.ok`, and `.stream`. + public var bodyData: Data? { + switch self { + case .accepted, .ok, .stream: + return nil + case .data(let data, _): + return data + case .error(_, let error, _): + let errorBody: [String: Any] = [ + "jsonrpc": "2.0", + "error": [ + "code": error.code, + "message": error.errorDescription ?? "Unknown error", + ], + "id": NSNull(), + ] + return try? JSONSerialization.data(withJSONObject: errorBody) + } + } +} + +// MARK: - HTTP Header Names + +/// Standard header names used by the MCP Streamable HTTP transport. +public enum HTTPHeaderName { + public static let sessionID = "Mcp-Session-Id" + public static let protocolVersion = "Mcp-Protocol-Version" + public static let lastEventID = "Last-Event-Id" + public static let accept = "Accept" + public static let contentType = "Content-Type" + public static let origin = "Origin" + public static let host = "Host" + public static let cacheControl = "Cache-Control" + public static let connection = "Connection" + public static let allow = "Allow" +} + +// MARK: - Content Types + +enum ContentType { + static let json = "application/json" + static let sse = "text/event-stream" +} + +// MARK: - SSE Event + +/// A Server-Sent Event (SSE) data structure. +/// +/// Formats according to the SSE specification: +/// https://html.spec.whatwg.org/multipage/server-sent-events.html +struct SSEEvent: Sendable { + var id: String? + var event: String? + var data: String + var retry: Int? + + /// Formats the event as SSE wire data. + func formatted() -> Data { + var result = "" + if let id { + result += "id: \(id)\n" + } + if let event { + result += "event: \(event)\n" + } + if let retry { + result += "retry: \(retry)\n" + } + result += "data: \(data)\n\n" + return Data(result.utf8) + } + + /// Creates a priming event with an empty data field. + /// Per spec, this is sent immediately to prime the client for reconnection. + static func priming(id: String, retry: Int? = nil) -> SSEEvent { + SSEEvent(id: id, event: nil, data: "", retry: retry) + } + + /// Creates a message event wrapping JSON-RPC data. + static func message(data: Data, id: String? = nil) -> SSEEvent { + SSEEvent( + id: id, + event: "message", + data: String(decoding: data, as: UTF8.self) + ) + } +} + +// MARK: - JSON-RPC Message Classification + +/// Classifies a raw JSON-RPC message for routing purposes. +/// +/// Used by transports to determine where to route outgoing messages: +/// - Responses are routed to the originating request's stream +/// - Notifications and server requests are routed to the standalone GET stream +package enum JSONRPCMessageKind { + case request(id: String, method: String) + case notification(method: String) + case response(id: String) + + /// Attempts to classify raw JSON-RPC data. + /// Returns `nil` if the data cannot be parsed or classified. + package init?(data: Data) { + guard let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] else { + return nil + } + + let id = Self.extractID(from: json) + + if let method = json["method"] as? String { + if let id { + self = .request(id: id, method: method) + } else { + self = .notification(method: method) + } + } else if json["result"] != nil || json["error"] != nil { + guard let id else { return nil } + self = .response(id: id) + } else { + return nil + } + } + + /// Whether this message is a JSON-RPC response (success or error). + var isResponse: Bool { + if case .response = self { return true } + return false + } + + /// Whether this message is an `initialize` request. + package var isInitializeRequest: Bool { + if case .request(_, let method) = self { + return method == "initialize" + } + return false + } + + private static func extractID(from json: [String: Any]) -> String? { + if let stringID = json["id"] as? String { + return stringID + } else if let intID = json["id"] as? Int { + return String(intID) + } + return nil + } +} diff --git a/Sources/MCP/Base/Transports/HTTPServer/StatefulHTTPServerTransport.swift b/Sources/MCP/Base/Transports/HTTPServer/StatefulHTTPServerTransport.swift new file mode 100644 index 00000000..aba32061 --- /dev/null +++ b/Sources/MCP/Base/Transports/HTTPServer/StatefulHTTPServerTransport.swift @@ -0,0 +1,536 @@ +import Foundation +import Logging + +/// A stateful HTTP server transport that manages sessions and uses SSE for streaming responses. +/// +/// This transport implements the MCP Streamable HTTP specification with full session management: +/// - Assigns a session ID during initialization (via `Mcp-Session-Id` header) +/// - POST requests receive SSE-streamed responses +/// - GET requests establish a standalone SSE stream for server-initiated messages +/// - DELETE requests terminate the session +/// - Built-in event store for resumability (reconnection with `Last-Event-ID`) +/// +/// ## Usage +/// +/// ```swift +/// let transport = StatefulHTTPServerTransport() // Uses UUID by default +/// +/// // Start the MCP server with this transport +/// try await server.start(transport: transport) +/// +/// // In your HTTP framework handler: +/// let response = await transport.handleRequest(httpRequest) +/// // Convert response to your framework's response type and return it +/// ``` +/// +/// ## Framework Integration +/// +/// This transport is framework-agnostic. You provide incoming requests as `HTTPRequest` +/// and receive `HTTPResponse` values to convert to your framework's native types. +/// For SSE responses, the `.stream` case provides an `AsyncThrowingStream` +/// to pipe to the client. +public actor StatefulHTTPServerTransport: Transport { + public nonisolated let logger: Logger + + // MARK: - Dependencies + + private let sessionIDGenerator: any SessionIDGenerator + private let validationPipeline: any HTTPRequestValidationPipeline + private let retryInterval: Int? + + // MARK: - State + + private var sessionID: String? + private var terminated = false + private var started = false + + // MARK: - Incoming message stream (client → server) + + private let incomingStream: AsyncThrowingStream + private let incomingContinuation: AsyncThrowingStream.Continuation + + // MARK: - SSE streams for POST request responses + + /// Maps request ID → SSE stream continuation for active POST request streams. + private var requestSSEContinuations: [String: AsyncThrowingStream.Continuation] = [:] + + // MARK: - Standalone GET SSE stream + + /// The standalone SSE stream continuation for server-initiated messages. + /// Only one GET stream is allowed per session. + private var standaloneSSEContinuation: AsyncThrowingStream.Continuation? + + /// Internal identifier for the standalone GET stream in the event store. + private let standaloneStreamID = "_GET_stream" + + // MARK: - Event Store (Resumability) + + private struct StoredEvent { + let streamID: String + let eventID: String + let message: Data? + } + + private var storedEvents: [StoredEvent] = [] + private var eventCounter: Int = 0 + + // MARK: - Init + + /// Creates a new stateful HTTP server transport. + /// + /// - Parameters: + /// - sessionIDGenerator: Generator for session IDs. The IDs MUST contain + /// only visible ASCII characters (0x21-0x7E) per the MCP specification. + /// Defaults to ``UUIDSessionIDGenerator``. + /// - validationPipeline: Custom validation pipeline. If `nil`, uses sensible defaults: + /// origin validation (localhost), Accept header (SSE required), Content-Type, + /// protocol version, and session validation. + /// - retryInterval: Retry interval in milliseconds for SSE priming events. + /// Controls how long clients wait before attempting to reconnect. + /// - logger: Optional logger. If `nil`, a no-op logger is used. + public init( + sessionIDGenerator: any SessionIDGenerator = UUIDSessionIDGenerator(), + validationPipeline: (any HTTPRequestValidationPipeline)? = nil, + retryInterval: Int? = nil, + logger: Logger? = nil + ) { + self.sessionIDGenerator = sessionIDGenerator + self.validationPipeline = validationPipeline ?? StandardValidationPipeline(validators: [ + OriginValidator.localhost(), + AcceptHeaderValidator(mode: .sseRequired), + ContentTypeValidator(), + ProtocolVersionValidator(), + SessionValidator(), + ]) + self.retryInterval = retryInterval + self.logger = logger ?? Logger( + label: "mcp.transport.http.server.stateful", + factory: { _ in SwiftLogNoOpLogHandler() } + ) + + let (stream, continuation) = AsyncThrowingStream.makeStream() + self.incomingStream = stream + self.incomingContinuation = continuation + } + + // MARK: - Transport Conformance + + public func connect() async throws { + guard !started else { + throw MCPError.internalError("Transport already started") + } + started = true + logger.debug("Stateful HTTP server transport started") + } + + public func disconnect() async { + terminate() + } + + /// Routes outgoing server messages to the appropriate client connection. + /// + /// - Responses are routed to the SSE stream matching the response's JSON-RPC ID. + /// - Notifications and server-initiated requests are routed to the standalone GET stream. + public func send(_ data: Data) async throws { + guard !terminated else { + throw MCPError.connectionClosed + } + + guard let kind = JSONRPCMessageKind(data: data) else { + logger.warning("Could not classify outgoing message for routing") + return + } + + switch kind { + case .response(let id): + routeResponse(data, requestID: id) + case .notification, .request: + routeServerInitiatedMessage(data) + } + } + + public func receive() -> AsyncThrowingStream { + incomingStream + } + + // MARK: - HTTP Request Handler + + /// Handles an incoming HTTP request from the framework adapter. + /// + /// Routes by HTTP method: + /// - **POST**: JSON-RPC messages (requests, notifications) + /// - **GET**: Establish standalone SSE stream for server-initiated messages + /// - **DELETE**: Terminate the session + /// - Others: 405 Method Not Allowed + public func handleRequest(_ request: HTTPRequest) async -> HTTPResponse { + if terminated { + return .error( + statusCode: 404, + .invalidRequest("Not Found: Session has been terminated"), + sessionID: sessionID + ) + } + + switch request.method.uppercased() { + case "POST": + return handlePost(request) + case "GET": + return handleGet(request) + case "DELETE": + return handleDelete(request) + default: + return .error( + statusCode: 405, + .invalidRequest("Method Not Allowed"), + sessionID: sessionID + ) + } + } + + // MARK: - POST Handler + + private func handlePost(_ request: HTTPRequest) -> HTTPResponse { + // Parse body first so we can determine if it's an initialization request + guard let body = request.body, !body.isEmpty else { + return .error( + statusCode: 400, + .parseError("Empty request body"), + sessionID: sessionID + ) + } + + guard let messageKind = JSONRPCMessageKind(data: body) else { + return .error( + statusCode: 400, + .parseError("Invalid JSON-RPC message"), + sessionID: sessionID + ) + } + + // Build validation context + let context = HTTPValidationContext( + httpMethod: "POST", + sessionID: sessionID, + isInitializationRequest: messageKind.isInitializeRequest, + supportedProtocolVersions: Version.supported + ) + + // Run validation pipeline + if let errorResponse = validationPipeline.validate(request, context: context) { + return errorResponse + } + + // Handle initialization request specially + if messageKind.isInitializeRequest { + return handleInitializationRequest(body, request: request) + } + + // Handle by message type + switch messageKind { + case .notification, .response: + // Yield to server and return 202 Accepted + incomingContinuation.yield(body) + return .accepted(headers: sessionHeaders()) + + case .request(let id, _): + return handleJSONRPCRequest(body, requestID: id, request: request) + } + } + + private func handleInitializationRequest(_ body: Data, request: HTTPRequest) -> HTTPResponse { + // Generate session ID + let newSessionID = sessionIDGenerator.generateSessionID() + + // Validate session ID contains only visible ASCII (0x21-0x7E) + guard isValidSessionID(newSessionID) else { + logger.error("Generated session ID contains invalid characters") + return .error( + statusCode: 500, + .internalError("Internal error: Invalid session ID generated") + ) + } + + self.sessionID = newSessionID + logger.info("Session initialized", metadata: ["sessionID": "\(newSessionID)"]) + + // Extract request ID for routing the response + guard case .request(let requestID, _) = JSONRPCMessageKind(data: body) else { + return .error( + statusCode: 400, + .parseError("Invalid initialize request"), + sessionID: newSessionID + ) + } + + // For the initialize request, use SSE streaming like any other request + return handleJSONRPCRequest(body, requestID: requestID, request: request) + } + + private func handleJSONRPCRequest(_ body: Data, requestID: String, request: HTTPRequest) -> HTTPResponse { + // Create SSE stream for this request + let (sseStream, sseContinuation) = AsyncThrowingStream.makeStream() + requestSSEContinuations[requestID] = sseContinuation + + // Extract protocol version for priming event decision + let protocolVersion = extractProtocolVersion(from: body, request: request) + + // Send priming event for resumability + sendPrimingEvent( + streamID: requestID, + continuation: sseContinuation, + protocolVersion: protocolVersion + ) + + // Yield the incoming message to the server + incomingContinuation.yield(body) + + // Build response headers + var headers = sessionHeaders() + headers[HTTPHeaderName.contentType] = ContentType.sse + headers[HTTPHeaderName.cacheControl] = "no-cache, no-transform" + headers[HTTPHeaderName.connection] = "keep-alive" + + return .stream(sseStream, headers: headers) + } + + // MARK: - GET Handler + + private func handleGet(_ request: HTTPRequest) -> HTTPResponse { + // Build validation context (GET is never an initialization request) + let context = HTTPValidationContext( + httpMethod: "GET", + sessionID: sessionID, + isInitializationRequest: false, + supportedProtocolVersions: Version.supported + ) + + // Run validation pipeline + if let errorResponse = validationPipeline.validate(request, context: context) { + return errorResponse + } + + // Handle resumability: check for Last-Event-ID header + if let lastEventID = request.header(HTTPHeaderName.lastEventID) { + return handleResumeRequest(lastEventID: lastEventID, request: request) + } + + // Only one standalone GET stream per session + guard standaloneSSEContinuation == nil else { + return .error( + statusCode: 409, + .invalidRequest("Conflict: Only one SSE stream is allowed per session"), + sessionID: sessionID + ) + } + + // Create standalone SSE stream + let (sseStream, sseContinuation) = AsyncThrowingStream.makeStream() + standaloneSSEContinuation = sseContinuation + + // Extract protocol version for priming event + let protocolVersion = request.header(HTTPHeaderName.protocolVersion) ?? Version.latest + + // Send priming event + sendPrimingEvent( + streamID: standaloneStreamID, + continuation: sseContinuation, + protocolVersion: protocolVersion + ) + + // Build response headers + var headers = sessionHeaders() + headers[HTTPHeaderName.contentType] = ContentType.sse + headers[HTTPHeaderName.cacheControl] = "no-cache, no-transform" + headers[HTTPHeaderName.connection] = "keep-alive" + + return .stream(sseStream, headers: headers) + } + + // MARK: - DELETE Handler + + private func handleDelete(_ request: HTTPRequest) -> HTTPResponse { + // Validate session + let context = HTTPValidationContext( + httpMethod: "DELETE", + sessionID: sessionID, + isInitializationRequest: false, + supportedProtocolVersions: Version.supported + ) + + // Only run session validation for DELETE (not all validators) + let sessionValidator = SessionValidator() + if let errorResponse = sessionValidator.validate(request, context: context) { + return errorResponse + } + + terminate() + + return .ok(headers: sessionHeaders()) + } + + // MARK: - Message Routing + + private func routeResponse(_ data: Data, requestID: String) { + let eventID = storeEvent(streamID: requestID, message: data) + + guard let continuation = requestSSEContinuations[requestID] else { + logger.debug( + "No active stream for request, response stored for replay", + metadata: ["requestID": "\(requestID)"] + ) + return + } + + // Format as SSE and yield + let sseEvent = SSEEvent.message(data: data, id: eventID) + continuation.yield(sseEvent.formatted()) + + // Response means the request is complete — close the stream + continuation.finish() + requestSSEContinuations.removeValue(forKey: requestID) + } + + private func routeServerInitiatedMessage(_ data: Data) { + let eventID = storeEvent(streamID: standaloneStreamID, message: data) + + guard let continuation = standaloneSSEContinuation else { + logger.debug("No standalone GET stream connected, message stored for replay") + return + } + + let sseEvent = SSEEvent.message(data: data, id: eventID) + continuation.yield(sseEvent.formatted()) + } + + // MARK: - Resumability + + private func handleResumeRequest(lastEventID: String, request: HTTPRequest) -> HTTPResponse { + guard let replay = replayEventsAfter(lastEventID) else { + return .error( + statusCode: 400, + .invalidRequest("Invalid Last-Event-ID"), + sessionID: sessionID + ) + } + + let (sseStream, sseContinuation) = AsyncThrowingStream.makeStream() + + // Replay stored events + for (eventID, message) in replay.events { + let sseEvent = SSEEvent.message(data: message, id: eventID) + sseContinuation.yield(sseEvent.formatted()) + } + + // Re-register the stream for future messages + if replay.streamID == standaloneStreamID { + standaloneSSEContinuation = sseContinuation + } else { + requestSSEContinuations[replay.streamID] = sseContinuation + } + + // Send a new priming event so the client can resume again if disconnected + let protocolVersion = request.header(HTTPHeaderName.protocolVersion) ?? Version.latest + sendPrimingEvent( + streamID: replay.streamID, + continuation: sseContinuation, + protocolVersion: protocolVersion + ) + + var headers = sessionHeaders() + headers[HTTPHeaderName.contentType] = ContentType.sse + headers[HTTPHeaderName.cacheControl] = "no-cache, no-transform" + headers[HTTPHeaderName.connection] = "keep-alive" + + return .stream(sseStream, headers: headers) + } + + // MARK: - Internal Event Store + + private func storeEvent(streamID: String, message: Data?) -> String { + eventCounter += 1 + let eventID = "\(streamID)_\(eventCounter)" + storedEvents.append(StoredEvent(streamID: streamID, eventID: eventID, message: message)) + return eventID + } + + private func replayEventsAfter(_ lastEventID: String) -> (streamID: String, events: [(eventID: String, message: Data)])? { + guard let index = storedEvents.firstIndex(where: { $0.eventID == lastEventID }) else { + return nil + } + let streamID = storedEvents[index].streamID + let eventsToReplay = storedEvents[(index + 1)...] + .filter { $0.streamID == streamID && $0.message != nil } + .map { (eventID: $0.eventID, message: $0.message!) } + return (streamID, eventsToReplay) + } + + // MARK: - SSE Helpers + + private func sendPrimingEvent( + streamID: String, + continuation: AsyncThrowingStream.Continuation, + protocolVersion: String + ) { + // Priming events with empty data are only safe for clients >= 2025-11-25 + guard protocolVersion >= "2025-11-25" else { return } + + let primingEventID = storeEvent(streamID: streamID, message: nil) + let primingEvent = SSEEvent.priming(id: primingEventID, retry: retryInterval) + continuation.yield(primingEvent.formatted()) + } + + // MARK: - Session Helpers + + private func sessionHeaders() -> [String: String] { + var headers: [String: String] = [:] + if let sessionID { + headers[HTTPHeaderName.sessionID] = sessionID + } + return headers + } + + private func isValidSessionID(_ id: String) -> Bool { + guard !id.isEmpty else { return false } + return id.utf8.allSatisfy { $0 >= 0x21 && $0 <= 0x7E } + } + + private func extractProtocolVersion(from body: Data, request: HTTPRequest) -> String { + // For initialize requests, extract from the request body params + if let json = try? JSONSerialization.jsonObject(with: body) as? [String: Any], + let method = json["method"] as? String, method == "initialize", + let params = json["params"] as? [String: Any], + let version = params["protocolVersion"] as? String + { + return version + } + // For other requests, use the header + return request.header(HTTPHeaderName.protocolVersion) ?? Version.latest + } + + // MARK: - Termination + + /// Terminates the session, closing all active streams. + /// After termination, all requests receive 404 Not Found. + private func terminate() { + guard !terminated else { return } + terminated = true + + logger.info("Terminating session", metadata: ["sessionID": "\(sessionID ?? "none")"]) + + // Close all request SSE streams + for (_, continuation) in requestSSEContinuations { + continuation.finish() + } + requestSSEContinuations.removeAll() + + // Close standalone GET stream + standaloneSSEContinuation?.finish() + standaloneSSEContinuation = nil + + // Clear stored events + storedEvents.removeAll() + + // Close incoming stream + incomingContinuation.finish() + } +} diff --git a/Sources/MCP/Base/Transports/HTTPServer/StatelessHTTPServerTransport.swift b/Sources/MCP/Base/Transports/HTTPServer/StatelessHTTPServerTransport.swift new file mode 100644 index 00000000..76948414 --- /dev/null +++ b/Sources/MCP/Base/Transports/HTTPServer/StatelessHTTPServerTransport.swift @@ -0,0 +1,251 @@ +import Foundation +import Logging + +/// A stateless HTTP server transport that returns single JSON responses. +/// +/// This transport implements a minimal subset of the MCP Streamable HTTP specification: +/// - No session management (no `Mcp-Session-Id` header) +/// - POST requests receive direct JSON responses (no SSE streaming) +/// - GET and DELETE requests return 405 Method Not Allowed +/// +/// ## Usage +/// +/// ```swift +/// let transport = StatelessHTTPServerTransport() +/// +/// // Start the MCP server with this transport +/// try await server.start(transport: transport) +/// +/// // In your HTTP framework handler: +/// let response = await transport.handleRequest(httpRequest) +/// // Convert response to your framework's response type and return it +/// ``` +/// +/// ## When to Use +/// +/// Use this transport when: +/// - You don't need server-initiated messages (no GET SSE stream) +/// - You want simple request-response semantics +/// - Session management is handled externally or not needed +/// +/// For full streaming and session support, use ``StatefulHTTPServerTransport`` instead. +public actor StatelessHTTPServerTransport: Transport { + public nonisolated let logger: Logger + + // MARK: - Dependencies + + private let validationPipeline: any HTTPRequestValidationPipeline + + // MARK: - State + + private var terminated = false + private var started = false + + // MARK: - Incoming message stream (client → server) + + private let incomingStream: AsyncThrowingStream + private let incomingContinuation: AsyncThrowingStream.Continuation + + // MARK: - Response waiters + + /// Maps request ID → continuation waiting for the server's response. + /// When the server calls `send()` with a response, the matching continuation is resumed. + private var responseWaiters: [String: CheckedContinuation] = [:] + + // MARK: - Init + + /// Creates a new stateless HTTP server transport. + /// + /// - Parameters: + /// - validationPipeline: Custom validation pipeline. If `nil`, uses sensible defaults: + /// origin validation (localhost), Accept header (JSON only), Content-Type, + /// and protocol version validation. + /// - logger: Optional logger. If `nil`, a no-op logger is used. + public init( + validationPipeline: (any HTTPRequestValidationPipeline)? = nil, + logger: Logger? = nil + ) { + self.validationPipeline = validationPipeline ?? StandardValidationPipeline(validators: [ + OriginValidator.localhost(), + AcceptHeaderValidator(mode: .jsonOnly), + ContentTypeValidator(), + ProtocolVersionValidator(), + ]) + self.logger = logger ?? Logger( + label: "mcp.transport.http.server.stateless", + factory: { _ in SwiftLogNoOpLogHandler() } + ) + + let (stream, continuation) = AsyncThrowingStream.makeStream() + self.incomingStream = stream + self.incomingContinuation = continuation + } + + // MARK: - Transport Conformance + + public func connect() async throws { + guard !started else { + throw MCPError.internalError("Transport already started") + } + started = true + logger.debug("Stateless HTTP server transport started") + } + + public func disconnect() async { + await terminate() + } + + /// Routes outgoing server messages to the appropriate waiting HTTP handler. + /// + /// - Responses are matched by JSON-RPC ID and delivered to the waiting `handleRequest` call. + /// - Notifications and server-initiated requests are logged and dropped + /// (no streaming channel available in stateless mode). + public func send(_ data: Data) async throws { + guard !terminated else { + throw MCPError.connectionClosed + } + + guard let kind = JSONRPCMessageKind(data: data) else { + logger.warning("Could not classify outgoing message for routing") + return + } + + switch kind { + case .response(let id): + guard let continuation = responseWaiters.removeValue(forKey: id) else { + logger.debug( + "No waiter for response, may have timed out", + metadata: ["requestID": "\(id)"] + ) + return + } + continuation.resume(returning: data) + + case .notification(let method): + logger.debug( + "Server-initiated notification dropped in stateless mode (no GET SSE stream)", + metadata: ["method": "\(method)"] + ) + + case .request(_, let method): + logger.debug( + "Server-initiated request dropped in stateless mode (no GET SSE stream)", + metadata: ["method": "\(method)"] + ) + } + } + + public func receive() -> AsyncThrowingStream { + incomingStream + } + + // MARK: - HTTP Request Handler + + /// Handles an incoming HTTP request from the framework adapter. + /// + /// Only POST is supported: + /// - **POST**: JSON-RPC messages (requests, notifications) + /// - **GET**: 405 Method Not Allowed + /// - **DELETE**: 405 Method Not Allowed + /// - Others: 405 Method Not Allowed + public func handleRequest(_ request: HTTPRequest) async -> HTTPResponse { + if terminated { + return .error( + statusCode: 404, + .invalidRequest("Not Found: Transport has been terminated") + ) + } + + switch request.method.uppercased() { + case "POST": + return await handlePost(request) + default: + return .error( + statusCode: 405, + .invalidRequest("Method Not Allowed") + ) + } + } + + // MARK: - POST Handler + + private func handlePost(_ request: HTTPRequest) async -> HTTPResponse { + // Parse body first to determine message type + guard let body = request.body, !body.isEmpty else { + return .error( + statusCode: 400, + .parseError("Empty request body") + ) + } + + guard let messageKind = JSONRPCMessageKind(data: body) else { + return .error( + statusCode: 400, + .parseError("Invalid JSON-RPC message") + ) + } + + // Build validation context + let context = HTTPValidationContext( + httpMethod: "POST", + sessionID: nil, + isInitializationRequest: messageKind.isInitializeRequest, + supportedProtocolVersions: Version.supported + ) + + // Run validation pipeline + if let errorResponse = validationPipeline.validate(request, context: context) { + return errorResponse + } + + // Handle by message type + switch messageKind { + case .notification, .response: + // Yield to server and return 202 Accepted + incomingContinuation.yield(body) + return .accepted() + + case .request(let id, _): + return await handleJSONRPCRequest(body, requestID: id) + } + } + + private func handleJSONRPCRequest(_ body: Data, requestID: String) async -> HTTPResponse { + // Yield the incoming message to the server + incomingContinuation.yield(body) + + // Wait for the server to process and send a response + let responseData: Data + do { + responseData = try await withCheckedThrowingContinuation { continuation in + responseWaiters[requestID] = continuation + } + } catch { + return .error( + statusCode: 500, + .internalError("Error processing request: \(error.localizedDescription)") + ) + } + + return .data(responseData, headers: [HTTPHeaderName.contentType: ContentType.json]) + } + + // MARK: - Termination + + private func terminate() async { + guard !terminated else { return } + terminated = true + + logger.debug("Stateless HTTP server transport terminated") + + // Cancel all waiting continuations + for (id, continuation) in responseWaiters { + continuation.resume(throwing: MCPError.connectionClosed) + logger.debug("Cancelled waiter for request", metadata: ["requestID": "\(id)"]) + } + responseWaiters.removeAll() + + // Close incoming stream + incomingContinuation.finish() + } +} diff --git a/Sources/MCP/Base/Versioning.swift b/Sources/MCP/Base/Versioning.swift index 9e29a82f..0bed5c49 100644 --- a/Sources/MCP/Base/Versioning.swift +++ b/Sources/MCP/Base/Versioning.swift @@ -7,7 +7,7 @@ import Foundation /// - SeeAlso: https://modelcontextprotocol.io/specification/2025-11-25/ public enum Version { /// All protocol versions supported by this implementation, ordered from newest to oldest. - static let supported: Set = [ + public static let supported: Set = [ "2025-11-25", "2025-06-18", "2025-03-26", diff --git a/Sources/MCP/Client/Client.swift b/Sources/MCP/Client/Client.swift index a3cf45a7..68487c7f 100644 --- a/Sources/MCP/Client/Client.swift +++ b/Sources/MCP/Client/Client.swift @@ -807,6 +807,87 @@ public actor Client { return self } + // MARK: - Logging + + /// Set the minimum logging level for server log messages. + /// + /// Servers that declare the `logging` capability will send log messages via + /// `notifications/message` notifications. Use this method to control which + /// severity levels the server should send. + /// + /// - Parameter level: The minimum log level to receive + /// - Throws: MCPError if the client is not connected or if the server doesn't support logging + /// - SeeAlso: https://modelcontextprotocol.io/specification/2025-11-25/server/utilities/logging/ + public func setLoggingLevel(_ level: LogLevel) async throws { + try validateServerCapability(\.logging, "Logging") + let request = SetLoggingLevel.request(.init(level: level)) + _ = try await sendAndAwait(request) + } + + // MARK: - Completions + + /// Request completion suggestions for a prompt argument. + /// + /// Servers that declare the `completions` capability can provide autocompletion + /// suggestions for prompt arguments as users type. + /// + /// - Parameters: + /// - promptName: The name of the prompt + /// - argumentName: The name of the argument being completed + /// - argumentValue: The current (partial) value of the argument + /// - context: Optional context with already-resolved arguments + /// - Returns: A completion result containing suggested values + /// - Throws: MCPError if the client is not connected or if the server doesn't support completions + /// - SeeAlso: https://modelcontextprotocol.io/specification/2025-11-25/server/utilities/completion/ + public func complete( + promptName: String, + argumentName: String, + argumentValue: String, + context: [String: Value]? = nil + ) async throws -> Complete.Result.Completion { + try validateServerCapability(\.completions, "Completions") + let request = Complete.request( + .init( + ref: .prompt(.init(name: promptName)), + argument: .init(name: argumentName, value: argumentValue), + context: context.map { .init(arguments: $0) } + ) + ) + let result = try await sendAndAwait(request) + return result.completion + } + + /// Request completion suggestions for a resource template argument. + /// + /// Servers that declare the `completions` capability can provide autocompletion + /// suggestions for resource template arguments as users type. + /// + /// - Parameters: + /// - resourceURI: The URI of the resource template + /// - argumentName: The name of the argument being completed + /// - argumentValue: The current (partial) value of the argument + /// - context: Optional context with already-resolved arguments + /// - Returns: A completion result containing suggested values + /// - Throws: MCPError if the client is not connected or if the server doesn't support completions + /// - SeeAlso: https://modelcontextprotocol.io/specification/2025-11-25/server/utilities/completion/ + public func complete( + resourceURI: String, + argumentName: String, + argumentValue: String, + context: [String: Value]? = nil + ) async throws -> Complete.Result.Completion { + try validateServerCapability(\.completions, "Completions") + let request = Complete.request( + .init( + ref: .resource(.init(uri: resourceURI)), + argument: .init(name: argumentName, value: argumentValue), + context: context.map { .init(arguments: $0) } + ) + ) + let result = try await sendAndAwait(request) + return result.completion + } + // MARK: - private func handleResponse(_ response: Response) async { diff --git a/Sources/MCP/Server/Completion.swift b/Sources/MCP/Server/Completion.swift new file mode 100644 index 00000000..dab83190 --- /dev/null +++ b/Sources/MCP/Server/Completion.swift @@ -0,0 +1,192 @@ +import Foundation + +/// The Model Context Protocol (MCP) provides a standardized way for servers to offer +/// autocompletion suggestions for the arguments of prompts and resource templates. +/// +/// - SeeAlso: https://modelcontextprotocol.io/specification/2025-11-25/server/utilities/completion/ + +// MARK: - Reference Types + +/// A reference to a prompt by name +public struct PromptReference: Hashable, Codable, Sendable { + /// The prompt name + public let name: String + + public init(name: String) { + self.name = name + } + + private enum CodingKeys: String, CodingKey { + case type, name + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode("ref/prompt", forKey: .type) + try container.encode(name, forKey: .name) + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let type = try container.decode(String.self, forKey: .type) + guard type == "ref/prompt" else { + throw DecodingError.dataCorruptedError( + forKey: .type, + in: container, + debugDescription: "Expected ref/prompt type" + ) + } + name = try container.decode(String.self, forKey: .name) + } +} + +/// A reference to a resource by URI +public struct ResourceReference: Hashable, Codable, Sendable { + /// The resource URI + public let uri: String + + public init(uri: String) { + self.uri = uri + } + + private enum CodingKeys: String, CodingKey { + case type, uri + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode("ref/resource", forKey: .type) + try container.encode(uri, forKey: .uri) + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let type = try container.decode(String.self, forKey: .type) + guard type == "ref/resource" else { + throw DecodingError.dataCorruptedError( + forKey: .type, + in: container, + debugDescription: "Expected ref/resource type" + ) + } + uri = try container.decode(String.self, forKey: .uri) + } +} + +/// A reference type for completion requests (either prompt or resource) +public enum CompletionReference: Hashable, Codable, Sendable { + /// References a prompt by name + case prompt(PromptReference) + /// References a resource URI + case resource(ResourceReference) + + private enum CodingKeys: String, CodingKey { + case type + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let type = try container.decode(String.self, forKey: .type) + + switch type { + case "ref/prompt": + self = .prompt(try PromptReference(from: decoder)) + case "ref/resource": + self = .resource(try ResourceReference(from: decoder)) + default: + throw DecodingError.dataCorruptedError( + forKey: .type, + in: container, + debugDescription: "Unknown reference type: \(type)" + ) + } + } + + public func encode(to encoder: Encoder) throws { + switch self { + case .prompt(let ref): + try ref.encode(to: encoder) + case .resource(let ref): + try ref.encode(to: encoder) + } + } +} + +// MARK: - Completion Request + +/// To get completion suggestions, clients send a `completion/complete` request. +/// - SeeAlso: https://modelcontextprotocol.io/specification/2025-11-25/server/utilities/completion/ +public enum Complete: Method { + public static let name = "completion/complete" + + public struct Parameters: Hashable, Codable, Sendable { + /// The reference to what is being completed + public let ref: CompletionReference + /// The argument being completed + public let argument: Argument + /// Optional context with already-resolved arguments + public let context: Context? + + public init( + ref: CompletionReference, + argument: Argument, + context: Context? = nil + ) { + self.ref = ref + self.argument = argument + self.context = context + } + + /// The argument being completed + public struct Argument: Hashable, Codable, Sendable { + /// The argument name + public let name: String + /// The current value (partial or complete) + public let value: String + + public init(name: String, value: String) { + self.name = name + self.value = value + } + } + + /// Context containing already-resolved arguments + public struct Context: Hashable, Codable, Sendable { + /// A mapping of already-resolved argument names to their values + public let arguments: [String: Value] + + public init(arguments: [String: Value]) { + self.arguments = arguments + } + } + } + + public struct Result: Hashable, Codable, Sendable { + /// The completion result + public let completion: Completion + + public init(completion: Completion) { + self.completion = completion + } + + /// Completion result containing suggested values + public struct Completion: Hashable, Codable, Sendable { + /// Array of completion values (max 100 items) + public let values: [String] + /// Optional total number of available matches + public let total: Int? + /// Whether additional results exist + public let hasMore: Bool? + + public init( + values: [String], + total: Int? = nil, + hasMore: Bool? = nil + ) { + self.values = values + self.total = total + self.hasMore = hasMore + } + } + } +} diff --git a/Sources/MCP/Server/Logging.swift b/Sources/MCP/Server/Logging.swift new file mode 100644 index 00000000..cb1b8be4 --- /dev/null +++ b/Sources/MCP/Server/Logging.swift @@ -0,0 +1,72 @@ +import Foundation + +/// The Model Context Protocol (MCP) provides a standardized way for servers to send +/// structured log messages to clients. Clients can control logging verbosity by setting +/// minimum log levels, with servers sending notifications containing severity levels, +/// optional logger names, and arbitrary JSON-serializable data. +/// +/// - SeeAlso: https://modelcontextprotocol.io/specification/2025-11-25/server/utilities/logging/ +public enum LogLevel: String, Hashable, Codable, Sendable, CaseIterable { + /// Detailed debugging information + case debug + /// General informational messages + case info + /// Normal but significant events + case notice + /// Warning conditions + case warning + /// Error conditions + case error + /// Critical conditions + case critical + /// Action must be taken immediately + case alert + /// System is unusable + case emergency +} + +// MARK: - Set Log Level + +/// To configure the minimum log level, clients MAY send a `logging/setLevel` request. +/// - SeeAlso: https://modelcontextprotocol.io/specification/2025-11-25/server/utilities/logging/ +public enum SetLoggingLevel: Method { + public static let name = "logging/setLevel" + + public struct Parameters: Hashable, Codable, Sendable { + /// The minimum log level to set + public let level: LogLevel + + public init(level: LogLevel) { + self.level = level + } + } + + public typealias Result = Empty +} + +// MARK: - Log Message Notification + +/// Servers send log messages using `notifications/message` notifications. +/// - SeeAlso: https://modelcontextprotocol.io/specification/2025-11-25/server/utilities/logging/ +public struct LogMessageNotification: Notification { + public static let name = "notifications/message" + + public struct Parameters: Hashable, Codable, Sendable { + /// The severity level of the log message + public let level: LogLevel + /// Optional logger name to identify the source + public let logger: String? + /// Arbitrary JSON-serializable data for the log message + public let data: Value + + public init( + level: LogLevel, + logger: String? = nil, + data: Value + ) { + self.level = level + self.logger = logger + self.data = data + } + } +} diff --git a/Sources/MCP/Server/Prompts.swift b/Sources/MCP/Server/Prompts.swift index 7be8f568..261f79d2 100644 --- a/Sources/MCP/Server/Prompts.swift +++ b/Sources/MCP/Server/Prompts.swift @@ -134,8 +134,8 @@ public struct Prompt: Hashable, Codable, Sendable { case image(data: String, mimeType: String) /// Audio content case audio(data: String, mimeType: String) - /// Embedded resource content - case resource(uri: String, mimeType: String, text: String?, blob: String?) + /// Embedded resource content (EmbeddedResource from spec) + case resource(resource: Resource.Content, annotations: Resource.Annotations? = nil, _meta: Metadata? = nil) } } @@ -175,7 +175,7 @@ public struct Prompt: Hashable, Codable, Sendable { extension Prompt.Message.Content: Codable { private enum CodingKeys: String, CodingKey { - case type, text, data, mimeType, uri, blob + case type, text, data, mimeType, resource, annotations, _meta } public func encode(to encoder: Encoder) throws { @@ -193,12 +193,11 @@ extension Prompt.Message.Content: Codable { try container.encode("audio", forKey: .type) try container.encode(data, forKey: .data) try container.encode(mimeType, forKey: .mimeType) - case .resource(let uri, let mimeType, let text, let blob): + case .resource(let resourceContent, let annotations, let _meta): try container.encode("resource", forKey: .type) - try container.encode(uri, forKey: .uri) - try container.encode(mimeType, forKey: .mimeType) - try container.encodeIfPresent(text, forKey: .text) - try container.encodeIfPresent(blob, forKey: .blob) + try container.encode(resourceContent, forKey: .resource) + try container.encodeIfPresent(annotations, forKey: .annotations) + try container.encodeIfPresent(_meta, forKey: ._meta) } } @@ -219,11 +218,10 @@ extension Prompt.Message.Content: Codable { let mimeType = try container.decode(String.self, forKey: .mimeType) self = .audio(data: data, mimeType: mimeType) case "resource": - let uri = try container.decode(String.self, forKey: .uri) - let mimeType = try container.decode(String.self, forKey: .mimeType) - let text = try container.decodeIfPresent(String.self, forKey: .text) - let blob = try container.decodeIfPresent(String.self, forKey: .blob) - self = .resource(uri: uri, mimeType: mimeType, text: text, blob: blob) + let resourceContent = try container.decode(Resource.Content.self, forKey: .resource) + let annotations = try container.decodeIfPresent(Resource.Annotations.self, forKey: .annotations) + let _meta = try container.decodeIfPresent(Metadata.self, forKey: ._meta) + self = .resource(resource: resourceContent, annotations: annotations, _meta: _meta) default: throw DecodingError.dataCorruptedError( forKey: .type, @@ -343,7 +341,7 @@ public enum GetPrompt: Method { var container = encoder.container(keyedBy: CodingKeys.self) try container.encodeIfPresent(description, forKey: .description) try container.encode(messages, forKey: .messages) - try container.encode(_meta, forKey: ._meta) + try container.encodeIfPresent(_meta, forKey: ._meta) } public init(from decoder: Decoder) throws { diff --git a/Sources/MCP/Server/Resources.swift b/Sources/MCP/Server/Resources.swift index 000e9648..24be3395 100644 --- a/Sources/MCP/Server/Resources.swift +++ b/Sources/MCP/Server/Resources.swift @@ -212,16 +212,16 @@ public struct Resource: Hashable, Codable, Sendable { } /// An array indicating the intended audience(s) for this resource. For example, `[.user, .assistant]` indicates content useful for both. - public let audience: [Audience] - /// A number from 0.0 to 1.0 indicating the importance of this resource. A value of 1 means “most important” (effectively required), while 0 means “least important”. + public let audience: [Audience]? + /// A number from 0.0 to 1.0 indicating the importance of this resource. A value of 1 means "most important" (effectively required), while 0 means "least important". public let priority: Double? /// An ISO 8601 formatted timestamp indicating when the resource was last modified (e.g., "2025-01-12T15:00:58Z"). - public let lastModified: String + public let lastModified: String? public init( - audience: [Audience], + audience: [Audience]? = nil, priority: Double? = nil, - lastModified: String + lastModified: String? = nil ) { self.audience = audience self.priority = priority @@ -403,6 +403,18 @@ public enum ResourceSubscribe: Method { public typealias Result = Empty } +/// Sent from the client to request cancellation of resources/updated notifications from the server. This should follow a previous resources/subscribe request. +/// - SeeAlso: https://modelcontextprotocol.io/specification/2025-06-18/schema#unsubscriberequest +public enum ResourceUnsubscribe: Method { + public static let name: String = "resources/unsubscribe" + + public struct Parameters: Hashable, Codable, Sendable { + public let uri: String + } + + public typealias Result = Empty +} + /// When a resource changes, servers that declared the updated capability SHOULD send a notification to subscribed clients. /// - SeeAlso: https://spec.modelcontextprotocol.io/specification/2025-06-18/server/resources/#subscriptions public struct ResourceUpdatedNotification: Notification { diff --git a/Sources/MCP/Server/Server.swift b/Sources/MCP/Server/Server.swift index 2b03f3f2..cea0ebc7 100644 --- a/Sources/MCP/Server/Server.swift +++ b/Sources/MCP/Server/Server.swift @@ -90,6 +90,13 @@ public actor Server { public init() {} } + /// Completions capabilities + public struct Completions: Hashable, Codable, Sendable { + public init() {} + } + + /// Completions capabilities + public var completions: Completions? /// Logging capabilities public var logging: Logging? /// Prompts capabilities @@ -102,12 +109,14 @@ public actor Server { public var tools: Tools? public init( + completions: Completions? = nil, logging: Logging? = nil, prompts: Prompts? = nil, resources: Resources? = nil, sampling: Sampling? = nil, tools: Tools? = nil ) { + self.completions = completions self.logging = logging self.prompts = prompts self.resources = resources @@ -193,7 +202,8 @@ public actor Server { try await transport.connect() await logger?.debug( - "Server started", metadata: ["name": "\(name)", "version": "\(version)"]) + "Server started", metadata: ["name": "\(name)", "version": "\(version)"] + ) // Start message handling loop task = Task { @@ -380,6 +390,50 @@ public actor Server { "Bidirectional sampling requests not yet implemented in transport layer") } + // MARK: - Logging + + /// Send a log message notification to connected clients. + /// + /// Servers that declare the `logging` capability can send structured log messages + /// to clients. The client controls which severity levels it wants to receive via + /// the `logging/setLevel` request. + /// + /// - Parameters: + /// - level: The severity level of the log message + /// - logger: Optional logger name to identify the source + /// - data: Arbitrary JSON-serializable data for the log message + /// - Throws: MCPError if the server is not connected + /// - SeeAlso: https://modelcontextprotocol.io/specification/2025-11-25/server/utilities/logging/ + public func log( + level: LogLevel, + logger: String? = nil, + data: Value + ) async throws { + let notification = LogMessageNotification.message( + .init(level: level, logger: logger, data: data) + ) + try await notify(notification) + } + + /// Send a log message notification with codable data. + /// + /// Convenience method that encodes data to JSON before sending. + /// + /// - Parameters: + /// - level: The severity level of the log message + /// - logger: Optional logger name to identify the source + /// - data: Any codable data for the log message + /// - Throws: MCPError if the server is not connected or encoding fails + /// - SeeAlso: https://modelcontextprotocol.io/specification/2025-11-25/server/utilities/logging/ + public func log( + level: LogLevel, + logger: String? = nil, + data: T + ) async throws { + let value = try Value(data) + try await log(level: level, logger: logger, data: value) + } + /// A JSON-RPC batch containing multiple requests and/or notifications struct Batch: Sendable { /// An item in a JSON-RPC batch diff --git a/Sources/MCP/Server/Tools.swift b/Sources/MCP/Server/Tools.swift index 8cd0b06f..51403260 100644 --- a/Sources/MCP/Server/Tools.swift +++ b/Sources/MCP/Server/Tools.swift @@ -119,11 +119,8 @@ public struct Tool: Hashable, Codable, Sendable { case image(data: String, mimeType: String, metadata: Metadata?) /// Audio content case audio(data: String, mimeType: String) - /// Embedded resource content - case resource( - uri: String, mimeType: String, text: String?, title: String? = nil, - annotations: Resource.Annotations? = nil - ) + /// Embedded resource content (EmbeddedResource from spec) + case resource(resource: Resource.Content, annotations: Resource.Annotations? = nil, _meta: Metadata? = nil) /// Resource link case resourceLink( uri: String, name: String, title: String? = nil, description: String? = nil, @@ -167,15 +164,10 @@ public struct Tool: Hashable, Codable, Sendable { let mimeType = try container.decode(String.self, forKey: .mimeType) self = .audio(data: data, mimeType: mimeType) case "resource": - let uri = try container.decode(String.self, forKey: .uri) - let title = try container.decodeIfPresent(String.self, forKey: .title) - let mimeType = try container.decode(String.self, forKey: .mimeType) - let text = try container.decodeIfPresent(String.self, forKey: .text) - let annotations = try container.decodeIfPresent( - Resource.Annotations.self, forKey: .annotations) - self = .resource( - uri: uri, mimeType: mimeType, text: text, title: title, annotations: annotations - ) + let resourceContent = try container.decode(Resource.Content.self, forKey: .resource) + let annotations = try container.decodeIfPresent(Resource.Annotations.self, forKey: .annotations) + let _meta = try container.decodeIfPresent(Metadata.self, forKey: ._meta) + self = .resource(resource: resourceContent, annotations: annotations, _meta: _meta) case "resourceLink": let uri = try container.decode(String.self, forKey: .uri) let name = try container.decode(String.self, forKey: .name) @@ -209,13 +201,11 @@ public struct Tool: Hashable, Codable, Sendable { try container.encode("audio", forKey: .type) try container.encode(data, forKey: .data) try container.encode(mimeType, forKey: .mimeType) - case .resource(let uri, let mimeType, let text, let title, let annotations): + case .resource(let resourceContent, let annotations, let _meta): try container.encode("resource", forKey: .type) - try container.encode(uri, forKey: .uri) - try container.encode(mimeType, forKey: .mimeType) - try container.encodeIfPresent(text, forKey: .text) - try container.encodeIfPresent(title, forKey: .title) + try container.encode(resourceContent, forKey: .resource) try container.encodeIfPresent(annotations, forKey: .annotations) + try container.encodeIfPresent(_meta, forKey: ._meta) case .resourceLink( let uri, let name, let title, let description, let mimeType, let annotations): try container.encode("resourceLink", forKey: .type) diff --git a/Sources/MCPConformance/Client/main.swift b/Sources/MCPConformance/Client/main.swift new file mode 100644 index 00000000..60be6a53 --- /dev/null +++ b/Sources/MCPConformance/Client/main.swift @@ -0,0 +1,291 @@ +/** + * Everything client - a single conformance test client that handles all scenarios. + * + * Usage: mcp-everything-client + * + * The scenario name is read from the MCP_CONFORMANCE_SCENARIO environment variable, + * which is set by the conformance test runner. + * + * This client routes to the appropriate behavior based on the scenario name, + * consolidating all the individual test clients into one. + */ + +import Foundation +import Logging +import MCP + +// MARK: - Scenario Handlers + +typealias ScenarioHandler = ([String]) async throws -> Void + +// MARK: - Basic Scenarios + +/// Basic client that connects, initializes, and lists tools +func runInitializeScenario(_ args: [String]) async throws { + var logger = Logger( + label: "mcp.conformance.client.initialize", + factory: { StreamLogHandler.standardError(label: $0) } + ) + logger.logLevel = .debug + + logger.debug("Starting initialize scenario") + + // Get server URL from args + guard let serverURLString = args.last, + let serverURL = URL(string: serverURLString) else { + throw ConformanceError.invalidArguments("Valid server URL is required") + } + + // Create HTTP transport + let transport = HTTPClientTransport( + endpoint: serverURL, + logger: logger + ) + + // Create client + let client = Client(name: "test-client", version: "1.0.0") + + // Connect + let initResult = try await client.connect(transport: transport) + logger.debug("Successfully connected to MCP server", metadata: [ + "serverName": "\(initResult.serverInfo.name)", + "serverVersion": "\(initResult.serverInfo.version)" + ]) + + // List tools + let (tools, _) = try await client.listTools() + logger.debug("Successfully listed tools", metadata: [ + "toolCount": "\(tools.count)" + ]) + + // Disconnect + await client.disconnect() + + logger.debug("Initialize scenario completed successfully") +} + +/// Client that calls the add_numbers tool +func runToolsCallScenario(_ args: [String]) async throws { + var logger = Logger( + label: "mcp.conformance.client.tools_call", + factory: { StreamLogHandler.standardError(label: $0) } + ) + logger.logLevel = .debug + + logger.debug("Starting tools_call scenario") + + // Get server URL from args + guard let serverURLString = args.last, + let serverURL = URL(string: serverURLString) else { + throw ConformanceError.invalidArguments("Valid server URL is required") + } + + // Create HTTP transport + let transport = HTTPClientTransport( + endpoint: serverURL, + logger: logger + ) + + // Create client + let client = Client(name: "test-client", version: "1.0.0") + + // Connect + try await client.connect(transport: transport) + logger.debug("Successfully connected to MCP server") + + // List tools + let (tools, _) = try await client.listTools() + logger.debug("Successfully listed tools", metadata: [ + "toolCount": "\(tools.count)" + ]) + + // Call the add_numbers tool + if tools.contains(where: { $0.name == "add_numbers" }) { + let result = try await client.callTool( + name: "add_numbers", + arguments: ["a": 5, "b": 3] + ) + logger.debug("Tool call result", metadata: [ + "isError": "\(result.isError ?? false)", + "contentCount": "\(result.content.count)" + ]) + } else { + logger.warning("add_numbers tool not found") + } + + // Disconnect + await client.disconnect() + + logger.debug("Tools call scenario completed successfully") +} + +// MARK: - SSE Scenarios + +/// Handler for SSE-related scenarios (retry, reconnection, etc.) +func runSSEScenario(_ args: [String]) async throws { + var logger = Logger( + label: "mcp.conformance.client.sse", + factory: { StreamLogHandler.standardError(label: $0) } + ) + logger.logLevel = .debug + + logger.debug("Starting SSE scenario") + + // Get server URL from args + guard let serverURLString = args.last, + let serverURL = URL(string: serverURLString) else { + throw ConformanceError.invalidArguments("Valid server URL is required") + } + + // Create HTTP transport with streaming enabled + let transport = HTTPClientTransport( + endpoint: serverURL, + streaming: true, + logger: logger + ) + + // Create client + let client = Client(name: "test-client", version: "1.0.0") + + // Connect - this will start the SSE stream in the background + let initResult = try await client.connect(transport: transport) + logger.debug("Successfully connected to MCP server", metadata: [ + "serverName": "\(initResult.serverInfo.name)", + "serverVersion": "\(initResult.serverInfo.version)" + ]) + + // Give the GET SSE stream time to establish + try await Task.sleep(for: .milliseconds(500)) + + // Call the test_reconnection tool to trigger SSE stream closure and retry test. + // The server will close the POST SSE stream without the response, + // then deliver it on the GET SSE stream after we reconnect. + logger.debug("Calling test_reconnection tool...") + let result = try await client.callTool(name: "test_reconnection", arguments: [:]) + logger.debug("Tool call result received", metadata: [ + "isError": "\(result.isError ?? false)", + "contentCount": "\(result.content.count)" + ]) + + // Keep the connection open briefly for the test to collect timing data + try await Task.sleep(for: .seconds(2)) + + // Disconnect + await client.disconnect() + + logger.debug("SSE scenario completed") +} + +// MARK: - Default Handler for Unimplemented Scenarios + +/// Default handler that performs basic connection test for unimplemented scenarios +func runDefaultScenario(_ args: [String]) async throws { + var logger = Logger( + label: "mcp.conformance.client.default", + factory: { StreamLogHandler.standardError(label: $0) } + ) + logger.logLevel = .debug + + logger.debug("Running default scenario handler") + + // Get server URL from args + guard let serverURLString = args.last, + let serverURL = URL(string: serverURLString) else { + throw ConformanceError.invalidArguments("Valid server URL is required") + } + + // Create HTTP transport + let transport = HTTPClientTransport( + endpoint: serverURL, + logger: logger + ) + + // Create client + let client = Client(name: "test-client", version: "1.0.0") + + // Connect + let initResult = try await client.connect(transport: transport) + logger.debug("Successfully connected to MCP server", metadata: [ + "serverName": "\(initResult.serverInfo.name)", + "serverVersion": "\(initResult.serverInfo.version)" + ]) + + // Disconnect + await client.disconnect() + + logger.debug("Default scenario completed successfully") +} + +// MARK: - Scenario Registry + +nonisolated(unsafe) let scenarioHandlers: [String: ScenarioHandler] = [ + "initialize": runInitializeScenario, + "tools_call": runToolsCallScenario, + "sse-retry": runSSEScenario, + // Note: Other scenarios (elicitation, auth/*) will use the default handler +] + +// MARK: - Error Types + +enum ConformanceError: Error, CustomStringConvertible { + case missingScenario + case invalidArguments(String) + + var description: String { + switch self { + case .missingScenario: + return "MCP_CONFORMANCE_SCENARIO environment variable not set" + case .invalidArguments(let message): + return "Invalid arguments: \(message)" + } + } +} + +struct ConformanceClient { + static func run() async { + do { + // Get scenario from environment + guard let scenario = ProcessInfo.processInfo.environment["MCP_CONFORMANCE_SCENARIO"] else { + var stderr = StandardError() + print("Error: MCP_CONFORMANCE_SCENARIO environment variable not set", to: &stderr) + Foundation.exit(1) + } + + // Get server URL from arguments (last argument) + let args = Array(CommandLine.arguments.dropFirst()) + guard !args.isEmpty else { + var stderr = StandardError() + print("Usage: mcp-everything-client ", to: &stderr) + print("Error: Server URL is required", to: &stderr) + Foundation.exit(1) + } + + // Get handler for scenario, or use default if not implemented + let handler = scenarioHandlers[scenario] ?? runDefaultScenario + + // Log if using default handler + if scenarioHandlers[scenario] == nil { + var stderr = StandardError() + print("⚠️ Scenario '\(scenario)' not fully implemented - using default handler", to: &stderr) + } + + // Run the scenario + try await handler(args) + Foundation.exit(0) + } catch { + var stderr = StandardError() + print("Error: \(error)", to: &stderr) + Foundation.exit(1) + } + } +} + +// MARK: - Helpers + +struct StandardError: TextOutputStream { + mutating func write(_ string: String) { + FileHandle.standardError.write(Data(string.utf8)) + } +} + +await ConformanceClient.run() diff --git a/Sources/MCPConformance/Server/HTTPApp.swift b/Sources/MCPConformance/Server/HTTPApp.swift new file mode 100644 index 00000000..6145300f --- /dev/null +++ b/Sources/MCPConformance/Server/HTTPApp.swift @@ -0,0 +1,429 @@ +import Foundation +import Logging +import MCP +@preconcurrency import NIOCore +@preconcurrency import NIOPosix +@preconcurrency import NIOHTTP1 + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +actor HTTPApp { + /// Configuration for the HTTP application. + struct Configuration: Sendable { + /// The host address to bind to. + var host: String + + /// The port to bind to. + var port: Int + + /// The MCP endpoint path. + var endpoint: String + + /// Session timeout in seconds. + var sessionTimeout: TimeInterval + + /// SSE retry interval in milliseconds for priming events. + var retryInterval: Int? + + init( + host: String = "127.0.0.1", + port: Int = 3000, + endpoint: String = "/mcp", + sessionTimeout: TimeInterval = 3600, + retryInterval: Int? = nil + ) { + self.host = host + self.port = port + self.endpoint = endpoint + self.sessionTimeout = sessionTimeout + self.retryInterval = retryInterval + } + } + + /// Factory function to create MCP Server instances for each session. + typealias ServerFactory = @Sendable (String) async throws -> Server + + private let configuration: Configuration + private let serverFactory: ServerFactory + private let validationPipeline: (any HTTPRequestValidationPipeline)? + private var channel: Channel? + private var sessions: [String: SessionContext] = [:] + + nonisolated let logger: Logger + + struct SessionContext { + let server: Server + let transport: StatefulHTTPServerTransport + let createdAt: Date + var lastAccessedAt: Date + } + + // MARK: - Init + + /// Creates a new HTTP application. + /// + /// - Parameters: + /// - configuration: Application configuration. + /// - validationPipeline: Custom validation pipeline passed to each transport. + /// If `nil`, transports use their sensible defaults. + /// - serverFactory: Factory function to create Server instances for each session. + /// - logger: Optional logger instance. + init( + configuration: Configuration = Configuration(), + validationPipeline: (any HTTPRequestValidationPipeline)? = nil, + serverFactory: @escaping ServerFactory, + logger: Logger? = nil + ) { + self.configuration = configuration + self.serverFactory = serverFactory + self.validationPipeline = validationPipeline + self.logger = logger ?? Logger( + label: "mcp.http.app", + factory: { _ in SwiftLogNoOpLogHandler() } + ) + } + + /// Convenience initializer with individual parameters. + init( + host: String = "127.0.0.1", + port: Int = 3000, + endpoint: String = "/mcp", + serverFactory: @escaping ServerFactory, + logger: Logger? = nil + ) { + self.init( + configuration: Configuration(host: host, port: port, endpoint: endpoint), + serverFactory: serverFactory, + logger: logger + ) + } + + // MARK: - Lifecycle + + /// Starts the HTTP application. + /// + /// This starts the NIO HTTP server and begins accepting connections. + /// The call blocks until the server is shut down via ``stop()``. + func start() async throws { + let group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount) + + let bootstrap = ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.backlog, value: 256) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .childChannelInitializer { channel in + channel.pipeline.configureHTTPServerPipeline().flatMap { + channel.pipeline.addHandler(HTTPHandler(app: self)) + } + } + .childChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .childChannelOption(ChannelOptions.maxMessagesPerRead, value: 1) + + logger.info( + "Starting MCP HTTP application", + metadata: [ + "host": "\(configuration.host)", + "port": "\(configuration.port)", + "endpoint": "\(configuration.endpoint)", + ] + ) + + let channel = try await bootstrap.bind(host: configuration.host, port: configuration.port).get() + self.channel = channel + + Task { await sessionCleanupLoop() } + + try await channel.closeFuture.get() + } + + /// Stops the HTTP application gracefully, closing all sessions. + func stop() async { + await closeAllSessions() + try? await channel?.close() + channel = nil + logger.info("MCP HTTP application stopped") + } + + // MARK: - Request Routing + + var endpoint: String { configuration.endpoint } + + /// Routes an incoming HTTP request to the appropriate session transport. + /// + /// - Requests with a valid `Mcp-Session-Id` are forwarded to the matching transport. + /// - POST requests with an `initialize` body create a new session. + /// - All other requests without a session return an error. + func handleHTTPRequest(_ request: HTTPRequest) async -> HTTPResponse { + let sessionID = request.header(HTTPHeaderName.sessionID) + + // Route to existing session + if let sessionID, var session = sessions[sessionID] { + session.lastAccessedAt = Date() + sessions[sessionID] = session + + let response = await session.transport.handleRequest(request) + + // Clean up on successful DELETE + if request.method.uppercased() == "DELETE" && response.statusCode == 200 { + sessions.removeValue(forKey: sessionID) + } + + return response + } + + // No session — check for initialize request + if request.method.uppercased() == "POST", + let body = request.body, + let kind = JSONRPCMessageKind(data: body), + kind.isInitializeRequest + { + return await createSessionAndHandle(request) + } + + // No session and not initialize + if sessionID != nil { + return .error(statusCode: 404, .invalidRequest("Not Found: Session not found or expired")) + } + return .error( + statusCode: 400, + .invalidRequest("Bad Request: Missing \(HTTPHeaderName.sessionID) header") + ) + } + + // MARK: - Session Management + + private struct FixedSessionIDGenerator: SessionIDGenerator { + let sessionID: String + func generateSessionID() -> String { sessionID } + } + + private func createSessionAndHandle(_ request: HTTPRequest) async -> HTTPResponse { + let sessionID = UUID().uuidString + + let transport = StatefulHTTPServerTransport( + sessionIDGenerator: FixedSessionIDGenerator(sessionID: sessionID), + validationPipeline: validationPipeline, + retryInterval: configuration.retryInterval, + logger: logger + ) + + do { + let server = try await serverFactory(sessionID) + try await server.start(transport: transport) + + sessions[sessionID] = SessionContext( + server: server, + transport: transport, + createdAt: Date(), + lastAccessedAt: Date() + ) + + let response = await transport.handleRequest(request) + + // If transport returned an error, clean up + if case .error = response { + sessions.removeValue(forKey: sessionID) + await transport.disconnect() + } + + return response + } catch { + await transport.disconnect() + return .error( + statusCode: 500, + .internalError("Failed to create session: \(error.localizedDescription)") + ) + } + } + + private func closeSession(_ sessionID: String) async { + guard let session = sessions.removeValue(forKey: sessionID) else { return } + await session.transport.disconnect() + logger.info("Closed session", metadata: ["sessionID": "\(sessionID)"]) + } + + private func closeAllSessions() async { + for sessionID in sessions.keys { + await closeSession(sessionID) + } + } + + private func sessionCleanupLoop() async { + while true { + try? await Task.sleep(for: .seconds(60)) + + let now = Date() + let expired = sessions.filter { _, context in + now.timeIntervalSince(context.lastAccessedAt) > configuration.sessionTimeout + } + + for (sessionID, _) in expired { + logger.info("Session expired", metadata: ["sessionID": "\(sessionID)"]) + await closeSession(sessionID) + } + } + } +} + +// MARK: - NIO HTTP Handler + +/// Thin NIO adapter that converts between NIO HTTP types and the framework-agnostic +/// `HTTPRequest`/`HTTPResponse` types, delegating all logic to the `HTTPApp`. +private final class HTTPHandler: ChannelInboundHandler, @unchecked Sendable { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + private let app: HTTPApp + + private struct RequestState { + var head: HTTPRequestHead + var bodyBuffer: ByteBuffer + } + + private var requestState: RequestState? + + init(app: HTTPApp) { + self.app = app + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let part = unwrapInboundIn(data) + + switch part { + case .head(let head): + requestState = RequestState( + head: head, + bodyBuffer: context.channel.allocator.buffer(capacity: 0) + ) + case .body(var buffer): + requestState?.bodyBuffer.writeBuffer(&buffer) + case .end: + guard let state = requestState else { return } + requestState = nil + + nonisolated(unsafe) let ctx = context + Task { @MainActor in + await self.handleRequest(state: state, context: ctx) + } + } + } + + // MARK: - Request Processing + + private func handleRequest(state: RequestState, context: ChannelHandlerContext) async { + let head = state.head + let path = head.uri.split(separator: "?").first.map(String.init) ?? head.uri + let endpoint = await app.endpoint + + guard path == endpoint else { + await writeResponse( + .error(statusCode: 404, .invalidRequest("Not Found")), + version: head.version, + context: context + ) + return + } + + let httpRequest = makeHTTPRequest(from: state) + let response = await app.handleHTTPRequest(httpRequest) + await writeResponse(response, version: head.version, context: context) + } + + // MARK: - NIO ↔ HTTPRequest/HTTPResponse Conversion + + private func makeHTTPRequest(from state: RequestState) -> HTTPRequest { + // Combine multiple header values per RFC 7230 + var headers: [String: String] = [:] + for (name, value) in state.head.headers { + if let existing = headers[name] { + headers[name] = existing + ", " + value + } else { + headers[name] = value + } + } + + let body: Data? + if state.bodyBuffer.readableBytes > 0, + let bytes = state.bodyBuffer.getBytes(at: 0, length: state.bodyBuffer.readableBytes) + { + body = Data(bytes) + } else { + body = nil + } + + return HTTPRequest( + method: state.head.method.rawValue, + headers: headers, + body: body + ) + } + + private func writeResponse( + _ response: HTTPResponse, + version: HTTPVersion, + context: ChannelHandlerContext + ) async { + nonisolated(unsafe) let ctx = context + let eventLoop = ctx.eventLoop + + // Write response head + let statusCode = response.statusCode + let headers = response.headers + + switch response { + case .stream(let stream, _): + eventLoop.execute { + var head = HTTPResponseHead( + version: version, + status: HTTPResponseStatus(statusCode: statusCode) + ) + for (name, value) in headers { + head.headers.add(name: name, value: value) + } + ctx.write(self.wrapOutboundOut(.head(head)), promise: nil) + ctx.flush() + } + + // Await the SSE stream directly — no Task needed since we're already in one + do { + for try await chunk in stream { + eventLoop.execute { + var buffer = ctx.channel.allocator.buffer(capacity: chunk.count) + buffer.writeBytes(chunk) + ctx.writeAndFlush( + self.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: nil) + } + } + } catch { + // Stream ended with error — close connection + } + + eventLoop.execute { + ctx.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + } + + default: + let bodyData = response.bodyData + eventLoop.execute { + var head = HTTPResponseHead( + version: version, + status: HTTPResponseStatus(statusCode: statusCode) + ) + for (name, value) in headers { + head.headers.add(name: name, value: value) + } + + ctx.write(self.wrapOutboundOut(.head(head)), promise: nil) + + if let body = bodyData { + var buffer = ctx.channel.allocator.buffer(capacity: body.count) + buffer.writeBytes(body) + ctx.write(self.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: nil) + } + + ctx.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + } + } + } +} diff --git a/Sources/MCPConformance/Server/main.swift b/Sources/MCPConformance/Server/main.swift new file mode 100644 index 00000000..e2d7b84f --- /dev/null +++ b/Sources/MCPConformance/Server/main.swift @@ -0,0 +1,309 @@ +/** + * MCP HTTP Server Wrapper + * + * HTTP server that wraps the MCP conformance server for testing with the + * official conformance framework. + * + * Usage: mcp-http-server [--port PORT] + */ + +import Foundation +import Logging +import MCP + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +// MARK: - Test Data + +private let testImageBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" +private let testAudioBase64 = "UklGRiYAAABXQVZFZm10IBAAAAABAAEAQB8AAAB9AAACABAAZGF0YQIAAAA=" + +// MARK: - Server State + +actor ServerState { + var resourceSubscriptions: Set = [] + var watchedResourceContent = "Watched resource content" + + func subscribe(to uri: String) { + resourceSubscriptions.insert(uri) + } + + func unsubscribe(from uri: String) { + resourceSubscriptions.remove(uri) + } + + func isSubscribed(to uri: String) -> Bool { + resourceSubscriptions.contains(uri) + } + + func updateWatchedResource(_ newContent: String) { + watchedResourceContent = newContent + } +} + +// MARK: - Server Setup + +func createConformanceServer(state: ServerState) async -> Server { + let server = Server( + name: "mcp-conformance-test-server", + version: "1.0.0", + capabilities: Server.Capabilities( + logging: .init(), + prompts: .init(listChanged: true), + resources: .init(subscribe: true, listChanged: true), + tools: .init(listChanged: true) + ) + ) + + // Tools + await server.withMethodHandler(ListTools.self) { _ in + .init(tools: [ + Tool(name: "test_simple_text", description: "Tests simple text content response", inputSchema: .object(["type": "object", "properties": [:]])), + Tool(name: "test_image_content", description: "Tests image content response", inputSchema: .object(["type": "object", "properties": [:]])), + Tool(name: "test_audio_content", description: "Tests audio content response", inputSchema: .object(["type": "object", "properties": [:]])), + Tool(name: "test_embedded_resource", description: "Tests embedded resource content response", inputSchema: .object(["type": "object", "properties": [:]])), + Tool(name: "test_multiple_content_types", description: "Tests response with multiple content types", inputSchema: .object(["type": "object", "properties": [:]])), + Tool(name: "test_error_handling", description: "Tests error response handling", inputSchema: .object(["type": "object", "properties": [:]])), + Tool(name: "test_logging", description: "Tests logging capabilities", inputSchema: .object(["type": "object", "properties": [:]])), + Tool(name: "test_progress", description: "Tests progress notifications", inputSchema: .object(["type": "object", "properties": ["duration_ms": ["type": "number", "description": "Duration in milliseconds to report progress"]]])), + Tool(name: "add_numbers", description: "Adds two numbers together", inputSchema: .object(["type": "object", "properties": ["a": ["type": "number", "description": "First number"], "b": ["type": "number", "description": "Second number"]]])), + Tool(name: "test_tool_with_progress", description: "Tool reports progress notifications", inputSchema: .object(["type": "object", "properties": [:]])), + Tool(name: "test_tool_with_logging", description: "Tool sends log messages during execution", inputSchema: .object(["type": "object", "properties": [:]])), + Tool(name: "test_reconnection", description: "Tests SSE reconnection and resumption with Last-Event-ID", inputSchema: .object(["type": "object", "properties": [:]])) + ]) + } + + await server.withMethodHandler(CallTool.self) { [weak server] params in + switch params.name { + case "test_simple_text": + return .init(content: [.text("This is a simple text response for testing.")], isError: false) + case "test_image_content": + return .init(content: [.image(data: testImageBase64, mimeType: "image/png", metadata: nil)], isError: false) + case "test_audio_content": + return .init(content: [.audio(data: testAudioBase64, mimeType: "audio/wav")], isError: false) + case "test_embedded_resource": + return .init(content: [.resource(resource: .text("This is an embedded resource content.", uri: "test://embedded-resource", mimeType: "text/plain"))], isError: false) + case "test_multiple_content_types": + return .init(content: [ + .text("Multiple content types test:"), + .image(data: testImageBase64, mimeType: "image/png", metadata: nil), + .resource(resource: .text("{\"test\":\"data\",\"value\":123}", uri: "test://mixed-content-resource", mimeType: "application/json"))], isError: false) + case "test_error_handling": + return .init(content: [.text("An error occurred during tool execution")], isError: true) + case "test_logging": + return .init(content: [.text("Logging test completed")], isError: false) + case "test_progress": + let duration = params.arguments?["duration_ms"]?.intValue ?? 1000 + try? await Task.sleep(for: .milliseconds(duration)) + return .init(content: [.text("Progress test completed")], isError: false) + case "add_numbers": + guard let a = params.arguments?["a"]?.intValue, let b = params.arguments?["b"]?.intValue else { + return .init(content: [.text("Invalid arguments: expected numbers a and b")], isError: true) + } + return .init(content: [.text("\(a + b)")], isError: false) + case "test_tool_with_progress": + if let token = params._meta?.progressToken { + let notification1 = ProgressNotification.message( + .init(progressToken: token, progress: 0, total: 100) + ) + try await server?.notify(notification1) + try await Task.sleep(for: .microseconds(50)) + + let notification2 = ProgressNotification.message( + .init(progressToken: token, progress: 50, total: 100) + ) + try await server?.notify(notification2) + try await Task.sleep(for: .microseconds(50)) + + let notification3 = ProgressNotification.message( + .init(progressToken: token, progress: 100, total: 100) + ) + try await server?.notify(notification3) + } + + return .init(content: [.text("This is a simple text response for testing.")], isError: false) + case "test_tool_with_logging": + // Send first log message + let log1 = LogMessageNotification.message( + .init(level: .info, data: .string("Tool execution started")) + ) + try await server?.notify(log1) + + // Wait 50ms + try await Task.sleep(for: .milliseconds(50)) + + // Send second log message + let log2 = LogMessageNotification.message( + .init(level: .info, data: .string("Tool processing data")) + ) + try await server?.notify(log2) + + // Wait another 50ms + try await Task.sleep(for: .milliseconds(50)) + + // Send third log message + let log3 = LogMessageNotification.message( + .init(level: .info, data: .string("Tool execution completed")) + ) + try await server?.notify(log3) + + return .init(content: [.text("Logging test completed")], isError: false) + case "test_reconnection": + // This tool tests SSE reconnection behavior (SEP-1699) + // In a full implementation, the server would close the SSE stream mid-call + // and the client would need to reconnect with Last-Event-ID to get the result. + // For now, we return a simple success response. + return .init(content: [.text("Reconnection test completed successfully")], isError: false) + default: + return .init(content: [.text("Unknown tool: \(params.name)")], isError: true) + } + } + + // Resources + await server.withMethodHandler(ListResources.self) { _ in + .init(resources: [ + Resource(name: "Static Text Resource", uri: "test://static-text", description: "A simple static text resource", mimeType: "text/plain"), + Resource(name: "Static Binary Resource", uri: "test://static-binary", description: "A simple static binary resource", mimeType: "application/octet-stream"), + Resource(name: "Watched Resource", uri: "test://watched", description: "A resource that can be subscribed to for updates", mimeType: "text/plain"), + Resource(name: "Template Resource", uri: "test://template/{id}", description: "A resource template with URI parameters", mimeType: "text/plain"), + ]) + } + + await server.withMethodHandler(ReadResource.self) { params in + switch params.uri { + case "test://static-text": + return .init(contents: [.text("This is static text content for testing.", uri: params.uri, mimeType: "text/plain")]) + case "test://static-binary": + guard let imageData = Data(base64Encoded: testImageBase64) else { + return .init(contents: [.text("Failed to decode binary data", uri: params.uri)]) + } + return .init(contents: [.binary(imageData, uri: params.uri, mimeType: "application/octet-stream")]) + case "test://watched": + let content = await state.watchedResourceContent + return .init(contents: [.text(content, uri: params.uri)]) + default: + if params.uri.hasPrefix("test://template/") { + let id = String(params.uri.dropFirst("test://template/".count)) + return .init(contents: [.text("Template resource with id: \(id)", uri: params.uri)]) + } + return .init(contents: [.text("Resource not found: \(params.uri)", uri: params.uri)]) + } + } + + await server.withMethodHandler(ResourceSubscribe.self) { params in + await state.subscribe(to: params.uri) + return Empty() + } + + await server.withMethodHandler(ResourceUnsubscribe.self) { params in + await state.unsubscribe(from: params.uri) + return Empty() + } + + // Prompts + await server.withMethodHandler(ListPrompts.self) { _ in + .init(prompts: [ + Prompt(name: "test_simple_prompt", description: "A simple prompt without arguments"), + Prompt(name: "test_prompt_with_arguments", description: "A prompt that accepts arguments", arguments: [Prompt.Argument(name: "arg1", description: "First test argument", required: true), Prompt.Argument(name: "arg2", description: "Second test argument", required: true)]), + Prompt(name: "test_prompt_with_embedded_resource", description: "A prompt that includes embedded resources", arguments: [Prompt.Argument(name: "resourceUri", description: "URI of the resource to embed", required: true)]), + Prompt(name: "test_prompt_with_image", description: "A prompt with image content"), + ]) + } + + await server.withMethodHandler(GetPrompt.self) { params in + switch params.name { + case "test_simple_prompt": + return .init(description: "Simple prompt response", messages: [.user(.text(text: "This is a simple prompt for testing."))]) + case "test_prompt_with_arguments": + let arg1 = params.arguments?["arg1"]?.stringValue ?? "default1" + let arg2 = params.arguments?["arg2"]?.stringValue ?? "default2" + return .init(description: "Prompt with arguments", messages: [.user(.text(text: "Prompt with arguments: arg1='\(arg1)', arg2='\(arg2)'"))]) + case "test_prompt_with_embedded_resource": + let resourceUri = params.arguments?["resourceUri"]?.stringValue ?? "test://default" + return .init(description: "Prompt with embedded resource", messages: [ + .user(.resource(resource: .text("Embedded resource content for testing.", uri: resourceUri, mimeType: "text/plain"))), + .user(.text(text: "Please process the embedded resource above.")) + ]) + case "test_prompt_with_image": + return .init(description: "Prompt with image", messages: [ + .user(.image(data: testImageBase64, mimeType: "image/png")), + .user(.text(text: "Please analyze the image above.")) + ]) + default: + throw MCPError.invalidRequest("Unknown prompt: \(params.name)") + } + } + + await server.withMethodHandler(SetLoggingLevel.self) { _ in + // Accept any logging level (debug, info, notice, warning, error, critical, alert, emergency) + // For conformance testing, we just accept it without doing anything + return Empty() + } + + await server.withMethodHandler(Complete.self) { _ in + return .init(completion: .init(values: [])) + } + + return server +} + +// MARK: - HTTP Server + +// HTTPApp handles all HTTP server functionality + +// MARK: - Main + +struct MCPHTTPServer { + static func run() async throws { + let args = CommandLine.arguments + var port = 3001 + + for (index, arg) in args.enumerated() { + if arg == "--port" && index + 1 < args.count { + if let p = Int(args[index + 1]) { + port = p + } + } + } + + var loggerConfig = Logger(label: "mcp.http.server", factory: { StreamLogHandler.standardError(label: $0) }) + loggerConfig.logLevel = .trace + let logger = loggerConfig + + let state = ServerState() + + logger.info("Starting MCP HTTP Server...", metadata: ["port": "\(port)"]) + + // Create HTTPApp with server factory + let app = HTTPApp( + configuration: .init( + host: "127.0.0.1", + port: port, + endpoint: "/mcp" + ), + validationPipeline: StandardValidationPipeline(validators: [ + OriginValidator.localhost(port: port), + AcceptHeaderValidator(mode: .sseRequired), + ContentTypeValidator(), + ProtocolVersionValidator(), + SessionValidator(), + ]), + serverFactory: { sessionID in + logger.debug("Creating server for session", metadata: ["sessionID": "\(sessionID)"]) + return await createConformanceServer(state: state) + }, + logger: logger + ) + + try await app.start() + } +} + +do { + try await MCPHTTPServer.run() +} catch { + print(error) + exit(1) +} diff --git a/Tests/MCPTests/CompletionTests.swift b/Tests/MCPTests/CompletionTests.swift new file mode 100644 index 00000000..12ef983e --- /dev/null +++ b/Tests/MCPTests/CompletionTests.swift @@ -0,0 +1,598 @@ +import Foundation +import Testing + +@testable import MCP + +@Suite("Completion Tests") +struct CompletionTests { + // MARK: - Reference Types Tests + + @Test("PromptReference initialization and encoding") + func testPromptReferenceEncodingDecoding() throws { + let ref = PromptReference(name: "code_review") + + let encoder = JSONEncoder() + encoder.outputFormatting = [.sortedKeys, .withoutEscapingSlashes] + + let data = try encoder.encode(ref) + let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] + + #expect(json?["type"] as? String == "ref/prompt") + #expect(json?["name"] as? String == "code_review") + + // Test decoding + let decoder = JSONDecoder() + let decoded = try decoder.decode(PromptReference.self, from: data) + #expect(decoded.name == "code_review") + } + + @Test("ResourceReference initialization and encoding") + func testResourceReferenceEncodingDecoding() throws { + let ref = ResourceReference(uri: "file:///path/to/resource") + + let encoder = JSONEncoder() + encoder.outputFormatting = [.sortedKeys, .withoutEscapingSlashes] + + let data = try encoder.encode(ref) + let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] + + #expect(json?["type"] as? String == "ref/resource") + #expect(json?["uri"] as? String == "file:///path/to/resource") + + // Test decoding + let decoder = JSONDecoder() + let decoded = try decoder.decode(ResourceReference.self, from: data) + #expect(decoded.uri == "file:///path/to/resource") + } + + @Test("CompletionReference prompt case encoding") + func testCompletionReferencePromptEncoding() throws { + let ref = CompletionReference.prompt(PromptReference(name: "test")) + + let encoder = JSONEncoder() + encoder.outputFormatting = [.sortedKeys, .withoutEscapingSlashes] + + let data = try encoder.encode(ref) + let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] + + #expect(json?["type"] as? String == "ref/prompt") + #expect(json?["name"] as? String == "test") + } + + @Test("CompletionReference resource case encoding") + func testCompletionReferenceResourceEncoding() throws { + let ref = CompletionReference.resource(ResourceReference(uri: "file:///test")) + + let encoder = JSONEncoder() + encoder.outputFormatting = [.sortedKeys, .withoutEscapingSlashes] + + let data = try encoder.encode(ref) + let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] + + #expect(json?["type"] as? String == "ref/resource") + #expect(json?["uri"] as? String == "file:///test") + } + + @Test("CompletionReference decoding prompt type") + func testCompletionReferenceDecodingPrompt() throws { + let json = """ + { + "type": "ref/prompt", + "name": "code_review" + } + """ + + let decoder = JSONDecoder() + let ref = try decoder.decode(CompletionReference.self, from: json.data(using: .utf8)!) + + if case .prompt(let promptRef) = ref { + #expect(promptRef.name == "code_review") + } else { + Issue.record("Expected prompt reference") + } + } + + @Test("CompletionReference decoding resource type") + func testCompletionReferenceDecodingResource() throws { + let json = """ + { + "type": "ref/resource", + "uri": "file:///path" + } + """ + + let decoder = JSONDecoder() + let ref = try decoder.decode(CompletionReference.self, from: json.data(using: .utf8)!) + + if case .resource(let resourceRef) = ref { + #expect(resourceRef.uri == "file:///path") + } else { + Issue.record("Expected resource reference") + } + } + + // MARK: - Complete Request Tests + + @Test("Complete request initialization") + func testCompleteRequestInitialization() throws { + let ref = CompletionReference.prompt(PromptReference(name: "code_review")) + let argument = Complete.Parameters.Argument(name: "language", value: "py") + let request = Complete.request(.init(ref: ref, argument: argument)) + + #expect(request.method == "completion/complete") + #expect(request.params.argument.name == "language") + #expect(request.params.argument.value == "py") + } + + @Test("Complete request with context") + func testCompleteRequestWithContext() throws { + let ref = CompletionReference.prompt(PromptReference(name: "code_review")) + let argument = Complete.Parameters.Argument(name: "framework", value: "fla") + let context = Complete.Parameters.Context(arguments: ["language": .string("python")]) + + let request = Complete.request(.init(ref: ref, argument: argument, context: context)) + + #expect(request.params.context != nil) + #expect(request.params.context?.arguments["language"] == .string("python")) + } + + @Test("Complete request encoding") + func testCompleteRequestEncoding() throws { + let ref = CompletionReference.prompt(PromptReference(name: "code_review")) + let argument = Complete.Parameters.Argument(name: "language", value: "py") + let request = Complete.request(.init(ref: ref, argument: argument)) + + let encoder = JSONEncoder() + encoder.outputFormatting = [.sortedKeys, .withoutEscapingSlashes] + + let data = try encoder.encode(request) + let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] + + #expect(json?["jsonrpc"] as? String == "2.0") + #expect(json?["method"] as? String == "completion/complete") + + guard let params = json?["params"] as? [String: Any] else { + Issue.record("Failed to get params") + return + } + guard let refDict = params["ref"] as? [String: Any] else { + Issue.record("Failed to get ref") + return + } + #expect(refDict["type"] as? String == "ref/prompt") + #expect(refDict["name"] as? String == "code_review") + + guard let arg = params["argument"] as? [String: Any] else { + Issue.record("Failed to get argument") + return + } + #expect(arg["name"] as? String == "language") + #expect(arg["value"] as? String == "py") + } + + @Test("Complete request decoding") + func testCompleteRequestDecoding() throws { + let json = """ + { + "jsonrpc": "2.0", + "id": "test-id", + "method": "completion/complete", + "params": { + "ref": { + "type": "ref/prompt", + "name": "code_review" + }, + "argument": { + "name": "language", + "value": "py" + } + } + } + """ + + let decoder = JSONDecoder() + let request = try decoder.decode(Request.self, from: json.data(using: .utf8)!) + + #expect(request.method == "completion/complete") + #expect(request.params.argument.name == "language") + #expect(request.params.argument.value == "py") + + if case .prompt(let promptRef) = request.params.ref { + #expect(promptRef.name == "code_review") + } else { + Issue.record("Expected prompt reference") + } + } + + // MARK: - Complete Result Tests + + @Test("Complete result initialization") + func testCompleteResultInitialization() throws { + let completion = Complete.Result.Completion( + values: ["python", "pytorch", "pyside"], + total: 10, + hasMore: true + ) + let result = Complete.Result(completion: completion) + + #expect(result.completion.values.count == 3) + #expect(result.completion.values[0] == "python") + #expect(result.completion.total == 10) + #expect(result.completion.hasMore == true) + } + + @Test("Complete result encoding") + func testCompleteResultEncoding() throws { + let completion = Complete.Result.Completion( + values: ["python", "pytorch"], + total: 2, + hasMore: false + ) + let result = Complete.Result(completion: completion) + + let encoder = JSONEncoder() + encoder.outputFormatting = [.sortedKeys, .withoutEscapingSlashes] + + let data = try encoder.encode(result) + let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] + + let completionDict = json?["completion"] as? [String: Any] + let values = completionDict?["values"] as? [String] + #expect(values == ["python", "pytorch"]) + #expect(completionDict?["total"] as? Int == 2) + #expect(completionDict?["hasMore"] as? Bool == false) + } + + @Test("Complete result decoding") + func testCompleteResultDecoding() throws { + let json = """ + { + "completion": { + "values": ["python", "pytorch", "pyside"], + "total": 10, + "hasMore": true + } + } + """ + + let decoder = JSONDecoder() + let result = try decoder.decode(Complete.Result.self, from: json.data(using: .utf8)!) + + #expect(result.completion.values.count == 3) + #expect(result.completion.values == ["python", "pytorch", "pyside"]) + #expect(result.completion.total == 10) + #expect(result.completion.hasMore == true) + } + + @Test("Complete result with optional fields") + func testCompleteResultWithOptionalFields() throws { + let completion = Complete.Result.Completion( + values: ["value1"], + total: nil, + hasMore: nil + ) + + #expect(completion.values == ["value1"]) + #expect(completion.total == nil) + #expect(completion.hasMore == nil) + } + + // MARK: - Client Integration Tests + + @Test("Client complete for prompt argument") + func testClientCompleteForPrompt() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + let client = Client(name: "TestClient", version: "1.0") + let server = Server( + name: "TestServer", + version: "1.0", + capabilities: .init(completions: .init()) + ) + + // Register handler for complete on server + await server.withMethodHandler(Complete.self) { params in + #expect(params.argument.name == "language") + #expect(params.argument.value == "py") + + if case .prompt(let promptRef) = params.ref { + #expect(promptRef.name == "code_review") + } else { + Issue.record("Expected prompt reference") + } + + return .init( + completion: .init( + values: ["python", "pytorch", "pyside"], + total: 10, + hasMore: true + ) + ) + } + + try await server.start(transport: serverTransport) + let initResult = try await client.connect(transport: clientTransport) + + // Verify completions capability is advertised + #expect(initResult.capabilities.completions != nil) + + // Request completions + let completion = try await client.complete( + promptName: "code_review", + argumentName: "language", + argumentValue: "py" + ) + + #expect(completion.values == ["python", "pytorch", "pyside"]) + #expect(completion.total == 10) + #expect(completion.hasMore == true) + + await client.disconnect() + await server.stop() + } + + @Test("Client complete for resource argument") + func testClientCompleteForResource() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + let client = Client(name: "TestClient", version: "1.0") + let server = Server( + name: "TestServer", + version: "1.0", + capabilities: .init(completions: .init()) + ) + + // Register handler for complete on server + await server.withMethodHandler(Complete.self) { params in + #expect(params.argument.name == "path") + #expect(params.argument.value == "/usr/") + + if case .resource(let resourceRef) = params.ref { + #expect(resourceRef.uri == "file:///{path}") + } else { + Issue.record("Expected resource reference") + } + + return .init( + completion: .init( + values: ["/usr/bin", "/usr/lib", "/usr/local"], + total: 3, + hasMore: false + ) + ) + } + + try await server.start(transport: serverTransport) + _ = try await client.connect(transport: clientTransport) + + // Request completions for resource + let completion = try await client.complete( + resourceURI: "file:///{path}", + argumentName: "path", + argumentValue: "/usr/" + ) + + #expect(completion.values == ["/usr/bin", "/usr/lib", "/usr/local"]) + #expect(completion.total == 3) + #expect(completion.hasMore == false) + + await client.disconnect() + await server.stop() + } + + @Test("Client complete with context") + func testClientCompleteWithContext() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + let client = Client(name: "TestClient", version: "1.0") + let server = Server( + name: "TestServer", + version: "1.0", + capabilities: .init(completions: .init()) + ) + + // Register handler for complete on server + await server.withMethodHandler(Complete.self) { params in + #expect(params.argument.name == "framework") + #expect(params.argument.value == "fla") + #expect(params.context != nil) + #expect(params.context?.arguments["language"] == .string("python")) + + return .init( + completion: .init( + values: ["flask"], + total: 1, + hasMore: false + ) + ) + } + + try await server.start(transport: serverTransport) + _ = try await client.connect(transport: clientTransport) + + // Request completions with context + let completion = try await client.complete( + promptName: "code_review", + argumentName: "framework", + argumentValue: "fla", + context: ["language": .string("python")] + ) + + #expect(completion.values == ["flask"]) + #expect(completion.total == 1) + #expect(completion.hasMore == false) + + await client.disconnect() + await server.stop() + } + + @Test("Client complete fails without completions capability") + func testClientCompleteFailsWithoutCapability() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + let client = Client(name: "TestClient", version: "1.0", configuration: .strict) + let server = Server( + name: "TestServer", + version: "1.0", + capabilities: .init() // No completions capability + ) + + try await server.start(transport: serverTransport) + let initResult = try await client.connect(transport: clientTransport) + + // Verify completions capability is NOT advertised + #expect(initResult.capabilities.completions == nil) + + // Attempt to request completions should fail in strict mode + await #expect(throws: MCPError.self) { + try await client.complete( + promptName: "test", + argumentName: "arg", + argumentValue: "val" + ) + } + + await client.disconnect() + await server.stop() + } + + @Test("Empty completion values") + func testEmptyCompletionValues() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + let client = Client(name: "TestClient", version: "1.0") + let server = Server( + name: "TestServer", + version: "1.0", + capabilities: .init(completions: .init()) + ) + + // Register handler that returns empty results + await server.withMethodHandler(Complete.self) { _ in + return .init( + completion: .init( + values: [], + total: 0, + hasMore: false + ) + ) + } + + try await server.start(transport: serverTransport) + _ = try await client.connect(transport: clientTransport) + + let completion = try await client.complete( + promptName: "test", + argumentName: "arg", + argumentValue: "xyz" + ) + + #expect(completion.values.isEmpty) + #expect(completion.total == 0) + #expect(completion.hasMore == false) + + await client.disconnect() + await server.stop() + } + + @Test("Maximum completion values (100 items)") + func testMaximumCompletionValues() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + let client = Client(name: "TestClient", version: "1.0") + let server = Server( + name: "TestServer", + version: "1.0", + capabilities: .init(completions: .init()) + ) + + // Register handler that returns 100 items + await server.withMethodHandler(Complete.self) { _ in + let values = (1...100).map { "value\($0)" } + return .init( + completion: .init( + values: values, + total: 200, + hasMore: true + ) + ) + } + + try await server.start(transport: serverTransport) + _ = try await client.connect(transport: clientTransport) + + let completion = try await client.complete( + promptName: "test", + argumentName: "arg", + argumentValue: "" + ) + + #expect(completion.values.count == 100) + #expect(completion.values.first == "value1") + #expect(completion.values.last == "value100") + #expect(completion.total == 200) + #expect(completion.hasMore == true) + + await client.disconnect() + await server.stop() + } + + @Test("Fuzzy matching completion scenario") + func testFuzzyMatchingScenario() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + let client = Client(name: "TestClient", version: "1.0") + let server = Server( + name: "TestServer", + version: "1.0", + capabilities: .init(completions: .init()) + ) + + // Register handler that implements fuzzy matching + await server.withMethodHandler(Complete.self) { params in + let input = params.argument.value.lowercased() + let allLanguages = ["python", "perl", "php", "pascal", "prolog", "javascript", "java"] + + // Simple prefix matching + let matches = allLanguages.filter { $0.lowercased().hasPrefix(input) } + + return .init( + completion: .init( + values: matches, + total: matches.count, + hasMore: false + ) + ) + } + + try await server.start(transport: serverTransport) + _ = try await client.connect(transport: clientTransport) + + // Test with "p" prefix + let completion1 = try await client.complete( + promptName: "language_selector", + argumentName: "language", + argumentValue: "p" + ) + #expect(completion1.values.count == 5) // python, perl, php, pascal, prolog + + // Test with "py" prefix + let completion2 = try await client.complete( + promptName: "language_selector", + argumentName: "language", + argumentValue: "py" + ) + #expect(completion2.values == ["python"]) + + // Test with "ja" prefix + let completion3 = try await client.complete( + promptName: "language_selector", + argumentName: "language", + argumentValue: "ja" + ) + #expect(completion3.values == ["javascript", "java"]) + + await client.disconnect() + await server.stop() + } +} diff --git a/Tests/MCPTests/HTTPClientTransportTests.swift b/Tests/MCPTests/HTTPClientTransportTests.swift index cf2a25d8..b1867740 100644 --- a/Tests/MCPTests/HTTPClientTransportTests.swift +++ b/Tests/MCPTests/HTTPClientTransportTests.swift @@ -209,12 +209,12 @@ import Testing await MockURLProtocol.requestHandlerStorage.setHandler { [testEndpoint] (request: URLRequest) in - #expect(request.value(forHTTPHeaderField: "Mcp-Session-Id") == nil) + #expect(request.value(forHTTPHeaderField: "MCP-Session-Id") == nil) let response = HTTPURLResponse( url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", headerFields: [ "Content-Type": "application/json", - "Mcp-Session-Id": newSessionID, + "MCP-Session-Id": newSessionID, ])! return (response, Data()) } @@ -247,12 +247,12 @@ import Testing await MockURLProtocol.requestHandlerStorage.setHandler { [testEndpoint] (request: URLRequest) in #expect(request.readBody() == firstMessageData) - #expect(request.value(forHTTPHeaderField: "Mcp-Session-Id") == nil) + #expect(request.value(forHTTPHeaderField: "MCP-Session-Id") == nil) let response = HTTPURLResponse( url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", headerFields: [ "Content-Type": "application/json", - "Mcp-Session-Id": initialSessionID, + "MCP-Session-Id": initialSessionID, ])! return (response, Data()) } @@ -262,7 +262,7 @@ import Testing await MockURLProtocol.requestHandlerStorage.setHandler { [testEndpoint] (request: URLRequest) in #expect(request.readBody() == secondMessageData) - #expect(request.value(forHTTPHeaderField: "Mcp-Session-Id") == initialSessionID) + #expect(request.value(forHTTPHeaderField: "MCP-Session-Id") == initialSessionID) let response = HTTPURLResponse( url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", @@ -368,7 +368,7 @@ import Testing url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", headerFields: [ "Content-Type": "application/json", - "Mcp-Session-Id": initialSessionID, + "MCP-Session-Id": initialSessionID, ])! return (response, Data()) } @@ -387,7 +387,7 @@ import Testing // Set up the second handler for the 404 response await MockURLProtocol.requestHandlerStorage.setHandler { [testEndpoint, initialSessionID] (request: URLRequest) in - #expect(request.value(forHTTPHeaderField: "Mcp-Session-Id") == initialSessionID) + #expect(request.value(forHTTPHeaderField: "MCP-Session-Id") == initialSessionID) let response = HTTPURLResponse( url: testEndpoint, statusCode: 404, httpVersion: "HTTP/1.1", headerFields: nil)! return (response, Data("Not Found".utf8)) @@ -450,7 +450,7 @@ import Testing #expect(request.httpMethod == "GET") #expect(request.value(forHTTPHeaderField: "Accept") == "text/event-stream") #expect( - request.value(forHTTPHeaderField: "Mcp-Session-Id") == "test-session-123") + request.value(forHTTPHeaderField: "MCP-Session-Id") == "test-session-123") let response = HTTPURLResponse( url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", @@ -512,7 +512,7 @@ import Testing #expect(request.httpMethod == "GET") #expect(request.value(forHTTPHeaderField: "Accept") == "text/event-stream") #expect( - request.value(forHTTPHeaderField: "Mcp-Session-Id") == "test-session-123") + request.value(forHTTPHeaderField: "MCP-Session-Id") == "test-session-123") let response = HTTPURLResponse( url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", @@ -724,6 +724,40 @@ import Testing try await transport.send(messageData) await transport.disconnect() } + + @Test("Send With Protocol Version Header", .httpClientTransportSetup) + func testProtocolVersionHeader() async throws { + let configuration = URLSessionConfiguration.ephemeral + configuration.protocolClasses = [MockURLProtocol.self] + + let protocolVersion = "2025-11-25" + let transport = HTTPClientTransport( + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + protocolVersion: protocolVersion, + logger: nil + ) + try await transport.connect() + + let messageData = #"{"jsonrpc":"2.0","method":"test","id":6}"#.data(using: .utf8)! + + await MockURLProtocol.requestHandlerStorage.setHandler { + [testEndpoint, protocolVersion] (request: URLRequest) in + // Verify the protocol version header is present + #expect( + request.value(forHTTPHeaderField: "MCP-Protocol-Version") + == protocolVersion) + + let response = HTTPURLResponse( + url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (response, Data()) + } + + try await transport.send(messageData) + await transport.disconnect() + } #endif // !canImport(FoundationNetworking) } #endif // swift(>=6.1) diff --git a/Tests/MCPTests/HTTPServerTransportTests.swift b/Tests/MCPTests/HTTPServerTransportTests.swift new file mode 100644 index 00000000..d92b6973 --- /dev/null +++ b/Tests/MCPTests/HTTPServerTransportTests.swift @@ -0,0 +1,899 @@ +import Foundation +import Testing + +@testable import MCP + +// MARK: - Test Helpers + +private struct FixedSessionIDGenerator: SessionIDGenerator { + let sessionID: String + func generateSessionID() -> String { sessionID } +} + +private func makeInitializeBody(id: String = "1") -> Data { + let json: [String: Any] = [ + "jsonrpc": "2.0", + "id": id, + "method": "initialize", + "params": [ + "protocolVersion": "2025-11-25", + "capabilities": [:] as [String: Any], + "clientInfo": ["name": "test", "version": "1.0"], + ] as [String: Any], + ] + return try! JSONSerialization.data(withJSONObject: json) +} + +private func makeNotificationBody(method: String = "notifications/initialized") -> Data { + let json: [String: Any] = ["jsonrpc": "2.0", "method": method] + return try! JSONSerialization.data(withJSONObject: json) +} + +private func makeRequestBody(id: String = "2", method: String = "tools/list") -> Data { + let json: [String: Any] = [ + "jsonrpc": "2.0", + "id": id, + "method": method, + "params": [:] as [String: Any], + ] + return try! JSONSerialization.data(withJSONObject: json) +} + +private func makeResponseBody(id: String = "2") -> Data { + let json: [String: Any] = [ + "jsonrpc": "2.0", + "id": id, + "result": ["tools": []] as [String: Any], + ] + return try! JSONSerialization.data(withJSONObject: json) +} + +private func makeStatefulPOSTRequest(body: Data, sessionID: String? = nil) -> HTTPRequest { + var headers: [String: String] = [ + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + ] + if let sessionID { + headers["Mcp-Session-Id"] = sessionID + } + return HTTPRequest(method: "POST", headers: headers, body: body) +} + +private func makeGETRequest(sessionID: String, lastEventID: String? = nil) -> HTTPRequest { + var headers: [String: String] = [ + "Accept": "text/event-stream", + "Mcp-Session-Id": sessionID, + ] + if let lastEventID { + headers["Last-Event-Id"] = lastEventID + } + return HTTPRequest(method: "GET", headers: headers) +} + +private func makeDELETERequest(sessionID: String) -> HTTPRequest { + HTTPRequest( + method: "DELETE", + headers: ["Mcp-Session-Id": sessionID] + ) +} + +private func makeStatelessPOSTRequest(body: Data) -> HTTPRequest { + HTTPRequest( + method: "POST", + headers: [ + "Content-Type": "application/json", + "Accept": "application/json", + ], + body: body + ) +} + +private func makeStatefulTransport( + sessionIDGenerator: any SessionIDGenerator = UUIDSessionIDGenerator() +) -> StatefulHTTPServerTransport { + StatefulHTTPServerTransport( + sessionIDGenerator: sessionIDGenerator, + validationPipeline: StandardValidationPipeline(validators: []) + ) +} + +private func makeStatelessTransport() -> StatelessHTTPServerTransport { + StatelessHTTPServerTransport( + validationPipeline: StandardValidationPipeline(validators: []) + ) +} + +/// Drains an SSE stream, collecting raw SSE chunks. +private actor ChunkCollector { + var chunks: [Data] = [] + func append(_ data: Data) { chunks.append(data) } + func getChunks() -> [Data] { chunks } +} + +private func drainSSEStream( + _ response: HTTPResponse, + maxChunks: Int = 10, + timeout: Duration = .seconds(2) +) async -> [Data] { + guard case .stream(let stream, _) = response else { return [] } + let collector = ChunkCollector() + let task = Task { + for try await chunk in stream { + await collector.append(chunk) + if await collector.getChunks().count >= maxChunks { break } + } + } + // Wait for stream to finish or timeout + try? await Task.sleep(for: timeout) + task.cancel() + return await collector.getChunks() +} + +/// Initializes a stateful transport session and returns the session ID. +/// Spawns a background task to consume the receive stream and send the init response. +private func initializeSession( + transport: StatefulHTTPServerTransport, + sessionID: String? = nil +) async throws -> String { + try await transport.connect() + + let initBody = makeInitializeBody() + + // Background task: read the init request from receive() and send back a response + let respondTask = Task { + let stream = await transport.receive() + for try await data in stream { + // Check if this is the initialize request + if let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any], + let method = json["method"] as? String, method == "initialize", + let id = json["id"] + { + let idString: String + if let s = id as? String { idString = s } + else if let n = id as? Int { idString = String(n) } + else { continue } + + let responseJSON: [String: Any] = [ + "jsonrpc": "2.0", + "id": idString, + "result": [ + "protocolVersion": "2025-11-25", + "serverInfo": ["name": "test", "version": "1.0"], + "capabilities": [:] as [String: Any], + ] as [String: Any], + ] + let responseData = try JSONSerialization.data(withJSONObject: responseJSON) + try await transport.send(responseData) + return + } + } + } + + let response = await transport.handleRequest( + makeStatefulPOSTRequest(body: initBody) + ) + + // Extract session ID + guard let sid = response.headers[HTTPHeaderName.sessionID] else { + throw MCPError.internalError("No session ID in init response") + } + + // Drain the SSE stream so the response task can complete + if case .stream(let stream, _) = response { + Task { for try await _ in stream {} } + } + + // Wait for the respond task + try? await respondTask.value + + return sid +} + +// MARK: - StatefulHTTPServerTransport Tests + +@Suite("StatefulHTTPServerTransport Tests") +struct StatefulHTTPServerTransportTests { + + // MARK: - Lifecycle + + @Test("Connect succeeds") + func testConnectSucceeds() async throws { + let transport = makeStatefulTransport() + try await transport.connect() + await transport.disconnect() + } + + @Test("Double connect throws") + func testDoubleConnectThrows() async throws { + let transport = makeStatefulTransport() + try await transport.connect() + do { + try await transport.connect() + Issue.record("Expected error on double connect") + } catch { + // Expected + } + await transport.disconnect() + } + + @Test("Send after disconnect throws connectionClosed") + func testSendAfterDisconnectThrows() async throws { + let transport = makeStatefulTransport() + try await transport.connect() + await transport.disconnect() + do { + try await transport.send(Data("test".utf8)) + Issue.record("Expected connectionClosed error") + } catch let error as MCPError { + #expect(error == .connectionClosed) + } + } + + // MARK: - POST Initialize + + @Test("Initialize creates session and returns SSE stream") + func testInitializeCreatesSession() async throws { + let transport = makeStatefulTransport( + sessionIDGenerator: FixedSessionIDGenerator(sessionID: "test-session-42") + ) + try await transport.connect() + + let response = await transport.handleRequest( + makeStatefulPOSTRequest(body: makeInitializeBody()) + ) + + #expect(response.statusCode == 200) + #expect(response.headers[HTTPHeaderName.sessionID] == "test-session-42") + + if case .stream = response { + // Expected + } else { + Issue.record("Expected .stream response, got \(response)") + } + + // Drain stream + if case .stream(let stream, _) = response { + Task { for try await _ in stream {} } + } + await transport.disconnect() + } + + @Test("Initialize with invalid session ID returns 500") + func testInitializeWithInvalidSessionIDReturns500() async throws { + // Control character \t is 0x09, outside valid range 0x21-0x7E + let transport = makeStatefulTransport( + sessionIDGenerator: FixedSessionIDGenerator(sessionID: "bad\tsession") + ) + try await transport.connect() + + let response = await transport.handleRequest( + makeStatefulPOSTRequest(body: makeInitializeBody()) + ) + + #expect(response.statusCode == 500) + } + + @Test("Custom SessionIDGenerator is used") + func testCustomSessionIDGenerator() async throws { + let transport = makeStatefulTransport( + sessionIDGenerator: FixedSessionIDGenerator(sessionID: "custom-id-abc") + ) + try await transport.connect() + + let response = await transport.handleRequest( + makeStatefulPOSTRequest(body: makeInitializeBody()) + ) + + #expect(response.headers[HTTPHeaderName.sessionID] == "custom-id-abc") + + if case .stream(let stream, _) = response { + Task { for try await _ in stream {} } + } + await transport.disconnect() + } + + @Test("Default UUIDSessionIDGenerator produces valid session ID") + func testDefaultGeneratorProducesUUID() async throws { + let transport = makeStatefulTransport() + try await transport.connect() + + let response = await transport.handleRequest( + makeStatefulPOSTRequest(body: makeInitializeBody()) + ) + + let sessionID = response.headers[HTTPHeaderName.sessionID] + #expect(sessionID != nil) + // UUID format: 8-4-4-4-12 hex chars + if let sid = sessionID { + #expect(sid.count == 36) + #expect(sid.contains("-")) + } + + if case .stream(let stream, _) = response { + Task { for try await _ in stream {} } + } + await transport.disconnect() + } + + // MARK: - POST Notification + + @Test("Notification returns 202 Accepted") + func testNotificationReturns202() async throws { + let transport = makeStatefulTransport() + let sessionID = try await initializeSession(transport: transport) + + let response = await transport.handleRequest( + makeStatefulPOSTRequest( + body: makeNotificationBody(), + sessionID: sessionID + ) + ) + + #expect(response.statusCode == 202) + await transport.disconnect() + } + + @Test("Notification yields to receive stream") + func testNotificationYieldsToReceive() async throws { + let transport = makeStatefulTransport() + let sessionID = try await initializeSession(transport: transport) + + let notificationBody = makeNotificationBody(method: "notifications/test") + + // Start receiving + let receiveTask = Task { + let stream = await transport.receive() + for try await data in stream { + // Skip init request if still in stream + if let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any], + let method = json["method"] as? String, method == "notifications/test" + { + return data + } + } + return nil + } + + // Small delay to let receive() start + try await Task.sleep(for: .milliseconds(50)) + + _ = await transport.handleRequest( + makeStatefulPOSTRequest(body: notificationBody, sessionID: sessionID) + ) + + let received = try await receiveTask.value + #expect(received != nil) + + await transport.disconnect() + } + + // MARK: - POST Request/Response + + @Test("POST request returns SSE stream") + func testRequestReturnsSSEStream() async throws { + let transport = makeStatefulTransport() + let sessionID = try await initializeSession(transport: transport) + + let response = await transport.handleRequest( + makeStatefulPOSTRequest( + body: makeRequestBody(id: "req-1"), + sessionID: sessionID + ) + ) + + #expect(response.statusCode == 200) + if case .stream = response { + // Expected + } else { + Issue.record("Expected .stream response") + } + + if case .stream(let stream, _) = response { + Task { for try await _ in stream {} } + } + await transport.disconnect() + } + + @Test("Response is routed to matching request SSE stream") + func testResponseRoutedToRequestStream() async throws { + let transport = makeStatefulTransport() + let sessionID = try await initializeSession(transport: transport) + + let requestID = "route-test-1" + + // POST a request + let response = await transport.handleRequest( + makeStatefulPOSTRequest( + body: makeRequestBody(id: requestID, method: "tools/list"), + sessionID: sessionID + ) + ) + + guard case .stream(let stream, _) = response else { + Issue.record("Expected .stream response") + return + } + + // Collect SSE chunks in background + let collectTask = Task { + var chunks: [Data] = [] + for try await chunk in stream { + chunks.append(chunk) + } + return chunks + } + + // Give stream time to start + try await Task.sleep(for: .milliseconds(50)) + + // Consume the request from receive and send the response + let responseBody = makeResponseBody(id: requestID) + try await transport.send(responseBody) + + // Collect all SSE chunks + let chunks = try await collectTask.value + + // Should have at least one chunk containing the response data + let allText = chunks.map { String(decoding: $0, as: UTF8.self) }.joined() + #expect(allText.contains("data:")) + #expect(allText.contains(requestID)) + + await transport.disconnect() + } + + // MARK: - GET Stream + + @Test("GET returns standalone SSE stream") + func testGetReturnsSSEStream() async throws { + let transport = makeStatefulTransport() + let sessionID = try await initializeSession(transport: transport) + + let response = await transport.handleRequest( + makeGETRequest(sessionID: sessionID) + ) + + #expect(response.statusCode == 200) + if case .stream = response { + // Expected + } else { + Issue.record("Expected .stream response for GET") + } + + if case .stream(let stream, _) = response { + Task { for try await _ in stream {} } + } + await transport.disconnect() + } + + @Test("Server-initiated message routed to GET stream") + func testServerMessageRoutedToGetStream() async throws { + let transport = makeStatefulTransport() + let sessionID = try await initializeSession(transport: transport) + + // Open GET stream + let getResponse = await transport.handleRequest( + makeGETRequest(sessionID: sessionID) + ) + + guard case .stream(let stream, _) = getResponse else { + Issue.record("Expected .stream response for GET") + return + } + + // Collect chunks + let collectTask = Task { + var chunks: [Data] = [] + for try await chunk in stream { + chunks.append(chunk) + // priming + message + if chunks.count >= 2 { break } + } + return chunks + } + + try await Task.sleep(for: .milliseconds(50)) + + // Send a notification (server-initiated) + let notification: [String: Any] = [ + "jsonrpc": "2.0", + "method": "notifications/test", + "params": [:] as [String: Any], + ] + let notifData = try JSONSerialization.data(withJSONObject: notification) + try await transport.send(notifData) + + let chunks = try await collectTask.value + let allText = chunks.map { String(decoding: $0, as: UTF8.self) }.joined() + #expect(allText.contains("data:")) + // JSONSerialization may escape "/" as "\/" in some configurations + #expect(allText.contains("notifications/test") || allText.contains("notifications\\/test")) + + await transport.disconnect() + } + + @Test("Second GET returns 409 Conflict") + func testSecondGetReturns409() async throws { + let transport = makeStatefulTransport() + let sessionID = try await initializeSession(transport: transport) + + // First GET + let first = await transport.handleRequest(makeGETRequest(sessionID: sessionID)) + #expect(first.statusCode == 200) + + // Second GET + let second = await transport.handleRequest(makeGETRequest(sessionID: sessionID)) + #expect(second.statusCode == 409) + + if case .stream(let stream, _) = first { + Task { for try await _ in stream {} } + } + await transport.disconnect() + } + + // MARK: - DELETE + + @Test("DELETE terminates session") + func testDeleteTerminatesSession() async throws { + let transport = makeStatefulTransport() + let sessionID = try await initializeSession(transport: transport) + + let response = await transport.handleRequest( + makeDELETERequest(sessionID: sessionID) + ) + + #expect(response.statusCode == 200) + } + + @Test("Requests after DELETE return 404") + func testRequestsAfterDeleteReturn404() async throws { + let transport = makeStatefulTransport() + let sessionID = try await initializeSession(transport: transport) + + // DELETE + _ = await transport.handleRequest(makeDELETERequest(sessionID: sessionID)) + + // POST after delete + let response = await transport.handleRequest( + makeStatefulPOSTRequest( + body: makeRequestBody(), + sessionID: sessionID + ) + ) + + #expect(response.statusCode == 404) + } + + // MARK: - Terminated State + + @Test("All methods return 404 when terminated") + func testAllMethodsReturn404WhenTerminated() async throws { + let transport = makeStatefulTransport() + let sessionID = try await initializeSession(transport: transport) + await transport.disconnect() + + let post = await transport.handleRequest( + makeStatefulPOSTRequest(body: makeRequestBody(), sessionID: sessionID) + ) + #expect(post.statusCode == 404) + + let get = await transport.handleRequest(makeGETRequest(sessionID: sessionID)) + #expect(get.statusCode == 404) + + let delete = await transport.handleRequest(makeDELETERequest(sessionID: sessionID)) + #expect(delete.statusCode == 404) + } + + // MARK: - Error Cases + + @Test("Unsupported method returns 405") + func testUnsupportedMethodReturns405() async throws { + let transport = makeStatefulTransport() + let sessionID = try await initializeSession(transport: transport) + + let response = await transport.handleRequest( + HTTPRequest( + method: "PUT", + headers: ["Mcp-Session-Id": sessionID], + body: Data("test".utf8) + ) + ) + + #expect(response.statusCode == 405) + await transport.disconnect() + } + + @Test("Empty body returns 400") + func testEmptyBodyReturns400() async throws { + let transport = makeStatefulTransport() + let sessionID = try await initializeSession(transport: transport) + + let response = await transport.handleRequest( + makeStatefulPOSTRequest(body: Data(), sessionID: sessionID) + ) + + #expect(response.statusCode == 400) + await transport.disconnect() + } + + @Test("Invalid JSON body returns 400") + func testInvalidJSONReturns400() async throws { + let transport = makeStatefulTransport() + let sessionID = try await initializeSession(transport: transport) + + let response = await transport.handleRequest( + makeStatefulPOSTRequest(body: Data("not json".utf8), sessionID: sessionID) + ) + + #expect(response.statusCode == 400) + await transport.disconnect() + } + + // MARK: - Resumability + + @Test("GET with Last-Event-ID replays stored events") + func testGetWithLastEventIDReplaysEvents() async throws { + let transport = makeStatefulTransport() + let sessionID = try await initializeSession(transport: transport) + + // POST a request to create events in the store + let requestID = "resume-test" + let postResponse = await transport.handleRequest( + makeStatefulPOSTRequest( + body: makeRequestBody(id: requestID), + sessionID: sessionID + ) + ) + + guard case .stream(let postStream, _) = postResponse else { + Issue.record("Expected .stream") + return + } + + // Collect the priming event to get its ID + let eventIDHolder = ChunkCollector() + let collectTask = Task { + for try await chunk in postStream { + await eventIDHolder.append(chunk) + break // Just get the first chunk (priming) + } + } + + try await Task.sleep(for: .milliseconds(50)) + + // Send the response to create a stored event + try await transport.send(makeResponseBody(id: requestID)) + + try? await collectTask.value + + // Parse event ID from the collected priming event + let collectedChunks = await eventIDHolder.getChunks() + let primingEventID: String? = collectedChunks.first.flatMap { chunk in + let text = String(decoding: chunk, as: UTF8.self) + guard let range = text.range(of: "id: ") else { return nil } + let afterID = text[range.upperBound...] + guard let newline = afterID.firstIndex(of: "\n") else { return nil } + return String(afterID[...self, from: json.data(using: .utf8)!) + + #expect(request.method == "logging/setLevel") + #expect(request.params.level == .warning) + } + + @Test("SetLoggingLevel response") + func testSetLoggingLevelResponse() throws { + let response = SetLoggingLevel.response(id: .random) + + if case .success = response.result { + // Success case + } else { + Issue.record("Expected success result") + } + } + + // MARK: - LogMessageNotification Tests + + @Test("LogMessageNotification initialization") + func testLogMessageNotificationInitialization() throws { + let data = Value.object([ + "message": Value.string("Test log message"), + "code": Value.int(42) + ]) + + let params = LogMessageNotification.Parameters( + level: .info, + logger: "test-logger", + data: data + ) + + #expect(params.level == LogLevel.info) + #expect(params.logger == "test-logger") + #expect(params.data == data) + } + + @Test("LogMessageNotification with nil logger") + func testLogMessageNotificationWithNilLogger() throws { + let data = Value.object(["message": Value.string("Test")]) + + let params = LogMessageNotification.Parameters( + level: .debug, + logger: nil, + data: data + ) + + #expect(params.level == LogLevel.debug) + #expect(params.logger == nil) + #expect(params.data == data) + } + + @Test("LogMessageNotification encoding") + func testLogMessageNotificationEncoding() throws { + let data = Value.object([ + "error": Value.string("Connection failed"), + "details": .object([ + "host": Value.string("localhost"), + "port": Value.int(5432) + ]) + ]) + + let notification = LogMessageNotification.message( + .init(level: .error, logger: "database", data: data) + ) + + let encoder = JSONEncoder() + encoder.outputFormatting = [.sortedKeys, .withoutEscapingSlashes] + + let encodedData = try encoder.encode(notification) + let json = try JSONSerialization.jsonObject(with: encodedData) as? [String: Any] + + guard let jsonValue = json else { + Issue.record("Failed to parse JSON") + return + } + + #expect(jsonValue["jsonrpc"] as? String == "2.0") + #expect(jsonValue["method"] as? String == "notifications/message") + + guard let params = jsonValue["params"] as? [String: Any] else { + Issue.record("Failed to get params") + return + } + #expect(params["level"] as? String == "error") + #expect(params["logger"] as? String == "database") + + guard let dataDict = params["data"] as? [String: Any] else { + Issue.record("Failed to get data") + return + } + #expect(dataDict["error"] as? String == "Connection failed") + } + + @Test("LogMessageNotification decoding") + func testLogMessageNotificationDecoding() throws { + let json = """ + { + "jsonrpc": "2.0", + "method": "notifications/message", + "params": { + "level": "info", + "logger": "app", + "data": { + "message": "Server started", + "port": 8080 + } + } + } + """ + + let decoder = JSONDecoder() + let notification = try decoder.decode(Message.self, from: json.data(using: .utf8)!) + + #expect(notification.method == "notifications/message") + #expect(notification.params.level == LogLevel.info) + #expect(notification.params.logger == "app") + + if case .object(let dataDict) = notification.params.data { + #expect(dataDict["message"] == Value.string("Server started")) + #expect(dataDict["port"] == Value.int(8080)) + } else { + Issue.record("Expected object data") + } + } + + // MARK: - Client Integration Tests + + @Test("Client setLoggingLevel sends correct request") + func testClientSetLoggingLevel() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + let client = Client(name: "TestClient", version: "1.0") + let server = Server( + name: "TestServer", + version: "1.0", + capabilities: .init(logging: .init()) + ) + + actor TestState { + var receivedLevel: LogLevel? + func setLevel(_ level: LogLevel) { receivedLevel = level } + func getLevel() -> LogLevel? { receivedLevel } + } + + let state = TestState() + + // Register handler for setLoggingLevel on server + await server.withMethodHandler(SetLoggingLevel.self) { params in + await state.setLevel(params.level) + return Empty() + } + + try await server.start(transport: serverTransport) + let initResult = try await client.connect(transport: clientTransport) + + // Verify logging capability is advertised + #expect(initResult.capabilities.logging != nil) + + // Call setLoggingLevel + try await client.setLoggingLevel(.warning) + + // Give time for message processing + try await Task.sleep(for: .milliseconds(100)) + + // Verify the handler was called + #expect(await state.getLevel() == .warning) + + await client.disconnect() + await server.stop() + } + + @Test("Client setLoggingLevel fails without logging capability") + func testClientSetLoggingLevelFailsWithoutCapability() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + let client = Client(name: "TestClient", version: "1.0", configuration: .strict) + let server = Server( + name: "TestServer", + version: "1.0", + capabilities: .init() // No logging capability + ) + + try await server.start(transport: serverTransport) + let initResult = try await client.connect(transport: clientTransport) + + // Verify logging capability is NOT advertised + #expect(initResult.capabilities.logging == nil) + + // Attempt to set logging level should fail in strict mode + await #expect(throws: MCPError.self) { + try await client.setLoggingLevel(.info) + } + + await client.disconnect() + await server.stop() + } + + // MARK: - Server Integration Tests + + @Test("Server log method sends notification") + func testServerLogMethod() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + let client = Client(name: "TestClient", version: "1.0") + let server = Server( + name: "TestServer", + version: "1.0", + capabilities: .init(logging: .init()) + ) + + actor TestState { + var logMessages: [(level: LogLevel, logger: String?, data: Value)] = [] + func addLog(level: LogLevel, logger: String?, data: Value) { + logMessages.append((level, logger, data)) + } + func getLogs() -> [(level: LogLevel, logger: String?, data: Value)] { logMessages } + } + + let state = TestState() + + // Register handler for log notifications on client + await client.onNotification(LogMessageNotification.self) { message in + await state.addLog( + level: message.params.level, + logger: message.params.logger, + data: message.params.data + ) + } + + try await server.start(transport: serverTransport) + _ = try await client.connect(transport: clientTransport) + + // Send a log message + let logData = Value.object([ + "message": Value.string("Test log"), + "count": Value.int(42) + ]) + + try await server.log(level: .info, logger: "test", data: logData) + + // Wait for message processing + try await Task.sleep(for: .milliseconds(100)) + + // Verify the notification was received + let logs = await state.getLogs() + #expect(logs.count == 1) + #expect(logs[0].level == LogLevel.info) + #expect(logs[0].logger == "test") + #expect(logs[0].data == logData) + + await client.disconnect() + await server.stop() + } + + @Test("Server log method with codable data") + func testServerLogMethodWithCodableData() async throws { + struct LogData: Codable, Hashable { + let message: String + let timestamp: String + let code: Int + } + + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + let client = Client(name: "TestClient", version: "1.0") + let server = Server( + name: "TestServer", + version: "1.0", + capabilities: .init(logging: .init()) + ) + + actor TestState { + var logMessages: [(level: LogLevel, logger: String?, data: Value)] = [] + func addLog(level: LogLevel, logger: String?, data: Value) { + logMessages.append((level, logger, data)) + } + func getLogs() -> [(level: LogLevel, logger: String?, data: Value)] { logMessages } + } + + let state = TestState() + + // Register handler for log notifications on client + await client.onNotification(LogMessageNotification.self) { message in + await state.addLog( + level: message.params.level, + logger: message.params.logger, + data: message.params.data + ) + } + + try await server.start(transport: serverTransport) + _ = try await client.connect(transport: clientTransport) + + // Send a log message with codable data + let logData = LogData( + message: "Error occurred", + timestamp: "2025-01-29T12:00:00Z", + code: 500 + ) + + try await server.log(level: .error, logger: "api", data: logData) + + // Wait for message processing + try await Task.sleep(for: .milliseconds(100)) + + // Verify the notification was received + let logs = await state.getLogs() + #expect(logs.count == 1) + #expect(logs[0].level == LogLevel.error) + #expect(logs[0].logger == "api") + + // Verify data content + if case .object(let dataDict) = logs[0].data { + #expect(dataDict["message"] == Value.string("Error occurred")) + #expect(dataDict["timestamp"] == Value.string("2025-01-29T12:00:00Z")) + #expect(dataDict["code"] == Value.int(500)) + } else { + Issue.record("Expected object data") + } + + await client.disconnect() + await server.stop() + } + + @Test("Server log without logger name") + func testServerLogWithoutLoggerName() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + let client = Client(name: "TestClient", version: "1.0") + let server = Server( + name: "TestServer", + version: "1.0", + capabilities: .init(logging: .init()) + ) + + actor TestState { + var logMessages: [(level: LogLevel, logger: String?, data: Value)] = [] + func addLog(level: LogLevel, logger: String?, data: Value) { + logMessages.append((level, logger, data)) + } + func getLogs() -> [(level: LogLevel, logger: String?, data: Value)] { logMessages } + } + + let state = TestState() + + // Register handler for log notifications on client + await client.onNotification(LogMessageNotification.self) { message in + await state.addLog( + level: message.params.level, + logger: message.params.logger, + data: message.params.data + ) + } + + try await server.start(transport: serverTransport) + _ = try await client.connect(transport: clientTransport) + + // Send a log message without logger name + let logData = Value.object(["message": Value.string("Generic log")]) + try await server.log(level: .debug, data: logData) + + // Wait for message processing + try await Task.sleep(for: .milliseconds(100)) + + // Verify the notification was received + let logs = await state.getLogs() + #expect(logs.count == 1) + #expect(logs[0].level == LogLevel.debug) + #expect(logs[0].logger == nil) + + await client.disconnect() + await server.stop() + } + + @Test("Multiple log levels sent correctly") + func testMultipleLogLevels() async throws { + let (clientTransport, serverTransport) = await InMemoryTransport.createConnectedPair() + + let client = Client(name: "TestClient", version: "1.0") + let server = Server( + name: "TestServer", + version: "1.0", + capabilities: .init(logging: .init()) + ) + + actor TestState { + var logMessages: [(level: LogLevel, logger: String?, data: Value)] = [] + func addLog(level: LogLevel, logger: String?, data: Value) { + logMessages.append((level, logger, data)) + } + func getLogs() -> [(level: LogLevel, logger: String?, data: Value)] { logMessages } + } + + let state = TestState() + + // Register handler for log notifications on client + await client.onNotification(LogMessageNotification.self) { message in + await state.addLog( + level: message.params.level, + logger: message.params.logger, + data: message.params.data + ) + } + + try await server.start(transport: serverTransport) + _ = try await client.connect(transport: clientTransport) + + // Send log messages at different levels + try await server.log(level: .debug, data: Value.object(["msg": Value.string("Debug message")])) + try await server.log(level: .info, data: Value.object(["msg": Value.string("Info message")])) + try await server.log(level: .warning, data: Value.object(["msg": Value.string("Warning message")])) + try await server.log(level: .error, data: Value.object(["msg": Value.string("Error message")])) + try await server.log(level: .critical, data: Value.object(["msg": Value.string("Critical message")])) + + // Wait for message processing + try await Task.sleep(for: .milliseconds(200)) + + // Verify all notifications were received + let logs = await state.getLogs() + #expect(logs.count == 5) + #expect(logs[0].level == LogLevel.debug) + #expect(logs[1].level == LogLevel.info) + #expect(logs[2].level == LogLevel.warning) + #expect(logs[3].level == LogLevel.error) + #expect(logs[4].level == LogLevel.critical) + + await client.disconnect() + await server.stop() + } +} diff --git a/Tests/MCPTests/PromptTests.swift b/Tests/MCPTests/PromptTests.swift index 561bdfca..6eb864fb 100644 --- a/Tests/MCPTests/PromptTests.swift +++ b/Tests/MCPTests/PromptTests.swift @@ -88,19 +88,20 @@ struct PromptTests { } // Test resource content - let resourceContent = Prompt.Message.Content.resource( + let textResourceContent = Resource.Content.text( + "Sample text", uri: "file://test.txt", - mimeType: "text/plain", - text: "Sample text", - blob: "blob_data" + mimeType: "text/plain" ) + let resourceContent = Prompt.Message.Content.resource(resource: textResourceContent, annotations: nil, _meta: nil) let resourceData = try encoder.encode(resourceContent) let decodedResource = try decoder.decode(Prompt.Message.Content.self, from: resourceData) - if case .resource(let uri, let mimeType, let text, let blob) = decodedResource { - #expect(uri == "file://test.txt") - #expect(mimeType == "text/plain") - #expect(text == "Sample text") - #expect(blob == "blob_data") + if case .resource(let resourceData, let annotations, let _meta) = decodedResource { + #expect(resourceData.uri == "file://test.txt") + #expect(resourceData.mimeType == "text/plain") + #expect(resourceData.text == "Sample text") + #expect(annotations == nil) + #expect(_meta == nil) } else { #expect(Bool(false), "Expected resource content") } @@ -249,15 +250,19 @@ struct PromptTests { } // Test with resource content - let resourceMessage: Prompt.Message = .user( - .resource( - uri: "file://test.txt", mimeType: "text/plain", text: "Sample text", blob: nil)) + let resourceContent = Resource.Content.text( + "Sample text", + uri: "file://test.txt", + mimeType: "text/plain" + ) + let resourceMessage: Prompt.Message = .user(.resource(resource: resourceContent, annotations: nil, _meta: nil)) #expect(resourceMessage.role == .user) - if case .resource(let uri, let mimeType, let text, let blob) = resourceMessage.content { - #expect(uri == "file://test.txt") - #expect(mimeType == "text/plain") - #expect(text == "Sample text") - #expect(blob == nil) + if case .resource(let resource, let annotations, let _meta) = resourceMessage.content { + #expect(resource.uri == "file://test.txt") + #expect(resource.mimeType == "text/plain") + #expect(resource.text == "Sample text") + #expect(annotations == nil) + #expect(_meta == nil) } else { #expect(Bool(false), "Expected resource content") } diff --git a/Tests/MCPTests/ToolTests.swift b/Tests/MCPTests/ToolTests.swift index 41367b66..73ee1b6e 100644 --- a/Tests/MCPTests/ToolTests.swift +++ b/Tests/MCPTests/ToolTests.swift @@ -279,23 +279,24 @@ struct ToolTests { @Test("Resource content encoding and decoding") func testToolContentResourceEncoding() throws { - let content = Tool.Content.resource( + let resourceContent = Resource.Content.text( + "Sample text", uri: "file://test.txt", - mimeType: "text/plain", - text: "Sample text" + mimeType: "text/plain" ) + let content = Tool.Content.resource(resource: resourceContent, annotations: nil, _meta: nil) let encoder = JSONEncoder() let decoder = JSONDecoder() let data = try encoder.encode(content) let decoded = try decoder.decode(Tool.Content.self, from: data) - if case .resource(let uri, let mimeType, let text, let title, let annotations) = decoded { - #expect(uri == "file://test.txt") - #expect(mimeType == "text/plain") - #expect(text == "Sample text") - #expect(title == nil) + if case .resource(let resource, let annotations, let _meta) = decoded { + #expect(resource.uri == "file://test.txt") + #expect(resource.mimeType == "text/plain") + #expect(resource.text == "Sample text") #expect(annotations == nil) + #expect(_meta == nil) } else { #expect(Bool(false), "Expected resource content") } diff --git a/conformance-baseline.yml b/conformance-baseline.yml new file mode 100644 index 00000000..1421cec6 --- /dev/null +++ b/conformance-baseline.yml @@ -0,0 +1,28 @@ +client: + - elicitation-sep1034-client-defaults + - auth/metadata-default + - auth/metadata-var1 + - auth/metadata-var2 + - auth/metadata-var3 + - auth/basic-cimd + - auth/scope-from-www-authenticate + - auth/scope-from-scopes-supported + - auth/scope-omitted-when-undefined + - auth/scope-step-up + - auth/scope-retry-limit + - auth/token-endpoint-auth-basic + - auth/token-endpoint-auth-post + - auth/token-endpoint-auth-none + - auth/pre-registration + - auth/2025-03-26-oauth-metadata-backcompat + - auth/2025-03-26-oauth-endpoint-fallback + - auth/client-credentials-jwt + - auth/client-credentials-basic + +server: + - tools-call-sampling + - tools-call-elicitation + - json-schema-2020-12 + - elicitation-sep1034-defaults + - server-sse-polling + - elicitation-sep1330-enums diff --git a/scripts/run-conformance.sh b/scripts/run-conformance.sh new file mode 100755 index 00000000..4c8c54d4 --- /dev/null +++ b/scripts/run-conformance.sh @@ -0,0 +1,113 @@ +#!/bin/bash +set -euo pipefail + +# Color output helpers +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +log_info() { echo -e "${GREEN}[INFO]${NC} $*"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $*"; } +log_error() { echo -e "${RED}[ERROR]${NC} $*"; } + +# Configuration +CONFORMANCE_PKG="@modelcontextprotocol/conformance" +CLIENT_EXEC="mcp-everything-client" +SERVER_EXEC="mcp-everything-server" +BASELINE_FILE="${BASELINE_FILE:-conformance-baseline.yml}" +MODE="${MODE:-both}" + +# Parse arguments +while [[ $# -gt 0 ]]; do + case $1 in + --mode) + MODE="$2" + shift 2 + ;; + --baseline) + BASELINE_FILE="$2" + shift 2 + ;; + *) + log_error "Unknown option: $1" + exit 1 + ;; + esac +done + +# Validate mode +if [[ ! "$MODE" =~ ^(client|server|both)$ ]]; then + log_error "Invalid mode: $MODE. Must be one of: client, server, both" + exit 1 +fi + +# Build Swift executables +log_info "Building Swift executables..." +swift build --product "$CLIENT_EXEC" || { + log_error "Failed to build client" + exit 1 +} +swift build --product "$SERVER_EXEC" || { + log_error "Failed to build server" + exit 1 +} + +CLIENT_PATH="$(swift build --show-bin-path)/$CLIENT_EXEC" +SERVER_PATH="$(swift build --show-bin-path)/$SERVER_EXEC" + +log_info "Client executable: $CLIENT_PATH" +log_info "Server executable: $SERVER_PATH" + +# Check for baseline file +BASELINE_ARG="" +if [[ -f "$BASELINE_FILE" ]]; then + log_info "Using baseline file: $BASELINE_FILE" + BASELINE_ARG="--expected-failures $BASELINE_FILE" +else + log_warn "No baseline file found at $BASELINE_FILE" +fi + +# Run client tests +if [[ "$MODE" == "client" || "$MODE" == "both" ]]; then + log_info "Running client conformance tests..." + npx "$CONFORMANCE_PKG" client \ + --command "$CLIENT_PATH" \ + --suite core \ + $BASELINE_ARG || { + log_error "Client conformance tests failed" + exit 1 + } + log_info "Client tests completed" +fi + +# Run server tests +if [[ "$MODE" == "server" || "$MODE" == "both" ]]; then + log_info "Starting server for conformance testing..." + + # Start server in background + "$SERVER_PATH" & + SERVER_PID=$! + + # Wait for server to be ready + log_info "Waiting for server to start (PID: $SERVER_PID)..." + sleep 3 + + # Run server tests + log_info "Running server conformance tests..." + npx "$CONFORMANCE_PKG" server \ + --url http://localhost:3001/mcp \ + --suite core \ + $BASELINE_ARG || { + log_error "Server conformance tests failed" + kill $SERVER_PID 2>/dev/null || true + exit 1 + } + + # Cleanup + log_info "Stopping server..." + kill $SERVER_PID 2>/dev/null || true + log_info "Server tests completed" +fi + +log_info "All conformance tests completed successfully"