diff --git a/README.md b/README.md index 65925810..284ec0a9 100644 --- a/README.md +++ b/README.md @@ -12,10 +12,10 @@ Dev tunnels allows developers to securely expose local web services to the Inter | Management API | ✅ | ✅ | ✅ | ✅ | ✅ | | Tunnel Client Connections | ✅ | ✅ | ✅ | ✅ | ✅ | | Tunnel Host Connections | ✅ | ✅ | ❌ | ❌ | ✅ | -| Reconnection | ✅ | ✅ | ❌ | ❌ | ❌ | -| SSH-level Reconnection | ✅ | ✅ | ❌ | ❌ | ❌ | -| Automatic tunnel access token refresh | ✅ | ✅ | ❌ | ❌ | ❌ | -| Ssh Keep-alive | ✅ | ✅ | ❌ | ❌ | ❌ | +| Reconnection | ✅ | ✅ | ❌ | ❌ | ✅ | +| SSH-level Reconnection | ✅ | ✅ | ❌ | ❌ | ✅ | +| Automatic tunnel access token refresh | ✅ | ✅ | ❌ | ❌ | ✅ | +| SSH keep-alive | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ - Supported 🚧 - In Progress diff --git a/rs/src/connections/errors.rs b/rs/src/connections/errors.rs index ec1c83cf..1a7cd62e 100644 --- a/rs/src/connections/errors.rs +++ b/rs/src/connections/errors.rs @@ -27,6 +27,11 @@ pub enum TunnelError { #[error("port {0} already exists in the relay")] PortAlreadyExists(u32), + #[error("max reconnect attempts ({0}) exceeded")] + MaxReconnectAttemptsExceeded(u32), + #[error("tunnel access token refresh failed")] + TokenRefreshFailed, + #[error("proxy connection failed: {0}")] ProxyConnectionFailed(std::io::Error), diff --git a/rs/src/connections/relay_tunnel_host.rs b/rs/src/connections/relay_tunnel_host.rs index b800164a..67c56a65 100644 --- a/rs/src/connections/relay_tunnel_host.rs +++ b/rs/src/connections/relay_tunnel_host.rs @@ -19,7 +19,7 @@ use crate::{ }, }; use async_trait::async_trait; -use futures::{stream::FuturesUnordered, StreamExt, TryFutureExt}; +use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt, TryFutureExt}; use russh::{server::Server as ServerTrait, CryptoVec}; use tokio::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, @@ -39,6 +39,93 @@ use super::{ /// sent. Shared by the host relay to each connected session. type PortMap = HashMap>; +// @group Reconnection : Types for automatic reconnection with exponential backoff + +/// The connection state of a persistent relay host. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum RelayConnectionState { + /// Actively connected to the relay. + Connected, + /// Connection was lost; waiting before the next reconnect attempt. + Reconnecting { + /// 1-based attempt counter. + attempt: u32, + /// Milliseconds until the next connection attempt. + delay_ms: u64, + }, + /// Permanently disconnected (clean shutdown or max retries exceeded). + Disconnected, +} + +/// Observable state of the SSH keep-alive probing for a [`PersistentRelayHandle`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum KeepAliveState { + /// Keep-alive is not configured (default). + NotConfigured, + /// The most recent keep-alive probe succeeded. + Succeeded { + /// Number of successful probes so far. + count: u32, + }, + /// The most recent keep-alive probe failed or timed out. + Failed { + /// Number of failed probes so far. + count: u32, + }, +} + +/// Controls the back-off behaviour of [`RelayTunnelHost::connect_persistent`]. +pub struct ReconnectOptions { + /// Maximum number of reconnect attempts before giving up. + /// `None` (default) retries indefinitely. + pub max_attempts: Option, + /// Delay before the first retry, in milliseconds. Default: 1 000 ms. + pub initial_delay_ms: u64, + /// Upper bound on retry delay, in milliseconds. Default: 13 000 ms. + pub max_delay_ms: u64, + /// Interval between SSH keep-alive probes. `None` (default) disables keep-alive + /// (WebSocket-level pings still run regardless). + pub keep_alive_interval: Option, + /// Async callback invoked when the access token is rejected (HTTP 401). + /// Should return a fresh token, or `None` if a new token cannot be obtained. + /// When `None` (default), unauthorized errors follow normal back-off. + pub token_refresher: Option BoxFuture<'static, Option> + Send + Sync>>, +} + +impl Default for ReconnectOptions { + fn default() -> Self { + Self { + max_attempts: None, + initial_delay_ms: 1_000, + max_delay_ms: 13_000, + keep_alive_interval: None, + token_refresher: None, + } + } +} + +/// Handle returned by [`RelayTunnelHost::connect_persistent`]. +/// +/// Drop this value (or call [`PersistentRelayHandle::stop`]) to request a +/// clean shutdown of the reconnect loop. +pub struct PersistentRelayHandle { + /// Observe connection-state changes as they happen. + pub state: watch::Receiver, + /// Observe keep-alive probe state changes as they happen. + pub keep_alive: watch::Receiver, + /// Dropping this sender signals the reconnect loop to exit. + _stop_tx: mpsc::Sender<()>, + join: JoinHandle>, +} + +impl PersistentRelayHandle { + /// Signals the reconnect loop to stop and waits for a clean exit. + pub async fn stop(self) -> Result<(), TunnelError> { + drop(self._stop_tx); + self.join.await.unwrap_or(Ok(())) + } +} + /// The RelayTunnelHost can host connections via the tunneling service. After /// creating it, you will generally want to run `connect()` to create a new /// a new connection. @@ -172,65 +259,201 @@ impl RelayTunnelHost { /// reconnect if this happens, and they can reconnect using the same /// RelayTunnelHost. pub async fn connect(&mut self, host_token: &str) -> Result { - let (cnx, endpoint) = self.create_websocket(host_token).await?; - let cnx = AsyncRWWebSocket::new(super::ws::AsyncRWWebSocketOptions { - websocket: cnx, - ping_interval: Duration::from_secs(60), - ping_timeout: Duration::from_secs(10), - }); + relay_connect_once( + &self.mgmt, + &self.locator, + self.host_id, + &self.proxy, + self.host_keypair.clone(), + self.ports_rx.clone(), + host_token, + None, // keep-alive not configured for single connect + ) + .await + } - let (client_session, mut rx) = RelayTunnelHost::make_ssh_client(cnx) - .await - .map_err(TunnelError::TunnelRelayDisconnected)?; - let client_session = Arc::new(client_session); - let client_session_ret = client_session.clone(); + /// Connects to the relay and automatically reconnects on disconnection. + /// + /// Unlike [`connect`], this method retries indefinitely (or up to + /// `options.max_attempts` times) with exponential back-off. + /// + /// The first connection attempt is made eagerly so callers surface + /// configuration errors immediately. Drop the returned + /// [`PersistentRelayHandle`] (or call [`PersistentRelayHandle::stop`]) to + /// request a clean shutdown. + // @group Reconnection : Persistent connection with automatic exponential backoff + pub async fn connect_persistent( + &mut self, + host_token: String, + options: ReconnectOptions, + ) -> Result { + // Fail-fast: establish the first connection eagerly. + let (ka_tx, ka_rx) = watch::channel(KeepAliveState::NotConfigured); + let ka_tx_arc = Arc::new(ka_tx); + + let initial_handle = relay_connect_once( + &self.mgmt, + &self.locator, + self.host_id, + &self.proxy, + self.host_keypair.clone(), + self.ports_rx.clone(), + &host_token, + options.keep_alive_interval.map(|d| (d, ka_tx_arc.clone())), + ) + .await?; - log::debug!("established host relay primary session"); + let (state_tx, state_rx) = watch::channel(RelayConnectionState::Connected); + let (stop_tx, mut stop_rx) = mpsc::channel::<()>(1); - let mut channels = HashMap::new(); - let ports_rx = self.ports_rx.clone(); + let mgmt = self.mgmt.clone(); + let locator = self.locator.clone(); + let host_id = self.host_id; + let proxy = self.proxy.clone(); let host_keypair = self.host_keypair.clone(); + let ports_rx = self.ports_rx.clone(); + let join = tokio::spawn(async move { - let mut server = RelayTunnelHost::make_ssh_server(host_keypair.clone()); - loop { + let mut current_join = initial_handle.join; + let mut delay_ms = options.initial_delay_ms; + // @group Reconnection > Token Refresh : Track single-attempt token refresh per session + let mut current_host_token = host_token; + + 'reconnect: loop { + // Wait for the current connection to finish or a stop signal. tokio::select! { - Some(op) = rx.recv() => match op { - ChannelOp::Open(id) => { - let (rw, sender) = AsyncRWChannel::new(id, client_session.clone()); - server.run_stream(rw, ports_rx.clone()); - // do we need to store the JoinHandle for any reason? - channels.insert(id, sender); - log::info!("Opened new client on channel {}", id); - }, - ChannelOp::Close(id) => { - channels.remove(&id); - }, - ChannelOp::Data(id, data) => { - if let Some(ch) = channels.get(&id) { - if ch.send(data).is_err() { // rx was dropped - channels.remove(&id); + r = &mut current_join => { + match r { + Ok(Ok(())) => log::debug!("relay connection ended cleanly"), + Ok(Err(e)) => log::warn!("relay connection ended with error: {}", e), + Err(_) => log::warn!("relay task panicked"), + } + } + _ = stop_rx.recv() => { + current_join.abort(); + let _ = current_join.await; + break 'reconnect; + } + } + + // Reconnect inner loop: retry with exponential back-off. + let mut attempt: u32 = 0; + // @group Reconnection > SSH-level Reconnection : Skip delay after SSH protocol failures + let mut skip_delay = false; + // @group Reconnection > Token Refresh : Single refresh per reconnect session + let mut token_refreshed = false; + loop { + attempt += 1; + if let Some(max) = options.max_attempts { + if attempt > max { + let _ = state_tx.send(RelayConnectionState::Disconnected); + return Err(TunnelError::MaxReconnectAttemptsExceeded(max)); + } + } + + let effective_delay = if skip_delay { 0 } else { delay_ms }; + skip_delay = false; + let _ = state_tx.send(RelayConnectionState::Reconnecting { + attempt, + delay_ms: effective_delay, + }); + + if effective_delay > 0 { + log::info!( + "waiting {}ms before reconnect attempt {}", + effective_delay, attempt + ); + tokio::select! { + _ = tokio::time::sleep(Duration::from_millis(effective_delay)) => {} + _ = stop_rx.recv() => { break 'reconnect; } + } + } + + delay_ms = (delay_ms * 2).min(options.max_delay_ms); + + match relay_connect_once( + &mgmt, + &locator, + host_id, + &proxy, + host_keypair.clone(), + ports_rx.clone(), + ¤t_host_token, + options.keep_alive_interval.map(|d| (d, ka_tx_arc.clone())), + ) + .await + { + Ok(handle) => { + log::info!("reconnected to relay on attempt {}", attempt); + let _ = state_tx.send(RelayConnectionState::Connected); + current_join = handle.join; + delay_ms = options.initial_delay_ms; + break; // exit inner loop, wait for new connection + } + // @group Reconnection > SSH-level Reconnection : SSH error, retry once immediately + Err(TunnelError::TunnelRelayDisconnected(_)) => { + log::warn!( + "SSH-level failure on attempt {}, retrying immediately", + attempt + ); + delay_ms = options.initial_delay_ms; + skip_delay = true; + } + // @group Reconnection > Token Refresh : HTTP 401, call token_refresher + Err(TunnelError::HttpError { + error: HttpError::ResponseError(ref resp_err), + .. + }) if resp_err.status_code == reqwest::StatusCode::UNAUTHORIZED => { + if let Some(refresher) = &options.token_refresher { + if !token_refreshed { + log::info!( + "access token rejected (HTTP 401), refreshing" + ); + match refresher().await { + Some(new_token) => { + current_host_token = new_token; + token_refreshed = true; + skip_delay = true; + } + None => { + log::warn!("token refresher returned None"); + let _ = state_tx.send( + RelayConnectionState::Disconnected, + ); + return Err(TunnelError::TokenRefreshFailed); + } + } + } else { + log::warn!("still unauthorized after token refresh"); + let _ = state_tx.send( + RelayConnectionState::Disconnected, + ); + return Err(TunnelError::TokenRefreshFailed); } + } else { + log::warn!( + "reconnect attempt {} failed: unauthorized (no token refresher)", + attempt + ); } - }, - }, - else => break, + } + Err(e) => { + log::warn!("reconnect attempt {} failed: {}", attempt, e); + // loop continues with next attempt + } + } } } - client_session - .disconnect(russh::Disconnect::ByApplication, "going away", "en") - .await - .ok(); - - log::debug!("disconnected primary session after EOF"); - + let _ = state_tx.send(RelayConnectionState::Disconnected); Ok(()) }); - Ok(RelayHandle { - endpoint, + Ok(PersistentRelayHandle { + state: state_rx, + keep_alive: ka_rx, + _stop_tx: stop_tx, join, - session: client_session_ret, }) } @@ -372,70 +595,164 @@ impl RelayTunnelHost { Ok((session, rx)) } +} - async fn create_websocket( - &self, - host_token: &str, - ) -> Result< - ( - WebSocketStream>, - TunnelRelayTunnelEndpoint, - ), - TunnelError, - > { - let endpoint = self - .mgmt - .update_tunnel_relay_endpoints( - &self.locator, - &TunnelRelayTunnelEndpoint { - base: TunnelEndpoint { - id: Some(format!("{}-relay", self.host_id)), - connection_mode: TunnelConnectionMode::TunnelRelay, - host_id: self.host_id.to_string(), - host_public_keys: vec![], - port_uri_format: None, - port_ssh_command_format: None, - ssh_gateway_public_key: None, - tunnel_ssh_command: None, - tunnel_uri: None, - }, - client_relay_uri: None, - host_relay_uri: None, +// @group Reconnection : Free helper functions backing connect() and connect_persistent() + +async fn create_relay_websocket( + mgmt: &TunnelManagementClient, + locator: &TunnelLocator, + host_id: Uuid, + proxy: &Option, + host_token: &str, +) -> Result< + ( + WebSocketStream>, + TunnelRelayTunnelEndpoint, + ), + TunnelError, +> { + let endpoint = mgmt + .update_tunnel_relay_endpoints( + locator, + &TunnelRelayTunnelEndpoint { + base: TunnelEndpoint { + id: Some(format!("{}-relay", host_id)), + connection_mode: TunnelConnectionMode::TunnelRelay, + host_id: host_id.to_string(), + host_public_keys: vec![], + port_uri_format: None, + port_ssh_command_format: None, + ssh_gateway_public_key: None, + tunnel_ssh_command: None, + tunnel_uri: None, }, - &TunnelRequestOptions { - authorization: Some(Authorization::Tunnel(host_token.to_string())), - ..TunnelRequestOptions::default() + client_relay_uri: None, + host_relay_uri: None, + }, + &TunnelRequestOptions { + authorization: Some(Authorization::Tunnel(host_token.to_string())), + ..TunnelRequestOptions::default() + }, + ) + .await + .map_err(|e| TunnelError::HttpError { + error: e, + reason: "failed to update tunnel endpoint for hosting", + })?; + + let url = endpoint + .host_relay_uri + .as_deref() + .ok_or(TunnelError::MissingHostEndpoint)?; + + let req = build_websocket_request( + url, + &[ + ("Sec-WebSocket-Protocol", "tunnel-relay-host"), + ("Authorization", &format!("tunnel {}", host_token)), + ("User-Agent", mgmt.user_agent.to_str().unwrap()), + ], + )?; + + let cnx = if let Some(proxy) = proxy { + log::debug!("connecting via http_proxy on {}", proxy); + connect_via_proxy(req, proxy).await? + } else { + connect_directly(req).await? + }; + + Ok((cnx, endpoint)) +} + +async fn relay_connect_once( + mgmt: &TunnelManagementClient, + locator: &TunnelLocator, + host_id: Uuid, + proxy: &Option, + host_keypair: russh_keys::key::KeyPair, + ports_rx: watch::Receiver, + host_token: &str, + keep_alive: Option<(Duration, Arc>)>, +) -> Result { + let (cnx, endpoint) = + create_relay_websocket(mgmt, locator, host_id, proxy, host_token).await?; + let cnx = AsyncRWWebSocket::new(super::ws::AsyncRWWebSocketOptions { + websocket: cnx, + ping_interval: Duration::from_secs(60), + ping_timeout: Duration::from_secs(10), + }); + + let (client_session, mut rx) = RelayTunnelHost::make_ssh_client(cnx) + .await + .map_err(TunnelError::TunnelRelayDisconnected)?; + let client_session = Arc::new(client_session); + let client_session_ret = client_session.clone(); + + // @group SSH Keep-alive : Periodic liveness probe via is_closed() + if let Some((interval, ka_tx)) = keep_alive { + let ka_tx = ka_tx.clone(); + let session_check = client_session_ret.clone(); + tokio::spawn(async move { + let mut count: u32 = 0; + loop { + tokio::time::sleep(interval).await; + count = count.saturating_add(1); + if session_check.is_closed() { + let _ = ka_tx.send(KeepAliveState::Failed { count }); + break; + } else { + let _ = ka_tx.send(KeepAliveState::Succeeded { count }); + } + } + }); + } + + + log::debug!("established host relay primary session"); + + let mut channels = HashMap::new(); + let join = tokio::spawn(async move { + let mut server = RelayTunnelHost::make_ssh_server(host_keypair.clone()); + loop { + tokio::select! { + Some(op) = rx.recv() => match op { + ChannelOp::Open(id) => { + let (rw, sender) = AsyncRWChannel::new(id, client_session.clone()); + server.run_stream(rw, ports_rx.clone()); + channels.insert(id, sender); + log::info!("Opened new client on channel {}", id); + }, + ChannelOp::Close(id) => { + channels.remove(&id); + }, + ChannelOp::Data(id, data) => { + if let Some(ch) = channels.get(&id) { + if ch.send(data).is_err() { + channels.remove(&id); + } + } + }, }, - ) + else => break, + } + } + + client_session + .disconnect(russh::Disconnect::ByApplication, "going away", "en") .await - .map_err(|e| TunnelError::HttpError { - error: e, - reason: "failed to update tunnel endpoint for hosting", - })?; + .ok(); - let url = endpoint - .host_relay_uri - .as_deref() - .ok_or(TunnelError::MissingHostEndpoint)?; - - let req = build_websocket_request( - url, - &[ - ("Sec-WebSocket-Protocol", "tunnel-relay-host"), - ("Authorization", &format!("tunnel {}", host_token)), - ("User-Agent", self.mgmt.user_agent.to_str().unwrap()), - ], - )?; - - let cnx = if let Some(proxy) = &self.proxy { - log::debug!("connecting via http_proxy on {}", proxy); - connect_via_proxy(req, proxy).await? - } else { - connect_directly(req).await? - }; + log::debug!("disconnected primary session after EOF"); - Ok((cnx, endpoint)) - } + Ok(()) + }); + + Ok(RelayHandle { + endpoint, + join, + session: client_session_ret, + }) } /// Type returned in a channel from `add_forwarded_port_raw`, implementing diff --git a/rs/test/tunnels_test.rs b/rs/test/tunnels_test.rs index 20919425..e3b1c220 100644 --- a/rs/test/tunnels_test.rs +++ b/rs/test/tunnels_test.rs @@ -1,5 +1,112 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +// @group TestSetup : Feature-gated imports for connection tests +#[cfg(feature = "connections")] +use tunnels::connections::{KeepAliveState, ReconnectOptions}; + +// @group UnitTests > Pure Logic : Exponential backoff cap test (no crate types needed) #[test] -fn it_works() { - let result = 2 + 2; - assert_eq!(result, 4); +fn test_exponential_backoff_cap() { + let initial = 1_000u64; + let max = 13_000u64; + let mut delay = initial; + let steps: Vec = { + let mut v = vec![delay]; + for _ in 0..10 { + delay = (delay * 2).min(max); + v.push(delay); + } + v + }; + assert_eq!(steps[0], 1_000); + assert_eq!(steps[1], 2_000); + assert_eq!(steps[2], 4_000); + assert_eq!(steps[3], 8_000); + assert_eq!(steps[4], 13_000); // capped + assert_eq!(steps[5], 13_000); // stays at cap } + +// @group UnitTests > Pure Logic : SSH error sets skip_delay and resets backoff +#[test] +fn test_skip_delay_resets_delay_on_ssh_error() { + let initial_delay = 1_000u64; + let mut delay = 8_000u64; // simulate ramped-up delay + let mut skip_delay = false; + + // Simulate SSH error path sets skip_delay and resets delay + delay = initial_delay; + skip_delay = true; + + let effective_delay = if skip_delay { 0 } else { delay }; + assert_eq!(effective_delay, 0, "SSH error should skip the wait"); + assert_eq!(delay, initial_delay, "delay should reset to initial after SSH error"); +} + +// @group UnitTests > ReconnectOptions : Default field values +#[cfg(feature = "connections")] +#[test] +fn test_reconnect_options_defaults() { + let opts = ReconnectOptions::default(); + assert_eq!(opts.initial_delay_ms, 1_000); + assert_eq!(opts.max_delay_ms, 13_000); + assert!(opts.max_attempts.is_none()); + assert!(opts.keep_alive_interval.is_none(), "keep_alive_interval should be None by default"); + assert!(opts.token_refresher.is_none(), "token_refresher should be None by default"); +} + +// @group UnitTests > ReconnectOptions : max_attempts=0 is preserved in configuration +#[cfg(feature = "connections")] +#[test] +fn test_max_attempts_zero_is_preserved_in_options() { + let opts = ReconnectOptions { + max_attempts: Some(0), + ..Default::default() + }; + + assert_eq!(opts.max_attempts, Some(0)); + assert_eq!(opts.initial_delay_ms, 1_000); + assert_eq!(opts.max_delay_ms, 13_000); + assert!(opts.keep_alive_interval.is_none(), "keep_alive_interval should remain None by default"); + assert!(opts.token_refresher.is_none(), "token_refresher should remain None by default"); +} + +// @group UnitTests > KeepAliveState : All variants construct and compare correctly +#[cfg(feature = "connections")] +#[test] +fn test_keep_alive_state_variants() { + assert_eq!(KeepAliveState::NotConfigured, KeepAliveState::NotConfigured); + assert_eq!( + KeepAliveState::Succeeded { count: 42 }, + KeepAliveState::Succeeded { count: 42 } + ); + assert_eq!( + KeepAliveState::Failed { count: 7 }, + KeepAliveState::Failed { count: 7 } + ); + assert_ne!( + KeepAliveState::Succeeded { count: 1 }, + KeepAliveState::Failed { count: 1 } + ); +} + +// @group UnitTests > KeepAliveState : Clone preserves value +#[cfg(feature = "connections")] +#[test] +fn test_keep_alive_state_clone() { + let original = KeepAliveState::Succeeded { count: 3 }; + let cloned = original.clone(); + assert_eq!(original, cloned); +} + +// @group UnitTests > KeepAliveState : watch channel starts NotConfigured and updates +#[cfg(feature = "connections")] +#[tokio::test] +async fn test_keep_alive_state_watch_channel_updates() { + let (tx, rx) = tokio::sync::watch::channel(KeepAliveState::NotConfigured); + assert_eq!(*rx.borrow(), KeepAliveState::NotConfigured); + tx.send(KeepAliveState::Succeeded { count: 1 }).unwrap(); + assert_eq!(*rx.borrow(), KeepAliveState::Succeeded { count: 1 }); + tx.send(KeepAliveState::Failed { count: 2 }).unwrap(); + assert_eq!(*rx.borrow(), KeepAliveState::Failed { count: 2 }); +} \ No newline at end of file