Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 120 additions & 60 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};
use std::pin::Pin;
#[cfg(feature = "early-data")]
use std::task::Waker;
use std::task::{Context, Poll};

use rustls::ClientConnection;
Expand All @@ -20,7 +22,7 @@ pub struct TlsStream<IO> {
pub(crate) state: TlsState,

#[cfg(feature = "early-data")]
pub(crate) early_waker: Option<std::task::Waker>,
pub(crate) early_waker: Option<Waker>,
}

impl<IO> TlsStream<IO> {
Expand Down Expand Up @@ -152,78 +154,70 @@ where
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());

#[allow(clippy::match_single_binding)]
match this.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData(ref mut pos, ref mut data) => {
use std::io::Write;

// write early data
if let Some(mut early_data) = stream.session.early_data() {
let len = match early_data.write(buf) {
Ok(n) => n,
Err(err) => return Poll::Ready(Err(err)),
};
if len != 0 {
data.extend_from_slice(&buf[..len]);
return Poll::Ready(Ok(len));
}
}

// complete handshake
while stream.session.is_handshaking() {
ready!(stream.handshake(cx))?;
}

// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
*pos += len;
}
}

// end
this.state = TlsState::Stream;

if let Some(waker) = this.early_waker.take() {
waker.wake();
}

stream.as_mut_pin().poll_write(cx, buf)
#[cfg(feature = "early-data")]
{
let bufs = [io::IoSlice::new(buf)];
let written = ready!(poll_handle_early_data(
&mut this.state,
&mut stream,
&mut this.early_waker,
cx,
&bufs
))?;
if written != 0 {
return Poll::Ready(Ok(written));
}
_ => stream.as_mut_pin().poll_write(cx, buf),
}

stream.as_mut_pin().poll_write(cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
/// Note: that it does not guarantee the final data to be sent.
/// To be cautious, you must manually call `flush`.
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());

#[cfg(feature = "early-data")]
{
if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
// complete handshake
while stream.session.is_handshaking() {
ready!(stream.handshake(cx))?;
}
let written = ready!(poll_handle_early_data(
&mut this.state,
&mut stream,
&mut this.early_waker,
cx,
bufs
))?;
if written != 0 {
return Poll::Ready(Ok(written));
}
}

// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
*pos += len;
}
}
stream.as_mut_pin().poll_write_vectored(cx, bufs)
}

this.state = TlsState::Stream;
#[inline]
fn is_write_vectored(&self) -> bool {
true
}

if let Some(waker) = this.early_waker.take() {
waker.wake();
}
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());

#[cfg(feature = "early-data")]
ready!(poll_handle_early_data(
&mut this.state,
&mut stream,
&mut this.early_waker,
cx,
&[]
))?;

stream.as_mut_pin().poll_flush(cx)
}
Expand All @@ -248,3 +242,69 @@ where
stream.as_mut_pin().poll_shutdown(cx)
}
}

#[cfg(feature = "early-data")]
fn poll_handle_early_data<IO>(
state: &mut TlsState,
stream: &mut Stream<IO, ClientConnection>,
early_waker: &mut Option<Waker>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
if let TlsState::EarlyData(pos, data) = state {
use std::io::Write;

// write early data
if let Some(mut early_data) = stream.session.early_data() {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little different than what I thought it would be, but it's also good.

I actually think it should be outside the function.

let mut written = 0;

for buf in bufs {
if buf.is_empty() {
continue;
}

let len = match early_data.write(buf) {
Ok(0) => break,
Ok(n) => n,
Err(err) => return Poll::Ready(Err(err)),
};

written += len;
data.extend_from_slice(&buf[..len]);

if len < buf.len() {
break;
}
}

if written != 0 {
return Poll::Ready(Ok(written));
}
}

// complete handshake
while stream.session.is_handshaking() {
ready!(stream.handshake(cx))?;
}

// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
*pos += len;
}
}

// end
*state = TlsState::Stream;

if let Some(waker) = early_waker.take() {
waker.wake();
}
}

Poll::Ready(Ok(0))
}
37 changes: 37 additions & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,43 @@ where
Poll::Ready(Ok(pos))
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
if bufs.iter().all(|buf| buf.is_empty()) {
return Poll::Ready(Ok(0));
}

loop {
let mut would_block = false;
let written = self.session.writer().write_vectored(bufs)?;

while self.session.wants_write() {
match self.write_io(cx) {
Poll::Ready(Ok(0)) | Poll::Pending => {
would_block = true;
break;
}
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}

return match (written, would_block) {
(0, true) => Poll::Pending,
(0, false) => continue,
Comment thread
paolobarbolini marked this conversation as resolved.
(n, _) => Poll::Ready(Ok(n)),
};
}
}

#[inline]
fn is_write_vectored(&self) -> bool {
true
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.session.writer().flush()?;
while self.session.wants_write() {
Expand Down
11 changes: 10 additions & 1 deletion src/common/test_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ impl AsyncWrite for Expected {

#[tokio::test]
async fn stream_good() -> io::Result<()> {
stream_good_impl(false).await
}

#[tokio::test]
async fn stream_good_vectored() -> io::Result<()> {
stream_good_impl(true).await
}

async fn stream_good_impl(vectored: bool) -> io::Result<()> {
const FILE: &[u8] = include_bytes!("../../README.md");

let (server, mut client) = make_pair();
Expand All @@ -139,7 +148,7 @@ async fn stream_good() -> io::Result<()> {
dbg!(stream.read_to_end(&mut buf).await)?;
assert_eq!(buf, FILE);

dbg!(stream.write_all(b"Hello World!").await)?;
dbg!(utils::write(&mut stream, b"Hello World!", vectored).await)?;
stream.session.send_close_notify();

dbg!(stream.shutdown().await)?;
Expand Down
20 changes: 20 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,26 @@ where
}
}

#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
TlsStream::Client(x) => Pin::new(x).poll_write_vectored(cx, bufs),
TlsStream::Server(x) => Pin::new(x).poll_write_vectored(cx, bufs),
}
}

#[inline]
fn is_write_vectored(&self) -> bool {
match self {
TlsStream::Client(x) => x.is_write_vectored(),
TlsStream::Server(x) => x.is_write_vectored(),
}
}

#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Expand Down
18 changes: 18 additions & 0 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,24 @@ where
stream.as_mut_pin().poll_write(cx, buf)
}

/// Note: that it does not guarantee the final data to be sent.
/// To be cautious, you must manually call `flush`.
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
stream.as_mut_pin().poll_write_vectored(cx, bufs)
}

#[inline]
fn is_write_vectored(&self) -> bool {
true
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
let mut stream =
Expand Down
Loading