Skip to content

Commit

Permalink
io: support vectored writes for DuplexStream (#5985)
Browse files Browse the repository at this point in the history
  • Loading branch information
maminrayej authored Sep 8, 2023
1 parent fb3ae0a commit 9fafe78
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 0 deletions.
75 changes: 75 additions & 0 deletions tokio/src/io/util/mem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,18 @@ impl AsyncWrite for DuplexStream {
Pin::new(&mut *self.write.lock()).poll_write(cx, buf)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut *self.write.lock()).poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
true
}

#[allow(unused_mut)]
fn poll_flush(
mut self: Pin<&mut Self>,
Expand Down Expand Up @@ -224,6 +236,37 @@ impl Pipe {
}
Poll::Ready(Ok(len))
}

fn poll_write_vectored_internal(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
if self.is_closed {
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
}
let avail = self.max_buf_size - self.buffer.len();
if avail == 0 {
self.write_waker = Some(cx.waker().clone());
return Poll::Pending;
}

let mut rem = avail;
for buf in bufs {
if rem == 0 {
break;
}

let len = buf.len().min(rem);
self.buffer.extend_from_slice(&buf[..len]);
rem -= len;
}

if let Some(waker) = self.read_waker.take() {
waker.wake();
}
Poll::Ready(Ok(avail - rem))
}
}

impl AsyncRead for Pipe {
Expand Down Expand Up @@ -285,6 +328,38 @@ impl AsyncWrite for Pipe {
}
}

cfg_coop! {
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
ready!(crate::trace::trace_leaf(cx));
let coop = ready!(crate::runtime::coop::poll_proceed(cx));

let ret = self.poll_write_vectored_internal(cx, bufs);
if ret.is_ready() {
coop.made_progress();
}
ret
}
}

cfg_not_coop! {
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
ready!(crate::trace::trace_leaf(cx));
self.poll_write_vectored_internal(cx, bufs)
}
}

fn is_write_vectored(&self) -> bool {
true
}

fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
Expand Down
47 changes: 47 additions & 0 deletions tokio/tests/duplex_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use std::io::IoSlice;
use tokio::io::{AsyncReadExt, AsyncWriteExt};

const HELLO: &[u8] = b"hello world...";

#[tokio::test]
async fn write_vectored() {
let (mut client, mut server) = tokio::io::duplex(64);

let ret = client
.write_vectored(&[IoSlice::new(HELLO), IoSlice::new(HELLO)])
.await
.unwrap();
assert_eq!(ret, HELLO.len() * 2);

client.flush().await.unwrap();
drop(client);

let mut buf = Vec::with_capacity(HELLO.len() * 2);
let bytes_read = server.read_to_end(&mut buf).await.unwrap();

assert_eq!(bytes_read, HELLO.len() * 2);
assert_eq!(buf, [HELLO, HELLO].concat());
}

#[tokio::test]
async fn write_vectored_and_shutdown() {
let (mut client, mut server) = tokio::io::duplex(64);

let ret = client
.write_vectored(&[IoSlice::new(HELLO), IoSlice::new(HELLO)])
.await
.unwrap();
assert_eq!(ret, HELLO.len() * 2);

client.shutdown().await.unwrap();
drop(client);

let mut buf = Vec::with_capacity(HELLO.len() * 2);
let bytes_read = server.read_to_end(&mut buf).await.unwrap();

assert_eq!(bytes_read, HELLO.len() * 2);
assert_eq!(buf, [HELLO, HELLO].concat());
}

0 comments on commit 9fafe78

Please sign in to comment.