Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Significantly improve whisper.cpp inference quality #1148

Merged
merged 25 commits into from
Aug 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4767223
Fix MSVC compile error C3688
bobqianic Jul 26, 2023
38eaeff
Significantly improve inference quality
bobqianic Aug 1, 2023
2dd6884
Merge branch 'ggerganov:master' into master
bobqianic Aug 1, 2023
4ebe450
Significantly improve inference quality
bobqianic Aug 2, 2023
6f445d1
Addressed a few minor issues
bobqianic Aug 2, 2023
7f690dd
Significantly improve inference quality
bobqianic Aug 2, 2023
527d7c6
Merge branch 'ggerganov:master' into master
bobqianic Aug 3, 2023
f3e7774
Add annotation and performance improvement
bobqianic Aug 3, 2023
95be6dc
Calculate FFT only when fft_in are not all zero
bobqianic Aug 3, 2023
bd1dbd1
Some minor performance improvement
bobqianic Aug 4, 2023
2c49c9b
Fixed a bug impacting inference quality
bobqianic Aug 4, 2023
5df242c
Merge branch 'ggerganov:master' into master
bobqianic Aug 11, 2023
e40ec27
The first version after all the analysis is completed.
bobqianic Aug 11, 2023
715bf61
Fix some bugs and add debug mode
bobqianic Aug 12, 2023
3fe41d5
Fixed several bugs
bobqianic Aug 12, 2023
36b0df7
Temporarily disable speed-up mode and add debug mode.
bobqianic Aug 13, 2023
444b59a
Add debug mode
bobqianic Aug 13, 2023
308f490
Disable speed-up mode and add debug mode
bobqianic Aug 13, 2023
252f807
Fix CI error (#1)
bobqianic Aug 13, 2023
0a5f435
Fixed several bugs including [BLANK_AUDIO] problem
bobqianic Aug 13, 2023
65fd0e1
Remove Hard-coded hann window
bobqianic Aug 13, 2023
386ef32
Some Final Fix (#2)
bobqianic Aug 14, 2023
241df86
Merge branch 'master' into master
bobqianic Aug 25, 2023
22d348c
whisper : minor coding style changes
ggerganov Aug 27, 2023
590a12e
whisper : remove debug from public API
ggerganov Aug 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ struct whisper_params {
float logprob_thold = -1.00f;

bool speed_up = false;
bool debug_mode = false;
bool translate = false;
bool detect_language = false;
bool diarize = false;
Expand Down Expand Up @@ -134,7 +135,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
// else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
Expand Down Expand Up @@ -188,7 +190,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
Expand Down Expand Up @@ -893,6 +896,7 @@ int main(int argc, char ** argv) {
wparams.split_on_word = params.split_on_word;

wparams.speed_up = params.speed_up;
wparams.debug_mode = params.debug_mode;

wparams.tdrz_enable = params.tinydiarize; // [TDRZ]

Expand Down
196 changes: 114 additions & 82 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2445,40 +2445,50 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
}
}

static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> &hann, const float *samples,
int n_samples, int fft_size, int fft_step, int n_threads,
const whisper_filters &filters, bool speed_up, whisper_mel &mel) {
std::vector<float> fft_in(fft_size, 0.0);
std::vector<float> fft_out(2 * fft_size);
int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2);

for (int i = ith; i < mel.n_len; i += n_threads) {
const int offset = i * fft_step;

// apply Hanning window
for (int j = 0; j < fft_size; j++) {
if (offset + j < n_samples) {
fft_in[j] = hann[j] * samples[offset + j];
} else {
fft_in[j] = 0.0;
}
}
static bool hann_window(int length, bool periodic, std::vector<float> & output) {
if (output.size() < length) {
output.resize(length);
}
int offset = -1;
if (periodic) {
offset = 0;
}
for (int i = 0; i < length; i++) {
output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
}

// FFT -> mag^2
fft(fft_in, fft_out);
return true;
}

for (int j = 0; j < fft_size; j++) {
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples,
int n_samples, int frame_size, int frame_step, int n_threads,
const whisper_filters & filters, whisper_mel & mel) {
std::vector<float> fft_in(frame_size, 0.0);
std::vector<float> fft_out(2 * frame_step);
// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
int n_fft = 1 + (frame_size / 2);
int i = ith;

// calculate FFT only when fft_in are not all zero
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
const int offset = i * frame_step;

// apply Hanning window (~10% faster)
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
fft_in[j] = hann[j] * samples[offset + j];
}
for (int j = 1; j < fft_size / 2; j++) {
fft_out[j] += fft_out[fft_size - j];
// fill the rest with zeros
if (n_samples - offset < frame_size) {
std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
}

if (speed_up) {
// scale down in the frequency domain results in a speed up in the time domain
for (int j = 0; j < n_fft; j++) {
fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]);
}
// FFT
fft(fft_in, fft_out);

// Calculate modulus^2 of complex numbers
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
for (int j = 0; j < frame_size; j++) {
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
}

// mel spectrogram
Expand All @@ -2489,10 +2499,10 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
int k = 0;
for (k = 0; k < n_fft - 3; k += 4) {
sum +=
fft_out[k + 0] * filters.data[j*n_fft + k + 0] +
fft_out[k + 1] * filters.data[j*n_fft + k + 1] +
fft_out[k + 2] * filters.data[j*n_fft + k + 2] +
fft_out[k + 3] * filters.data[j*n_fft + k + 3];
fft_out[k + 0] * filters.data[j * n_fft + k + 0] +
fft_out[k + 1] * filters.data[j * n_fft + k + 1] +
fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
fft_out[k + 3] * filters.data[j * n_fft + k + 3];
}

// handle n_fft remainder
Expand All @@ -2505,68 +2515,73 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
mel.data[j * mel.n_len + i] = sum;
}
}

