From 3010d2890144681cd8643f4ba50b8c86a6a7649d Mon Sep 17 00:00:00 2001 From: link2xt Date: Fri, 18 Jul 2025 04:59:15 +0000 Subject: [PATCH] fix: prevent reuse of the stream after an error When a stream timeouts, `tokio_io_timeout::TimeoutStream` returns an error once, but then allows to keep using the stream, e.g. calling `poll_read()` again. This can be dangerous if the error is ignored. For example in case of IMAP stream, if IMAP command is sent, but then reading the response times out and the error is ignored, it is possible to send another IMAP command. In this case leftover response from a previous command may be read and interpreted as the response to the new IMAP command. ErrorCapturingStream wraps the stream to prevent its reuse after an error. --- src/net.rs | 10 ++- src/net/error_capturing_stream.rs | 136 ++++++++++++++++++++++++++++++ src/net/proxy.rs | 4 +- src/net/session.rs | 8 +- 4 files changed, 150 insertions(+), 8 deletions(-) create mode 100644 src/net/error_capturing_stream.rs diff --git a/src/net.rs b/src/net.rs index 01bd8ed9b2..168e63be00 100644 --- a/src/net.rs +++ b/src/net.rs @@ -16,12 +16,14 @@ use crate::sql::Sql; use crate::tools::time; pub(crate) mod dns; +pub(crate) mod error_capturing_stream; pub(crate) mod http; pub(crate) mod proxy; pub(crate) mod session; pub(crate) mod tls; use dns::lookup_host_with_cache; +pub(crate) use error_capturing_stream::ErrorCapturingStream; pub use http::{Response as HttpResponse, read_url, read_url_blob}; use tls::wrap_tls; @@ -105,7 +107,7 @@ pub(crate) async fn load_connection_timestamp( /// to the network, which is important to reduce the latency of interactive protocols such as IMAP. pub(crate) async fn connect_tcp_inner( addr: SocketAddr, -) -> Result>>> { +) -> Result>>>> { let tcp_stream = timeout(TIMEOUT, TcpStream::connect(addr)) .await .context("connection timeout")? @@ -118,7 +120,9 @@ pub(crate) async fn connect_tcp_inner( timeout_stream.set_write_timeout(Some(TIMEOUT)); timeout_stream.set_read_timeout(Some(TIMEOUT)); - Ok(Box::pin(timeout_stream)) + let error_capturing_stream = ErrorCapturingStream::new(timeout_stream); + + Ok(Box::pin(error_capturing_stream)) } /// Attempts to establish TLS connection @@ -235,7 +239,7 @@ pub(crate) async fn connect_tcp( host: &str, port: u16, load_cache: bool, -) -> Result>>> { +) -> Result>>>> { let connection_futures = lookup_host_with_cache(context, host, port, "", load_cache) .await? .into_iter() diff --git a/src/net/error_capturing_stream.rs b/src/net/error_capturing_stream.rs new file mode 100644 index 0000000000..4edbb5bfce --- /dev/null +++ b/src/net/error_capturing_stream.rs @@ -0,0 +1,136 @@ +use std::io::IoSlice; +use std::net::SocketAddr; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; +use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf}; + +use pin_project::pin_project; + +use crate::net::SessionStream; + +/// Stream that remembers the first error +/// and keeps returning it afterwards. +/// +/// It is needed to avoid accidentally using +/// the stream after read timeout. +#[derive(Debug)] +#[pin_project] +pub(crate) struct ErrorCapturingStream { + #[pin] + inner: T, + + /// If true, the stream has already returned an error once. + /// + /// All read and write operations return error in this case. + is_broken: bool, +} + +impl ErrorCapturingStream { + pub fn new(inner: T) -> Self { + Self { + inner, + is_broken: false, + } + } + + /// Gets a reference to the underlying stream. + pub fn get_ref(&self) -> &T { + &self.inner + } + + /// Gets a pinned mutable reference to the underlying stream. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { + self.project().inner + } +} + +impl AsyncRead for ErrorCapturingStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf, + ) -> Poll> { + let this = self.project(); + if *this.is_broken { + return Poll::Ready(Err(io::Error::other("Broken stream"))); + } + let res = this.inner.poll_read(cx, buf); + if let Poll::Ready(Err(_)) = res { + *this.is_broken = true; + } + res + } +} + +impl AsyncWrite for ErrorCapturingStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.project(); + if *this.is_broken { + return Poll::Ready(Err(io::Error::other("Broken stream"))); + } + let res = this.inner.poll_write(cx, buf); + if let Poll::Ready(Err(_)) = res { + *this.is_broken = true; + } + res + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + if *this.is_broken { + return Poll::Ready(Err(io::Error::other("Broken stream"))); + } + let res = this.inner.poll_flush(cx); + if let Poll::Ready(Err(_)) = res { + *this.is_broken = true; + } + res + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + if *this.is_broken { + return Poll::Ready(Err(io::Error::other("Broken stream"))); + } + let res = this.inner.poll_shutdown(cx); + if let Poll::Ready(Err(_)) = res { + *this.is_broken = true; + } + res + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + let this = self.project(); + if *this.is_broken { + return Poll::Ready(Err(io::Error::other("Broken stream"))); + } + let res = this.inner.poll_write_vectored(cx, bufs); + if let Poll::Ready(Err(_)) = res { + *this.is_broken = true; + } + res + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +} + +impl SessionStream for ErrorCapturingStream { + fn set_read_timeout(&mut self, timeout: Option) { + self.inner.set_read_timeout(timeout) + } + + fn peer_addr(&self) -> anyhow::Result { + self.inner.peer_addr() + } +} diff --git a/src/net/proxy.rs b/src/net/proxy.rs index 0f657b5439..6c4797e9fb 100644 --- a/src/net/proxy.rs +++ b/src/net/proxy.rs @@ -21,9 +21,9 @@ use url::Url; use crate::config::Config; use crate::constants::NON_ALPHANUMERIC_WITHOUT_DOT; use crate::context::Context; -use crate::net::connect_tcp; use crate::net::session::SessionStream; use crate::net::tls::wrap_rustls; +use crate::net::{ErrorCapturingStream, connect_tcp}; use crate::sql::Sql; /// Default SOCKS5 port according to [RFC 1928](https://tools.ietf.org/html/rfc1928). @@ -118,7 +118,7 @@ impl Socks5Config { target_host: &str, target_port: u16, load_dns_cache: bool, - ) -> Result>>>> { + ) -> Result>>>>> { let tcp_stream = connect_tcp(context, &self.host, self.port, load_dns_cache) .await .context("Failed to connect to SOCKS5 proxy")?; diff --git a/src/net/session.rs b/src/net/session.rs index 981e01fd4d..f3d16dc2bc 100644 --- a/src/net/session.rs +++ b/src/net/session.rs @@ -7,6 +7,8 @@ use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufStream, BufWriter}; use tokio::net::TcpStream; use tokio_io_timeout::TimeoutStream; +use crate::net::ErrorCapturingStream; + pub(crate) trait SessionStream: AsyncRead + AsyncWrite + Unpin + Send + Sync + std::fmt::Debug { @@ -61,13 +63,13 @@ impl SessionStream for BufWriter { self.get_ref().peer_addr() } } -impl SessionStream for Pin>> { +impl SessionStream for Pin>>> { fn set_read_timeout(&mut self, timeout: Option) { - self.as_mut().set_read_timeout_pinned(timeout); + self.as_mut().get_pin_mut().set_read_timeout_pinned(timeout); } fn peer_addr(&self) -> Result { - Ok(self.get_ref().peer_addr()?) + Ok(self.get_ref().get_ref().peer_addr()?) } } impl SessionStream for Socks5Stream {