Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions rs/src/connections/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ pub enum TunnelError {
#[error("proxy connection failed: {0}")]
ProxyConnectionFailed(std::io::Error),

#[error("proxy address invalid: {0}")]
ProxyAddressInvalid(url::ParseError),

#[error("proxy handshake failed: {0}")]
ProxyHandshakeFailed(hyper::Error),

Expand Down
18 changes: 14 additions & 4 deletions rs/src/connections/ws.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

use std::{io, pin::Pin, task::Poll, time::Duration};
use std::{io, net::SocketAddr, pin::Pin, task::Poll, time::Duration};

use futures::{Future, Sink, Stream};
use tokio::{
Expand Down Expand Up @@ -237,9 +237,19 @@ pub(crate) async fn connect_via_proxy(
format!("{}:{}", hostname, port)
};

let stream = TcpStream::connect(proxy_addr)
.await
.map_err(TunnelError::ProxyConnectionFailed)?;
let stream = match proxy_addr.parse::<SocketAddr>() {
Ok(addr) => TcpStream::connect(addr).await,
Err(_) => {
let as_uri = url::Url::parse(proxy_addr).map_err(TunnelError::ProxyAddressInvalid)?;
TcpStream::connect((
as_uri.host_str().unwrap_or("localhost"),
as_uri.port().unwrap_or(80),
))
.await
}
};

let stream = stream.map_err(TunnelError::ProxyConnectionFailed)?;

let (mut request_sender, conn) = hyper::client::conn::handshake(stream)
.await
Expand Down