Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: RC stream #118

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions examples/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
//!
//! You can try this example by running:
//!
//! cargo run --example server
//! cargo run --example server <server_ip> <port>
//!
//! And then start client in another terminal by running:
//!
//! cargo run --example client
//! cargo run --example client <server_ip> <port>

use async_rdma::{LocalMrReadAccess, LocalMrWriteAccess, Rdma, RdmaBuilder};
use async_rdma::{LocalMrReadAccess, LocalMrWriteAccess, MrAccess, RCStream, Rdma, RdmaBuilder};
use std::{
alloc::Layout,
env,
Expand Down Expand Up @@ -118,6 +118,35 @@ async fn request_then_write_cas(rdma: &Rdma) -> io::Result<()> {
Ok(())
}

async fn rcstream_send(stream: &mut RCStream) -> io::Result<()> {
for i in 0..10 {
// alloc 8 bytes local memory
let mut lmr = stream.alloc_local_mr(Layout::new::<[u8; 8]>())?;
// write data into lmr
let _num = lmr.as_mut_slice().write(&[i as u8; 8])?;
// send data in mr to the remote end
stream.send_lmr(lmr).await?;
println!("stream send datagram {} ", i);
}
Ok(())
}

async fn rcstream_recv(stream: &mut RCStream) -> io::Result<()> {
for i in 0..10 {
// recieve data from the remote end
let mut lmr_vec = stream.recieve_lmr(8).await?;
println!("stream recieve datagram {}", i);
// check the length of the recieved data
assert!(lmr_vec.len() == 1);
let lmr = lmr_vec.pop().unwrap();
assert!(lmr.length() == 8);
let buff = *(lmr.as_slice());
// check the data
assert_eq!(buff, [i as u8; 8]);
}
Ok(())
}

#[tokio::main]
async fn main() {
println!("client start");
Expand Down Expand Up @@ -153,5 +182,8 @@ async fn main() {
request_then_write_with_imm(&rdma).await.unwrap();
request_then_write_cas(&rdma).await.unwrap();
}
let mut stream: RCStream = rdma.into();
rcstream_send(&mut stream).await.unwrap();
rcstream_recv(&mut stream).await.unwrap();
println!("client done");
}
39 changes: 36 additions & 3 deletions examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
//!
//! You can try this example by running:
//!
//! cargo run --example server
//! cargo run --example server <server_ip> <port>
//!
//! And start client in another terminal by running:
//!
//! cargo run --example client
//! cargo run --example client <server_ip> <port>

use async_rdma::{LocalMrReadAccess, Rdma, RdmaBuilder};
use async_rdma::{LocalMrReadAccess, LocalMrWriteAccess, MrAccess, RCStream, Rdma, RdmaBuilder};
use clippy_utilities::Cast;
use std::io::Write;
use std::{alloc::Layout, env, io, process::exit};

/// receive data from client
Expand Down Expand Up @@ -90,6 +91,35 @@ async fn receive_mr_after_being_written_by_cas(rdma: &Rdma) -> io::Result<()> {
Ok(())
}

async fn rcstream_send(stream: &mut RCStream) -> io::Result<()> {
for i in 0..10 {
// alloc 8 bytes local memory
let mut lmr = stream.alloc_local_mr(Layout::new::<[u8; 8]>())?;
// write data into lmr
let _num = lmr.as_mut_slice().write(&[i as u8; 8])?;
// send data in mr to the remote end
stream.send_lmr(lmr).await?;
println!("stream send datagram {} ", i);
}
Ok(())
}

async fn rcstream_recv(stream: &mut RCStream) -> io::Result<()> {
for i in 0..10 {
// recieve data from the remote end
let mut lmr_vec = stream.recieve_lmr(8).await?;
println!("stream recieve datagram {}", i);
// check the length of the recieved data
assert!(lmr_vec.len() == 1);
let lmr = lmr_vec.pop().unwrap();
assert!(lmr.length() == 8);
let buff = *(lmr.as_slice());
// check the data
assert_eq!(buff, [i as u8; 8]);
}
Ok(())
}

#[tokio::main]
async fn main() {
println!("server start");
Expand Down Expand Up @@ -129,5 +159,8 @@ async fn main() {
.unwrap();
receive_mr_after_being_written_by_cas(&rdma).await.unwrap();
}
let mut stream: RCStream = rdma.into();
rcstream_recv(&mut stream).await.unwrap();
rcstream_send(&mut stream).await.unwrap();
println!("server done");
}
103 changes: 96 additions & 7 deletions src/agent.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::context::Context;
use crate::hashmap_extension::HashMapExtension;
use crate::ibv_event_listener::IbvEventListener;
use crate::queue_pair::MAX_RECV_WR;
use crate::queue_pair::{QPSendOwn, QueuePairOp, QueuePairOpsInflight, MAX_RECV_WR};
use crate::rmr_manager::RemoteMrManager;
use crate::RemoteMrReadAccess;
use crate::{
Expand Down Expand Up @@ -249,6 +249,23 @@ impl Agent {
Ok(())
}

/// Send the content in the `lm` to the other side
pub(crate) async fn submit_send_data(
&self,
lms: Vec<LocalMr>,
imm: Option<u32>,
) -> io::Result<RequestSubmitted<QPSendOwn<LocalMr>>> {
let lm_len = lms.iter().map(|lm| lm.length()).sum::<usize>();
assert!(lm_len <= self.max_msg_len());
let kind = RequestKind::SendData(SendDataRequest { len: lm_len });
let req_submitted = self
.inner
// SAFETY: The input range is always valid
.submit_send_request_append_data(kind, lms, imm)
.await?;
Ok(req_submitted)
}

/// Receive content sent from the other side and stored in the `LocalMr`
pub(crate) async fn receive_data(&self) -> io::Result<(LocalMr, Option<u32>)> {
let (lmr, len, imm) = self
Expand Down Expand Up @@ -689,6 +706,45 @@ impl AgentInner {
}
}

/// submit a send request with data appended
async fn submit_send_request_append_data(
&self,
kind: RequestKind,
data: Vec<LocalMr>,
imm: Option<u32>,
) -> io::Result<RequestSubmitted<QPSendOwn<LocalMr>>> {
let data_len: usize = data.iter().map(|l| l.length()).sum();
assert!(data_len <= self.max_sr_data_len);
let (tx, rx) = channel(2);
let req_id = self
.response_waits
.lock()
.insert_until_success(tx, AgentRequestId::new);
let req = Request {
request_id: req_id,
kind,
};
// SAFETY: ?
// TODO: check safety
let mut header_buf = self
.allocator
// alignment 1 is always correct
.alloc_zeroed_default(unsafe {
&Layout::from_size_align_unchecked(*REQUEST_HEADER_MAX_LEN, 1)
})?;
// SAFETY: the mr is writeable here without cancel safety issue
let cursor = Cursor::new(unsafe { header_buf.as_mut_slice_unchecked() });
let message = Message::Request(req);
// FIXME: serialize udpate
bincode::serialize_into(cursor, &message)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
// SAFETY: The input range is always valid
let mut lmrs = vec![header_buf];
lmrs.extend(data);
let inflight = self.qp.submit_send_sge(lmrs, imm).await?;
Ok(RequestSubmitted::new(inflight, rx))
}

/// Send a response to the other side
async fn send_response(&self, response: Response) -> io::Result<()> {
// SAFETY: ?
Expand Down Expand Up @@ -850,7 +906,7 @@ struct AllocMRRequest {

/// Response to the alloc MR request
#[derive(Debug, Serialize, Deserialize)]
struct AllocMRResponse {
pub(crate) struct AllocMRResponse {
/// The token to access the MR
token: MrToken,
}
Expand All @@ -864,7 +920,7 @@ struct ReleaseMRRequest {

/// Response to the release MR request
#[derive(Debug, Serialize, Deserialize)]
struct ReleaseMRResponse {
pub(crate) struct ReleaseMRResponse {
/// The status of the operation
status: usize,
}
Expand All @@ -887,7 +943,7 @@ struct SendMRRequest {

/// Response to the request of sending MR
#[derive(Debug, Serialize, Deserialize)]
struct SendMRResponse {
pub(crate) struct SendMRResponse {
/// The kinds of Response to the request of sending MR
kind: SendMRResponseKind,
}
Expand All @@ -911,9 +967,9 @@ struct SendDataRequest {

/// Response to the request of sending data
#[derive(Debug, Serialize, Deserialize)]
struct SendDataResponse {
pub(crate) struct SendDataResponse {
/// response status
status: usize,
pub(crate) status: usize,
}

/// Request type enumeration
Expand Down Expand Up @@ -941,7 +997,7 @@ struct Request {
/// Response type enumeration
#[derive(Serialize, Deserialize, Debug)]
#[allow(variant_size_differences)]
enum ResponseKind {
pub(crate) enum ResponseKind {
/// Allocate MR
AllocMR(AllocMRResponse),
/// Release MR
Expand Down Expand Up @@ -969,3 +1025,36 @@ enum Message {
/// Response
Response(Response),
}

/// Queue pair operation submitted in wq, waitting for wc & response
#[derive(Debug)]
pub(crate) struct RequestSubmitted<Op: QueuePairOp> {
/// the operation of the request
inflight: QueuePairOpsInflight<Op>,
/// receiver for the response of the request
rx: Receiver<Result<ResponseKind, io::Error>>,
}

impl<Op: QueuePairOp> RequestSubmitted<Op> {
/// Create a new `RequestSubmitted`
fn new(
inflight: QueuePairOpsInflight<Op>,
rx: Receiver<Result<ResponseKind, io::Error>>,
) -> Self {
Self { inflight, rx }
}

/// Wait for the response of the request
pub(crate) async fn response(mut self) -> io::Result<ResponseKind> {
let _ = self.inflight.result().await?;
match tokio::time::timeout(RESPONSE_TIMEOUT, self.rx.recv()).await {
Ok(resp) => {
resp.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "agent is dropped"))?
}
Err(_) => Err(io::Error::new(
io::ErrorKind::TimedOut,
"Timeout for waiting for a response.",
)),
}
}
}
Loading
Loading