diff --git a/src/response/chunked.rs b/src/response/chunked.rs index e786319..29e3594 100644 --- a/src/response/chunked.rs +++ b/src/response/chunked.rs @@ -182,6 +182,10 @@ where reader.raw_body.buffer.buffer = &mut reader.raw_body.buffer.buffer[consumed_from_buffer..]; } + if !reader.is_done() { + return Err(Error::BufferTooSmall); + } + Ok(&mut buffer[..len]) } } diff --git a/src/response/mod.rs b/src/response/mod.rs index 90f24db..cfa1535 100644 --- a/src/response/mod.rs +++ b/src/response/mod.rs @@ -242,10 +242,16 @@ where ReaderHint::Empty => Ok(&mut []), ReaderHint::FixedLength(content_length) => { // Read into the buffer after the portion that was already received when parsing the header + let to_read = self.body_buf.len().min(content_length); self.conn - .read_exact(&mut self.body_buf[self.raw_body_read..content_length]) + .read_exact(&mut self.body_buf[self.raw_body_read..to_read]) .await?; + if content_length > self.body_buf.len() { + warn!("FixedLength: {} bytes remained", content_length - self.body_buf.len()); + return Err(Error::BufferTooSmall); + } + Ok(&mut self.body_buf[..content_length]) } ReaderHint::Chunked => { @@ -509,6 +515,22 @@ mod tests { assert!(conn.is_exhausted()); } + #[tokio::test] + async fn read_to_end_with_content_length_with_small_buffer() { + let mut conn = FakeSingleReadConnection::new( + b"HTTP/1.1 200 OK\r\nContent-Length: 52\r\n\r\nHELLO WORLD this is some longer response for testing", + ); + let mut header_buf = [0; 40]; + let response = Response::read(&mut conn, Method::GET, &mut header_buf).await.unwrap(); + + let body = response.body().read_to_end().await.expect_err("Failure expected"); + + match body { + Error::BufferTooSmall => {} + e => panic!("Unexpected error: {e:?}"), + } + } + #[tokio::test] async fn can_discard_with_content_length() { let mut conn = FakeSingleReadConnection::new(b"HTTP/1.1 200 OK\r\nContent-Length: 11\r\n\r\nHELLO WORLD"); diff --git a/tests/request.rs b/tests/request.rs index 8eb102b..111ccc5 100644 --- a/tests/request.rs +++ b/tests/request.rs @@ -3,6 +3,7 @@ use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Server}; use reqwless::client::HttpConnection; use reqwless::request::{Method, RequestBuilder}; +use reqwless::Error; use reqwless::{headers::ContentType, request::Request, response::Response}; use std::str::from_utf8; use std::sync::Once; @@ -91,6 +92,13 @@ async fn google_panic() { let mut rx_buf = [0; 8 * 1024]; let resp = Response::read(&mut conn, Method::GET, &mut rx_buf).await.unwrap(); - let body = resp.body().read_to_end().await.unwrap(); - println!("{} -> {}", body.len(), core::str::from_utf8(&body).unwrap()); + let result = resp.body().read_to_end().await; + + match result { + Ok(body) => { + println!("{} -> {}", body.len(), core::str::from_utf8(&body).unwrap()); + } + Err(Error::BufferTooSmall) => println!("Buffer too small"), + Err(e) => panic!("Unexpected error: {e:?}"), + } }