diff --git a/rs/src/connections/errors.rs b/rs/src/connections/errors.rs index 65047a33..ec1c83cf 100644 --- a/rs/src/connections/errors.rs +++ b/rs/src/connections/errors.rs @@ -30,6 +30,9 @@ pub enum TunnelError { #[error("proxy connection failed: {0}")] ProxyConnectionFailed(std::io::Error), + #[error("proxy address invalid: {0}")] + ProxyAddressInvalid(url::ParseError), + #[error("proxy handshake failed: {0}")] ProxyHandshakeFailed(hyper::Error), diff --git a/rs/src/connections/ws.rs b/rs/src/connections/ws.rs index 12460197..d94c1e8b 100644 --- a/rs/src/connections/ws.rs +++ b/rs/src/connections/ws.rs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -use std::{io, pin::Pin, task::Poll, time::Duration}; +use std::{io, net::SocketAddr, pin::Pin, task::Poll, time::Duration}; use futures::{Future, Sink, Stream}; use tokio::{ @@ -237,9 +237,19 @@ pub(crate) async fn connect_via_proxy( format!("{}:{}", hostname, port) }; - let stream = TcpStream::connect(proxy_addr) - .await - .map_err(TunnelError::ProxyConnectionFailed)?; + let stream = match proxy_addr.parse::() { + Ok(addr) => TcpStream::connect(addr).await, + Err(_) => { + let as_uri = url::Url::parse(proxy_addr).map_err(TunnelError::ProxyAddressInvalid)?; + TcpStream::connect(( + as_uri.host_str().unwrap_or("localhost"), + as_uri.port().unwrap_or(80), + )) + .await + } + }; + + let stream = stream.map_err(TunnelError::ProxyConnectionFailed)?; let (mut request_sender, conn) = hyper::client::conn::handshake(stream) .await