Skip to content

Commit

Permalink
speex codec optimisation
Browse files Browse the repository at this point in the history
  • Loading branch information
crc-32 committed Oct 30, 2024
1 parent 050188c commit 014c8f5
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@ import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.flow.*
import org.koin.core.component.KoinComponent
import org.koin.core.component.inject
import java.io.InputStream
import java.io.PipedInputStream
import java.io.PipedOutputStream
import java.nio.ByteBuffer
import kotlin.math.roundToInt


@RequiresApi(VERSION_CODES.TIRAMISU)
class SpeechRecognizerDictationService: DictationService, KoinComponent {
Expand All @@ -36,7 +33,7 @@ class SpeechRecognizerDictationService: DictationService, KoinComponent {
class Results(val results: List<Pair<Float, String>>): SpeechRecognizerStatus()
}

private fun beginSpeechRecognition(speechRecognizer: SpeechRecognizer, intent: Intent) = callbackFlow<SpeechRecognizerStatus> {
private fun beginSpeechRecognition(speechRecognizer: SpeechRecognizer, intent: Intent) = callbackFlow {
speechRecognizer.setRecognitionListener(object : RecognitionListener {
private var lastPartials = emptyList<Pair<Float, String>>()
override fun onReadyForSpeech(params: Bundle?) {
Expand All @@ -48,7 +45,7 @@ class SpeechRecognizerDictationService: DictationService, KoinComponent {
}

override fun onRmsChanged(rmsdB: Float) {

//Logging.d("RMS: $rmsdB")
}

override fun onBufferReceived(buffer: ByteArray?) {
Expand Down Expand Up @@ -107,10 +104,11 @@ class SpeechRecognizerDictationService: DictationService, KoinComponent {
return@flow
}
val decoder = SpeexCodec(speexEncoderInfo.sampleRate, speexEncoderInfo.bitRate)
val decodedBuf = ByteArray(speexEncoderInfo.frameSize * Short.SIZE_BYTES)
val decodeBufLength = Short.SIZE_BYTES * speexEncoderInfo.frameSize
val decodedBuf = ByteBuffer.allocateDirect(decodeBufLength)
val recognizerPipes = ParcelFileDescriptor.createSocketPair()
val recognizerReadPipe = recognizerPipes[0]
val recognizerWritePipe = ParcelFileDescriptor.AutoCloseOutputStream(recognizerPipes[1]).buffered(320 * Short.SIZE_BYTES)
val recognizerWritePipe = ParcelFileDescriptor.AutoCloseOutputStream(recognizerPipes[1])
val recognizerIntent = buildRecognizerIntent(recognizerReadPipe, AudioFormat.ENCODING_PCM_16BIT, speexEncoderInfo.sampleRate.toInt())
//val recognizerIntent = buildRecognizerIntent()
val speechRecognizer = withContext(Dispatchers.Main) {
Expand All @@ -131,30 +129,29 @@ class SpeechRecognizerDictationService: DictationService, KoinComponent {
emit(DictationServiceResponse.Error(Result.FailServiceUnavailable))
return@flow
}

val audioJob = scope.launch {
audioStreamFrames
.onEach { frame ->
if (frame is AudioStreamFrame.Stop) {
//Logging.v("Stop")
withContext(Dispatchers.IO) {
recognizerWritePipe.flush()
}
recognizerWritePipe.flush()
withContext(Dispatchers.Main) {
//XXX: Shouldn't use main here for I/O call but recognizer has weird thread behaviour
recognizerWritePipe.close()
recognizerReadPipe.close()
speechRecognizer.stopListening()
}
} else if (frame is AudioStreamFrame.AudioData) {
decodedBuf.rewind()
val result = decoder.decodeFrame(frame.data, decodedBuf, hasHeaderByte = true)
if (result != SpeexDecodeResult.Success) {
Logging.e("Speex decode error: ${result.name}")
}
withContext(Dispatchers.IO) {
recognizerWritePipe.write(decodedBuf)
}
recognizerWritePipe.write(decodedBuf.array(), decodedBuf.arrayOffset(), decodeBufLength)
}
}
.flowOn(Dispatchers.IO)
.catch {
Logging.e("Error in audio stream: $it")
}
Expand Down
7 changes: 6 additions & 1 deletion android/speex_codec/src/main/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,19 @@ add_library(${CMAKE_PROJECT_NAME} SHARED
speex/libspeex/gain_table_lbr.c
)

target_compile_options(${CMAKE_PROJECT_NAME} PRIVATE
"$<$<CONFIG:RELEASE>:-O3>"
"$<$<CONFIG:DEBUG>:-O3>"
)

target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE
# List paths to include headers from
include
speex/include
)

target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE
FLOATING_POINT
FIXED_POINT
"EXPORT=/* */"
)

Expand Down
35 changes: 17 additions & 18 deletions android/speex_codec/src/main/cpp/speex_codec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,21 @@
#include <string>
#include <speex/speex.h>

static jfieldID speexDecBits;
static jfieldID speexDecState;

extern "C"
JNIEXPORT jint JNICALL
Java_com_example_speex_1codec_SpeexCodec_decode(JNIEnv *env, jobject thiz,
jbyteArray encoded_frame, jbyteArray out_frame, jboolean has_header_byte) {
jbyteArray encoded_frame, jobject out_frame, jboolean has_header_byte) {
jbyte *encoded_frame_data = env->GetByteArrayElements(encoded_frame, nullptr);
jsize encoded_frame_length = env->GetArrayLength(encoded_frame);
if (has_header_byte) {
// Skip the first byte
encoded_frame_data++;
encoded_frame_length--;
}
auto *bits = reinterpret_cast<SpeexBits *>(env->GetLongField(thiz, env->GetFieldID(env->GetObjectClass(thiz), "speexDecBits", "J")));
auto *dec_state = reinterpret_cast<void *>(env->GetLongField(thiz, env->GetFieldID(env->GetObjectClass(thiz), "speexDecState", "J")));
jshort pcm_frame[320];
speex_bits_read_from(bits, reinterpret_cast<char *>(encoded_frame_data), encoded_frame_length);
int result = speex_decode_int(dec_state, bits, pcm_frame);
if (result == 0) {
env->SetByteArrayRegion(out_frame, 0, sizeof(pcm_frame), reinterpret_cast<jbyte *>(pcm_frame));
}
if (has_header_byte) {
// Restore the first byte, so that the encoded_frame_data pointer points to the original address
encoded_frame_data--;
}
auto *out_frame_data = reinterpret_cast<spx_int16_t *>(env->GetDirectBufferAddress(out_frame));
auto *bits = reinterpret_cast<SpeexBits *>(env->GetLongField(thiz, speexDecBits));
auto *dec_state = reinterpret_cast<void *>(env->GetLongField(thiz, speexDecState));
int offset = has_header_byte ? 1 : 0;
speex_bits_read_from(bits, reinterpret_cast<char *>(encoded_frame_data)+offset, encoded_frame_length-offset);
int result = speex_decode_int(dec_state, bits, out_frame_data);
env->ReleaseByteArrayElements(encoded_frame, encoded_frame_data, 0);
return result;
}
Expand Down Expand Up @@ -56,4 +48,11 @@ Java_com_example_speex_1codec_SpeexCodec_destroyDecState(JNIEnv *env, jobject th
jlong dec_state) {
auto *state = reinterpret_cast<void *>(dec_state);
speex_decoder_destroy(state);
}
extern "C"
JNIEXPORT void JNICALL
Java_com_example_speex_1codec_SpeexCodec_initNative(JNIEnv *env, jobject thiz) {
jclass clazz = env->GetObjectClass(thiz);
speexDecBits = env->GetFieldID(clazz, "speexDecBits", "J");
speexDecState = env->GetFieldID(clazz, "speexDecState", "J");
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package com.example.speex_codec

import android.media.MediaCodec
import java.nio.ByteBuffer

class SpeexCodec(private val sampleRate: Long, private val bitRate: Int): AutoCloseable {
init {
initNative()
}
private val speexDecBits: Long = initSpeexBits()
private val speexDecState: Long = initDecState(sampleRate, bitRate)

Expand All @@ -12,7 +16,7 @@ class SpeexCodec(private val sampleRate: Long, private val bitRate: Int): AutoCl
* @param decodedFrame The buffer to store the decoded frame in.
*
*/
fun decodeFrame(encodedFrame: ByteArray, decodedFrame: ByteArray, hasHeaderByte: Boolean = true): SpeexDecodeResult {
fun decodeFrame(encodedFrame: ByteArray, decodedFrame: ByteBuffer, hasHeaderByte: Boolean = true): SpeexDecodeResult {
return SpeexDecodeResult.fromInt(decode(encodedFrame, decodedFrame, hasHeaderByte))
}

Expand All @@ -21,7 +25,8 @@ class SpeexCodec(private val sampleRate: Long, private val bitRate: Int): AutoCl
destroyDecState(speexDecState)
}

private external fun decode(encodedFrame: ByteArray, decodedFrame: ByteArray, hasHeaderByte: Boolean): Int
private external fun initNative()
private external fun decode(encodedFrame: ByteArray, decodedFrame: ByteBuffer, hasHeaderByte: Boolean): Int
private external fun initSpeexBits(): Long
private external fun initDecState(sampleRate: Long, bitRate: Int): Long
private external fun destroySpeexBits(speexBits: Long)
Expand Down

0 comments on commit 014c8f5

Please sign in to comment.