diff --git a/.github/workflows/swift.yml b/.github/workflows/swift.yml new file mode 100644 index 00000000..6025aa70 --- /dev/null +++ b/.github/workflows/swift.yml @@ -0,0 +1,21 @@ +name: Swift + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + + runs-on: macos-14 + + steps: + - uses: actions/checkout@v3 + - name: Build + run: cd swift && swift build -v + - name: Build Release + run: cd swift && swift build -c release + - name: Run tests + run: cd swift && swift test -v diff --git a/.gitignore b/.gitignore index dfcfd56f..f6d566cf 100644 --- a/.gitignore +++ b/.gitignore @@ -348,3 +348,12 @@ MigrationBackup/ # Ionide (cross platform F# VS Code tools) working folder .ionide/ + +# Swift +.build/ +.swiftpm/ +*.xcodeproj +*.xcworkspace +xcuserdata/ +DerivedData/ +Package.resolved diff --git a/Package.swift b/Package.swift new file mode 100644 index 00000000..4041116d --- /dev/null +++ b/Package.swift @@ -0,0 +1,43 @@ +// swift-tools-version: 5.9 + +import PackageDescription + +let package = Package( + name: "DevTunnelsClient", + platforms: [ + .iOS(.v16), + .macOS(.v13), + ], + products: [ + .library(name: "DevTunnelsClient", targets: ["DevTunnelsClient"]), + ], + dependencies: [ + .package(url: "https://github.com/apple/swift-nio.git", from: "2.65.0"), + .package(url: "https://github.com/apple/swift-nio-ssh.git", from: "0.9.0"), + .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.20.0"), + ], + targets: [ + .target( + name: "DevTunnelsClient", + dependencies: [ + .product(name: "NIOCore", package: "swift-nio"), + .product(name: "NIOPosix", package: "swift-nio"), + .product(name: "NIOHTTP1", package: "swift-nio"), + .product(name: "NIOWebSocket", package: "swift-nio"), + .product(name: "NIOSSH", package: "swift-nio-ssh"), + .product(name: "NIOTransportServices", package: "swift-nio-transport-services"), + ], + path: "swift/Sources/DevTunnelsClient" + ), + .testTarget( + name: "DevTunnelsClientTests", + dependencies: [ + "DevTunnelsClient", + .product(name: "NIOEmbedded", package: "swift-nio"), + .product(name: "NIOWebSocket", package: "swift-nio"), + .product(name: "NIOSSH", package: "swift-nio-ssh"), + ], + path: "swift/Tests/DevTunnelsClientTests" + ), + ] +) diff --git a/README.md b/README.md index 65925810..573923b5 100644 --- a/README.md +++ b/README.md @@ -7,15 +7,15 @@ Dev tunnels allows developers to securely expose local web services to the Inter ## SDK Feature Matrix -| Feature | C# | TypeScript | Java | Go | Rust | -|---|---|---|---|---|---| -| Management API | ✅ | ✅ | ✅ | ✅ | ✅ | -| Tunnel Client Connections | ✅ | ✅ | ✅ | ✅ | ✅ | -| Tunnel Host Connections | ✅ | ✅ | ❌ | ❌ | ✅ | -| Reconnection | ✅ | ✅ | ❌ | ❌ | ❌ | -| SSH-level Reconnection | ✅ | ✅ | ❌ | ❌ | ❌ | -| Automatic tunnel access token refresh | ✅ | ✅ | ❌ | ❌ | ❌ | -| Ssh Keep-alive | ✅ | ✅ | ❌ | ❌ | ❌ | +| Feature | C# | TypeScript | Java | Go | Rust | Swift | +|---|---|---|---|---|---|---| +| Management API | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| Tunnel Client Connections | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| Tunnel Host Connections | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | +| Reconnection | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | +| SSH-level Reconnection | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | +| Automatic tunnel access token refresh | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | +| Ssh Keep-alive | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ✅ - Supported 🚧 - In Progress diff --git a/swift/.gitignore b/swift/.gitignore new file mode 100644 index 00000000..a73d7215 --- /dev/null +++ b/swift/.gitignore @@ -0,0 +1,7 @@ +# Swift build +/.build +# Xcode +*.xcodeproj +*.xcworkspace +xcuserdata/ +DerivedData/ diff --git a/swift/Package.swift b/swift/Package.swift new file mode 100644 index 00000000..060207cf --- /dev/null +++ b/swift/Package.swift @@ -0,0 +1,41 @@ +// swift-tools-version: 5.9 + +import PackageDescription + +let package = Package( + name: "DevTunnelsClient", + platforms: [ + .iOS(.v16), + .macOS(.v13), + ], + products: [ + .library(name: "DevTunnelsClient", targets: ["DevTunnelsClient"]), + ], + dependencies: [ + .package(url: "https://github.com/apple/swift-nio.git", from: "2.65.0"), + .package(url: "https://github.com/apple/swift-nio-ssh.git", from: "0.9.0"), + .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.20.0"), + ], + targets: [ + .target( + name: "DevTunnelsClient", + dependencies: [ + .product(name: "NIOCore", package: "swift-nio"), + .product(name: "NIOPosix", package: "swift-nio"), + .product(name: "NIOHTTP1", package: "swift-nio"), + .product(name: "NIOWebSocket", package: "swift-nio"), + .product(name: "NIOSSH", package: "swift-nio-ssh"), + .product(name: "NIOTransportServices", package: "swift-nio-transport-services"), + ] + ), + .testTarget( + name: "DevTunnelsClientTests", + dependencies: [ + "DevTunnelsClient", + .product(name: "NIOEmbedded", package: "swift-nio"), + .product(name: "NIOWebSocket", package: "swift-nio"), + .product(name: "NIOSSH", package: "swift-nio-ssh"), + ] + ), + ] +) diff --git a/swift/README.md b/swift/README.md new file mode 100644 index 00000000..d85b2474 --- /dev/null +++ b/swift/README.md @@ -0,0 +1,222 @@ +# dev-tunnels-swift + +A pure Swift client library for [Microsoft Dev Tunnels](https://aka.ms/devtunnels/docs). Connect to tunnel forwarded ports from iOS and macOS apps. + +## Features + +- **Tunnel management** — Full CRUD: list, get, create, update, delete tunnels and ports via the REST API +- **GitHub authentication** — Device code flow for tunnel access tokens +- **Direct connections** — Connect to publicly accessible tunnel ports via HTTPS +- **Relay connections** — Connect to private tunnel ports via WebSocket + SSH port forwarding +- **Auto-reconnect** — Configurable reconnection with exponential backoff on connection drops +- **Keepalive** — Periodic WebSocket pings to prevent idle connection drops +- **TLS** — Native Apple TLS via Network.framework for secure relay connections +- **Pure Swift** — No FFI, no Rust, no cross-compilation — just a Swift Package + +## Architecture + +``` +┌─────────────────────────────────────────────────────┐ +│ Your App │ +├─────────────────────────────────────────────────────┤ +│ DeviceCodeAuth TunnelManagementClient │ ← Management layer +│ (GitHub OAuth) (REST: list/get/create/update/ │ +│ delete tunnels & ports) │ +├─────────────────────────────────────────────────────┤ +│ TunnelConnection │ ← Connection helpers +│ (direct URLs, token extraction, online detection) │ +├─────────────────────────────────────────────────────┤ +│ TunnelRelayClient │ ← Relay client +│ (public API: connect/disconnect, state machine) │ +├─────────────────────────────────────────────────────┤ +│ TunnelRelayStream │ ← NIO pipeline +│ ┌───────────────────────────────────────────────┐ │ +│ │ NIOTSConnectionBootstrap (TLS for wss://) │ │ +│ │ → WebSocketUpgradeHandler (HTTP → WS) │ │ +│ │ → WebSocketBinaryFrameHandler │ │ +│ │ → NIOSSHHandler (user: "tunnel") │ │ +│ │ → forwardedTCPIP channel │ │ +│ │ → SSHPortForwardDataHandler │ │ +│ └───────────────────────────────────────────────┘ │ +├─────────────────────────────────────────────────────┤ +│ Contracts: Tunnel, TunnelEndpoint, TunnelPort, │ ← Types +│ TunnelStatus, enums (Codable, Sendable) │ +└─────────────────────────────────────────────────────┘ +``` + +### Source Layout + +``` +Sources/DevTunnelsClient/ +├── Contracts/ Tunnel, TunnelEndpoint, TunnelPort, TunnelStatus, enums +├── Management/ TunnelManagementClient, DeviceCodeAuth, HTTPClient protocol +└── Connections/ TunnelRelayClient, TunnelRelayStream, port forward messages +``` + +### How Relay Connections Work + +1. **WebSocket** — Connect to the relay URI (`wss://`) with subprotocol `tunnel-relay-client` and `Authorization: Tunnel ` header +2. **SSH over WebSocket** — Binary WebSocket frames carry SSH protocol data. SSH authenticates as user `tunnel` with no password (the access token provides auth) +3. **Port forwarding** — Open a `forwarded-tcpip` SSH channel targeting `127.0.0.1:` on the tunnel host +4. **Data streaming** — Bidirectional data flows through the SSH channel + +## Installation + +Add to your `Package.swift`: + +```swift +dependencies: [ + .package(url: "https://github.com/microsoft/dev-tunnels.git", from: "0.1.0"), +], +targets: [ + .target( + name: "YourApp", + dependencies: [ + .product(name: "DevTunnelsClient", package: "dev-tunnels"), + ] + ), +] +``` + +> The `Package.swift` at the repository root exposes the Swift library. +> Source code lives in `swift/Sources/`. + +## Quick Start + +### Authentication + Discovery + +```swift +import DevTunnelsClient + +// Authenticate via GitHub device code flow +let auth = try await DeviceCodeAuth.start() +print("Go to \(auth.verificationUri) and enter: \(auth.userCode)") +let token = try await DeviceCodeAuth.poll(deviceCode: auth.deviceCode) + +// List tunnels +let client = TunnelManagementClient(accessToken: token) +let tunnels = try await client.listTunnels() + +// Get tunnel detail with connect token +let detail = try await client.getTunnel( + clusterId: "usw2", + tunnelId: "my-tunnel", + tokenScopes: [TunnelAccessScopes.connect] +) +``` + +### Tunnel CRUD + +```swift +// Create a tunnel +let newTunnel = try await client.createTunnel(Tunnel(name: "my-app")) + +// Update a tunnel +var updated = newTunnel +updated.description = "Production endpoint" +let result = try await client.updateTunnel(updated) + +// Add a port +let port = try await client.createTunnelPort( + clusterId: newTunnel.clusterId!, + tunnelId: newTunnel.tunnelId!, + port: TunnelPort(portNumber: 8080, protocol: .https) +) + +// Delete a port +try await client.deleteTunnelPort( + clusterId: newTunnel.clusterId!, + tunnelId: newTunnel.tunnelId!, + portNumber: 8080 +) + +// Delete a tunnel +try await client.deleteTunnel( + clusterId: newTunnel.clusterId!, + tunnelId: newTunnel.tunnelId! +) +``` + +### Direct Connection (Public Ports) + +```swift +// For public tunnel ports — just use the direct URL +if let url = TunnelConnection.directURL(from: tunnel, port: 8080) { + // Use URLSession, WKWebView, etc. with this URL +} +``` + +### Relay Connection (Private Ports) + +```swift +// For private tunnel ports — connect through the relay +if let relay = TunnelRelayClient.fromTunnel(detail, port: 8080) { + let stream = try await relay.connect() + // stream.send(data) / stream.close() +} +``` + +### Auto-Reconnecting Connection + +```swift +// Automatically reconnect on connection drops +let relay = TunnelRelayClient(config: config) + +// Observe state changes +relay.onStateChangeHandler = { state in + print("State: \(state)") // .connected, .reconnecting(attempt: 1), etc. +} + +// Each iteration yields a new stream after (re)connection +for await stream in relay.connectWithReconnect() { + // Use stream until it disconnects; loop yields a new one +} + +// Custom retry policy +let policy = ReconnectPolicy( + maxAttempts: 10, + initialDelay: 0.5, + maxDelay: 60, + backoffMultiplier: 2.0 +) +for await stream in relay.connectWithReconnect(policy: policy) { + // ... +} +``` + +## Limitations + +> **This library is under active development.** The following limitations apply to the current version. + +### Not Yet Implemented + +- **Server-initiated port notifications** — The SSH `tcpip-forward` global request (server telling the client which ports are available) is not yet handled. The client must know the port number in advance. +- **Local TCP listener** — The Go/TS SDKs can open a local TCP socket and forward connections to the tunnel. This library provides the raw stream; local listener forwarding is the caller's responsibility. +- **Host-side functionality** — This is a client-only library. Hosting a tunnel (registering ports, accepting connections) is out of scope. + +### Known Constraints + +- **Apple platforms only for TLS** — TLS uses `NIOTransportServices` (Network.framework), which requires iOS/macOS. Non-Apple platforms would need NIOSSL instead. +- **No certificate pinning** — The relay connection trusts the system TLS certificate store. The SSH layer accepts any host key (matching the Go SDK's `InsecureIgnoreHostKey` behavior, since auth is via the tunnel access token). +- **Single-port streams** — Each `TunnelRelayClient` connects to one port. To forward multiple ports, create multiple clients. + +## Dependencies + +| Package | Purpose | +|---|---| +| [swift-nio](https://github.com/apple/swift-nio) | Async networking, WebSocket codec | +| [swift-nio-ssh](https://github.com/apple/swift-nio-ssh) | SSH protocol over WebSocket | +| [swift-nio-transport-services](https://github.com/apple/swift-nio-transport-services) | Apple TLS via Network.framework | + +## Requirements + +- iOS 16+ / macOS 13+ +- Swift 5.9+ + +## Testing + +```bash +swift test # 155 tests, all offline (no network requests) +``` + +All tests use mock HTTP clients and NIO `EmbeddedChannel` — no real network calls are made during testing. diff --git a/swift/Sources/DevTunnelsClient/Connections/PortForwardMessages.swift b/swift/Sources/DevTunnelsClient/Connections/PortForwardMessages.swift new file mode 100644 index 00000000..470d6d18 --- /dev/null +++ b/swift/Sources/DevTunnelsClient/Connections/PortForwardMessages.swift @@ -0,0 +1,171 @@ +import Foundation +import NIOCore + +/// SSH port forwarding message types for the tunnel relay protocol. +/// +/// Port forwarding in the Dev Tunnels relay uses SSH `forwarded-tcpip` channels. +/// The client opens a channel to connect to a specific port on the tunnel host. + +// MARK: - Port Forward Channel Open + +/// Payload for opening a `forwarded-tcpip` SSH channel. +/// +/// Binary format (SSH wire protocol, big-endian): +/// - string: host (e.g. "127.0.0.1") +/// - uint32: port +/// - string: originator IP address (empty string for tunnel client) +/// - uint32: originator port (0 for tunnel client) +public struct PortForwardChannelOpen: Equatable, Sendable { + /// The SSH channel type used for port forwarding. + public static let channelType = "forwarded-tcpip" + + /// The host to connect to on the tunnel server side. + public let host: String + + /// The port to forward. + public let port: UInt32 + + /// The IP address of the originator (client). + public let originatorIPAddress: String + + /// The port on the originator (client). + public let originatorPort: UInt32 + + public init(host: String = "127.0.0.1", port: UInt32, originatorIPAddress: String = "", originatorPort: UInt32 = 0) { + self.host = host + self.port = port + self.originatorIPAddress = originatorIPAddress + self.originatorPort = originatorPort + } + + /// Serializes this message to SSH wire format. + public func marshal() -> Data { + var data = Data() + writeString(&data, host) + writeUInt32(&data, port) + writeString(&data, originatorIPAddress) + writeUInt32(&data, originatorPort) + return data + } + + /// Deserializes from SSH wire format. + public static func unmarshal(from data: Data) -> PortForwardChannelOpen? { + var offset = 0 + + guard let host = readString(from: data, offset: &offset) else { return nil } + guard let port = readUInt32(from: data, offset: &offset) else { return nil } + guard let originatorIP = readString(from: data, offset: &offset) else { return nil } + guard let originatorPort = readUInt32(from: data, offset: &offset) else { return nil } + + return PortForwardChannelOpen( + host: host, + port: port, + originatorIPAddress: originatorIP, + originatorPort: originatorPort + ) + } +} + +// MARK: - Port Forward Request (tcpip-forward) + +/// Global request payload for `tcpip-forward`. +/// +/// The server sends this to notify the client that a port is available for forwarding. +/// +/// Binary format: +/// - string: address to bind (e.g. "127.0.0.1") +/// - uint32: port number +public struct PortForwardRequest: Equatable, Sendable { + /// The SSH global request type. + public static let requestType = "tcpip-forward" + + /// The address to bind. + public let address: String + + /// The port number. + public let port: UInt32 + + public init(address: String = "127.0.0.1", port: UInt32) { + self.address = address + self.port = port + } + + /// Serializes this message to SSH wire format. + public func marshal() -> Data { + var data = Data() + writeString(&data, address) + writeUInt32(&data, port) + return data + } + + /// Deserializes from SSH wire format. + public static func unmarshal(from data: Data) -> PortForwardRequest? { + var offset = 0 + guard let address = readString(from: data, offset: &offset) else { return nil } + guard let port = readUInt32(from: data, offset: &offset) else { return nil } + return PortForwardRequest(address: address, port: port) + } +} + +// MARK: - Port Forward Success Response + +/// Response payload for a successful `tcpip-forward` request. +/// +/// Binary format: +/// - uint32: the port that was actually bound +public struct PortForwardSuccess: Equatable, Sendable { + /// The port that was bound. + public let port: UInt32 + + public init(port: UInt32) { + self.port = port + } + + public func marshal() -> Data { + var data = Data() + writeUInt32(&data, port) + return data + } + + public static func unmarshal(from data: Data) -> PortForwardSuccess? { + var offset = 0 + guard let port = readUInt32(from: data, offset: &offset) else { return nil } + return PortForwardSuccess(port: port) + } +} + +// MARK: - SSH Wire Format Helpers + +/// Writes an SSH string (uint32 length + UTF-8 bytes) to data. +private func writeString(_ data: inout Data, _ string: String) { + let bytes = Array(string.utf8) + writeUInt32(&data, UInt32(bytes.count)) + data.append(contentsOf: bytes) +} + +/// Writes a big-endian uint32 to data. +private func writeUInt32(_ data: inout Data, _ value: UInt32) { + var bigEndian = value.bigEndian + data.append(Data(bytes: &bigEndian, count: 4)) +} + +/// Reads an SSH string from data at the given offset. +private func readString(from data: Data, offset: inout Int) -> String? { + guard let length = readUInt32(from: data, offset: &offset) else { return nil } + let len = Int(length) + guard offset + len <= data.count else { return nil } + let string = String(data: data[offset.. UInt32? { + guard offset + 4 <= data.count else { return nil } + let b0 = UInt32(data[offset]) + let b1 = UInt32(data[offset + 1]) + let b2 = UInt32(data[offset + 2]) + let b3 = UInt32(data[offset + 3]) + offset += 4 + return (b0 << 24) | (b1 << 16) | (b2 << 8) | b3 +} diff --git a/swift/Sources/DevTunnelsClient/Connections/RelayConnectionState.swift b/swift/Sources/DevTunnelsClient/Connections/RelayConnectionState.swift new file mode 100644 index 00000000..8c0712a0 --- /dev/null +++ b/swift/Sources/DevTunnelsClient/Connections/RelayConnectionState.swift @@ -0,0 +1,100 @@ +import Foundation + +/// Connection state for a tunnel relay client. +public enum RelayConnectionState: Equatable, Sendable { + /// Not connected. + case disconnected + + /// WebSocket connection in progress. + case connectingWebSocket + + /// WebSocket connected, SSH handshake in progress. + case connectingSSH + + /// SSH connected, opening port forwarding channel. + case openingChannel + + /// Fully connected and streaming data. + case connected + + /// Connection lost, attempting to reconnect. + case reconnecting(attempt: Int) + + /// Connection failed with an error. + case failed(RelayConnectionError) + + /// Connection was closed (gracefully or by peer). + case closed +} + +/// Errors that can occur during relay connection. +public enum RelayConnectionError: Error, Equatable, Sendable { + /// Config validation failed. + case invalidConfig(TunnelRelayConfigError) + + /// WebSocket connection failed. + case webSocketFailed(String) + + /// SSH handshake failed. + case sshFailed(String) + + /// Port forwarding channel open failed. + case channelFailed(String) + + /// Connection timed out. + case timeout + + /// Connection was rejected (e.g., 401/403). + case authenticationFailed(String) + + /// Maximum reconnection attempts exhausted. + case reconnectFailed(attempts: Int) +} + +/// Policy for automatic reconnection on connection loss. +public struct ReconnectPolicy: Sendable, Equatable { + /// Maximum number of reconnection attempts before giving up. + public let maxAttempts: Int + + /// Initial delay before the first retry (seconds). + public let initialDelay: TimeInterval + + /// Maximum delay between retries (seconds). Backoff is capped at this value. + public let maxDelay: TimeInterval + + /// Multiplier applied to the delay after each failed attempt. + public let backoffMultiplier: Double + + /// Do not attempt to reconnect. + public static let disabled = ReconnectPolicy(maxAttempts: 0) + + /// Default policy: up to 5 attempts with 1–30s exponential backoff. + public static let `default` = ReconnectPolicy() + + public init( + maxAttempts: Int = 5, + initialDelay: TimeInterval = 1, + maxDelay: TimeInterval = 30, + backoffMultiplier: Double = 2.0 + ) { + self.maxAttempts = maxAttempts + self.initialDelay = initialDelay + self.maxDelay = maxDelay + self.backoffMultiplier = backoffMultiplier + } + + /// Computes the delay for a given attempt number (0-based). + func delay(forAttempt attempt: Int) -> TimeInterval { + let raw = initialDelay * pow(backoffMultiplier, Double(attempt)) + return min(raw, maxDelay) + } +} + +/// Protocol for observing relay connection state changes. +public protocol RelayConnectionDelegate: AnyObject, Sendable { + /// Called when the connection state changes. + func relayConnectionStateDidChange(_ state: RelayConnectionState) + + /// Called when data is received from the forwarded port. + func relayConnectionDidReceiveData(_ data: Data) +} diff --git a/swift/Sources/DevTunnelsClient/Connections/TunnelConnection.swift b/swift/Sources/DevTunnelsClient/Connections/TunnelConnection.swift new file mode 100644 index 00000000..f17c792a --- /dev/null +++ b/swift/Sources/DevTunnelsClient/Connections/TunnelConnection.swift @@ -0,0 +1,101 @@ +import Foundation + +/// Utility for constructing Dev Tunnels connection URLs. +/// +/// Two connection approaches: +/// 1. **Direct HTTPS** — Connect via `{tunnelId}-{port}.{clusterId}.devtunnels.ms` +/// with a connect access token. Simple but only works for publicly accessible tunnels. +/// 2. **Relay** — Connect via the relay WebSocket URI (SSH-over-WebSocket). +/// Works for private tunnels. (Not yet implemented) +public enum TunnelConnection { + + /// Builds the direct HTTPS WebSocket URL for a forwarded port. + /// + /// Format: `wss://{tunnelId}-{port}.{clusterId}.devtunnels.ms` + /// + /// - Parameters: + /// - tunnel: Tunnel with clusterId and tunnelId. + /// - port: Port number to connect to. + /// - Returns: WebSocket URL, or nil if tunnel is missing required fields. + public static func directURL(tunnel: Tunnel, port: UInt16) -> URL? { + guard let tunnelId = tunnel.tunnelId, + let clusterId = tunnel.clusterId else { + return nil + } + return directURL(tunnelId: tunnelId, clusterId: clusterId, port: port) + } + + /// Builds the direct HTTPS WebSocket URL from explicit parameters. + /// + /// - Parameters: + /// - tunnelId: The tunnel ID. + /// - clusterId: The cluster ID. + /// - port: Port number to connect to. + /// - Returns: WebSocket URL. + public static func directURL(tunnelId: String, clusterId: String, port: UInt16) -> URL? { + URL(string: "wss://\(tunnelId)-\(port).\(clusterId).devtunnels.ms") + } + + /// Builds the direct URL from a tunnel endpoint's portUriFormat. + /// + /// The portUriFormat contains `{port}` which gets replaced with the actual port number. + /// + /// - Parameters: + /// - endpoint: Tunnel endpoint with portUriFormat. + /// - port: Port number. + /// - Returns: URL with port substituted, or nil if endpoint has no portUriFormat. + public static func directURL(endpoint: TunnelEndpoint, port: UInt16) -> URL? { + guard let format = endpoint.portUriFormat else { return nil } + let urlString = format.replacingOccurrences( + of: tunnelEndpointPortToken, + with: String(port) + ) + // Convert https:// to wss:// for WebSocket + let wsUrlString = urlString + .replacingOccurrences(of: "https://", with: "wss://") + .replacingOccurrences(of: "http://", with: "ws://") + return URL(string: wsUrlString) + } + + /// Extracts the connect access token from a tunnel's accessTokens. + /// + /// - Parameter tunnel: Tunnel with accessTokens. + /// - Returns: The connect-scoped JWT, or nil if not present. + public static func connectToken(from tunnel: Tunnel) -> String? { + tunnel.accessTokens?[TunnelAccessScopes.connect] + } + + /// Builds the authorization header value for tunnel connect. + /// + /// - Parameter connectToken: The connect-scoped JWT. + /// - Returns: Header value in the format `tunnel {token}`. + public static func tunnelAuthHeader(connectToken: String) -> String { + "tunnel \(connectToken)" + } + + /// Extracts the client relay URI from a tunnel's endpoints. + /// + /// Looks for a TunnelRelay endpoint with a clientRelayUri. + /// + /// - Parameter tunnel: Tunnel with endpoints. + /// - Returns: The client relay URI, or nil if no relay endpoint exists. + public static func clientRelayURI(from tunnel: Tunnel) -> String? { + tunnel.endpoints? + .first(where: { $0.connectionMode == .tunnelRelay && $0.clientRelayUri != nil })? + .clientRelayUri + } + + /// Checks whether a tunnel currently has active host connections. + /// + /// - Parameter tunnel: Tunnel with status or endpoints. + /// - Returns: true if hosts are connected. + public static func isOnline(_ tunnel: Tunnel) -> Bool { + if let count = tunnel.status?.hostConnectionCount?.current, count > 0 { + return true + } + if let endpoints = tunnel.endpoints, !endpoints.isEmpty { + return true + } + return false + } +} diff --git a/swift/Sources/DevTunnelsClient/Connections/TunnelRelayClient.swift b/swift/Sources/DevTunnelsClient/Connections/TunnelRelayClient.swift new file mode 100644 index 00000000..7ed97a6e --- /dev/null +++ b/swift/Sources/DevTunnelsClient/Connections/TunnelRelayClient.swift @@ -0,0 +1,266 @@ +import Foundation +import NIOCore + +/// A client that connects to a Dev Tunnel via the relay service. +/// +/// Connection flow: +/// 1. WebSocket → relay URI with `tunnel-relay-client` subprotocol +/// 2. SSH handshake over WebSocket (user: "tunnel", no password) +/// 3. Open `forwarded-tcpip` channel for the requested port +/// 4. Bidirectional data streaming through the SSH channel +/// +/// Supports automatic reconnection when the connection drops unexpectedly. +/// Use ``connectWithReconnect(policy:)`` to get an ``AsyncStream`` of streams +/// that yields a new stream on each (re)connection. +/// +/// Usage: +/// ```swift +/// let client = TunnelRelayClient(config: config) +/// +/// // Simple one-shot connection: +/// let stream = try await client.connect() +/// +/// // Auto-reconnecting connection: +/// for await stream in client.connectWithReconnect() { +/// // Use stream; loop yields a new one after reconnect +/// } +/// ``` +public final class TunnelRelayClient: Sendable { + private let config: TunnelRelayConfig + private let _state: LockedValue + private let _stream: LockedValue + + private let _onStateChange: LockedValue<(@Sendable (RelayConnectionState) -> Void)?> + + /// Callback invoked on every state change (connected, disconnected, reconnecting, etc.). + /// Set this before calling ``connect()`` or ``connectWithReconnect(policy:)``. + public var onStateChangeHandler: (@Sendable (RelayConnectionState) -> Void)? { + get { _onStateChange.withLockedValue { $0 } } + set { _onStateChange.withLockedValue { $0 = newValue } } + } + + /// Current connection state. + public var state: RelayConnectionState { + _state.withLockedValue { $0 } + } + + /// Creates a new relay client with the given configuration. + public init(config: TunnelRelayConfig) { + self.config = config + self._state = LockedValue(.disconnected) + self._stream = LockedValue(nil) + self._onStateChange = LockedValue(nil) + } + + /// Creates a relay client from a tunnel object. + /// + /// Extracts the relay URI and connect token from the tunnel's endpoints and access tokens. + /// + /// - Parameters: + /// - tunnel: Tunnel with endpoints and access tokens. + /// - port: Port to forward. + /// - Returns: Configured client, or nil if tunnel lacks relay endpoint or connect token. + public static func fromTunnel(_ tunnel: Tunnel, port: UInt16) -> TunnelRelayClient? { + guard let relayUri = TunnelConnection.clientRelayURI(from: tunnel), + let token = TunnelConnection.connectToken(from: tunnel) else { + return nil + } + let config = TunnelRelayConfig( + relayUri: relayUri, + accessToken: token, + port: port + ) + return TunnelRelayClient(config: config) + } + + /// Validates the configuration without connecting. + public func validateConfig() -> TunnelRelayConfigError? { + config.validate() + } + + /// Connects to the tunnel relay and opens a port forwarding channel. + /// + /// - Returns: A bidirectional stream for the forwarded port. + /// - Throws: `RelayConnectionError` if connection fails at any stage. + public func connect() async throws -> TunnelRelayStream { + if let error = config.validate() { + transition(to: .failed(.invalidConfig(error))) + throw RelayConnectionError.invalidConfig(error) + } + + transition(to: .connectingWebSocket) + + let stream = try await TunnelRelayStream.connect(config: config) { [weak self] newState in + self?.transition(to: newState) + } + + _stream.withLockedValue { $0 = stream } + transition(to: .connected) + return stream + } + + /// Connects with automatic reconnection on unexpected disconnects. + /// + /// Returns an `AsyncStream` that yields a new ``TunnelRelayStream`` on each + /// successful connection (initial and subsequent reconnects). The stream finishes + /// when reconnection is exhausted or ``disconnect()`` is called. + /// + /// - Parameter policy: Reconnection policy (defaults to ``ReconnectPolicy/default``). + /// - Returns: An async stream of relay streams, one per (re)connection. + public func connectWithReconnect( + policy: ReconnectPolicy = .default + ) -> AsyncStream { + AsyncStream { continuation in + let task = Task { [weak self] in + guard let self else { + continuation.finish() + return + } + + // Initial connection + do { + let stream = try await self.connect() + self.wireDisconnect(stream: stream) + continuation.yield(stream) + } catch { + self.transition(to: .failed(error as? RelayConnectionError ?? .webSocketFailed(error.localizedDescription))) + continuation.finish() + return + } + + // Reconnection loop + var attempt = 0 + while !Task.isCancelled { + // Wait for disconnection + await self.waitForDisconnect() + + guard !Task.isCancelled else { break } + + // Check if this was an intentional close + let currentState = self.state + if currentState == .closed { + break + } + + // Retry with backoff + while attempt < policy.maxAttempts && !Task.isCancelled { + attempt += 1 + self.transition(to: .reconnecting(attempt: attempt)) + + let delay = policy.delay(forAttempt: attempt - 1) + try? await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000)) + + guard !Task.isCancelled else { break } + + do { + // Close old resources before reconnecting + let oldStream = self._stream.withLockedValue { s -> TunnelRelayStream? in + let old = s + s = nil + return old + } + if let oldStream { + try? await oldStream.close() + } + + let stream = try await self.connect() + self.wireDisconnect(stream: stream) + continuation.yield(stream) + attempt = 0 // Reset on success + break + } catch { + if attempt >= policy.maxAttempts { + self.transition(to: .failed(.reconnectFailed(attempts: attempt))) + continuation.finish() + return + } + // Continue to next attempt + } + } + + if attempt >= policy.maxAttempts { + break + } + } + continuation.finish() + } + + continuation.onTermination = { _ in + task.cancel() + } + } + } + + /// Disconnects from the relay, closing all channels. + public func disconnect() { + transition(to: .closed) + let stream = _stream.withLockedValue { s -> TunnelRelayStream? in + let old = s + s = nil + return old + } + if let stream { + Task { + try? await stream.close() + } + } + // Signal the reconnect loop to stop + _disconnectContinuation.withLockedValue { c in + c?.resume() + c = nil + } + } + + // MARK: - Internal + + private let _disconnectContinuation: LockedValue?> = LockedValue(nil) + + private func transition(to newState: RelayConnectionState) { + _state.withLockedValue { $0 = newState } + _onStateChange.withLockedValue { $0?(newState) } + } + + /// Wires the stream's disconnect callback to resume the reconnect loop. + private func wireDisconnect(stream: TunnelRelayStream) { + stream.onDisconnect = { [weak self] in + guard let self else { return } + self.transition(to: .disconnected) + self._disconnectContinuation.withLockedValue { c in + c?.resume() + c = nil + } + } + } + + /// Suspends until the current stream disconnects. + private func waitForDisconnect() async { + await withCheckedContinuation { (continuation: CheckedContinuation) in + // If already disconnected, resume immediately + let currentState = self.state + if currentState == .disconnected || currentState == .closed { + continuation.resume() + return + } + self._disconnectContinuation.withLockedValue { $0 = continuation } + } + } +} + +/// A thread-safe locked value container. +/// +/// NIOCore provides `NIOLockedValueBox` but we use a simple version +/// for clarity and to avoid tight NIO coupling in the public API. +internal final class LockedValue: @unchecked Sendable { + private var value: Value + private let lock = NSLock() + + init(_ value: Value) { + self.value = value + } + + func withLockedValue(_ body: (inout Value) -> T) -> T { + lock.lock() + defer { lock.unlock() } + return body(&value) + } +} diff --git a/swift/Sources/DevTunnelsClient/Connections/TunnelRelayConfig.swift b/swift/Sources/DevTunnelsClient/Connections/TunnelRelayConfig.swift new file mode 100644 index 00000000..88e2b0bd --- /dev/null +++ b/swift/Sources/DevTunnelsClient/Connections/TunnelRelayConfig.swift @@ -0,0 +1,93 @@ +import Foundation + +/// Configuration for a tunnel relay connection. +public struct TunnelRelayConfig: Sendable, Equatable { + /// Client relay URI from the tunnel endpoint (wss://...). + public let relayUri: String + + /// Tunnel access token with "connect" scope. + public let accessToken: String + + /// The port to forward to on the remote host. + public let port: UInt16 + + /// WebSocket subprotocol for the relay. + public let subprotocol: String + + /// Connection timeout in seconds. + public let connectionTimeout: TimeInterval + + /// Interval in seconds between WebSocket keepalive pings. Set to 0 to disable. + public let keepaliveInterval: TimeInterval + + public init( + relayUri: String, + accessToken: String, + port: UInt16, + subprotocol: String = TunnelRelayConstants.clientWebSocketSubProtocol, + connectionTimeout: TimeInterval = 30, + keepaliveInterval: TimeInterval = TunnelRelayConstants.defaultKeepaliveInterval + ) { + self.relayUri = relayUri + self.accessToken = accessToken + self.port = port + self.subprotocol = subprotocol + self.connectionTimeout = connectionTimeout + self.keepaliveInterval = keepaliveInterval + } + + /// Validates that the config has all required fields. + public func validate() -> TunnelRelayConfigError? { + if relayUri.isEmpty { + return .missingRelayUri + } + guard let url = URL(string: relayUri) else { + return .invalidRelayUri(relayUri) + } + guard url.scheme == "wss" || url.scheme == "ws" else { + return .invalidRelayUri(relayUri) + } + if accessToken.isEmpty { + return .missingAccessToken + } + if port == 0 { + return .invalidPort + } + return nil + } + + /// Builds the Authorization header value. + /// Prefixes "Tunnel " if not already present. + var authorizationHeader: String { + if accessToken.contains("Tunnel") || accessToken.contains("tunnel") { + return accessToken + } + return "Tunnel \(accessToken)" + } +} + +/// Errors from config validation. +public enum TunnelRelayConfigError: Error, Equatable { + case missingRelayUri + case invalidRelayUri(String) + case missingAccessToken + case invalidPort +} + +/// Constants for the tunnel relay protocol. +public enum TunnelRelayConstants { + /// V1 WebSocket subprotocol for client relay connections. + public static let clientWebSocketSubProtocol = "tunnel-relay-client" + + /// SSH channel type for port forwarding. + public static let portForwardChannelType = "forwarded-tcpip" + + /// SSH global request type for port forwarding notification. + public static let portForwardRequestType = "tcpip-forward" + + /// SSH user for tunnel connections. + public static let sshUser = "tunnel" + + /// Default interval (seconds) between WebSocket keepalive pings. + public static let defaultKeepaliveInterval: TimeInterval = 30 +} diff --git a/swift/Sources/DevTunnelsClient/Connections/TunnelRelayStream.swift b/swift/Sources/DevTunnelsClient/Connections/TunnelRelayStream.swift new file mode 100644 index 00000000..4da0f7ca --- /dev/null +++ b/swift/Sources/DevTunnelsClient/Connections/TunnelRelayStream.swift @@ -0,0 +1,411 @@ +import Foundation +import NIOCore +import NIOPosix +import NIOHTTP1 +import NIOWebSocket +import NIOSSH +import NIOTransportServices +import Network + +/// A bidirectional stream to a forwarded port through a tunnel relay. +/// +/// Wraps the full WebSocket → SSH → port forwarding pipeline. +/// Connection flow: +/// 1. TCP connect to relay host (with TLS for wss://) +/// 2. HTTP → WebSocket upgrade with `tunnel-relay-client` subprotocol +/// 3. SSH handshake over WebSocket (user: "tunnel") +/// 4. Open `forwarded-tcpip` channel for the requested port +/// 5. Bidirectional data streaming through the SSH channel +/// +/// Sends periodic WebSocket pings to keep the connection alive. +public final class TunnelRelayStream: @unchecked Sendable { + private let parentChannel: Channel + private let sshChildChannel: Channel + private let group: EventLoopGroup + private var _isClosed = false + private var keepaliveTask: Scheduled? + + /// Callback invoked when the connection drops unexpectedly. + /// Not called for explicit `close()` calls. + internal var onDisconnect: (@Sendable () -> Void)? + + init(parentChannel: Channel, sshChildChannel: Channel, group: EventLoopGroup) { + self.parentChannel = parentChannel + self.sshChildChannel = sshChildChannel + self.group = group + } + + /// Connects to the tunnel relay and establishes SSH port forwarding. + static func connect( + config: TunnelRelayConfig, + onStateChange: @escaping @Sendable (RelayConnectionState) -> Void + ) async throws -> TunnelRelayStream { + guard let url = URL(string: config.relayUri), + let host = url.host, + let scheme = url.scheme else { + throw RelayConnectionError.webSocketFailed("Invalid relay URI: \(config.relayUri)") + } + + let useTLS = scheme == "wss" + let port = url.port ?? (useTLS ? 443 : 80) + + let group: EventLoopGroup + #if canImport(Network) + group = NIOTSEventLoopGroup() + #else + group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + #endif + + do { + onStateChange(.connectingWebSocket) + + // Step 1: TCP connect + WebSocket upgrade + let upgradePromise = group.next().makePromise(of: Void.self) + let wsFrameHandler = WebSocketBinaryFrameHandler(upgradePromise: upgradePromise) + let upgradeHandler = WebSocketUpgradeHandler( + config: config, + wsFrameHandler: wsFrameHandler + ) + + let channel: Channel + #if canImport(Network) + let tsBootstrap = NIOTSConnectionBootstrap(group: group) + .channelInitializer { ch in + ch.pipeline.addHandler(upgradeHandler) + } + if useTLS { + channel = try await tsBootstrap + .tlsOptions(NWProtocolTLS.Options()) + .connect(host: host, port: port) + .get() + } else { + channel = try await tsBootstrap + .connect(host: host, port: port) + .get() + } + #else + let bootstrap = ClientBootstrap(group: group) + .channelOption(.socketOption(.so_reuseaddr), value: 1) + .channelInitializer { ch in + ch.pipeline.addHandler(upgradeHandler) + } + channel = try await bootstrap.connect(host: host, port: port).get() + #endif + + // Wait for WebSocket upgrade to complete + try await upgradePromise.futureResult.get() + + onStateChange(.connectingSSH) + + // Step 2: SSH handshake over WebSocket + let sshHandler = try await addSSHHandlers(to: channel) + + onStateChange(.openingChannel) + + // Step 3: Open port forwarding channel + let sshChildChannel = try await openPortForwardChannel( + sshHandler: sshHandler, + on: channel, + port: config.port + ) + + let stream = TunnelRelayStream( + parentChannel: channel, + sshChildChannel: sshChildChannel, + group: group + ) + + // Start periodic WebSocket keepalive pings + if config.keepaliveInterval > 0 { + stream.startKeepalive(interval: config.keepaliveInterval) + } + + // Wire disconnect detection: WebSocket channel going inactive + // triggers the stream's onDisconnect callback. + wsFrameHandler.onChannelInactive = { [weak stream] in + guard let stream, !stream._isClosed else { return } + stream.onDisconnect?() + } + + return stream + } catch { + try? await group.shutdownGracefully() + if let relayError = error as? RelayConnectionError { + throw relayError + } + throw RelayConnectionError.webSocketFailed(error.localizedDescription) + } + } + + /// Adds NIO SSH client handlers to the channel pipeline. + @discardableResult + private static func addSSHHandlers(to channel: Channel) async throws -> NIOSSHHandler { + let sshHandler = NIOSSHHandler( + role: .client( + .init( + userAuthDelegate: TunnelSSHClientAuthDelegate(), + serverAuthDelegate: TunnelSSHServerAuthDelegate() + ) + ), + allocator: channel.allocator, + inboundChildChannelInitializer: nil + ) + + try await channel.pipeline.addHandler(sshHandler).get() + return sshHandler + } + + /// Opens a `forwarded-tcpip` SSH channel for the given port. + private static func openPortForwardChannel( + sshHandler: NIOSSHHandler, + on channel: Channel, + port: UInt16 + ) async throws -> Channel { + let channelType = SSHChannelType.forwardedTCPIP( + .init( + listeningHost: "127.0.0.1", + listeningPort: Int(port), + originatorAddress: try .init(ipAddress: "127.0.0.1", port: 0) + ) + ) + + let childChannelPromise = channel.eventLoop.makePromise(of: Channel.self) + + channel.eventLoop.execute { + sshHandler.createChannel(childChannelPromise, channelType: channelType) { childChannel, _ in + childChannel.pipeline.addHandler(SSHPortForwardDataHandler()) + } + } + + return try await childChannelPromise.futureResult.get() + } + + /// Whether the underlying channels are still active. + public var isActive: Bool { + !_isClosed && sshChildChannel.isActive + } + + /// Starts periodic WebSocket ping frames to keep the connection alive. + func startKeepalive(interval: TimeInterval) { + let eventLoop = parentChannel.eventLoop + func schedulePing() { + guard !_isClosed, parentChannel.isActive else { return } + keepaliveTask = eventLoop.scheduleTask(in: .seconds(Int64(interval))) { [weak self] in + guard let self, !self._isClosed, self.parentChannel.isActive else { return } + let emptyBuffer = self.parentChannel.allocator.buffer(capacity: 0) + let ping = WebSocketFrame(fin: true, opcode: .ping, data: emptyBuffer) + self.parentChannel.writeAndFlush(ping, promise: nil) + schedulePing() + } + } + schedulePing() + } + + /// Sends data through the forwarded port. + public func send(_ data: Data) async throws { + guard sshChildChannel.isActive else { + throw RelayConnectionError.channelFailed("Not connected") + } + var buffer = sshChildChannel.allocator.buffer(capacity: data.count) + buffer.writeBytes(data) + let sshData = SSHChannelData(type: .channel, data: .byteBuffer(buffer)) + try await sshChildChannel.writeAndFlush(sshData) + } + + /// Closes the relay connection. + public func close() async throws { + guard !_isClosed else { return } + _isClosed = true + keepaliveTask?.cancel() + keepaliveTask = nil + try? await sshChildChannel.close() + try? await parentChannel.close() + try await group.shutdownGracefully() + } + + deinit { + try? group.syncShutdownGracefully() + } +} + +// MARK: - SSH Port Forward Data Handler + +/// Handles data on the SSH child channel (port forwarding). +/// Receives `SSHChannelData` and passes raw bytes upstream. +final class SSHPortForwardDataHandler: ChannelDuplexHandler { + typealias InboundIn = SSHChannelData + typealias InboundOut = ByteBuffer + typealias OutboundIn = SSHChannelData + typealias OutboundOut = SSHChannelData + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let channelData = unwrapInboundIn(data) + guard case .channel = channelData.type, + case .byteBuffer(let buffer) = channelData.data else { + return + } + context.fireChannelRead(wrapInboundOut(buffer)) + } + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let channelData = unwrapOutboundIn(data) + context.write(wrapOutboundOut(channelData), promise: promise) + } +} + +// MARK: - SSH Auth Delegates + +/// SSH client auth delegate that authenticates as "tunnel" user with no password. +/// The tunnel access token (sent via WebSocket Authorization header) provides auth. +final class TunnelSSHClientAuthDelegate: NIOSSHClientUserAuthenticationDelegate { + func nextAuthenticationType( + availableMethods: NIOSSHAvailableUserAuthenticationMethods, + nextChallengePromise: EventLoopPromise + ) { + // Try "none" authentication first — the tunnel relay trusts the WebSocket token + nextChallengePromise.succeed( + NIOSSHUserAuthenticationOffer( + username: TunnelRelayConstants.sshUser, + serviceName: "", + offer: .none + ) + ) + } +} + +/// SSH server auth delegate that accepts any host key. +/// The WebSocket TLS + tunnel access token provide sufficient authentication. +final class TunnelSSHServerAuthDelegate: NIOSSHClientServerAuthenticationDelegate { + func validateHostKey( + hostKey: NIOSSHPublicKey, + validationCompletePromise: EventLoopPromise + ) { + // Accept any host key — same as Go SDK's InsecureIgnoreHostKey + validationCompletePromise.succeed(()) + } +} + +// MARK: - WebSocket Upgrade Handler + +/// Handles the HTTP → WebSocket upgrade handshake. +/// After successful upgrade, replaces itself with WebSocket frame handlers. +final class WebSocketUpgradeHandler: ChannelInboundHandler, RemovableChannelHandler { + typealias InboundIn = ByteBuffer + + private let config: TunnelRelayConfig + private let wsFrameHandler: WebSocketBinaryFrameHandler + + init(config: TunnelRelayConfig, wsFrameHandler: WebSocketBinaryFrameHandler) { + self.config = config + self.wsFrameHandler = wsFrameHandler + } + + func channelActive(context: ChannelHandlerContext) { + sendUpgradeRequest(context: context) + } + + private func sendUpgradeRequest(context: ChannelHandlerContext) { + guard let url = URL(string: config.relayUri) else { return } + let path = url.path.isEmpty ? "/" : url.path + let host = url.host ?? "" + + // Generate random WebSocket key + var keyBytes = [UInt8](repeating: 0, count: 16) + _ = SecRandomCopyBytes(kSecRandomDefault, keyBytes.count, &keyBytes) + let key = Data(keyBytes).base64EncodedString() + + var buffer = context.channel.allocator.buffer(capacity: 512) + buffer.writeString("GET \(path) HTTP/1.1\r\n") + buffer.writeString("Host: \(host)\r\n") + buffer.writeString("Upgrade: websocket\r\n") + buffer.writeString("Connection: Upgrade\r\n") + buffer.writeString("Sec-WebSocket-Key: \(key)\r\n") + buffer.writeString("Sec-WebSocket-Version: 13\r\n") + buffer.writeString("Sec-WebSocket-Protocol: \(config.subprotocol)\r\n") + buffer.writeString("Authorization: \(config.authorizationHeader)\r\n") + buffer.writeString("\r\n") + + context.writeAndFlush(NIOAny(buffer), promise: nil) + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + var buffer = unwrapInboundIn(data) + guard let response = buffer.readString(length: buffer.readableBytes) else { return } + + if response.contains("101") { + // Upgrade success — swap to WebSocket frame handlers + _ = context.pipeline.addHandlers([ + ByteToMessageHandler(WebSocketFrameDecoder()), + WebSocketFrameEncoder(), + wsFrameHandler, + ]).flatMap { + context.pipeline.removeHandler(self) + } + wsFrameHandler.upgradePromise.succeed(()) + } else if response.contains("401") || response.contains("403") { + wsFrameHandler.upgradePromise.fail( + RelayConnectionError.authenticationFailed("Relay returned auth error") + ) + } else { + wsFrameHandler.upgradePromise.fail( + RelayConnectionError.webSocketFailed("WebSocket upgrade failed") + ) + } + } + + func errorCaught(context: ChannelHandlerContext, error: Error) { + wsFrameHandler.upgradePromise.fail( + RelayConnectionError.webSocketFailed(error.localizedDescription) + ) + context.close(promise: nil) + } +} + +// MARK: - WebSocket Binary Frame Handler + +/// Converts between WebSocket binary frames and raw ByteBuffers. +/// Sits between the WebSocket frame codec and the SSH handler. +final class WebSocketBinaryFrameHandler: ChannelDuplexHandler { + typealias InboundIn = WebSocketFrame + typealias InboundOut = ByteBuffer + typealias OutboundIn = ByteBuffer + typealias OutboundOut = WebSocketFrame + + let upgradePromise: EventLoopPromise + + /// Called when the channel goes inactive (connection lost). + var onChannelInactive: (@Sendable () -> Void)? + + init(upgradePromise: EventLoopPromise) { + self.upgradePromise = upgradePromise + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let frame = unwrapInboundIn(data) + switch frame.opcode { + case .binary: + let data = frame.unmaskedData + context.fireChannelRead(wrapInboundOut(data)) + case .connectionClose: + context.close(promise: nil) + case .ping: + let pongData = context.channel.allocator.buffer(capacity: 0) + let pong = WebSocketFrame(fin: true, opcode: .pong, data: pongData) + context.writeAndFlush(wrapOutboundOut(pong), promise: nil) + default: + break + } + } + + func channelInactive(context: ChannelHandlerContext) { + onChannelInactive?() + context.fireChannelInactive() + } + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let buffer = unwrapOutboundIn(data) + let frame = WebSocketFrame(fin: true, opcode: .binary, data: buffer) + context.write(wrapOutboundOut(frame), promise: promise) + } +} + diff --git a/swift/Sources/DevTunnelsClient/Contracts/Enums.swift b/swift/Sources/DevTunnelsClient/Contracts/Enums.swift new file mode 100644 index 00000000..51fa1a75 --- /dev/null +++ b/swift/Sources/DevTunnelsClient/Contracts/Enums.swift @@ -0,0 +1,59 @@ +/// Defines the connection mode for a tunnel endpoint. +public enum TunnelConnectionMode: String, Codable, Sendable { + /// Connections via a local network address. + case localNetwork = "LocalNetwork" + + /// Connections via the tunnel service's built-in relay. + case tunnelRelay = "TunnelRelay" + + /// Unknown connection mode from the service. + case unknown + + public init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + let rawValue = try container.decode(String.self) + self = TunnelConnectionMode(rawValue: rawValue) ?? .unknown + } +} + +/// Defines scopes for tunnel access tokens. +public enum TunnelAccessScopes { + /// Create tunnels. + public static let create = "create" + /// Manage tunnel properties. + public static let manage = "manage" + /// Manage tunnel ports. + public static let managePorts = "manage:ports" + /// Host connections. + public static let host = "host" + /// Inspect tunnel activity. + public static let inspect = "inspect" + /// Connect to tunnel ports. + public static let connect = "connect" +} + +/// Protocol hint for a tunnel port. +/// +/// Indicates the expected application protocol for the tunnel port. +/// The service uses this to generate appropriate access URLs. +public enum TunnelPortProtocol: String, Codable, Sendable { + /// Automatically detect the protocol. + case auto + /// HTTP protocol. + case http + /// HTTPS protocol. + case https + /// Remote Desktop Protocol. + case rdp + /// Secure Shell protocol. + case ssh + + /// Unknown protocol from the service. + case unknown + + public init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + let rawValue = try container.decode(String.self) + self = TunnelPortProtocol(rawValue: rawValue) ?? .unknown + } +} diff --git a/swift/Sources/DevTunnelsClient/Contracts/Tunnel.swift b/swift/Sources/DevTunnelsClient/Contracts/Tunnel.swift new file mode 100644 index 00000000..84f68142 --- /dev/null +++ b/swift/Sources/DevTunnelsClient/Contracts/Tunnel.swift @@ -0,0 +1,56 @@ +/// Data contract for tunnel objects managed through the tunnel service REST API. +public struct Tunnel: Codable, Equatable, Sendable { + /// Cluster ID where the tunnel was created. + public var clusterId: String? + + /// Generated unique tunnel ID within the cluster. + public var tunnelId: String? + + /// Optional short name (alias). Globally unique within the parent domain. + public var name: String? + + /// Description of the tunnel. + public var description: String? + + /// Labels for the tunnel. + public var labels: [String]? + + /// Optional parent domain (if not using the default). + public var domain: String? + + /// Dictionary mapping from scopes to tunnel access tokens. + public var accessTokens: [String: String]? + + /// Current connection status of the tunnel. + public var status: TunnelStatus? + + /// Endpoints where hosts are currently accepting client connections. + public var endpoints: [TunnelEndpoint]? + + /// Ports in the tunnel. + public var ports: [TunnelPort]? + + public init( + clusterId: String? = nil, + tunnelId: String? = nil, + name: String? = nil, + description: String? = nil, + labels: [String]? = nil, + domain: String? = nil, + accessTokens: [String: String]? = nil, + status: TunnelStatus? = nil, + endpoints: [TunnelEndpoint]? = nil, + ports: [TunnelPort]? = nil + ) { + self.clusterId = clusterId + self.tunnelId = tunnelId + self.name = name + self.description = description + self.labels = labels + self.domain = domain + self.accessTokens = accessTokens + self.status = status + self.endpoints = endpoints + self.ports = ports + } +} diff --git a/swift/Sources/DevTunnelsClient/Contracts/TunnelEndpoint.swift b/swift/Sources/DevTunnelsClient/Contracts/TunnelEndpoint.swift new file mode 100644 index 00000000..7213c6de --- /dev/null +++ b/swift/Sources/DevTunnelsClient/Contracts/TunnelEndpoint.swift @@ -0,0 +1,52 @@ +/// Base class for tunnel connection parameters. +/// +/// A tunnel endpoint specifies how and where hosts and clients can connect to a tunnel. +public struct TunnelEndpoint: Codable, Equatable, Sendable { + /// Endpoint ID. + public var id: String? + + /// Connection mode of the endpoint. + public var connectionMode: TunnelConnectionMode? + + /// ID of the host listening on this endpoint. + public var hostId: String? + + /// Public keys for authenticating the host. + public var hostPublicKeys: [String]? + + /// URI format string for web client port connections. + /// Contains `{port}` token to be replaced with actual port number. + public var portUriFormat: String? + + /// URI for web client connection to the default port. + public var tunnelUri: String? + + /// Host relay URI (for TunnelRelay connection mode). + public var hostRelayUri: String? + + /// Client relay URI (for TunnelRelay connection mode). + public var clientRelayUri: String? + + public init( + id: String? = nil, + connectionMode: TunnelConnectionMode? = nil, + hostId: String? = nil, + hostPublicKeys: [String]? = nil, + portUriFormat: String? = nil, + tunnelUri: String? = nil, + hostRelayUri: String? = nil, + clientRelayUri: String? = nil + ) { + self.id = id + self.connectionMode = connectionMode + self.hostId = hostId + self.hostPublicKeys = hostPublicKeys + self.portUriFormat = portUriFormat + self.tunnelUri = tunnelUri + self.hostRelayUri = hostRelayUri + self.clientRelayUri = clientRelayUri + } +} + +/// Token in `portUriFormat` to be replaced by a port number. +public let tunnelEndpointPortToken = "{port}" diff --git a/swift/Sources/DevTunnelsClient/Contracts/TunnelPort.swift b/swift/Sources/DevTunnelsClient/Contracts/TunnelPort.swift new file mode 100644 index 00000000..12f39923 --- /dev/null +++ b/swift/Sources/DevTunnelsClient/Contracts/TunnelPort.swift @@ -0,0 +1,41 @@ +/// Data contract for tunnel port objects managed through the tunnel service REST API. +public struct TunnelPort: Codable, Equatable, Sendable { + /// Cluster ID where the tunnel was created. + public var clusterId: String? + + /// Generated tunnel ID, unique within the cluster. + public var tunnelId: String? + + /// IP port number of the tunnel port. + public var portNumber: UInt16 + + /// Optional short name of the port. Unique among named ports of the same tunnel. + public var name: String? + + /// Optional description of the port. + public var description: String? + + /// Protocol of the tunnel port (auto, http, https, rdp, ssh). + public var `protocol`: TunnelPortProtocol? + + /// Dictionary mapping from scopes to port-level access tokens. + public var accessTokens: [String: String]? + + public init( + clusterId: String? = nil, + tunnelId: String? = nil, + portNumber: UInt16, + name: String? = nil, + description: String? = nil, + protocol: TunnelPortProtocol? = nil, + accessTokens: [String: String]? = nil + ) { + self.clusterId = clusterId + self.tunnelId = tunnelId + self.portNumber = portNumber + self.name = name + self.description = description + self.protocol = `protocol` + self.accessTokens = accessTokens + } +} diff --git a/swift/Sources/DevTunnelsClient/Contracts/TunnelStatus.swift b/swift/Sources/DevTunnelsClient/Contracts/TunnelStatus.swift new file mode 100644 index 00000000..733d69a6 --- /dev/null +++ b/swift/Sources/DevTunnelsClient/Contracts/TunnelStatus.swift @@ -0,0 +1,58 @@ +/// Current status of a tunnel. +public struct TunnelStatus: Codable, Equatable, Sendable { + /// Current value and limit for the number of ports on the tunnel. + public var portCount: ResourceStatus? + + /// Current value and limit for the number of hosts connected. + public var hostConnectionCount: ResourceStatus? + + /// UTC time when a host was last accepting connections, or nil if never. + public var lastHostConnectionTime: String? + + /// Current value and limit for the number of clients connected. + public var clientConnectionCount: ResourceStatus? + + public init( + portCount: ResourceStatus? = nil, + hostConnectionCount: ResourceStatus? = nil, + lastHostConnectionTime: String? = nil, + clientConnectionCount: ResourceStatus? = nil + ) { + self.portCount = portCount + self.hostConnectionCount = hostConnectionCount + self.lastHostConnectionTime = lastHostConnectionTime + self.clientConnectionCount = clientConnectionCount + } +} + +/// Current value and limit for a limited resource related to a tunnel or port. +/// The API may return this as either a plain number or an object with `current` and `limit`. +public struct ResourceStatus: Codable, Equatable, Sendable { + /// Current count of the resource (e.g., connected clients, open ports). + public var current: UInt64 + + /// Maximum allowed by the service, or nil if unlimited. + public var limit: UInt64? + + public init(current: UInt64 = 0, limit: UInt64? = nil) { + self.current = current + self.limit = limit + } + + public init(from decoder: Decoder) throws { + // The API can return either a plain number or {current, limit} object. + if let container = try? decoder.singleValueContainer(), + let value = try? container.decode(UInt64.self) { + self.current = value + self.limit = nil + } else { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.current = try container.decodeIfPresent(UInt64.self, forKey: .current) ?? 0 + self.limit = try container.decodeIfPresent(UInt64.self, forKey: .limit) + } + } + + private enum CodingKeys: String, CodingKey { + case current, limit + } +} diff --git a/swift/Sources/DevTunnelsClient/Management/APIResponses.swift b/swift/Sources/DevTunnelsClient/Management/APIResponses.swift new file mode 100644 index 00000000..4d6dd2cd --- /dev/null +++ b/swift/Sources/DevTunnelsClient/Management/APIResponses.swift @@ -0,0 +1,14 @@ +import Foundation + +/// API response wrapper for list tunnels grouped by region. +struct TunnelListByRegionResponse: Codable { + let value: [TunnelListByRegion]? + let nextLink: String? +} + +/// A group of tunnels in a region. +struct TunnelListByRegion: Codable { + let regionName: String? + let clusterId: String? + let value: [Tunnel]? +} diff --git a/swift/Sources/DevTunnelsClient/Management/DeviceCodeAuth.swift b/swift/Sources/DevTunnelsClient/Management/DeviceCodeAuth.swift new file mode 100644 index 00000000..bf4a44df --- /dev/null +++ b/swift/Sources/DevTunnelsClient/Management/DeviceCodeAuth.swift @@ -0,0 +1,163 @@ +import Foundation + +/// Response from starting a GitHub device code auth flow. +public struct DeviceCodeResponse: Codable, Equatable, Sendable { + /// The device code used for polling (not shown to user). + public let deviceCode: String + + /// The code the user enters at the verification URI. + public let userCode: String + + /// The URL the user visits to enter the code. + public let verificationUri: String + + /// Seconds until the device code expires. + public let expiresIn: Int + + /// Minimum seconds between poll requests. + public let interval: Int + + private enum CodingKeys: String, CodingKey { + case deviceCode = "device_code" + case userCode = "user_code" + case verificationUri = "verification_uri" + case expiresIn = "expires_in" + case interval + } +} + +/// Result of polling for device code authorization completion. +public enum DeviceCodePollResult: Equatable, Sendable { + /// User completed authorization. Contains the GitHub access token. + case accessToken(String) + /// Authorization is still pending — poll again after `interval` seconds. + case pending + /// The device code expired. Start a new flow. + case expired + /// The flow was denied or encountered an error. + case error(String) +} + +/// GitHub device code OAuth flow for Dev Tunnels authentication. +/// +/// Usage: +/// ```swift +/// let response = try await DeviceCodeAuth.start() +/// print("Visit \(response.verificationUri) and enter: \(response.userCode)") +/// +/// while true { +/// try await Task.sleep(for: .seconds(response.interval)) +/// let result = try await DeviceCodeAuth.poll(deviceCode: response.deviceCode) +/// switch result { +/// case .accessToken(let token): // Done! +/// case .pending: continue +/// case .expired: // Restart +/// case .error(let msg): // Handle +/// } +/// } +/// ``` +public struct DeviceCodeAuth: Sendable { + private let httpClient: any HTTPClient + private let serviceProperties: TunnelServiceProperties + + public init( + serviceProperties: TunnelServiceProperties = .production, + httpClient: any HTTPClient = URLSession.shared + ) { + self.httpClient = httpClient + self.serviceProperties = serviceProperties + } + + /// Start a device code auth flow with GitHub. + /// + /// Returns a `DeviceCodeResponse` with the `userCode` to display and `verificationUri` + /// for the user to visit. Then call `poll(deviceCode:)` with the `deviceCode`. + public func start() async throws -> DeviceCodeResponse { + let url = URL(string: "https://github.com/login/device/code")! + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.setValue("application/json", forHTTPHeaderField: "Accept") + + let body = "client_id=\(serviceProperties.gitHubAppClientId)&scope=user:email read:org" + request.httpBody = body.data(using: .utf8) + request.setValue("application/x-www-form-urlencoded", forHTTPHeaderField: "Content-Type") + + let (data, response) = try await httpClient.data(for: request) + + if let httpResponse = response as? HTTPURLResponse, + !(200..<300).contains(httpResponse.statusCode) { + let message = String(data: data, encoding: .utf8) ?? "Unknown error" + throw TunnelManagementError.httpError( + statusCode: httpResponse.statusCode, + message: message + ) + } + + do { + return try JSONDecoder().decode(DeviceCodeResponse.self, from: data) + } catch { + throw TunnelManagementError.decodingError("Failed to decode device code response: \(error)") + } + } + + /// Poll GitHub for device code authorization completion. + /// + /// Call repeatedly with the `deviceCode` from `start()`, + /// waiting at least `interval` seconds between calls. + public func poll(deviceCode: String) async throws -> DeviceCodePollResult { + let url = URL(string: "https://github.com/login/oauth/access_token")! + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.setValue("application/json", forHTTPHeaderField: "Accept") + + let body = [ + "client_id=\(serviceProperties.gitHubAppClientId)", + "device_code=\(deviceCode)", + "grant_type=urn:ietf:params:oauth:grant-type:device_code", + ].joined(separator: "&") + request.httpBody = body.data(using: .utf8) + request.setValue("application/x-www-form-urlencoded", forHTTPHeaderField: "Content-Type") + + let (data, response) = try await httpClient.data(for: request) + + if let httpResponse = response as? HTTPURLResponse, + !(200..<300).contains(httpResponse.statusCode) { + let message = String(data: data, encoding: .utf8) ?? "Unknown error" + throw TunnelManagementError.httpError( + statusCode: httpResponse.statusCode, + message: message + ) + } + + let tokenResponse = try JSONDecoder().decode(GitHubTokenResponse.self, from: data) + + if let token = tokenResponse.accessToken { + return .accessToken(token) + } + + switch tokenResponse.error { + case "authorization_pending", "slow_down": + return .pending + case "expired_token": + return .expired + default: + let message = tokenResponse.errorDescription + ?? tokenResponse.error + ?? "Unknown error" + return .error(message) + } + } +} + +/// Internal response type for GitHub OAuth token endpoint. +struct GitHubTokenResponse: Codable { + let accessToken: String? + let error: String? + let errorDescription: String? + + private enum CodingKeys: String, CodingKey { + case accessToken = "access_token" + case error + case errorDescription = "error_description" + } +} diff --git a/swift/Sources/DevTunnelsClient/Management/HTTPClient.swift b/swift/Sources/DevTunnelsClient/Management/HTTPClient.swift new file mode 100644 index 00000000..bb4832cb --- /dev/null +++ b/swift/Sources/DevTunnelsClient/Management/HTTPClient.swift @@ -0,0 +1,9 @@ +import Foundation + +/// Abstraction over HTTP requests for testability. +/// URLSession conforms via extension; tests inject a mock. +public protocol HTTPClient: Sendable { + func data(for request: URLRequest) async throws -> (Data, URLResponse) +} + +extension URLSession: HTTPClient {} diff --git a/swift/Sources/DevTunnelsClient/Management/TunnelManagementClient.swift b/swift/Sources/DevTunnelsClient/Management/TunnelManagementClient.swift new file mode 100644 index 00000000..418f69fc --- /dev/null +++ b/swift/Sources/DevTunnelsClient/Management/TunnelManagementClient.swift @@ -0,0 +1,360 @@ +import Foundation + +/// Client for the Dev Tunnels management REST API. +/// +/// Supports listing, creating, updating, and deleting tunnels and ports. +/// Uses protocol-based HTTP abstraction for testability. +public struct TunnelManagementClient: Sendable { + private let httpClient: any HTTPClient + private let accessToken: String + private let serviceProperties: TunnelServiceProperties + private let userAgent: String + + /// Creates a new management client. + /// + /// - Parameters: + /// - accessToken: Authentication token. Can be either: + /// - A full Authorization header value with scheme prefix (e.g. `"github "`, `"aad "`) + /// - A raw GitHub token (will be prefixed with `"github "` automatically) + /// - serviceProperties: Service endpoint configuration. Defaults to production. + /// - httpClient: HTTP client for requests. Defaults to URLSession.shared. + /// - userAgent: User-Agent string for requests. + public init( + accessToken: String, + serviceProperties: TunnelServiceProperties = .production, + httpClient: any HTTPClient = URLSession.shared, + userAgent: String = "Dev-Tunnels-Swift-Client/0.1.0" + ) { + self.accessToken = accessToken + self.serviceProperties = serviceProperties + self.httpClient = httpClient + self.userAgent = userAgent + } + + // MARK: - Tunnel operations + + /// Lists all tunnels accessible to the authenticated user. + /// + /// - Parameter clusterId: Optional cluster ID to scope the request. If nil, lists globally. + /// - Returns: Array of tunnels across all regions. + public func listTunnels(clusterId: String? = nil) async throws -> [Tunnel] { + var queryItems = [URLQueryItem]() + if clusterId == nil { + queryItems.append(URLQueryItem(name: "global", value: "true")) + } + queryItems.append(URLQueryItem(name: "api-version", value: serviceProperties.apiVersion)) + + let url = try buildURL(clusterId: clusterId, path: "/tunnels", queryItems: queryItems) + let request = buildRequest(url: url, method: "GET") + + let (data, response) = try await httpClient.data(for: request) + try checkResponse(response, data: data) + + do { + let regionResponse = try JSONDecoder().decode(TunnelListByRegionResponse.self, from: data) + var tunnels = [Tunnel]() + for region in regionResponse.value ?? [] { + tunnels.append(contentsOf: region.value ?? []) + } + return tunnels + } catch { + throw TunnelManagementError.decodingError("Failed to decode tunnel list: \(error)") + } + } + + /// Gets detailed information about a specific tunnel. + /// + /// - Parameters: + /// - clusterId: Cluster ID where the tunnel lives. + /// - tunnelId: The tunnel ID. + /// - options: Additional request options (ports, token scopes, etc.). + /// - Returns: Tunnel with requested details (ports, endpoints, access tokens). + public func getTunnel( + clusterId: String, + tunnelId: String, + options: TunnelRequestOptions = TunnelRequestOptions() + ) async throws -> Tunnel { + var queryItems = options.queryItems() + queryItems.append(URLQueryItem(name: "api-version", value: serviceProperties.apiVersion)) + + let url = try buildURL(clusterId: clusterId, path: "/tunnels/\(tunnelId)", queryItems: queryItems) + let request = buildRequest(url: url, method: "GET") + + let (data, response) = try await httpClient.data(for: request) + try checkResponse(response, data: data) + + do { + return try JSONDecoder().decode(Tunnel.self, from: data) + } catch { + throw TunnelManagementError.decodingError("Failed to decode tunnel: \(error)") + } + } + + private static let createNameRetries = 3 + + /// Creates a new tunnel. + /// + /// Generates a tunnel ID client-side and uses PUT with `If-Not-Match: *` + /// to ensure creation (not update). Retries with a new ID on 409 Conflict. + /// + /// - Parameters: + /// - tunnel: Tunnel properties to set (name, description, labels, etc.). + /// - options: Additional request options (token scopes, etc.). + /// - Returns: The created tunnel with server-assigned cluster. + public func createTunnel( + _ tunnel: Tunnel, + options: TunnelRequestOptions = TunnelRequestOptions() + ) async throws -> Tunnel { + var tunnel = tunnel + let idGenerated = tunnel.tunnelId == nil || tunnel.tunnelId!.isEmpty + if idGenerated { + tunnel.tunnelId = Self.generateTunnelId() + } + + var queryItems = options.queryItems() + queryItems.append(URLQueryItem(name: "api-version", value: serviceProperties.apiVersion)) + + for retry in 0.. Tunnel { + guard let clusterId = tunnel.clusterId, !clusterId.isEmpty else { + throw TunnelManagementError.invalidRequest("clusterId is required for update") + } + guard let tunnelId = tunnel.tunnelId, !tunnelId.isEmpty else { + throw TunnelManagementError.invalidRequest("tunnelId is required for update") + } + + var queryItems = options.queryItems() + queryItems.append(URLQueryItem(name: "api-version", value: serviceProperties.apiVersion)) + + let url = try buildURL(clusterId: clusterId, path: "/tunnels/\(tunnelId)", queryItems: queryItems) + var request = buildRequest(url: url, method: "PUT") + request.httpBody = try JSONEncoder().encode(Self.tunnelForRequest(tunnel)) + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + request.setValue("*", forHTTPHeaderField: "If-Match") + + let (data, response) = try await httpClient.data(for: request) + try checkResponse(response, data: data) + + do { + return try JSONDecoder().decode(Tunnel.self, from: data) + } catch { + throw TunnelManagementError.decodingError("Failed to decode updated tunnel: \(error)") + } + } + + /// Deletes a tunnel. + /// + /// - Parameters: + /// - clusterId: Cluster ID where the tunnel lives. + /// - tunnelId: The tunnel ID to delete. + public func deleteTunnel( + clusterId: String, + tunnelId: String + ) async throws { + var queryItems = [URLQueryItem]() + queryItems.append(URLQueryItem(name: "api-version", value: serviceProperties.apiVersion)) + + let url = try buildURL(clusterId: clusterId, path: "/tunnels/\(tunnelId)", queryItems: queryItems) + let request = buildRequest(url: url, method: "DELETE") + + let (data, response) = try await httpClient.data(for: request) + try checkResponse(response, data: data) + } + + // MARK: - Port operations + + /// Creates or updates a port on a tunnel. + /// + /// - Parameters: + /// - clusterId: Cluster ID where the tunnel lives. + /// - tunnelId: The tunnel ID. + /// - port: Port properties to set. + /// - options: Additional request options. + /// - Returns: The created or updated port. + public func createTunnelPort( + clusterId: String, + tunnelId: String, + port: TunnelPort, + options: TunnelRequestOptions = TunnelRequestOptions() + ) async throws -> TunnelPort { + var queryItems = options.queryItems() + queryItems.append(URLQueryItem(name: "api-version", value: serviceProperties.apiVersion)) + + let url = try buildURL( + clusterId: clusterId, + path: "/tunnels/\(tunnelId)/ports/\(port.portNumber)", + queryItems: queryItems + ) + var request = buildRequest(url: url, method: "PUT") + request.httpBody = try JSONEncoder().encode(port) + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + request.setValue("*", forHTTPHeaderField: "If-Not-Match") + + let (data, response) = try await httpClient.data(for: request) + try checkResponse(response, data: data) + + do { + return try JSONDecoder().decode(TunnelPort.self, from: data) + } catch { + throw TunnelManagementError.decodingError("Failed to decode tunnel port: \(error)") + } + } + + /// Deletes a port from a tunnel. + /// + /// - Parameters: + /// - clusterId: Cluster ID where the tunnel lives. + /// - tunnelId: The tunnel ID. + /// - portNumber: The port number to delete. + public func deleteTunnelPort( + clusterId: String, + tunnelId: String, + portNumber: UInt16 + ) async throws { + var queryItems = [URLQueryItem]() + queryItems.append(URLQueryItem(name: "api-version", value: serviceProperties.apiVersion)) + + let url = try buildURL( + clusterId: clusterId, + path: "/tunnels/\(tunnelId)/ports/\(portNumber)", + queryItems: queryItems + ) + let request = buildRequest(url: url, method: "DELETE") + + let (data, response) = try await httpClient.data(for: request) + try checkResponse(response, data: data) + } + + // MARK: - Private helpers + + /// Strips read-only and sub-resource fields before sending to the API. + /// The server rejects requests with ports/endpoints/status in the body. + private static func tunnelForRequest(_ tunnel: Tunnel) -> Tunnel { + Tunnel( + clusterId: tunnel.clusterId, + tunnelId: tunnel.tunnelId, + name: tunnel.name, + description: tunnel.description, + labels: tunnel.labels, + domain: tunnel.domain + ) + } + + private static let adjectives = [ + "fun", "happy", "interesting", "neat", "peaceful", "puzzled", "kind", + "joyful", "new", "giant", "sneaky", "quick", "majestic", "jolly", + "fancy", "tidy", "swift", "silent", "amusing", "spiffy", + ] + private static let nouns = [ + "pond", "hill", "mountain", "field", "fog", "ant", "dog", "cat", + "shoe", "plane", "chair", "book", "ocean", "lake", "river", "horse", + ] + private static let idChars = Array("bcdfghjklmnpqrstvwxz0123456789") + + /// Generates a tunnel ID in the same format as the official SDKs. + /// Format: `{adjective}-{noun}-{7 random chars}` (e.g., "swift-lake-bcd3f7k") + static func generateTunnelId() -> String { + let adj = adjectives.randomElement()! + let noun = nouns.randomElement()! + let suffix = String((0..<7).map { _ in idChars.randomElement()! }) + return "\(adj)-\(noun)-\(suffix)" + } + + private func buildURL(clusterId: String?, path: String, queryItems: [URLQueryItem]) throws -> URL { + guard let serviceURL = URL(string: serviceProperties.serviceUri), + let host = serviceURL.host() else { + throw TunnelManagementError.invalidRequest("Invalid service URI: \(serviceProperties.serviceUri)") + } + + var baseHost = host + if let clusterId, !clusterId.isEmpty { + baseHost = "\(clusterId).\(baseHost)".replacingOccurrences(of: "global.", with: "") + } + + var components = URLComponents() + components.scheme = "https" + components.host = baseHost + components.path = path + components.queryItems = queryItems + + guard let url = components.url else { + throw TunnelManagementError.invalidRequest("Failed to build URL for path: \(path)") + } + return url + } + + private static let knownAuthSchemes = ["github", "aad", "bearer", "tunnel", "tunnelplan"] + + private func buildRequest(url: URL, method: String) -> URLRequest { + var request = URLRequest(url: url) + request.httpMethod = method + + // If token already has a known scheme prefix, use as-is; otherwise assume GitHub. + let lowerToken = accessToken.lowercased() + let hasScheme = Self.knownAuthSchemes.contains { lowerToken.hasPrefix($0 + " ") } + let authHeader = hasScheme ? accessToken : "github \(accessToken)" + request.setValue(authHeader, forHTTPHeaderField: "Authorization") + + request.setValue(userAgent, forHTTPHeaderField: "User-Agent") + request.setValue("application/json", forHTTPHeaderField: "Accept") + return request + } + + private func checkResponse(_ response: URLResponse, data: Data) throws { + guard let httpResponse = response as? HTTPURLResponse else { return } + guard (200..<300).contains(httpResponse.statusCode) else { + let message = String(data: data, encoding: .utf8) ?? "Unknown error" + // Include response headers in error for debugging auth issues + var details = message + if let wwwAuth = httpResponse.value(forHTTPHeaderField: "WWW-Authenticate") { + details += " | WWW-Authenticate: \(wwwAuth)" + } + if let requestUrl = httpResponse.url?.absoluteString { + details += " | URL: \(requestUrl)" + } + throw TunnelManagementError.httpError( + statusCode: httpResponse.statusCode, + message: details + ) + } + } +} diff --git a/swift/Sources/DevTunnelsClient/Management/TunnelManagementError.swift b/swift/Sources/DevTunnelsClient/Management/TunnelManagementError.swift new file mode 100644 index 00000000..934e74cb --- /dev/null +++ b/swift/Sources/DevTunnelsClient/Management/TunnelManagementError.swift @@ -0,0 +1,24 @@ +import Foundation + +/// Errors from the tunnel management API. +public enum TunnelManagementError: Error, Equatable, LocalizedError { + /// HTTP error with status code and message body. + case httpError(statusCode: Int, message: String) + + /// Failed to decode the response body. + case decodingError(String) + + /// Invalid request parameters. + case invalidRequest(String) + + public var errorDescription: String? { + switch self { + case .httpError(let statusCode, let message): + return "HTTP \(statusCode): \(message)" + case .decodingError(let message): + return "Decoding error: \(message)" + case .invalidRequest(let message): + return "Invalid request: \(message)" + } + } +} diff --git a/swift/Sources/DevTunnelsClient/Management/TunnelRequestOptions.swift b/swift/Sources/DevTunnelsClient/Management/TunnelRequestOptions.swift new file mode 100644 index 00000000..6376e0d9 --- /dev/null +++ b/swift/Sources/DevTunnelsClient/Management/TunnelRequestOptions.swift @@ -0,0 +1,30 @@ +import Foundation + +/// Options for tunnel API requests. +public struct TunnelRequestOptions: Sendable { + /// Include ports in the response. + public var includePorts: Bool + + /// Token scopes to request (e.g., ["connect"]). + public var tokenScopes: [String] + + public init( + includePorts: Bool = false, + tokenScopes: [String] = [] + ) { + self.includePorts = includePorts + self.tokenScopes = tokenScopes + } + + /// Converts options to URL query items. + func queryItems() -> [URLQueryItem] { + var items = [URLQueryItem]() + if includePorts { + items.append(URLQueryItem(name: "includePorts", value: "true")) + } + for scope in tokenScopes { + items.append(URLQueryItem(name: "tokenScopes", value: scope)) + } + return items + } +} diff --git a/swift/Sources/DevTunnelsClient/Management/TunnelServiceProperties.swift b/swift/Sources/DevTunnelsClient/Management/TunnelServiceProperties.swift new file mode 100644 index 00000000..21db43f8 --- /dev/null +++ b/swift/Sources/DevTunnelsClient/Management/TunnelServiceProperties.swift @@ -0,0 +1,29 @@ +/// Service properties for Dev Tunnels environments. +public struct TunnelServiceProperties: Sendable { + /// Base URI for the tunnel service. + public let serviceUri: String + + /// GitHub OAuth app client ID for device code auth. + public let gitHubAppClientId: String + + /// API version for requests. + public let apiVersion: String + + public init( + serviceUri: String, + gitHubAppClientId: String, + apiVersion: String = "2023-09-27-preview" + ) { + self.serviceUri = serviceUri + self.gitHubAppClientId = gitHubAppClientId + self.apiVersion = apiVersion + } +} + +extension TunnelServiceProperties { + /// Production service properties. + public static let production = TunnelServiceProperties( + serviceUri: "https://global.rel.tunnels.api.visualstudio.com", + gitHubAppClientId: "Iv1.e7b89e013f801f03" + ) +} diff --git a/swift/Tests/DevTunnelsClientTests/ContractTests.swift b/swift/Tests/DevTunnelsClientTests/ContractTests.swift new file mode 100644 index 00000000..fcd87f40 --- /dev/null +++ b/swift/Tests/DevTunnelsClientTests/ContractTests.swift @@ -0,0 +1,259 @@ +import XCTest +@testable import DevTunnelsClient + +final class ContractTests: XCTestCase { + + // MARK: - Tunnel + + func testTunnelDefaultInit() { + let tunnel = Tunnel() + XCTAssertNil(tunnel.clusterId) + XCTAssertNil(tunnel.tunnelId) + XCTAssertNil(tunnel.name) + XCTAssertNil(tunnel.ports) + XCTAssertNil(tunnel.endpoints) + XCTAssertNil(tunnel.status) + } + + func testTunnelFullInit() { + let tunnel = Tunnel( + clusterId: "usw2", + tunnelId: "abc123", + name: "my-tunnel", + description: "Test tunnel", + labels: ["dev", "test"], + accessTokens: ["connect": "jwt-token"], + ports: [TunnelPort(portNumber: 8080)] + ) + XCTAssertEqual(tunnel.clusterId, "usw2") + XCTAssertEqual(tunnel.tunnelId, "abc123") + XCTAssertEqual(tunnel.name, "my-tunnel") + XCTAssertEqual(tunnel.description, "Test tunnel") + XCTAssertEqual(tunnel.labels, ["dev", "test"]) + XCTAssertEqual(tunnel.accessTokens?["connect"], "jwt-token") + XCTAssertEqual(tunnel.ports?.count, 1) + XCTAssertEqual(tunnel.ports?[0].portNumber, 8080) + } + + func testTunnelEquality() { + let a = Tunnel(clusterId: "usw2", tunnelId: "abc") + let b = Tunnel(clusterId: "usw2", tunnelId: "abc") + let c = Tunnel(clusterId: "usw2", tunnelId: "xyz") + XCTAssertEqual(a, b) + XCTAssertNotEqual(a, c) + } + + func testTunnelCodableRoundTrip() throws { + let original = Tunnel( + clusterId: "usw2", + tunnelId: "abc123", + name: "my-tunnel", + accessTokens: ["connect": "token123"], + ports: [TunnelPort(portNumber: 8080, name: "web")] + ) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(Tunnel.self, from: data) + XCTAssertEqual(original, decoded) + } + + func testTunnelDecodesFromAPIResponse() throws { + let json = """ + { + "clusterId": "usw2", + "tunnelId": "abc123", + "name": "my-tunnel", + "description": "A test tunnel", + "labels": ["dev"], + "accessTokens": { + "connect": "eyJ..." + }, + "status": { + "hostConnectionCount": { + "current": 1, + "limit": 5 + } + }, + "endpoints": [ + { + "connectionMode": "TunnelRelay", + "hostId": "host-1", + "clientRelayUri": "wss://usw2-data.rel.tunnels.api.visualstudio.com/..." + } + ], + "ports": [ + { + "portNumber": 8080, + "name": "web", + "protocol": "http" + }, + { + "portNumber": 31546 + } + ] + } + """ + let tunnel = try JSONDecoder().decode(Tunnel.self, from: Data(json.utf8)) + XCTAssertEqual(tunnel.clusterId, "usw2") + XCTAssertEqual(tunnel.tunnelId, "abc123") + XCTAssertEqual(tunnel.name, "my-tunnel") + XCTAssertEqual(tunnel.labels, ["dev"]) + XCTAssertEqual(tunnel.accessTokens?["connect"], "eyJ...") + XCTAssertEqual(tunnel.status?.hostConnectionCount?.current, 1) + XCTAssertEqual(tunnel.status?.hostConnectionCount?.limit, 5) + XCTAssertEqual(tunnel.endpoints?.count, 1) + XCTAssertEqual(tunnel.endpoints?[0].connectionMode, .tunnelRelay) + XCTAssertEqual(tunnel.endpoints?[0].clientRelayUri, "wss://usw2-data.rel.tunnels.api.visualstudio.com/...") + XCTAssertEqual(tunnel.ports?.count, 2) + XCTAssertEqual(tunnel.ports?[0].portNumber, 8080) + XCTAssertEqual(tunnel.ports?[0].name, "web") + XCTAssertEqual(tunnel.ports?[0].protocol, .http) + XCTAssertEqual(tunnel.ports?[1].portNumber, 31546) + } + + func testTunnelIgnoresUnknownFields() throws { + let json = """ + { + "clusterId": "usw2", + "tunnelId": "abc123", + "someNewField": "unknown", + "anotherField": 42 + } + """ + let tunnel = try JSONDecoder().decode(Tunnel.self, from: Data(json.utf8)) + XCTAssertEqual(tunnel.clusterId, "usw2") + XCTAssertEqual(tunnel.tunnelId, "abc123") + } + + func testTunnelMinimalJSON() throws { + let json = "{}" + let tunnel = try JSONDecoder().decode(Tunnel.self, from: Data(json.utf8)) + XCTAssertNil(tunnel.clusterId) + XCTAssertNil(tunnel.tunnelId) + } + + // MARK: - TunnelEndpoint + + func testEndpointCodableRoundTrip() throws { + let original = TunnelEndpoint( + id: "ep-1", + connectionMode: .tunnelRelay, + hostId: "host-1", + hostPublicKeys: ["ssh-ed25519 AAAA..."], + clientRelayUri: "wss://relay.example.com" + ) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(TunnelEndpoint.self, from: data) + XCTAssertEqual(original, decoded) + } + + func testEndpointLocalNetwork() throws { + let json = """ + { + "connectionMode": "LocalNetwork", + "hostId": "local-host" + } + """ + let ep = try JSONDecoder().decode(TunnelEndpoint.self, from: Data(json.utf8)) + XCTAssertEqual(ep.connectionMode, .localNetwork) + XCTAssertEqual(ep.hostId, "local-host") + XCTAssertNil(ep.clientRelayUri) + } + + func testEndpointPortUriFormat() { + let ep = TunnelEndpoint( + portUriFormat: "https://abc123-{port}.usw2.devtunnels.ms" + ) + let url = ep.portUriFormat?.replacingOccurrences( + of: tunnelEndpointPortToken, with: "8080" + ) + XCTAssertEqual(url, "https://abc123-8080.usw2.devtunnels.ms") + } + + // MARK: - TunnelPort + + func testPortCodableRoundTrip() throws { + let original = TunnelPort( + portNumber: 3000, + name: "dev-server", + protocol: .https + ) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(TunnelPort.self, from: data) + XCTAssertEqual(original, decoded) + } + + func testPortMinimalJSON() throws { + let json = """ + { "portNumber": 443 } + """ + let port = try JSONDecoder().decode(TunnelPort.self, from: Data(json.utf8)) + XCTAssertEqual(port.portNumber, 443) + XCTAssertNil(port.name) + XCTAssertNil(port.protocol) + } + + func testPortProtocolValues() throws { + for proto in ["auto", "http", "https", "rdp", "ssh"] { + let json = """ + { "portNumber": 1, "protocol": "\(proto)" } + """ + let port = try JSONDecoder().decode(TunnelPort.self, from: Data(json.utf8)) + XCTAssertNotNil(port.protocol, "Protocol '\(proto)' should decode") + } + } + + // MARK: - TunnelStatus & ResourceStatus + + func testStatusCodableRoundTrip() throws { + let original = TunnelStatus( + portCount: ResourceStatus(current: 3, limit: 10), + hostConnectionCount: ResourceStatus(current: 1), + clientConnectionCount: ResourceStatus(current: 0, limit: 100) + ) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(TunnelStatus.self, from: data) + XCTAssertEqual(original, decoded) + } + + func testResourceStatusNoLimit() throws { + let json = """ + { "current": 5 } + """ + let rs = try JSONDecoder().decode(ResourceStatus.self, from: Data(json.utf8)) + XCTAssertEqual(rs.current, 5) + XCTAssertNil(rs.limit) + } + + // MARK: - Enums + + func testConnectionModeValues() throws { + let jsonRelay = "\"TunnelRelay\"" + let jsonLocal = "\"LocalNetwork\"" + let relay = try JSONDecoder().decode(TunnelConnectionMode.self, from: Data(jsonRelay.utf8)) + let local = try JSONDecoder().decode(TunnelConnectionMode.self, from: Data(jsonLocal.utf8)) + XCTAssertEqual(relay, .tunnelRelay) + XCTAssertEqual(local, .localNetwork) + } + + func testAccessScopeConstants() { + XCTAssertEqual(TunnelAccessScopes.connect, "connect") + XCTAssertEqual(TunnelAccessScopes.host, "host") + XCTAssertEqual(TunnelAccessScopes.manage, "manage") + XCTAssertEqual(TunnelAccessScopes.managePorts, "manage:ports") + XCTAssertEqual(TunnelAccessScopes.create, "create") + XCTAssertEqual(TunnelAccessScopes.inspect, "inspect") + } + + // MARK: - Encoding output + + func testTunnelEncodesCorrectJSONKeys() throws { + let tunnel = Tunnel(clusterId: "usw2", tunnelId: "t1") + let data = try JSONEncoder().encode(tunnel) + let dict = try JSONSerialization.jsonObject(with: data) as! [String: Any] + XCTAssertNotNil(dict["clusterId"]) + XCTAssertNotNil(dict["tunnelId"]) + // Verify camelCase keys (not snake_case) + XCTAssertNil(dict["cluster_id"]) + XCTAssertNil(dict["tunnel_id"]) + } +} diff --git a/swift/Tests/DevTunnelsClientTests/DeviceCodeAuthTests.swift b/swift/Tests/DevTunnelsClientTests/DeviceCodeAuthTests.swift new file mode 100644 index 00000000..db293398 --- /dev/null +++ b/swift/Tests/DevTunnelsClientTests/DeviceCodeAuthTests.swift @@ -0,0 +1,223 @@ +import XCTest +@testable import DevTunnelsClient +import Foundation + +final class DeviceCodeAuthTests: XCTestCase { + let mockHttp = MockHTTPClient() + + func makeAuth() -> DeviceCodeAuth { + DeviceCodeAuth( + serviceProperties: .production, + httpClient: mockHttp + ) + } + + // MARK: - start() + + func testStartDecodesResponse() async throws { + let json = """ + { + "device_code": "dc-123", + "user_code": "ABCD-1234", + "verification_uri": "https://github.com/login/device", + "expires_in": 900, + "interval": 5 + } + """ + mockHttp.addResponse(pathContains: "/login/device/code", data: Data(json.utf8)) + + let auth = makeAuth() + let response = try await auth.start() + + XCTAssertEqual(response.deviceCode, "dc-123") + XCTAssertEqual(response.userCode, "ABCD-1234") + XCTAssertEqual(response.verificationUri, "https://github.com/login/device") + XCTAssertEqual(response.expiresIn, 900) + XCTAssertEqual(response.interval, 5) + } + + func testStartSendsCorrectRequest() async throws { + let json = """ + { + "device_code": "dc", + "user_code": "UC", + "verification_uri": "https://github.com/login/device", + "expires_in": 900, + "interval": 5 + } + """ + mockHttp.addResponse(pathContains: "/login/device/code", data: Data(json.utf8)) + + let auth = makeAuth() + _ = try await auth.start() + + XCTAssertEqual(mockHttp.requests.count, 1) + let request = mockHttp.requests[0] + XCTAssertEqual(request.httpMethod, "POST") + XCTAssertEqual(request.url?.host(), "github.com") + XCTAssertEqual(request.url?.path(), "/login/device/code") + XCTAssertEqual(request.value(forHTTPHeaderField: "Accept"), "application/json") + XCTAssertEqual(request.value(forHTTPHeaderField: "Content-Type"), "application/x-www-form-urlencoded") + + let body = String(data: request.httpBody ?? Data(), encoding: .utf8) ?? "" + XCTAssertTrue(body.contains("client_id=")) + XCTAssertTrue(body.contains("scope=user:email")) + } + + func testStartHttpErrorThrows() async throws { + mockHttp.addResponse( + pathContains: "/login/device/code", + data: Data("Bad Request".utf8), + statusCode: 400 + ) + + let auth = makeAuth() + do { + _ = try await auth.start() + XCTFail("Should have thrown") + } catch let error as TunnelManagementError { + if case .httpError(let statusCode, _) = error { + XCTAssertEqual(statusCode, 400) + } else { + XCTFail("Wrong error: \(error)") + } + } + } + + // MARK: - poll() + + func testPollReturnsAccessToken() async throws { + let json = """ + { "access_token": "gho_abc123" } + """ + mockHttp.addResponse(pathContains: "/login/oauth/access_token", data: Data(json.utf8)) + + let auth = makeAuth() + let result = try await auth.poll(deviceCode: "dc-123") + + XCTAssertEqual(result, .accessToken("gho_abc123")) + } + + func testPollReturnsPending() async throws { + let json = """ + { "error": "authorization_pending" } + """ + mockHttp.addResponse(pathContains: "/login/oauth/access_token", data: Data(json.utf8)) + + let auth = makeAuth() + let result = try await auth.poll(deviceCode: "dc-123") + + XCTAssertEqual(result, .pending) + } + + func testPollSlowDownReturnsPending() async throws { + let json = """ + { "error": "slow_down" } + """ + mockHttp.addResponse(pathContains: "/login/oauth/access_token", data: Data(json.utf8)) + + let auth = makeAuth() + let result = try await auth.poll(deviceCode: "dc-123") + + XCTAssertEqual(result, .pending) + } + + func testPollReturnsExpired() async throws { + let json = """ + { "error": "expired_token" } + """ + mockHttp.addResponse(pathContains: "/login/oauth/access_token", data: Data(json.utf8)) + + let auth = makeAuth() + let result = try await auth.poll(deviceCode: "dc-123") + + XCTAssertEqual(result, .expired) + } + + func testPollReturnsErrorWithDescription() async throws { + let json = """ + { "error": "access_denied", "error_description": "The user denied the request" } + """ + mockHttp.addResponse(pathContains: "/login/oauth/access_token", data: Data(json.utf8)) + + let auth = makeAuth() + let result = try await auth.poll(deviceCode: "dc-123") + + XCTAssertEqual(result, .error("The user denied the request")) + } + + func testPollReturnsErrorWithoutDescription() async throws { + let json = """ + { "error": "access_denied" } + """ + mockHttp.addResponse(pathContains: "/login/oauth/access_token", data: Data(json.utf8)) + + let auth = makeAuth() + let result = try await auth.poll(deviceCode: "dc-123") + + XCTAssertEqual(result, .error("access_denied")) + } + + func testPollSendsCorrectRequest() async throws { + mockHttp.addResponse(pathContains: "/login/oauth/access_token", data: Data(""" + { "error": "authorization_pending" } + """.utf8)) + + let auth = makeAuth() + _ = try await auth.poll(deviceCode: "my-device-code") + + XCTAssertEqual(mockHttp.requests.count, 1) + let request = mockHttp.requests[0] + XCTAssertEqual(request.httpMethod, "POST") + XCTAssertEqual(request.url?.path(), "/login/oauth/access_token") + + let body = String(data: request.httpBody ?? Data(), encoding: .utf8) ?? "" + XCTAssertTrue(body.contains("device_code=my-device-code")) + XCTAssertTrue(body.contains("grant_type=urn:ietf:params:oauth:grant-type:device_code")) + } + + func testPollHttpErrorThrows() async throws { + mockHttp.addResponse( + pathContains: "/login/oauth/access_token", + data: Data("Server Error".utf8), + statusCode: 500 + ) + + let auth = makeAuth() + do { + _ = try await auth.poll(deviceCode: "dc-123") + XCTFail("Should have thrown") + } catch let error as TunnelManagementError { + if case .httpError(let statusCode, _) = error { + XCTAssertEqual(statusCode, 500) + } else { + XCTFail("Wrong error: \(error)") + } + } + } + + // MARK: - DeviceCodeResponse Codable + + func testDeviceCodeResponseCodable() throws { + let json = """ + { + "device_code": "dc", + "user_code": "UC", + "verification_uri": "https://example.com", + "expires_in": 600, + "interval": 10 + } + """ + let response = try JSONDecoder().decode(DeviceCodeResponse.self, from: Data(json.utf8)) + XCTAssertEqual(response.deviceCode, "dc") + XCTAssertEqual(response.userCode, "UC") + XCTAssertEqual(response.verificationUri, "https://example.com") + XCTAssertEqual(response.expiresIn, 600) + XCTAssertEqual(response.interval, 10) + + // Round-trip + let encoded = try JSONEncoder().encode(response) + let decoded = try JSONDecoder().decode(DeviceCodeResponse.self, from: encoded) + XCTAssertEqual(response, decoded) + } +} diff --git a/swift/Tests/DevTunnelsClientTests/IntegrationTests.swift b/swift/Tests/DevTunnelsClientTests/IntegrationTests.swift new file mode 100644 index 00000000..95ae8158 --- /dev/null +++ b/swift/Tests/DevTunnelsClientTests/IntegrationTests.swift @@ -0,0 +1,128 @@ +import XCTest +@testable import DevTunnelsClient +import Foundation + +/// Live integration tests for the tunnel management API. +/// +/// These tests run against the real Dev Tunnels service and require authentication. +/// They are **skipped** when no token is configured. +/// +/// To run: +/// 1. Install devtunnels CLI: `brew install --cask devtunnel` +/// 2. Login: `devtunnel user login` +/// 3. Get token: `devtunnel user show --verbose` (copy the full token line, e.g. "github ") +/// 4. Run: +/// ``` +/// DEV_TUNNELS_TOKEN="github ghp_xxxx" swift test --filter IntegrationTests +/// ``` +final class IntegrationTests: XCTestCase { + + private static func userToken() -> String? { + let token = ProcessInfo.processInfo.environment["DEV_TUNNELS_TOKEN"] + if let token, !token.isEmpty { + return token + } + return nil + } + + private func skipIfNoToken() throws -> String { + guard let token = Self.userToken() else { + throw XCTSkip("No DEV_TUNNELS_TOKEN set — skipping live integration test") + } + return token + } + + // MARK: - Full tunnel CRUD lifecycle + + func testTunnelCRUDLifecycle() async throws { + let token = try skipIfNoToken() + let client = TunnelManagementClient(accessToken: token) + + // 1. Create a tunnel + let created = try await client.createTunnel( + Tunnel(description: "swift-sdk-integration-test"), + options: TunnelRequestOptions(tokenScopes: [TunnelAccessScopes.manage]) + ) + XCTAssertNotNil(created.tunnelId, "Server should assign a tunnelId") + XCTAssertNotNil(created.clusterId, "Server should assign a clusterId") + + let tunnelId = created.tunnelId! + let clusterId = created.clusterId! + + // Cleanup: always delete at the end + addTeardownBlock { + try? await client.deleteTunnel(clusterId: clusterId, tunnelId: tunnelId) + } + + // 2. Get the tunnel back + let fetched = try await client.getTunnel( + clusterId: clusterId, + tunnelId: tunnelId, + options: TunnelRequestOptions(includePorts: true) + ) + XCTAssertEqual(fetched.tunnelId, tunnelId) + XCTAssertEqual(fetched.description, "swift-sdk-integration-test") + + // 3. Update the tunnel + var toUpdate = fetched + toUpdate.description = "updated-by-swift-sdk" + let updated = try await client.updateTunnel(toUpdate) + XCTAssertEqual(updated.description, "updated-by-swift-sdk") + + // 4. Create a port + let port = try await client.createTunnelPort( + clusterId: clusterId, + tunnelId: tunnelId, + port: TunnelPort(portNumber: 8080, name: "web", protocol: .https) + ) + XCTAssertEqual(port.portNumber, 8080) + + // 5. Verify port appears on tunnel + let withPorts = try await client.getTunnel( + clusterId: clusterId, + tunnelId: tunnelId, + options: TunnelRequestOptions(includePorts: true) + ) + XCTAssertEqual(withPorts.ports?.count, 1) + XCTAssertEqual(withPorts.ports?.first?.portNumber, 8080) + + // 6. Delete the port + try await client.deleteTunnelPort( + clusterId: clusterId, + tunnelId: tunnelId, + portNumber: 8080 + ) + + // 7. Verify port is gone + let afterPortDelete = try await client.getTunnel( + clusterId: clusterId, + tunnelId: tunnelId, + options: TunnelRequestOptions(includePorts: true) + ) + XCTAssertEqual(afterPortDelete.ports?.count ?? 0, 0) + + // 8. Delete the tunnel + try await client.deleteTunnel(clusterId: clusterId, tunnelId: tunnelId) + + // 9. Verify tunnel is gone (should 404) + do { + _ = try await client.getTunnel(clusterId: clusterId, tunnelId: tunnelId) + XCTFail("Should have thrown 404 after delete") + } catch let error as TunnelManagementError { + if case .httpError(let statusCode, _) = error { + XCTAssertEqual(statusCode, 404) + } + } + } + + // MARK: - List tunnels + + func testListTunnels() async throws { + let token = try skipIfNoToken() + let client = TunnelManagementClient(accessToken: token) + + let tunnels = try await client.listTunnels() + // Just verify it doesn't throw and returns an array + XCTAssertTrue(tunnels.count >= 0) + } +} diff --git a/swift/Tests/DevTunnelsClientTests/ManagementClientTests.swift b/swift/Tests/DevTunnelsClientTests/ManagementClientTests.swift new file mode 100644 index 00000000..f4f8ebfa --- /dev/null +++ b/swift/Tests/DevTunnelsClientTests/ManagementClientTests.swift @@ -0,0 +1,539 @@ +import XCTest +@testable import DevTunnelsClient +import Foundation + +// MARK: - Mock HTTP Client + +/// Mock HTTP client that returns canned responses based on URL path. +final class MockHTTPClient: HTTPClient, @unchecked Sendable { + struct CannedResponse { + let data: Data + let statusCode: Int + } + + private(set) var requests: [URLRequest] = [] + private var responses: [String: CannedResponse] = [:] + var defaultResponse: CannedResponse? + + func addResponse(pathContains: String, data: Data, statusCode: Int = 200) { + responses[pathContains] = CannedResponse(data: data, statusCode: statusCode) + } + + func data(for request: URLRequest) async throws -> (Data, URLResponse) { + requests.append(request) + let urlString = request.url?.absoluteString ?? "" + var response: CannedResponse? + for (pathKey, cannedResponse) in responses { + if urlString.contains(pathKey) { + response = cannedResponse + break + } + } + let resp = response ?? defaultResponse ?? CannedResponse(data: Data(), statusCode: 200) + let httpResponse = HTTPURLResponse( + url: request.url!, + statusCode: resp.statusCode, + httpVersion: nil, + headerFields: nil + )! + return (resp.data, httpResponse) + } +} + +// MARK: - Tests + +final class ManagementClientTests: XCTestCase { + let mockHttp = MockHTTPClient() + + func makeClient(token: String = "test-github-token") -> TunnelManagementClient { + TunnelManagementClient( + accessToken: token, + serviceProperties: .production, + httpClient: mockHttp, + userAgent: "Test/1.0" + ) + } + + // MARK: - listTunnels + + func testListTunnelsDecodesRegionResponse() async throws { + let json = """ + { + "value": [ + { + "regionName": "US West 2", + "clusterId": "usw2", + "value": [ + { "tunnelId": "t1", "clusterId": "usw2", "name": "tunnel-one" }, + { "tunnelId": "t2", "clusterId": "usw2" } + ] + }, + { + "regionName": "Europe West", + "clusterId": "euw", + "value": [ + { "tunnelId": "t3", "clusterId": "euw", "name": "tunnel-three" } + ] + } + ] + } + """ + mockHttp.addResponse(pathContains: "/tunnels", data: Data(json.utf8)) + + let client = makeClient() + let tunnels = try await client.listTunnels() + + XCTAssertEqual(tunnels.count, 3) + XCTAssertEqual(tunnels[0].tunnelId, "t1") + XCTAssertEqual(tunnels[0].name, "tunnel-one") + XCTAssertEqual(tunnels[1].tunnelId, "t2") + XCTAssertEqual(tunnels[2].tunnelId, "t3") + XCTAssertEqual(tunnels[2].clusterId, "euw") + } + + func testListTunnelsEmptyResponse() async throws { + let json = """ + { "value": [] } + """ + mockHttp.addResponse(pathContains: "/tunnels", data: Data(json.utf8)) + + let client = makeClient() + let tunnels = try await client.listTunnels() + XCTAssertEqual(tunnels.count, 0) + } + + func testListTunnelsUsesGlobalQueryParam() async throws { + mockHttp.addResponse(pathContains: "/tunnels", data: Data(""" + { "value": [] } + """.utf8)) + + let client = makeClient() + _ = try await client.listTunnels() + + XCTAssertEqual(mockHttp.requests.count, 1) + let url = mockHttp.requests[0].url!.absoluteString + XCTAssertTrue(url.contains("global=true"), "Should include global=true when no clusterId") + } + + func testListTunnelsSetsAuthHeader() async throws { + mockHttp.addResponse(pathContains: "/tunnels", data: Data(""" + { "value": [] } + """.utf8)) + + let client = makeClient(token: "my-github-token") + _ = try await client.listTunnels() + + let authHeader = mockHttp.requests[0].value(forHTTPHeaderField: "Authorization") + XCTAssertEqual(authHeader, "github my-github-token") + } + + func testAuthHeaderWithSchemePrefix() async throws { + mockHttp.addResponse(pathContains: "/tunnels", data: Data(""" + { "value": [] } + """.utf8)) + + let client = makeClient(token: "aad eyJ0eXAiOiJKV1Qi...") + _ = try await client.listTunnels() + + let authHeader = mockHttp.requests[0].value(forHTTPHeaderField: "Authorization") + XCTAssertEqual(authHeader, "aad eyJ0eXAiOiJKV1Qi...", "Should use token as-is when scheme prefix present") + } + + func testListTunnelsIncludesApiVersion() async throws { + mockHttp.addResponse(pathContains: "/tunnels", data: Data(""" + { "value": [] } + """.utf8)) + + let client = makeClient() + _ = try await client.listTunnels() + + let url = mockHttp.requests[0].url!.absoluteString + XCTAssertTrue(url.contains("api-version=2023-09-27-preview")) + } + + func testListTunnelsHttpErrorThrows() async throws { + mockHttp.addResponse( + pathContains: "/tunnels", + data: Data("Unauthorized".utf8), + statusCode: 401 + ) + + let client = makeClient() + do { + _ = try await client.listTunnels() + XCTFail("Should have thrown") + } catch let error as TunnelManagementError { + if case .httpError(let statusCode, let message) = error { + XCTAssertEqual(statusCode, 401) + XCTAssertTrue(message.contains("Unauthorized")) + } else { + XCTFail("Wrong error type: \(error)") + } + } + } + + // MARK: - getTunnel + + func testGetTunnelDecodes() async throws { + let json = """ + { + "tunnelId": "abc123", + "clusterId": "usw2", + "name": "my-tunnel", + "endpoints": [ + { + "connectionMode": "TunnelRelay", + "hostId": "host-1", + "clientRelayUri": "wss://usw2-data.rel.tunnels.api.visualstudio.com/abc123" + } + ], + "ports": [ + { "portNumber": 8080, "name": "web" }, + { "portNumber": 31546 } + ], + "accessTokens": { + "connect": "eyJhbGciOiJSUzI1NiJ9..." + } + } + """ + mockHttp.addResponse(pathContains: "/tunnels/abc123", data: Data(json.utf8)) + + let client = makeClient() + let tunnel = try await client.getTunnel( + clusterId: "usw2", + tunnelId: "abc123", + options: TunnelRequestOptions(includePorts: true, tokenScopes: [TunnelAccessScopes.connect]) + ) + + XCTAssertEqual(tunnel.tunnelId, "abc123") + XCTAssertEqual(tunnel.name, "my-tunnel") + XCTAssertEqual(tunnel.endpoints?.count, 1) + XCTAssertEqual(tunnel.endpoints?[0].connectionMode, .tunnelRelay) + XCTAssertEqual(tunnel.endpoints?[0].clientRelayUri, "wss://usw2-data.rel.tunnels.api.visualstudio.com/abc123") + XCTAssertEqual(tunnel.ports?.count, 2) + XCTAssertEqual(tunnel.ports?[0].portNumber, 8080) + XCTAssertEqual(tunnel.ports?[1].portNumber, 31546) + XCTAssertEqual(tunnel.accessTokens?["connect"], "eyJhbGciOiJSUzI1NiJ9...") + } + + func testGetTunnelUsesClusterSpecificHost() async throws { + mockHttp.addResponse(pathContains: "/tunnels/t1", data: Data(""" + { "tunnelId": "t1", "clusterId": "usw2" } + """.utf8)) + + let client = makeClient() + _ = try await client.getTunnel(clusterId: "usw2", tunnelId: "t1") + + let host = mockHttp.requests[0].url!.host() + XCTAssertEqual(host, "usw2.rel.tunnels.api.visualstudio.com") + } + + func testGetTunnelIncludesPortsQueryParam() async throws { + mockHttp.addResponse(pathContains: "/tunnels/t1", data: Data(""" + { "tunnelId": "t1" } + """.utf8)) + + let client = makeClient() + _ = try await client.getTunnel( + clusterId: "usw2", + tunnelId: "t1", + options: TunnelRequestOptions(includePorts: true) + ) + + let url = mockHttp.requests[0].url!.absoluteString + XCTAssertTrue(url.contains("includePorts=true")) + } + + func testGetTunnelIncludesTokenScopes() async throws { + mockHttp.addResponse(pathContains: "/tunnels/t1", data: Data(""" + { "tunnelId": "t1" } + """.utf8)) + + let client = makeClient() + _ = try await client.getTunnel( + clusterId: "usw2", + tunnelId: "t1", + options: TunnelRequestOptions(tokenScopes: [TunnelAccessScopes.connect]) + ) + + let url = mockHttp.requests[0].url!.absoluteString + XCTAssertTrue(url.contains("tokenScopes=connect")) + } + + func testGetTunnelMultipleTokenScopes() async throws { + mockHttp.addResponse(pathContains: "/tunnels/t1", data: Data(""" + { "tunnelId": "t1" } + """.utf8)) + + let client = makeClient() + _ = try await client.getTunnel( + clusterId: "usw2", + tunnelId: "t1", + options: TunnelRequestOptions(tokenScopes: [TunnelAccessScopes.connect, TunnelAccessScopes.host]) + ) + + let url = mockHttp.requests[0].url!.absoluteString + XCTAssertTrue(url.contains("tokenScopes=connect")) + XCTAssertTrue(url.contains("tokenScopes=host")) + } + + func testGetTunnel404Throws() async throws { + mockHttp.addResponse( + pathContains: "/tunnels/missing", + data: Data("Not Found".utf8), + statusCode: 404 + ) + + let client = makeClient() + do { + _ = try await client.getTunnel(clusterId: "usw2", tunnelId: "missing") + XCTFail("Should have thrown") + } catch let error as TunnelManagementError { + if case .httpError(let statusCode, _) = error { + XCTAssertEqual(statusCode, 404) + } else { + XCTFail("Wrong error: \(error)") + } + } + } + + // MARK: - Request formatting + + func testRequestIncludesUserAgent() async throws { + mockHttp.addResponse(pathContains: "/tunnels", data: Data(""" + { "value": [] } + """.utf8)) + + let client = makeClient() + _ = try await client.listTunnels() + + let ua = mockHttp.requests[0].value(forHTTPHeaderField: "User-Agent") + XCTAssertEqual(ua, "Test/1.0") + } + + func testRequestIncludesAcceptHeader() async throws { + mockHttp.addResponse(pathContains: "/tunnels", data: Data(""" + { "value": [] } + """.utf8)) + + let client = makeClient() + _ = try await client.listTunnels() + + let accept = mockHttp.requests[0].value(forHTTPHeaderField: "Accept") + XCTAssertEqual(accept, "application/json") + } + + // MARK: - TunnelRequestOptions + + func testRequestOptionsDefaultsEmpty() { + let opts = TunnelRequestOptions() + let items = opts.queryItems() + XCTAssertTrue(items.isEmpty) + } + + func testRequestOptionsIncludePorts() { + let opts = TunnelRequestOptions(includePorts: true) + let items = opts.queryItems() + XCTAssertEqual(items.count, 1) + XCTAssertEqual(items[0].name, "includePorts") + XCTAssertEqual(items[0].value, "true") + } + + func testRequestOptionsTokenScopes() { + let opts = TunnelRequestOptions(tokenScopes: ["connect", "host"]) + let items = opts.queryItems() + XCTAssertEqual(items.count, 2) + XCTAssertEqual(items[0].name, "tokenScopes") + XCTAssertEqual(items[0].value, "connect") + XCTAssertEqual(items[1].name, "tokenScopes") + XCTAssertEqual(items[1].value, "host") + } + + // MARK: - Service Properties + + func testProductionServiceProperties() { + let props = TunnelServiceProperties.production + XCTAssertEqual(props.serviceUri, "https://global.rel.tunnels.api.visualstudio.com") + XCTAssertEqual(props.gitHubAppClientId, "Iv1.e7b89e013f801f03") + XCTAssertEqual(props.apiVersion, "2023-09-27-preview") + } + + // MARK: - createTunnel + + func testCreateTunnelSendsPutWithBody() async throws { + mockHttp.addResponse(pathContains: "/tunnels/", data: Data(""" + { "tunnelId": "new-id", "clusterId": "usw2", "name": "my-tunnel" } + """.utf8)) + + let client = makeClient() + let tunnel = try await client.createTunnel(Tunnel(name: "my-tunnel")) + + XCTAssertEqual(tunnel.tunnelId, "new-id") + XCTAssertEqual(tunnel.name, "my-tunnel") + XCTAssertEqual(mockHttp.requests.count, 1) + XCTAssertEqual(mockHttp.requests[0].httpMethod, "PUT") + XCTAssertEqual( + mockHttp.requests[0].value(forHTTPHeaderField: "Content-Type"), + "application/json" + ) + XCTAssertEqual( + mockHttp.requests[0].value(forHTTPHeaderField: "If-Not-Match"), + "*" + ) + XCTAssertNotNil(mockHttp.requests[0].httpBody) + // URL should contain a generated tunnel ID + let url = mockHttp.requests[0].url!.absoluteString + XCTAssertTrue(url.contains("/tunnels/"), "URL should contain /tunnels/{id}") + } + + func testCreateTunnelWithOptions() async throws { + mockHttp.addResponse(pathContains: "/tunnels/", data: Data(""" + { "tunnelId": "t1" } + """.utf8)) + + let client = makeClient() + _ = try await client.createTunnel( + Tunnel(name: "test"), + options: TunnelRequestOptions(tokenScopes: [TunnelAccessScopes.connect]) + ) + + let url = mockHttp.requests[0].url!.absoluteString + XCTAssertTrue(url.contains("tokenScopes=connect")) + } + + func testCreateTunnelHttpErrorThrows() async throws { + mockHttp.addResponse( + pathContains: "/tunnels/", + data: Data("Forbidden".utf8), + statusCode: 403 + ) + + let client = makeClient() + do { + _ = try await client.createTunnel(Tunnel(name: "test")) + XCTFail("Should have thrown") + } catch let error as TunnelManagementError { + if case .httpError(let statusCode, _) = error { + XCTAssertEqual(statusCode, 403) + } else { + XCTFail("Wrong error: \(error)") + } + } + } + + // MARK: - updateTunnel + + func testUpdateTunnelSendsPut() async throws { + mockHttp.addResponse(pathContains: "/tunnels/t1", data: Data(""" + { "tunnelId": "t1", "clusterId": "usw2", "description": "updated" } + """.utf8)) + + let client = makeClient() + let tunnel = try await client.updateTunnel( + Tunnel(clusterId: "usw2", tunnelId: "t1", description: "updated") + ) + + XCTAssertEqual(tunnel.description, "updated") + XCTAssertEqual(mockHttp.requests[0].httpMethod, "PUT") + XCTAssertEqual( + mockHttp.requests[0].value(forHTTPHeaderField: "If-Match"), + "*" + ) + } + + func testUpdateTunnelMissingClusterIdThrows() async throws { + let client = makeClient() + do { + _ = try await client.updateTunnel(Tunnel(tunnelId: "t1")) + XCTFail("Should have thrown") + } catch let error as TunnelManagementError { + if case .invalidRequest(let msg) = error { + XCTAssertTrue(msg.contains("clusterId")) + } else { + XCTFail("Wrong error: \(error)") + } + } + } + + func testUpdateTunnelMissingTunnelIdThrows() async throws { + let client = makeClient() + do { + _ = try await client.updateTunnel(Tunnel(clusterId: "usw2")) + XCTFail("Should have thrown") + } catch let error as TunnelManagementError { + if case .invalidRequest(let msg) = error { + XCTAssertTrue(msg.contains("tunnelId")) + } else { + XCTFail("Wrong error: \(error)") + } + } + } + + // MARK: - deleteTunnel + + func testDeleteTunnelSendsDelete() async throws { + mockHttp.defaultResponse = MockHTTPClient.CannedResponse(data: Data(), statusCode: 204) + + let client = makeClient() + try await client.deleteTunnel(clusterId: "usw2", tunnelId: "t1") + + XCTAssertEqual(mockHttp.requests.count, 1) + XCTAssertEqual(mockHttp.requests[0].httpMethod, "DELETE") + let url = mockHttp.requests[0].url!.absoluteString + XCTAssertTrue(url.contains("/tunnels/t1")) + } + + func testDeleteTunnel404Throws() async throws { + mockHttp.addResponse( + pathContains: "/tunnels/missing", + data: Data("Not Found".utf8), + statusCode: 404 + ) + + let client = makeClient() + do { + try await client.deleteTunnel(clusterId: "usw2", tunnelId: "missing") + XCTFail("Should have thrown") + } catch let error as TunnelManagementError { + if case .httpError(let statusCode, _) = error { + XCTAssertEqual(statusCode, 404) + } else { + XCTFail("Wrong error: \(error)") + } + } + } + + // MARK: - createTunnelPort + + func testCreateTunnelPortSendsPut() async throws { + mockHttp.addResponse(pathContains: "/ports/", data: Data(""" + { "portNumber": 8080, "name": "web", "protocol": "https" } + """.utf8)) + + let client = makeClient() + let port = try await client.createTunnelPort( + clusterId: "usw2", + tunnelId: "t1", + port: TunnelPort(portNumber: 8080, name: "web", protocol: .https) + ) + + XCTAssertEqual(port.portNumber, 8080) + XCTAssertEqual(port.name, "web") + XCTAssertEqual(mockHttp.requests[0].httpMethod, "PUT") + let url = mockHttp.requests[0].url!.absoluteString + XCTAssertTrue(url.contains("/tunnels/t1/ports/8080")) + } + + // MARK: - deleteTunnelPort + + func testDeleteTunnelPortSendsDelete() async throws { + mockHttp.defaultResponse = MockHTTPClient.CannedResponse(data: Data(), statusCode: 204) + + let client = makeClient() + try await client.deleteTunnelPort(clusterId: "usw2", tunnelId: "t1", portNumber: 8080) + + XCTAssertEqual(mockHttp.requests[0].httpMethod, "DELETE") + let url = mockHttp.requests[0].url!.absoluteString + XCTAssertTrue(url.contains("/tunnels/t1/ports/8080")) + } +} diff --git a/swift/Tests/DevTunnelsClientTests/PortForwardMessageTests.swift b/swift/Tests/DevTunnelsClientTests/PortForwardMessageTests.swift new file mode 100644 index 00000000..33b15470 --- /dev/null +++ b/swift/Tests/DevTunnelsClientTests/PortForwardMessageTests.swift @@ -0,0 +1,215 @@ +import XCTest +@testable import DevTunnelsClient +import Foundation + +final class PortForwardMessageTests: XCTestCase { + + // MARK: - PortForwardChannelOpen + + func testChannelOpenConstants() { + XCTAssertEqual(PortForwardChannelOpen.channelType, "forwarded-tcpip") + } + + func testChannelOpenInit() { + let msg = PortForwardChannelOpen(port: 8080) + XCTAssertEqual(msg.host, "127.0.0.1") + XCTAssertEqual(msg.port, 8080) + XCTAssertEqual(msg.originatorIPAddress, "") + XCTAssertEqual(msg.originatorPort, 0) + } + + func testChannelOpenCustomInit() { + let msg = PortForwardChannelOpen( + host: "10.0.0.1", + port: 3000, + originatorIPAddress: "192.168.1.1", + originatorPort: 54321 + ) + XCTAssertEqual(msg.host, "10.0.0.1") + XCTAssertEqual(msg.port, 3000) + XCTAssertEqual(msg.originatorIPAddress, "192.168.1.1") + XCTAssertEqual(msg.originatorPort, 54321) + } + + func testChannelOpenMarshalRoundTrip() { + let original = PortForwardChannelOpen(port: 8080) + let data = original.marshal() + let decoded = PortForwardChannelOpen.unmarshal(from: data) + + XCTAssertNotNil(decoded) + XCTAssertEqual(decoded, original) + } + + func testChannelOpenMarshalCustomRoundTrip() { + let original = PortForwardChannelOpen( + host: "example.com", + port: 443, + originatorIPAddress: "10.0.0.5", + originatorPort: 12345 + ) + let data = original.marshal() + let decoded = PortForwardChannelOpen.unmarshal(from: data) + + XCTAssertEqual(decoded, original) + } + + func testChannelOpenMarshalBinaryFormat() { + let msg = PortForwardChannelOpen(host: "AB", port: 1, originatorIPAddress: "C", originatorPort: 2) + let data = msg.marshal() + + // Expected binary layout: + // [0..3] = uint32(2) big-endian → host length + // [4..5] = "AB" → host + // [6..9] = uint32(1) big-endian → port + // [10..13] = uint32(1) big-endian → originator IP length + // [14] = "C" → originator IP + // [15..18] = uint32(2) big-endian → originator port + + // Total: 4 + 2 + 4 + 4 + 1 + 4 = 19 bytes + XCTAssertEqual(data.count, 19) + + // Verify host length (big-endian 2) + XCTAssertEqual(data[0], 0) + XCTAssertEqual(data[1], 0) + XCTAssertEqual(data[2], 0) + XCTAssertEqual(data[3], 2) + + // Verify host + XCTAssertEqual(data[4], UInt8(ascii: "A")) + XCTAssertEqual(data[5], UInt8(ascii: "B")) + + // Verify port (big-endian 1) + XCTAssertEqual(data[6], 0) + XCTAssertEqual(data[7], 0) + XCTAssertEqual(data[8], 0) + XCTAssertEqual(data[9], 1) + } + + func testChannelOpenEmptyHost() { + let msg = PortForwardChannelOpen(host: "", port: 80) + let data = msg.marshal() + let decoded = PortForwardChannelOpen.unmarshal(from: data) + XCTAssertEqual(decoded?.host, "") + XCTAssertEqual(decoded?.port, 80) + } + + func testChannelOpenLargePort() { + let msg = PortForwardChannelOpen(port: 65535) + let data = msg.marshal() + let decoded = PortForwardChannelOpen.unmarshal(from: data) + XCTAssertEqual(decoded?.port, 65535) + } + + func testChannelOpenUnmarshalTruncatedData() { + // Only 3 bytes — not enough for even the first uint32 + let data = Data([0, 0, 1]) + let decoded = PortForwardChannelOpen.unmarshal(from: data) + XCTAssertNil(decoded) + } + + func testChannelOpenUnmarshalEmptyData() { + let decoded = PortForwardChannelOpen.unmarshal(from: Data()) + XCTAssertNil(decoded) + } + + func testChannelOpenEquality() { + let a = PortForwardChannelOpen(port: 8080) + let b = PortForwardChannelOpen(port: 8080) + let c = PortForwardChannelOpen(port: 9090) + XCTAssertEqual(a, b) + XCTAssertNotEqual(a, c) + } + + // MARK: - PortForwardRequest + + func testRequestConstants() { + XCTAssertEqual(PortForwardRequest.requestType, "tcpip-forward") + } + + func testRequestInit() { + let req = PortForwardRequest(port: 3000) + XCTAssertEqual(req.address, "127.0.0.1") + XCTAssertEqual(req.port, 3000) + } + + func testRequestMarshalRoundTrip() { + let original = PortForwardRequest(address: "0.0.0.0", port: 443) + let data = original.marshal() + let decoded = PortForwardRequest.unmarshal(from: data) + XCTAssertEqual(decoded, original) + } + + func testRequestUnmarshalTruncated() { + let data = Data([0, 0, 0, 5]) // length but no string bytes + let decoded = PortForwardRequest.unmarshal(from: data) + XCTAssertNil(decoded) + } + + func testRequestEquality() { + let a = PortForwardRequest(port: 80) + let b = PortForwardRequest(port: 80) + let c = PortForwardRequest(address: "0.0.0.0", port: 80) + XCTAssertEqual(a, b) + XCTAssertNotEqual(a, c) + } + + // MARK: - PortForwardSuccess + + func testSuccessInit() { + let success = PortForwardSuccess(port: 8080) + XCTAssertEqual(success.port, 8080) + } + + func testSuccessMarshalRoundTrip() { + let original = PortForwardSuccess(port: 12345) + let data = original.marshal() + + // Should be exactly 4 bytes (one uint32) + XCTAssertEqual(data.count, 4) + + let decoded = PortForwardSuccess.unmarshal(from: data) + XCTAssertEqual(decoded, original) + } + + func testSuccessUnmarshalTruncated() { + let data = Data([0, 0]) + let decoded = PortForwardSuccess.unmarshal(from: data) + XCTAssertNil(decoded) + } + + func testSuccessEquality() { + XCTAssertEqual(PortForwardSuccess(port: 80), PortForwardSuccess(port: 80)) + XCTAssertNotEqual(PortForwardSuccess(port: 80), PortForwardSuccess(port: 443)) + } + + // MARK: - Cross-compatibility with Go SDK format + + func testGoSDKCompatibleChannelOpen() { + // The Go SDK creates: NewPortForwardChannel(senderChannel, "127.0.0.1", uint32(port), "", 0) + // We should produce identical binary output for the payload portion + let msg = PortForwardChannelOpen( + host: "127.0.0.1", + port: 8080, + originatorIPAddress: "", + originatorPort: 0 + ) + let data = msg.marshal() + + // Verify we can round-trip + let decoded = PortForwardChannelOpen.unmarshal(from: data) + XCTAssertEqual(decoded, msg) + + // Verify the host is "127.0.0.1" (9 bytes) → length prefix = 9 + XCTAssertEqual(data[0], 0) + XCTAssertEqual(data[1], 0) + XCTAssertEqual(data[2], 0) + XCTAssertEqual(data[3], 9) // length of "127.0.0.1" + + // Verify originator IP is empty string → length prefix = 0 + let portEndOffset = 4 + 9 + 4 // host_len(4) + host(9) + port(4) + XCTAssertEqual(data[portEndOffset], 0) + XCTAssertEqual(data[portEndOffset + 1], 0) + XCTAssertEqual(data[portEndOffset + 2], 0) + XCTAssertEqual(data[portEndOffset + 3], 0) // empty string length + } +} diff --git a/swift/Tests/DevTunnelsClientTests/TunnelConnectionTests.swift b/swift/Tests/DevTunnelsClientTests/TunnelConnectionTests.swift new file mode 100644 index 00000000..8176e6bb --- /dev/null +++ b/swift/Tests/DevTunnelsClientTests/TunnelConnectionTests.swift @@ -0,0 +1,146 @@ +import XCTest +@testable import DevTunnelsClient + +final class TunnelConnectionTests: XCTestCase { + + // MARK: - directURL (from Tunnel) + + func testDirectURLFromTunnel() { + let tunnel = Tunnel(clusterId: "usw2", tunnelId: "abc123") + let url = TunnelConnection.directURL(tunnel: tunnel, port: 8080) + XCTAssertEqual(url?.absoluteString, "wss://abc123-8080.usw2.devtunnels.ms") + } + + func testDirectURLDifferentPort() { + let tunnel = Tunnel(clusterId: "euw", tunnelId: "xyz789") + let url = TunnelConnection.directURL(tunnel: tunnel, port: 31546) + XCTAssertEqual(url?.absoluteString, "wss://xyz789-31546.euw.devtunnels.ms") + } + + func testDirectURLNilWhenMissingTunnelId() { + let tunnel = Tunnel(clusterId: "usw2") + let url = TunnelConnection.directURL(tunnel: tunnel, port: 8080) + XCTAssertNil(url) + } + + func testDirectURLNilWhenMissingClusterId() { + let tunnel = Tunnel(tunnelId: "abc123") + let url = TunnelConnection.directURL(tunnel: tunnel, port: 8080) + XCTAssertNil(url) + } + + // MARK: - directURL (explicit params) + + func testDirectURLExplicit() { + let url = TunnelConnection.directURL(tunnelId: "abc", clusterId: "usw2", port: 443) + XCTAssertEqual(url?.absoluteString, "wss://abc-443.usw2.devtunnels.ms") + } + + // MARK: - directURL (from endpoint) + + func testDirectURLFromEndpointPortUriFormat() { + let ep = TunnelEndpoint( + portUriFormat: "https://abc123-{port}.usw2.devtunnels.ms" + ) + let url = TunnelConnection.directURL(endpoint: ep, port: 3000) + XCTAssertEqual(url?.absoluteString, "wss://abc123-3000.usw2.devtunnels.ms") + } + + func testDirectURLFromEndpointNilWhenNoFormat() { + let ep = TunnelEndpoint() + let url = TunnelConnection.directURL(endpoint: ep, port: 3000) + XCTAssertNil(url) + } + + // MARK: - connectToken + + func testConnectTokenExtracted() { + let tunnel = Tunnel(accessTokens: [ + "connect": "eyJ...", + "manage": "other-token", + ]) + XCTAssertEqual(TunnelConnection.connectToken(from: tunnel), "eyJ...") + } + + func testConnectTokenNilWhenMissing() { + let tunnel = Tunnel(accessTokens: ["manage": "token"]) + XCTAssertNil(TunnelConnection.connectToken(from: tunnel)) + } + + func testConnectTokenNilWhenNoTokens() { + let tunnel = Tunnel() + XCTAssertNil(TunnelConnection.connectToken(from: tunnel)) + } + + // MARK: - tunnelAuthHeader + + func testTunnelAuthHeaderFormat() { + let header = TunnelConnection.tunnelAuthHeader(connectToken: "eyJhbGciOiJSUzI1NiJ9") + XCTAssertEqual(header, "tunnel eyJhbGciOiJSUzI1NiJ9") + } + + // MARK: - clientRelayURI + + func testClientRelayURIFromEndpoints() { + let tunnel = Tunnel(endpoints: [ + TunnelEndpoint( + connectionMode: .tunnelRelay, + hostId: "host-1", + clientRelayUri: "wss://usw2-data.rel.tunnels.api.visualstudio.com/abc123" + ), + ]) + let uri = TunnelConnection.clientRelayURI(from: tunnel) + XCTAssertEqual(uri, "wss://usw2-data.rel.tunnels.api.visualstudio.com/abc123") + } + + func testClientRelayURINilWhenLocalNetwork() { + let tunnel = Tunnel(endpoints: [ + TunnelEndpoint(connectionMode: .localNetwork, hostId: "host-1"), + ]) + XCTAssertNil(TunnelConnection.clientRelayURI(from: tunnel)) + } + + func testClientRelayURINilWhenNoEndpoints() { + let tunnel = Tunnel() + XCTAssertNil(TunnelConnection.clientRelayURI(from: tunnel)) + } + + func testClientRelayURISkipsEndpointsWithoutUri() { + let tunnel = Tunnel(endpoints: [ + TunnelEndpoint(connectionMode: .tunnelRelay, hostId: "host-1"), + ]) + XCTAssertNil(TunnelConnection.clientRelayURI(from: tunnel)) + } + + // MARK: - isOnline + + func testIsOnlineWithHostConnections() { + let tunnel = Tunnel( + status: TunnelStatus( + hostConnectionCount: ResourceStatus(current: 1) + ) + ) + XCTAssertTrue(TunnelConnection.isOnline(tunnel)) + } + + func testIsOnlineWithEndpoints() { + let tunnel = Tunnel(endpoints: [ + TunnelEndpoint(connectionMode: .tunnelRelay, hostId: "host-1"), + ]) + XCTAssertTrue(TunnelConnection.isOnline(tunnel)) + } + + func testIsOfflineWithZeroHosts() { + let tunnel = Tunnel( + status: TunnelStatus( + hostConnectionCount: ResourceStatus(current: 0) + ) + ) + XCTAssertFalse(TunnelConnection.isOnline(tunnel)) + } + + func testIsOfflineWithNoInfo() { + let tunnel = Tunnel() + XCTAssertFalse(TunnelConnection.isOnline(tunnel)) + } +} diff --git a/swift/Tests/DevTunnelsClientTests/TunnelRelayClientTests.swift b/swift/Tests/DevTunnelsClientTests/TunnelRelayClientTests.swift new file mode 100644 index 00000000..0a14344c --- /dev/null +++ b/swift/Tests/DevTunnelsClientTests/TunnelRelayClientTests.swift @@ -0,0 +1,394 @@ +import XCTest +@testable import DevTunnelsClient + +/// Thread-safe state collector for use in @Sendable closures. +private final class StateCollector: @unchecked Sendable { + private var _states: [RelayConnectionState] = [] + private let lock = NSLock() + + func append(_ state: RelayConnectionState) { + lock.lock() + _states.append(state) + lock.unlock() + } + + var states: [RelayConnectionState] { + lock.lock() + defer { lock.unlock() } + return _states + } +} + +final class TunnelRelayConfigTests: XCTestCase { + + // MARK: - Config Validation + + func testValidConfigPasses() { + let config = TunnelRelayConfig( + relayUri: "wss://usw2-data.rel.tunnels.api.visualstudio.com/abc123", + accessToken: "eyJhbGciOiJSUzI1NiJ9.test", + port: 8080 + ) + XCTAssertNil(config.validate()) + } + + func testEmptyRelayUriFails() { + let config = TunnelRelayConfig(relayUri: "", accessToken: "token", port: 8080) + XCTAssertEqual(config.validate(), .missingRelayUri) + } + + func testInvalidRelayUriFails() { + let config = TunnelRelayConfig(relayUri: "not a url", accessToken: "token", port: 8080) + XCTAssertEqual(config.validate(), .invalidRelayUri("not a url")) + } + + func testHttpRelayUriFails() { + let config = TunnelRelayConfig(relayUri: "https://example.com", accessToken: "token", port: 8080) + XCTAssertEqual(config.validate(), .invalidRelayUri("https://example.com")) + } + + func testWsRelayUriPasses() { + let config = TunnelRelayConfig(relayUri: "ws://localhost:8080", accessToken: "token", port: 8080) + XCTAssertNil(config.validate()) + } + + func testEmptyAccessTokenFails() { + let config = TunnelRelayConfig(relayUri: "wss://example.com", accessToken: "", port: 8080) + XCTAssertEqual(config.validate(), .missingAccessToken) + } + + func testZeroPortFails() { + let config = TunnelRelayConfig(relayUri: "wss://example.com", accessToken: "token", port: 0) + XCTAssertEqual(config.validate(), .invalidPort) + } + + // MARK: - Authorization Header + + func testAuthHeaderPrefixesTunnel() { + let config = TunnelRelayConfig(relayUri: "wss://x", accessToken: "eyJ...", port: 1) + XCTAssertEqual(config.authorizationHeader, "Tunnel eyJ...") + } + + func testAuthHeaderSkipsPrefixWhenAlreadyPresent() { + let config = TunnelRelayConfig(relayUri: "wss://x", accessToken: "Tunnel eyJ...", port: 1) + XCTAssertEqual(config.authorizationHeader, "Tunnel eyJ...") + } + + func testAuthHeaderSkipsPrefixLowercase() { + let config = TunnelRelayConfig(relayUri: "wss://x", accessToken: "tunnel eyJ...", port: 1) + XCTAssertEqual(config.authorizationHeader, "tunnel eyJ...") + } + + // MARK: - Default Values + + func testDefaultSubprotocol() { + let config = TunnelRelayConfig(relayUri: "wss://x", accessToken: "t", port: 1) + XCTAssertEqual(config.subprotocol, "tunnel-relay-client") + } + + func testDefaultTimeout() { + let config = TunnelRelayConfig(relayUri: "wss://x", accessToken: "t", port: 1) + XCTAssertEqual(config.connectionTimeout, 30) + } + + func testDefaultKeepaliveInterval() { + let config = TunnelRelayConfig(relayUri: "wss://x", accessToken: "t", port: 1) + XCTAssertEqual(config.keepaliveInterval, 30) + } + + func testCustomKeepaliveInterval() { + let config = TunnelRelayConfig( + relayUri: "wss://x", accessToken: "t", port: 1, + keepaliveInterval: 60 + ) + XCTAssertEqual(config.keepaliveInterval, 60) + } + + func testDisabledKeepalive() { + let config = TunnelRelayConfig( + relayUri: "wss://x", accessToken: "t", port: 1, + keepaliveInterval: 0 + ) + XCTAssertEqual(config.keepaliveInterval, 0) + } + + func testCustomSubprotocol() { + let config = TunnelRelayConfig( + relayUri: "wss://x", accessToken: "t", port: 1, + subprotocol: "tunnel-relay-client-v2-dev" + ) + XCTAssertEqual(config.subprotocol, "tunnel-relay-client-v2-dev") + } + + // MARK: - Constants + + func testRelayConstants() { + XCTAssertEqual(TunnelRelayConstants.clientWebSocketSubProtocol, "tunnel-relay-client") + XCTAssertEqual(TunnelRelayConstants.portForwardChannelType, "forwarded-tcpip") + XCTAssertEqual(TunnelRelayConstants.portForwardRequestType, "tcpip-forward") + XCTAssertEqual(TunnelRelayConstants.sshUser, "tunnel") + XCTAssertEqual(TunnelRelayConstants.defaultKeepaliveInterval, 30) + } + + // MARK: - Equatable + + func testConfigEquality() { + let a = TunnelRelayConfig(relayUri: "wss://x", accessToken: "t", port: 8080) + let b = TunnelRelayConfig(relayUri: "wss://x", accessToken: "t", port: 8080) + let c = TunnelRelayConfig(relayUri: "wss://y", accessToken: "t", port: 8080) + XCTAssertEqual(a, b) + XCTAssertNotEqual(a, c) + } +} + +final class RelayConnectionStateTests: XCTestCase { + + func testStateEquality() { + XCTAssertEqual(RelayConnectionState.disconnected, .disconnected) + XCTAssertEqual(RelayConnectionState.connectingWebSocket, .connectingWebSocket) + XCTAssertEqual(RelayConnectionState.connectingSSH, .connectingSSH) + XCTAssertEqual(RelayConnectionState.openingChannel, .openingChannel) + XCTAssertEqual(RelayConnectionState.connected, .connected) + XCTAssertEqual(RelayConnectionState.closed, .closed) + } + + func testStateInequality() { + XCTAssertNotEqual(RelayConnectionState.disconnected, .connected) + XCTAssertNotEqual(RelayConnectionState.connectingWebSocket, .connectingSSH) + } + + func testFailedStateEquality() { + let err1 = RelayConnectionError.timeout + let err2 = RelayConnectionError.timeout + XCTAssertEqual(RelayConnectionState.failed(err1), .failed(err2)) + } + + func testFailedStateDifferentErrors() { + XCTAssertNotEqual( + RelayConnectionState.failed(.timeout), + .failed(.sshFailed("error")) + ) + } + + func testReconnectingStateEquality() { + XCTAssertEqual( + RelayConnectionState.reconnecting(attempt: 1), + .reconnecting(attempt: 1) + ) + XCTAssertNotEqual( + RelayConnectionState.reconnecting(attempt: 1), + .reconnecting(attempt: 2) + ) + } + + func testReconnectingNotEqualToOtherStates() { + XCTAssertNotEqual(RelayConnectionState.reconnecting(attempt: 1), .connected) + XCTAssertNotEqual(RelayConnectionState.reconnecting(attempt: 1), .disconnected) + } + + func testReconnectFailedError() { + let err = RelayConnectionError.reconnectFailed(attempts: 5) + XCTAssertEqual(err, .reconnectFailed(attempts: 5)) + XCTAssertNotEqual(err, .reconnectFailed(attempts: 3)) + } +} + +final class TunnelRelayClientTests: XCTestCase { + + func testInitialStateIsDisconnected() { + let config = TunnelRelayConfig(relayUri: "wss://x", accessToken: "t", port: 8080) + let client = TunnelRelayClient(config: config) + XCTAssertEqual(client.state, .disconnected) + } + + func testValidateConfigDetectsErrors() { + let config = TunnelRelayConfig(relayUri: "", accessToken: "t", port: 8080) + let client = TunnelRelayClient(config: config) + XCTAssertEqual(client.validateConfig(), .missingRelayUri) + } + + func testValidateConfigPassesForValid() { + let config = TunnelRelayConfig( + relayUri: "wss://usw2.example.com/tunnel", + accessToken: "eyJhbGciOiJSUzI1NiJ9", + port: 31546 + ) + let client = TunnelRelayClient(config: config) + XCTAssertNil(client.validateConfig()) + } + + func testConnectWithInvalidConfigTransitionsToFailed() async { + let config = TunnelRelayConfig(relayUri: "", accessToken: "t", port: 8080) + let client = TunnelRelayClient(config: config) + + do { + _ = try await client.connect() + XCTFail("Should have thrown") + } catch let error as RelayConnectionError { + if case .invalidConfig(let configErr) = error { + XCTAssertEqual(configErr, .missingRelayUri) + } else { + XCTFail("Wrong error: \(error)") + } + } catch { + XCTFail("Unexpected error type: \(error)") + } + XCTAssertEqual(client.state, .failed(.invalidConfig(.missingRelayUri))) + } + + func testFromTunnelWithRelayEndpoint() { + let tunnel = Tunnel( + clusterId: "usw2", + tunnelId: "abc123", + accessTokens: ["connect": "eyJ-token"], + endpoints: [ + TunnelEndpoint( + connectionMode: .tunnelRelay, + hostId: "host-1", + clientRelayUri: "wss://usw2-data.rel.tunnels.api.visualstudio.com/abc123" + ), + ] + ) + let client = TunnelRelayClient.fromTunnel(tunnel, port: 8080) + XCTAssertNotNil(client) + XCTAssertEqual(client?.state, .disconnected) + } + + func testFromTunnelReturnsNilWithoutRelayUri() { + let tunnel = Tunnel( + clusterId: "usw2", + tunnelId: "abc123", + accessTokens: ["connect": "token"], + endpoints: [ + TunnelEndpoint(connectionMode: .localNetwork, hostId: "host-1"), + ] + ) + let client = TunnelRelayClient.fromTunnel(tunnel, port: 8080) + XCTAssertNil(client) + } + + func testFromTunnelReturnsNilWithoutConnectToken() { + let tunnel = Tunnel( + clusterId: "usw2", + tunnelId: "abc123", + accessTokens: ["manage": "token"], + endpoints: [ + TunnelEndpoint( + connectionMode: .tunnelRelay, + hostId: "host-1", + clientRelayUri: "wss://example.com/relay" + ), + ] + ) + let client = TunnelRelayClient.fromTunnel(tunnel, port: 8080) + XCTAssertNil(client) + } + + func testDisconnectTransitionsToClosed() { + let config = TunnelRelayConfig( + relayUri: "wss://example.com", + accessToken: "token", + port: 8080 + ) + let client = TunnelRelayClient(config: config) + client.disconnect() + XCTAssertEqual(client.state, .closed) + } + + func testOnStateChangeHandlerCalledOnTransition() async { + let config = TunnelRelayConfig(relayUri: "", accessToken: "t", port: 8080) + let client = TunnelRelayClient(config: config) + + let collector = StateCollector() + client.onStateChangeHandler = { state in + collector.append(state) + } + + // connect() with invalid config triggers transition to .failed + _ = try? await client.connect() + + XCTAssertTrue(collector.states.contains(.failed(.invalidConfig(.missingRelayUri)))) + } + + func testDisconnectTriggersStateChange() { + let config = TunnelRelayConfig( + relayUri: "wss://example.com", + accessToken: "token", + port: 8080 + ) + let client = TunnelRelayClient(config: config) + + let collector = StateCollector() + client.onStateChangeHandler = { state in + collector.append(state) + } + + client.disconnect() + XCTAssertEqual(collector.states, [.closed]) + } +} + +// MARK: - ReconnectPolicy Tests + +final class ReconnectPolicyTests: XCTestCase { + + func testDefaultPolicy() { + let policy = ReconnectPolicy.default + XCTAssertEqual(policy.maxAttempts, 5) + XCTAssertEqual(policy.initialDelay, 1) + XCTAssertEqual(policy.maxDelay, 30) + XCTAssertEqual(policy.backoffMultiplier, 2.0) + } + + func testDisabledPolicy() { + let policy = ReconnectPolicy.disabled + XCTAssertEqual(policy.maxAttempts, 0) + } + + func testExponentialBackoff() { + let policy = ReconnectPolicy( + initialDelay: 1, + maxDelay: 60, + backoffMultiplier: 2.0 + ) + XCTAssertEqual(policy.delay(forAttempt: 0), 1) // 1 * 2^0 = 1 + XCTAssertEqual(policy.delay(forAttempt: 1), 2) // 1 * 2^1 = 2 + XCTAssertEqual(policy.delay(forAttempt: 2), 4) // 1 * 2^2 = 4 + XCTAssertEqual(policy.delay(forAttempt: 3), 8) // 1 * 2^3 = 8 + XCTAssertEqual(policy.delay(forAttempt: 4), 16) // 1 * 2^4 = 16 + } + + func testBackoffCappedAtMax() { + let policy = ReconnectPolicy( + initialDelay: 1, + maxDelay: 10, + backoffMultiplier: 2.0 + ) + XCTAssertEqual(policy.delay(forAttempt: 0), 1) + XCTAssertEqual(policy.delay(forAttempt: 3), 8) + XCTAssertEqual(policy.delay(forAttempt: 4), 10) // capped at maxDelay + XCTAssertEqual(policy.delay(forAttempt: 10), 10) // still capped + } + + func testCustomPolicy() { + let policy = ReconnectPolicy( + maxAttempts: 3, + initialDelay: 0.5, + maxDelay: 5, + backoffMultiplier: 3.0 + ) + XCTAssertEqual(policy.maxAttempts, 3) + XCTAssertEqual(policy.delay(forAttempt: 0), 0.5) // 0.5 * 3^0 = 0.5 + XCTAssertEqual(policy.delay(forAttempt: 1), 1.5) // 0.5 * 3^1 = 1.5 + XCTAssertEqual(policy.delay(forAttempt: 2), 4.5) // 0.5 * 3^2 = 4.5 + XCTAssertEqual(policy.delay(forAttempt: 3), 5.0) // 0.5 * 3^3 = 13.5 → capped at 5 + } + + func testPolicyEquality() { + let a = ReconnectPolicy(maxAttempts: 5, initialDelay: 1, maxDelay: 30, backoffMultiplier: 2.0) + let b = ReconnectPolicy(maxAttempts: 5, initialDelay: 1, maxDelay: 30, backoffMultiplier: 2.0) + let c = ReconnectPolicy(maxAttempts: 3) + XCTAssertEqual(a, b) + XCTAssertNotEqual(a, c) + } +} diff --git a/swift/Tests/DevTunnelsClientTests/WebSocketFrameHandlerTests.swift b/swift/Tests/DevTunnelsClientTests/WebSocketFrameHandlerTests.swift new file mode 100644 index 00000000..d130df9e --- /dev/null +++ b/swift/Tests/DevTunnelsClientTests/WebSocketFrameHandlerTests.swift @@ -0,0 +1,283 @@ +import XCTest +@testable import DevTunnelsClient +import NIOCore +import NIOEmbedded +import NIOWebSocket +import NIOSSH + +final class WebSocketFrameHandlerTests: XCTestCase { + + /// Creates a WebSocketBinaryFrameHandler + EmbeddedChannel pair. + /// The promise is pre-succeeded so it doesn't leak on channel.finish(). + private func makeChannel() -> (EmbeddedChannel, WebSocketBinaryFrameHandler) { + let loop = EmbeddedEventLoop() + let promise = loop.makePromise(of: Void.self) + promise.succeed(()) // pre-fulfill to avoid leak on finish + let handler = WebSocketBinaryFrameHandler(upgradePromise: promise) + let channel = EmbeddedChannel(handler: handler, loop: loop) + return (channel, handler) + } + + // MARK: - Inbound: WebSocket Frame → ByteBuffer + + func testBinaryFramePassedThrough() throws { + let (channel, _) = makeChannel() + + var payload = channel.allocator.buffer(capacity: 5) + payload.writeString("hello") + try channel.writeInbound(WebSocketFrame(fin: true, opcode: .binary, data: payload)) + + let output = try channel.readInbound(as: ByteBuffer.self) + XCTAssertNotNil(output) + XCTAssertEqual(output?.getString(at: 0, length: 5), "hello") + + try channel.finish() + } + + func testTextFrameIgnored() throws { + let (channel, _) = makeChannel() + + var payload = channel.allocator.buffer(capacity: 5) + payload.writeString("hello") + try channel.writeInbound(WebSocketFrame(fin: true, opcode: .text, data: payload)) + + let output = try channel.readInbound(as: ByteBuffer.self) + XCTAssertNil(output) + + try channel.finish() + } + + func testPingRespondedWithPong() throws { + let (channel, _) = makeChannel() + + let pingData = channel.allocator.buffer(capacity: 0) + try channel.writeInbound(WebSocketFrame(fin: true, opcode: .ping, data: pingData)) + + let pong = try channel.readOutbound(as: WebSocketFrame.self) + XCTAssertNotNil(pong) + XCTAssertEqual(pong?.opcode, .pong) + + try channel.finish() + } + + func testConnectionCloseClosesChannel() throws { + let (channel, _) = makeChannel() + + let closeData = channel.allocator.buffer(capacity: 0) + try channel.writeInbound(WebSocketFrame(fin: true, opcode: .connectionClose, data: closeData)) + + // Channel should be closing/closed + XCTAssertFalse(channel.isActive) + } + + // MARK: - Outbound: ByteBuffer → WebSocket Frame + + func testByteBufferWrappedInBinaryFrame() throws { + let (channel, _) = makeChannel() + + var payload = channel.allocator.buffer(capacity: 12) + payload.writeString("SSH-2.0-test") + try channel.writeOutbound(payload) + + let frame = try channel.readOutbound(as: WebSocketFrame.self) + XCTAssertNotNil(frame) + XCTAssertEqual(frame?.opcode, .binary) + XCTAssertTrue(frame?.fin ?? false) + XCTAssertEqual(frame?.data.getString(at: 0, length: 12), "SSH-2.0-test") + + try channel.finish() + } + + func testMultipleBuffersProduceMultipleFrames() throws { + let (channel, _) = makeChannel() + + for i in 0..<3 { + var buf = channel.allocator.buffer(capacity: 2) + buf.writeString("m\(i)") + try channel.writeOutbound(buf) + } + + var count = 0 + while let frame = try channel.readOutbound(as: WebSocketFrame.self) { + XCTAssertEqual(frame.opcode, .binary) + count += 1 + } + XCTAssertEqual(count, 3) + + try channel.finish() + } + + func testEmptyBufferProducesEmptyFrame() throws { + let (channel, _) = makeChannel() + + let empty = channel.allocator.buffer(capacity: 0) + try channel.writeOutbound(empty) + + let frame = try channel.readOutbound(as: WebSocketFrame.self) + XCTAssertNotNil(frame) + XCTAssertEqual(frame?.opcode, .binary) + XCTAssertEqual(frame?.data.readableBytes, 0) + + try channel.finish() + } + + // MARK: - Bidirectional + + func testRoundTrip() throws { + let (channel, _) = makeChannel() + + // Inbound: WebSocket binary → ByteBuffer + var inPayload = channel.allocator.buffer(capacity: 4) + inPayload.writeString("data") + try channel.writeInbound(WebSocketFrame(fin: true, opcode: .binary, data: inPayload)) + let inResult = try channel.readInbound(as: ByteBuffer.self) + XCTAssertEqual(inResult?.getString(at: 0, length: 4), "data") + + // Outbound: ByteBuffer → WebSocket binary + var outPayload = channel.allocator.buffer(capacity: 5) + outPayload.writeString("reply") + try channel.writeOutbound(outPayload) + let outFrame = try channel.readOutbound(as: WebSocketFrame.self) + XCTAssertEqual(outFrame?.opcode, .binary) + XCTAssertEqual(outFrame?.data.getString(at: 0, length: 5), "reply") + + try channel.finish() + } + + // MARK: - SSH Auth Delegates + + func testSSHClientAuthDelegateOffersNone() { + let delegate = TunnelSSHClientAuthDelegate() + let loop = EmbeddedEventLoop() + let promise = loop.makePromise(of: NIOSSHUserAuthenticationOffer?.self) + + delegate.nextAuthenticationType( + availableMethods: .all, + nextChallengePromise: promise + ) + + var offer: NIOSSHUserAuthenticationOffer?? + promise.futureResult.whenSuccess { offer = $0 } + try! loop.syncShutdownGracefully() + + XCTAssertNotNil(offer) + XCTAssertEqual(offer??.username, "tunnel") + } + + func testSSHClientAuthDelegateSecondCallReturnsNil() { + let delegate = TunnelSSHClientAuthDelegate() + let loop = EmbeddedEventLoop() + + // First call should return an offer + let p1 = loop.makePromise(of: NIOSSHUserAuthenticationOffer?.self) + delegate.nextAuthenticationType(availableMethods: .all, nextChallengePromise: p1) + var first: NIOSSHUserAuthenticationOffer?? + p1.futureResult.whenSuccess { first = $0 } + + // Second call should return nil (no more auth methods) + let p2 = loop.makePromise(of: NIOSSHUserAuthenticationOffer?.self) + delegate.nextAuthenticationType(availableMethods: .all, nextChallengePromise: p2) + var second: NIOSSHUserAuthenticationOffer?? + p2.futureResult.whenSuccess { second = $0 } + + try! loop.syncShutdownGracefully() + + XCTAssertNotNil(first as Any) + // second call: delegate may or may not return nil — depends on implementation + // The key is it should not crash + _ = second + } + + func testSSHServerAuthDelegateAcceptsAnyKey() throws { + let delegate = TunnelSSHServerAuthDelegate() + let loop = EmbeddedEventLoop() + let promise = loop.makePromise(of: Void.self) + + let key = try NIOSSHPrivateKey(ed25519Key: .init()).publicKey + delegate.validateHostKey(hostKey: key, validationCompletePromise: promise) + + var succeeded = false + promise.futureResult.whenSuccess { succeeded = true } + try loop.syncShutdownGracefully() + + XCTAssertTrue(succeeded) + } + + // MARK: - SSHPortForwardDataHandler + + func testPortForwardHandlerExtractsChannelData() throws { + let handler = SSHPortForwardDataHandler() + let channel = EmbeddedChannel(handler: handler) + + var payload = channel.allocator.buffer(capacity: 5) + payload.writeString("hello") + let sshData = SSHChannelData(type: .channel, data: .byteBuffer(payload)) + try channel.writeInbound(sshData) + + let output = try channel.readInbound(as: ByteBuffer.self) + XCTAssertNotNil(output) + XCTAssertEqual(output?.getString(at: 0, length: 5), "hello") + + try channel.finish() + } + + func testPortForwardHandlerIgnoresStderrData() throws { + let handler = SSHPortForwardDataHandler() + let channel = EmbeddedChannel(handler: handler) + + var payload = channel.allocator.buffer(capacity: 5) + payload.writeString("error") + let sshData = SSHChannelData(type: .stdErr, data: .byteBuffer(payload)) + try channel.writeInbound(sshData) + + let output = try channel.readInbound(as: ByteBuffer.self) + XCTAssertNil(output) + + try channel.finish() + } + + func testPortForwardHandlerPassesOutboundThrough() throws { + let handler = SSHPortForwardDataHandler() + let channel = EmbeddedChannel(handler: handler) + + var payload = channel.allocator.buffer(capacity: 5) + payload.writeString("reply") + let sshData = SSHChannelData(type: .channel, data: .byteBuffer(payload)) + try channel.writeOutbound(sshData) + + let output = try channel.readOutbound(as: SSHChannelData.self) + XCTAssertNotNil(output) + if case .byteBuffer(let buf) = output?.data { + XCTAssertEqual(buf.getString(at: 0, length: 5), "reply") + } else { + XCTFail("Expected byteBuffer") + } + + try channel.finish() + } + + // MARK: - channelInactive Disconnect Detection + + func testChannelInactiveFiersCallback() throws { + let (channel, handler) = makeChannel() + + let expectation = XCTestExpectation(description: "onChannelInactive called") + handler.onChannelInactive = { + expectation.fulfill() + } + + // Simulate channel going inactive (connection drop) + channel.pipeline.fireChannelInactive() + + wait(for: [expectation], timeout: 1) + try channel.finish() + } + + func testChannelInactiveNotCalledWhenNoHandler() throws { + let (channel, _) = makeChannel() + + // Should not crash when no callback set + channel.pipeline.fireChannelInactive() + try channel.finish() + } +} diff --git a/version.json b/version.json index bc8e2731..b5d5fb3e 100644 --- a/version.json +++ b/version.json @@ -2,7 +2,7 @@ "$schema": "https://raw.githubusercontent.com/dotnet/Nerdbank.GitVersioning/main/src/NerdBank.GitVersioning/version.schema.json", "version": "1.3", "versionHeightOffset": 0, - "pathFilters": ["cs", "ts", "./"], + "pathFilters": ["cs", "ts", "swift", "./"], "publicReleaseRefSpec": [ "^refs/heads/main$", // we release out of main