diff --git a/tokio/src/fs/file.rs b/tokio/src/fs/file.rs index 1e88f30434e..cd140c6cc97 100644 --- a/tokio/src/fs/file.rs +++ b/tokio/src/fs/file.rs @@ -725,6 +725,81 @@ impl AsyncWrite for File { } } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + ready!(crate::trace::trace_leaf(cx)); + let me = self.get_mut(); + let inner = me.inner.get_mut(); + + if let Some(e) = inner.last_write_err.take() { + return Ready(Err(e.into())); + } + + loop { + match inner.state { + Idle(ref mut buf_cell) => { + let mut buf = buf_cell.take().unwrap(); + + let seek = if !buf.is_empty() { + Some(SeekFrom::Current(buf.discard_read())) + } else { + None + }; + + let n = buf.copy_from_bufs(bufs); + let std = me.std.clone(); + + let blocking_task_join_handle = spawn_mandatory_blocking(move || { + let res = if let Some(seek) = seek { + (&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std)) + } else { + buf.write_to(&mut &*std) + }; + + (Operation::Write(res), buf) + }) + .ok_or_else(|| { + io::Error::new(io::ErrorKind::Other, "background task failed") + })?; + + inner.state = Busy(blocking_task_join_handle); + + return Ready(Ok(n)); + } + Busy(ref mut rx) => { + let (op, buf) = ready!(Pin::new(rx).poll(cx))?; + inner.state = Idle(Some(buf)); + + match op { + Operation::Read(_) => { + // We don't care about the result here. The fact + // that the cursor has advanced will be reflected in + // the next iteration of the loop + continue; + } + Operation::Write(res) => { + // If the previous write was successful, continue. + // Otherwise, error. + res?; + continue; + } + Operation::Seek(_) => { + // Ignore the seek + continue; + } + } + } + } + } + } + + fn is_write_vectored(&self) -> bool { + true + } + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { ready!(crate::trace::trace_leaf(cx)); let inner = self.inner.get_mut(); diff --git a/tokio/src/io/blocking.rs b/tokio/src/io/blocking.rs index 416573e9732..27d08b1fcc7 100644 --- a/tokio/src/io/blocking.rs +++ b/tokio/src/io/blocking.rs @@ -276,5 +276,22 @@ cfg_fs! { self.buf.truncate(0); ret } + + pub(crate) fn copy_from_bufs(&mut self, bufs: &[io::IoSlice<'_>]) -> usize { + assert!(self.is_empty()); + + let mut rem = MAX_BUF; + for buf in bufs { + if rem == 0 { + break + } + + let len = buf.len().min(rem); + self.buf.extend_from_slice(&buf[..len]); + rem -= len; + } + + MAX_BUF - rem + } } } diff --git a/tokio/tests/fs_file.rs b/tokio/tests/fs_file.rs index 40bd4fce564..6a8b07a7ffe 100644 --- a/tokio/tests/fs_file.rs +++ b/tokio/tests/fs_file.rs @@ -2,6 +2,7 @@ #![cfg(all(feature = "full", not(target_os = "wasi")))] // WASI does not support all fs operations use std::io::prelude::*; +use std::io::IoSlice; use tempfile::NamedTempFile; use tokio::fs::File; use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt, SeekFrom}; @@ -49,6 +50,40 @@ async fn basic_write_and_shutdown() { assert_eq!(file, HELLO); } +#[tokio::test] +async fn write_vectored() { + let tempfile = tempfile(); + + let mut file = File::create(tempfile.path()).await.unwrap(); + + let ret = file + .write_vectored(&[IoSlice::new(HELLO), IoSlice::new(HELLO)]) + .await + .unwrap(); + assert_eq!(ret, HELLO.len() * 2); + file.flush().await.unwrap(); + + let file = std::fs::read(tempfile.path()).unwrap(); + assert_eq!(file, [HELLO, HELLO].concat()); +} + +#[tokio::test] +async fn write_vectored_and_shutdown() { + let tempfile = tempfile(); + + let mut file = File::create(tempfile.path()).await.unwrap(); + + let ret = file + .write_vectored(&[IoSlice::new(HELLO), IoSlice::new(HELLO)]) + .await + .unwrap(); + assert_eq!(ret, HELLO.len() * 2); + file.shutdown().await.unwrap(); + + let file = std::fs::read(tempfile.path()).unwrap(); + assert_eq!(file, [HELLO, HELLO].concat()); +} + #[tokio::test] async fn rewind_seek_position() { let tempfile = tempfile();