diff --git a/Cargo.lock b/Cargo.lock index 449c5e4..4a5574f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + [[package]] name = "ahash" version = "0.8.11" @@ -26,7 +32,7 @@ dependencies = [ "cfg-if", "once_cell", "version_check", - "zerocopy 0.7.35", + "zerocopy", ] [[package]] @@ -123,9 +129,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.81" +version = "0.1.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" +checksum = "a27b8a3a6e1a44fa4c8baf1f653e4172e81486d4941f2237e20dc2d0cf4ddff1" dependencies = [ "proc-macro2", "quote", @@ -161,7 +167,7 @@ dependencies = [ "rustversion", "serde", "sync_wrapper 0.1.2", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", ] @@ -174,7 +180,7 @@ checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf" dependencies = [ "async-trait", "axum-core 0.4.3", - "base64", + "base64 0.21.7", "bytes", "futures-util", "http 1.1.0", @@ -197,7 +203,7 @@ dependencies = [ "sync_wrapper 1.0.1", "tokio", "tokio-tungstenite", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -243,19 +249,19 @@ dependencies = [ [[package]] name = "axum-tracing-opentelemetry" -version = "0.19.0" +version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "164dc5772777b14dbd4e3d5c0b8fb4c2f34cfde658337d89b166c50d32a81aff" +checksum = "561a0967337dfeaf3e28700d23e791712cd7d5e97ab335e0f4a0c3ac62e6ece0" dependencies = [ "axum 0.7.5", "futures-core", "futures-util", "http 1.1.0", - "opentelemetry", + "opentelemetry 0.24.0", "pin-project-lite", - "tower", + "tower 0.5.0", "tracing", - "tracing-opentelemetry", + "tracing-opentelemetry 0.25.0", "tracing-opentelemetry-instrumentation-sdk", ] @@ -269,7 +275,7 @@ dependencies = [ "cc", "cfg-if", "libc", - "miniz_oxide", + "miniz_oxide 0.7.4", "object", "rustc-demangle", ] @@ -280,6 +286,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bitflags" version = "1.3.2" @@ -321,9 +333,12 @@ checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" [[package]] name = "cc" -version = "1.1.7" +version = "1.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26a5c3fd7bfa1ce3897a3a3501d362b2d87b7f2583ebcb4a949ec25911025cbc" +checksum = "e9d013ecb737093c0e86b151a7b837993cf9ec6c502946cfb44bedc392421e0b" +dependencies = [ + "shlex", +] [[package]] name = "cfg-if" @@ -333,9 +348,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.5.13" +version = "4.5.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fbb260a053428790f3de475e304ff84cdbc4face759ea7a3e64c1edd938a7fc" +checksum = "ed6719fffa43d0d87e5fd8caeab59be1554fb028cd30edc88fc4369b17971019" dependencies = [ "clap_builder", "clap_derive", @@ -343,9 +358,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.13" +version = "4.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64b17d7ea74e9f833c7dbf2cbe4fb12ff26783eda4782a8975b72f895c9b4d99" +checksum = "216aec2b177652e3846684cbfe25c9964d18ec45234f0f5da5157b207ed1aab6" dependencies = [ "anstream", "anstyle", @@ -389,25 +404,40 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpufeatures" -version = "0.2.12" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +checksum = "51e852e6dc9a5bed1fae92dd2375037bf2b768725bf3be87811edee3249d09ad" dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + [[package]] name = "crossbeam-utils" version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + [[package]] name = "crypto-common" version = "0.1.6" @@ -574,11 +604,43 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "errno" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "fastrand" -version = "2.1.0" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" + +[[package]] +name = "filetime" +version = "0.2.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" +checksum = "35c0522e981e68cbfa8c3f978441a5f34b30b96e146b33cd3359176b50fe8586" +dependencies = [ + "cfg-if", + "libc", + "libredox", + "windows-sys 0.59.0", +] + +[[package]] +name = "flate2" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253" +dependencies = [ + "crc32fast", + "miniz_oxide 0.8.0", +] [[package]] name = "fnv" @@ -729,13 +791,23 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap 2.3.0", + "indexmap 2.5.0", "slab", "tokio", "tokio-util", "tracing", ] +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "cfg-if", + "crunchy", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -895,9 +967,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956" +checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" dependencies = [ "bytes", "futures-util", @@ -930,9 +1002,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.3.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de3fc2e30ba82dd1b3911c8de1ffc143c74a914a14e99514d7637e3099df5ea0" +checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" dependencies = [ "equivalent", "hashbrown 0.14.5", @@ -944,14 +1016,14 @@ version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2f12893a5f71a99623566df4ec1811a3b0d38a103b0792741aec0688663f042" dependencies = [ - "opentelemetry", + "opentelemetry 0.23.0", "opentelemetry-otlp", "opentelemetry-resource-detectors", "opentelemetry-semantic-conventions", - "opentelemetry_sdk", + "opentelemetry_sdk 0.23.0", "thiserror", "tracing", - "tracing-opentelemetry", + "tracing-opentelemetry 0.24.0", "tracing-subscriber", ] @@ -984,9 +1056,9 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "js-sys" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" dependencies = [ "wasm-bindgen", ] @@ -999,9 +1071,26 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.155" +version = "0.2.158" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" + +[[package]] +name = "libredox" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +dependencies = [ + "bitflags 2.6.0", + "libc", + "redox_syscall", +] + +[[package]] +name = "linux-raw-sys" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" [[package]] name = "lock_api" @@ -1034,6 +1123,16 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "measured" version = "0.0.22" @@ -1084,11 +1183,20 @@ dependencies = [ "adler", ] +[[package]] +name = "miniz_oxide" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +dependencies = [ + "adler2", +] + [[package]] name = "mio" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" dependencies = [ "hermit-abi", "libc", @@ -1096,6 +1204,19 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -1135,9 +1256,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.2" +version = "0.36.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f203fa8daa7bb185f760ae12bd8e097f63d17041dcdcaf675ac54cdf863170e" +checksum = "084f1a5821ac4c651660a94a7153d27ac9d8a53736203f58b31945ded098070a" dependencies = [ "memchr", ] @@ -1162,6 +1283,20 @@ dependencies = [ "thiserror", ] +[[package]] +name = "opentelemetry" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96" +dependencies = [ + "futures-core", + "futures-sink", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", +] + [[package]] name = "opentelemetry-http" version = "0.12.0" @@ -1171,7 +1306,7 @@ dependencies = [ "async-trait", "bytes", "http 0.2.12", - "opentelemetry", + "opentelemetry 0.23.0", "reqwest", ] @@ -1184,10 +1319,10 @@ dependencies = [ "async-trait", "futures-core", "http 0.2.12", - "opentelemetry", + "opentelemetry 0.23.0", "opentelemetry-http", "opentelemetry-proto", - "opentelemetry_sdk", + "opentelemetry_sdk 0.23.0", "prost", "reqwest", "thiserror", @@ -1201,8 +1336,8 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "984806e6cf27f2b49282e2a05e288f30594f3dbc74eb7a6e99422bc48ed78162" dependencies = [ - "opentelemetry", - "opentelemetry_sdk", + "opentelemetry 0.23.0", + "opentelemetry_sdk 0.23.0", "prost", "tonic", ] @@ -1213,9 +1348,9 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5cd98b7277913e22e95b6fd3a5f7413438005471a6e33e8a4ae7b9a20be36ad" dependencies = [ - "opentelemetry", + "opentelemetry 0.23.0", "opentelemetry-semantic-conventions", - "opentelemetry_sdk", + "opentelemetry_sdk 0.23.0", ] [[package]] @@ -1237,7 +1372,7 @@ dependencies = [ "glob", "lazy_static", "once_cell", - "opentelemetry", + "opentelemetry 0.23.0", "ordered-float", "percent-encoding", "rand", @@ -1246,6 +1381,24 @@ dependencies = [ "tokio-stream", ] +[[package]] +name = "opentelemetry_sdk" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df" +dependencies = [ + "async-trait", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "once_cell", + "opentelemetry 0.24.0", + "percent-encoding", + "rand", + "thiserror", +] + [[package]] name = "ordered-float" version = "4.2.2" @@ -1255,6 +1408,34 @@ dependencies = [ "num-traits", ] +[[package]] +name = "ort" +version = "2.0.0-rc.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86d83095ae3c1258738d70ae7a06195c94d966a8e546f0d3609dc90885fb61f5" +dependencies = [ + "half", + "js-sys", + "ndarray", + "ort-sys", + "thiserror", + "tracing", + "web-sys", +] + +[[package]] +name = "ort-sys" +version = "2.0.0-rc.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec6fe264a9467cd0c19cbee07afe689fae9480c4706c4a1a00b5e64ff99ea83a" +dependencies = [ + "flate2", + "pkg-config", + "sha2", + "tar", + "ureq", +] + [[package]] name = "overload" version = "0.1.1" @@ -1322,13 +1503,19 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkg-config" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" + [[package]] name = "ppv-lite86" -version = "0.2.18" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dee4364d9f3b902ef14fab8a1ddffb783a1cb6b4bba3bfc1fa3922732c7de97f" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy 0.6.6", + "zerocopy", ] [[package]] @@ -1374,9 +1561,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] @@ -1411,6 +1598,12 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "realfft" version = "3.3.0" @@ -1479,7 +1672,7 @@ version = "0.11.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" dependencies = [ - "base64", + "base64 0.21.7", "bytes", "encoding_rs", "futures-core", @@ -1509,6 +1702,21 @@ dependencies = [ "winreg", ] +[[package]] +name = "ring" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +dependencies = [ + "cc", + "cfg-if", + "getrandom", + "libc", + "spin", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rubato" version = "0.15.0" @@ -1548,6 +1756,51 @@ dependencies = [ "version_check", ] +[[package]] +name = "rustix" +version = "0.38.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a85d50532239da68e9addb745ba38ff4612a242c1c7ceea689c4bc7c2f43c36f" +dependencies = [ + "bitflags 2.6.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustls" +version = "0.23.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044" +dependencies = [ + "log", + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" + +[[package]] +name = "rustls-webpki" +version = "0.102.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84678086bd54edf2b415183ed7a94d0efb049f1b646a33e22a36f3794be6ae56" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.17" @@ -1568,18 +1821,18 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.204" +version = "1.0.209" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" +checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.204" +version = "1.0.209" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" +checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" dependencies = [ "proc-macro2", "quote", @@ -1588,9 +1841,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.122" +version = "1.0.127" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784b6203951c57ff748476b126ccb5e8e2959a5c19e5c617ab1956be3dbc68da" +checksum = "8043c06d9f82bd7271361ed64f415fe5e12a77fdb52e573e7f06a516dea329ad" dependencies = [ "itoa", "memchr", @@ -1631,6 +1884,17 @@ dependencies = [ "digest", ] +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1640,6 +1904,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "signal-hook-registry" version = "1.4.2" @@ -1649,6 +1919,16 @@ dependencies = [ "libc", ] +[[package]] +name = "silero" +version = "0.1.0" +source = "git+https://github.com/emotechlab/silero-rs#562df14fc710f9f378b8aeeda613ab7470c5a77f" +dependencies = [ + "anyhow", + "ndarray", + "ort", +] + [[package]] name = "slab" version = "0.4.9" @@ -1674,6 +1954,23 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "streamer-template" version = "0.2.0" @@ -1689,20 +1986,21 @@ dependencies = [ "hound", "init-tracing-opentelemetry", "measured", - "opentelemetry", + "opentelemetry 0.23.0", "opentelemetry-otlp", "opentelemetry-semantic-conventions", - "opentelemetry_sdk", + "opentelemetry_sdk 0.23.0", "rubato", "serde", "serde_json", + "silero", "tokio", "tokio-metrics", "tokio-stream", "tokio-tungstenite", "tower-http", "tracing", - "tracing-opentelemetry", + "tracing-opentelemetry 0.24.0", "tracing-subscriber", "tracing-test", ] @@ -1719,11 +2017,17 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" -version = "2.0.72" +version = "2.0.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc4b9b9bf2add8093d3f2c0204471e951b2285580335de42f9d2534f3ae7a8af" +checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" dependencies = [ "proc-macro2", "quote", @@ -1763,6 +2067,17 @@ dependencies = [ "libc", ] +[[package]] +name = "tar" +version = "0.4.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb797dad5fb5b76fcf519e702f4a589483b5ef06567f160c392832c1f5e44909" +dependencies = [ + "filetime", + "libc", + "xattr", +] + [[package]] name = "thiserror" version = "1.0.63" @@ -1810,9 +2125,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.39.2" +version = "1.40.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" +checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" dependencies = [ "backtrace", "bytes", @@ -1904,7 +2219,7 @@ dependencies = [ "async-stream", "async-trait", "axum 0.6.20", - "base64", + "base64 0.21.7", "bytes", "h2", "http 0.2.12", @@ -1916,7 +2231,7 @@ dependencies = [ "prost", "tokio", "tokio-stream", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -1942,6 +2257,16 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36b837f86b25d7c0d7988f00a54e74739be6477f2aac6201b8f429a7569991b7" +dependencies = [ + "tower-layer", + "tower-service", +] + [[package]] name = "tower-http" version = "0.5.2" @@ -1961,15 +2286,15 @@ dependencies = [ [[package]] name = "tower-layer" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-service" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" @@ -2023,8 +2348,26 @@ checksum = "f68803492bf28ab40aeccaecc7021096bd256baf7ca77c3d425d89b35a7be4e4" dependencies = [ "js-sys", "once_cell", - "opentelemetry", - "opentelemetry_sdk", + "opentelemetry 0.23.0", + "opentelemetry_sdk 0.23.0", + "smallvec", + "tracing", + "tracing-core", + "tracing-log", + "tracing-subscriber", + "web-time", +] + +[[package]] +name = "tracing-opentelemetry" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b" +dependencies = [ + "js-sys", + "once_cell", + "opentelemetry 0.24.0", + "opentelemetry_sdk 0.24.1", "smallvec", "tracing", "tracing-core", @@ -2035,14 +2378,14 @@ dependencies = [ [[package]] name = "tracing-opentelemetry-instrumentation-sdk" -version = "0.19.0" +version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7065f4c337874edb2ba504cb1e487b3bb4f1533a5bb6fcdf72da1575564814c" +checksum = "8159fbb3bd93e20342e7e6ef45b96c5d122cd88043f37ad0e4b5bb052f0f4483" dependencies = [ "http 1.1.0", - "opentelemetry", + "opentelemetry 0.24.0", "tracing", - "tracing-opentelemetry", + "tracing-opentelemetry 0.25.0", ] [[package]] @@ -2159,6 +2502,28 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "ureq" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b74fc6b57825be3373f7054754755f03ac3a8f5d70015ccad699ba2029956f4a" +dependencies = [ + "base64 0.22.1", + "log", + "once_cell", + "rustls", + "rustls-pki-types", + "socks", + "url", + "webpki-roots", +] + [[package]] name = "url" version = "2.5.2" @@ -2211,19 +2576,20 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" dependencies = [ "cfg-if", + "once_cell", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" dependencies = [ "bumpalo", "log", @@ -2236,9 +2602,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.42" +version = "0.4.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76bc14366121efc8dbb487ab05bcc9d346b3b5ec0eaa76e46594cabbe51762c0" +checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" dependencies = [ "cfg-if", "js-sys", @@ -2248,9 +2614,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2258,9 +2624,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" dependencies = [ "proc-macro2", "quote", @@ -2271,15 +2637,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" [[package]] name = "web-sys" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" dependencies = [ "js-sys", "wasm-bindgen", @@ -2295,6 +2661,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-roots" +version = "0.26.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bd24728e5af82c6c4ec1b66ac4844bdf8156257fccda846ec58b42cd0cdbe6a" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "winapi" version = "0.3.9" @@ -2335,6 +2710,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -2467,13 +2851,14 @@ dependencies = [ ] [[package]] -name = "zerocopy" -version = "0.6.6" +name = "xattr" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "854e949ac82d619ee9a14c66a1b674ac730422372ccb759ce0c39cabcf2bf8e6" +checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f" dependencies = [ - "byteorder", - "zerocopy-derive 0.6.6", + "libc", + "linux-raw-sys", + "rustix", ] [[package]] @@ -2482,14 +2867,15 @@ version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ - "zerocopy-derive 0.7.35", + "byteorder", + "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.6.6" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "125139de3f6b9d625c39e2efdd73d41bdac468ccd556556440e322be0e1bbd91" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", @@ -2497,12 +2883,7 @@ dependencies = [ ] [[package]] -name = "zerocopy-derive" -version = "0.7.35" +name = "zeroize" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" diff --git a/Cargo.toml b/Cargo.toml index 028a17f..70fc8ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ opentelemetry_sdk = { version = "0.23.0", features = ["rt-tokio", "trace"] } rubato = "0.15.0" serde = { version = "1.0.200", features = ["derive"] } serde_json = "1.0.117" +silero = { git = "https://github.com/emotechlab/silero-rs" } tokio = { version = "1.37.0", features = ["macros", "signal", "sync", "rt-multi-thread"] } tokio-metrics = "0.3.1" tokio-stream = { version = "0.1.15", features = ["sync"] } diff --git a/README.md b/README.md index e7daf9b..ad6c166 100644 --- a/README.md +++ b/README.md @@ -25,9 +25,10 @@ model we have the decison on whether we: 1. Process all of the incoming audio 2. Detect segments of interest and process them (VAD/energy filtering) -There's also a choice on whether we can process segments concurrently or if the -result from one segment needs to be applied to the future segment for various -reasons i.e. smoothing/hiding seams generative outputs from the audio. +For the first options there's also a choice on whether we can process segments +concurrently or if the result from one segment needs to be applied to the future +segment for various reasons i.e. smoothing/hiding seams generative outputs from +the audio. Enumerating these patterns and representing them all in the code is a WIP. Currently, I process everything and assume no relationship between utterances. diff --git a/src/api_types.rs b/src/api_types.rs index 8265eae..70da365 100644 --- a/src/api_types.rs +++ b/src/api_types.rs @@ -1,4 +1,4 @@ -use crate::{model, OutputEvent}; +use crate::model; use opentelemetry::propagation::Extractor; use serde::{Deserialize, Serialize}; @@ -8,6 +8,8 @@ pub struct StartMessage { pub trace_id: Option, /// Format information for the audio samples pub format: AudioFormat, + // TODO here we likely need some configuration to let people do things like configure the VAD + // sensitivity. } /// Describes the PCM samples coming in. I could have gone for an enum instead of bit_depth + @@ -53,24 +55,32 @@ pub enum RequestMessage { Stop(StopMessage), } -#[derive(Serialize, Deserialize)] +/// If we're processing segments of audio we +#[derive(Debug, Serialize, Deserialize)] +pub struct SegmentOutput { + /// Start time of the segment in seconds + pub start_time: f32, + /// End time of the segment in seconds + pub end_time: f32, + /// Some APIs may do the inverse check of "is_partial" where the last request in an utterance + /// would be `false` + #[serde(skip_serializing_if = "Option::is_none")] + pub is_final: Option, + /// The output from our ML model + #[serde(flatten)] + pub output: model::Output, +} + +#[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum Event { Data(model::Output), + Segment(SegmentOutput), Error(String), Active, Inactive, } -impl From for Event { - fn from(event: OutputEvent) -> Self { - match event { - OutputEvent::Response(o) => Event::Data(o), - OutputEvent::ModelError(e) => Event::Error(e), - } - } -} - #[derive(Serialize, Deserialize)] #[serde(tag = "event", rename_all = "snake_case")] pub struct ResponseMessage { diff --git a/src/audio.rs b/src/audio.rs index dfca0e6..c28ca63 100644 --- a/src/audio.rs +++ b/src/audio.rs @@ -25,7 +25,7 @@ pub async fn decode_audio( anyhow::bail!("No output sinks for channel data"); } - const RESAMPLER_SIZE: usize = 4086; + const RESAMPLER_SIZE: usize = 4096; let resample_ratio = 16000.0 / audio_format.sample_rate as f64; diff --git a/src/axum_server.rs b/src/axum_server.rs index ff2b810..efb5e7d 100644 --- a/src/axum_server.rs +++ b/src/axum_server.rs @@ -1,7 +1,7 @@ use crate::api_types::*; use crate::audio::decode_audio; use crate::metrics::*; -use crate::{OutputEvent, StreamingContext}; +use crate::StreamingContext; use axum::{ extract::{ ws::{Message, WebSocket, WebSocketUpgrade}, @@ -26,11 +26,14 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; async fn ws_handler( ws: WebSocketUpgrade, + vad_processing: bool, Extension(state): Extension>, Extension(metrics): Extension>, ) -> impl IntoResponse { let current = Span::current(); - ws.on_upgrade(move |socket| handle_socket(socket, state, metrics).instrument(current)) + ws.on_upgrade(move |socket| { + handle_socket(socket, vad_processing, state, metrics).instrument(current) + }) } async fn handle_initial_start(receiver: &mut S) -> Option @@ -60,8 +63,7 @@ where start } -fn create_websocket_message(output: OutputEvent) -> Result { - let event = Event::from(output); +fn create_websocket_message(event: Event) -> Result { let string = serde_json::to_string(&event).unwrap(); Ok(Message::Text(string)) } @@ -72,6 +74,7 @@ fn create_websocket_message(output: OutputEvent) -> Result /// tracing harder RE otel context propagation. async fn handle_socket( socket: WebSocket, + vad_processing: bool, state: Arc, metrics_enc: Arc, ) { @@ -121,9 +124,15 @@ async fn handle_socket( let inference_task = TaskMonitor::instrument( &monitors.inference, async move { - context - .inference_runner(samples_rx, client_sender_clone) - .await + if vad_processing { + context + .segmented_runner(samples_rx, client_sender_clone) + .await + } else { + context + .inference_runner(samples_rx, client_sender_clone) + .await + } } .in_current_span(), ); @@ -212,11 +221,20 @@ pub fn make_service_router(app_state: Arc) -> Router { }); Router::new() .route( - "/api/v1/stream", + "/api/v1/simple", + get({ + move |ws, app_state, metrics_enc: Extension>| { + let route = metrics_enc.metrics.route.clone(); + TaskMonitor::instrument(&route, ws_handler(ws, false, app_state, metrics_enc)) + } + }), + ) + .route( + "/api/v1/segmented", get({ move |ws, app_state, metrics_enc: Extension>| { let route = metrics_enc.metrics.route.clone(); - TaskMonitor::instrument(&route, ws_handler(ws, app_state, metrics_enc)) + TaskMonitor::instrument(&route, ws_handler(ws, true, app_state, metrics_enc)) } }), ) diff --git a/src/bin/client.rs b/src/bin/client.rs index 0bb129f..460c928 100644 --- a/src/bin/client.rs +++ b/src/bin/client.rs @@ -28,8 +28,9 @@ struct Cli { #[clap(long, default_value = "256")] /// Size of audio chunks to send to the server chunk_size: usize, - #[clap(short, long, default_value = "ws://localhost:8080/api/v1/stream")] - /// Address of the streaming server + #[clap(short, long, default_value = "ws://localhost:8080/api/v1/segmented")] + /// Address of the streaming server (/api/v1/segmented or /api/v1/simple for vad or non-vad + /// options) addr: String, #[clap(long)] /// Attempts to simulate real time streaming by adding a pause between sending proportional to diff --git a/src/lib.rs b/src/lib.rs index cf18242..b2711e0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ -use crate::model::{Model, Output}; +use crate::api_types::{Event, SegmentOutput}; +use crate::model::Model; use futures::{stream::FuturesOrdered, StreamExt}; +use silero::*; use std::sync::Arc; use std::thread; use tokio::sync::mpsc; @@ -23,17 +25,6 @@ pub async fn launch_server() { info!("Server exiting"); } -pub enum InputEvent { - Start, - Data(Arc>), - Stop, -} - -pub enum OutputEvent { - Response(Output), - ModelError(String), -} - #[derive(Clone)] pub struct StreamingContext { model: Model, @@ -79,7 +70,7 @@ impl StreamingContext { pub async fn inference_runner( self: Arc, mut inference: mpsc::Receiver>>, - output: mpsc::Sender, + output: mpsc::Sender, ) -> anyhow::Result<()> { let mut runners = FuturesOrdered::new(); let mut still_receiving = true; @@ -116,10 +107,10 @@ impl StreamingContext { received_results += 1; debug!("Received inference result: {}", received_results); let msg = match data { - Some(Ok(Ok(output))) => OutputEvent::Response(output), + Some(Ok(Ok(output))) => Event::Data(output), Some(Ok(Err(e))) => { error!("Failed inference event: {}", e); - OutputEvent::ModelError(e.to_string()) + Event::Error(e.to_string()) } Some(Err(_)) => unreachable!("Spawn blocking cannot error"), None => { @@ -135,67 +126,139 @@ impl StreamingContext { } #[instrument(skip_all)] - pub async fn simple( + pub async fn segmented_runner( self: Arc, - mut input: mpsc::Receiver, - output: mpsc::Sender, + mut inference: mpsc::Receiver>>, + output: mpsc::Sender, ) -> anyhow::Result<()> { - let mut data_store = Vec::new(); - let mut is_running = false; - - let (tx, rx) = mpsc::channel(self.max_futures); - let other_self = self.clone(); - let inf_task = task::spawn(async move { - if let Err(e) = other_self.inference_runner(rx, output).await { - error!("Inference failed: {}", e); - } - }); + let mut vad = VadSession::new(VadConfig::default())?; + let mut still_receiving = true; - while let Some(msg) = input.recv().await { - match msg { - InputEvent::Start => { - debug!("Received start message"); - is_running = true; - } - InputEvent::Stop => { - debug!("Received stop message"); - is_running = false; + let mut recv_buffer = Vec::with_capacity(inference.max_capacity()); + + let mut current_start = None; + let mut current_end = None; + + // Need to test and prove this doesn't lose any data! + while still_receiving { + let msg_len = inference + .recv_many(&mut recv_buffer, inference.max_capacity()) + .await; + if msg_len == 0 { + info!("No longer receiving any messages"); + still_receiving = false; + } else { + let mut audio = vec![]; + for samples in recv_buffer.drain(..) { + audio.extend_from_slice(&samples); } - InputEvent::Data(bytes) => { - debug!("Receiving data len: {}", bytes.len()); - if is_running { - // Here we should probably actually do a filtering of the data and push it - data_store.extend(bytes.iter()); - } else { - warn!("Data sent when in stop mode. Discarding"); + let events = vad.process(&audio)?; + + let mut found_endpoint = false; + let mut last_segment = None; + for event in &events { + match event { + VadTransition::SpeechStart { timestamp_ms } => { + info!(time_ms = timestamp_ms, "Detected start of speech"); + match (current_start, current_end) { + (Some(start), Some(end)) if found_endpoint => { + if last_segment.is_some() { + // More than 2x start/end pairs found in a single chunk. + // Something is going wrong! + error!("Found another endpoint but already had a last segment! Losing segment {:?}", last_segment); + } + last_segment = Some((start, end)); + } + (None, _) if found_endpoint => { + error!("Found an endpoint with no start time"); + } + _ => {} + } + current_start = Some(*timestamp_ms); + current_end = None; + found_endpoint = false; + } + VadTransition::SpeechEnd { timestamp_ms } => { + info!(time_ms = timestamp_ms, "Detected end of speech"); + current_end = Some(*timestamp_ms); + found_endpoint = true; + } } } - } - // Check if we want to do an inference - if self.should_run_inference(&data_store, false) { - tx.send(data_store.into()).await?; - data_store = Vec::new(); + if let Some((start, end)) = last_segment { + let audio = vad.get_speech(start, Some(end)).to_vec(); + let msg = self.spawned_inference(audio, Some((start, end))).await; + output.send(msg).await?; + } + + if found_endpoint { + // We actually don't need the start/end if we've got an endpoint! + let audio = vad.get_current_speech().to_vec(); + let msg = self + .spawned_inference(audio, current_start.zip(current_end)) + .await; + output.send(msg).await?; + current_start = None; + current_end = None; + } } } - // Check if we want to do an inference - if self.should_run_inference(&data_store, true) { - tx.send(data_store.into()).await?; - } else { - // Do any final message stuff! + // If we're speaking then we haven't endpointed so do the final inference + if vad.is_speaking() { + let audio = vad.get_current_speech().to_vec(); + let msg = self + .spawned_inference( + audio, + current_start.zip(Some(vad.session_time().as_millis() as usize)), + ) + .await; + output.send(msg).await?; } - std::mem::drop(tx); - let end = inf_task.await; - info!("Finished inference: {:?}", end); + info!("Inference finished"); Ok(()) } + + async fn spawned_inference(&self, audio: Vec, bounds_ms: Option<(usize, usize)>) -> Event { + let current = Span::current(); + let temp_model = self.model.clone(); + let result = task::spawn_blocking(move || { + let span = info_span!(parent: ¤t, "inference_task"); + let _guard = span.enter(); + temp_model.infer(&audio) + }) + .await; + match result { + Ok(Ok(output)) => { + if let Some((start, end)) = bounds_ms { + let start_time = start as f32 / 1000.0; + let end_time = end as f32 / 1000.0; + let seg = SegmentOutput { + start_time, + end_time, + is_final: Some(true), + output, + }; + Event::Segment(seg) + } else { + Event::Data(output) + } + } + Ok(Err(e)) => { + error!("Failed inference event: {}", e); + Event::Error(e.to_string()) + } + Err(_) => unreachable!("Spawn blocking cannot error"), + } + } } #[cfg(test)] mod tests { use super::*; + use crate::model::Output; use tracing_test::traced_test; #[tokio::test] @@ -212,22 +275,17 @@ mod tests { max_futures: 4, }); - let inference = context.simple(input_rx, output_tx); + let inference = context.inference_runner(input_rx, output_tx); let sender = task::spawn(async move { let mut bytes_sent = 0; - input_tx.send(InputEvent::Start).await.unwrap(); for _ in 0..100 { let data = fastrand::u8(5..); bytes_sent += data as usize; let to_send = (0..data).map(|x| x as f32).collect::>(); - input_tx - .send(InputEvent::Data(Arc::new(to_send))) - .await - .unwrap(); + input_tx.send(Arc::new(to_send)).await.unwrap(); } - input_tx.send(InputEvent::Stop).await.unwrap(); info!("Finished sender task"); bytes_sent }); @@ -236,10 +294,10 @@ mod tests { let mut received = 0; while let Some(msg) = output_rx.recv().await { match msg { - OutputEvent::Response(Output { count }) => { + Event::Data(Output { count }) => { received += count; } - OutputEvent::ModelError(e) => panic!("{}", e), + e => panic!("Unexpected: {:?}", e), } } info!("Finished receiver task"); @@ -250,62 +308,10 @@ mod tests { run.unwrap(); assert_eq!(bytes_sent.unwrap(), count_received.unwrap()); - assert!(logs_contain("Received start message")); - assert!(logs_contain("Received stop message")); + assert!(logs_contain("Adding to inference runner task")); assert!(logs_contain("Inference finished")); } - #[tokio::test] - #[traced_test] - async fn no_start() { - let (input_tx, input_rx) = mpsc::channel(10); - let (output_tx, mut output_rx) = mpsc::channel(10); - - let context = Arc::new(StreamingContext::new()); - - let inference = context.simple(input_rx, output_tx); - - let sender = task::spawn(async move { - let mut bytes_sent = 0; - for _ in 0..100 { - let data = fastrand::u8(5..); - bytes_sent += data as usize; - let to_send = (0..data).map(|x| x as f32).collect::>(); - - input_tx - .send(InputEvent::Data(Arc::new(to_send))) - .await - .unwrap(); - } - info!("Finished sender task"); - bytes_sent - }); - - let receiver = task::spawn(async move { - let mut received = 0; - while let Some(msg) = output_rx.recv().await { - match msg { - OutputEvent::Response(Output { count }) => { - received += count; - } - OutputEvent::ModelError(e) => panic!("{}", e), - } - } - info!("Finished receiver task"); - received - }); - - let (bytes_sent, count_received, run) = tokio::join!(sender, receiver, inference); - - run.unwrap(); - assert_eq!(count_received.unwrap(), 0); - assert!(bytes_sent.unwrap() > 0); - assert!(!logs_contain("Received start message")); - assert!(logs_contain("Data sent when in stop mode. Discarding")); - assert!(logs_contain("No longer receiving any messages")); - assert!(!logs_contain("Adding to inference runner task")); - } - #[tokio::test] #[traced_test] async fn broken_model() { @@ -320,17 +326,13 @@ mod tests { max_futures: 4, }); - let inference = context.simple(input_rx, output_tx); + let inference = context.inference_runner(input_rx, output_tx); let sender = task::spawn(async move { - input_tx.send(InputEvent::Start).await.unwrap(); for _ in 0..100 { let to_send = (0..10).map(|x| x as f32).collect::>(); - input_tx - .send(InputEvent::Data(Arc::new(to_send))) - .await - .unwrap(); + input_tx.send(Arc::new(to_send)).await.unwrap(); } info!("Finished sender task"); }); @@ -339,12 +341,12 @@ mod tests { let mut received_errors = 0; while let Some(msg) = output_rx.recv().await { match msg { - OutputEvent::Response(Output { count }) => { - panic!("Didn't expect actual messages back!"); - } - OutputEvent::ModelError(e) => { + Event::Error(_e) => { received_errors += 1; } + _ => { + panic!("Didn't expect actual messages back!"); + } } } info!("Finished receiver task"); @@ -355,7 +357,6 @@ mod tests { run.unwrap(); assert!(count_received.unwrap() > 1); - assert!(logs_contain("Received start message")); assert!(logs_contain("Adding to inference runner task")); assert!(logs_contain("Failed inference event")); }