// Otherwise fft_out are all zero
double sum = log10(1e-10);
for (; i < mel.n_len; i += n_threads) {
for (int j = 0; j < mel.n_mel; j++) {
mel.data[j * mel.n_len + i] = sum;
}
}
}

// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
static bool log_mel_spectrogram(
whisper_state & wstate,
const float * samples,
whisper_state & wstate,
const float * samples,
const int n_samples,
const int /*sample_rate*/,
const int fft_size,
const int fft_step,
const int frame_size,
const int frame_step,
const int n_mel,
const int n_threads,
const whisper_filters & filters,
const bool speed_up,
whisper_mel & mel) {
const whisper_filters & filters,
const bool debug,
whisper_mel & mel) {
const int64_t t_start_us = ggml_time_us();

// Hanning window
// Hanning window (Use cosf to eliminate difference)
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
std::vector<float> hann;
hann.resize(fft_size);
for (int i = 0; i < fft_size; i++) {
hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
}

mel.n_mel = n_mel;
mel.n_len = n_samples/fft_step;
mel.n_len_org = mel.n_len;
hann_window(frame_size, true, hann);

std::vector<float> samples_padded;

// pad audio with at least one extra chunk of zeros
{
const int pad = (100*WHISPER_CHUNK_SIZE)/2;
// Calculate the length of padding
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
int64_t stage_2_pad = frame_size / 2;

if (mel.n_len % pad != 0) {
mel.n_len = (mel.n_len/pad + 1)*pad;
}
mel.n_len += pad;
// Initialize a vector and copy data from C array to it.
std::vector<float> samples_padded;
samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);

samples_padded.resize(mel.n_len*fft_step);
memcpy(samples_padded.data(), samples, n_samples*sizeof(float));
memset(samples_padded.data() + n_samples, 0, (mel.n_len*fft_step - n_samples)*sizeof(float));
// pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);

samples = samples_padded.data();
}
// reflective pad 200 samples at the beginning of audio
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());

