Skip to content

Commit

Permalink
Merge pull request #61 from bugadani/end
Browse files Browse the repository at this point in the history
Allow read_to_end with ChunkedEncoding
  • Loading branch information
bugadani authored Nov 28, 2023
2 parents ec9435a + 9c2e257 commit ab10c14
Show file tree
Hide file tree
Showing 8 changed files with 625 additions and 379 deletions.
33 changes: 33 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,36 @@ impl From<nourl::Error> for Error {
Error::InvalidUrl(e)
}
}

/// Trait for types that may optionally implement [`embedded_io_async::BufRead`]
pub trait TryBufRead: embedded_io_async::Read {
async fn try_fill_buf(&mut self) -> Option<Result<&[u8], Self::Error>> {
None
}

fn try_consume(&mut self, _amt: usize) {}
}

impl<C> TryBufRead for crate::client::HttpConnection<'_, C>
where
C: embedded_io_async::Read + embedded_io_async::Write,
{
async fn try_fill_buf(&mut self) -> Option<Result<&[u8], Self::Error>> {
// embedded-tls has its own internal buffer, let's prefer that if we can
#[cfg(feature = "embedded-tls")]
if let Self::Tls(ref mut tls) = self {
use embedded_io_async::{BufRead, Error};
return Some(tls.fill_buf().await.map_err(|e| e.kind()));
}

None
}

fn try_consume(&mut self, amt: usize) {
#[cfg(feature = "embedded-tls")]
if let Self::Tls(tls) = self {
use embedded_io_async::BufRead;
tls.consume(amt);
}
}
}
35 changes: 17 additions & 18 deletions src/reader.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use embedded_io::{Error, ErrorKind, ErrorType};
use embedded_io_async::{BufRead, Read, Write};
use embedded_io_async::{BufRead, Read};

use crate::client::HttpConnection;
use crate::TryBufRead;

struct ReadBuffer<'buf> {
buffer: &'buf mut [u8],
loaded: usize,
pub(crate) struct ReadBuffer<'buf> {
pub buffer: &'buf mut [u8],
pub loaded: usize,
}

impl<'buf> ReadBuffer<'buf> {
Expand Down Expand Up @@ -46,8 +46,8 @@ pub struct BufferingReader<'resp, 'buf, B>
where
B: Read,
{
buffer: ReadBuffer<'buf>,
stream: &'resp mut B,
pub(crate) buffer: ReadBuffer<'buf>,
pub(crate) stream: &'resp mut B,
}

impl<'resp, 'buf, B> BufferingReader<'resp, 'buf, B>
Expand Down Expand Up @@ -83,20 +83,22 @@ where
}
}

impl<C> BufRead for BufferingReader<'_, '_, HttpConnection<'_, C>>
impl<C> BufRead for BufferingReader<'_, '_, C>
where
C: Read + Write,
C: TryBufRead,
{
async fn fill_buf(&mut self) -> Result<&[u8], ErrorKind> {
// We need to consume the loaded bytes before we read mode.
if self.buffer.is_empty() {
// embedded-tls has its own internal buffer, let's prefer that if we can
#[cfg(feature = "embedded-tls")]
if let HttpConnection::Tls(ref mut tls) = self.stream {
return tls.fill_buf().await.map_err(|e| e.kind());
// The matches/if let dance is to fix lifetime of the borrowed inner connection.
if self.stream.try_fill_buf().await.is_some() {
if let Some(result) = self.stream.try_fill_buf().await {
return result.map_err(|e| e.kind());
}
unreachable!()
}

self.buffer.loaded = self.stream.read(&mut self.buffer.buffer).await?;
self.buffer.loaded = self.stream.read(&mut self.buffer.buffer).await.map_err(|e| e.kind())?;
}

self.buffer.fill_buf()
Expand All @@ -109,10 +111,7 @@ where
let unconsumed = self.buffer.consume(amt);

if unconsumed > 0 {
#[cfg(feature = "embedded-tls")]
if let HttpConnection::Tls(tls) = &mut self.stream {
tls.consume(unconsumed);
}
self.stream.try_consume(unconsumed);
}
}
}
235 changes: 235 additions & 0 deletions src/response/chunked.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
use embedded_io_async::{BufRead, Error as _, ErrorType, Read};

use crate::{
reader::{BufferingReader, ReadBuffer},
Error, TryBufRead,
};

#[derive(Clone, Copy, PartialEq, Eq, Debug)]
enum ChunkState {
NoChunk,
NotEmpty(u32),
Empty,
}

impl ChunkState {
fn consume(&mut self, amt: usize) -> usize {
if let ChunkState::NotEmpty(remaining) = self {
let consumed = (amt as u32).min(*remaining);
*remaining -= consumed;
consumed as usize
} else {
0
}
}

fn len(self) -> usize {
if let ChunkState::NotEmpty(len) = self {
len as usize
} else {
0
}
}
}

/// Chunked response body reader
pub struct ChunkedBodyReader<B> {
pub raw_body: B,
chunk_remaining: ChunkState,
}

impl<C> ChunkedBodyReader<C>
where
C: Read,
{
pub fn new(raw_body: C) -> Self {
Self {
raw_body,
chunk_remaining: ChunkState::NoChunk,
}
}

pub fn is_done(&self) -> bool {
self.chunk_remaining == ChunkState::Empty
}

async fn read_next_chunk_length(&mut self) -> Result<(), Error> {
let mut header_buf = [0; 8 + 2]; // 32 bit hex + \r + \n
let mut total_read = 0;

'read_size: loop {
let mut byte = 0;
self.raw_body
.read_exact(core::slice::from_mut(&mut byte))
.await
.map_err(|e| Error::from(e).kind())?;

if byte != b'\n' {
header_buf[total_read] = byte;
total_read += 1;

if total_read == header_buf.len() {
return Err(Error::Codec);
}
} else {
if total_read == 0 || header_buf[total_read - 1] != b'\r' {
return Err(Error::Codec);
}
break 'read_size;
}
}

let hex_digits = total_read - 1;

// Prepend hex with zeros
let mut hex = [b'0'; 8];
hex[8 - hex_digits..].copy_from_slice(&header_buf[..hex_digits]);

let mut bytes = [0; 4];
hex::decode_to_slice(hex, &mut bytes).map_err(|_| Error::Codec)?;

let chunk_length = u32::from_be_bytes(bytes);

debug!("Chunk length: {}", chunk_length);

self.chunk_remaining = match chunk_length {
0 => ChunkState::Empty,
other => ChunkState::NotEmpty(other),
};

Ok(())
}

async fn read_chunk_end(&mut self) -> Result<(), Error> {
// All chunks are terminated with a \r\n
let mut newline_buf = [0; 2];
self.raw_body.read_exact(&mut newline_buf).await?;

if newline_buf != [b'\r', b'\n'] {
return Err(Error::Codec);
}
Ok(())
}

/// Handles chunk boundary and returns the number of bytes in the current (or new) chunk.
async fn handle_chunk_boundary(&mut self) -> Result<usize, Error> {
match self.chunk_remaining {
ChunkState::NoChunk => self.read_next_chunk_length().await?,

ChunkState::NotEmpty(0) => {
// The current chunk is currently empty, advance into a new chunk...
self.read_chunk_end().await?;
self.read_next_chunk_length().await?;
}

ChunkState::NotEmpty(_) => {}

ChunkState::Empty => return Ok(0),
}

if self.chunk_remaining == ChunkState::Empty {
// Read final chunk termination
self.read_chunk_end().await?;
}

Ok(self.chunk_remaining.len())
}
}

