Skip to content

Commit

Permalink
Merge pull request #17 from xd009642/feat/vad-version
Browse files Browse the repository at this point in the history
Initial pass at a VAD version
  • Loading branch information
xd009642 authored Sep 7, 2024
2 parents cec485a + 0fface1 commit 09915cf
Show file tree
Hide file tree
Showing 8 changed files with 678 additions and 265 deletions.
599 changes: 490 additions & 109 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ opentelemetry_sdk = { version = "0.23.0", features = ["rt-tokio", "trace"] }
rubato = "0.15.0"
serde = { version = "1.0.200", features = ["derive"] }
serde_json = "1.0.117"
silero = { git = "https://github.com/emotechlab/silero-rs" }
tokio = { version = "1.37.0", features = ["macros", "signal", "sync", "rt-multi-thread"] }
tokio-metrics = "0.3.1"
tokio-stream = { version = "0.1.15", features = ["sync"] }
Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ model we have the decison on whether we:
1. Process all of the incoming audio
2. Detect segments of interest and process them (VAD/energy filtering)

There's also a choice on whether we can process segments concurrently or if the
result from one segment needs to be applied to the future segment for various
reasons i.e. smoothing/hiding seams generative outputs from the audio.
For the first options there's also a choice on whether we can process segments
concurrently or if the result from one segment needs to be applied to the future
segment for various reasons i.e. smoothing/hiding seams generative outputs from
the audio.

Enumerating these patterns and representing them all in the code is a WIP.
Currently, I process everything and assume no relationship between utterances.
Expand Down
32 changes: 21 additions & 11 deletions src/api_types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{model, OutputEvent};
use crate::model;
use opentelemetry::propagation::Extractor;
use serde::{Deserialize, Serialize};

Expand All @@ -8,6 +8,8 @@ pub struct StartMessage {
pub trace_id: Option<String>,
/// Format information for the audio samples
pub format: AudioFormat,
// TODO here we likely need some configuration to let people do things like configure the VAD
// sensitivity.
}

/// Describes the PCM samples coming in. I could have gone for an enum instead of bit_depth +
Expand Down Expand Up @@ -53,24 +55,32 @@ pub enum RequestMessage {
Stop(StopMessage),
}

#[derive(Serialize, Deserialize)]
/// If we're processing segments of audio we
#[derive(Debug, Serialize, Deserialize)]
pub struct SegmentOutput {
/// Start time of the segment in seconds
pub start_time: f32,
/// End time of the segment in seconds
pub end_time: f32,
/// Some APIs may do the inverse check of "is_partial" where the last request in an utterance
/// would be `false`
#[serde(skip_serializing_if = "Option::is_none")]
pub is_final: Option<bool>,
/// The output from our ML model
#[serde(flatten)]
pub output: model::Output,
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Event {
Data(model::Output),
Segment(SegmentOutput),
Error(String),
Active,
Inactive,
}

impl From<OutputEvent> for Event {
fn from(event: OutputEvent) -> Self {
match event {
OutputEvent::Response(o) => Event::Data(o),
OutputEvent::ModelError(e) => Event::Error(e),
}
}
}

#[derive(Serialize, Deserialize)]
#[serde(tag = "event", rename_all = "snake_case")]
pub struct ResponseMessage {
Expand Down
2 changes: 1 addition & 1 deletion src/audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub async fn decode_audio(
anyhow::bail!("No output sinks for channel data");
}

const RESAMPLER_SIZE: usize = 4086;
const RESAMPLER_SIZE: usize = 4096;

let resample_ratio = 16000.0 / audio_format.sample_rate as f64;

Expand Down
36 changes: 27 additions & 9 deletions src/axum_server.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::api_types::*;
use crate::audio::decode_audio;
use crate::metrics::*;
use crate::{OutputEvent, StreamingContext};
use crate::StreamingContext;
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Expand All @@ -26,11 +26,14 @@ use tracing_opentelemetry::OpenTelemetrySpanExt;

async fn ws_handler(
ws: WebSocketUpgrade,
vad_processing: bool,
Extension(state): Extension<Arc<StreamingContext>>,
Extension(metrics): Extension<Arc<AppMetricsEncoder>>,
) -> impl IntoResponse {
let current = Span::current();
ws.on_upgrade(move |socket| handle_socket(socket, state, metrics).instrument(current))
ws.on_upgrade(move |socket| {
handle_socket(socket, vad_processing, state, metrics).instrument(current)
})
}

async fn handle_initial_start<S, E>(receiver: &mut S) -> Option<StartMessage>
Expand Down Expand Up @@ -60,8 +63,7 @@ where
start
}

fn create_websocket_message(output: OutputEvent) -> Result<Message, axum::Error> {
let event = Event::from(output);
fn create_websocket_message(event: Event) -> Result<Message, axum::Error> {
let string = serde_json::to_string(&event).unwrap();
Ok(Message::Text(string))
}
Expand All @@ -72,6 +74,7 @@ fn create_websocket_message(output: OutputEvent) -> Result<Message, axum::Error>
/// tracing harder RE otel context propagation.
async fn handle_socket(
socket: WebSocket,
vad_processing: bool,
state: Arc<StreamingContext>,
metrics_enc: Arc<AppMetricsEncoder>,
) {
Expand Down Expand Up @@ -121,9 +124,15 @@ async fn handle_socket(
let inference_task = TaskMonitor::instrument(
&monitors.inference,
async move {
context
.inference_runner(samples_rx, client_sender_clone)
.await
if vad_processing {
context
.segmented_runner(samples_rx, client_sender_clone)
.await
} else {
context
.inference_runner(samples_rx, client_sender_clone)
.await
}
}
.in_current_span(),
);
Expand Down Expand Up @@ -212,11 +221,20 @@ pub fn make_service_router(app_state: Arc<StreamingContext>) -> Router {
});
Router::new()
.route(
"/api/v1/stream",
"/api/v1/simple",
get({
move |ws, app_state, metrics_enc: Extension<Arc<AppMetricsEncoder>>| {
let route = metrics_enc.metrics.route.clone();
TaskMonitor::instrument(&route, ws_handler(ws, false, app_state, metrics_enc))
}
}),
)
.route(
"/api/v1/segmented",
get({
move |ws, app_state, metrics_enc: Extension<Arc<AppMetricsEncoder>>| {
let route = metrics_enc.metrics.route.clone();
TaskMonitor::instrument(&route, ws_handler(ws, app_state, metrics_enc))
TaskMonitor::instrument(&route, ws_handler(ws, true, app_state, metrics_enc))
}
}),
)
Expand Down
5 changes: 3 additions & 2 deletions src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ struct Cli {
#[clap(long, default_value = "256")]
/// Size of audio chunks to send to the server
chunk_size: usize,
#[clap(short, long, default_value = "ws://localhost:8080/api/v1/stream")]
/// Address of the streaming server
#[clap(short, long, default_value = "ws://localhost:8080/api/v1/segmented")]
/// Address of the streaming server (/api/v1/segmented or /api/v1/simple for vad or non-vad
/// options)
addr: String,
#[clap(long)]
/// Attempts to simulate real time streaming by adding a pause between sending proportional to
Expand Down
Loading

0 comments on commit 09915cf

Please sign in to comment.