mel.data.resize(mel.n_mel*mel.n_len);
mel.n_mel = n_mel;
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
// Calculate number of frames + remove the last frame
mel.n_len = (samples_padded.size() - frame_size) / frame_step;
// Calculate semi-padded sample length to ensure compatibility
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
mel.data.resize(mel.n_mel * mel.n_len);

//printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
//printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);

{
std::vector<std::thread> workers(n_threads - 1);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw] = std::thread(
log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples,
n_samples, fft_size, fft_step, n_threads,
std::cref(filters), speed_up, std::ref(mel));
log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded,
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
std::cref(filters), std::ref(mel));
}

// main thread
log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel);
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel);

for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw].join();
Expand All @@ -2580,7 +2595,6 @@ static bool log_mel_spectrogram(
mmax = mel.data[i];
}
}
//printf("%s: max = %f\n", __func__, mmax);

mmax -= 8.0;

Expand All @@ -2594,7 +2608,16 @@ static bool log_mel_spectrogram(

wstate.t_mel_us += ggml_time_us() - t_start_us;

//printf("mel.n_len() = %d, divided by 1500: %f, n_samples / fft_step: %d\n", mel.n_len, mel.n_len / 1500.0, n_samples / fft_step);
// Dump log_mel_spectrogram
if (debug) {
std::ofstream outFile("log_mel_spectrogram.json");
outFile << "[";
for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
outFile << mel.data[i] << ", ";
}
outFile << mel.data[mel.data.size() - 1] << "]";
outFile.close();
}

return true;
}
Expand Down Expand Up @@ -3026,21 +3049,30 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
}

// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, state->mel)) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
log("%s: failed to compute mel spectrogram\n", __func__);
return -1;
}

return 0;
}

// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
}

// same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
// TODO

// same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
// TODO

// same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
// TODO

int whisper_set_mel_with_state(
struct whisper_context * /*ctx*/,
struct whisper_state * state,
Expand Down Expand Up @@ -3492,6 +3524,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.max_tokens =*/ 0,

/*.speed_up =*/ false,
/*.debug_mode =*/ false,
/*.audio_ctx =*/ 0,

/*.tdrz_enable =*/ false,
Expand Down Expand Up @@ -3653,7 +3686,7 @@ static void whisper_process_logits(
WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab);

// extract the logits for the last token
// we will be mutating and therefore we don't want to use the ctx.logits buffer directly
// we will be mutating, and therefore we don't want to use the ctx.logits buffer directly
auto & probs = decoder.probs;
auto & logits = decoder.logits;
auto & logprobs = decoder.logprobs;
Expand Down Expand Up @@ -4056,10 +4089,9 @@ int whisper_full_with_state(

// compute log mel spectrogram
if (params.speed_up) {
if (whisper_pcm_to_mel_phase_vocoder_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
log("%s: failed to compute log mel spectrogram\n", __func__);
return -1;
}
// TODO: Replace PV with more advanced algorithm
log("%s: failed to compute log mel spectrogram\n", __func__);
return -1;
} else {
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
log("%s: failed to compute log mel spectrogram\n", __func__);
Expand Down Expand Up @@ -4095,8 +4127,8 @@ int whisper_full_with_state(
const int seek_start = params.offset_ms/10;
const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10;

// if length of spectrogram is less than 1s (100 samples), then return
// basically don't process anything that is less than 1s
// if length of spectrogram is less than 1.0s (100 frames), then return
// basically don't process anything that is less than 1.0s
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
return 0;
Expand Down
1 change: 1 addition & 0 deletions whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ extern "C" {
// [EXPERIMENTAL] speed-up techniques
// note: these can significantly reduce the quality of the output
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel)
int audio_ctx; // overwrite the audio context size (0 = use default)

// [EXPERIMENTAL] [TDRZ] tinydiarize
Expand Down
Loading