impl<'conn, 'buf, C> ChunkedBodyReader<BufferingReader<'conn, 'buf, C>>
where
C: Read + TryBufRead,
{
pub(crate) async fn read_to_end(self) -> Result<&'buf mut [u8], Error> {
let buffer = self.raw_body.buffer.buffer;

// We reconstruct the reader to change the 'buf lifetime.
let mut reader = ChunkedBodyReader {
raw_body: BufferingReader {
buffer: ReadBuffer {
buffer: &mut buffer[..],
loaded: self.raw_body.buffer.loaded,
},
stream: self.raw_body.stream,
},
chunk_remaining: self.chunk_remaining,
};

let mut len = 0;
while !reader.raw_body.buffer.buffer.is_empty() {
// Read some
let read = reader.fill_buf().await?.len();
len += read;

// Make sure we don't erase the newly read data
let was_loaded = reader.raw_body.buffer.loaded;
let fake_loaded = read.min(was_loaded);
reader.raw_body.buffer.loaded = fake_loaded;

// Consume the returned buffer
reader.consume(read);

if reader.is_done() {
// If we're done, we don't care about the rest of the housekeeping.
break;
}

// How many bytes were actually consumed from the preloaded buffer?
let consumed_from_buffer = fake_loaded - reader.raw_body.buffer.loaded;

// ... move the buffer by that many bytes to avoid overwriting in the next iteration.
reader.raw_body.buffer.loaded = was_loaded - consumed_from_buffer;
reader.raw_body.buffer.buffer = &mut reader.raw_body.buffer.buffer[consumed_from_buffer..];
}

if !reader.is_done() {
return Err(Error::BufferTooSmall);
}

Ok(&mut buffer[..len])
}
}

impl<C> ErrorType for ChunkedBodyReader<C> {
type Error = Error;
}

impl<C> Read for ChunkedBodyReader<C>
where
C: Read,
{
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
let remaining = self.handle_chunk_boundary().await?;
let max_len = buf.len().min(remaining);

let len = self
.raw_body
.read(&mut buf[..max_len])
.await
.map_err(|e| Error::Network(e.kind()))?;

self.chunk_remaining.consume(len);

Ok(len)
}
}

impl<C> BufRead for ChunkedBodyReader<C>
where
C: BufRead + Read,
{
async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
let remaining = self.handle_chunk_boundary().await?;

let buf = self.raw_body.fill_buf().await.map_err(|e| Error::Network(e.kind()))?;

let len = buf.len().min(remaining);

Ok(&buf[..len])
}

fn consume(&mut self, amt: usize) {
let consumed = self.chunk_remaining.consume(amt);
self.raw_body.consume(consumed);
}
}
Loading

0 comments on commit ab10c14

Please sign in to comment.