From 2aa9500ad36c28b694cab671d742bfb01db5d75b Mon Sep 17 00:00:00 2001 From: Ben Cherry Date: Thu, 24 Oct 2024 14:29:01 -0700 Subject: [PATCH] RPC implementation + FFI (#461) * Add types * gen proto * callback * wip * wip * proto * error * wip * Builds and passes * close * string * string * Fixes * more fixes * initial example * wip * compiling * almost * close * somewhat working * working sample * cleanup * ffi building * delegate method * ffi * Fixes * remove dead code * Logging * fixes * fix * fix * logging * comment out' * handle * room handle * Revert "room handle" This reverts commit c62dc0e7ada53a0eb6e3254a8aaf3cbee0fe131a. * handle * cb * cleanup example * cleanup debug logs * cleanup other logs * remove some logging * SimplifyW * errors * fixes * fixes * fmt * waiter * 10k * sender->caller * update example * perform_rpc_request->perform_rpc * move methods to participant * fmt * cleanup * ms * opt * panics * fmt * remove conn * use webrtc uuid * Move waiter * refactor for readability * Simplify * uui * flat * fmt * fix? * fmt * unused imports * start time * better * fix * store rpc state in one spot * macro * opts * rusty * simplify * rm * fmt * remove initial wait * Revert "Merge remote-tracking branch 'origin/main' into bcherry/rpc-full" This reverts commit 961f3b69ed5d67a2767402d2b2648e47e4d379a1, reversing changes made to 73106cfb0211508f9564002c0eac1a5683d49d75. * fix * v * fix pb2 * proto * commit * p * 123 * p * stats * wip * wip * wip * import * add min version check * remove empty callbacks * fixes * rm * 1 * fmt * jr * fmt --- Cargo.lock | 1 + examples/Cargo.lock | 62 ++- examples/Cargo.toml | 1 + examples/basic_room/src/main.rs | 4 +- examples/rpc/Cargo.toml | 13 + examples/rpc/src/main.rs | 268 +++++++++++++ livekit-api/src/services/sip.rs | 3 +- livekit-ffi/generate_proto.sh | 3 +- livekit-ffi/protocol/ffi.proto | 17 +- livekit-ffi/protocol/rpc.proto | 81 ++++ livekit-ffi/src/conversion/participant.rs | 2 +- livekit-ffi/src/livekit.proto.rs | 132 ++++++- livekit-ffi/src/server/mod.rs | 7 +- livekit-ffi/src/server/participant.rs | 171 +++++++++ livekit-ffi/src/server/requests.rs | 76 +++- livekit-ffi/src/server/room.rs | 38 +- livekit-ffi/src/server/utils.rs | 2 +- livekit/Cargo.toml | 1 + livekit/src/prelude.rs | 4 +- livekit/src/room/mod.rs | 53 +++ .../src/room/participant/local_participant.rs | 355 ++++++++++++++++-- livekit/src/room/participant/mod.rs | 2 + livekit/src/room/participant/rpc.rs | 112 ++++++ livekit/src/rtc_engine/mod.rs | 44 +++ livekit/src/rtc_engine/rtc_session.rs | 52 +++ soxr-sys/src/lib.rs | 2 - 26 files changed, 1441 insertions(+), 65 deletions(-) create mode 100644 examples/rpc/Cargo.toml create mode 100644 examples/rpc/src/main.rs create mode 100644 livekit-ffi/protocol/rpc.proto create mode 100644 livekit-ffi/src/server/participant.rs create mode 100644 livekit/src/room/participant/rpc.rs diff --git a/Cargo.lock b/Cargo.lock index 6b2c90388..4c109ac0f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1607,6 +1607,7 @@ dependencies = [ "log", "parking_lot", "prost 0.12.3", + "semver", "serde", "serde_json", "thiserror", diff --git a/examples/Cargo.lock b/examples/Cargo.lock index 94d6b303f..06ad21b8f 100644 --- a/examples/Cargo.lock +++ b/examples/Cargo.lock @@ -106,6 +106,12 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc7eb209b1518d6bb87b283c20095f5228ecda460da70b44f0802523dea6da04" +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + [[package]] name = "android_log-sys" version = "0.3.1" @@ -658,11 +664,16 @@ checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" [[package]] name = "chrono" -version = "0.4.31" +version = "0.4.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", "num-traits", + "wasm-bindgen", + "windows-targets 0.52.0", ] [[package]] @@ -1842,6 +1853,29 @@ dependencies = [ "tokio-native-tls", ] +[[package]] +name = "iana-time-zone" +version = "0.1.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "idna" version = "0.5.0" @@ -2153,6 +2187,7 @@ checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" name = "livekit" version = "0.6.0" dependencies = [ + "chrono", "futures-util", "lazy_static", "libwebrtc", @@ -2162,6 +2197,7 @@ dependencies = [ "log", "parking_lot", "prost 0.12.3", + "semver", "serde", "serde_json", "thiserror", @@ -3281,6 +3317,19 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "rpc_example" +version = "0.1.0" +dependencies = [ + "env_logger", + "livekit", + "livekit-api", + "log", + "rand", + "serde_json", + "tokio", +] + [[package]] name = "rustc-demangle" version = "0.1.23" @@ -4670,6 +4719,15 @@ dependencies = [ "windows-targets 0.42.2", ] +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.0", +] + [[package]] name = "windows-sys" version = "0.45.0" diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 0f0cb3d4d..c1336ef69 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -6,4 +6,5 @@ members = [ "wgpu_room", "webhooks", "api", + "rpc", ] diff --git a/examples/basic_room/src/main.rs b/examples/basic_room/src/main.rs index b023feb57..e82b784c2 100644 --- a/examples/basic_room/src/main.rs +++ b/examples/basic_room/src/main.rs @@ -24,9 +24,7 @@ async fn main() { .to_jwt() .unwrap(); - let (room, mut rx) = Room::connect(&url, &token, RoomOptions::default()) - .await - .unwrap(); + let (room, mut rx) = Room::connect(&url, &token, RoomOptions::default()).await.unwrap(); log::info!("Connected to room: {} - {}", room.name(), String::from(room.sid().await)); room.local_participant() diff --git a/examples/rpc/Cargo.toml b/examples/rpc/Cargo.toml new file mode 100644 index 000000000..85b6b2ab8 --- /dev/null +++ b/examples/rpc/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "rpc_example" +version = "0.1.0" +edition = "2021" + +[dependencies] +tokio = { version = "1", features = ["full"] } +env_logger = "0.10" +livekit = { path = "../../livekit", features = ["native-tls"]} +livekit-api = { path = "../../livekit-api"} +log = "0.4" +rand = "0.8" +serde_json = "1.0" diff --git a/examples/rpc/src/main.rs b/examples/rpc/src/main.rs new file mode 100644 index 000000000..152f380e1 --- /dev/null +++ b/examples/rpc/src/main.rs @@ -0,0 +1,268 @@ +use livekit::prelude::*; +use livekit_api::access_token; +use rand::Rng; +use serde_json::{json, Value}; +use std::env; +use std::sync::Once; +use std::sync::{ + atomic::{AtomicU64, Ordering}, + Arc, +}; +use std::time::{Duration, Instant}; +use tokio::time::sleep; + +// Example usage of RPC calls between participants +// (In a real app, you'd have one participant per client/device such as an agent and a browser app) +// +// Try it with `LIVEKIT_URL= LIVEKIT_API_KEY= LIVEKIT_API_SECRET= cargo run` + +static START_TIME: Once = Once::new(); +static mut START_INSTANT: Option = None; + +fn get_start_time() -> Instant { + unsafe { + START_TIME.call_once(|| { + START_INSTANT = Some(Instant::now()); + }); + START_INSTANT.unwrap() + } +} + +fn elapsed_time() -> String { + let start = get_start_time(); + let elapsed = Instant::now().duration_since(start); + format!("+{:.3}s", elapsed.as_secs_f64()) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + env_logger::init(); + + // Initialize START_TIME + get_start_time(); + + let url = env::var("LIVEKIT_URL").expect("LIVEKIT_URL is not set"); + let api_key = env::var("LIVEKIT_API_KEY").expect("LIVEKIT_API_KEY is not set"); + let api_secret = env::var("LIVEKIT_API_SECRET").expect("LIVEKIT_API_SECRET is not set"); + + let room_name = format!("rpc-test-{:x}", rand::thread_rng().gen::()); + println!("[{}] Connecting participants to room: {}", elapsed_time(), room_name); + + let (callers_room, greeters_room, math_genius_room) = tokio::try_join!( + connect_participant("caller", &room_name, &url, &api_key, &api_secret), + connect_participant("greeter", &room_name, &url, &api_key, &api_secret), + connect_participant("math-genius", &room_name, &url, &api_key, &api_secret) + )?; + + register_receiver_methods(&greeters_room, &math_genius_room).await; + + println!("\n\nRunning greeting example..."); + perform_greeting(&callers_room).await?; + + println!("\n\nRunning error handling example..."); + perform_division(&callers_room).await?; + + println!("\n\nRunning math example..."); + perform_square_root(&callers_room).await?; + sleep(Duration::from_secs(2)).await; + perform_quantum_hypergeometric_series(&callers_room).await?; + + println!("\n\nParticipants done, disconnecting..."); + callers_room.close().await?; + greeters_room.close().await?; + math_genius_room.close().await?; + + println!("Participants disconnected. Example completed."); + + Ok(()) +} + +async fn register_receiver_methods(greeters_room: &Arc, math_genius_room: &Arc) { + greeters_room.local_participant().register_rpc_method( + "arrival".to_string(), + |_, caller_identity, payload, _| { + Box::pin(async move { + println!( + "[{}] [Greeter] Oh {} arrived and said \"{}\"", + elapsed_time(), + caller_identity, + payload + ); + sleep(Duration::from_secs(2)).await; + Ok("Welcome and have a wonderful day!".to_string()) + }) + }, + ); + + math_genius_room.local_participant().register_rpc_method("square-root".to_string(), |_, caller_identity, payload, response_timeout_ms| { + Box::pin(async move { + let json_data: Value = serde_json::from_str(&payload).unwrap(); + let number = json_data["number"].as_f64().unwrap(); + println!( + "[{}] [Math Genius] I guess {} wants the square root of {}. I've only got {} seconds to respond but I think I can pull it off.", + elapsed_time(), + caller_identity, + number, + response_timeout_ms.as_secs() + ); + + println!("[{}] [Math Genius] *doing math*…", elapsed_time()); + sleep(Duration::from_secs(2)).await; + + let result = number.sqrt(); + println!("[{}] [Math Genius] Aha! It's {}", elapsed_time(), result); + Ok(json!({"result": result}).to_string()) + }) + }); + + math_genius_room.local_participant().register_rpc_method( + "divide".to_string(), + |_, caller_identity, payload, _| { + Box::pin(async move { + let json_data: Value = serde_json::from_str(&payload).unwrap(); + let dividend = json_data["dividend"].as_i64().unwrap(); + let divisor = json_data["divisor"].as_i64().unwrap(); + println!( + "[{}] [Math Genius] {} wants me to divide {} by {}.", + elapsed_time(), + caller_identity, + dividend, + divisor + ); + + let result = dividend / divisor; + println!("[{}] [Math Genius] The result is {}", elapsed_time(), result); + Ok(json!({"result": result}).to_string()) + }) + }, + ); +} + +async fn perform_greeting(room: &Arc) -> Result<(), Box> { + println!("[{}] Letting the greeter know that I've arrived", elapsed_time()); + match room + .local_participant() + .perform_rpc("greeter".to_string(), "arrival".to_string(), "Hello".to_string(), None) + .await + { + Ok(response) => { + println!("[{}] That's nice, the greeter said: \"{}\"", elapsed_time(), response) + } + Err(e) => println!("[{}] RPC call failed: {:?}", elapsed_time(), e), + } + Ok(()) +} + +async fn perform_square_root(room: &Arc) -> Result<(), Box> { + println!("[{}] What's the square root of 16?", elapsed_time()); + match room + .local_participant() + .perform_rpc( + "math-genius".to_string(), + "square-root".to_string(), + json!({"number": 16}).to_string(), + None, + ) + .await + { + Ok(response) => { + let parsed_response: Value = serde_json::from_str(&response)?; + println!("[{}] Nice, the answer was {}", elapsed_time(), parsed_response["result"]); + } + Err(e) => log::error!("[{}] RPC call failed: {:?}", elapsed_time(), e), + } + Ok(()) +} + +async fn perform_quantum_hypergeometric_series( + room: &Arc, +) -> Result<(), Box> { + println!("[{}] What's the quantum hypergeometric series of 42?", elapsed_time()); + match room + .local_participant() + .perform_rpc( + "math-genius".to_string(), + "quantum-hypergeometric-series".to_string(), + json!({"number": 42}).to_string(), + None, + ) + .await + { + Ok(response) => { + let parsed_response: Value = serde_json::from_str(&response)?; + println!("[{}] genius says {}!", elapsed_time(), parsed_response["result"]); + } + Err(e) => { + if e.code == RpcErrorCode::UnsupportedMethod as u32 { + println!("[{}] Aww looks like the genius doesn't know that one.", elapsed_time()); + return Ok(()); + } + log::error!("[{}] RPC error: {} (code: {})", elapsed_time(), e.message, e.code); + } + } + Ok(()) +} + +async fn perform_division(room: &Arc) -> Result<(), Box> { + println!("[{}] Let's try dividing 5 by 0", elapsed_time()); + match room + .local_participant() + .perform_rpc( + "math-genius".to_string(), + "divide".to_string(), + json!({"dividend": 5, "divisor": 0}).to_string(), + None, + ) + .await + { + Ok(response) => { + let parsed_response: Value = serde_json::from_str(&response)?; + println!("[{}] The result is {}", elapsed_time(), parsed_response["result"]); + } + Err(e) => { + println!("[{}] Oops! Dividing by zero didn't work. That's ok...", elapsed_time()); + log::error!("[{}] RPC error: {} (code: {})", elapsed_time(), e.message, e.code); + } + } + + Ok(()) +} + +async fn connect_participant( + identity: &str, + room_name: &str, + url: &str, + api_key: &str, + api_secret: &str, +) -> Result, Box> { + let token = access_token::AccessToken::with_api_key(api_key, api_secret) + .with_identity(identity) + .with_name(identity) + .with_grants(access_token::VideoGrants { + room_join: true, + room: room_name.to_string(), + ..Default::default() + }) + .to_jwt()?; + + println!("[{}] [{}] Connecting...", elapsed_time(), identity); + let (room, mut rx) = Room::connect(url, &token, RoomOptions::default()).await?; + + let room = Arc::new(room); + + tokio::spawn({ + let identity = identity.to_string(); + let room_clone = Arc::clone(&room); + async move { + while let Some(event) = rx.recv().await { + if let RoomEvent::Disconnected { .. } = event { + println!("[{}] Disconnected from room", identity); + break; + } + } + room_clone.close().await.ok(); + } + }); + + Ok(room) +} diff --git a/livekit-api/src/services/sip.rs b/livekit-api/src/services/sip.rs index aaf796f5d..e2517be9d 100644 --- a/livekit-api/src/services/sip.rs +++ b/livekit-api/src/services/sip.rs @@ -14,9 +14,8 @@ use livekit_protocol as proto; use std::collections::HashMap; -use std::ptr::null; -use crate::access_token::{SIPGrants, VideoGrants}; +use crate::access_token::SIPGrants; use crate::get_env_keys; use crate::services::twirp_client::TwirpClient; use crate::services::{ServiceBase, ServiceResult, LIVEKIT_PACKAGE}; diff --git a/livekit-ffi/generate_proto.sh b/livekit-ffi/generate_proto.sh index 19a1ccc15..b1d5e9d6a 100755 --- a/livekit-ffi/generate_proto.sh +++ b/livekit-ffi/generate_proto.sh @@ -28,4 +28,5 @@ protoc \ $PROTOCOL/video_frame.proto \ $PROTOCOL/audio_frame.proto \ $PROTOCOL/e2ee.proto \ - $PROTOCOL/stats.proto + $PROTOCOL/stats.proto \ + $PROTOCOL/rpc.proto diff --git a/livekit-ffi/protocol/ffi.proto b/livekit-ffi/protocol/ffi.proto index 32bfe4416..e8c1fda4a 100644 --- a/livekit-ffi/protocol/ffi.proto +++ b/livekit-ffi/protocol/ffi.proto @@ -23,6 +23,7 @@ import "track.proto"; import "room.proto"; import "video_frame.proto"; import "audio_frame.proto"; +import "rpc.proto"; // **How is the livekit-ffi working: // We refer as the ffi server the Rust server that is running the LiveKit client implementation, and we @@ -69,6 +70,7 @@ message FfiRequest { GetSessionStatsRequest get_session_stats = 12; PublishTranscriptionRequest publish_transcription = 13; PublishSipDtmfRequest publish_sip_dtmf = 14; + // Track CreateVideoTrackRequest create_video_track = 15; @@ -96,9 +98,14 @@ message FfiRequest { NewSoxResamplerRequest new_sox_resampler = 33; PushSoxResamplerRequest push_sox_resampler = 34; FlushSoxResamplerRequest flush_sox_resampler = 35; - SendChatMessageRequest send_chat_message = 36; EditChatMessageRequest edit_chat_message = 37; + + // RPC + PerformRpcRequest perform_rpc = 38; + RegisterRpcMethodRequest register_rpc_method = 39; + UnregisterRpcMethodRequest unregister_rpc_method = 40; + RpcMethodInvocationResponseRequest rpc_method_invocation_response = 41; } } @@ -147,8 +154,12 @@ message FfiResponse { NewSoxResamplerResponse new_sox_resampler = 33; PushSoxResamplerResponse push_sox_resampler = 34; FlushSoxResamplerResponse flush_sox_resampler = 35; - SendChatMessageResponse send_chat_message = 36; + // RPC + PerformRpcResponse perform_rpc = 37; + RegisterRpcMethodResponse register_rpc_method = 38; + UnregisterRpcMethodResponse unregister_rpc_method = 39; + RpcMethodInvocationResponseResponse rpc_method_invocation_response = 40; } } @@ -178,6 +189,8 @@ message FfiEvent { Panic panic = 20; PublishSipDtmfCallback publish_sip_dtmf = 21; SendChatMessageCallback chat_message = 22; + PerformRpcCallback perform_rpc = 23; + RpcMethodInvocationEvent rpc_method_invocation = 24; } } diff --git a/livekit-ffi/protocol/rpc.proto b/livekit-ffi/protocol/rpc.proto new file mode 100644 index 000000000..19fd75f11 --- /dev/null +++ b/livekit-ffi/protocol/rpc.proto @@ -0,0 +1,81 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package livekit.proto; +option csharp_namespace = "LiveKit.Proto"; + +message RpcError { + required uint32 code = 1; + required string message = 2; + optional string data = 3; +} + +// FFI Requests +message PerformRpcRequest { + required uint64 local_participant_handle = 1; + required string destination_identity = 2; + required string method = 3; + required string payload = 4; + optional uint32 response_timeout_ms = 5; +} + +message RegisterRpcMethodRequest { + required uint64 local_participant_handle = 1; + required string method = 2; +} + +message UnregisterRpcMethodRequest { + required uint64 local_participant_handle = 1; + required string method = 2; +} + +message RpcMethodInvocationResponseRequest { + required uint64 local_participant_handle = 1; + required uint64 invocation_id = 2; + optional string payload = 3; + optional RpcError error = 4; +} + +// FFI Responses +message PerformRpcResponse { + required uint64 async_id = 1; +} + +message RegisterRpcMethodResponse {} + +message UnregisterRpcMethodResponse {} + +message RpcMethodInvocationResponseResponse { + optional string error = 1; +} + +// FFI Callbacks +message PerformRpcCallback { + required uint64 async_id = 1; + optional string payload = 2; + optional RpcError error = 3; +} + +// FFI Events +message RpcMethodInvocationEvent { + required uint64 local_participant_handle = 1; + required uint64 invocation_id = 2; + required string method = 3; + required string request_id = 4; + required string caller_identity = 5; + required string payload = 6; + required uint32 response_timeout_ms = 7; +} diff --git a/livekit-ffi/src/conversion/participant.rs b/livekit-ffi/src/conversion/participant.rs index f8041f8d7..755f0bd06 100644 --- a/livekit-ffi/src/conversion/participant.rs +++ b/livekit-ffi/src/conversion/participant.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::{proto, server::room::FfiParticipant}; +use crate::{proto, server::participant::FfiParticipant}; use livekit::ParticipantKind; impl From<&FfiParticipant> for proto::ParticipantInfo { diff --git a/livekit-ffi/src/livekit.proto.rs b/livekit-ffi/src/livekit.proto.rs index 367e02741..a068e502d 100644 --- a/livekit-ffi/src/livekit.proto.rs +++ b/livekit-ffi/src/livekit.proto.rs @@ -3551,6 +3551,110 @@ impl AudioSourceType { } } } +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RpcError { + #[prost(uint32, required, tag="1")] + pub code: u32, + #[prost(string, required, tag="2")] + pub message: ::prost::alloc::string::String, + #[prost(string, optional, tag="3")] + pub data: ::core::option::Option<::prost::alloc::string::String>, +} +/// FFI Requests +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PerformRpcRequest { + #[prost(uint64, required, tag="1")] + pub local_participant_handle: u64, + #[prost(string, required, tag="2")] + pub destination_identity: ::prost::alloc::string::String, + #[prost(string, required, tag="3")] + pub method: ::prost::alloc::string::String, + #[prost(string, required, tag="4")] + pub payload: ::prost::alloc::string::String, + #[prost(uint32, optional, tag="5")] + pub response_timeout_ms: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RegisterRpcMethodRequest { + #[prost(uint64, required, tag="1")] + pub local_participant_handle: u64, + #[prost(string, required, tag="2")] + pub method: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnregisterRpcMethodRequest { + #[prost(uint64, required, tag="1")] + pub local_participant_handle: u64, + #[prost(string, required, tag="2")] + pub method: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RpcMethodInvocationResponseRequest { + #[prost(uint64, required, tag="1")] + pub local_participant_handle: u64, + #[prost(uint64, required, tag="2")] + pub invocation_id: u64, + #[prost(string, optional, tag="3")] + pub payload: ::core::option::Option<::prost::alloc::string::String>, + #[prost(message, optional, tag="4")] + pub error: ::core::option::Option, +} +/// FFI Responses +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PerformRpcResponse { + #[prost(uint64, required, tag="1")] + pub async_id: u64, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RegisterRpcMethodResponse { +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnregisterRpcMethodResponse { +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RpcMethodInvocationResponseResponse { + #[prost(string, optional, tag="1")] + pub error: ::core::option::Option<::prost::alloc::string::String>, +} +/// FFI Callbacks +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PerformRpcCallback { + #[prost(uint64, required, tag="1")] + pub async_id: u64, + #[prost(string, optional, tag="2")] + pub payload: ::core::option::Option<::prost::alloc::string::String>, + #[prost(message, optional, tag="3")] + pub error: ::core::option::Option, +} +/// FFI Events +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RpcMethodInvocationEvent { + #[prost(uint64, required, tag="1")] + pub local_participant_handle: u64, + #[prost(uint64, required, tag="2")] + pub invocation_id: u64, + #[prost(string, required, tag="3")] + pub method: ::prost::alloc::string::String, + #[prost(string, required, tag="4")] + pub request_id: ::prost::alloc::string::String, + #[prost(string, required, tag="5")] + pub caller_identity: ::prost::alloc::string::String, + #[prost(string, required, tag="6")] + pub payload: ::prost::alloc::string::String, + #[prost(uint32, required, tag="7")] + pub response_timeout_ms: u32, +} // **How is the livekit-ffi working: // We refer as the ffi server the Rust server that is running the LiveKit client implementation, and we // refer as the ffi client the foreign language that commumicates with the ffi server. (e.g Python SDK, Unity SDK, etc...) @@ -3582,7 +3686,7 @@ impl AudioSourceType { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FfiRequest { - #[prost(oneof="ffi_request::Message", tags="2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37")] + #[prost(oneof="ffi_request::Message", tags="2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41")] pub message: ::core::option::Option, } /// Nested message and enum types in `FfiRequest`. @@ -3666,13 +3770,22 @@ pub mod ffi_request { SendChatMessage(super::SendChatMessageRequest), #[prost(message, tag="37")] EditChatMessage(super::EditChatMessageRequest), + /// RPC + #[prost(message, tag="38")] + PerformRpc(super::PerformRpcRequest), + #[prost(message, tag="39")] + RegisterRpcMethod(super::RegisterRpcMethodRequest), + #[prost(message, tag="40")] + UnregisterRpcMethod(super::UnregisterRpcMethodRequest), + #[prost(message, tag="41")] + RpcMethodInvocationResponse(super::RpcMethodInvocationResponseRequest), } } /// This is the output of livekit_ffi_request function. #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FfiResponse { - #[prost(oneof="ffi_response::Message", tags="2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36")] + #[prost(oneof="ffi_response::Message", tags="2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40")] pub message: ::core::option::Option, } /// Nested message and enum types in `FfiResponse`. @@ -3754,6 +3867,15 @@ pub mod ffi_response { FlushSoxResampler(super::FlushSoxResamplerResponse), #[prost(message, tag="36")] SendChatMessage(super::SendChatMessageResponse), + /// RPC + #[prost(message, tag="37")] + PerformRpc(super::PerformRpcResponse), + #[prost(message, tag="38")] + RegisterRpcMethod(super::RegisterRpcMethodResponse), + #[prost(message, tag="39")] + UnregisterRpcMethod(super::UnregisterRpcMethodResponse), + #[prost(message, tag="40")] + RpcMethodInvocationResponse(super::RpcMethodInvocationResponseResponse), } } /// To minimize complexity, participant events are not included in the protocol. @@ -3762,7 +3884,7 @@ pub mod ffi_response { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FfiEvent { - #[prost(oneof="ffi_event::Message", tags="1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22")] + #[prost(oneof="ffi_event::Message", tags="1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24")] pub message: ::core::option::Option, } /// Nested message and enum types in `FfiEvent`. @@ -3812,6 +3934,10 @@ pub mod ffi_event { PublishSipDtmf(super::PublishSipDtmfCallback), #[prost(message, tag="22")] ChatMessage(super::SendChatMessageCallback), + #[prost(message, tag="23")] + PerformRpc(super::PerformRpcCallback), + #[prost(message, tag="24")] + RpcMethodInvocation(super::RpcMethodInvocationEvent), } } /// Stop all rooms synchronously (Do we need async here?). diff --git a/livekit-ffi/src/server/mod.rs b/livekit-ffi/src/server/mod.rs index 461439ab9..cd015b5f8 100644 --- a/livekit-ffi/src/server/mod.rs +++ b/livekit-ffi/src/server/mod.rs @@ -13,7 +13,6 @@ // limitations under the License. use std::{ - collections::HashMap, error::Error, sync::{ atomic::{AtomicU64, Ordering}, @@ -27,10 +26,7 @@ use dashmap::{mapref::one::MappedRef, DashMap}; use downcast_rs::{impl_downcast, Downcast}; use livekit::webrtc::{native::audio_resampler::AudioResampler, prelude::*}; use parking_lot::{deadlock, Mutex}; -use tokio::{ - sync::{broadcast, oneshot}, - task::JoinHandle, -}; +use tokio::{sync::oneshot, task::JoinHandle}; use crate::{proto, proto::FfiEvent, FfiError, FfiHandleId, FfiResult, INVALID_HANDLE}; @@ -38,6 +34,7 @@ pub mod audio_source; pub mod audio_stream; pub mod colorcvt; pub mod logger; +pub mod participant; pub mod requests; pub mod resampler; pub mod room; diff --git a/livekit-ffi/src/server/participant.rs b/livekit-ffi/src/server/participant.rs new file mode 100644 index 000000000..fdf54bd60 --- /dev/null +++ b/livekit-ffi/src/server/participant.rs @@ -0,0 +1,171 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use livekit::prelude::*; +use std::time::Duration; +use tokio::sync::oneshot; + +use crate::{ + proto, + server::room::RoomInner, + server::{FfiHandle, FfiServer}, + FfiError, FfiHandleId, FfiResult, +}; + +#[derive(Clone)] +pub struct FfiParticipant { + pub handle: FfiHandleId, + pub participant: Participant, + pub room: Arc, +} + +impl FfiHandle for FfiParticipant {} + +impl FfiParticipant { + pub fn perform_rpc( + &self, + server: &'static FfiServer, + request: proto::PerformRpcRequest, + ) -> FfiResult { + let async_id = server.next_id(); + + let local = match &self.participant { + Participant::Local(local) => local.clone(), + Participant::Remote(_) => { + return Err(FfiError::InvalidRequest("Expected local participant".into())) + } + }; + + let handle = server.async_runtime.spawn(async move { + let result = local + .perform_rpc( + request.destination_identity.to_string(), + request.method, + request.payload, + request.response_timeout_ms, + ) + .await; + + let callback = proto::PerformRpcCallback { + async_id, + payload: result.as_ref().ok().cloned(), + error: result.as_ref().err().map(|error| proto::RpcError { + code: error.code, + message: error.message.clone(), + data: error.data.clone(), + }), + }; + + let _ = server.send_event(proto::ffi_event::Message::PerformRpc(callback)); + }); + server.watch_panic(handle); + Ok(proto::PerformRpcResponse { async_id }) + } + + pub fn register_rpc_method( + &self, + server: &'static FfiServer, + request: proto::RegisterRpcMethodRequest, + ) -> FfiResult { + let method = request.method.clone(); + + let local = match &self.participant { + Participant::Local(local) => local.clone(), + Participant::Remote(_) => { + return Err(FfiError::InvalidRequest("Expected local participant".into())) + } + }; + + let local_participant_handle = self.handle.clone(); + let room: Arc = self.room.clone(); + local.register_rpc_method( + method.clone(), + move |request_id, caller_identity, payload, response_timeout| { + Box::pin({ + let room = room.clone(); + let method = method.clone(); + async move { + forward_rpc_method_invocation( + server, + room, + local_participant_handle, + method, + request_id, + caller_identity, + payload, + response_timeout, + ) + .await + } + }) + }, + ); + Ok(proto::RegisterRpcMethodResponse {}) + } + + pub fn unregister_rpc_method( + &self, + server: &'static FfiServer, + request: proto::UnregisterRpcMethodRequest, + ) -> FfiResult { + let local = match &self.participant { + Participant::Local(local) => local.clone(), + Participant::Remote(_) => { + return Err(FfiError::InvalidRequest("Expected local participant".into())) + } + }; + + local.unregister_rpc_method(request.method); + + Ok(proto::UnregisterRpcMethodResponse {}) + } +} + +async fn forward_rpc_method_invocation( + server: &'static FfiServer, + room: Arc, + local_participant_handle: FfiHandleId, + method: String, + request_id: String, + caller_identity: ParticipantIdentity, + payload: String, + response_timeout: Duration, +) -> Result { + let (tx, rx) = oneshot::channel(); + let invocation_id = server.next_id(); + + let _ = server.send_event(proto::ffi_event::Message::RpcMethodInvocation( + proto::RpcMethodInvocationEvent { + local_participant_handle: local_participant_handle as u64, + invocation_id, + method, + request_id, + caller_identity: caller_identity.into(), + payload, + response_timeout_ms: response_timeout.as_millis() as u32, + }, + )); + + room.store_rpc_method_invocation_waiter(invocation_id, tx); + + rx.await.unwrap_or_else(|_| { + Err(RpcError { + code: RpcErrorCode::ApplicationError as u32, + message: "Error from method handler".to_string(), + data: None, + }) + }) +} diff --git a/livekit-ffi/src/server/requests.rs b/livekit-ffi/src/server/requests.rs index a61fe44ea..cb3b6a715 100644 --- a/livekit-ffi/src/server/requests.rs +++ b/livekit-ffi/src/server/requests.rs @@ -22,8 +22,10 @@ use livekit::{ use parking_lot::Mutex; use super::{ - audio_source, audio_stream, colorcvt, resampler, - room::{self, FfiParticipant, FfiPublication, FfiTrack}, + audio_source, audio_stream, colorcvt, + participant::FfiParticipant, + resampler, + room::{self, FfiPublication, FfiTrack}, video_source, video_stream, FfiError, FfiResult, FfiServer, }; use crate::proto; @@ -788,6 +790,58 @@ fn on_flush_sox_resampler( } } +fn on_perform_rpc( + server: &'static FfiServer, + request: proto::PerformRpcRequest, +) -> FfiResult { + let ffi_participant = + server.retrieve_handle::(request.local_participant_handle)?.clone(); + return ffi_participant.perform_rpc(server, request); +} + +fn on_register_rpc_method( + server: &'static FfiServer, + request: proto::RegisterRpcMethodRequest, +) -> FfiResult { + let ffi_participant = + server.retrieve_handle::(request.local_participant_handle)?.clone(); + return ffi_participant.register_rpc_method(server, request); +} + +fn on_unregister_rpc_method( + server: &'static FfiServer, + request: proto::UnregisterRpcMethodRequest, +) -> FfiResult { + let ffi_participant = + server.retrieve_handle::(request.local_participant_handle)?.clone(); + return ffi_participant.unregister_rpc_method(server, request); +} + +fn on_rpc_method_invocation_response( + server: &'static FfiServer, + request: proto::RpcMethodInvocationResponseRequest, +) -> FfiResult { + let ffi_participant = + server.retrieve_handle::(request.local_participant_handle)?.clone(); + + let room = ffi_participant.room; + + let mut error: Option = None; + + if let Some(waiter) = room.take_rpc_method_invocation_waiter(request.invocation_id) { + let result = if let Some(error) = request.error.clone() { + Err(RpcError { code: error.code, message: error.message, data: error.data }) + } else { + Ok(request.payload.unwrap_or_default()) + }; + let _ = waiter.send(result); + } else { + error = Some("No caller found".to_string()); + } + + Ok(proto::RpcMethodInvocationResponseResponse { error }) +} + #[allow(clippy::field_reassign_with_default)] // Avoid uggly format pub fn handle_request( server: &'static FfiServer, @@ -922,6 +976,24 @@ pub fn handle_request( server, flush_soxr, )?) } + proto::ffi_request::Message::PerformRpc(request) => { + proto::ffi_response::Message::PerformRpc(on_perform_rpc(server, request)?) + } + proto::ffi_request::Message::RegisterRpcMethod(request) => { + proto::ffi_response::Message::RegisterRpcMethod(on_register_rpc_method( + server, request, + )?) + } + proto::ffi_request::Message::UnregisterRpcMethod(request) => { + proto::ffi_response::Message::UnregisterRpcMethod(on_unregister_rpc_method( + server, request, + )?) + } + proto::ffi_request::Message::RpcMethodInvocationResponse(request) => { + proto::ffi_response::Message::RpcMethodInvocationResponse( + on_rpc_method_invocation_response(server, request)?, + ) + } }); Ok(res) diff --git a/livekit-ffi/src/server/room.rs b/livekit-ffi/src/server/room.rs index a6da9b634..1367897e8 100644 --- a/livekit-ffi/src/server/room.rs +++ b/livekit-ffi/src/server/room.rs @@ -13,29 +13,23 @@ // limitations under the License. use std::collections::HashMap; -use std::{collections::HashSet, slice, sync::Arc, time::Duration}; +use std::time::Duration; +use std::{collections::HashSet, slice, sync::Arc}; use livekit::prelude::*; -use livekit::{participant, track, ChatMessage}; +use livekit::ChatMessage; use parking_lot::Mutex; use tokio::sync::{broadcast, mpsc, oneshot, Mutex as AsyncMutex}; use tokio::task::JoinHandle; use super::FfiDataBuffer; -use crate::conversion::room; use crate::{ proto, + server::participant::FfiParticipant, server::{FfiHandle, FfiServer}, FfiError, FfiHandleId, FfiResult, }; -#[derive(Clone)] -pub struct FfiParticipant { - pub handle: FfiHandleId, - pub participant: Participant, - pub room: Arc, -} - #[derive(Clone)] pub struct FfiPublication { pub handle: FfiHandleId, @@ -50,7 +44,6 @@ pub struct FfiTrack { impl FfiHandle for FfiTrack {} impl FfiHandle for FfiPublication {} -impl FfiHandle for FfiParticipant {} impl FfiHandle for FfiRoom {} #[derive(Clone)] @@ -73,6 +66,9 @@ pub struct RoomInner { pending_unpublished_tracks: Mutex>, track_handle_lookup: Arc>>, + + // Used to forward RPC method invocation to the FfiClient and collect their results + rpc_method_invocation_waiters: Mutex>>>, } struct Handle { @@ -140,7 +136,6 @@ impl FfiRoom { let (data_tx, data_rx) = mpsc::unbounded_channel(); let (transcription_tx, transcription_rx) = mpsc::unbounded_channel(); let (dtmf_tx, dtmf_rx) = mpsc::unbounded_channel(); - let (close_tx, close_rx) = broadcast::channel(1); let handle_id = server.next_id(); @@ -153,6 +148,7 @@ impl FfiRoom { pending_published_tracks: Default::default(), pending_unpublished_tracks: Default::default(), track_handle_lookup: Default::default(), + rpc_method_invocation_waiters: Default::default(), }); let (local_info, remote_infos) = @@ -595,7 +591,6 @@ impl RoomInner { ) .await; let sent_message = res.as_ref().unwrap().clone(); - match res { Ok(message) => { let _ = server.send_event(proto::ffi_event::Message::ChatMessage( @@ -670,6 +665,21 @@ impl RoomInner { server.watch_panic(handle); proto::SendChatMessageResponse { async_id } } + + pub fn store_rpc_method_invocation_waiter( + &self, + invocation_id: u64, + waiter: oneshot::Sender>, + ) { + self.rpc_method_invocation_waiters.lock().insert(invocation_id, waiter); + } + + pub fn take_rpc_method_invocation_waiter( + &self, + invocation_id: u64, + ) -> Option>> { + return self.rpc_method_invocation_waiters.lock().remove(&invocation_id); + } } // Task used to publish data without blocking the client thread @@ -1086,6 +1096,7 @@ async fn forward_event( }, )); } + RoomEvent::ChatMessage { message, participant } => { let (sid, identity) = match participant { Some(p) => (Some(p.sid().to_string()), p.identity().to_string()), @@ -1097,6 +1108,7 @@ async fn forward_event( participant_identity: identity, })); } + RoomEvent::ConnectionStateChanged(state) => { let _ = send_event(proto::room_event::Message::ConnectionStateChanged( proto::ConnectionStateChanged { state: proto::ConnectionState::from(state).into() }, diff --git a/livekit-ffi/src/server/utils.rs b/livekit-ffi/src/server/utils.rs index 59f0e037d..7e95cd5fd 100644 --- a/livekit-ffi/src/server/utils.rs +++ b/livekit-ffi/src/server/utils.rs @@ -1,7 +1,7 @@ use livekit::prelude::{RoomEvent, Track, TrackSource}; use tokio::sync::{broadcast, mpsc}; -use super::room::FfiParticipant; +use super::participant::FfiParticipant; use crate::{server, FfiError, FfiHandleId}; pub async fn track_changed_trigger( diff --git a/livekit/Cargo.toml b/livekit/Cargo.toml index 269ffdf1a..6e150330b 100644 --- a/livekit/Cargo.toml +++ b/livekit/Cargo.toml @@ -41,3 +41,4 @@ thiserror = "1.0" lazy_static = "1.4" log = "0.4" chrono = "0.4.38" +semver = "1.0" diff --git a/livekit/src/prelude.rs b/livekit/src/prelude.rs index f53fdda17..a7b86d62a 100644 --- a/livekit/src/prelude.rs +++ b/livekit/src/prelude.rs @@ -14,7 +14,9 @@ pub use crate::{ id::*, - participant::{ConnectionQuality, LocalParticipant, Participant, RemoteParticipant}, + participant::{ + ConnectionQuality, LocalParticipant, Participant, RemoteParticipant, RpcError, RpcErrorCode, + }, publication::{LocalTrackPublication, RemoteTrackPublication, TrackPublication}, track::{ AudioTrack, LocalAudioTrack, LocalTrack, LocalVideoTrack, RemoteAudioTrack, RemoteTrack, diff --git a/livekit/src/room/mod.rs b/livekit/src/room/mod.rs index ab296b186..28d51902f 100644 --- a/livekit/src/room/mod.rs +++ b/livekit/src/room/mod.rs @@ -255,6 +255,30 @@ pub struct ChatMessage { pub generated: Option, } +#[derive(Debug, Clone)] +pub struct RpcRequest { + pub destination_identity: String, + pub id: String, + pub method: String, + pub payload: String, + pub response_timeout_ms: u32, + pub version: u32, +} + +#[derive(Debug, Clone)] +pub struct RpcResponse { + destination_identity: String, + request_id: String, + payload: Option, + error: Option, +} + +#[derive(Debug, Clone)] +pub struct RpcAck { + destination_identity: String, + request_id: String, +} + #[derive(Debug, Clone)] #[non_exhaustive] pub struct RoomSdkOptions { @@ -660,6 +684,34 @@ impl RoomSession { EngineEvent::SipDTMF { code, digit, participant_identity } => { self.handle_dtmf(code, digit, participant_identity); } + EngineEvent::RpcRequest { + caller_identity, + request_id, + method, + payload, + response_timeout_ms, + version, + } => { + if caller_identity.is_none() { + log::warn!("Received RPC request with null caller identity"); + return Ok(()); + } + self.local_participant + .handle_incoming_rpc_request( + caller_identity.unwrap(), + request_id, + method, + payload, + response_timeout_ms, + ) + .await; + } + EngineEvent::RpcResponse { request_id, payload, error } => { + self.local_participant.handle_incoming_rpc_response(request_id, payload, error); + } + EngineEvent::RpcAck { request_id } => { + self.local_participant.handle_incoming_rpc_ack(request_id); + } EngineEvent::SpeakersChanged { speakers } => self.handle_speakers_changed(speakers), EngineEvent::ConnectionQuality { updates } => { self.handle_connection_quality_update(updates) @@ -667,6 +719,7 @@ impl RoomSession { EngineEvent::LocalTrackSubscribed { track_sid } => { self.handle_track_subscribed(track_sid) } + _ => {} } Ok(()) diff --git a/livekit/src/room/participant/local_participant.rs b/livekit/src/room/participant/local_participant.rs index 6b334abb4..8c4596763 100644 --- a/livekit/src/room/participant/local_participant.rs +++ b/livekit/src/room/participant/local_participant.rs @@ -12,30 +12,39 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{ - collections::HashMap, - fmt::Debug, - sync::{self, Arc}, - time::Duration, -}; - -use chrono::{TimeZone, Utc}; - -use libwebrtc::{native::create_random_uuid, rtp_parameters::RtpEncodingParameters}; -use livekit_api::signal_client::SignalError; -use livekit_protocol as proto; -use livekit_runtime::timeout; -use parking_lot::Mutex; -use proto::request_response::Reason; +use std::{collections::HashMap, fmt::Debug, pin::Pin, sync::Arc, time::Duration}; use super::{ConnectionQuality, ParticipantInner, ParticipantKind}; use crate::{ e2ee::EncryptionType, options::{self, compute_video_encodings, video_layers_from_encodings, TrackPublishOptions}, prelude::*, + room::participant::rpc::{RpcError, RpcErrorCode, MAX_PAYLOAD_BYTES}, rtc_engine::{EngineError, RtcEngine}, - ChatMessage, DataPacket, SipDTMF, Transcription, + ChatMessage, DataPacket, RpcAck, RpcRequest, RpcResponse, SipDTMF, Transcription, }; +use chrono::Utc; +use futures_util::Future; + +use libwebrtc::{native::create_random_uuid, rtp_parameters::RtpEncodingParameters}; +use livekit_api::signal_client::SignalError; +use livekit_protocol as proto; +use livekit_runtime::timeout; +use parking_lot::Mutex; +use proto::request_response::Reason; +use semver::Version; +use tokio::sync::oneshot; + +type RpcHandler = Arc< + dyn Fn( + String, // request_id + ParticipantIdentity, // caller_identity + String, // payload + Duration, // response_timeout_ms + ) -> Pin> + Send>> + + Send + + Sync, +>; const REQUEST_TIMEOUT: Duration = Duration::from_secs(5); @@ -48,9 +57,26 @@ struct LocalEvents { local_track_unpublished: Mutex>, } +struct RpcState { + pending_acks: HashMap>, + pending_responses: HashMap>>, + handlers: HashMap, +} + +impl RpcState { + fn new() -> Self { + Self { + pending_acks: HashMap::new(), + pending_responses: HashMap::new(), + handlers: HashMap::new(), + } + } +} + struct LocalInfo { events: LocalEvents, encryption_type: EncryptionType, + rpc_state: Mutex, } #[derive(Clone)] @@ -82,7 +108,11 @@ impl LocalParticipant { ) -> Self { Self { inner: super::new_inner(rtc_engine, sid, identity, name, metadata, attributes, kind), - local: Arc::new(LocalInfo { events: LocalEvents::default(), encryption_type }), + local: Arc::new(LocalInfo { + events: LocalEvents::default(), + encryption_type, + rpc_state: Mutex::new(RpcState::new()), + }), } } @@ -438,16 +468,14 @@ impl LocalParticipant { let segments: Vec = packet .segments .into_iter() - .map( - (|segment| proto::TranscriptionSegment { - id: segment.id, - start_time: segment.start_time, - end_time: segment.end_time, - text: segment.text, - r#final: segment.r#final, - language: segment.language, - }), - ) + .map(|segment| proto::TranscriptionSegment { + id: segment.id, + start_time: segment.start_time, + end_time: segment.end_time, + text: segment.text, + r#final: segment.r#final, + language: segment.language, + }) .collect(); let transcription_packet = proto::Transcription { transcribed_participant_identity: packet.participant_identity, @@ -483,6 +511,76 @@ impl LocalParticipant { .map_err(Into::into) } + async fn publish_rpc_request(&self, rpc_request: RpcRequest) -> RoomResult<()> { + let destination_identities = vec![rpc_request.destination_identity]; + let rpc_request_message = proto::RpcRequest { + id: rpc_request.id, + method: rpc_request.method, + payload: rpc_request.payload, + response_timeout_ms: rpc_request.response_timeout_ms, + version: rpc_request.version, + ..Default::default() + }; + + let data = proto::DataPacket { + value: Some(proto::data_packet::Value::RpcRequest(rpc_request_message)), + destination_identities, + ..Default::default() + }; + + self.inner + .rtc_engine + .publish_data(&data, DataPacketKind::Reliable) + .await + .map_err(Into::into) + } + + async fn publish_rpc_response(&self, rpc_response: RpcResponse) -> RoomResult<()> { + let destination_identities = vec![rpc_response.destination_identity]; + let rpc_response_message = proto::RpcResponse { + request_id: rpc_response.request_id, + value: Some(match rpc_response.error { + Some(error) => proto::rpc_response::Value::Error(proto::RpcError { + code: error.code, + message: error.message, + data: error.data, + }), + None => proto::rpc_response::Value::Payload(rpc_response.payload.unwrap()), + }), + ..Default::default() + }; + + let data = proto::DataPacket { + value: Some(proto::data_packet::Value::RpcResponse(rpc_response_message)), + destination_identities: destination_identities.clone(), + ..Default::default() + }; + + self.inner + .rtc_engine + .publish_data(&data, DataPacketKind::Reliable) + .await + .map_err(Into::into) + } + + async fn publish_rpc_ack(&self, rpc_ack: RpcAck) -> RoomResult<()> { + let destination_identities = vec![rpc_ack.destination_identity]; + let rpc_ack_message = + proto::RpcAck { request_id: rpc_ack.request_id, ..Default::default() }; + + let data = proto::DataPacket { + value: Some(proto::data_packet::Value::RpcAck(rpc_ack_message)), + destination_identities: destination_identities.clone(), + ..Default::default() + }; + + self.inner + .rtc_engine + .publish_data(&data, DataPacketKind::Reliable) + .await + .map_err(Into::into) + } + pub fn get_track_publication(&self, sid: &TrackSid) -> Option { self.inner.track_publications.read().get(sid).map(|track| { if let TrackPublication::Local(local) = track { @@ -544,4 +642,207 @@ impl LocalParticipant { pub fn kind(&self) -> ParticipantKind { self.inner.info.read().kind } + + pub async fn perform_rpc( + &self, + destination_identity: String, + method: String, + payload: String, + response_timeout_ms: Option, + ) -> Result { + let response_timeout = Duration::from_millis(response_timeout_ms.unwrap_or(10000) as u64); + let max_round_trip_latency = Duration::from_millis(2000); + + if payload.len() > MAX_PAYLOAD_BYTES { + return Err(RpcError::built_in(RpcErrorCode::RequestPayloadTooLarge, None)); + } + + if let Some(server_info) = + self.inner.rtc_engine.session().signal_client().join_response().server_info + { + if !server_info.version.is_empty() { + let server_version = Version::parse(&server_info.version).unwrap(); + let min_required_version = Version::parse("1.8.0").unwrap(); + if server_version < min_required_version { + return Err(RpcError::built_in(RpcErrorCode::UnsupportedServer, None)); + } + } + } + + let id = create_random_uuid(); + let (ack_tx, ack_rx) = oneshot::channel(); + let (response_tx, response_rx) = oneshot::channel(); + + match self + .publish_rpc_request(RpcRequest { + destination_identity: destination_identity.clone(), + id: id.clone(), + method: method.clone(), + payload: payload.clone(), + response_timeout_ms: (response_timeout - max_round_trip_latency).as_millis() as u32, + version: 1, + }) + .await + { + Ok(_) => { + let mut rpc_state = self.local.rpc_state.lock(); + rpc_state.pending_acks.insert(id.clone(), ack_tx); + rpc_state.pending_responses.insert(id.clone(), response_tx); + } + Err(e) => { + log::error!("Failed to publish RPC request: {}", e); + return Err(RpcError::built_in(RpcErrorCode::SendFailed, Some(e.to_string()))); + } + } + + // Wait for ack timeout + match tokio::time::timeout(max_round_trip_latency, ack_rx).await { + Err(_) => { + let mut rpc_state = self.local.rpc_state.lock(); + rpc_state.pending_acks.remove(&id); + rpc_state.pending_responses.remove(&id); + return Err(RpcError::built_in(RpcErrorCode::ConnectionTimeout, None)); + } + Ok(_) => { + // Ack received, continue to wait for response + } + } + + // Wait for response timout + let response = match tokio::time::timeout(response_timeout, response_rx).await { + Err(_) => { + self.local.rpc_state.lock().pending_responses.remove(&id); + return Err(RpcError::built_in(RpcErrorCode::ResponseTimeout, None)); + } + Ok(result) => result, + }; + + match response { + Err(_) => { + // Something went wrong locally + Err(RpcError::built_in(RpcErrorCode::RecipientDisconnected, None)) + } + Ok(Err(e)) => { + // RPC error from remote, forward it + Err(e) + } + Ok(Ok(payload)) => { + // Successful response + Ok(payload) + } + } + } + + pub fn register_rpc_method( + &self, + method: String, + handler: impl Fn( + String, + ParticipantIdentity, + String, + Duration, + ) -> Pin> + Send>> + + Send + + Sync + + 'static, + ) { + self.local.rpc_state.lock().handlers.insert(method, Arc::new(handler)); + } + + pub fn unregister_rpc_method(&self, method: String) { + self.local.rpc_state.lock().handlers.remove(&method); + } + + pub(crate) fn handle_incoming_rpc_ack(&self, request_id: String) { + let mut rpc_state = self.local.rpc_state.lock(); + if let Some(tx) = rpc_state.pending_acks.remove(&request_id) { + let _ = tx.send(()); + } else { + log::error!("Ack received for unexpected RPC request: {}", request_id); + } + } + + pub(crate) fn handle_incoming_rpc_response( + &self, + request_id: String, + payload: Option, + error: Option, + ) { + let mut rpc_state = self.local.rpc_state.lock(); + if let Some(tx) = rpc_state.pending_responses.remove(&request_id) { + let _ = tx.send(match error { + Some(e) => Err(RpcError::from_proto(e)), + None => Ok(payload.unwrap_or_default()), + }); + } else { + log::error!("Response received for unexpected RPC request: {}", request_id); + } + } + + pub(crate) async fn handle_incoming_rpc_request( + &self, + caller_identity: ParticipantIdentity, + request_id: String, + method: String, + payload: String, + response_timeout_ms: u32, + ) { + if let Err(e) = self + .publish_rpc_ack(RpcAck { + destination_identity: caller_identity.to_string(), + request_id: request_id.clone(), + }) + .await + { + log::error!("Failed to publish RPC ACK: {:?}", e); + } + + let handler = self.local.rpc_state.lock().handlers.get(&method).cloned(); + + let caller_identity_2 = caller_identity.clone(); + let request_id_2 = request_id.clone(); + + let response = match handler { + Some(handler) => { + match tokio::task::spawn(async move { + handler( + request_id.clone(), + caller_identity.clone(), + payload.clone(), + Duration::from_millis(response_timeout_ms as u64), + ) + .await + }) + .await + { + Ok(result) => result, + Err(e) => { + log::error!("RPC method handler returned an error: {:?}", e); + Err(RpcError::built_in(RpcErrorCode::ApplicationError, None)) + } + } + } + None => Err(RpcError::built_in(RpcErrorCode::UnsupportedMethod, None)), + }; + + let (payload, error) = match response { + Ok(response_payload) if response_payload.len() <= MAX_PAYLOAD_BYTES => { + (Some(response_payload), None) + } + Ok(_) => (None, Some(RpcError::built_in(RpcErrorCode::ResponsePayloadTooLarge, None))), + Err(e) => (None, Some(e.into())), + }; + + if let Err(e) = self + .publish_rpc_response(RpcResponse { + destination_identity: caller_identity_2.to_string(), + request_id: request_id_2, + payload, + error: error.map(|e| e.to_proto()), + }) + .await + { + log::error!("Failed to publish RPC response: {:?}", e); + } + } } diff --git a/livekit/src/room/participant/mod.rs b/livekit/src/room/participant/mod.rs index c535111cf..f8f7a7460 100644 --- a/livekit/src/room/participant/mod.rs +++ b/livekit/src/room/participant/mod.rs @@ -22,10 +22,12 @@ use crate::{prelude::*, rtc_engine::RtcEngine}; mod local_participant; mod remote_participant; +mod rpc; use crate::room::utils; pub use local_participant::*; pub use remote_participant::*; +pub use rpc::*; #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub enum ConnectionQuality { diff --git a/livekit/src/room/participant/rpc.rs b/livekit/src/room/participant/rpc.rs new file mode 100644 index 000000000..3ddb0f545 --- /dev/null +++ b/livekit/src/room/participant/rpc.rs @@ -0,0 +1,112 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +use livekit_protocol::RpcError as RpcError_Proto; + +/// Specialized error handling for RPC methods. +/// +/// Instances of this type, when thrown in a method handler, will have their `message` +/// serialized and sent across the wire. The caller will receive an equivalent error on the other side. +/// +/// Build-in types are included but developers may use any string, with a max length of 256 bytes. +#[derive(Debug, Clone)] +pub struct RpcError { + pub code: u32, + pub message: String, + pub data: Option, +} + +impl RpcError { + pub const MAX_MESSAGE_BYTES: usize = 256; + pub const MAX_DATA_BYTES: usize = 15360; // 15 KB + + /// Creates an error object with the given code and message, plus an optional data payload. + /// + /// If thrown in an RPC method handler, the error will be sent back to the caller. + /// + /// Error codes 1001-1999 are reserved for built-in errors (see RpcErrorCode for their meanings). + pub fn new(code: u32, message: String, data: Option) -> Self { + Self { + code, + message: truncate_bytes(&message, Self::MAX_MESSAGE_BYTES), + data: data.map(|d| truncate_bytes(&d, Self::MAX_DATA_BYTES)), + } + } + + pub fn from_proto(proto: RpcError_Proto) -> Self { + Self::new(proto.code, proto.message, Some(proto.data)) + } + + pub fn to_proto(&self) -> RpcError_Proto { + RpcError_Proto { + code: self.code, + message: self.message.clone(), + data: self.data.clone().unwrap_or_default(), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub enum RpcErrorCode { + ApplicationError = 1500, + ConnectionTimeout = 1501, + ResponseTimeout = 1502, + RecipientDisconnected = 1503, + ResponsePayloadTooLarge = 1504, + SendFailed = 1505, + + UnsupportedMethod = 1400, + RecipientNotFound = 1401, + RequestPayloadTooLarge = 1402, + UnsupportedServer = 1403, +} + +impl RpcErrorCode { + pub(crate) fn message(&self) -> &'static str { + match self { + Self::ApplicationError => "Application error in method handler", + Self::ConnectionTimeout => "Connection timeout", + Self::ResponseTimeout => "Response timeout", + Self::RecipientDisconnected => "Recipient disconnected", + Self::ResponsePayloadTooLarge => "Response payload too large", + Self::SendFailed => "Failed to send", + + Self::UnsupportedMethod => "Method not supported at destination", + Self::RecipientNotFound => "Recipient not found", + Self::RequestPayloadTooLarge => "Request payload too large", + Self::UnsupportedServer => "RPC not supported by server", + } + } +} + +impl RpcError { + /// Creates an error object from the code, with an auto-populated message. + pub(crate) fn built_in(code: RpcErrorCode, data: Option) -> Self { + Self::new(code as u32, code.message().to_string(), data) + } +} + +/// Maximum payload size in bytes +pub const MAX_PAYLOAD_BYTES: usize = 15360; // 15 KB + +/// Calculate the byte length of a string +pub(crate) fn byte_length(s: &str) -> usize { + s.as_bytes().len() +} + +/// Truncate a string to a maximum number of bytes +pub(crate) fn truncate_bytes(s: &str, max_bytes: usize) -> String { + if byte_length(s) <= max_bytes { + return s.to_string(); + } + + let mut result = String::new(); + for c in s.chars() { + if byte_length(&(result.clone() + &c.to_string())) > max_bytes { + break; + } + result.push(c); + } + result +} diff --git a/livekit/src/rtc_engine/mod.rs b/livekit/src/rtc_engine/mod.rs index efaf28740..d848571b6 100644 --- a/livekit/src/rtc_engine/mod.rs +++ b/livekit/src/rtc_engine/mod.rs @@ -113,6 +113,22 @@ pub enum EngineEvent { code: u32, digit: Option, }, + RpcRequest { + caller_identity: Option, + request_id: String, + method: String, + payload: String, + response_timeout_ms: u32, + version: u32, + }, + RpcResponse { + request_id: String, + payload: Option, + error: Option, + }, + RpcAck { + request_id: String, + }, SpeakersChanged { speakers: Vec, }, @@ -462,6 +478,34 @@ impl EngineInner { segments, }); } + SessionEvent::SipDTMF { participant_identity, code, digit } => { + let _ = + self.engine_tx.send(EngineEvent::SipDTMF { participant_identity, code, digit }); + } + SessionEvent::RpcRequest { + caller_identity, + request_id, + method, + payload, + response_timeout_ms, + version, + } => { + let _ = self.engine_tx.send(EngineEvent::RpcRequest { + caller_identity, + request_id, + method, + payload, + response_timeout_ms, + version, + }); + } + SessionEvent::RpcResponse { request_id, payload, error } => { + let _ = + self.engine_tx.send(EngineEvent::RpcResponse { request_id, payload, error }); + } + SessionEvent::RpcAck { request_id } => { + let _ = self.engine_tx.send(EngineEvent::RpcAck { request_id }); + } SessionEvent::MediaTrack { track, stream, transceiver } => { let _ = self.engine_tx.send(EngineEvent::MediaTrack { track, stream, transceiver }); } diff --git a/livekit/src/rtc_engine/rtc_session.rs b/livekit/src/rtc_engine/rtc_session.rs index 58168359b..5377be510 100644 --- a/livekit/src/rtc_engine/rtc_session.rs +++ b/livekit/src/rtc_engine/rtc_session.rs @@ -96,6 +96,22 @@ pub enum SessionEvent { code: u32, digit: Option, }, + RpcRequest { + caller_identity: Option, + request_id: String, + method: String, + payload: String, + response_timeout_ms: u32, + version: u32, + }, + RpcResponse { + request_id: String, + payload: Option, + error: Option, + }, + RpcAck { + request_id: String, + }, MediaTrack { track: MediaStreamTrack, stream: MediaStream, @@ -661,6 +677,42 @@ impl SessionInner { segments, }); } + proto::data_packet::Value::RpcRequest(rpc_request) => { + let caller_identity = data + .participant_identity + .is_empty() + .not() + .then_some(data.participant_identity.clone()) + .map(|s| s.try_into().unwrap()); + let _ = self.emitter.send(SessionEvent::RpcRequest { + caller_identity, + request_id: rpc_request.id.clone(), + method: rpc_request.method.clone(), + payload: rpc_request.payload.clone(), + response_timeout_ms: rpc_request.response_timeout_ms, + version: rpc_request.version, + }); + } + proto::data_packet::Value::RpcResponse(rpc_response) => { + let _ = self.emitter.send(SessionEvent::RpcResponse { + request_id: rpc_response.request_id.clone(), + payload: rpc_response.value.as_ref().and_then(|v| match v { + proto::rpc_response::Value::Payload(payload) => { + Some(payload.clone()) + } + _ => None, + }), + error: rpc_response.value.as_ref().and_then(|v| match v { + proto::rpc_response::Value::Error(error) => Some(error.clone()), + _ => None, + }), + }); + } + proto::data_packet::Value::RpcAck(rpc_ack) => { + let _ = self.emitter.send(SessionEvent::RpcAck { + request_id: rpc_ack.request_id.clone(), + }); + } proto::data_packet::Value::ChatMessage(message) => { let _ = self.emitter.send(SessionEvent::ChatMessage { participant_identity: ParticipantIdentity( diff --git a/soxr-sys/src/lib.rs b/soxr-sys/src/lib.rs index 545f21c78..1a59d4db3 100644 --- a/soxr-sys/src/lib.rs +++ b/soxr-sys/src/lib.rs @@ -6,8 +6,6 @@ include!("soxr.rs"); #[cfg(test)] mod tests { use super::*; - use std::fs::File; - use std::io::{Read, Seek, SeekFrom}; #[test] fn it_works() {