diff --git a/neqo-bin/src/udp.rs b/neqo-bin/src/udp.rs index 148ff43175..c418f5ee3c 100644 --- a/neqo-bin/src/udp.rs +++ b/neqo-bin/src/udp.rs @@ -7,6 +7,7 @@ use std::{io, net::SocketAddr}; use neqo_common::Datagram; +use neqo_transport::RECV_BUFFER_SIZE; /// Ideally this would live in [`neqo-udp`]. [`neqo-udp`] is used in Firefox. /// @@ -56,10 +57,12 @@ impl Socket { /// Receive a batch of [`Datagram`]s on the given [`Socket`], each set with /// the provided local address. pub fn recv(&self, local_address: &SocketAddr) -> Result, io::Error> { + let mut recv_buf = vec![0; RECV_BUFFER_SIZE]; self.inner .try_io(tokio::io::Interest::READABLE, || { - neqo_udp::recv_inner(local_address, &self.state, &self.inner) + neqo_udp::recv_inner(local_address, &self.state, &self.inner, &mut recv_buf) }) + .map(|dgrams| dgrams.map(|d| d.to_owned()).collect()) .or_else(|e| { if e.kind() == io::ErrorKind::WouldBlock { Ok(vec![]) diff --git a/neqo-udp/src/lib.rs b/neqo-udp/src/lib.rs index 5f1fb3dbe6..e6ae78d9ae 100644 --- a/neqo-udp/src/lib.rs +++ b/neqo-udp/src/lib.rs @@ -7,10 +7,9 @@ #![allow(clippy::missing_errors_doc)] // Functions simply delegate to tokio and quinn-udp. use std::{ - cell::RefCell, io::{self, IoSliceMut}, net::SocketAddr, - slice, + slice::{self, Chunks}, }; use neqo_common::{qdebug, qtrace, Datagram, IpTos}; @@ -21,11 +20,7 @@ use quinn_udp::{EcnCodepoint, RecvMeta, Transmit, UdpSocketState}; /// Allows reading multiple datagrams in a single [`Socket::recv`] call. // // TODO: Experiment with different values across platforms. -const RECV_BUF_SIZE: usize = u16::MAX as usize; - -std::thread_local! { - static RECV_BUF: RefCell> = RefCell::new(vec![0; RECV_BUF_SIZE]); -} +pub const RECV_BUF_SIZE: usize = u16::MAX as usize; pub fn send_inner( state: &UdpSocketState, @@ -57,63 +52,89 @@ use std::os::fd::AsFd as SocketRef; #[cfg(windows)] use std::os::windows::io::AsSocket as SocketRef; -pub fn recv_inner( +pub fn recv_inner<'a>( local_address: &SocketAddr, state: &UdpSocketState, socket: impl SocketRef, -) -> Result, io::Error> { - let dgrams = RECV_BUF.with_borrow_mut(|recv_buf| -> Result, io::Error> { - let mut meta; - - loop { - meta = RecvMeta::default(); - - state.recv( - (&socket).into(), - &mut [IoSliceMut::new(recv_buf)], - slice::from_mut(&mut meta), - )?; - - if meta.len == 0 || meta.stride == 0 { - qdebug!( - "ignoring datagram from {} to {} len {} stride {}", - meta.addr, - local_address, - meta.len, - meta.stride - ); - continue; - } - - break; + recv_buf: &'a mut [u8], +) -> Result, io::Error> { + let mut meta; + + let data = loop { + meta = RecvMeta::default(); + + state.recv( + (&socket).into(), + &mut [IoSliceMut::new(recv_buf)], + slice::from_mut(&mut meta), + )?; + + if meta.len == 0 || meta.stride == 0 { + qdebug!( + "ignoring datagram from {} to {} len {} stride {}", + meta.addr, + local_address, + meta.len, + meta.stride + ); + continue; } - Ok(recv_buf[0..meta.len] - .chunks(meta.stride) - .map(|d| { - qtrace!( - "received {} bytes from {} to {}", - d.len(), - meta.addr, - local_address, - ); - Datagram::new( - meta.addr, - *local_address, - meta.ecn.map(|n| IpTos::from(n as u8)).unwrap_or_default(), - d, - ) - }) - .collect()) - })?; + break &recv_buf[..meta.len]; + }; qtrace!( - "received {} datagrams ({:?})", - dgrams.len(), - dgrams.iter().map(|d| d.len()).collect::>(), + "received {} bytes from {} to {} in {} segments", + data.len(), + meta.addr, + local_address, + data.len().div_ceil(meta.stride), ); - Ok(dgrams) + Ok(DatagramIter { + meta, + datagrams: data.chunks(meta.stride), + local_address: *local_address, + }) +} + +pub struct DatagramIter<'a> { + meta: RecvMeta, + datagrams: Chunks<'a, u8>, + local_address: SocketAddr, +} + +impl<'a> std::fmt::Debug for DatagramIter<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Datagrams") + .field("meta", &self.meta) + .field("local_address", &self.local_address) + .finish() + } +} + +impl<'a> Iterator for DatagramIter<'a> { + type Item = Datagram<&'a [u8]>; + + fn next(&mut self) -> Option { + self.datagrams.next().map(|d| { + Datagram::from_slice( + self.meta.addr, + self.local_address, + self.meta + .ecn + .map(|n| IpTos::from(n as u8)) + .unwrap_or_default(), + d, + ) + }) + } +} + +impl<'a> ExactSizeIterator for DatagramIter<'a> { + fn len(&self) -> usize { + self.datagrams.len() + } } /// A wrapper around a UDP socket, sending and receiving [`Datagram`]s. @@ -138,8 +159,12 @@ impl Socket { /// Receive a batch of [`Datagram`]s on the given [`Socket`], each /// set with the provided local address. - pub fn recv(&self, local_address: &SocketAddr) -> Result, io::Error> { - recv_inner(local_address, &self.state, &self.inner) + pub fn recv<'a>( + &self, + local_address: &SocketAddr, + recv_buf: &'a mut [u8], + ) -> Result, io::Error> { + recv_inner(local_address, &self.state, &self.inner, recv_buf) } } @@ -170,7 +195,8 @@ mod tests { ); sender.send(&datagram)?; - let res = receiver.recv(&receiver_addr); + let mut recv_buf = vec![0; RECV_BUF_SIZE]; + let res = receiver.recv(&receiver_addr, &mut recv_buf); assert_eq!(res.unwrap_err().kind(), std::io::ErrorKind::WouldBlock); Ok(()) @@ -191,17 +217,15 @@ mod tests { sender.send(&datagram)?; - let received_datagram = receiver - .recv(&receiver_addr) - .expect("receive to succeed") - .into_iter() - .next() - .expect("receive to yield datagram"); + let mut recv_buf = vec![0; RECV_BUF_SIZE]; + let mut received_datagrams = receiver + .recv(&receiver_addr, &mut recv_buf) + .expect("receive to succeed"); // Assert that the ECN is correct. assert_eq!( IpTosEcn::from(datagram.tos()), - IpTosEcn::from(received_datagram.tos()) + IpTosEcn::from(received_datagrams.next().unwrap().tos()) ); Ok(()) @@ -236,11 +260,11 @@ mod tests { // Allow for one GSO sendmmsg to result in multiple GRO recvmmsg. let mut num_received = 0; + let mut recv_buf = vec![0; RECV_BUF_SIZE]; while num_received < max_gso_segments { receiver - .recv(&receiver_addr) + .recv(&receiver_addr, &mut recv_buf) .expect("receive to succeed") - .into_iter() .for_each(|d| { assert_eq!( SEGMENT_SIZE,