Skip to content

Commit

Permalink
Merge pull request #24 from xd009642/feat/partial-inference
Browse files Browse the repository at this point in the history
Implement simple partial inferencing
  • Loading branch information
xd009642 authored Sep 10, 2024
2 parents 67293ac + 2d653a7 commit cbb3b60
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 8 deletions.
4 changes: 4 additions & 0 deletions src/api_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ pub struct StartMessage {
pub trace_id: Option<String>,
/// Format information for the audio samples
pub format: AudioFormat,
/// Whether interim results should be provided. An alternative API to this would be to specify
/// the interval at which interim results are returned.
#[serde(default)]
pub interim_results: bool,
// TODO here we likely need some configuration to let people do things like configure the VAD
// sensitivity.
}
Expand Down
3 changes: 2 additions & 1 deletion src/axum_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,14 @@ async fn handle_socket(
let client_sender_clone = client_sender.clone();
let (samples_tx, samples_rx) = mpsc::channel(8);
let context = state.clone();
let start_cloned = start.clone();

let inference_task = TaskMonitor::instrument(
&monitors.inference,
async move {
if vad_processing {
context
.segmented_runner(samples_rx, client_sender_clone)
.segmented_runner(start_cloned, samples_rx, client_sender_clone)
.await
} else {
context
Expand Down
7 changes: 6 additions & 1 deletion src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ struct Cli {
/// Attempts to simulate real time streaming by adding a pause between sending proportional to
/// sample rate
real_time: bool,
/// Return interim results before an endpoint is detected
#[clap(long)]
interim_results: bool,
}

#[tokio::main]
Expand All @@ -44,6 +47,7 @@ async fn main() -> anyhow::Result<()> {
// Lets just start by loading the whole file, doing the messages and then sending them all in
// one go.
let args = Cli::parse();
info!("Config: {:?}", args);

run_client(args).await?;

Expand All @@ -66,7 +70,7 @@ fn get_otel_span_id(span: Span) -> Option<String> {
map.get("traceparent").cloned()
}

#[instrument]
#[instrument(skip_all)]
async fn run_client(args: Cli) -> anyhow::Result<()> {
info!("Connecting to: {}", args.addr);

Expand Down Expand Up @@ -108,6 +112,7 @@ async fn run_client(args: Cli) -> anyhow::Result<()> {
bit_depth: spec.bits_per_sample,
is_float: spec.sample_format == SampleFormat::Float,
},
interim_results: args.interim_results,
});
let delay = if real_time {
let n_samples = (chunk_size as f32 / (spec.bits_per_sample as f32 / 8.0)).ceil();
Expand Down
47 changes: 41 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::api_types::{Event, SegmentOutput};
use crate::api_types::{Event, SegmentOutput, StartMessage};
use crate::model::Model;
use futures::{stream::FuturesOrdered, StreamExt};
use silero::*;
use std::sync::Arc;
use std::thread;
use std::{thread, time::Duration};
use tokio::sync::mpsc;
use tokio::task;
use tracing::{debug, error, info, info_span, instrument, warn, Span};
Expand Down Expand Up @@ -142,6 +142,7 @@ impl StreamingContext {
#[instrument(skip_all)]
pub async fn segmented_runner(
self: Arc<Self>,
settings: StartMessage,
mut inference: mpsc::Receiver<Arc<Vec<f32>>>,
output: mpsc::Sender<Event>,
) -> anyhow::Result<()> {
Expand All @@ -152,6 +153,10 @@ impl StreamingContext {

let mut current_start = None;
let mut current_end = None;
let mut dur_since_inference = Duration::from_millis(0);
// So we're not allowing this to be configured via API. Instead we're setting it to the
// equivalent of every 500ms.
const INTERIM_THRESHOLD: Duration = Duration::from_millis(500);

// Need to test and prove this doesn't lose any data!
while still_receiving {
Expand Down Expand Up @@ -219,20 +224,44 @@ impl StreamingContext {
}
}
}
let current_vad_dur = vad.current_speech_duration();
if last_segment.is_none()
&& settings.interim_results
&& current_vad_dur > (dur_since_inference + INTERIM_THRESHOLD)
{
dur_since_inference = current_vad_dur;
let session_time = vad.session_time();
let audio = vad.get_current_speech().to_vec();
// So here we could do a bit of faffing to not block on this inference to keep
// things running but for now we're going to limit each request to a maximum of
// N_CHANNELS concurrent inferences.
let msg = self
.spawned_inference(
audio,
current_start.zip(Some(session_time.as_millis() as usize)),
false,
)
.await;
output.send(msg).await?;
}

if let Some((start, end)) = last_segment {
let audio = vad.get_speech(start, Some(end)).to_vec();
let msg = self.spawned_inference(audio, Some((start, end))).await;
let msg = self
.spawned_inference(audio, Some((start, end)), true)
.await;
output.send(msg).await?;
dur_since_inference = Duration::from_millis(0);
}

if found_endpoint {
// We actually don't need the start/end if we've got an endpoint!
let audio = vad.get_current_speech().to_vec();
let msg = self
.spawned_inference(audio, current_start.zip(current_end))
.spawned_inference(audio, current_start.zip(current_end), true)
.await;
output.send(msg).await?;
dur_since_inference = Duration::from_millis(0);
current_start = None;
current_end = None;
}
Expand All @@ -257,6 +286,7 @@ impl StreamingContext {
.spawned_inference(
audio,
current_start.zip(Some(session_time.as_millis() as usize)),
true,
)
.await;
output.send(msg).await?;
Expand All @@ -266,7 +296,12 @@ impl StreamingContext {
Ok(())
}

async fn spawned_inference(&self, audio: Vec<f32>, bounds_ms: Option<(usize, usize)>) -> Event {
async fn spawned_inference(
&self,
audio: Vec<f32>,
bounds_ms: Option<(usize, usize)>,
is_final: bool,
) -> Event {
let current = Span::current();
let temp_model = self.model.clone();
let result = task::spawn_blocking(move || {
Expand All @@ -283,7 +318,7 @@ impl StreamingContext {
let seg = SegmentOutput {
start_time,
end_time,
is_final: Some(true),
is_final: Some(is_final),
output,
};
Event::Segment(seg)
Expand Down

0 comments on commit cbb3b60

Please sign in to comment.