Skip to content

Commit

Permalink
Fix: pass mid-connection WS events to task (#269)
Browse files Browse the repository at this point in the history
Discord will send `GatewayEvent::Speaking` (opcode 5) messages
after the Hello+Ready exchange, but will happily interleave
them with crypto mode negotiation. We were previously not expecting
such messages and dropping them -- this hurts receive-based bots'
ability to map SSRCs to UserIds when joining a call with existing
users.

This PR feeds all unexpected messages into the WS task directly,
which will handle them once all tasks are fully started.
  • Loading branch information
FelixMcFelix authored Nov 26, 2024
1 parent 71535c5 commit 17993bc
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 12 deletions.
11 changes: 9 additions & 2 deletions examples/serenity/voice_receive/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,10 @@ async fn join(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
.expect("Songbird Voice client placed in at initialisation.")
.clone();

if let Ok(handler_lock) = manager.join(guild_id, connect_to).await {
// NOTE: this skips listening for the actual connection result.
// Some events relating to voice receive fire *while joining*.
// We must make sure that any event handlers are installed before we attempt to join.
{
let handler_lock = manager.get_or_insert(guild_id);
let mut handler = handler_lock.lock().await;

let evt_receiver = Receiver::new();
Expand All @@ -262,13 +264,18 @@ async fn join(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
handler.add_global_event(CoreEvent::RtcpPacket.into(), evt_receiver.clone());
handler.add_global_event(CoreEvent::ClientDisconnect.into(), evt_receiver.clone());
handler.add_global_event(CoreEvent::VoiceTick.into(), evt_receiver);
}

if let Ok(handler_lock) = manager.join(guild_id, connect_to).await {
check_msg(
msg.channel_id
.say(&ctx.http, &format!("Joined {}", connect_to.mention()))
.await,
);
} else {
// Although we failed to join, we need to clear out existing event handlers on the call.
_ = manager.remove(guild_id).await;

check_msg(
msg.channel_id
.say(&ctx.http, "Error joining the channel")
Expand Down
24 changes: 15 additions & 9 deletions src/driver/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ impl Connection {
let url = generate_url(&mut info.endpoint)?;

let mut client = WsStream::connect(url).await?;
let (ws_msg_tx, ws_msg_rx) = flume::unbounded();

let mut hello = None;
let mut ready = None;
Expand Down Expand Up @@ -93,7 +94,11 @@ impl Connection {
}
},
other => {
// Discord hold back per-user connection state until after this handshake.
// There's no guarantee that will remain the case, so buffer it like all
// subsequent steps where we know they *do* send these packets.
debug!("Expected ready/hello; got: {:?}", other);
ws_msg_tx.send(WsMessage::Deliver(other))?;
},
}
}
Expand Down Expand Up @@ -176,13 +181,12 @@ impl Connection {
.await?;
}

let cipher = init_cipher(&mut client, chosen_crypto).await?;
let cipher = init_cipher(&mut client, chosen_crypto, &ws_msg_tx).await?;

info!("Connected to: {}", info.endpoint);

info!("WS heartbeat duration {}ms.", hello.heartbeat_interval);

let (ws_msg_tx, ws_msg_rx) = flume::unbounded();
#[cfg(feature = "receive")]
let (udp_receiver_msg_tx, udp_receiver_msg_rx) = flume::unbounded();

Expand Down Expand Up @@ -304,7 +308,7 @@ impl Connection {
}
},
other => {
debug!("Expected resumed/hello; got: {:?}", other);
self.ws.send(WsMessage::Deliver(other))?;
},
}
}
Expand Down Expand Up @@ -338,7 +342,11 @@ fn generate_url(endpoint: &mut String) -> Result<Url> {
}

#[inline]
async fn init_cipher(client: &mut WsStream, mode: CryptoMode) -> Result<Cipher> {
async fn init_cipher(
client: &mut WsStream,
mode: CryptoMode,
tx: &Sender<WsMessage>,
) -> Result<Cipher> {
loop {
let Some(value) = client.recv_json().await? else {
continue;
Expand All @@ -355,11 +363,9 @@ async fn init_cipher(client: &mut WsStream, mode: CryptoMode) -> Result<Cipher>
.map_err(|_| Error::CryptoInvalidLength);
},
other => {
debug!(
"Expected ready for key; got: op{}/v{:?}",
other.kind() as u8,
other
);
// Discord can and will send user-specific payload packets during this time
// which are needed to map SSRCs to `UserId`s.
tx.send(WsMessage::Deliver(other))?;
},
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/driver/tasks/message/ws.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#![allow(missing_docs)]

use super::Interconnect;
use crate::ws::WsStream;
use crate::{model::Event as GatewayEvent, ws::WsStream};

pub enum WsMessage {
Ws(Box<WsStream>),
ReplaceInterconnect(Interconnect),
SetKeepalive(f64),
Speaking(bool),
Deliver(GatewayEvent),
}
3 changes: 3 additions & 0 deletions src/driver/tasks/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ impl AuxNetwork {
}
}
},
Ok(WsMessage::Deliver(msg)) => {
self.process_ws(interconnect, msg);
},
Err(flume::RecvError::Disconnected) => {
break;
},
Expand Down

0 comments on commit 17993bc

Please sign in to comment.