From 014c8f57449542635baaa9ced91487c2baad7005 Mon Sep 17 00:00:00 2001 From: crc-32 Date: Wed, 30 Oct 2024 22:41:32 +0000 Subject: [PATCH] speex codec optimisation --- .../voice/SpeechRecognizerDictationService.kt | 25 ++++++------- .../speex_codec/src/main/cpp/CMakeLists.txt | 7 +++- .../speex_codec/src/main/cpp/speex_codec.cpp | 35 +++++++++---------- .../com/example/speex_codec/SpeexCodec.kt | 9 +++-- 4 files changed, 41 insertions(+), 35 deletions(-) diff --git a/android/shared/src/androidMain/kotlin/io/rebble/cobble/shared/domain/voice/SpeechRecognizerDictationService.kt b/android/shared/src/androidMain/kotlin/io/rebble/cobble/shared/domain/voice/SpeechRecognizerDictationService.kt index 404d480e..de28027a 100644 --- a/android/shared/src/androidMain/kotlin/io/rebble/cobble/shared/domain/voice/SpeechRecognizerDictationService.kt +++ b/android/shared/src/androidMain/kotlin/io/rebble/cobble/shared/domain/voice/SpeechRecognizerDictationService.kt @@ -19,11 +19,8 @@ import kotlinx.coroutines.channels.awaitClose import kotlinx.coroutines.flow.* import org.koin.core.component.KoinComponent import org.koin.core.component.inject -import java.io.InputStream -import java.io.PipedInputStream -import java.io.PipedOutputStream import java.nio.ByteBuffer -import kotlin.math.roundToInt + @RequiresApi(VERSION_CODES.TIRAMISU) class SpeechRecognizerDictationService: DictationService, KoinComponent { @@ -36,7 +33,7 @@ class SpeechRecognizerDictationService: DictationService, KoinComponent { class Results(val results: List>): SpeechRecognizerStatus() } - private fun beginSpeechRecognition(speechRecognizer: SpeechRecognizer, intent: Intent) = callbackFlow { + private fun beginSpeechRecognition(speechRecognizer: SpeechRecognizer, intent: Intent) = callbackFlow { speechRecognizer.setRecognitionListener(object : RecognitionListener { private var lastPartials = emptyList>() override fun onReadyForSpeech(params: Bundle?) { @@ -48,7 +45,7 @@ class SpeechRecognizerDictationService: DictationService, KoinComponent { } override fun onRmsChanged(rmsdB: Float) { - + //Logging.d("RMS: $rmsdB") } override fun onBufferReceived(buffer: ByteArray?) { @@ -107,10 +104,11 @@ class SpeechRecognizerDictationService: DictationService, KoinComponent { return@flow } val decoder = SpeexCodec(speexEncoderInfo.sampleRate, speexEncoderInfo.bitRate) - val decodedBuf = ByteArray(speexEncoderInfo.frameSize * Short.SIZE_BYTES) + val decodeBufLength = Short.SIZE_BYTES * speexEncoderInfo.frameSize + val decodedBuf = ByteBuffer.allocateDirect(decodeBufLength) val recognizerPipes = ParcelFileDescriptor.createSocketPair() val recognizerReadPipe = recognizerPipes[0] - val recognizerWritePipe = ParcelFileDescriptor.AutoCloseOutputStream(recognizerPipes[1]).buffered(320 * Short.SIZE_BYTES) + val recognizerWritePipe = ParcelFileDescriptor.AutoCloseOutputStream(recognizerPipes[1]) val recognizerIntent = buildRecognizerIntent(recognizerReadPipe, AudioFormat.ENCODING_PCM_16BIT, speexEncoderInfo.sampleRate.toInt()) //val recognizerIntent = buildRecognizerIntent() val speechRecognizer = withContext(Dispatchers.Main) { @@ -131,14 +129,13 @@ class SpeechRecognizerDictationService: DictationService, KoinComponent { emit(DictationServiceResponse.Error(Result.FailServiceUnavailable)) return@flow } + val audioJob = scope.launch { audioStreamFrames .onEach { frame -> if (frame is AudioStreamFrame.Stop) { //Logging.v("Stop") - withContext(Dispatchers.IO) { - recognizerWritePipe.flush() - } + recognizerWritePipe.flush() withContext(Dispatchers.Main) { //XXX: Shouldn't use main here for I/O call but recognizer has weird thread behaviour recognizerWritePipe.close() @@ -146,15 +143,15 @@ class SpeechRecognizerDictationService: DictationService, KoinComponent { speechRecognizer.stopListening() } } else if (frame is AudioStreamFrame.AudioData) { + decodedBuf.rewind() val result = decoder.decodeFrame(frame.data, decodedBuf, hasHeaderByte = true) if (result != SpeexDecodeResult.Success) { Logging.e("Speex decode error: ${result.name}") } - withContext(Dispatchers.IO) { - recognizerWritePipe.write(decodedBuf) - } + recognizerWritePipe.write(decodedBuf.array(), decodedBuf.arrayOffset(), decodeBufLength) } } + .flowOn(Dispatchers.IO) .catch { Logging.e("Error in audio stream: $it") } diff --git a/android/speex_codec/src/main/cpp/CMakeLists.txt b/android/speex_codec/src/main/cpp/CMakeLists.txt index 2b06aa75..e7586298 100644 --- a/android/speex_codec/src/main/cpp/CMakeLists.txt +++ b/android/speex_codec/src/main/cpp/CMakeLists.txt @@ -59,6 +59,11 @@ add_library(${CMAKE_PROJECT_NAME} SHARED speex/libspeex/gain_table_lbr.c ) +target_compile_options(${CMAKE_PROJECT_NAME} PRIVATE + "$<$:-O3>" + "$<$:-O3>" +) + target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE # List paths to include headers from include @@ -66,7 +71,7 @@ target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE ) target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE - FLOATING_POINT + FIXED_POINT "EXPORT=/* */" ) diff --git a/android/speex_codec/src/main/cpp/speex_codec.cpp b/android/speex_codec/src/main/cpp/speex_codec.cpp index b271f005..06a72e5b 100644 --- a/android/speex_codec/src/main/cpp/speex_codec.cpp +++ b/android/speex_codec/src/main/cpp/speex_codec.cpp @@ -2,29 +2,21 @@ #include #include +static jfieldID speexDecBits; +static jfieldID speexDecState; + extern "C" JNIEXPORT jint JNICALL Java_com_example_speex_1codec_SpeexCodec_decode(JNIEnv *env, jobject thiz, - jbyteArray encoded_frame, jbyteArray out_frame, jboolean has_header_byte) { + jbyteArray encoded_frame, jobject out_frame, jboolean has_header_byte) { jbyte *encoded_frame_data = env->GetByteArrayElements(encoded_frame, nullptr); jsize encoded_frame_length = env->GetArrayLength(encoded_frame); - if (has_header_byte) { - // Skip the first byte - encoded_frame_data++; - encoded_frame_length--; - } - auto *bits = reinterpret_cast(env->GetLongField(thiz, env->GetFieldID(env->GetObjectClass(thiz), "speexDecBits", "J"))); - auto *dec_state = reinterpret_cast(env->GetLongField(thiz, env->GetFieldID(env->GetObjectClass(thiz), "speexDecState", "J"))); - jshort pcm_frame[320]; - speex_bits_read_from(bits, reinterpret_cast(encoded_frame_data), encoded_frame_length); - int result = speex_decode_int(dec_state, bits, pcm_frame); - if (result == 0) { - env->SetByteArrayRegion(out_frame, 0, sizeof(pcm_frame), reinterpret_cast(pcm_frame)); - } - if (has_header_byte) { - // Restore the first byte, so that the encoded_frame_data pointer points to the original address - encoded_frame_data--; - } + auto *out_frame_data = reinterpret_cast(env->GetDirectBufferAddress(out_frame)); + auto *bits = reinterpret_cast(env->GetLongField(thiz, speexDecBits)); + auto *dec_state = reinterpret_cast(env->GetLongField(thiz, speexDecState)); + int offset = has_header_byte ? 1 : 0; + speex_bits_read_from(bits, reinterpret_cast(encoded_frame_data)+offset, encoded_frame_length-offset); + int result = speex_decode_int(dec_state, bits, out_frame_data); env->ReleaseByteArrayElements(encoded_frame, encoded_frame_data, 0); return result; } @@ -56,4 +48,11 @@ Java_com_example_speex_1codec_SpeexCodec_destroyDecState(JNIEnv *env, jobject th jlong dec_state) { auto *state = reinterpret_cast(dec_state); speex_decoder_destroy(state); +} +extern "C" +JNIEXPORT void JNICALL +Java_com_example_speex_1codec_SpeexCodec_initNative(JNIEnv *env, jobject thiz) { + jclass clazz = env->GetObjectClass(thiz); + speexDecBits = env->GetFieldID(clazz, "speexDecBits", "J"); + speexDecState = env->GetFieldID(clazz, "speexDecState", "J"); } \ No newline at end of file diff --git a/android/speex_codec/src/main/java/com/example/speex_codec/SpeexCodec.kt b/android/speex_codec/src/main/java/com/example/speex_codec/SpeexCodec.kt index c290ff2f..8fdd1988 100644 --- a/android/speex_codec/src/main/java/com/example/speex_codec/SpeexCodec.kt +++ b/android/speex_codec/src/main/java/com/example/speex_codec/SpeexCodec.kt @@ -1,8 +1,12 @@ package com.example.speex_codec import android.media.MediaCodec +import java.nio.ByteBuffer class SpeexCodec(private val sampleRate: Long, private val bitRate: Int): AutoCloseable { + init { + initNative() + } private val speexDecBits: Long = initSpeexBits() private val speexDecState: Long = initDecState(sampleRate, bitRate) @@ -12,7 +16,7 @@ class SpeexCodec(private val sampleRate: Long, private val bitRate: Int): AutoCl * @param decodedFrame The buffer to store the decoded frame in. * */ - fun decodeFrame(encodedFrame: ByteArray, decodedFrame: ByteArray, hasHeaderByte: Boolean = true): SpeexDecodeResult { + fun decodeFrame(encodedFrame: ByteArray, decodedFrame: ByteBuffer, hasHeaderByte: Boolean = true): SpeexDecodeResult { return SpeexDecodeResult.fromInt(decode(encodedFrame, decodedFrame, hasHeaderByte)) } @@ -21,7 +25,8 @@ class SpeexCodec(private val sampleRate: Long, private val bitRate: Int): AutoCl destroyDecState(speexDecState) } - private external fun decode(encodedFrame: ByteArray, decodedFrame: ByteArray, hasHeaderByte: Boolean): Int + private external fun initNative() + private external fun decode(encodedFrame: ByteArray, decodedFrame: ByteBuffer, hasHeaderByte: Boolean): Int private external fun initSpeexBits(): Long private external fun initDecState(sampleRate: Long, bitRate: Int): Long private external fun destroySpeexBits(speexBits: Long)