From 3f0da4c30fd73fd1fdff940cc84e499eba3034f9 Mon Sep 17 00:00:00 2001 From: moui66744 Date: Fri, 25 Nov 2022 03:06:36 -0500 Subject: [PATCH 01/11] tonic: add simple RDMA support Server handles both HTTP and RDMA requests through `serve_with_rdma`. Client has two options to communicate with the server: either HTTP via Channel or RDMA via RdmaChannel. RDMA communication based on `async-rdma` crate. --- tonic/Cargo.toml | 1 + tonic/src/transport/channel/mod.rs | 2 + tonic/src/transport/channel/rdma_channel.rs | 78 +++++++++++++++++++ tonic/src/transport/mod.rs | 3 +- tonic/src/transport/server/mod.rs | 84 +++++++++++++++++++++ 5 files changed, 167 insertions(+), 1 deletion(-) create mode 100644 tonic/src/transport/channel/rdma_channel.rs diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index c138aa184..54e379267 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -51,6 +51,7 @@ channel = [ # harness = false [dependencies] +async-rdma = "0.4" base64 = "0.13" bytes = "1.0" futures-core = {version = "0.3", default-features = false} diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index 9254cf022..2adb3e30c 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -1,10 +1,12 @@ //! Client implementation and builder. +mod rdma_channel; mod endpoint; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] mod tls; +pub use rdma_channel::RdmaChannel; pub use endpoint::Endpoint; #[cfg(feature = "tls")] pub use tls::ClientTlsConfig; diff --git a/tonic/src/transport/channel/rdma_channel.rs b/tonic/src/transport/channel/rdma_channel.rs new file mode 100644 index 000000000..4c28dc510 --- /dev/null +++ b/tonic/src/transport/channel/rdma_channel.rs @@ -0,0 +1,78 @@ +use crate::{ + body::BoxBody, + transport::{self, BoxFuture}, +}; +use async_rdma::{LocalMrReadAccess, LocalMrWriteAccess, Rdma}; +use http::{Request, Response}; +use http_body::Body; +use std::{alloc::Layout, io, sync::Arc}; +use tokio::net::ToSocketAddrs; +use tower::Service; + +/// A RDMA transport channel +#[derive(Debug, Clone)] +pub struct RdmaChannel { + rdma: Arc, +} + +const MAX_MSG_LEN: usize = 10240; + +impl RdmaChannel { + /// Create a new [`RdmaChannel`] + pub async fn new(addr: A) -> Result + where + A: ToSocketAddrs, + { + // TODO: expose these params to user (create a `RdmaEndpoint` maybe) + let rdma = Arc::new(Rdma::connect(addr, 1, 1, MAX_MSG_LEN).await?); + Ok(Self { rdma }) + } +} + +impl Service> for RdmaChannel { + type Response = Response; + type Error = transport::Error; + type Future = BoxFuture; + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + // TODO: handle error + fn call(&mut self, req: Request) -> Self::Future { + let rdma = Arc::clone(&self.rdma); + Box::pin(async move { + tokio::spawn(rdma_send_req(Arc::clone(&rdma), req)); + + let (resp_mr, len) = rdma.receive_with_imm().await.unwrap(); + let len = len.unwrap() as usize; + let resp_vec = resp_mr.as_slice()[0..len].to_vec(); + + let body = hyper::Body::from(resp_vec); + let resp = http::Response::builder().body(body).unwrap(); + Ok(resp) + }) + } +} + +async fn rdma_send_req(rdma: Arc, mut req: Request) { + let mut req_header = req.uri().to_string().into_bytes(); + req_header.push(32u8); + let len_header = req_header.len(); + let mut req_mr = rdma + .alloc_local_mr(Layout::new::<[u8; MAX_MSG_LEN]>()) + .unwrap(); + req_mr.as_mut_slice()[0..len_header].copy_from_slice(req_header.as_slice()); + + let mut len = len_header; + while let Some(Ok(bytes)) = req.data().await { + let req_body = bytes.to_vec(); + let len_body = req_body.len(); + assert!(len + len_body <= MAX_MSG_LEN); // TODO: len > MAX_MSG_LEN + req_mr.as_mut_slice()[len..len + len_body].copy_from_slice(req_body.as_slice()); + len += len_body; + } + rdma.send_with_imm(&req_mr, len as u32).await.unwrap(); +} diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index e83b4c424..935a8dd9f 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -96,7 +96,8 @@ mod tls; #[doc(inline)] #[cfg(feature = "channel")] #[cfg_attr(docsrs, doc(cfg(feature = "channel")))] -pub use self::channel::{Channel, Endpoint}; +pub use self::channel::{Channel, Endpoint, RdmaChannel}; +pub use tokio::net::ToSocketAddrs; pub use self::error::Error; #[doc(inline)] pub use self::server::{NamedService, Server}; diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 5fa003afa..9f262f539 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -35,6 +35,7 @@ use crate::transport::Error; use self::recover_error::RecoverError; use super::service::{GrpcTimeout, ServerIo}; use crate::body::BoxBody; +use async_rdma::{LocalMrReadAccess, LocalMrWriteAccess, Rdma, RdmaBuilder}; use bytes::Bytes; use futures_core::Stream; use futures_util::{future, ready}; @@ -43,9 +44,11 @@ use http_body::Body as _; use hyper::{server::accept, Body}; use pin_project::pin_project; use std::{ + alloc::Layout, convert::Infallible, fmt, future::Future, + io, marker::PhantomData, net::SocketAddr, pin::Pin, @@ -596,6 +599,35 @@ impl Router { .await } + /// Listening for HTTP requests from addr + /// while listening for RDMA requests from rdma_addr + pub async fn serve_with_rdma( + self, + addr: SocketAddr, + rdma_addr: SocketAddr, + ) -> Result<(), super::Error> + where + L: Layer, + L::Service: Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: Into + Send, + ResBody: http_body::Body + Send + 'static, + ResBody::Error: Into, + { + // let routes_clone = self.routes.clone(); + tokio::spawn(rdma_serve(rdma_addr, self.routes.clone())); + + let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive) + .map_err(super::Error::from_source)?; + self.server + .serve_with_shutdown::<_, _, future::Ready<()>, _, _, ResBody>( + self.routes, + incoming, + None, + ) + .await + } + /// Consume this [`Server`] creating a future that will execute the server /// on [tokio]'s default executor. And shutdown when the provided signal /// is received. @@ -846,3 +878,55 @@ where future::ready(Ok(svc)) } } + +// TODO: handle this error +async fn rdma_serve(addr: SocketAddr, routes: Routes) -> Result<(), io::Error> { + let rdma = RdmaBuilder::default() + .set_max_message_length(MAX_MSG_LEN) + .listen(addr) + .await?; + let routes_clone = routes.clone(); + rdma_serve_inner(&rdma, routes_clone).await?; + loop { + let rdma = rdma.listen().await?; + let routes = routes.clone(); + tokio::spawn(async move { rdma_serve_inner(&rdma, routes).await }); + } +} + +const MAX_MSG_LEN: usize = 10240; + +// TODO: handle this error +async fn rdma_serve_inner(rdma: &Rdma, mut routes: Routes) -> Result<(), io::Error> { + println!("[server] connected!"); + + loop { + let (req_mr, len) = rdma.receive_with_imm().await?; + let len = len.unwrap() as usize; + let req_vec = req_mr.as_slice()[0..len].to_vec(); + + // deserialize + let idx = req_vec.iter().position(|num| *num == 32).unwrap(); + let uri = &req_vec[0..idx]; + let body = hyper::Body::from(req_vec[idx + 1..len].to_vec()); + let req = http::Request::builder() + .method("POST") + .uri(uri) + .body(body) + .unwrap(); + + let mut resp = routes.call(req).await.unwrap(); + let mut resp_mr = rdma + .alloc_local_mr(Layout::new::<[u8; MAX_MSG_LEN]>()) + .unwrap(); + let mut len = 0; + while let Some(Ok(bytes)) = resp.data().await { + let resp_body = bytes.to_vec(); + let len_body = resp_body.len(); + resp_mr.as_mut_slice()[len..len + len_body].copy_from_slice(resp_body.as_slice()); + len += len_body; + } + rdma.send_with_imm(&resp_mr, len as u32).await.unwrap(); + } + // Ok(()) +} From cd7d711270b812394e733cd1724c5f2a451198a3 Mon Sep 17 00:00:00 2001 From: moui66744 Date: Fri, 25 Nov 2022 03:08:29 -0500 Subject: [PATCH 02/11] build: add `connect_rdma` Client establishess an RDMA connection to the server through `connect_rdma`. --- tonic-build/src/client.rs | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tonic-build/src/client.rs b/tonic-build/src/client.rs index 0aff59d3d..f56d56c72 100644 --- a/tonic-build/src/client.rs +++ b/tonic-build/src/client.rs @@ -29,6 +29,7 @@ pub fn generate( ); let connect = generate_connect(&service_ident, build_transport); + let connect_rdma = generate_connect_rdma(&service_ident, build_transport); let package = if emit_package { service.package() } else { "" }; let path = format!( @@ -69,6 +70,7 @@ pub fn generate( } #connect + #connect_rdma impl #service_ident where @@ -146,6 +148,27 @@ fn generate_connect(service_ident: &syn::Ident, enabled: bool) -> TokenStream { } } +#[cfg(feature = "transport")] +fn generate_connect_rdma(service_ident: &syn::Ident, enable: bool) -> TokenStream { + let connect_impl = quote! { + impl #service_ident { + pub async fn connect_rdma(addr: A) -> Result + where + A: tonic::transport::ToSocketAddrs + { + let rdma_channel = tonic::transport::RdmaChannel::new(addr).await?; + Ok(Self::new(rdma_channel)) + } + } + }; + + if enable { + connect_impl + } else { + TokenStream::new() + } +} + #[cfg(not(feature = "transport"))] fn generate_connect(_service_ident: &syn::Ident, _enabled: bool) -> TokenStream { TokenStream::new() From 88baa20848abf953698ad5dcc0d7d5e3e54061fb Mon Sep 17 00:00:00 2001 From: moui66744 Date: Fri, 25 Nov 2022 03:10:14 -0500 Subject: [PATCH 03/11] add example of RDMA communication. --- examples/Cargo.toml | 8 + examples/src/routeguide_rdma/client.rs | 126 ++++++++++++++++ examples/src/routeguide_rdma/data.rs | 33 ++++ examples/src/routeguide_rdma/server.rs | 201 +++++++++++++++++++++++++ 4 files changed, 368 insertions(+) create mode 100644 examples/src/routeguide_rdma/client.rs create mode 100644 examples/src/routeguide_rdma/data.rs create mode 100644 examples/src/routeguide_rdma/server.rs diff --git a/examples/Cargo.toml b/examples/Cargo.toml index a4b240a6c..78f4b58f6 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -26,6 +26,14 @@ path = "src/routeguide/server.rs" name = "routeguide-client" path = "src/routeguide/client.rs" +[[bin]] +name = "routeguide-rdma-server" +path = "src/routeguide_rdma/server.rs" + +[[bin]] +name = "routeguide-rdma-client" +path = "src/routeguide_rdma/client.rs" + [[bin]] name = "authentication-client" path = "src/authentication/client.rs" diff --git a/examples/src/routeguide_rdma/client.rs b/examples/src/routeguide_rdma/client.rs new file mode 100644 index 000000000..432e74874 --- /dev/null +++ b/examples/src/routeguide_rdma/client.rs @@ -0,0 +1,126 @@ +use std::error::Error; +use std::time::Duration; + +use futures::stream; +use rand::rngs::ThreadRng; +use rand::Rng; +use tokio::time; +use tonic::transport::RdmaChannel; +use tonic::Request; + +use routeguide::route_guide_client::RouteGuideClient; +use routeguide::{Point, Rectangle, RouteNote}; + +pub mod routeguide { + tonic::include_proto!("routeguide"); +} + +async fn print_features(client: &mut RouteGuideClient) -> Result<(), Box> { + let rectangle = Rectangle { + lo: Some(Point { + latitude: 400_000_000, + longitude: -750_000_000, + }), + hi: Some(Point { + latitude: 420_000_000, + longitude: -730_000_000, + }), + }; + + let mut stream = client + .list_features(Request::new(rectangle)) + .await? + .into_inner(); + + while let Some(feature) = stream.message().await? { + println!("NOTE = {:?}", feature); + } + + Ok(()) +} + +async fn run_record_route(client: &mut RouteGuideClient) -> Result<(), Box> { + let mut rng = rand::thread_rng(); + let point_count: i32 = rng.gen_range(2..100); + + let mut points = vec![]; + for _ in 0..=point_count { + points.push(random_point(&mut rng)) + } + + println!("Traversing {} points", points.len()); + let request = Request::new(stream::iter(points)); + + match client.record_route(request).await { + Ok(response) => println!("SUMMARY: {:?}", response.into_inner()), + Err(e) => println!("something went wrong: {:?}", e), + } + + Ok(()) +} + +async fn run_route_chat(client: &mut RouteGuideClient) -> Result<(), Box> { + let start = time::Instant::now(); + + let outbound = async_stream::stream! { + let mut interval = time::interval(Duration::from_secs(1)); + + loop { + let time = interval.tick().await; + let elapsed = time.duration_since(start); + let note = RouteNote { + location: Some(Point { + latitude: 409146138 + elapsed.as_secs() as i32, + longitude: -746188906, + }), + message: format!("at {:?}", elapsed), + }; + + yield note; + } + }; + + let response = client.route_chat(Request::new(outbound)).await?; + let mut inbound = response.into_inner(); + + while let Some(note) = inbound.message().await? { + println!("NOTE = {:?}", note); + } + + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let mut client = RouteGuideClient::connect_rdma("[::1]:10001").await?; + println!("rdma conn established!"); + + println!("*** SIMPLE RPC ***"); + let response = client + .get_feature(Request::new(Point { + latitude: 409_146_138, + longitude: -746_188_906, + })) + .await?; + println!("RESPONSE = {:?}", response); + + println!("\n*** SERVER STREAMING ***"); + print_features(&mut client).await?; + + println!("\n*** CLIENT STREAMING ***"); + run_record_route(&mut client).await?; + + println!("\n*** BIDIRECTIONAL STREAMING ***"); + run_route_chat(&mut client).await?; + + Ok(()) +} + +fn random_point(rng: &mut ThreadRng) -> Point { + let latitude = (rng.gen_range(0..180) - 90) * 10_000_000; + let longitude = (rng.gen_range(0..360) - 180) * 10_000_000; + Point { + latitude, + longitude, + } +} diff --git a/examples/src/routeguide_rdma/data.rs b/examples/src/routeguide_rdma/data.rs new file mode 100644 index 000000000..bec3805d3 --- /dev/null +++ b/examples/src/routeguide_rdma/data.rs @@ -0,0 +1,33 @@ +use serde::Deserialize; +use std::fs::File; + +#[derive(Debug, Deserialize)] +struct Feature { + location: Location, + name: String, +} + +#[derive(Debug, Deserialize)] +struct Location { + latitude: i32, + longitude: i32, +} + +#[allow(dead_code)] +pub fn load() -> Vec { + let file = File::open("examples/data/route_guide_db.json").expect("failed to open data file"); + + let decoded: Vec = + serde_json::from_reader(&file).expect("failed to deserialize features"); + + decoded + .into_iter() + .map(|feature| crate::routeguide::Feature { + name: feature.name, + location: Some(crate::routeguide::Point { + longitude: feature.location.longitude, + latitude: feature.location.latitude, + }), + }) + .collect() +} diff --git a/examples/src/routeguide_rdma/server.rs b/examples/src/routeguide_rdma/server.rs new file mode 100644 index 000000000..d6a15bda1 --- /dev/null +++ b/examples/src/routeguide_rdma/server.rs @@ -0,0 +1,201 @@ +use std::collections::HashMap; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Instant; + +use futures::{Stream, StreamExt}; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tonic::transport::Server; +use tonic::{Request, Response, Status}; + +use routeguide::route_guide_server::{RouteGuide, RouteGuideServer}; +use routeguide::{Feature, Point, Rectangle, RouteNote, RouteSummary}; + +pub mod routeguide { + tonic::include_proto!("routeguide"); +} + +mod data; + +#[derive(Debug)] +pub struct RouteGuideService { + features: Arc>, +} + +#[tonic::async_trait] +impl RouteGuide for RouteGuideService { + async fn get_feature(&self, request: Request) -> Result, Status> { + println!("GetFeature = {:?}", request); + + for feature in &self.features[..] { + if feature.location.as_ref() == Some(request.get_ref()) { + return Ok(Response::new(feature.clone())); + } + } + + Ok(Response::new(Feature::default())) + } + + type ListFeaturesStream = ReceiverStream>; + + async fn list_features( + &self, + request: Request, + ) -> Result, Status> { + println!("ListFeatures = {:?}", request); + + let (tx, rx) = mpsc::channel(4); + let features = self.features.clone(); + + tokio::spawn(async move { + for feature in &features[..] { + if in_range(feature.location.as_ref().unwrap(), request.get_ref()) { + println!(" => send {:?}", feature); + tx.send(Ok(feature.clone())).await.unwrap(); + } + } + + println!(" /// done sending"); + }); + + Ok(Response::new(ReceiverStream::new(rx))) + } + + async fn record_route( + &self, + request: Request>, + ) -> Result, Status> { + println!("RecordRoute"); + + let mut stream = request.into_inner(); + + let mut summary = RouteSummary::default(); + let mut last_point = None; + let now = Instant::now(); + + while let Some(point) = stream.next().await { + let point = point?; + + println!(" ==> Point = {:?}", point); + + // Increment the point count + summary.point_count += 1; + + // Find features + for feature in &self.features[..] { + if feature.location.as_ref() == Some(&point) { + summary.feature_count += 1; + } + } + + // Calculate the distance + if let Some(ref last_point) = last_point { + summary.distance += calc_distance(last_point, &point); + } + + last_point = Some(point); + } + + summary.elapsed_time = now.elapsed().as_secs() as i32; + + Ok(Response::new(summary)) + } + + type RouteChatStream = Pin> + Send + 'static>>; + + async fn route_chat( + &self, + request: Request>, + ) -> Result, Status> { + println!("RouteChat"); + + let mut notes = HashMap::new(); + let mut stream = request.into_inner(); + + let output = async_stream::try_stream! { + while let Some(note) = stream.next().await { + let note = note?; + + let location = note.location.clone().unwrap(); + + let location_notes = notes.entry(location).or_insert(vec![]); + location_notes.push(note); + + for note in location_notes { + yield note.clone(); + } + } + }; + + Ok(Response::new(Box::pin(output) as Self::RouteChatStream)) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let addr = "[::1]:10000".parse().unwrap(); + let rdma_addr = "[::1]:10001".parse().unwrap(); + + println!( + "RouteGuideServer listening on: {}(HTTP) and {}(RDMA)", + addr, rdma_addr + ); + + let route_guide = RouteGuideService { + features: Arc::new(data::load()), + }; + + let svc = RouteGuideServer::new(route_guide); + + Server::builder() + .add_service(svc) + .serve_with_rdma(addr, rdma_addr) + .await?; + + Ok(()) +} + +impl Eq for Point {} + +fn in_range(point: &Point, rect: &Rectangle) -> bool { + use std::cmp; + + let lo = rect.lo.as_ref().unwrap(); + let hi = rect.hi.as_ref().unwrap(); + + let left = cmp::min(lo.longitude, hi.longitude); + let right = cmp::max(lo.longitude, hi.longitude); + let top = cmp::max(lo.latitude, hi.latitude); + let bottom = cmp::min(lo.latitude, hi.latitude); + + point.longitude >= left + && point.longitude <= right + && point.latitude >= bottom + && point.latitude <= top +} + +/// Calculates the distance between two points using the "haversine" formula. +/// This code was taken from http://www.movable-type.co.uk/scripts/latlong.html. +fn calc_distance(p1: &Point, p2: &Point) -> i32 { + const CORD_FACTOR: f64 = 1e7; + const R: f64 = 6_371_000.0; // meters + + let lat1 = p1.latitude as f64 / CORD_FACTOR; + let lat2 = p2.latitude as f64 / CORD_FACTOR; + let lng1 = p1.longitude as f64 / CORD_FACTOR; + let lng2 = p2.longitude as f64 / CORD_FACTOR; + + let lat_rad1 = lat1.to_radians(); + let lat_rad2 = lat2.to_radians(); + + let delta_lat = (lat2 - lat1).to_radians(); + let delta_lng = (lng2 - lng1).to_radians(); + + let a = (delta_lat / 2f64).sin() * (delta_lat / 2f64).sin() + + (lat_rad1).cos() * (lat_rad2).cos() * (delta_lng / 2f64).sin() * (delta_lng / 2f64).sin(); + + let c = 2f64 * a.sqrt().atan2((1f64 - a).sqrt()); + + (R * c) as i32 +} From 6da0c4e9c3110e70620ab809a3d6ed4455bb0174 Mon Sep 17 00:00:00 2001 From: moui66744 Date: Mon, 28 Nov 2022 00:09:39 -0500 Subject: [PATCH 04/11] relection, types: update generated code. --- .../src/generated/grpc.reflection.v1alpha.rs | 2 +- tonic-types/src/generated/google.rpc.rs | 43 +++++++++---------- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs b/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs index ef3e013dc..d348a4d50 100644 --- a/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs +++ b/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs @@ -23,7 +23,7 @@ pub mod server_reflection_request { FileByFilename(::prost::alloc::string::String), /// Find the proto file that declares the given fully-qualified symbol name. /// This field should be a fully-qualified symbol name - /// (e.g. .\\[.\\] or .). + /// (e.g. .\[.\] or .). #[prost(string, tag = "4")] FileContainingSymbol(::prost::alloc::string::String), /// Find the proto file which defines an extension extending the given diff --git a/tonic-types/src/generated/google.rpc.rs b/tonic-types/src/generated/google.rpc.rs index 1937ae200..659d1fab4 100644 --- a/tonic-types/src/generated/google.rpc.rs +++ b/tonic-types/src/generated/google.rpc.rs @@ -7,12 +7,12 @@ /// [API Design Guide](). #[derive(Clone, PartialEq, ::prost::Message)] pub struct Status { - /// The status code, which should be an enum value of \\[google.rpc.Code\]\[google.rpc.Code\\]. + /// The status code, which should be an enum value of \[google.rpc.Code][google.rpc.Code\]. #[prost(int32, tag = "1")] pub code: i32, /// A developer-facing error message, which should be in English. Any /// user-facing error message should be localized and sent in the - /// \\[google.rpc.Status.details\]\[google.rpc.Status.details\\] field, or localized by the client. + /// \[google.rpc.Status.details][google.rpc.Status.details\] field, or localized by the client. #[prost(string, tag = "2")] pub message: ::prost::alloc::string::String, /// A list of messages that carry the error details. There is a common set of @@ -92,36 +92,33 @@ pub mod quota_failure { /// /// Example of an error when contacting the "pubsub.googleapis.com" API when it /// is not enabled: -/// -/// ```text,json -/// { "reason": "API_DISABLED" -/// "domain": "googleapis.com" -/// "metadata": { -/// "resource": "projects/123", -/// "service": "pubsub.googleapis.com" -/// } -/// } +/// ```json +/// { "reason": "API_DISABLED" +/// "domain": "googleapis.com" +/// "metadata": { +/// "resource": "projects/123", +/// "service": "pubsub.googleapis.com" +/// } +/// } /// ``` -/// /// This response indicates that the pubsub.googleapis.com API is not enabled. /// /// Example of an error that is returned when attempting to create a Spanner /// instance in a region that is out of stock: -/// -/// ```text,json -/// { "reason": "STOCKOUT" -/// "domain": "spanner.googleapis.com", -/// "metadata": { -/// "availableRegions": "us-central1,us-east2" -/// } -/// } +/// ```json +/// { "reason": "STOCKOUT" +/// "domain": "spanner.googleapis.com", +/// "metadata": { +/// "availableRegions": "us-central1,us-east2" +/// } +/// } /// ``` #[derive(Clone, PartialEq, ::prost::Message)] pub struct ErrorInfo { /// The reason of the error. This is a constant value that identifies the /// proximate cause of the error. Error reasons are unique within a particular /// domain of errors. This should be at most 63 characters and match - /// /\\[A-Z0-9\_\\]+/. + /// /\[A-Z0-9_\]+/. #[prost(string, tag = "1")] pub reason: ::prost::alloc::string::String, /// The logical grouping to which the "reason" belongs. The error domain @@ -134,7 +131,7 @@ pub struct ErrorInfo { pub domain: ::prost::alloc::string::String, /// Additional structured details about this error. /// - /// Keys should match /\\[a-zA-Z0-9-\_\\]/ and be limited to 64 characters in + /// Keys should match /\[a-zA-Z0-9-_\]/ and be limited to 64 characters in /// length. When identifying the current value of an exceeded limit, the units /// should be contained in the key, not the value. For example, rather than /// {"instanceLimit": "100/request"}, should be returned as, @@ -226,7 +223,7 @@ pub struct ResourceInfo { pub resource_type: ::prost::alloc::string::String, /// The name of the resource being accessed. For example, a shared calendar /// name: "example.com_4fghdhgsrgh@group.calendar.google.com", if the current - /// error is \\[google.rpc.Code.PERMISSION_DENIED\]\[google.rpc.Code.PERMISSION_DENIED\\]. + /// error is \[google.rpc.Code.PERMISSION_DENIED][google.rpc.Code.PERMISSION_DENIED\]. #[prost(string, tag = "2")] pub resource_name: ::prost::alloc::string::String, /// The owner of the resource (optional). From d73493ffaea815d216731d632ef991d6eb76fc37 Mon Sep 17 00:00:00 2001 From: moui66744 Date: Thu, 1 Dec 2022 05:59:12 -0500 Subject: [PATCH 05/11] feat: add methods to operate on slice for codec. --- examples/src/json-codec/common.rs | 18 +++++++++++++++ tonic/benches/decode.rs | 6 +++++ tonic/src/codec/mod.rs | 9 ++++++++ tonic/src/codec/prost.rs | 38 +++++++++++++++++++++++++++++++ 4 files changed, 71 insertions(+) diff --git a/examples/src/json-codec/common.rs b/examples/src/json-codec/common.rs index 9f0ffeb54..7bb03341d 100644 --- a/examples/src/json-codec/common.rs +++ b/examples/src/json-codec/common.rs @@ -30,6 +30,14 @@ impl Encoder for JsonEncoder { fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> { serde_json::to_writer(buf.writer(), &item).map_err(|e| Status::internal(e.to_string())) } + + fn encode_into_slice( + &mut self, + _item: Self::Item, + _buf: &mut [u8], + ) -> Result { + unimplemented!() + } } #[derive(Debug)] @@ -48,6 +56,16 @@ impl Decoder for JsonDecoder { serde_json::from_reader(buf.reader()).map_err(|e| Status::internal(e.to_string()))?; Ok(Some(item)) } + + fn decode_from_slice(&mut self, buf: &[u8]) -> Result, Self::Error> { + if !buf.has_remaining() { + return Ok(None); + } + + let item: Self::Item = + serde_json::from_reader(buf.reader()).map_err(|e| Status::internal(e.to_string()))?; + Ok(Some(item)) + } } /// A [`Codec`] that implements `application/grpc+json` via the serde library. diff --git a/tonic/benches/decode.rs b/tonic/benches/decode.rs index 96f5b498d..afb75fb9a 100644 --- a/tonic/benches/decode.rs +++ b/tonic/benches/decode.rs @@ -105,6 +105,12 @@ impl Decoder for MockDecoder { buf.advance(self.message_size); Ok(Some(out)) } + + fn decode_from_slice(&mut self, mut buf: &[u8]) -> Result, Self::Error> { + let out = Vec::from(buf.chunk()); + buf.advance(self.message_size); + Ok(Some(out)) + } } fn make_payload(message_length: usize, message_count: usize) -> Bytes { diff --git a/tonic/src/codec/mod.rs b/tonic/src/codec/mod.rs index cc330b14c..6539595b0 100644 --- a/tonic/src/codec/mod.rs +++ b/tonic/src/codec/mod.rs @@ -59,6 +59,10 @@ pub trait Encoder { /// Encodes a message into the provided buffer. fn encode(&mut self, item: Self::Item, dst: &mut EncodeBuf<'_>) -> Result<(), Self::Error>; + + /// Encodes a message into the provided buffer. + fn encode_into_slice(&mut self, item: Self::Item, dst: &mut [u8]) + -> Result; } /// Decodes gRPC message types @@ -75,4 +79,9 @@ pub trait Decoder { /// is no need to get the length from the bytes, gRPC framing is handled /// for you. fn decode(&mut self, src: &mut DecodeBuf<'_>) -> Result, Self::Error>; + + /// Decode a message from the buffer. + /// + /// For RDMA in-situ codec. + fn decode_from_slice(&mut self, src: &[u8]) -> Result, Self::Error>; } diff --git a/tonic/src/codec/prost.rs b/tonic/src/codec/prost.rs index 2facddcef..2f87005b1 100644 --- a/tonic/src/codec/prost.rs +++ b/tonic/src/codec/prost.rs @@ -1,6 +1,7 @@ use super::{Codec, DecodeBuf, Decoder, Encoder}; use crate::codec::EncodeBuf; use crate::{Code, Status}; +use bytes::BufMut; use prost1::Message; use std::marker::PhantomData; @@ -50,6 +51,21 @@ impl Encoder for ProstEncoder { Ok(()) } + + fn encode_into_slice( + &mut self, + item: Self::Item, + mut buf: &mut [u8], + ) -> Result { + let len = item.encoded_len(); + buf.put_u8(0); + buf.put_u32(len as u32); + + item.encode(&mut buf) + .expect("Message only errors if not enough space"); + + Ok(len + 5) + } } /// A [`Decoder`] that knows how to decode `U`. @@ -67,6 +83,14 @@ impl Decoder for ProstDecoder { Ok(item) } + + fn decode_from_slice(&mut self, buf: &[u8]) -> Result, Self::Error> { + let item = Message::decode(buf) + .map(Option::Some) + .map_err(from_decode_error)?; + + Ok(item) + } } fn from_decode_error(error: prost1::DecodeError) -> crate::Status { @@ -180,6 +204,14 @@ mod tests { buf.put(&item[..]); Ok(()) } + + fn encode_into_slice( + &mut self, + _item: Self::Item, + _buf: &mut [u8], + ) -> Result { + unimplemented!() + } } #[derive(Debug, Clone, Default)] @@ -194,6 +226,12 @@ mod tests { buf.advance(LEN); Ok(Some(out)) } + + fn decode_from_slice(&mut self, mut buf: &[u8]) -> Result, Self::Error> { + let out = Vec::from(buf.chunk()); + buf.advance(LEN); + Ok(Some(out)) + } } mod body { From 2316a9f50d3bdfa59d88f5b822d3b8de204c15c0 Mon Sep 17 00:00:00 2001 From: moui66744 Date: Thu, 1 Dec 2022 06:08:24 -0500 Subject: [PATCH 06/11] feat: add request/response/routes structs for RDMA --- tonic/src/lib.rs | 2 + tonic/src/rdma.rs | 76 +++++++++++++++++++++++++++ tonic/src/transport/service/mod.rs | 2 +- tonic/src/transport/service/router.rs | 59 ++++++++++++++++++++- 4 files changed, 137 insertions(+), 2 deletions(-) create mode 100644 tonic/src/rdma.rs diff --git a/tonic/src/lib.rs b/tonic/src/lib.rs index 256574b3e..bf5e76655 100644 --- a/tonic/src/lib.rs +++ b/tonic/src/lib.rs @@ -99,6 +99,7 @@ pub mod transport; mod extensions; mod macros; +mod rdma; mod request; mod response; mod status; @@ -112,6 +113,7 @@ pub use async_trait::async_trait; #[doc(inline)] pub use codec::Streaming; pub use extensions::Extensions; +pub use rdma::{RdmaRequest, RdmaResponse}; pub use request::{IntoRequest, IntoStreamingRequest, Request}; pub use response::Response; pub use status::{Code, Status}; diff --git a/tonic/src/rdma.rs b/tonic/src/rdma.rs new file mode 100644 index 000000000..57c04c23c --- /dev/null +++ b/tonic/src/rdma.rs @@ -0,0 +1,76 @@ +use async_rdma::{LocalMr, LocalMrReadAccess, LocalMrWriteAccess}; + +/// RDMA request. Corresponding to gRPC http::Request +/// +/// Request content: +/// {Path of Service}{Blank Space}{Serialied Message} +#[derive(Debug)] +pub struct RdmaRequest { + req_mr: LocalMr, + len: usize, + resp_mr: LocalMr, +} + +/// RDMA response. Corresponding to gRPC http::Response +#[derive(Debug)] +pub struct RdmaResponse { + /// MR where the response message is stored + pub resp_mr: LocalMr, + /// length of message in MR + pub len: usize, +} + +impl RdmaRequest { + /// Create a new RdmaRequest + pub fn new(req_mr: LocalMr, len: usize, resp_mr: LocalMr) -> Self { + Self { + req_mr, + len, + resp_mr, + } + } + /// Get the index of separator(i.e. blank space). + fn separator_index(&self) -> usize { + self.req_mr + .as_slice() + .iter() + .position(|num| *num == ' ' as u8) // blank space + .unwrap() + } + /// Get service name. + pub fn service(&self) -> &str { + let pos = self.separator_index(); + let idx = self.req_mr.as_slice()[0..pos] + .iter() + .rev() + .position(|c| *c == '/' as u8) + .unwrap(); + std::str::from_utf8(&self.req_mr.as_slice()[0..pos - idx - 1]).unwrap() + } + /// Get path of service function. + pub fn path(&self) -> &str { + let pos = self.separator_index(); + std::str::from_utf8(&self.req_mr.as_slice()[0..pos]).unwrap() + } + /// Get serialized data from MR. + pub fn body(&self) -> &[u8] { + let pos = self.separator_index(); + let res = &self.req_mr.as_slice()[pos + 1..self.len]; + println!("body: {:?}", res); + res + } +} + +impl RdmaResponse { + /// Create RdmaResponse from RdmaRequest. + pub fn from_req(req: RdmaRequest) -> Self { + Self { + resp_mr: req.resp_mr, + len: 0, + } + } + /// Get mutable reference of MR slice. + pub fn buf(&mut self) -> &mut [u8] { + unsafe { self.resp_mr.as_mut_slice_unchecked() } + } +} diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index 355aadf09..3916dee93 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -22,4 +22,4 @@ pub(crate) use self::io::ServerIo; pub(crate) use self::tls::{TlsAcceptor, TlsConnector}; pub(crate) use self::user_agent::UserAgent; -pub use self::router::Routes; +pub use self::router::{RdmaRoutes, Routes}; diff --git a/tonic/src/transport/service/router.rs b/tonic/src/transport/service/router.rs index 5f5e10d18..45093cc03 100644 --- a/tonic/src/transport/service/router.rs +++ b/tonic/src/transport/service/router.rs @@ -1,19 +1,22 @@ use crate::{ body::{boxed, BoxBody}, + codegen::BoxFuture, transport::NamedService, + RdmaRequest, RdmaResponse, }; use axum::handler::Handler; use http::{Request, Response}; use hyper::Body; use pin_project::pin_project; use std::{ + collections::HashMap, convert::Infallible, fmt, future::Future, pin::Pin, task::{Context, Poll}, }; -use tower::ServiceExt; +use tower::{util::BoxCloneService, ServiceExt}; use tower_service::Service; /// A [`Service`] router. @@ -93,3 +96,57 @@ impl Future for RoutesFuture { } } } + +/// RDMA router +#[derive(Debug, Default, Clone)] +pub struct RdmaRoutes { + router: HashMap>, +} + +impl RdmaRoutes { + pub(crate) fn new(svc: S) -> Self + where + S: Service + + NamedService + + Clone + + Send + + 'static, + S::Future: Send + 'static, + S::Error: Into + Send, + { + let router = HashMap::new(); + Self { router }.add_service(svc) + } + + pub(crate) fn add_service(mut self, svc: S) -> Self + where + S: Service + + NamedService + + Clone + + Send + + 'static, + S::Future: Send + 'static, + S::Error: Into + Send, + { + let v = Box::new(svc); + self.router + .insert(format!("/{}", S::NAME), BoxCloneService::new(v)); + self + } +} + +impl Service for RdmaRoutes { + type Response = RdmaResponse; + type Error = Infallible; + type Future = BoxFuture; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: RdmaRequest) -> Self::Future { + let service_name = req.service(); + let svc = self.router.get_mut(service_name).unwrap(); + Box::pin(svc.call(req)) + } +} From 9cc457b4c5a101c9bc5b1dfcca9bc3684ac67ccc Mon Sep 17 00:00:00 2001 From: moui66744 Date: Thu, 1 Dec 2022 06:12:58 -0500 Subject: [PATCH 07/11] perf: eliminate copy when server codec. --- examples/src/routeguide_rdma/server.rs | 3 +- tonic-build/src/server.rs | 121 +++++++++++++++++++++++++ tonic/src/server/grpc.rs | 39 ++++++++ tonic/src/transport/server/mod.rs | 103 +++++++++++++-------- 4 files changed, 225 insertions(+), 41 deletions(-) diff --git a/examples/src/routeguide_rdma/server.rs b/examples/src/routeguide_rdma/server.rs index d6a15bda1..18486763a 100644 --- a/examples/src/routeguide_rdma/server.rs +++ b/examples/src/routeguide_rdma/server.rs @@ -149,7 +149,8 @@ async fn main() -> Result<(), Box> { let svc = RouteGuideServer::new(route_guide); Server::builder() - .add_service(svc) + .add_service(svc.clone()) + .add_service_rdma(svc) .serve_with_rdma(addr, rdma_addr) .await?; diff --git a/tonic-build/src/server.rs b/tonic-build/src/server.rs index 080a00fd7..71adc1004 100644 --- a/tonic-build/src/server.rs +++ b/tonic-build/src/server.rs @@ -19,6 +19,7 @@ pub fn generate( disable_comments: &HashSet, ) -> TokenStream { let methods = generate_methods(service, proto_path, compile_well_known_types); + let methods_rdma = generate_methods_rdma(service, proto_path, compile_well_known_types); let server_service = quote::format_ident!("{}Server", service.name()); let server_trait = quote::format_ident!("{}", service.name()); @@ -148,6 +149,30 @@ pub fn generate( } } + impl tonic::codegen::Service for #server_service + where + T: #server_trait, + { + type Response = tonic::RdmaResponse; + type Error = std::convert::Infallible; + type Future = BoxFuture; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: tonic::RdmaRequest) -> Self::Future { + let inner = self.inner.clone(); + + let path = req.path(); + match path { + #methods_rdma + + _ => todo!(), + } + } + } + impl Clone for #server_service { fn clone(&self) -> Self { let inner = self.inner.clone(); @@ -369,6 +394,102 @@ fn generate_methods( stream } +fn generate_methods_rdma( + service: &T, + proto_path: &str, + compile_well_known_types: bool, +) -> TokenStream { + let mut stream = TokenStream::new(); + + for method in service.methods() { + let path = format!( + "/{}{}{}/{}", + service.package(), + if service.package().is_empty() { + "" + } else { + "." + }, + service.identifier(), + method.identifier() + ); + let method_path = Lit::Str(LitStr::new(&path, Span::call_site())); + let ident = quote::format_ident!("{}", method.name()); + let server_trait = quote::format_ident!("{}", service.name()); + + let method_stream = match (method.client_streaming(), method.server_streaming()) { + (false, false) => generate_unary_rdma( + method, + proto_path, + compile_well_known_types, + ident, + server_trait, + ), + _ => quote! { + todo!(); + } + }; + + let method = quote! { + #method_path => { + #method_stream + } + }; + stream.extend(method); + } + + stream +} + +fn generate_unary_rdma( + method: &T, + proto_path: &str, + compile_well_known_types: bool, + method_ident: Ident, + server_trait: Ident, +) -> TokenStream { + let codec_name = syn::parse_str::(method.codec_path()).unwrap(); + + let service_ident = quote::format_ident!("{}Svc", method.identifier()); + + let (request, response) = method.request_response_name(proto_path, compile_well_known_types); + + quote! { + #[allow(non_camel_case_types)] + struct #service_ident(pub Arc); + + impl tonic::server::UnaryService<#request> for #service_ident { + type Response = #response; + type Future = BoxFuture, tonic::Status>; + + fn call(&mut self, request: tonic::Request<#request>) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { + (*inner).#method_ident(request).await + }; + Box::pin(fut) + } + } + + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = #service_ident(inner); + let codec = #codec_name::default(); + + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config(accept_compression_encodings, send_compression_encodings); + + let res = grpc.unary_rdma(method, req).await; + Ok(res) + }; + + Box::pin(fut) + } +} + fn generate_unary( method: &T, proto_path: &str, diff --git a/tonic/src/server/grpc.rs b/tonic/src/server/grpc.rs index cbe8450ff..4e3ec9b67 100644 --- a/tonic/src/server/grpc.rs +++ b/tonic/src/server/grpc.rs @@ -1,12 +1,14 @@ use crate::codec::compression::{ CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride, }; +use crate::codec::{Decoder, Encoder}; use crate::{ body::BoxBody, codec::{encode_server, Codec, Streaming}, server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService}, Code, Request, Status, }; +use crate::{RdmaRequest, RdmaResponse}; use futures_core::TryStream; use futures_util::{future, stream, TryStreamExt}; use http_body::Body; @@ -363,6 +365,43 @@ where self.accept_compression_encodings, ) } + + /// Handle a single unary request. + pub async fn unary_rdma<'a, S>(&mut self, mut service: S, req: RdmaRequest) -> RdmaResponse + where + S: UnaryService, + { + // decode request + let request = self.map_request_rdma(&req).unwrap(); + // call service + let response = service.call(request).await.unwrap(); + // encode response + self.map_response_rdma(response, RdmaResponse::from_req(req)) + } + + fn map_request_rdma(&mut self, req: &RdmaRequest) -> Result, Status> { + let message = self + .codec + .decoder() + .decode_from_slice(&req.body()[5..]) + .unwrap() + .unwrap(); + Ok(Request::new(message)) + } + + fn map_response_rdma( + &mut self, + response: crate::Response<::Encode>, + mut resp: RdmaResponse, + ) -> RdmaResponse { + let inner = response.into_inner(); + resp.len = self + .codec + .encoder() + .encode_into_slice(inner, &mut resp.buf()) + .unwrap(); + resp + } } impl fmt::Debug for Grpc { diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 9f262f539..25805a054 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -9,7 +9,7 @@ mod tls; #[cfg(unix)] mod unix; -pub use super::service::Routes; +pub use super::service::{RdmaRoutes, Routes}; pub use crate::server::NamedService; pub use conn::{Connected, TcpConnectInfo}; #[cfg(feature = "tls")] @@ -34,8 +34,8 @@ use crate::transport::Error; use self::recover_error::RecoverError; use super::service::{GrpcTimeout, ServerIo}; -use crate::body::BoxBody; -use async_rdma::{LocalMrReadAccess, LocalMrWriteAccess, Rdma, RdmaBuilder}; +use crate::{body::BoxBody, RdmaRequest, RdmaResponse}; +use async_rdma::{Rdma, RdmaBuilder}; use bytes::Bytes; use futures_core::Stream; use futures_util::{future, ready}; @@ -127,6 +127,7 @@ impl Default for Server { pub struct Router { server: Server, routes: Routes, + rdma_routes: RdmaRoutes, } impl NamedService for Either { @@ -354,7 +355,7 @@ impl Server { S::Future: Send + 'static, L: Clone, { - Router::new(self.clone(), Routes::new(svc)) + Router::new(self.clone(), Routes::new(svc), RdmaRoutes::default()) } /// Create a router with the optional `S` typed service as the first service. @@ -376,7 +377,23 @@ impl Server { L: Clone, { let routes = svc.map(Routes::new).unwrap_or_default(); - Router::new(self.clone(), routes) + Router::new(self.clone(), routes, RdmaRoutes::default()) + } + + /// Create a router with the `S` typed service as the first service. + /// + /// + pub fn add_service_rdma(&mut self, svc: S) -> Router + where + S: Service + + NamedService + + Clone + + Send + + 'static, + S::Future: Send + 'static, + L: Clone, + { + Router::new(self.clone(), Routes::default(), RdmaRoutes::new(svc)) } /// Set the [Tower] [`Layer`] all services will be wrapped in. @@ -533,8 +550,12 @@ impl Server { } impl Router { - pub(crate) fn new(server: Server, routes: Routes) -> Self { - Self { server, routes } + pub(crate) fn new(server: Server, routes: Routes, rdma_routes: RdmaRoutes) -> Self { + Self { + server, + routes, + rdma_routes, + } } } @@ -574,6 +595,20 @@ impl Router { self } + /// Add a new service to this router. + pub fn add_service_rdma(mut self, svc: S) -> Self + where + S: Service + + NamedService + + Clone + + Send + + 'static, + S::Future: Send + 'static, + { + self.rdma_routes = self.rdma_routes.add_service(svc); + self + } + /// Consume this [`Server`] creating a future that will execute the server /// on [tokio]'s default executor. /// @@ -599,7 +634,7 @@ impl Router { .await } - /// Listening for HTTP requests from addr + /// Listening for HTTP requests from addr /// while listening for RDMA requests from rdma_addr pub async fn serve_with_rdma( self, @@ -615,7 +650,7 @@ impl Router { ResBody::Error: Into, { // let routes_clone = self.routes.clone(); - tokio::spawn(rdma_serve(rdma_addr, self.routes.clone())); + tokio::spawn(rdma_serve(rdma_addr, self.rdma_routes.clone())); let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive) .map_err(super::Error::from_source)?; @@ -880,53 +915,41 @@ where } // TODO: handle this error -async fn rdma_serve(addr: SocketAddr, routes: Routes) -> Result<(), io::Error> { - let rdma = RdmaBuilder::default() - .set_max_message_length(MAX_MSG_LEN) - .listen(addr) - .await?; - let routes_clone = routes.clone(); - rdma_serve_inner(&rdma, routes_clone).await?; +async fn rdma_serve(addr: SocketAddr, routes: RdmaRoutes) -> Result<(), io::Error> { + let rdma = Arc::new( + RdmaBuilder::default() + .set_max_message_length(MAX_MSG_LEN) + .listen(addr) + .await?, + ); + tokio::spawn(rdma_serve_inner(Arc::clone(&rdma), routes.clone())); loop { - let rdma = rdma.listen().await?; + let rdma = Arc::new(rdma.listen().await?); let routes = routes.clone(); - tokio::spawn(async move { rdma_serve_inner(&rdma, routes).await }); + tokio::spawn(rdma_serve_inner(rdma, routes.clone())); } } const MAX_MSG_LEN: usize = 10240; // TODO: handle this error -async fn rdma_serve_inner(rdma: &Rdma, mut routes: Routes) -> Result<(), io::Error> { +async fn rdma_serve_inner(rdma: Arc, mut routes: RdmaRoutes) -> Result<(), io::Error> { println!("[server] connected!"); loop { let (req_mr, len) = rdma.receive_with_imm().await?; let len = len.unwrap() as usize; - let req_vec = req_mr.as_slice()[0..len].to_vec(); - - // deserialize - let idx = req_vec.iter().position(|num| *num == 32).unwrap(); - let uri = &req_vec[0..idx]; - let body = hyper::Body::from(req_vec[idx + 1..len].to_vec()); - let req = http::Request::builder() - .method("POST") - .uri(uri) - .body(body) + + let resp_mr = rdma.alloc_local_mr(Layout::new::<[u8; 1024]>()).unwrap(); + + let resp = routes + .call(RdmaRequest::new(req_mr, len, resp_mr)) + .await .unwrap(); - let mut resp = routes.call(req).await.unwrap(); - let mut resp_mr = rdma - .alloc_local_mr(Layout::new::<[u8; MAX_MSG_LEN]>()) + rdma.send_with_imm(&resp.resp_mr, resp.len as u32) + .await .unwrap(); - let mut len = 0; - while let Some(Ok(bytes)) = resp.data().await { - let resp_body = bytes.to_vec(); - let len_body = resp_body.len(); - resp_mr.as_mut_slice()[len..len + len_body].copy_from_slice(resp_body.as_slice()); - len += len_body; - } - rdma.send_with_imm(&resp_mr, len as u32).await.unwrap(); } // Ok(()) } From 3c5fb6a18eb0476a960a62272cbf60d978be925f Mon Sep 17 00:00:00 2001 From: moui66744 Date: Thu, 1 Dec 2022 21:48:24 -0500 Subject: [PATCH 08/11] health, reflection: update generated code. --- tonic-health/src/generated/grpc.health.v1.rs | 62 +++++++++++++++++++ .../src/generated/grpc.reflection.v1alpha.rs | 24 +++++++ 2 files changed, 86 insertions(+) diff --git a/tonic-health/src/generated/grpc.health.v1.rs b/tonic-health/src/generated/grpc.health.v1.rs index 30dd88531..c78bd0f0e 100644 --- a/tonic-health/src/generated/grpc.health.v1.rs +++ b/tonic-health/src/generated/grpc.health.v1.rs @@ -352,6 +352,68 @@ pub mod health_server { } } } + impl tonic::codegen::Service for HealthServer + where + T: Health, + { + type Response = tonic::RdmaResponse; + type Error = std::convert::Infallible; + type Future = BoxFuture; + fn poll_ready( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: tonic::RdmaRequest) -> Self::Future { + let inner = self.inner.clone(); + let path = req.path(); + match path { + "/grpc.health.v1.Health/Check" => { + #[allow(non_camel_case_types)] + struct CheckSvc(pub Arc); + impl< + T: Health, + > tonic::server::UnaryService + for CheckSvc { + type Response = super::HealthCheckResponse; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { (*inner).check(request).await }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = CheckSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ); + let res = grpc.unary_rdma(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/grpc.health.v1.Health/Watch" => { + todo!(); + } + _ => todo!(), + } + } + } impl Clone for HealthServer { fn clone(&self) -> Self { let inner = self.inner.clone(); diff --git a/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs b/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs index d348a4d50..b01f594a6 100644 --- a/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs +++ b/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs @@ -366,6 +366,30 @@ pub mod server_reflection_server { } } } + impl tonic::codegen::Service for ServerReflectionServer + where + T: ServerReflection, + { + type Response = tonic::RdmaResponse; + type Error = std::convert::Infallible; + type Future = BoxFuture; + fn poll_ready( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: tonic::RdmaRequest) -> Self::Future { + let inner = self.inner.clone(); + let path = req.path(); + match path { + "/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo" => { + todo!(); + } + _ => todo!(), + } + } + } impl Clone for ServerReflectionServer { fn clone(&self) -> Self { let inner = self.inner.clone(); From 110e75fbbf4386e63c74a7e18d3eb78ec122c578 Mon Sep 17 00:00:00 2001 From: moui66744 Date: Tue, 6 Dec 2022 23:18:14 -0500 Subject: [PATCH 09/11] feat: add `encode_rdma` used to encode msg into MR --- tonic/src/codec/encode.rs | 31 +++++++++++++++++++++++++++++++ tonic/src/codec/mod.rs | 2 +- tonic/src/rdma.rs | 1 - 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index d94a1f0ba..770520189 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -1,9 +1,11 @@ use super::compression::{compress, CompressionEncoding, SingleMessageCompressionOverride}; use super::{EncodeBuf, Encoder, HEADER_SIZE}; use crate::{Code, Status}; +use async_rdma::{LocalMr, LocalMrWriteAccess}; use bytes::{BufMut, Bytes, BytesMut}; use futures_core::{Stream, TryStream}; use futures_util::{ready, StreamExt, TryStreamExt}; +use http::uri::PathAndQuery; use http::HeaderMap; use http_body::Body; use pin_project::pin_project; @@ -247,3 +249,32 @@ where Poll::Ready(self.project().state.trailers()) } } + +/// encode `source` and `path` into `mr` +pub(crate) async fn encode_rdma( + mut encoder: T, + source: U, + path: PathAndQuery, + mr: &mut LocalMr, +) -> usize +where + T: Encoder, + U: Stream, +{ + let mut buf = unsafe { mr.as_mut_slice_unchecked() }; + // path + let path = path.as_str(); + let mut len = path.len(); + buf[0..len].copy_from_slice(path.as_bytes()); + unsafe { buf.advance_mut(len) }; + buf.put_u8(' ' as u8); + len += 1; + // body + source + .for_each(|item| { + len += encoder.encode_into_slice(item, &mut buf).unwrap(); + std::future::ready(()) + }) + .await; + len +} diff --git a/tonic/src/codec/mod.rs b/tonic/src/codec/mod.rs index 6539595b0..9070a6670 100644 --- a/tonic/src/codec/mod.rs +++ b/tonic/src/codec/mod.rs @@ -13,7 +13,7 @@ mod prost; use crate::Status; use std::io; -pub(crate) use self::encode::{encode_client, encode_server}; +pub(crate) use self::encode::{encode_client, encode_rdma, encode_server}; pub use self::buffer::{DecodeBuf, EncodeBuf}; pub use self::compression::{CompressionEncoding, EnabledCompressionEncodings}; diff --git a/tonic/src/rdma.rs b/tonic/src/rdma.rs index 57c04c23c..e32a5dbb5 100644 --- a/tonic/src/rdma.rs +++ b/tonic/src/rdma.rs @@ -56,7 +56,6 @@ impl RdmaRequest { pub fn body(&self) -> &[u8] { let pos = self.separator_index(); let res = &self.req_mr.as_slice()[pos + 1..self.len]; - println!("body: {:?}", res); res } } From 3af6889db7e870d96dcc0f91d45c8e4742d387e6 Mon Sep 17 00:00:00 2001 From: moui66744 Date: Tue, 6 Dec 2022 23:21:22 -0500 Subject: [PATCH 10/11] perf: eliminate copy when client codec. --- examples/src/routeguide_rdma/client.rs | 8 +- tonic-build/src/client.rs | 155 +++++++++++++++++++- tonic/src/client/grpc.rs | 36 +++++ tonic/src/transport/channel/mod.rs | 2 +- tonic/src/transport/channel/rdma_channel.rs | 67 ++------- tonic/src/transport/mod.rs | 2 +- 6 files changed, 212 insertions(+), 58 deletions(-) diff --git a/examples/src/routeguide_rdma/client.rs b/examples/src/routeguide_rdma/client.rs index 432e74874..f49a08bd0 100644 --- a/examples/src/routeguide_rdma/client.rs +++ b/examples/src/routeguide_rdma/client.rs @@ -28,7 +28,7 @@ async fn print_features(client: &mut RouteGuideClient) -> Result<() }; let mut stream = client - .list_features(Request::new(rectangle)) + .list_features_rdma(Request::new(rectangle)) .await? .into_inner(); @@ -51,7 +51,7 @@ async fn run_record_route(client: &mut RouteGuideClient) -> Result< println!("Traversing {} points", points.len()); let request = Request::new(stream::iter(points)); - match client.record_route(request).await { + match client.record_route_rdma(request).await { Ok(response) => println!("SUMMARY: {:?}", response.into_inner()), Err(e) => println!("something went wrong: {:?}", e), } @@ -80,7 +80,7 @@ async fn run_route_chat(client: &mut RouteGuideClient) -> Result<() } }; - let response = client.route_chat(Request::new(outbound)).await?; + let response = client.route_chat_rdma(Request::new(outbound)).await?; let mut inbound = response.into_inner(); while let Some(note) = inbound.message().await? { @@ -97,7 +97,7 @@ async fn main() -> Result<(), Box> { println!("*** SIMPLE RPC ***"); let response = client - .get_feature(Request::new(Point { + .get_feature_rdma(Request::new(Point { latitude: 409_146_138, longitude: -746_188_906, })) diff --git a/tonic-build/src/client.rs b/tonic-build/src/client.rs index f56d56c72..840bb0896 100644 --- a/tonic-build/src/client.rs +++ b/tonic-build/src/client.rs @@ -27,6 +27,13 @@ pub fn generate( compile_well_known_types, disable_comments, ); + let methods_rdma = generate_methods_rdma( + service, + emit_package, + proto_path, + compile_well_known_types, + disable_comments, + ); let connect = generate_connect(&service_ident, build_transport); let connect_rdma = generate_connect_rdma(&service_ident, build_transport); @@ -121,6 +128,19 @@ pub fn generate( #methods } + + impl #service_ident + where + T: tonic::transport::RdmaService, + { + pub fn new_rdma(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + + #methods_rdma + } + } } } @@ -157,7 +177,7 @@ fn generate_connect_rdma(service_ident: &syn::Ident, enable: bool) -> TokenStrea A: tonic::transport::ToSocketAddrs { let rdma_channel = tonic::transport::RdmaChannel::new(addr).await?; - Ok(Self::new(rdma_channel)) + Ok(Self::new_rdma(rdma_channel)) } } }; @@ -220,6 +240,139 @@ fn generate_methods( stream } +fn generate_methods_rdma( + service: &T, + emit_package: bool, + proto_path: &str, + compile_well_known_types: bool, + disable_comments: &HashSet, +) -> TokenStream { + let mut stream = TokenStream::new(); + let package = if emit_package { service.package() } else { "" }; + + for method in service.methods() { + let path = format!( + "/{}{}{}/{}", + package, + if package.is_empty() { "" } else { "." }, + service.identifier(), + method.identifier() + ); + + // if !disable_comments.contains(&format!( + // "{}{}{}.{}", + // package, + // if package.is_empty() { "" } else { "." }, + // service.identifier(), + // method.identifier() + // )) { + // stream.extend(generate_doc_comments(method.comment())); + // } + + let method = match (method.client_streaming(), method.server_streaming()) { + (false, false) => { + generate_unary_rdma(method, proto_path, compile_well_known_types, path) + } + (false, true) => { + generate_server_streaming_rdma(method, proto_path, compile_well_known_types, path) + } + (true, false) => { + generate_client_streaming_rdma(method, proto_path, compile_well_known_types, path) + } + (true, true) => generate_streaming_rdma(method, proto_path, compile_well_known_types, path), + }; + + stream.extend(method); + } + + stream +} + +fn generate_unary_rdma( + method: &T, + proto_path: &str, + compile_well_known_types: bool, + path: String, +) -> TokenStream { + let codec_name = syn::parse_str::(method.codec_path()).unwrap(); + let ident = format_ident!("{}_rdma", method.name()); + let (request, response) = method.request_response_name(proto_path, compile_well_known_types); + + quote! { + pub async fn #ident( + &mut self, + request: impl tonic::IntoRequest<#request>, + ) -> Result, tonic::Status> { + let codec = #codec_name::default(); + let path = http::uri::PathAndQuery::from_static(#path); + self.inner.unary_rdma(request.into_request(), path, codec).await + } + } +} + +fn generate_server_streaming_rdma( + method: &T, + proto_path: &str, + compile_well_known_types: bool, + path: String, +) -> TokenStream { + let codec_name = syn::parse_str::(method.codec_path()).unwrap(); + let ident = format_ident!("{}_rdma", method.name()); + + let (request, response) = method.request_response_name(proto_path, compile_well_known_types); + + quote! { + pub async fn #ident( + &mut self, + request: impl tonic::IntoRequest<#request>, + ) -> Result>, tonic::Status> { + todo!() + } + } +} + +fn generate_client_streaming_rdma( + method: &T, + proto_path: &str, + compile_well_known_types: bool, + path: String, +) -> TokenStream { + let codec_name = syn::parse_str::(method.codec_path()).unwrap(); + let ident = format_ident!("{}_rdma", method.name()); + + let (request, response) = method.request_response_name(proto_path, compile_well_known_types); + + quote! { + pub async fn #ident( + &mut self, + request: impl tonic::IntoStreamingRequest + ) -> Result, tonic::Status> { + todo!() + } + } +} + +fn generate_streaming_rdma( + method: &T, + proto_path: &str, + compile_well_known_types: bool, + path: String, +) -> TokenStream { + let codec_name = syn::parse_str::(method.codec_path()).unwrap(); + let ident = format_ident!("{}_rdma", method.name()); + + let (request, response) = method.request_response_name(proto_path, compile_well_known_types); + + quote! { + pub async fn #ident( + &mut self, + request: impl tonic::IntoStreamingRequest + ) -> Result>, tonic::Status> { + todo!() + } + } +} + fn generate_unary( method: &T, proto_path: &str, diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index 85134ceae..9c5a05a51 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -1,4 +1,6 @@ use crate::codec::compression::{CompressionEncoding, EnabledCompressionEncodings}; +use crate::codec::encode_rdma; +use crate::transport::RdmaService; use crate::{ body::BoxBody, client::GrpcService, @@ -6,6 +8,7 @@ use crate::{ request::SanitizeHeaders, Code, Request, Response, Status, }; +use async_rdma::LocalMrReadAccess; use futures_core::Stream; use futures_util::{future, stream, TryStreamExt}; use http::{ @@ -286,6 +289,39 @@ impl Grpc { Ok(Response::from_http(response)) } + + /// Send a single unary gRPC request through RDMA. + pub async fn unary_rdma( + &mut self, + request: Request, + path: PathAndQuery, + mut codec: C, + ) -> Result, Status> + where + T: RdmaService, + C: Codec, + M1: Send + Sync + 'static, + M2: Send + Sync + 'static, + { + let request = request.map(|m| stream::once(future::ready(m))); + + // encode into MR. + let mut req_mr = self.inner.alloc_mr().unwrap(); + let len = encode_rdma(codec.encoder(), request.into_inner(), path, &mut req_mr).await; + + // send request, recv response + let (resp_mr, len) = self.inner.call(req_mr, len).await; + let len = len.unwrap() as usize; + + // decode response from mr + let item = codec + .decoder() + .decode_from_slice(&resp_mr.as_slice()[5..len]) + .unwrap() + .unwrap(); + + Ok(Response::new(item)) + } } impl GrpcConfig { diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index 2adb3e30c..c057ade69 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -6,7 +6,7 @@ mod endpoint; #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] mod tls; -pub use rdma_channel::RdmaChannel; +pub use rdma_channel::{ RdmaChannel, RdmaService }; pub use endpoint::Endpoint; #[cfg(feature = "tls")] pub use tls::ClientTlsConfig; diff --git a/tonic/src/transport/channel/rdma_channel.rs b/tonic/src/transport/channel/rdma_channel.rs index 4c28dc510..8fcb4a947 100644 --- a/tonic/src/transport/channel/rdma_channel.rs +++ b/tonic/src/transport/channel/rdma_channel.rs @@ -1,13 +1,6 @@ -use crate::{ - body::BoxBody, - transport::{self, BoxFuture}, -}; -use async_rdma::{LocalMrReadAccess, LocalMrWriteAccess, Rdma}; -use http::{Request, Response}; -use http_body::Body; +use async_rdma::{LocalMr, Rdma}; use std::{alloc::Layout, io, sync::Arc}; use tokio::net::ToSocketAddrs; -use tower::Service; /// A RDMA transport channel #[derive(Debug, Clone)] @@ -29,50 +22,22 @@ impl RdmaChannel { } } -impl Service> for RdmaChannel { - type Response = Response; - type Error = transport::Error; - type Future = BoxFuture; - fn poll_ready( - &mut self, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - std::task::Poll::Ready(Ok(())) - } - - // TODO: handle error - fn call(&mut self, req: Request) -> Self::Future { - let rdma = Arc::clone(&self.rdma); - Box::pin(async move { - tokio::spawn(rdma_send_req(Arc::clone(&rdma), req)); - - let (resp_mr, len) = rdma.receive_with_imm().await.unwrap(); - let len = len.unwrap() as usize; - let resp_vec = resp_mr.as_slice()[0..len].to_vec(); - - let body = hyper::Body::from(resp_vec); - let resp = http::Response::builder().body(body).unwrap(); - Ok(resp) - }) - } +/// RDMA service +#[async_trait::async_trait] +pub trait RdmaService { + /// + fn alloc_mr(&self) -> io::Result; + /// + async fn call(&mut self, req: LocalMr, len: usize) -> (LocalMr, Option); } -async fn rdma_send_req(rdma: Arc, mut req: Request) { - let mut req_header = req.uri().to_string().into_bytes(); - req_header.push(32u8); - let len_header = req_header.len(); - let mut req_mr = rdma - .alloc_local_mr(Layout::new::<[u8; MAX_MSG_LEN]>()) - .unwrap(); - req_mr.as_mut_slice()[0..len_header].copy_from_slice(req_header.as_slice()); - - let mut len = len_header; - while let Some(Ok(bytes)) = req.data().await { - let req_body = bytes.to_vec(); - let len_body = req_body.len(); - assert!(len + len_body <= MAX_MSG_LEN); // TODO: len > MAX_MSG_LEN - req_mr.as_mut_slice()[len..len + len_body].copy_from_slice(req_body.as_slice()); - len += len_body; +#[async_trait::async_trait] +impl RdmaService for RdmaChannel { + fn alloc_mr(&self) -> io::Result { + self.rdma.alloc_local_mr(Layout::new::<[u8; MAX_MSG_LEN]>()) + } + async fn call(&mut self, req: LocalMr, len: usize) -> (LocalMr, Option) { + self.rdma.send_with_imm(&req, len as u32).await.unwrap(); + self.rdma.receive_with_imm().await.unwrap() } - rdma.send_with_imm(&req_mr, len as u32).await.unwrap(); } diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index 935a8dd9f..766f161eb 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -96,7 +96,7 @@ mod tls; #[doc(inline)] #[cfg(feature = "channel")] #[cfg_attr(docsrs, doc(cfg(feature = "channel")))] -pub use self::channel::{Channel, Endpoint, RdmaChannel}; +pub use self::channel::{Channel, Endpoint, RdmaChannel, RdmaService}; pub use tokio::net::ToSocketAddrs; pub use self::error::Error; #[doc(inline)] From 33726a3e5d77638ac638c738cdea66544f76735b Mon Sep 17 00:00:00 2001 From: moui66744 Date: Tue, 6 Dec 2022 23:22:11 -0500 Subject: [PATCH 11/11] health, reflection: update generated code. --- tonic-health/src/generated/grpc.health.v1.rs | 28 +++++++++++++++++++ .../src/generated/grpc.reflection.v1alpha.rs | 20 +++++++++++++ 2 files changed, 48 insertions(+) diff --git a/tonic-health/src/generated/grpc.health.v1.rs b/tonic-health/src/generated/grpc.health.v1.rs index c78bd0f0e..26c724300 100644 --- a/tonic-health/src/generated/grpc.health.v1.rs +++ b/tonic-health/src/generated/grpc.health.v1.rs @@ -161,6 +161,34 @@ pub mod health_client { self.inner.server_streaming(request.into_request(), path, codec).await } } + impl HealthClient + where + T: tonic::transport::RdmaService, + { + pub fn new_rdma(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub async fn check_rdma( + &mut self, + request: impl tonic::IntoRequest, + ) -> Result, tonic::Status> { + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/grpc.health.v1.Health/Check", + ); + self.inner.unary_rdma(request.into_request(), path, codec).await + } + pub async fn watch_rdma( + &mut self, + request: impl tonic::IntoRequest, + ) -> Result< + tonic::Response>, + tonic::Status, + > { + todo!() + } + } } /// Generated server implementations. pub mod health_server { diff --git a/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs b/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs index b01f594a6..7a04f303c 100644 --- a/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs +++ b/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs @@ -228,6 +228,26 @@ pub mod server_reflection_client { self.inner.streaming(request.into_streaming_request(), path, codec).await } } + impl ServerReflectionClient + where + T: tonic::transport::RdmaService, + { + pub fn new_rdma(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub async fn server_reflection_info_rdma( + &mut self, + request: impl tonic::IntoStreamingRequest< + Message = super::ServerReflectionRequest, + >, + ) -> Result< + tonic::Response>, + tonic::Status, + > { + todo!() + } + } } /// Generated server implementations. pub mod server_reflection_server {