Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Significantly improve whisper.cpp inference quality #1148

Merged
merged 25 commits into from
Aug 27, 2023
Merged

Significantly improve whisper.cpp inference quality #1148

merged 25 commits into from
Aug 27, 2023

Conversation

bobqianic
Copy link
Collaborator

@bobqianic bobqianic commented Aug 1, 2023

We've wrapped up our analysis comparing the log_mel_spectrogram generation between whisper.cpp and OpenAI's Whisper.

To summarize the main issues we found in whisper.cpp:

  1. The Stage-1 padding (zero padding) is inadequate. While OpenAI's Whisper uses a padding of 480,000 samples, whisper.cpp only goes between 240,000 and 480,000.

  2. Stage-2 (reflective padding) is missing. This oversight can introduce edge effects when whisper.cpp processes the STFT, potentially causing spectral leakage.

  3. The frame count calculation isn't accurate.

  4. After performing the FFT, it mistakenly aggregates the amplitudes from symmetrical frequency bins.

On top of these, whisper.cpp presents a couple of secondary concerns:

  1. The trig functions in the C++ library default to FP64 computations, in contrast to PyTorch's FP32. This leads to significant discrepancies, especially with smaller angles.

  2. While whisper.cpp offers a feature to shift from mono to a kind of simulated stereo, it's uncertain if the model fully supports stereo audio.

  3. whisper.cpp doesn't remove the last frame like OpenAI's Whisper does, which could lead to some potential issues.

With these findings in hand, we're set to fix whisper.cpp.

Uncovering these issues wouldn't have been possible without the collective support and assistance from everyone. A heartfelt thanks to @regularfry for his invaluable technical expertise in signal processing and for suggesting solutions. I'm grateful to @gauvainjl for introducing the WER testing method. A special shoutout to @ggerganov for his unwavering support. And last but not least, a big thank you to the entire community for their continued attention to this project.

Full report

Whisper_FFT

After conducting some experiments, I found that the smaller the model size, the more significant the improvement in inference quality. There is almost no noticeable difference when using the large model.

diffusion2023-07-03.wav ggml-model-tiny.bin
A complete comparison is here.

image

diffusion2023-07-03.wav ggml-model-base.bin
A complete comparison is here.

image

a13.wav
A complete comparison is currently unavailable.

image

Instead of simply using 'add_compile_options(/utf-8)' to address the MSVC compile error C3688, a better approach would be to handle it in a way that prevents passing '/utf-8' to NVCC.
In the function `log_mel_spectrogram_worker_thread`, there's an array out-of-bounds issue occurring during the calculation of complex number moduli. This issue is causing disruptions in the FFT spectrum, which, in turn, is reducing the quality of inference.
At last, I've pinpointed the actual source of the problem. Given that the frequency spectrum generated from real input data is symmetrical around the Nyquist frequency, there's a for-loop within the `log_mel_spectrogram_worker_thread` function that attempts to fold the frequency spectrum. Regrettably, a bug within this for-loop is causing a frame shift in the frequency spectrum. The previous attempt to remedy this, which involved using `fft_size + 1` when calculating the modulus, was merely a band-aid solution and did not address the underlying issue.
@bobqianic bobqianic closed this Aug 2, 2023
Fixed the issue of `fft_out` continuously expanding. Resolved the fallback caused by using 'break' instead of `fft_in[j] = 0`.
@bobqianic bobqianic reopened this Aug 2, 2023
@bobqianic bobqianic closed this Aug 2, 2023
Thanks for your patience everyone. It's finally sorted out. Now, the right side of the FFT spectrum is being flipped over to the left, and the amplitudes at corresponding positions on the left and right are added together (the spectrum on the left needs to be shifted by one position), then the average is calculated. FFT_OUT[0] is no longer discarded, making full use of the limited space to pack in more information.
@bobqianic bobqianic reopened this Aug 2, 2023
@bobqianic
Copy link
Collaborator Author

bobqianic commented Aug 3, 2023

Interestingly, when the Speed-up mode is switched on – and it's worth noting that it's off by default – the system occasionally generates numerous duplicate sentences. Importantly, this issue doesn't appear when the Speed-up mode is deactivated. The cause of this problem remains unclear at the moment, but rest assured, I'm working diligently to resolve it as quickly as possible.

image

Significant quality improvement speedup mode ON ggml-model-tiny.bin

image

What's that?

image

It seems there might be a bug in the Hann window calculation within whisper.cpp. According to OpenAI's Whisper, they employ torch.hann_window(). Upon examining how Torch implements this function, it becomes evident that they use N - 1 as the denominator, where N is the size of the Hann window. Consequently, we should be using FFT_SIZE - 1 in our calculations.

image

whisper.cpp/whisper.cpp

Lines 2507 to 2512 in a792c40

// Hanning window
std::vector<float> hann;
hann.resize(fft_size);
for (int i = 0; i < fft_size; i++) {
hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
}

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this. I always suspected that there could be a problem in the mel spectrogram computation. I even created an issue about this topic: #568

Cannot test at the moment, but would be great if more people give this a try and confirm the improvement.

whisper.cpp Outdated
int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2);

for (int i = ith; i < mel.n_len; i += n_threads) {
std::vector<float> fft_in(fft_size, 0.0);
std::vector<float> fft_out(2 * fft_size);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the benefit of moving these vectors inside the loop?

Copy link
Collaborator Author

@bobqianic bobqianic Aug 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Firstly, moving std::vector<float> fft_in(fft_size, 0.0); inside the loop can take advantage of the built-in padding feature Fastest way to reset every value of std::vector to 0 when assigning vectors, which theoretically should be faster than manually using a for loop. This section of code could theoretically be optimized further by removing the if statement. The reason for moving std::vector<float> fft_out(2 * fft_size); is that during debugging, I found that fft_out would continuously grow in length after the FFT computation, far exceeding the theoretical length of 800.

Sorry... I've conducted further testing, and it appears that my initial reasons for moving them into the loop were not correct.

This comment was marked as off-topic.

@bobqianic
Copy link
Collaborator Author

I believe the most effective solution to address these issues is a complete rewrite of the log_mel_spectrogram function in whisper.cpp. We should do this in line with the guidelines provided in audio.py and the Torch documentation. Generating a correct log_mel_spectrogram is crucial as the Whisper model uses this as its input.

@bobqianic
Copy link
Collaborator Author

Regarding the issue of slow FFT computation for non-N^2 cases, I came across an article yesterday. It suggested a mixed computational approach using three different algorithms, which could potentially speed up non-N^2 FFT computations by nearly 84 times compared to the Radix-2 Cooley-Tukey algorithm. The top layer initially uses the Mixed-Radix Cooley-Tukey algorithm for computation, followed by the PFA (Prime Factor Algorithm) for 63 and 60-point computations in the next layer. Finally, the WFTA (Winograd Fourier Transform Algorithm) is used for computations of 7, 9, 3, 4, and 5 points.

@bobqianic
Copy link
Collaborator Author

I would like to share some exciting news with you all. Now, in the Log_Mel function, when FFT_IN is all zeros, the FFT will not be computed, and the result will be output directly, significantly accelerating the calculation speed of the Log_Mel_Spectrogram. On the basis of the Sin/Cos_Cache optimization in #1142, the calculation time has been reduced by approximately 42%, which is a huge improvement! @ggerganov @AlexandrGraschenkov

Ryzen 7 5700X a13.wav

Original (ms) Sin/Cos_Cache (ms) Sin/Cos_Cache + Log_Mel_opt_v1 (ms) Sin/Cos_Cache + Log_Mel_opt_v2 (ms)
652.68 190.68 101.16 94.24
640.01 175.93 92.70 98.43
647.10 168.24 103.22 92.45
649.84 172.56 105.32 93.28
637.95 176.11 94.04 100.21
637.75 176.59 106.42 95.28
635.10 171.02 95.02 93.74
632.21 169.23 98.07 93.02
635.55 167.07 105.57 98.53
635.26 168.94 109.16 93.75
645.66 170.45 110.75 94.10

Ryzen 7 5700X diffusion2023-07-03.wav

Original (ms) Sin/Cos_Cache (ms) Sin/Cos_Cache + Log_Mel_opt_v1 (ms) Sin/Cos_Cache + Log_Mel_opt_v2 (ms)
5609.77 1478.99 859.58 857.02
5630.08 1467.94 859.60 844.49
5662.63 1467.00 849.40 858.90

@vadi2
Copy link
Contributor

vadi2 commented Aug 4, 2023

Excellent work!

To get a better perspective, how does it affect the overall transcription speed from start to end?

whisper.cpp Outdated
fft_in[j] = hann[j] * samples[offset + j];
}
// fill the rest with zeros
if (n_samples - offset < fft_size) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't doing a zero fill here problematic? Let's say you've got an fft_size of 256, and n_samples - offset happens to come out to 128. That means you're creating a cliff edge between fft_in[127] and fft_in[128]. The Hanning window is at its maximum value there, so you get no smoothing at all, it's going to be all artefacty.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! So, I am writing a more detailed code analysis note, comparing the differences in the methods of generating log mel spectrograms between OpenAI's whisper and whisper.cpp. I have already completed part of it, and I will publish this finished portion shortly.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've got a sneaking suspicion it might not be easy to get this right without doing some rework on the audio pipeline. Ideally you'd want to be saving unprocessed samples for the next round so you're always working with a full buffer.

@bobqianic
Copy link
Collaborator Author

bobqianic commented Aug 6, 2023

This note will compare the differences between OpenAI and whisper.cpp in generating log mel spectrograms, following the sequence of steps used by OpenAI's whisper to create the log mel spectrogram.

Click me

Part-0: Introduce the Hyperparameters

whisper/audio.py

# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE  # 480000 samples in a 30-second chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH)  # 3000 frames in a mel spectrogram input

N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2  # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH)  # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN)  # 20ms per audio token

whisper.cpp/whisper.h

Lines 22 to 26 in a32c4aa

#define WHISPER_SAMPLE_RATE 16000
#define WHISPER_N_FFT 400
#define WHISPER_N_MEL 80
#define WHISPER_HOP_LENGTH 160
#define WHISPER_CHUNK_SIZE 30

How are these hardcoded hyperparameters obtained? We can find the answers by looking at OpenAI's paper on Whisper. In the paper, it mentions that the signal is 're-sampled to 16,000Hz,' so we know that SAMPLE_RATE equals 16000. Similarly, there's an '80-channel log-magnitude Mel spectrogram,' which means N_MELS must be 80. The paper also talks about '25-millisecond windows,' so N_FFT can be calculated as 16,000 multiplied by 0.025, resulting in 400. Lastly, it mentions 'a stride of 10 milliseconds,' allowing us to calculate HOP_LENGTH by multiplying 16,000 by 0.010, resulting in 160. The most important description is a stride of 10 milliseconds. That is to say, when a window of 25 milliseconds moves across the signal, it moves 10 milliseconds each time, meaning that adjacent windows will have an overlap of 15 milliseconds. This ensures that no information will be lost, and the captured information will be more complete and continuous.

image

Part-1: Sample Preprocessing

Although OpenAI has not specified whether the whisper model can support higher sample rate audio (greater than 16KHz), for convenience, we will follow OpenAI's standard and assume that whisper can only support up to 16KHz. To facilitate subsequent calculations, we need to decompress the compressed audio first.

whisper/audio.py

def load_audio(file: str, sr: int = SAMPLE_RATE):
    """
    Open an audio file and read as mono waveform, resampling as necessary

    Parameters
    ----------
    file: str
        The audio file to open

    sr: int
        The sample rate to resample the audio if necessary

    Returns
    -------
    A NumPy array containing the audio waveform, in float32 dtype.
    """

    # This launches a subprocess to decode audio while down-mixing
    # and resampling as necessary.  Requires the ffmpeg CLI in PATH.
    # fmt: off
    cmd = [
        "ffmpeg",
        "-nostdin",
        "-threads", "0",
        "-i", file,
        "-f", "s16le",
        "-ac", "1",
        "-acodec", "pcm_s16le",
        "-ar", str(sr),
        "-"
    ]
    # fmt: on
    try:
        out = run(cmd, capture_output=True, check=True).stdout
    except CalledProcessError as e:
        raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e

    return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0

const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (stereo) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}

Notice that np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 reads the raw bytes and converts them into a 16-bit integer array. The .flatten() method ensures the array is 1-dimensional. Then, the values are converted to 32-bit floating-point numbers (FP32) and normalized to the range between -1 and 0.999969, which corresponds to the range of 16-bit signed integers divided by 32768.0

image

A strange thing is that whisper.cpp, based on the settings, will try to convert stereo audio into mono, or turn mono into fake stereo without using ffmpeg, whereas OpenAI's whisper directly uses ffmpeg to convert it into mono. Although I can confirm that the method whisper.cpp uses to change the channel count is correct (according to Multimedia Programming Interface and Data Specifications 1.0 (Page 59)), I am not clear whether Whisper actually supports stereo audio or not.

image

Part-2: Sample Padding

In OpenAI's Whisper, padding of the sample is divided into two stages. The first stage is carried out in the log_mel_spectrogram function, and the second stage is in the stft function.

Sample Padding (Stage 1)

whisper/transcribe.py

    # Pad 30-seconds of silence to the input audio, for slicing
    mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
    content_frames = mel.shape[-1] - N_FRAMES

whisper/audio.py

import torch.nn.functional as F

"""
Here a lot of code is omitted.
"""

def log_mel_spectrogram(
    audio: Union[str, np.ndarray, torch.Tensor],
    n_mels: int = N_MELS,
    padding: int = 0,
    device: Optional[Union[str, torch.device]] = None,
):
    """
    Compute the log-Mel spectrogram of

    Parameters
    ----------
    audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
        The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz

    n_mels: int
        The number of Mel-frequency filters, only 80 is supported

    padding: int
        Number of zero samples to pad to the right

    device: Optional[Union[str, torch.device]]
        If given, the audio tensor is moved to this device before STFT

    Returns
    -------
    torch.Tensor, shape = (80, n_frames)
        A Tensor that contains the Mel spectrogram
    """
    if not torch.is_tensor(audio):
        if isinstance(audio, str):
            audio = load_audio(audio)
        audio = torch.from_numpy(audio)

    if device is not None:
        audio = audio.to(device)
    if padding > 0:
        audio = F.pad(audio, (0, padding))
    window = torch.hann_window(N_FFT).to(audio.device)
    stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
    magnitudes = stft[..., :-1].abs() ** 2

    filters = mel_filters(audio.device, n_mels)
    mel_spec = filters @ magnitudes

    log_spec = torch.clamp(mel_spec, min=1e-10).log10()
    log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
    log_spec = (log_spec + 4.0) / 4.0
    return log_spec

We can find that the log_mel_spectrogram function is called in whisper/transcribe.py to generate the log mel spectrogram. This function is located in whisper/audio.py, where we can see that as long as padding > 0, the audio will be padded. Since we passed the argument padding=N_SAMPLES when calling it in whisper/transcribe.py, the audio will definitely be padded.

image
image

According to Torch's documentation and the code in whisper/audio.py, we can conclude that its padding strategy is quite simple, just appending a lot of zeros (480,000 samples) to the end of the audio, equivalent to 30 seconds of blank audio, since the default values are value=0, mode='constant'. In fact, we can conduct a small experiment to verify our idea.

import torch.nn.functional as F
import torch


if __name__ == '__main__':
    t1d = torch.empty(5)
    print("Original:", t1d)
    t1d_new = F.pad(t1d, (0, 5))
    print("New:", t1d_new)
Original: tensor([1.0010e-38, 1.0469e-38, 9.7347e-39, 9.0919e-39, 1.0561e-38])
New: tensor([1.0010e-38, 1.0469e-38, 9.7347e-39, 9.0919e-39, 1.0561e-38, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])

Let's see how whisper.cpp implements Sample Padding?

if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 10;
}

whisper.cpp/whisper.cpp

Lines 4774 to 4829 in b948361

int whisper_full_parallel(
struct whisper_context * ctx,
struct whisper_full_params params,
const float * samples,
int n_samples,
int n_processors) {
if (n_processors == 1) {
return whisper_full(ctx, params, samples, n_samples);
}
int ret = 0;
// prepare separate states for each thread
std::vector<whisper_state*> states;
const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
// the calling thread will process the first chunk
// while the other threads will process the remaining chunks
std::vector<std::thread> workers(n_processors - 1);
for (int i = 0; i < n_processors - 1; ++i) {
// create a new state for each thread
states.push_back(whisper_init_state(ctx));
const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
auto params_cur = params;
params_cur.offset_ms = 0;
params_cur.print_progress = false;
params_cur.print_realtime = false;
params_cur.new_segment_callback = nullptr;
params_cur.new_segment_callback_user_data = nullptr;
params_cur.progress_callback = nullptr;
params_cur.progress_callback_user_data = nullptr;
workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur);
}
{
auto params_cur = params;
// We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk.
params_cur.print_realtime = false;
// Run the first transformation using default state but only for the first chunk.
ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
}
for (int i = 0; i < n_processors - 1; ++i) {
workers[i].join();
}

whisper.cpp/whisper.cpp

Lines 4027 to 4049 in b948361

int whisper_full_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
struct whisper_full_params params,
const float * samples,
int n_samples) {
// clear old results
auto & result_all = state->result_all;
result_all.clear();
// compute log mel spectrogram
if (params.speed_up) {
if (whisper_pcm_to_mel_phase_vocoder_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
log("%s: failed to compute log mel spectrogram\n", __func__);
return -1;
}
} else {
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
log("%s: failed to compute log mel spectrogram\n", __func__);
return -2;
}
}

whisper.cpp/whisper.cpp

Lines 2997 to 3004 in b948361

int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
log("%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
return 0;
}

whisper.cpp/whisper.cpp

Lines 2493 to 2582 in b948361

static bool log_mel_spectrogram(
whisper_state & wstate,
const float * samples,
const int n_samples,
const int /*sample_rate*/,
const int fft_size,
const int fft_step,
const int n_mel,
const int n_threads,
const whisper_filters & filters,
const bool speed_up,
whisper_mel & mel) {
const int64_t t_start_us = ggml_time_us();
// Hanning window
std::vector<float> hann;
hann.resize(fft_size);
for (int i = 0; i < fft_size; i++) {
hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
}
mel.n_mel = n_mel;
mel.n_len = n_samples/fft_step;
mel.n_len_org = mel.n_len;
std::vector<float> samples_padded;
// pad audio with at least one extra chunk of zeros
{
const int pad = (100*WHISPER_CHUNK_SIZE)/2;
if (mel.n_len % pad != 0) {
mel.n_len = (mel.n_len/pad + 1)*pad;
}
mel.n_len += pad;
samples_padded.resize(mel.n_len*fft_step);
memcpy(samples_padded.data(), samples, n_samples*sizeof(float));
memset(samples_padded.data() + n_samples, 0, (mel.n_len*fft_step - n_samples)*sizeof(float));
samples = samples_padded.data();
}
mel.data.resize(mel.n_mel*mel.n_len);
//printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
//printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
{
std::vector<std::thread> workers(n_threads - 1);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw] = std::thread(
log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples,
n_samples, fft_size, fft_step, n_threads,
std::cref(filters), speed_up, std::ref(mel));
}
// main thread
log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw].join();
}
}
// clamping and normalization
double mmax = -1e20;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] > mmax) {
mmax = mel.data[i];
}
}
//printf("%s: max = %f\n", __func__, mmax);
mmax -= 8.0;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] < mmax) {
mel.data[i] = mmax;
}
mel.data[i] = (mel.data[i] + 4.0)/4.0;
}
wstate.t_mel_us += ggml_time_us() - t_start_us;
//printf("mel.n_len() = %d, divided by 1500: %f, n_samples / fft_step: %d\n", mel.n_len, mel.n_len / 1500.0, n_samples / fft_step);
return true;
}

To make it clear for everyone, I have pasted the most critical code for Sample Padding here.

     mel.n_mel     = n_mel; 
     mel.n_len     = n_samples/fft_step; 
     mel.n_len_org = mel.n_len; 
  
     std::vector<float> samples_padded; 
  
     // pad audio with at least one extra chunk of zeros 
     { 
         const int pad = (100*WHISPER_CHUNK_SIZE)/2; 
  
         if (mel.n_len % pad != 0) { 
             mel.n_len = (mel.n_len/pad + 1)*pad; 
         } 
         mel.n_len += pad; 
  
         samples_padded.resize(mel.n_len*fft_step); 
         memcpy(samples_padded.data(), samples, n_samples*sizeof(float)); 
         memset(samples_padded.data() + n_samples, 0, (mel.n_len*fft_step - n_samples)*sizeof(float)); 
  
         samples = samples_padded.data(); 
     } 
  
     mel.data.resize(mel.n_mel*mel.n_len); 

Actually, this code has many problems, but in this chapter, we are only discussing Sample Padding. In C/C++, the integer division operator '/' itself has a truncation function, meaning it rounds down by discarding the decimal part for numbers greater than 0. So, mel.n_len = (mel.n_len/pad + 1)*pad; will increase mel.n_len by at most 1500, and mel.n_len += pad; will further increase mel.n_len by 1500 on this basis. Overall, this will increase mel.n_len by 1500 to 3000, padding 240,000 to 480,000 samples, equivalent to 15 seconds to 30 seconds of blank audio. This is significantly less than the 480,000 samples padded by OpenAI's whisper.

Sample Padding (Stage 2)

stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
         win_length: Optional[int] = None, window: Optional[Tensor] = None,
         center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
         onesided: Optional[bool] = None,
         return_complex: Optional[bool] = None) -> Tensor:

    if has_torch_function_unary(input):
        return handle_torch_function(
            stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
            window=window, center=center, pad_mode=pad_mode, normalized=normalized,
            onesided=onesided, return_complex=return_complex)
    # NOTE: Do not edit. This code will be removed once the forward-compatibility
    #       period is over for PR #73432
    if center:
        signal_dim = input.dim()
        extended_shape = [1] * (3 - signal_dim) + list(input.size())
        pad = int(n_fft // 2)
        input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
        input = input.view(input.shape[-signal_dim:])
    return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]
                    normalized, onesided, return_complex)

In OpenAI's Whisper, in addition to appending many zeros at the end of the Sample, there is also a padding hidden within the stft function. However, this padding strategy is more complex, adding 200 samples to both the beginning and the end of the sample, using the padding mode reflect. But what is reflect? I drew a graph to describe it simply.

image

For example, if I now want to pad this array of length 8 at both the beginning and end, adding 2 samples to each, and the mode is 'reflect'. It will end up looking like this, reflecting around the first sample at both the beginning and the end as the axis of symmetry. Similarly, we can conduct an experiment to verify our idea. The reason why it carries out such a complex operation here, first increasing the dimension to pad, then reducing it, is because torch.nn.functional.pad currently does not support direct reflect padding on a 1-D tensor. Otherwise, you will receive the following error message: 'NotImplementedError: Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now.'

import torch.nn.functional as F
import torch


if __name__ == '__main__':
    t1d = torch.empty(10)
    for x in range(10):
        t1d[x] = x
    print("Original:", t1d)
    extended_shape = [1] * (3 - t1d.dim()) + list(t1d.size())
    t1d_new = F.pad(t1d.view(extended_shape), (5, 5), mode="reflect")
    print("New:", t1d_new.view(t1d_new.shape[-t1d.dim():]))
Original: tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
New: tensor([5., 4., 3., 2., 1., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 8., 7., 6., 5., 4.])

How does whisper.cpp achieve this? Unfortunately, I carefully checked the code of whisper.cpp and found that it does not carry out this padding stage. If anyone finds that it does, please tell me where it is?

Part-3: Hann Window Generation

image

whisper/audio.py

window = torch.hann_window(N_FFT).to(audio.device)

OpenAI's Whisper directly uses the built-in torch.hann_window in PyTorch to generate a Hann Window. Since PyTorch invokes backend code for computation, and the calling process uses automatically generated code that is very complex, we are unable to understand how the internal calculation works. Fortunately, the official Torch documentation provides a detailed explanation.

image

We can conduct an experiment to check the output of torch.hann_window

import torch


if __name__ == '__main__':
    print(torch.hann_window(400))
Show Console
tensor([0.0000e+00, 6.1691e-05, 2.4673e-04, 5.5507e-04, 9.8664e-04, 1.5413e-03,
        2.2190e-03, 3.0195e-03, 3.9426e-03, 4.9882e-03, 6.1558e-03, 7.4453e-03,
        8.8564e-03, 1.0389e-02, 1.2042e-02, 1.3815e-02, 1.5708e-02, 1.7721e-02,
        1.9853e-02, 2.2103e-02, 2.4472e-02, 2.6957e-02, 2.9560e-02, 3.2278e-02,
        3.5112e-02, 3.8060e-02, 4.1123e-02, 4.4298e-02, 4.7586e-02, 5.0986e-02,
        5.4497e-02, 5.8117e-02, 6.1847e-02, 6.5684e-02, 6.9629e-02, 7.3680e-02,
        7.7836e-02, 8.2096e-02, 8.6460e-02, 9.0925e-02, 9.5491e-02, 1.0016e-01,
        1.0492e-01, 1.0978e-01, 1.1474e-01, 1.1980e-01, 1.2494e-01, 1.3018e-01,
        1.3552e-01, 1.4094e-01, 1.4645e-01, 1.5204e-01, 1.5773e-01, 1.6349e-01,
        1.6934e-01, 1.7528e-01, 1.8129e-01, 1.8738e-01, 1.9355e-01, 1.9979e-01,
        2.0611e-01, 2.1250e-01, 2.1896e-01, 2.2549e-01, 2.3209e-01, 2.3875e-01,
        2.4548e-01, 2.5227e-01, 2.5912e-01, 2.6604e-01, 2.7300e-01, 2.8003e-01,
        2.8711e-01, 2.9424e-01, 3.0143e-01, 3.0866e-01, 3.1594e-01, 3.2326e-01,
        3.3063e-01, 3.3804e-01, 3.4549e-01, 3.5298e-01, 3.6050e-01, 3.6806e-01,
        3.7566e-01, 3.8328e-01, 3.9093e-01, 3.9861e-01, 4.0631e-01, 4.1404e-01,
        4.2178e-01, 4.2955e-01, 4.3733e-01, 4.4513e-01, 4.5295e-01, 4.6077e-01,
        4.6860e-01, 4.7645e-01, 4.8429e-01, 4.9215e-01, 5.0000e-01, 5.0785e-01,
        5.1571e-01, 5.2355e-01, 5.3140e-01, 5.3923e-01, 5.4705e-01, 5.5487e-01,
        5.6267e-01, 5.7045e-01, 5.7822e-01, 5.8596e-01, 5.9369e-01, 6.0139e-01,
        6.0907e-01, 6.1672e-01, 6.2435e-01, 6.3194e-01, 6.3950e-01, 6.4702e-01,
        6.5451e-01, 6.6196e-01, 6.6937e-01, 6.7674e-01, 6.8406e-01, 6.9134e-01,
        6.9857e-01, 7.0576e-01, 7.1289e-01, 7.1997e-01, 7.2700e-01, 7.3396e-01,
        7.4088e-01, 7.4773e-01, 7.5452e-01, 7.6125e-01, 7.6791e-01, 7.7451e-01,
        7.8104e-01, 7.8750e-01, 7.9389e-01, 8.0021e-01, 8.0645e-01, 8.1262e-01,
        8.1871e-01, 8.2472e-01, 8.3066e-01, 8.3651e-01, 8.4227e-01, 8.4796e-01,
        8.5355e-01, 8.5906e-01, 8.6448e-01, 8.6982e-01, 8.7506e-01, 8.8020e-01,
        8.8526e-01, 8.9022e-01, 8.9508e-01, 8.9984e-01, 9.0451e-01, 9.0907e-01,
        9.1354e-01, 9.1790e-01, 9.2216e-01, 9.2632e-01, 9.3037e-01, 9.3432e-01,
        9.3815e-01, 9.4188e-01, 9.4550e-01, 9.4901e-01, 9.5241e-01, 9.5570e-01,
        9.5888e-01, 9.6194e-01, 9.6489e-01, 9.6772e-01, 9.7044e-01, 9.7304e-01,
        9.7553e-01, 9.7790e-01, 9.8015e-01, 9.8228e-01, 9.8429e-01, 9.8618e-01,
        9.8796e-01, 9.8961e-01, 9.9114e-01, 9.9255e-01, 9.9384e-01, 9.9501e-01,
        9.9606e-01, 9.9698e-01, 9.9778e-01, 9.9846e-01, 9.9901e-01, 9.9944e-01,
        9.9975e-01, 9.9994e-01, 1.0000e+00, 9.9994e-01, 9.9975e-01, 9.9944e-01,
        9.9901e-01, 9.9846e-01, 9.9778e-01, 9.9698e-01, 9.9606e-01, 9.9501e-01,
        9.9384e-01, 9.9255e-01, 9.9114e-01, 9.8961e-01, 9.8796e-01, 9.8618e-01,
        9.8429e-01, 9.8228e-01, 9.8015e-01, 9.7790e-01, 9.7553e-01, 9.7304e-01,
        9.7044e-01, 9.6772e-01, 9.6489e-01, 9.6194e-01, 9.5888e-01, 9.5570e-01,
        9.5241e-01, 9.4901e-01, 9.4550e-01, 9.4188e-01, 9.3815e-01, 9.3432e-01,
        9.3037e-01, 9.2632e-01, 9.2216e-01, 9.1790e-01, 9.1354e-01, 9.0907e-01,
        9.0451e-01, 8.9984e-01, 8.9508e-01, 8.9022e-01, 8.8526e-01, 8.8020e-01,
        8.7506e-01, 8.6982e-01, 8.6448e-01, 8.5906e-01, 8.5355e-01, 8.4796e-01,
        8.4227e-01, 8.3651e-01, 8.3066e-01, 8.2472e-01, 8.1871e-01, 8.1262e-01,
        8.0645e-01, 8.0021e-01, 7.9389e-01, 7.8750e-01, 7.8104e-01, 7.7451e-01,
        7.6791e-01, 7.6125e-01, 7.5452e-01, 7.4773e-01, 7.4088e-01, 7.3396e-01,
        7.2700e-01, 7.1997e-01, 7.1289e-01, 7.0576e-01, 6.9857e-01, 6.9134e-01,
        6.8406e-01, 6.7674e-01, 6.6937e-01, 6.6196e-01, 6.5451e-01, 6.4702e-01,
        6.3950e-01, 6.3194e-01, 6.2434e-01, 6.1672e-01, 6.0907e-01, 6.0139e-01,
        5.9369e-01, 5.8596e-01, 5.7822e-01, 5.7045e-01, 5.6267e-01, 5.5487e-01,
        5.4705e-01, 5.3923e-01, 5.3140e-01, 5.2355e-01, 5.1571e-01, 5.0785e-01,
        5.0000e-01, 4.9215e-01, 4.8429e-01, 4.7645e-01, 4.6860e-01, 4.6077e-01,
        4.5295e-01, 4.4513e-01, 4.3733e-01, 4.2955e-01, 4.2178e-01, 4.1404e-01,
        4.0631e-01, 3.9861e-01, 3.9093e-01, 3.8328e-01, 3.7565e-01, 3.6806e-01,
        3.6050e-01, 3.5298e-01, 3.4549e-01, 3.3804e-01, 3.3063e-01, 3.2326e-01,
        3.1594e-01, 3.0866e-01, 3.0143e-01, 2.9424e-01, 2.8711e-01, 2.8003e-01,
        2.7300e-01, 2.6604e-01, 2.5912e-01, 2.5227e-01, 2.4548e-01, 2.3875e-01,
        2.3209e-01, 2.2549e-01, 2.1896e-01, 2.1250e-01, 2.0611e-01, 1.9979e-01,
        1.9355e-01, 1.8738e-01, 1.8129e-01, 1.7528e-01, 1.6934e-01, 1.6349e-01,
        1.5773e-01, 1.5204e-01, 1.4645e-01, 1.4094e-01, 1.3552e-01, 1.3018e-01,
        1.2494e-01, 1.1980e-01, 1.1474e-01, 1.0978e-01, 1.0492e-01, 1.0016e-01,
        9.5491e-02, 9.0925e-02, 8.6460e-02, 8.2096e-02, 7.7836e-02, 7.3680e-02,
        6.9629e-02, 6.5684e-02, 6.1847e-02, 5.8117e-02, 5.4497e-02, 5.0986e-02,
        4.7586e-02, 4.4298e-02, 4.1123e-02, 3.8060e-02, 3.5112e-02, 3.2278e-02,
        2.9560e-02, 2.6957e-02, 2.4472e-02, 2.2103e-02, 1.9853e-02, 1.7721e-02,
        1.5708e-02, 1.3815e-02, 1.2042e-02, 1.0389e-02, 8.8564e-03, 7.4453e-03,
        6.1558e-03, 4.9882e-03, 3.9426e-03, 3.0195e-03, 2.2190e-03, 1.5413e-03,
        9.8664e-04, 5.5507e-04, 2.4673e-04, 6.1691e-05])

whisper.cpp/whisper.cpp

Lines 2507 to 2512 in b948361

// Hanning window
std::vector<float> hann;
hann.resize(fft_size);
for (int i = 0; i < fft_size; i++) {
hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
}

I had always thought that the implementation of generating the Hann window in whisper.cpp was problematic. But today, after taking a closer look, I realized that it is actually not problematic. This is because torch.hann_window is set to periodic=True by default, so N becomes window_size + 1, and therefore the denominator should indeed be fft_size. But we still need to compare the output results with the results output by torch.

#include <cmath>
#include <iostream>
#include <vector>
#include <corecrt_math_defines.h>


int main() {
    const int fft_size = 400;
    std::vector<float> hann;
    hann.resize(fft_size);
    for (int i = 0; i < fft_size; i++) {
        hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
    }

    // print results
    for (int i = 0; i < fft_size; i++) {
        if (i > 0 && i % 6 == 0) {
            std::cout << std::endl;
        }
        std::cout << hann[i];
        if (i + 1 < fft_size) {
            std::cout << ", ";
        }
    }
}
Show Console
0, 6.16838e-05, 0.00024672, 0.000555062, 0.000986636, 0.00154133,
0.00221902, 0.00301952, 0.00394265, 0.00498817, 0.00615583, 0.00744534,
0.00885637, 0.0103886, 0.0120416, 0.013815, 0.0157084, 0.0177213,
0.0198532, 0.0221035, 0.0244717, 0.0269573, 0.0295596, 0.032278,
0.0351118, 0.0380602, 0.0411227, 0.0442984, 0.0475865, 0.0509862,
0.0544967, 0.0581172, 0.0618467, 0.0656842, 0.069629, 0.0736799,
0.077836, 0.0820963, 0.0864597, 0.0909251, 0.0954915, 0.100158,
0.104922, 0.109785, 0.114743, 0.119797, 0.124944, 0.130184,
0.135516, 0.140937, 0.146447, 0.152044, 0.157726, 0.163494,
0.169344, 0.175276, 0.181288, 0.187379, 0.193546, 0.19979,
0.206107, 0.212497, 0.218958, 0.225489, 0.232087, 0.238751,
0.245479, 0.252271, 0.259123, 0.266035, 0.273005, 0.28003,
0.28711, 0.294243, 0.301426, 0.308658, 0.315938, 0.323263,
0.330631, 0.338041, 0.345491, 0.35298, 0.360504, 0.368063,
0.375655, 0.383277, 0.390928, 0.398606, 0.406309, 0.414035,
0.421783, 0.429549, 0.437333, 0.445133, 0.452946, 0.46077,
0.468605, 0.476447, 0.484295, 0.492146, 0.5, 0.507854,
0.515705, 0.523553, 0.531395, 0.53923, 0.547054, 0.554867,
0.562667, 0.570451, 0.578217, 0.585965, 0.593691, 0.601394,
0.609072, 0.616723, 0.624345, 0.631937, 0.639496, 0.64702,
0.654508, 0.661959, 0.669369, 0.676737, 0.684062, 0.691342,
0.698574, 0.705757, 0.71289, 0.71997, 0.726995, 0.733965,
0.740877, 0.747729, 0.754521, 0.761249, 0.767913, 0.774511,
0.781042, 0.787503, 0.793893, 0.80021, 0.806454, 0.812621, 
0.818712, 0.824724, 0.830656, 0.836506, 0.842274, 0.847956,
0.853553, 0.859063, 0.864484, 0.869816, 0.875056, 0.880203,
0.885257, 0.890215, 0.895078, 0.899842, 0.904508, 0.909075,
0.91354, 0.917904, 0.922164, 0.92632, 0.930371, 0.934316,
0.938153, 0.941883, 0.945503, 0.949014, 0.952413, 0.955702,
0.958877, 0.96194, 0.964888, 0.967722, 0.97044, 0.973043,
0.975528, 0.977897, 0.980147, 0.982279, 0.984292, 0.986185,
0.987958, 0.989611, 0.991144, 0.992555, 0.993844, 0.995012,
0.996057, 0.99698, 0.997781, 0.998459, 0.999013, 0.999445,
0.999753, 0.999938, 1, 0.999938, 0.999753, 0.999445,
0.999013, 0.998459, 0.997781, 0.99698, 0.996057, 0.995012,
0.993844, 0.992555, 0.991144, 0.989611, 0.987958, 0.986185,
0.984292, 0.982279, 0.980147, 0.977897, 0.975528, 0.973043,
0.97044, 0.967722, 0.964888, 0.96194, 0.958877, 0.955702,
0.952413, 0.949014, 0.945503, 0.941883, 0.938153, 0.934316,
0.930371, 0.92632, 0.922164, 0.917904, 0.91354, 0.909075, 
0.904508, 0.899842, 0.895078, 0.890215, 0.885257, 0.880203,
0.875056, 0.869816, 0.864484, 0.859063, 0.853553, 0.847956,
0.842274, 0.836506, 0.830656, 0.824724, 0.818712, 0.812621,
0.806454, 0.80021, 0.793893, 0.787503, 0.781042, 0.774511,
0.767913, 0.761249, 0.754521, 0.747729, 0.740877, 0.733965,
0.726995, 0.71997, 0.71289, 0.705757, 0.698574, 0.691342,
0.684062, 0.676737, 0.669369, 0.661959, 0.654508, 0.64702,
0.639496, 0.631937, 0.624345, 0.616723, 0.609072, 0.601394,
0.593691, 0.585965, 0.578217, 0.570451, 0.562667, 0.554867,
0.547054, 0.53923, 0.531395, 0.523553, 0.515705, 0.507854,
0.5, 0.492146, 0.484295, 0.476447, 0.468605, 0.46077,
0.452946, 0.445133, 0.437333, 0.429549, 0.421783, 0.414035,
0.406309, 0.398606, 0.390928, 0.383277, 0.375655, 0.368063,
0.360504, 0.35298, 0.345491, 0.338041, 0.330631, 0.323263,
0.315938, 0.308658, 0.301426, 0.294243, 0.28711, 0.28003,
0.273005, 0.266035, 0.259123, 0.252271, 0.245479, 0.238751, 
0.232087, 0.225489, 0.218958, 0.212497, 0.206107, 0.19979,
0.193546, 0.187379, 0.181288, 0.175276, 0.169344, 0.163494,
0.157726, 0.152044, 0.146447, 0.140937, 0.135516, 0.130184,
0.124944, 0.119797, 0.114743, 0.109785, 0.104922, 0.100158,
0.0954915, 0.0909251, 0.0864597, 0.0820963, 0.077836, 0.0736799,
0.069629, 0.0656842, 0.0618467, 0.0581172, 0.0544967, 0.0509862,
0.0475865, 0.0442984, 0.0411227, 0.0380602, 0.0351118, 0.032278,
0.0295596, 0.0269573, 0.0244717, 0.0221035, 0.0198532, 0.0177213,
0.0157084, 0.013815, 0.0120416, 0.0103886, 0.00885637, 0.00744534,
0.00615583, 0.00498817, 0.00394265, 0.00301952, 0.00221902, 0.00154133,
0.000986636, 0.000555062, 0.00024672, 6.16838e-05

I used the results output by torch as the numerator and the results from whisper.cpp as the denominator, dividing the corresponding items, and obtained the following results. I have tried using the PI defined by torch, 3.141592653589793, but the results did not change.

Show Console
-nan(ind), 1.00011, 1.00006, 1.00001, 1, 1.00001,
1, 0.999997, 0.999997, 0.999999, 0.999998, 1,
1, 0.999999, 1, 1, 1, 1,
0.999999, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1,
1, 1, 0.999999, 1, 0.999999, 0.999999,
0.999999, 0.999999, 0.999999, 0.999999, 0.999999, 0.999999,
0.999999, 0.999999, 0.999999, 0.999999, 0.999999, 0.999999,
0.999999, 0.999999, 0.999999, 0.999999, 0.999999, 0.999999,
0.999999, 0.999999, 0.999999, 0.999999, 0.999998, 0.999999,
0.999998, 0.999998, 0.999998, 0.999998, 0.999998, 0.999998,
0.999998, 0.999998, 0.999998, 0.999997, 0.999997, 0.999997,
0.999997, 0.999997, 0.999997, 0.999996, 0.999996, 0.999995,
0.999996, 0.999998, 0.999998, 0.999999, 0.999998, 1,
0.999998, 0.999999, 0.999997, 0.999997, 1, 0.999989,
1, 1.00001, 1.00006, 1.00011

regularfry: Torch's default internal representation is float32, and the C++ code is using float64. If you use torch.empty(1, dtype=torch.float64) you'll get matching results. It looks like torch uses std::cos under the hood so they'll match if the precision does.

import torch


if __name__ == '__main__':
    t = torch.empty(1)
    t[0] = 0.0157079632679
    print(float(torch.cos(t)))
0.9998766183853149
import torch


if __name__ == '__main__':
    t = torch.empty(1, dtype=torch.float64)
    t[0] = 0.0157079632679
    print(float(torch.cos(t)))
0.9998766324816614
#include <cmath>
#include <iostream>
#include <iomanip>


int main() {
    double t = 0.0157079632679;
    std::cout << std::setprecision(16) << cos(t) << std::endl;
}
0.9998766324816614

However, I think overall, the problem shouldn't be too significant, since we are using the float type, which only has 6 to 7 significant digits?

Part-4: Short-Time Fourier Transform

Traditional Fourier transforms describe the frequency components in a signal averaged over all time. The most important and informative aspects of signals like speech and music, however, is how these frequency components evolve over time.

whisper/audio.py

stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)

torch/functional.py

return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]
                    normalized, onesided, return_complex)
image

OpenAI's Whisper directly uses PyTorch's torch.stft to perform STFT calculations. Most of the computations in torch.stft are implemented in C++, and although PyTorch provides documentation, the STFT is rather complex, and the documentation does not fully describe how the STFT is calculated, with many details left out. Fortunately, I successfully found the C++ code that PyTorch uses to calculate the STFT.

pytorch/aten/src/ATen/native/SpectralOps.cpp

Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
            const optional<int64_t> win_lengthOpt, const c10::optional<Tensor>& window_opt,
            const bool center, c10::string_view mode, const bool normalized,
            const optional<bool> onesidedOpt, const optional<bool> return_complexOpt) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> window_maybe_owned = at::borrow_from_optional_tensor(window_opt);
  const Tensor& window = *window_maybe_owned;

  #define REPR(SS) \
    SS << "stft(" << self.toString() << self.sizes() << ", n_fft=" << n_fft \
       << ", hop_length=" << hop_length << ", win_length=" << win_length \
       << ", window="; \
    if (window.defined()) { \
      SS << window.toString() << "{" << window.sizes() << "}"; \
    } else { \
      SS << "None"; \
    } \
    SS << ", normalized=" << normalized << ", onesided="; \
    write_opt(SS, onesidedOpt) << ", return_complex="; \
    write_opt(SS, return_complexOpt) << ") "

  TORCH_CHECK(!window.defined() || window.device() == self.device(),
              "stft input and window must be on the same device but got self on ",
              self.device(), " and window on ", window.device())

  // default_init hop_length and win_length
  auto hop_length = hop_lengthOpt.value_or(n_fft >> 2);
  auto win_length = win_lengthOpt.value_or(n_fft);
  const bool return_complex = return_complexOpt.value_or(
      self.is_complex() || (window.defined() && window.is_complex()));
  if (!return_complex) {
    TORCH_CHECK(return_complexOpt.has_value(),
        "stft requires the return_complex parameter be given for real inputs, "
        "and will further require that return_complex=True in a future PyTorch release.");


    TORCH_WARN_ONCE(
        "stft with return_complex=False is deprecated. In a future pytorch "
        "release, stft will return complex tensors for all inputs, and "
        "return_complex=False will raise an error.\n"
        "Note: you can still call torch.view_as_real on the complex output to "
        "recover the old return format.");
  }

  if (!at::isFloatingType(self.scalar_type()) && !at::isComplexType(self.scalar_type())) {
    std::ostringstream ss;
    REPR(ss) << ": expected a tensor of floating point or complex values";
    AT_ERROR(ss.str());
  }
  if (self.dim() > 2 || self.dim() < 1) {
    std::ostringstream ss;
    REPR(ss) << ": expected a 1D or 2D tensor";
    AT_ERROR(ss.str());
  }
  Tensor input = self;
  if (self.dim() == 1) {
    input = input.unsqueeze(0);
  }

  if (center) {
    const auto input_shape = input.sizes();
    const auto input_dim = input_shape.size();
    const auto extra_dims = std::max(size_t{3}, input_dim) - input_dim;
    const auto pad_amount = n_fft / 2;

    DimVector extended_shape(extra_dims, 1);
    extended_shape.append(input_shape.begin(), input_shape.end());
    input = at::pad(input.view(extended_shape), {pad_amount, pad_amount}, mode);
    input = input.view(IntArrayRef(input.sizes()).slice(extra_dims));
  }

  int64_t batch = input.size(0);
  int64_t len = input.size(1);
  if (n_fft <= 0 || n_fft > len) {
    std::ostringstream ss;
    REPR(ss) << ": expected 0 < n_fft < " << len
             << ", but got n_fft=" << win_length;
    AT_ERROR(ss.str());
  }
  if (hop_length <= 0) {
    std::ostringstream ss;
    REPR(ss) << ": expected hop_length > 0, but got hop_length=" << hop_length;
    AT_ERROR(ss.str());
  }
  if (win_length <= 0 || win_length > n_fft) {
    std::ostringstream ss;
    REPR(ss) << ": expected 0 < win_length <= n_fft, but got win_length="
             << win_length;
    AT_ERROR(ss.str());
  }
  if (window.defined() && (window.dim() != 1 || window.size(0) != win_length)) {
    std::ostringstream ss;
    REPR(ss) << ": expected a 1D window tensor of size equal to win_length="
             << win_length << ", but got window with size " << window.sizes();
    AT_ERROR(ss.str());
  }
  #undef REPR
  auto window_ = window;
  if (win_length < n_fft) {
    // pad center
    auto left = (n_fft - win_length) / 2;
    if (window.defined()) {
      window_ = at::zeros({n_fft}, window.options());
      window_.narrow(0, left, win_length).copy_(window);
    } else {
      window_ = at::zeros({n_fft}, self.options());
      window_.narrow(0, left, win_length).fill_(1);
    }
  }
  int64_t n_frames = 1 + (len - n_fft) / hop_length;
  // time2col
  input = input.as_strided(
    {batch, n_frames, n_fft},
    {input.stride(0), hop_length * input.stride(1), input.stride(1)}
  );
  if (window_.defined()) {
    input = input.mul(window_);
  }

  // FFT and transpose to get (batch x fft_size x num_frames)
  const bool complex_fft = input.is_complex();
  const auto onesided = onesidedOpt.value_or(!complex_fft);

  const fft_norm_mode norm = normalized ? fft_norm_mode::by_root_n : fft_norm_mode::none;
  Tensor out;
  if (complex_fft) {
    TORCH_CHECK(!onesided, "Cannot have onesided output if window or input is complex");
    out = at::_fft_c2c(input, input.dim() - 1, static_cast<int64_t>(norm), /*forward=*/true);
  } else {
    out = at::_fft_r2c(input, input.dim() - 1, static_cast<int64_t>(norm), onesided);
  }
  out.transpose_(1, 2);

  if (self.dim() == 1) {
    out.squeeze_(0);
  }

  if (return_complex) {
    return out;
  } else {
    return at::view_as_real(out);
  }
}

Tensor stft(
    const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
    const optional<int64_t> win_lengthOpt, const c10::optional<Tensor>& window_opt,
    const bool normalized,
    const optional<bool> onesidedOpt, const optional<bool> return_complexOpt) {
  return at::stft(
      self, n_fft, hop_lengthOpt, win_lengthOpt, window_opt,
      /*center=*/false, /*mode=*/"constant", normalized, onesidedOpt,
      return_complexOpt);
}

The STFT C++ code is written very obscurely and is not very easy to understand. Fortunately, PyTorch provides a set of Python APIs that allow you to directly call the tensor's member functions in Python, which is very convenient. The following will use this method to conduct an in-depth analysis of this program.

Tensor stft(
    const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
    const optional<int64_t> win_lengthOpt, const c10::optional<Tensor>& window_opt,
    const bool normalized,
    const optional<bool> onesidedOpt, const optional<bool> return_complexOpt) {
  return at::stft(
      self, n_fft, hop_lengthOpt, win_lengthOpt, window_opt,
      /*center=*/false, /*mode=*/"constant", normalized, onesidedOpt,
      return_complexOpt);
}

The code above was written by PyTorch for the sake of backward compatibility. Up until now, the Stage-2 Padding in STFT has always been implemented using Python, so the speed can be relatively slow. PyTorch plans to implement this part of the computation using C++, but it seems that the migration is not yet fully completed. The STFT function first checks the input variables to ensure that they all meet the definition; if not, an error will be reported. Then it handles the default values. After that, it enters the formal calculation phase.

  Tensor input = self;
  if (self.dim() == 1) {
    input = input.unsqueeze(0);
  }
  1. Copy the input Tensor (self) into a newly created Tensor (input), and check if the input Tensor (self) is one-dimensional. If it is, then the copied Tensor (input) is transformed from one dimension to two dimensions. We can conduct an experiment to verify our idea.
import torch


if __name__ == '__main__':
    input = torch.empty(10)
    print(input.size())
    input = input.unsqueeze(0)
    print(input.size())
torch.Size([10])
torch.Size([1, 10])

We can clearly see that the Tensor has added a dimension on top of its original basis.

  if (center) {
    const auto input_shape = input.sizes();
    const auto input_dim = input_shape.size();
    const auto extra_dims = std::max(size_t{3}, input_dim) - input_dim;
    const auto pad_amount = n_fft / 2;

    DimVector extended_shape(extra_dims, 1);
    extended_shape.append(input_shape.begin(), input_shape.end());
    input = at::pad(input.view(extended_shape), {pad_amount, pad_amount}, mode);
    input = input.view(IntArrayRef(input.sizes()).slice(extra_dims));
  }
  1. Check whether it is in Center mode; if it is, half of n_fft will be padded at both the beginning and the end, but since we have already padded it, this step is directly skipped.
  int64_t batch = input.size(0);
  int64_t len = input.size(1);
  1. Assign the size of the copied Tensor (input) to two variables. From the experiment results above, we can learn that for a one-dimensional Tensor input, the batch will be assigned a value of 1, and len will be assigned the length of the one-dimensional Tensor.
  auto window_ = window;
  if (win_length < n_fft) {
    // pad center
    auto left = (n_fft - win_length) / 2;
    if (window.defined()) {
      window_ = at::zeros({n_fft}, window.options());
      window_.narrow(0, left, win_length).copy_(window);
    } else {
      window_ = at::zeros({n_fft}, self.options());
      window_.narrow(0, left, win_length).fill_(1);
    }
  }
  1. Copy the Tensor (window) into a newly created Tensor (window_), and check whether the size of the window is consistent with the size of n_fft. If it is inconsistent, padding will be applied. If the window is found to be undefined, a window of all ones will be created, meaning it will have no effect (as if it doesn't exist). Since we have already ensured that the size of the input window is consistent with the size of n_fft, the last two steps are directly skipped.
int64_t n_frames = 1 + (len - n_fft) / hop_length;
  1. Calculate the total number of frames. Note that integer division in C++ comes with a built-in function that rounds down.
  input = input.as_strided(
    {batch, n_frames, n_fft},
    {input.stride(0), hop_length * input.stride(1), input.stride(1)}
  );
  1. Split the Tensor (input) into frames of the specified size and quantity. But what is as_strided? Before understanding as_strided, we must first understand a concept called Stride. In PyTorch, the Stride of a Tensor refers to the distance that needs to be moved in memory to access the next element within a dimension. This is because all Tensors, regardless of how many dimensions they have, are stored in memory in a one-dimensional form. as_strided can change the Size and Stride of a Tensor without copying in memory, achieving our goal of splitting frames. We can conduct an experiment to verify our idea.
import torch
import torch.nn.functional as F


if __name__ == '__main__':
    test = torch.empty(200)
    for i in range(200):
        test[i] = i

    signal_dim = test.dim()
    extended_shape = [1] * (3 - signal_dim) + list(test.size())
    pad = int(50 // 2)
    test = F.pad(test.view(extended_shape), [pad, pad], "reflect")
    test = test.view(test.shape[-signal_dim:])
    print("Input:", test)
    new = test.unsqueeze(0)
    print("Stride:", new.stride())
    new = new.as_strided((1, 11, 50), (250, 20, 1))
    print("Result:", new)
Show Console
Input: tensor([ 25.,  24.,  23.,  22.,  21.,  20.,  19.,  18.,  17.,  16.,  15.,  14.,
         13.,  12.,  11.,  10.,   9.,   8.,   7.,   6.,   5.,   4.,   3.,   2.,
          1.,   0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
         11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,  22.,
         23.,  24.,  25.,  26.,  27.,  28.,  29.,  30.,  31.,  32.,  33.,  34.,
         35.,  36.,  37.,  38.,  39.,  40.,  41.,  42.,  43.,  44.,  45.,  46.,
         47.,  48.,  49.,  50.,  51.,  52.,  53.,  54.,  55.,  56.,  57.,  58.,
         59.,  60.,  61.,  62.,  63.,  64.,  65.,  66.,  67.,  68.,  69.,  70.,
         71.,  72.,  73.,  74.,  75.,  76.,  77.,  78.,  79.,  80.,  81.,  82.,
         83.,  84.,  85.,  86.,  87.,  88.,  89.,  90.,  91.,  92.,  93.,  94.,
         95.,  96.,  97.,  98.,  99., 100., 101., 102., 103., 104., 105., 106.,
        107., 108., 109., 110., 111., 112., 113., 114., 115., 116., 117., 118.,
        119., 120., 121., 122., 123., 124., 125., 126., 127., 128., 129., 130.,
        131., 132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142.,
        143., 144., 145., 146., 147., 148., 149., 150., 151., 152., 153., 154.,
        155., 156., 157., 158., 159., 160., 161., 162., 163., 164., 165., 166.,
        167., 168., 169., 170., 171., 172., 173., 174., 175., 176., 177., 178.,
        179., 180., 181., 182., 183., 184., 185., 186., 187., 188., 189., 190.,
        191., 192., 193., 194., 195., 196., 197., 198., 199., 198., 197., 196.,
        195., 194., 193., 192., 191., 190., 189., 188., 187., 186., 185., 184.,
        183., 182., 181., 180., 179., 178., 177., 176., 175., 174.])
Stride: (250, 1)
Result: tensor([[[ 25.,  24.,  23.,  22.,  21.,  20.,  19.,  18.,  17.,  16.,  15.,
           14.,  13.,  12.,  11.,  10.,   9.,   8.,   7.,   6.,   5.,   4.,
            3.,   2.,   1.,   0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,
            8.,   9.,  10.,  11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,
           19.,  20.,  21.,  22.,  23.,  24.],
         [  5.,   4.,   3.,   2.,   1.,   0.,   1.,   2.,   3.,   4.,   5.,
            6.,   7.,   8.,   9.,  10.,  11.,  12.,  13.,  14.,  15.,  16.,
           17.,  18.,  19.,  20.,  21.,  22.,  23.,  24.,  25.,  26.,  27.,
           28.,  29.,  30.,  31.,  32.,  33.,  34.,  35.,  36.,  37.,  38.,
           39.,  40.,  41.,  42.,  43.,  44.],
         [ 15.,  16.,  17.,  18.,  19.,  20.,  21.,  22.,  23.,  24.,  25.,
           26.,  27.,  28.,  29.,  30.,  31.,  32.,  33.,  34.,  35.,  36.,
           37.,  38.,  39.,  40.,  41.,  42.,  43.,  44.,  45.,  46.,  47.,
           48.,  49.,  50.,  51.,  52.,  53.,  54.,  55.,  56.,  57.,  58.,
           59.,  60.,  61.,  62.,  63.,  64.],
         [ 35.,  36.,  37.,  38.,  39.,  40.,  41.,  42.,  43.,  44.,  45.,
           46.,  47.,  48.,  49.,  50.,  51.,  52.,  53.,  54.,  55.,  56.,
           57.,  58.,  59.,  60.,  61.,  62.,  63.,  64.,  65.,  66.,  67.,
           68.,  69.,  70.,  71.,  72.,  73.,  74.,  75.,  76.,  77.,  78.,
           79.,  80.,  81.,  82.,  83.,  84.],
         [ 55.,  56.,  57.,  58.,  59.,  60.,  61.,  62.,  63.,  64.,  65.,
           66.,  67.,  68.,  69.,  70.,  71.,  72.,  73.,  74.,  75.,  76.,
           77.,  78.,  79.,  80.,  81.,  82.,  83.,  84.,  85.,  86.,  87.,
           88.,  89.,  90.,  91.,  92.,  93.,  94.,  95.,  96.,  97.,  98.,
           99., 100., 101., 102., 103., 104.],
         [ 75.,  76.,  77.,  78.,  79.,  80.,  81.,  82.,  83.,  84.,  85.,
           86.,  87.,  88.,  89.,  90.,  91.,  92.,  93.,  94.,  95.,  96.,
           97.,  98.,  99., 100., 101., 102., 103., 104., 105., 106., 107.,
          108., 109., 110., 111., 112., 113., 114., 115., 116., 117., 118.,
          119., 120., 121., 122., 123., 124.],
         [ 95.,  96.,  97.,  98.,  99., 100., 101., 102., 103., 104., 105.,
          106., 107., 108., 109., 110., 111., 112., 113., 114., 115., 116.,
          117., 118., 119., 120., 121., 122., 123., 124., 125., 126., 127.,
          128., 129., 130., 131., 132., 133., 134., 135., 136., 137., 138.,
          139., 140., 141., 142., 143., 144.],
         [115., 116., 117., 118., 119., 120., 121., 122., 123., 124., 125.,
          126., 127., 128., 129., 130., 131., 132., 133., 134., 135., 136.,
          137., 138., 139., 140., 141., 142., 143., 144., 145., 146., 147.,
          148., 149., 150., 151., 152., 153., 154., 155., 156., 157., 158.,
          159., 160., 161., 162., 163., 164.],
         [135., 136., 137., 138., 139., 140., 141., 142., 143., 144., 145.,
          146., 147., 148., 149., 150., 151., 152., 153., 154., 155., 156.,
          157., 158., 159., 160., 161., 162., 163., 164., 165., 166., 167.,
          168., 169., 170., 171., 172., 173., 174., 175., 176., 177., 178.,
          179., 180., 181., 182., 183., 184.],
         [155., 156., 157., 158., 159., 160., 161., 162., 163., 164., 165.,
          166., 167., 168., 169., 170., 171., 172., 173., 174., 175., 176.,
          177., 178., 179., 180., 181., 182., 183., 184., 185., 186., 187.,
          188., 189., 190., 191., 192., 193., 194., 195., 196., 197., 198.,
          199., 198., 197., 196., 195., 194.],
         [175., 176., 177., 178., 179., 180., 181., 182., 183., 184., 185.,
          186., 187., 188., 189., 190., 191., 192., 193., 194., 195., 196.,
          197., 198., 199., 198., 197., 196., 195., 194., 193., 192., 191.,
          190., 189., 188., 187., 186., 185., 184., 183., 182., 181., 180.,
          179., 178., 177., 176., 175., 174.]]])

We first generated a one-dimensional Tensor that has already undergone Stage-2 Padding, and then used unsqueeze to increase its dimensionality from 1D to 2D, finally using as_strided to split it. In as_strided((1, 11, 50), (250, 20, 1)), the first 1 corresponds to the batch, 11 corresponds to n_frames, which we obtained through calculation, 50 corresponds to n_fft, which we arbitrarily set in our experiment, 250 corresponds to input.stride(0), which is the total length of our data, 20 corresponds to hop_length * input.stride(1), which is our moving step size, and the last 1 corresponds to input.stride(1). We can clearly see that the data has been divided into 11 frames, each with a length of 50. Since we used Stage-2 Padding, the data at index 0 has been placed near the center of frame 0, and the data at index 199 has been placed near the center of frame 10.

  if (window_.defined()) {
    input = input.mul(window_);
  }
  1. Apply a window function to each frame. We can conduct an experiment to verify our idea.
import torch
import torch.nn.functional as F


if __name__ == '__main__':
    test = torch.empty(200)
    for i in range(200):
        test[i] = i

    signal_dim = test.dim()
    extended_shape = [1] * (3 - signal_dim) + list(test.size())
    pad = int(50 // 2)
    test = F.pad(test.view(extended_shape), [pad, pad], "reflect")
    test = test.view(test.shape[-signal_dim:])
    new = test.unsqueeze(0)
    new = new.as_strided((1, 11, 50), (250, 20, 1))
    print("Before:", new)
    hnn = torch.hann_window(50)
    new = new.mul(hnn)
    print("After:", new)
Show Console
Before: tensor([[[ 25.,  24.,  23.,  22.,  21.,  20.,  19.,  18.,  17.,  16.,  15.,
           14.,  13.,  12.,  11.,  10.,   9.,   8.,   7.,   6.,   5.,   4.,
            3.,   2.,   1.,   0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,
            8.,   9.,  10.,  11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,
           19.,  20.,  21.,  22.,  23.,  24.],
         [  5.,   4.,   3.,   2.,   1.,   0.,   1.,   2.,   3.,   4.,   5.,
            6.,   7.,   8.,   9.,  10.,  11.,  12.,  13.,  14.,  15.,  16.,
           17.,  18.,  19.,  20.,  21.,  22.,  23.,  24.,  25.,  26.,  27.,
           28.,  29.,  30.,  31.,  32.,  33.,  34.,  35.,  36.,  37.,  38.,
           39.,  40.,  41.,  42.,  43.,  44.],
         [ 15.,  16.,  17.,  18.,  19.,  20.,  21.,  22.,  23.,  24.,  25.,
           26.,  27.,  28.,  29.,  30.,  31.,  32.,  33.,  34.,  35.,  36.,
           37.,  38.,  39.,  40.,  41.,  42.,  43.,  44.,  45.,  46.,  47.,
           48.,  49.,  50.,  51.,  52.,  53.,  54.,  55.,  56.,  57.,  58.,
           59.,  60.,  61.,  62.,  63.,  64.],
         [ 35.,  36.,  37.,  38.,  39.,  40.,  41.,  42.,  43.,  44.,  45.,
           46.,  47.,  48.,  49.,  50.,  51.,  52.,  53.,  54.,  55.,  56.,
           57.,  58.,  59.,  60.,  61.,  62.,  63.,  64.,  65.,  66.,  67.,
           68.,  69.,  70.,  71.,  72.,  73.,  74.,  75.,  76.,  77.,  78.,
           79.,  80.,  81.,  82.,  83.,  84.],
         [ 55.,  56.,  57.,  58.,  59.,  60.,  61.,  62.,  63.,  64.,  65.,
           66.,  67.,  68.,  69.,  70.,  71.,  72.,  73.,  74.,  75.,  76.,
           77.,  78.,  79.,  80.,  81.,  82.,  83.,  84.,  85.,  86.,  87.,
           88.,  89.,  90.,  91.,  92.,  93.,  94.,  95.,  96.,  97.,  98.,
           99., 100., 101., 102., 103., 104.],
         [ 75.,  76.,  77.,  78.,  79.,  80.,  81.,  82.,  83.,  84.,  85.,
           86.,  87.,  88.,  89.,  90.,  91.,  92.,  93.,  94.,  95.,  96.,
           97.,  98.,  99., 100., 101., 102., 103., 104., 105., 106., 107.,
          108., 109., 110., 111., 112., 113., 114., 115., 116., 117., 118.,
          119., 120., 121., 122., 123., 124.],
         [ 95.,  96.,  97.,  98.,  99., 100., 101., 102., 103., 104., 105.,
          106., 107., 108., 109., 110., 111., 112., 113., 114., 115., 116.,
          117., 118., 119., 120., 121., 122., 123., 124., 125., 126., 127.,
          128., 129., 130., 131., 132., 133., 134., 135., 136., 137., 138.,
          139., 140., 141., 142., 143., 144.],
         [115., 116., 117., 118., 119., 120., 121., 122., 123., 124., 125.,
          126., 127., 128., 129., 130., 131., 132., 133., 134., 135., 136.,
          137., 138., 139., 140., 141., 142., 143., 144., 145., 146., 147.,
          148., 149., 150., 151., 152., 153., 154., 155., 156., 157., 158.,
          159., 160., 161., 162., 163., 164.],
         [135., 136., 137., 138., 139., 140., 141., 142., 143., 144., 145.,
          146., 147., 148., 149., 150., 151., 152., 153., 154., 155., 156.,
          157., 158., 159., 160., 161., 162., 163., 164., 165., 166., 167.,
          168., 169., 170., 171., 172., 173., 174., 175., 176., 177., 178.,
          179., 180., 181., 182., 183., 184.],
         [155., 156., 157., 158., 159., 160., 161., 162., 163., 164., 165.,
          166., 167., 168., 169., 170., 171., 172., 173., 174., 175., 176.,
          177., 178., 179., 180., 181., 182., 183., 184., 185., 186., 187.,
          188., 189., 190., 191., 192., 193., 194., 195., 196., 197., 198.,
          199., 198., 197., 196., 195., 194.],
         [175., 176., 177., 178., 179., 180., 181., 182., 183., 184., 185.,
          186., 187., 188., 189., 190., 191., 192., 193., 194., 195., 196.,
          197., 198., 199., 198., 197., 196., 195., 194., 193., 192., 191.,
          190., 189., 188., 187., 186., 185., 184., 183., 182., 181., 180.,
          179., 178., 177., 176., 175., 174.]]])
After: tensor([[[0.0000e+00, 9.4623e-02, 3.6129e-01, 7.7246e-01, 1.2988e+00,
          1.9098e+00, 2.5748e+00, 3.2632e+00, 3.9455e+00, 4.5938e+00,
          5.1824e+00, 5.6883e+00, 6.0919e+00, 6.3767e+00, 6.5306e+00,
          6.5451e+00, 6.4160e+00, 6.1433e+00, 5.7310e+00, 5.1869e+00,
          4.5225e+00, 3.7526e+00, 2.8947e+00, 1.9686e+00, 9.9606e-01,
          0.0000e+00, 9.9606e-01, 1.9686e+00, 2.8947e+00, 3.7526e+00,
          4.5225e+00, 5.1869e+00, 5.7310e+00, 6.1433e+00, 6.4160e+00,
          6.5451e+00, 6.5306e+00, 6.3767e+00, 6.0919e+00, 5.6883e+00,
          5.1824e+00, 4.5938e+00, 3.9455e+00, 3.2632e+00, 2.5748e+00,
          1.9098e+00, 1.2988e+00, 7.7246e-01, 3.6129e-01, 9.4623e-02],
         [0.0000e+00, 1.5771e-02, 4.7125e-02, 7.0224e-02, 6.1847e-02,
          0.0000e+00, 1.3552e-01, 3.6258e-01, 6.9626e-01, 1.1484e+00,
          1.7275e+00, 2.4379e+00, 3.2802e+00, 4.2512e+00, 5.3432e+00,
          6.5451e+00, 7.8418e+00, 9.2150e+00, 1.0643e+01, 1.2103e+01,
          1.3568e+01, 1.5010e+01, 1.6403e+01, 1.7717e+01, 1.8925e+01,
          2.0000e+01, 2.0917e+01, 2.1654e+01, 2.2192e+01, 2.2516e+01,
          2.2613e+01, 2.2477e+01, 2.2105e+01, 2.1502e+01, 2.0674e+01,
          1.9635e+01, 1.8404e+01, 1.7005e+01, 1.5464e+01, 1.3815e+01,
          1.2092e+01, 1.0336e+01, 8.5872e+00, 6.8889e+00, 5.2851e+00,
          3.8197e+00, 2.5357e+00, 1.4747e+00, 6.7546e-01, 1.7348e-01],
         [0.0000e+00, 6.3082e-02, 2.6704e-01, 6.3201e-01, 1.1751e+00,
          1.9098e+00, 2.8458e+00, 3.9883e+00, 5.3380e+00, 6.8906e+00,
          8.6373e+00, 1.0564e+01, 1.2652e+01, 1.4879e+01, 1.7217e+01,
          1.9635e+01, 2.2100e+01, 2.4573e+01, 2.7017e+01, 2.9392e+01,
          3.1658e+01, 3.3774e+01, 3.5701e+01, 3.7403e+01, 3.8846e+01,
          4.0000e+01, 4.0838e+01, 4.1340e+01, 4.1490e+01, 4.1279e+01,
          4.0703e+01, 3.9766e+01, 3.8479e+01, 3.6860e+01, 3.4932e+01,
          3.2725e+01, 3.0278e+01, 2.7633e+01, 2.4836e+01, 2.1941e+01,
          1.9002e+01, 1.6078e+01, 1.3229e+01, 1.0515e+01, 7.9954e+00,
          5.7295e+00, 3.7726e+00, 2.1769e+00, 9.8963e-01, 2.5233e-01],
         [0.0000e+00, 1.4193e-01, 5.8121e-01, 1.3342e+00, 2.4120e+00,
          3.8197e+00, 5.5561e+00, 7.6141e+00, 9.9797e+00, 1.2633e+01,
          1.5547e+01, 1.8690e+01, 2.2024e+01, 2.5507e+01, 2.9091e+01,
          3.2725e+01, 3.6357e+01, 3.9931e+01, 4.3392e+01, 4.6682e+01,
          4.9748e+01, 5.2537e+01, 5.4999e+01, 5.7089e+01, 5.8767e+01,
          6.0000e+01, 6.0759e+01, 6.1026e+01, 6.0788e+01, 6.0042e+01,
          5.8793e+01, 5.7056e+01, 5.4854e+01, 5.2218e+01, 4.9189e+01,
          4.5816e+01, 4.2152e+01, 3.8260e+01, 3.4208e+01, 3.0067e+01,
          2.5912e+01, 2.1820e+01, 1.7871e+01, 1.4140e+01, 1.0706e+01,
          7.6393e+00, 5.0096e+00, 2.8792e+00, 1.3038e+00, 3.3118e-01],
         [0.0000e+00, 2.2079e-01, 8.9538e-01, 2.0365e+00, 3.6490e+00,
          5.7295e+00, 8.2665e+00, 1.1240e+01, 1.4621e+01, 1.8375e+01,
          2.2457e+01, 2.6816e+01, 3.1397e+01, 3.6135e+01, 4.0965e+01,
          4.5816e+01, 5.0615e+01, 5.5290e+01, 5.9766e+01, 6.3972e+01,
          6.7838e+01, 7.1300e+01, 7.4296e+01, 7.6775e+01, 7.8689e+01,
          8.0000e+01, 8.0681e+01, 8.0712e+01, 8.0086e+01, 7.8805e+01,
          7.6883e+01, 7.4346e+01, 7.1228e+01, 6.7576e+01, 6.3447e+01,
          5.8906e+01, 5.4026e+01, 4.8888e+01, 4.3580e+01, 3.8193e+01,
          3.2822e+01, 2.7563e+01, 2.2512e+01, 1.7766e+01, 1.3416e+01,
          9.5491e+00, 6.2465e+00, 3.5814e+00, 1.6180e+00, 4.1003e-01],
         [0.0000e+00, 2.9964e-01, 1.2095e+00, 2.7387e+00, 4.8859e+00,
          7.6393e+00, 1.0977e+01, 1.4866e+01, 1.9263e+01, 2.4117e+01,
          2.9367e+01, 3.4943e+01, 4.0769e+01, 4.6763e+01, 5.2838e+01,
          5.8906e+01, 6.4873e+01, 7.0648e+01, 7.6140e+01, 8.1262e+01,
          8.5928e+01, 9.0063e+01, 9.3594e+01, 9.6461e+01, 9.8610e+01,
          1.0000e+02, 1.0060e+02, 1.0040e+02, 9.9383e+01, 9.7568e+01,
          9.4973e+01, 9.1635e+01, 8.7602e+01, 8.2935e+01, 7.7705e+01,
          7.1996e+01, 6.5900e+01, 5.9516e+01, 5.2952e+01, 4.6319e+01,
          3.9732e+01, 3.3305e+01, 2.7154e+01, 2.1392e+01, 1.6126e+01,
          1.1459e+01, 7.4834e+00, 4.2836e+00, 1.9321e+00, 4.8889e-01],
         [0.0000e+00, 3.7849e-01, 1.5237e+00, 3.4410e+00, 6.1228e+00,
          9.5491e+00, 1.3687e+01, 1.8491e+01, 2.3905e+01, 2.9859e+01,
          3.6277e+01, 4.3069e+01, 5.0141e+01, 5.7391e+01, 6.4712e+01,
          7.1996e+01, 7.9131e+01, 8.6006e+01, 9.2514e+01, 9.8551e+01,
          1.0402e+02, 1.0883e+02, 1.1289e+02, 1.1615e+02, 1.1853e+02,
          1.2000e+02, 1.2052e+02, 1.2008e+02, 1.1868e+02, 1.1633e+02,
          1.1306e+02, 1.0893e+02, 1.0398e+02, 9.8293e+01, 9.1963e+01,
          8.5086e+01, 7.7773e+01, 7.0144e+01, 6.2324e+01, 5.4445e+01,
          4.6641e+01, 3.9047e+01, 3.1796e+01, 2.5018e+01, 1.8837e+01,
          1.3369e+01, 8.7204e+00, 4.9859e+00, 2.2463e+00, 5.6774e-01],
         [0.0000e+00, 4.5735e-01, 1.8379e+00, 4.1432e+00, 7.3598e+00,
          1.1459e+01, 1.6397e+01, 2.2117e+01, 2.8547e+01, 3.5602e+01,
          4.3186e+01, 5.1195e+01, 5.9513e+01, 6.8019e+01, 7.6586e+01,
          8.5086e+01, 9.3389e+01, 1.0136e+02, 1.0889e+02, 1.1584e+02,
          1.2211e+02, 1.2759e+02, 1.3219e+02, 1.3583e+02, 1.3845e+02,
          1.4000e+02, 1.4044e+02, 1.3977e+02, 1.3798e+02, 1.3509e+02,
          1.3115e+02, 1.2621e+02, 1.2035e+02, 1.1365e+02, 1.0622e+02,
          9.8176e+01, 8.9647e+01, 8.0772e+01, 7.1697e+01, 6.2572e+01,
          5.3551e+01, 4.4789e+01, 3.6438e+01, 2.8643e+01, 2.1547e+01,
          1.5279e+01, 9.9573e+00, 5.6881e+00, 2.5605e+00, 6.4659e-01],
         [0.0000e+00, 5.3620e-01, 2.1521e+00, 4.8454e+00, 8.5967e+00,
          1.3369e+01, 1.9108e+01, 2.5743e+01, 3.3188e+01, 4.1344e+01,
          5.0096e+01, 5.9321e+01, 6.8885e+01, 7.8647e+01, 8.8460e+01,
          9.8176e+01, 1.0765e+02, 1.1672e+02, 1.2526e+02, 1.3313e+02,
          1.4020e+02, 1.4635e+02, 1.5149e+02, 1.5552e+02, 1.5837e+02,
          1.6000e+02, 1.6037e+02, 1.5946e+02, 1.5728e+02, 1.5386e+02,
          1.4924e+02, 1.4350e+02, 1.3672e+02, 1.2901e+02, 1.2048e+02,
          1.1127e+02, 1.0152e+02, 9.1400e+01, 8.1069e+01, 7.0698e+01,
          6.0461e+01, 5.0531e+01, 4.1079e+01, 3.2269e+01, 2.4257e+01,
          1.7188e+01, 1.1194e+01, 6.3903e+00, 2.8746e+00, 7.2545e-01],
         [0.0000e+00, 6.1505e-01, 2.4662e+00, 5.5477e+00, 9.8336e+00,
          1.5279e+01, 2.1818e+01, 2.9369e+01, 3.7830e+01, 4.7086e+01,
          5.7006e+01, 6.7447e+01, 7.8257e+01, 8.9274e+01, 1.0033e+02,
          1.1127e+02, 1.2190e+02, 1.3208e+02, 1.4164e+02, 1.5042e+02,
          1.5829e+02, 1.6511e+02, 1.7079e+02, 1.7520e+02, 1.7829e+02,
          1.8000e+02, 1.8029e+02, 1.7914e+02, 1.7657e+02, 1.7262e+02,
          1.6733e+02, 1.6079e+02, 1.5310e+02, 1.4437e+02, 1.3474e+02,
          1.2436e+02, 1.1339e+02, 1.0203e+02, 9.0441e+01, 7.8824e+01,
          6.7371e+01, 5.6274e+01, 4.5721e+01, 3.5895e+01, 2.6968e+01,
          1.8907e+01, 1.2184e+01, 6.8819e+00, 3.0631e+00, 7.6487e-01],
         [0.0000e+00, 6.9390e-01, 2.7804e+00, 6.2499e+00, 1.1071e+01,
          1.7188e+01, 2.4528e+01, 3.2994e+01, 4.2472e+01, 5.2828e+01,
          6.3916e+01, 7.5574e+01, 8.7629e+01, 9.9902e+01, 1.1221e+02,
          1.2436e+02, 1.3616e+02, 1.4744e+02, 1.5801e+02, 1.6771e+02,
          1.7638e+02, 1.8388e+02, 1.9008e+02, 1.9489e+02, 1.9822e+02,
          1.9800e+02, 1.9622e+02, 1.9292e+02, 1.8815e+02, 1.8200e+02,
          1.7457e+02, 1.6598e+02, 1.5637e+02, 1.4590e+02, 1.3474e+02,
          1.2305e+02, 1.1102e+02, 9.8839e+01, 8.6692e+01, 7.4761e+01,
          6.3225e+01, 5.2254e+01, 4.2008e+01, 3.2632e+01, 2.4257e+01,
          1.6997e+01, 1.0947e+01, 6.1797e+00, 2.7490e+00, 6.8602e-01]]])

In the experiment, we used a Hann window, and we can clearly see that the Hann window has been applied to each frame.

  // FFT and transpose to get (batch x fft_size x num_frames)
  const bool complex_fft = input.is_complex();
  const auto onesided = onesidedOpt.value_or(!complex_fft);

  const fft_norm_mode norm = normalized ? fft_norm_mode::by_root_n : fft_norm_mode::none;
  Tensor out;
  if (complex_fft) {
    TORCH_CHECK(!onesided, "Cannot have onesided output if window or input is complex");
    out = at::_fft_c2c(input, input.dim() - 1, static_cast<int64_t>(norm), /*forward=*/true);
  } else {
    out = at::_fft_r2c(input, input.dim() - 1, static_cast<int64_t>(norm), onesided);
  }
  out.transpose_(1, 2);
  1. Perform FFT (Fast Fourier Transform) computation on each frame, and finally, through the transpose operation, place the same frequency bin together in each frame. First, it will detect whether the input Tensor is real or complex. Since our input is real, onesided will be set to true, and in the end, only n_fft/2 + 1 bins will be returned. Then it will check whether normalized is true. Since the default value of normalized is false, and we did not set it to true, the norm will end up being fft_norm_mode::none, meaning that the result will not be normalized. Next is the actual computation, and finally, through the transpose operation, the same frequency bins from each frame are placed together. We can verify our idea through an experiment.

experiment

import torch
import torch.nn.functional as F


if __name__ == '__main__':
    test = torch.empty(200)
    for i in range(200):
        test[i] = i

    signal_dim = test.dim()
    extended_shape = [1] * (3 - signal_dim) + list(test.size())
    pad = int(50 // 2)
    test = F.pad(test.view(extended_shape), [pad, pad], "reflect")
    test = test.view(test.shape[-signal_dim:])
    new = test.unsqueeze(0)
    new = new.as_strided((1, 11, 50), (250, 20, 1))
    hnn = torch.hann_window(50)
    new = new.mul(hnn)
    new = torch.fft.fft(new)
    print("Before_tp:", new)
    new = new.transpose_(1, 2)
    print("After_tp", new)
Show Console
Before_tp: tensor([[[ 1.8568e+02+0.0000e+00j, -2.9432e+01-2.5749e-05j,
          -7.0529e+01-7.6294e-06j,  1.4240e+01-1.9073e-06j,
          -9.7382e+00-6.6757e-06j,  5.2361e+00-4.2915e-06j,
          -3.9971e+00+4.2915e-06j,  2.7580e+00+1.1921e-06j,
          -2.2498e+00-1.6093e-06j,  1.7415e+00-4.1723e-07j,
          -1.4860e+00+4.1127e-06j,  1.2306e+00+1.9372e-06j,
          -1.0858e+00-1.6242e-06j,  9.4092e-01+5.8115e-07j,
          -8.5243e-01+2.8908e-06j,  7.6393e-01+4.7684e-07j,
          -7.0752e-01+2.9802e-07j,  6.5112e-01-3.6359e-06j,
          -6.1475e-01+2.3842e-07j,  5.7838e-01+1.1921e-06j,
          -5.5567e-01-1.9073e-06j,  5.3296e-01+2.8610e-06j,
          -5.2047e-01+4.7684e-07j,  5.0798e-01+0.0000e+00j,
          -5.0399e-01-9.5367e-07j,  5.0000e-01+0.0000e+00j,
          -5.0399e-01+9.5367e-07j,  5.0798e-01-0.0000e+00j,
          -5.2047e-01-4.7684e-07j,  5.3296e-01-2.8610e-06j,
          -5.5567e-01+1.9073e-06j,  5.7838e-01-1.1921e-06j,
          -6.1475e-01-2.3842e-07j,  6.5112e-01+3.6359e-06j,
          -7.0752e-01-2.9802e-07j,  7.6393e-01-4.7684e-07j,
          -8.5243e-01-2.8908e-06j,  9.4092e-01-5.8115e-07j,
          -1.0858e+00+1.6242e-06j,  1.2306e+00-1.9372e-06j,
          -1.4860e+00-4.1127e-06j,  1.7415e+00+4.1723e-07j,
          -2.2498e+00+1.6093e-06j,  2.7580e+00-1.1921e-06j,
          -3.9971e+00-4.2915e-06j,  5.2361e+00+4.2915e-06j,
          -9.7382e+00+6.6757e-06j,  1.4240e+01+1.9073e-06j,
          -7.0529e+01+7.6294e-06j, -2.9432e+01+2.5749e-05j],
         [ 5.0039e+02+0.0000e+00j, -2.4964e+02+1.4907e+02j,
           2.8182e-01-3.3411e+01j,  1.6559e-01-8.6154e+00j,
           3.4311e-02-3.6618e+00j, -8.8814e-02-1.9713e+00j,
          -1.8333e-01-1.1854e+00j, -2.3564e-01-7.2956e-01j,
          -2.4141e-01-4.2715e-01j, -2.0585e-01-2.1897e-01j,
          -1.4190e-01-8.3810e-02j, -6.6800e-02-1.0473e-02j,
           1.9604e-03+1.2476e-02j,  5.0641e-02-2.0648e-04j,
           7.1879e-02-3.1504e-02j,  6.5845e-02-6.4718e-02j,
           3.9343e-02-8.6616e-02j,  3.4683e-03-8.9727e-02j,
          -2.9699e-02-7.3373e-02j, -5.0266e-02-4.2919e-02j,
          -5.3064e-02-7.5495e-03j, -3.8621e-02+2.2579e-02j,
          -1.2657e-02+3.9499e-02j,  1.5934e-02+3.9506e-02j,
           3.7811e-02+2.4242e-02j,  4.5944e-02+0.0000e+00j,
           3.7811e-02-2.4242e-02j,  1.5934e-02-3.9506e-02j,
          -1.2657e-02-3.9499e-02j, -3.8621e-02-2.2579e-02j,
          -5.3064e-02+7.5495e-03j, -5.0266e-02+4.2919e-02j,
          -2.9699e-02+7.3373e-02j,  3.4683e-03+8.9727e-02j,
           3.9343e-02+8.6616e-02j,  6.5845e-02+6.4718e-02j,
           7.1879e-02+3.1504e-02j,  5.0641e-02+2.0648e-04j,
           1.9604e-03-1.2476e-02j, -6.6800e-02+1.0473e-02j,
          -1.4190e-01+8.3810e-02j, -2.0585e-01+2.1897e-01j,
          -2.4141e-01+4.2715e-01j, -2.3564e-01+7.2956e-01j,
          -1.8333e-01+1.1854e+00j, -8.8814e-02+1.9713e+00j,
           3.4311e-02+3.6618e+00j,  1.6559e-01+8.6154e+00j,
           2.8182e-01+3.3411e+01j, -2.4964e+02-1.4907e+02j],
         [ 1.0000e+03+0.0000e+00j, -5.0000e+02+1.4921e+02j,
           3.0518e-05-3.3157e+01j, -1.0729e-05-8.2887e+00j,
           4.8876e-06-3.3149e+00j,  3.3736e-05-1.6568e+00j,
           1.9610e-05-9.4603e-01j, -1.0684e-05-5.9056e-01j,
           9.8869e-06-3.9294e-01j,  1.2480e-05-2.7424e-01j,
           1.4994e-05-1.9861e-01j,  1.3297e-05-1.4803e-01j,
          -4.5490e-06-1.1296e-01j,  9.1051e-06-8.7735e-02j,
           3.3081e-06-6.9134e-02j, -1.0330e-05-5.5033e-02j,
          -1.6905e-05-4.4136e-02j,  1.2197e-05-3.5464e-02j,
          -7.5996e-07-2.8453e-02j, -2.1398e-05-2.2675e-02j,
           7.3910e-06-1.7784e-02j,  5.1260e-06-1.3529e-02j,
          -9.0599e-06-9.7783e-03j,  0.0000e+00-6.3553e-03j,
           0.0000e+00-3.1281e-03j,  3.0518e-05+0.0000e+00j,
           0.0000e+00+3.1281e-03j,  0.0000e+00+6.3553e-03j,
          -9.0599e-06+9.7783e-03j,  5.1260e-06+1.3529e-02j,
           7.3910e-06+1.7784e-02j, -2.1398e-05+2.2675e-02j,
          -7.5996e-07+2.8453e-02j,  1.2197e-05+3.5464e-02j,
          -1.6905e-05+4.4136e-02j, -1.0330e-05+5.5033e-02j,
           3.3081e-06+6.9134e-02j,  9.1051e-06+8.7735e-02j,
          -4.5490e-06+1.1296e-01j,  1.3297e-05+1.4803e-01j,
           1.4994e-05+1.9861e-01j,  1.2480e-05+2.7424e-01j,
           9.8869e-06+3.9294e-01j, -1.0684e-05+5.9056e-01j,
           1.9610e-05+9.4603e-01j,  3.3736e-05+1.6568e+00j,
           4.8876e-06+3.3149e+00j, -1.0729e-05+8.2887e+00j,
           3.0518e-05+3.3157e+01j, -5.0000e+02-1.4921e+02j],
         [ 1.5000e+03+0.0000e+00j, -7.5000e+02+1.4921e+02j,
           3.6240e-05-3.3157e+01j, -2.5034e-05-8.2887e+00j,
           2.0504e-05-3.3149e+00j,  2.8312e-05-1.6568e+00j,
           2.4259e-05-9.4604e-01j, -2.0400e-05-5.9055e-01j,
           5.2154e-06-3.9294e-01j,  1.9066e-05-2.7426e-01j,
           1.6965e-05-1.9861e-01j,  1.0869e-05-1.4803e-01j,
          -3.5157e-07-1.1297e-01j,  1.3185e-05-8.7747e-02j,
           9.9167e-06-6.9107e-02j,  1.7628e-05-5.5049e-02j,
          -2.3469e-06-4.4123e-02j,  2.6584e-05-3.5476e-02j,
          -1.9222e-06-2.8454e-02j, -8.3148e-06-2.2661e-02j,
          -4.7624e-05-1.7806e-02j,  2.5034e-05-1.3548e-02j,
          -2.3842e-06-9.7733e-03j, -7.6294e-06-6.3553e-03j,
           3.0518e-05-3.1738e-03j,  6.1035e-05+0.0000e+00j,
           3.0518e-05+3.1738e-03j, -7.6294e-06+6.3553e-03j,
          -2.3842e-06+9.7733e-03j,  2.5034e-05+1.3548e-02j,
          -4.7624e-05+1.7806e-02j, -8.3148e-06+2.2661e-02j,
          -1.9222e-06+2.8454e-02j,  2.6584e-05+3.5476e-02j,
          -2.3469e-06+4.4123e-02j,  1.7628e-05+5.5049e-02j,
           9.9167e-06+6.9107e-02j,  1.3185e-05+8.7747e-02j,
          -3.5157e-07+1.1297e-01j,  1.0869e-05+1.4803e-01j,
           1.6965e-05+1.9861e-01j,  1.9066e-05+2.7426e-01j,
           5.2154e-06+3.9294e-01j, -2.0400e-05+5.9055e-01j,
           2.4259e-05+9.4604e-01j,  2.8312e-05+1.6568e+00j,
           2.0504e-05+3.3149e+00j, -2.5034e-05+8.2887e+00j,
           3.6240e-05+3.3157e+01j, -7.5000e+02-1.4921e+02j],
         [ 2.0000e+03+0.0000e+00j, -1.0000e+03+1.4921e+02j,
           4.3869e-05-3.3157e+01j, -2.9087e-05-8.2887e+00j,
           4.8637e-05-3.3149e+00j, -1.5199e-05-1.6568e+00j,
           4.4465e-05-9.4605e-01j, -2.7329e-05-5.9054e-01j,
           8.9034e-06-3.9294e-01j,  7.6815e-06-2.7427e-01j,
          -1.0461e-05-1.9855e-01j,  2.1271e-06-1.4803e-01j,
           5.3099e-06-1.1295e-01j,  6.1677e-06-8.7749e-02j,
           1.5119e-05-6.9120e-02j, -2.9594e-05-5.5070e-02j,
          -9.6485e-06-4.4076e-02j,  2.7277e-05-3.5498e-02j,
          -8.2105e-06-2.8441e-02j, -1.1921e-06-2.2653e-02j,
          -5.1558e-05-1.7799e-02j,  3.7670e-05-1.3556e-02j,
          -4.7684e-06-9.7816e-03j, -7.6294e-06-6.3429e-03j,
          -3.0518e-05-3.1738e-03j,  1.2207e-04+0.0000e+00j,
          -3.0518e-05+3.1738e-03j, -7.6294e-06+6.3429e-03j,
          -4.7684e-06+9.7816e-03j,  3.7670e-05+1.3556e-02j,
          -5.1558e-05+1.7799e-02j, -1.1921e-06+2.2653e-02j,
          -8.2105e-06+2.8441e-02j,  2.7277e-05+3.5498e-02j,
          -9.6485e-06+4.4076e-02j, -2.9594e-05+5.5070e-02j,
           1.5119e-05+6.9120e-02j,  6.1677e-06+8.7749e-02j,
           5.3099e-06+1.1295e-01j,  2.1271e-06+1.4803e-01j,
          -1.0461e-05+1.9855e-01j,  7.6815e-06+2.7427e-01j,
           8.9034e-06+3.9294e-01j, -2.7329e-05+5.9054e-01j,
           4.4465e-05+9.4605e-01j, -1.5199e-05+1.6568e+00j,
           4.8637e-05+3.3149e+00j, -2.9087e-05+8.2887e+00j,
           4.3869e-05+3.3157e+01j, -1.0000e+03-1.4921e+02j],
         [ 2.5000e+03+0.0000e+00j, -1.2500e+03+1.4921e+02j,
           4.9591e-05-3.3157e+01j, -7.3910e-06-8.2887e+00j,
          -9.0599e-06-3.3150e+00j,  9.6321e-05-1.6568e+00j,
           2.1726e-05-9.4602e-01j, -2.8864e-05-5.9056e-01j,
           8.4639e-06-3.9293e-01j,  6.5714e-06-2.7428e-01j,
          -2.1730e-05-1.9857e-01j, -5.2843e-05-1.4802e-01j,
           1.3079e-05-1.1299e-01j,  2.0702e-05-8.7752e-02j,
           1.1185e-05-6.9129e-02j, -1.3079e-05-5.4998e-02j,
           2.0258e-05-4.4148e-02j,  3.1456e-05-3.5485e-02j,
          -1.1548e-05-2.8456e-02j, -6.9141e-06-2.2656e-02j,
          -4.7684e-07-1.7781e-02j, -7.8678e-06-1.3556e-02j,
           2.6226e-06-9.7888e-03j, -2.0981e-05-6.3457e-03j,
           0.0000e+00-3.1738e-03j,  1.2207e-04+0.0000e+00j,
           0.0000e+00+3.1738e-03j, -2.0981e-05+6.3457e-03j,
           2.6226e-06+9.7888e-03j, -7.8678e-06+1.3556e-02j,
          -4.7684e-07+1.7781e-02j, -6.9141e-06+2.2656e-02j,
          -1.1548e-05+2.8456e-02j,  3.1456e-05+3.5485e-02j,
           2.0258e-05+4.4148e-02j, -1.3079e-05+5.4998e-02j,
           1.1185e-05+6.9129e-02j,  2.0702e-05+8.7752e-02j,
           1.3079e-05+1.1299e-01j, -5.2843e-05+1.4802e-01j,
          -2.1730e-05+1.9857e-01j,  6.5714e-06+2.7428e-01j,
           8.4639e-06+3.9293e-01j, -2.8864e-05+5.9056e-01j,
           2.1726e-05+9.4602e-01j,  9.6321e-05+1.6568e+00j,
          -9.0599e-06+3.3150e+00j, -7.3910e-06+8.2887e+00j,
           4.9591e-05+3.3157e+01j, -1.2500e+03-1.4921e+02j],
         [ 3.0000e+03+0.0000e+00j, -1.5000e+03+1.4921e+02j,
           6.8665e-05-3.3157e+01j, -2.9087e-05-8.2887e+00j,
           6.4969e-05-3.3150e+00j,  5.9903e-05-1.6569e+00j,
           6.1303e-05-9.4599e-01j, -1.6883e-05-5.9057e-01j,
           5.0887e-06-3.9293e-01j,  4.2900e-05-2.7426e-01j,
           9.3095e-06-1.9855e-01j, -6.6292e-06-1.4803e-01j,
           7.4464e-06-1.1299e-01j,  1.5848e-05-8.7733e-02j,
           5.0755e-05-6.9070e-02j, -3.8873e-05-5.5008e-02j,
           1.1474e-06-4.4102e-02j,  1.9751e-05-3.5476e-02j,
          -8.5086e-06-2.8434e-02j, -2.5421e-05-2.2670e-02j,
          -6.0856e-05-1.7741e-02j, -1.3232e-05-1.3592e-02j,
           1.6689e-06-9.7620e-03j, -1.7166e-05-6.3601e-03j,
           6.1035e-05-3.1128e-03j,  1.2207e-04+0.0000e+00j,
           6.1035e-05+3.1128e-03j, -1.7166e-05+6.3601e-03j,
           1.6689e-06+9.7620e-03j, -1.3232e-05+1.3592e-02j,
          -6.0856e-05+1.7741e-02j, -2.5421e-05+2.2670e-02j,
          -8.5086e-06+2.8434e-02j,  1.9751e-05+3.5476e-02j,
           1.1474e-06+4.4102e-02j, -3.8873e-05+5.5008e-02j,
           5.0755e-05+6.9070e-02j,  1.5848e-05+8.7733e-02j,
           7.4464e-06+1.1299e-01j, -6.6292e-06+1.4803e-01j,
           9.3095e-06+1.9855e-01j,  4.2900e-05+2.7426e-01j,
           5.0887e-06+3.9293e-01j, -1.6883e-05+5.9057e-01j,
           6.1303e-05+9.4599e-01j,  5.9903e-05+1.6569e+00j,
           6.4969e-05+3.3150e+00j, -2.9087e-05+8.2887e+00j,
           6.8665e-05+3.3157e+01j, -1.5000e+03-1.4921e+02j],
         [ 3.5000e+03+0.0000e+00j, -1.7500e+03+1.4921e+02j,
           8.3923e-05-3.3157e+01j, -3.6001e-05-8.2887e+00j,
           2.5034e-05-3.3150e+00j,  8.3923e-05-1.6568e+00j,
           4.7237e-05-9.4597e-01j, -2.8163e-05-5.9057e-01j,
           4.5449e-06-3.9292e-01j,  8.7544e-05-2.7428e-01j,
          -1.3746e-05-1.9852e-01j, -1.2556e-05-1.4803e-01j,
           2.7884e-05-1.1299e-01j,  1.8269e-05-8.7750e-02j,
           1.1617e-05-6.9100e-02j, -1.0572e-05-5.4994e-02j,
          -3.9563e-06-4.4092e-02j,  2.4512e-05-3.5495e-02j,
          -4.6492e-06-2.8434e-02j, -3.5852e-05-2.2684e-02j,
           1.4305e-06-1.7839e-02j, -8.8215e-06-1.3545e-02j,
           9.2983e-06-9.7985e-03j,  1.5259e-05-6.3667e-03j,
           6.1035e-05-3.0518e-03j,  1.2207e-04+0.0000e+00j,
           6.1035e-05+3.0518e-03j,  1.5259e-05+6.3667e-03j,
           9.2983e-06+9.7985e-03j, -8.8215e-06+1.3545e-02j,
           1.4305e-06+1.7839e-02j, -3.5852e-05+2.2684e-02j,
          -4.6492e-06+2.8434e-02j,  2.4512e-05+3.5495e-02j,
          -3.9563e-06+4.4092e-02j, -1.0572e-05+5.4994e-02j,
           1.1617e-05+6.9100e-02j,  1.8269e-05+8.7750e-02j,
           2.7884e-05+1.1299e-01j, -1.2556e-05+1.4803e-01j,
          -1.3746e-05+1.9852e-01j,  8.7544e-05+2.7428e-01j,
           4.5449e-06+3.9292e-01j, -2.8163e-05+5.9057e-01j,
           4.7237e-05+9.4597e-01j,  8.3923e-05+1.6568e+00j,
           2.5034e-05+3.3150e+00j, -3.6001e-05+8.2887e+00j,
           8.3923e-05+3.3157e+01j, -1.7500e+03-1.4921e+02j],
         [ 4.0000e+03+0.0000e+00j, -2.0000e+03+1.4921e+02j,
           9.1553e-05-3.3157e+01j, -6.8903e-05-8.2887e+00j,
           2.1815e-05-3.3150e+00j,  1.2493e-04-1.6568e+00j,
           7.2449e-05-9.4595e-01j, -5.8800e-05-5.9057e-01j,
           1.0714e-05-3.9295e-01j,  1.0174e-04-2.7430e-01j,
          -4.6659e-05-1.9854e-01j, -2.6269e-05-1.4798e-01j,
          -3.0464e-06-1.1303e-01j,  2.1524e-05-8.7722e-02j,
           6.6511e-05-6.9066e-02j, -3.3449e-05-5.4950e-02j,
          -3.1091e-05-4.4052e-02j,  3.3930e-05-3.5480e-02j,
          -3.3230e-05-2.8437e-02j, -1.0818e-05-2.2620e-02j,
           1.6212e-05-1.7867e-02j,  2.4676e-05-1.3526e-02j,
          -5.4836e-06-9.8081e-03j, -2.2888e-05-6.3391e-03j,
          -6.1035e-05-3.1128e-03j, -1.2207e-04+0.0000e+00j,
          -6.1035e-05+3.1128e-03j, -2.2888e-05+6.3391e-03j,
          -5.4836e-06+9.8081e-03j,  2.4676e-05+1.3526e-02j,
           1.6212e-05+1.7867e-02j, -1.0818e-05+2.2620e-02j,
          -3.3230e-05+2.8437e-02j,  3.3930e-05+3.5480e-02j,
          -3.1091e-05+4.4052e-02j, -3.3449e-05+5.4950e-02j,
           6.6511e-05+6.9066e-02j,  2.1524e-05+8.7722e-02j,
          -3.0464e-06+1.1303e-01j, -2.6269e-05+1.4798e-01j,
          -4.6659e-05+1.9854e-01j,  1.0174e-04+2.7430e-01j,
           1.0714e-05+3.9295e-01j, -5.8800e-05+5.9057e-01j,
           7.2449e-05+9.4595e-01j,  1.2493e-04+1.6568e+00j,
           2.1815e-05+3.3150e+00j, -6.8903e-05+8.2887e+00j,
           9.1553e-05+3.3157e+01j, -2.0000e+03-1.4921e+02j],
         [ 4.4992e+03+0.0000e+00j, -2.2507e+03+1.4886e+02j,
          -4.9339e-01-3.3762e+01j, -1.7455e-01-9.0084e+00j,
           1.4484e-01-3.9864e+00j,  3.8557e-01-2.1453e+00j,
           4.9770e-01-1.1795e+00j,  4.7201e-01-5.7314e-01j,
           3.4008e-01-1.9410e-01j,  1.5927e-01+8.4578e-04j,
          -7.4884e-03+4.9216e-02j, -1.1256e-01+1.9278e-03j,
          -1.3694e-01-8.1263e-02j, -9.2869e-02-1.4757e-01j,
          -1.4493e-02-1.6525e-01j,  5.7960e-02-1.3015e-01j,
           9.3750e-02-6.1317e-02j,  8.2363e-02+9.5680e-03j,
           3.5013e-02+5.3722e-02j, -2.2522e-02+5.6695e-02j,
          -6.2935e-02+2.3522e-02j, -6.8807e-02-2.5094e-02j,
          -3.9800e-02-6.3587e-02j,  8.0528e-03-7.2608e-02j,
           5.1025e-02-4.7607e-02j,  6.8115e-02+0.0000e+00j,
           5.1025e-02+4.7607e-02j,  8.0528e-03+7.2608e-02j,
          -3.9800e-02+6.3587e-02j, -6.8807e-02+2.5094e-02j,
          -6.2935e-02-2.3522e-02j, -2.2522e-02-5.6695e-02j,
           3.5013e-02-5.3722e-02j,  8.2363e-02-9.5680e-03j,
           9.3750e-02+6.1317e-02j,  5.7960e-02+1.3015e-01j,
          -1.4493e-02+1.6525e-01j, -9.2869e-02+1.4757e-01j,
          -1.3694e-01+8.1263e-02j, -1.1256e-01-1.9278e-03j,
          -7.4884e-03-4.9216e-02j,  1.5927e-01-8.4578e-04j,
           3.4008e-01+1.9410e-01j,  4.7201e-01+5.7314e-01j,
           4.9770e-01+1.1795e+00j,  3.8557e-01+2.1453e+00j,
           1.4484e-01+3.9864e+00j, -1.7455e-01+9.0084e+00j,
          -4.9339e-01+3.3762e+01j, -2.2507e+03-1.4886e+02j],
         [ 4.7883e+03+0.0000e+00j, -2.4571e+03-1.5895e+01j,
           6.9529e+01+1.0568e+01j, -1.3240e+01-5.2422e+00j,
           8.7382e+00+4.1599e+00j, -4.2360e+00-3.0777e+00j,
           2.9972e+00+2.6014e+00j, -1.7581e+00-2.1251e+00j,
           1.2498e+00+1.8505e+00j, -7.4142e-01-1.5758e+00j,
           4.8607e-01+1.3923e+00j, -2.3058e-01-1.2087e+00j,
           8.5761e-02+1.0738e+00j,  5.9072e-02-9.3906e-01j,
          -1.4752e-01+8.3284e-01j,  2.3605e-01-7.2658e-01j,
          -2.9239e-01+6.3817e-01j,  3.4893e-01-5.4978e-01j,
          -3.8530e-01+4.7289e-01j,  4.2160e-01-3.9590e-01j,
          -4.4431e-01+3.2632e-01j,  4.6710e-01-2.5672e-01j,
          -4.7955e-01+1.9158e-01j,  4.9201e-01-1.2634e-01j,
          -4.9609e-01+6.3110e-02j,  5.0000e-01+0.0000e+00j,
          -4.9609e-01-6.3110e-02j,  4.9201e-01+1.2634e-01j,
          -4.7955e-01-1.9158e-01j,  4.6710e-01+2.5672e-01j,
          -4.4431e-01-3.2632e-01j,  4.2160e-01+3.9590e-01j,
          -3.8530e-01-4.7289e-01j,  3.4893e-01+5.4978e-01j,
          -2.9239e-01-6.3817e-01j,  2.3605e-01+7.2658e-01j,
          -1.4752e-01-8.3284e-01j,  5.9072e-02+9.3906e-01j,
           8.5761e-02-1.0738e+00j, -2.3058e-01+1.2087e+00j,
           4.8607e-01-1.3923e+00j, -7.4142e-01+1.5758e+00j,
           1.2498e+00-1.8505e+00j, -1.7581e+00+2.1251e+00j,
           2.9972e+00-2.6014e+00j, -4.2360e+00+3.0777e+00j,
           8.7382e+00-4.1599e+00j, -1.3240e+01+5.2422e+00j,
           6.9529e+01-1.0568e+01j, -2.4571e+03+1.5895e+01j]]])
After_tp tensor([[[ 1.8568e+02+0.0000e+00j,  5.0039e+02+0.0000e+00j,
           1.0000e+03+0.0000e+00j,  1.5000e+03+0.0000e+00j,
           2.0000e+03+0.0000e+00j,  2.5000e+03+0.0000e+00j,
           3.0000e+03+0.0000e+00j,  3.5000e+03+0.0000e+00j,
           4.0000e+03+0.0000e+00j,  4.4992e+03+0.0000e+00j,
           4.7883e+03+0.0000e+00j],
         [-2.9432e+01-2.5749e-05j, -2.4964e+02+1.4907e+02j,
          -5.0000e+02+1.4921e+02j, -7.5000e+02+1.4921e+02j,
          -1.0000e+03+1.4921e+02j, -1.2500e+03+1.4921e+02j,
          -1.5000e+03+1.4921e+02j, -1.7500e+03+1.4921e+02j,
          -2.0000e+03+1.4921e+02j, -2.2507e+03+1.4886e+02j,
          -2.4571e+03-1.5895e+01j],
         [-7.0529e+01-7.6294e-06j,  2.8182e-01-3.3411e+01j,
           3.0518e-05-3.3157e+01j,  3.6240e-05-3.3157e+01j,
           4.3869e-05-3.3157e+01j,  4.9591e-05-3.3157e+01j,
           6.8665e-05-3.3157e+01j,  8.3923e-05-3.3157e+01j,
           9.1553e-05-3.3157e+01j, -4.9339e-01-3.3762e+01j,
           6.9529e+01+1.0568e+01j],
         [ 1.4240e+01-1.9073e-06j,  1.6559e-01-8.6154e+00j,
          -1.0729e-05-8.2887e+00j, -2.5034e-05-8.2887e+00j,
          -2.9087e-05-8.2887e+00j, -7.3910e-06-8.2887e+00j,
          -2.9087e-05-8.2887e+00j, -3.6001e-05-8.2887e+00j,
          -6.8903e-05-8.2887e+00j, -1.7455e-01-9.0084e+00j,
          -1.3240e+01-5.2422e+00j],
         [-9.7382e+00-6.6757e-06j,  3.4311e-02-3.6618e+00j,
           4.8876e-06-3.3149e+00j,  2.0504e-05-3.3149e+00j,
           4.8637e-05-3.3149e+00j, -9.0599e-06-3.3150e+00j,
           6.4969e-05-3.3150e+00j,  2.5034e-05-3.3150e+00j,
           2.1815e-05-3.3150e+00j,  1.4484e-01-3.9864e+00j,
           8.7382e+00+4.1599e+00j],
         [ 5.2361e+00-4.2915e-06j, -8.8814e-02-1.9713e+00j,
           3.3736e-05-1.6568e+00j,  2.8312e-05-1.6568e+00j,
          -1.5199e-05-1.6568e+00j,  9.6321e-05-1.6568e+00j,
           5.9903e-05-1.6569e+00j,  8.3923e-05-1.6568e+00j,
           1.2493e-04-1.6568e+00j,  3.8557e-01-2.1453e+00j,
          -4.2360e+00-3.0777e+00j],
         [-3.9971e+00+4.2915e-06j, -1.8333e-01-1.1854e+00j,
           1.9610e-05-9.4603e-01j,  2.4259e-05-9.4604e-01j,
           4.4465e-05-9.4605e-01j,  2.1726e-05-9.4602e-01j,
           6.1303e-05-9.4599e-01j,  4.7237e-05-9.4597e-01j,
           7.2449e-05-9.4595e-01j,  4.9770e-01-1.1795e+00j,
           2.9972e+00+2.6014e+00j],
         [ 2.7580e+00+1.1921e-06j, -2.3564e-01-7.2956e-01j,
          -1.0684e-05-5.9056e-01j, -2.0400e-05-5.9055e-01j,
          -2.7329e-05-5.9054e-01j, -2.8864e-05-5.9056e-01j,
          -1.6883e-05-5.9057e-01j, -2.8163e-05-5.9057e-01j,
          -5.8800e-05-5.9057e-01j,  4.7201e-01-5.7314e-01j,
          -1.7581e+00-2.1251e+00j],
         [-2.2498e+00-1.6093e-06j, -2.4141e-01-4.2715e-01j,
           9.8869e-06-3.9294e-01j,  5.2154e-06-3.9294e-01j,
           8.9034e-06-3.9294e-01j,  8.4639e-06-3.9293e-01j,
           5.0887e-06-3.9293e-01j,  4.5449e-06-3.9292e-01j,
           1.0714e-05-3.9295e-01j,  3.4008e-01-1.9410e-01j,
           1.2498e+00+1.8505e+00j],
         [ 1.7415e+00-4.1723e-07j, -2.0585e-01-2.1897e-01j,
           1.2480e-05-2.7424e-01j,  1.9066e-05-2.7426e-01j,
           7.6815e-06-2.7427e-01j,  6.5714e-06-2.7428e-01j,
           4.2900e-05-2.7426e-01j,  8.7544e-05-2.7428e-01j,
           1.0174e-04-2.7430e-01j,  1.5927e-01+8.4578e-04j,
          -7.4142e-01-1.5758e+00j],
         [-1.4860e+00+4.1127e-06j, -1.4190e-01-8.3810e-02j,
           1.4994e-05-1.9861e-01j,  1.6965e-05-1.9861e-01j,
          -1.0461e-05-1.9855e-01j, -2.1730e-05-1.9857e-01j,
           9.3095e-06-1.9855e-01j, -1.3746e-05-1.9852e-01j,
          -4.6659e-05-1.9854e-01j, -7.4884e-03+4.9216e-02j,
           4.8607e-01+1.3923e+00j],
         [ 1.2306e+00+1.9372e-06j, -6.6800e-02-1.0473e-02j,
           1.3297e-05-1.4803e-01j,  1.0869e-05-1.4803e-01j,
           2.1271e-06-1.4803e-01j, -5.2843e-05-1.4802e-01j,
          -6.6292e-06-1.4803e-01j, -1.2556e-05-1.4803e-01j,
          -2.6269e-05-1.4798e-01j, -1.1256e-01+1.9278e-03j,
          -2.3058e-01-1.2087e+00j],
         [-1.0858e+00-1.6242e-06j,  1.9604e-03+1.2476e-02j,
          -4.5490e-06-1.1296e-01j, -3.5157e-07-1.1297e-01j,
           5.3099e-06-1.1295e-01j,  1.3079e-05-1.1299e-01j,
           7.4464e-06-1.1299e-01j,  2.7884e-05-1.1299e-01j,
          -3.0464e-06-1.1303e-01j, -1.3694e-01-8.1263e-02j,
           8.5761e-02+1.0738e+00j],
         [ 9.4092e-01+5.8115e-07j,  5.0641e-02-2.0648e-04j,
           9.1051e-06-8.7735e-02j,  1.3185e-05-8.7747e-02j,
           6.1677e-06-8.7749e-02j,  2.0702e-05-8.7752e-02j,
           1.5848e-05-8.7733e-02j,  1.8269e-05-8.7750e-02j,
           2.1524e-05-8.7722e-02j, -9.2869e-02-1.4757e-01j,
           5.9072e-02-9.3906e-01j],
         [-8.5243e-01+2.8908e-06j,  7.1879e-02-3.1504e-02j,
           3.3081e-06-6.9134e-02j,  9.9167e-06-6.9107e-02j,
           1.5119e-05-6.9120e-02j,  1.1185e-05-6.9129e-02j,
           5.0755e-05-6.9070e-02j,  1.1617e-05-6.9100e-02j,
           6.6511e-05-6.9066e-02j, -1.4493e-02-1.6525e-01j,
          -1.4752e-01+8.3284e-01j],
         [ 7.6393e-01+4.7684e-07j,  6.5845e-02-6.4718e-02j,
          -1.0330e-05-5.5033e-02j,  1.7628e-05-5.5049e-02j,
          -2.9594e-05-5.5070e-02j, -1.3079e-05-5.4998e-02j,
          -3.8873e-05-5.5008e-02j, -1.0572e-05-5.4994e-02j,
          -3.3449e-05-5.4950e-02j,  5.7960e-02-1.3015e-01j,
           2.3605e-01-7.2658e-01j],
         [-7.0752e-01+2.9802e-07j,  3.9343e-02-8.6616e-02j,
          -1.6905e-05-4.4136e-02j, -2.3469e-06-4.4123e-02j,
          -9.6485e-06-4.4076e-02j,  2.0258e-05-4.4148e-02j,
           1.1474e-06-4.4102e-02j, -3.9563e-06-4.4092e-02j,
          -3.1091e-05-4.4052e-02j,  9.3750e-02-6.1317e-02j,
          -2.9239e-01+6.3817e-01j],
         [ 6.5112e-01-3.6359e-06j,  3.4683e-03-8.9727e-02j,
           1.2197e-05-3.5464e-02j,  2.6584e-05-3.5476e-02j,
           2.7277e-05-3.5498e-02j,  3.1456e-05-3.5485e-02j,
           1.9751e-05-3.5476e-02j,  2.4512e-05-3.5495e-02j,
           3.3930e-05-3.5480e-02j,  8.2363e-02+9.5680e-03j,
           3.4893e-01-5.4978e-01j],
         [-6.1475e-01+2.3842e-07j, -2.9699e-02-7.3373e-02j,
          -7.5996e-07-2.8453e-02j, -1.9222e-06-2.8454e-02j,
          -8.2105e-06-2.8441e-02j, -1.1548e-05-2.8456e-02j,
          -8.5086e-06-2.8434e-02j, -4.6492e-06-2.8434e-02j,
          -3.3230e-05-2.8437e-02j,  3.5013e-02+5.3722e-02j,
          -3.8530e-01+4.7289e-01j],
         [ 5.7838e-01+1.1921e-06j, -5.0266e-02-4.2919e-02j,
          -2.1398e-05-2.2675e-02j, -8.3148e-06-2.2661e-02j,
          -1.1921e-06-2.2653e-02j, -6.9141e-06-2.2656e-02j,
          -2.5421e-05-2.2670e-02j, -3.5852e-05-2.2684e-02j,
          -1.0818e-05-2.2620e-02j, -2.2522e-02+5.6695e-02j,
           4.2160e-01-3.9590e-01j],
         [-5.5567e-01-1.9073e-06j, -5.3064e-02-7.5495e-03j,
           7.3910e-06-1.7784e-02j, -4.7624e-05-1.7806e-02j,
          -5.1558e-05-1.7799e-02j, -4.7684e-07-1.7781e-02j,
          -6.0856e-05-1.7741e-02j,  1.4305e-06-1.7839e-02j,
           1.6212e-05-1.7867e-02j, -6.2935e-02+2.3522e-02j,
          -4.4431e-01+3.2632e-01j],
         [ 5.3296e-01+2.8610e-06j, -3.8621e-02+2.2579e-02j,
           5.1260e-06-1.3529e-02j,  2.5034e-05-1.3548e-02j,
           3.7670e-05-1.3556e-02j, -7.8678e-06-1.3556e-02j,
          -1.3232e-05-1.3592e-02j, -8.8215e-06-1.3545e-02j,
           2.4676e-05-1.3526e-02j, -6.8807e-02-2.5094e-02j,
           4.6710e-01-2.5672e-01j],
         [-5.2047e-01+4.7684e-07j, -1.2657e-02+3.9499e-02j,
          -9.0599e-06-9.7783e-03j, -2.3842e-06-9.7733e-03j,
          -4.7684e-06-9.7816e-03j,  2.6226e-06-9.7888e-03j,
           1.6689e-06-9.7620e-03j,  9.2983e-06-9.7985e-03j,
          -5.4836e-06-9.8081e-03j, -3.9800e-02-6.3587e-02j,
          -4.7955e-01+1.9158e-01j],
         [ 5.0798e-01+0.0000e+00j,  1.5934e-02+3.9506e-02j,
           0.0000e+00-6.3553e-03j, -7.6294e-06-6.3553e-03j,
          -7.6294e-06-6.3429e-03j, -2.0981e-05-6.3457e-03j,
          -1.7166e-05-6.3601e-03j,  1.5259e-05-6.3667e-03j,
          -2.2888e-05-6.3391e-03j,  8.0528e-03-7.2608e-02j,
           4.9201e-01-1.2634e-01j],
         [-5.0399e-01-9.5367e-07j,  3.7811e-02+2.4242e-02j,
           0.0000e+00-3.1281e-03j,  3.0518e-05-3.1738e-03j,
          -3.0518e-05-3.1738e-03j,  0.0000e+00-3.1738e-03j,
           6.1035e-05-3.1128e-03j,  6.1035e-05-3.0518e-03j,
          -6.1035e-05-3.1128e-03j,  5.1025e-02-4.7607e-02j,
          -4.9609e-01+6.3110e-02j],
         [ 5.0000e-01+0.0000e+00j,  4.5944e-02+0.0000e+00j,
           3.0518e-05+0.0000e+00j,  6.1035e-05+0.0000e+00j,
           1.2207e-04+0.0000e+00j,  1.2207e-04+0.0000e+00j,
           1.2207e-04+0.0000e+00j,  1.2207e-04+0.0000e+00j,
          -1.2207e-04+0.0000e+00j,  6.8115e-02+0.0000e+00j,
           5.0000e-01+0.0000e+00j],
         [-5.0399e-01+9.5367e-07j,  3.7811e-02-2.4242e-02j,
           0.0000e+00+3.1281e-03j,  3.0518e-05+3.1738e-03j,
          -3.0518e-05+3.1738e-03j,  0.0000e+00+3.1738e-03j,
           6.1035e-05+3.1128e-03j,  6.1035e-05+3.0518e-03j,
          -6.1035e-05+3.1128e-03j,  5.1025e-02+4.7607e-02j,
          -4.9609e-01-6.3110e-02j],
         [ 5.0798e-01-0.0000e+00j,  1.5934e-02-3.9506e-02j,
           0.0000e+00+6.3553e-03j, -7.6294e-06+6.3553e-03j,
          -7.6294e-06+6.3429e-03j, -2.0981e-05+6.3457e-03j,
          -1.7166e-05+6.3601e-03j,  1.5259e-05+6.3667e-03j,
          -2.2888e-05+6.3391e-03j,  8.0528e-03+7.2608e-02j,
           4.9201e-01+1.2634e-01j],
         [-5.2047e-01-4.7684e-07j, -1.2657e-02-3.9499e-02j,
          -9.0599e-06+9.7783e-03j, -2.3842e-06+9.7733e-03j,
          -4.7684e-06+9.7816e-03j,  2.6226e-06+9.7888e-03j,
           1.6689e-06+9.7620e-03j,  9.2983e-06+9.7985e-03j,
          -5.4836e-06+9.8081e-03j, -3.9800e-02+6.3587e-02j,
          -4.7955e-01-1.9158e-01j],
         [ 5.3296e-01-2.8610e-06j, -3.8621e-02-2.2579e-02j,
           5.1260e-06+1.3529e-02j,  2.5034e-05+1.3548e-02j,
           3.7670e-05+1.3556e-02j, -7.8678e-06+1.3556e-02j,
          -1.3232e-05+1.3592e-02j, -8.8215e-06+1.3545e-02j,
           2.4676e-05+1.3526e-02j, -6.8807e-02+2.5094e-02j,
           4.6710e-01+2.5672e-01j],
         [-5.5567e-01+1.9073e-06j, -5.3064e-02+7.5495e-03j,
           7.3910e-06+1.7784e-02j, -4.7624e-05+1.7806e-02j,
          -5.1558e-05+1.7799e-02j, -4.7684e-07+1.7781e-02j,
          -6.0856e-05+1.7741e-02j,  1.4305e-06+1.7839e-02j,
           1.6212e-05+1.7867e-02j, -6.2935e-02-2.3522e-02j,
          -4.4431e-01-3.2632e-01j],
         [ 5.7838e-01-1.1921e-06j, -5.0266e-02+4.2919e-02j,
          -2.1398e-05+2.2675e-02j, -8.3148e-06+2.2661e-02j,
          -1.1921e-06+2.2653e-02j, -6.9141e-06+2.2656e-02j,
          -2.5421e-05+2.2670e-02j, -3.5852e-05+2.2684e-02j,
          -1.0818e-05+2.2620e-02j, -2.2522e-02-5.6695e-02j,
           4.2160e-01+3.9590e-01j],
         [-6.1475e-01-2.3842e-07j, -2.9699e-02+7.3373e-02j,
          -7.5996e-07+2.8453e-02j, -1.9222e-06+2.8454e-02j,
          -8.2105e-06+2.8441e-02j, -1.1548e-05+2.8456e-02j,
          -8.5086e-06+2.8434e-02j, -4.6492e-06+2.8434e-02j,
          -3.3230e-05+2.8437e-02j,  3.5013e-02-5.3722e-02j,
          -3.8530e-01-4.7289e-01j],
         [ 6.5112e-01+3.6359e-06j,  3.4683e-03+8.9727e-02j,
           1.2197e-05+3.5464e-02j,  2.6584e-05+3.5476e-02j,
           2.7277e-05+3.5498e-02j,  3.1456e-05+3.5485e-02j,
           1.9751e-05+3.5476e-02j,  2.4512e-05+3.5495e-02j,
           3.3930e-05+3.5480e-02j,  8.2363e-02-9.5680e-03j,
           3.4893e-01+5.4978e-01j],
         [-7.0752e-01-2.9802e-07j,  3.9343e-02+8.6616e-02j,
          -1.6905e-05+4.4136e-02j, -2.3469e-06+4.4123e-02j,
          -9.6485e-06+4.4076e-02j,  2.0258e-05+4.4148e-02j,
           1.1474e-06+4.4102e-02j, -3.9563e-06+4.4092e-02j,
          -3.1091e-05+4.4052e-02j,  9.3750e-02+6.1317e-02j,
          -2.9239e-01-6.3817e-01j],
         [ 7.6393e-01-4.7684e-07j,  6.5845e-02+6.4718e-02j,
          -1.0330e-05+5.5033e-02j,  1.7628e-05+5.5049e-02j,
          -2.9594e-05+5.5070e-02j, -1.3079e-05+5.4998e-02j,
          -3.8873e-05+5.5008e-02j, -1.0572e-05+5.4994e-02j,
          -3.3449e-05+5.4950e-02j,  5.7960e-02+1.3015e-01j,
           2.3605e-01+7.2658e-01j],
         [-8.5243e-01-2.8908e-06j,  7.1879e-02+3.1504e-02j,
           3.3081e-06+6.9134e-02j,  9.9167e-06+6.9107e-02j,
           1.5119e-05+6.9120e-02j,  1.1185e-05+6.9129e-02j,
           5.0755e-05+6.9070e-02j,  1.1617e-05+6.9100e-02j,
           6.6511e-05+6.9066e-02j, -1.4493e-02+1.6525e-01j,
          -1.4752e-01-8.3284e-01j],
         [ 9.4092e-01-5.8115e-07j,  5.0641e-02+2.0648e-04j,
           9.1051e-06+8.7735e-02j,  1.3185e-05+8.7747e-02j,
           6.1677e-06+8.7749e-02j,  2.0702e-05+8.7752e-02j,
           1.5848e-05+8.7733e-02j,  1.8269e-05+8.7750e-02j,
           2.1524e-05+8.7722e-02j, -9.2869e-02+1.4757e-01j,
           5.9072e-02+9.3906e-01j],
         [-1.0858e+00+1.6242e-06j,  1.9604e-03-1.2476e-02j,
          -4.5490e-06+1.1296e-01j, -3.5157e-07+1.1297e-01j,
           5.3099e-06+1.1295e-01j,  1.3079e-05+1.1299e-01j,
           7.4464e-06+1.1299e-01j,  2.7884e-05+1.1299e-01j,
          -3.0464e-06+1.1303e-01j, -1.3694e-01+8.1263e-02j,
           8.5761e-02-1.0738e+00j],
         [ 1.2306e+00-1.9372e-06j, -6.6800e-02+1.0473e-02j,
           1.3297e-05+1.4803e-01j,  1.0869e-05+1.4803e-01j,
           2.1271e-06+1.4803e-01j, -5.2843e-05+1.4802e-01j,
          -6.6292e-06+1.4803e-01j, -1.2556e-05+1.4803e-01j,
          -2.6269e-05+1.4798e-01j, -1.1256e-01-1.9278e-03j,
          -2.3058e-01+1.2087e+00j],
         [-1.4860e+00-4.1127e-06j, -1.4190e-01+8.3810e-02j,
           1.4994e-05+1.9861e-01j,  1.6965e-05+1.9861e-01j,
          -1.0461e-05+1.9855e-01j, -2.1730e-05+1.9857e-01j,
           9.3095e-06+1.9855e-01j, -1.3746e-05+1.9852e-01j,
          -4.6659e-05+1.9854e-01j, -7.4884e-03-4.9216e-02j,
           4.8607e-01-1.3923e+00j],
         [ 1.7415e+00+4.1723e-07j, -2.0585e-01+2.1897e-01j,
           1.2480e-05+2.7424e-01j,  1.9066e-05+2.7426e-01j,
           7.6815e-06+2.7427e-01j,  6.5714e-06+2.7428e-01j,
           4.2900e-05+2.7426e-01j,  8.7544e-05+2.7428e-01j,
           1.0174e-04+2.7430e-01j,  1.5927e-01-8.4578e-04j,
          -7.4142e-01+1.5758e+00j],
         [-2.2498e+00+1.6093e-06j, -2.4141e-01+4.2715e-01j,
           9.8869e-06+3.9294e-01j,  5.2154e-06+3.9294e-01j,
           8.9034e-06+3.9294e-01j,  8.4639e-06+3.9293e-01j,
           5.0887e-06+3.9293e-01j,  4.5449e-06+3.9292e-01j,
           1.0714e-05+3.9295e-01j,  3.4008e-01+1.9410e-01j,
           1.2498e+00-1.8505e+00j],
         [ 2.7580e+00-1.1921e-06j, -2.3564e-01+7.2956e-01j,
          -1.0684e-05+5.9056e-01j, -2.0400e-05+5.9055e-01j,
          -2.7329e-05+5.9054e-01j, -2.8864e-05+5.9056e-01j,
          -1.6883e-05+5.9057e-01j, -2.8163e-05+5.9057e-01j,
          -5.8800e-05+5.9057e-01j,  4.7201e-01+5.7314e-01j,
          -1.7581e+00+2.1251e+00j],
         [-3.9971e+00-4.2915e-06j, -1.8333e-01+1.1854e+00j,
           1.9610e-05+9.4603e-01j,  2.4259e-05+9.4604e-01j,
           4.4465e-05+9.4605e-01j,  2.1726e-05+9.4602e-01j,
           6.1303e-05+9.4599e-01j,  4.7237e-05+9.4597e-01j,
           7.2449e-05+9.4595e-01j,  4.9770e-01+1.1795e+00j,
           2.9972e+00-2.6014e+00j],
         [ 5.2361e+00+4.2915e-06j, -8.8814e-02+1.9713e+00j,
           3.3736e-05+1.6568e+00j,  2.8312e-05+1.6568e+00j,
          -1.5199e-05+1.6568e+00j,  9.6321e-05+1.6568e+00j,
           5.9903e-05+1.6569e+00j,  8.3923e-05+1.6568e+00j,
           1.2493e-04+1.6568e+00j,  3.8557e-01+2.1453e+00j,
          -4.2360e+00+3.0777e+00j],
         [-9.7382e+00+6.6757e-06j,  3.4311e-02+3.6618e+00j,
           4.8876e-06+3.3149e+00j,  2.0504e-05+3.3149e+00j,
           4.8637e-05+3.3149e+00j, -9.0599e-06+3.3150e+00j,
           6.4969e-05+3.3150e+00j,  2.5034e-05+3.3150e+00j,
           2.1815e-05+3.3150e+00j,  1.4484e-01+3.9864e+00j,
           8.7382e+00-4.1599e+00j],
         [ 1.4240e+01+1.9073e-06j,  1.6559e-01+8.6154e+00j,
          -1.0729e-05+8.2887e+00j, -2.5034e-05+8.2887e+00j,
          -2.9087e-05+8.2887e+00j, -7.3910e-06+8.2887e+00j,
          -2.9087e-05+8.2887e+00j, -3.6001e-05+8.2887e+00j,
          -6.8903e-05+8.2887e+00j, -1.7455e-01+9.0084e+00j,
          -1.3240e+01+5.2422e+00j],
         [-7.0529e+01+7.6294e-06j,  2.8182e-01+3.3411e+01j,
           3.0518e-05+3.3157e+01j,  3.6240e-05+3.3157e+01j,
           4.3869e-05+3.3157e+01j,  4.9591e-05+3.3157e+01j,
           6.8665e-05+3.3157e+01j,  8.3923e-05+3.3157e+01j,
           9.1553e-05+3.3157e+01j, -4.9339e-01+3.3762e+01j,
           6.9529e+01-1.0568e+01j],
         [-2.9432e+01+2.5749e-05j, -2.4964e+02-1.4907e+02j,
          -5.0000e+02-1.4921e+02j, -7.5000e+02-1.4921e+02j,
          -1.0000e+03-1.4921e+02j, -1.2500e+03-1.4921e+02j,
          -1.5000e+03-1.4921e+02j, -1.7500e+03-1.4921e+02j,
          -2.0000e+03-1.4921e+02j, -2.2507e+03-1.4886e+02j,
          -2.4571e+03+1.5895e+01j]]])

reference

import torch


if __name__ == '__main__':
    hnn = torch.hann_window(50)
    test = torch.empty(200)
    for i in range(200):
        test[i] = i
    print("Reference", torch.stft(test, 50, 20, window=hnn, return_complex=True))
Show Console
Reference: tensor([[ 1.8568e+02+0.0000e+00j,  5.0039e+02+0.0000e+00j,
          1.0000e+03+0.0000e+00j,  1.5000e+03+0.0000e+00j,
          2.0000e+03+0.0000e+00j,  2.5000e+03+0.0000e+00j,
          3.0000e+03+0.0000e+00j,  3.5000e+03+0.0000e+00j,
          4.0000e+03+0.0000e+00j,  4.4992e+03+0.0000e+00j,
          4.7883e+03+0.0000e+00j],
        [-2.9432e+01-2.5749e-05j, -2.4964e+02+1.4907e+02j,
         -5.0000e+02+1.4921e+02j, -7.5000e+02+1.4921e+02j,
         -1.0000e+03+1.4921e+02j, -1.2500e+03+1.4921e+02j,
         -1.5000e+03+1.4921e+02j, -1.7500e+03+1.4921e+02j,
         -2.0000e+03+1.4921e+02j, -2.2507e+03+1.4886e+02j,
         -2.4571e+03-1.5895e+01j],
        [-7.0529e+01-7.6294e-06j,  2.8182e-01-3.3411e+01j,
          3.0518e-05-3.3157e+01j,  3.6240e-05-3.3157e+01j,
          4.3869e-05-3.3157e+01j,  4.9591e-05-3.3157e+01j,
          6.8665e-05-3.3157e+01j,  8.3923e-05-3.3157e+01j,
          9.1553e-05-3.3157e+01j, -4.9339e-01-3.3762e+01j,
          6.9529e+01+1.0568e+01j],
        [ 1.4240e+01-1.9073e-06j,  1.6559e-01-8.6154e+00j,
         -1.0729e-05-8.2887e+00j, -2.5034e-05-8.2887e+00j,
         -2.9087e-05-8.2887e+00j, -7.3910e-06-8.2887e+00j,
         -2.9087e-05-8.2887e+00j, -3.6001e-05-8.2887e+00j,
         -6.8903e-05-8.2887e+00j, -1.7455e-01-9.0084e+00j,
         -1.3240e+01-5.2422e+00j],
        [-9.7382e+00-6.6757e-06j,  3.4311e-02-3.6618e+00j,
          4.8876e-06-3.3149e+00j,  2.0504e-05-3.3149e+00j,
          4.8637e-05-3.3149e+00j, -9.0599e-06-3.3150e+00j,
          6.4969e-05-3.3150e+00j,  2.5034e-05-3.3150e+00j,
          2.1815e-05-3.3150e+00j,  1.4484e-01-3.9864e+00j,
          8.7382e+00+4.1599e+00j],
        [ 5.2361e+00-4.2915e-06j, -8.8814e-02-1.9713e+00j,
          3.3736e-05-1.6568e+00j,  2.8312e-05-1.6568e+00j,
         -1.5199e-05-1.6568e+00j,  9.6321e-05-1.6568e+00j,
          5.9903e-05-1.6569e+00j,  8.3923e-05-1.6568e+00j,
          1.2493e-04-1.6568e+00j,  3.8557e-01-2.1453e+00j,
         -4.2360e+00-3.0777e+00j],
        [-3.9971e+00+4.2915e-06j, -1.8333e-01-1.1854e+00j,
          1.9610e-05-9.4603e-01j,  2.4259e-05-9.4604e-01j,
          4.4465e-05-9.4605e-01j,  2.1726e-05-9.4602e-01j,
          6.1303e-05-9.4599e-01j,  4.7237e-05-9.4597e-01j,
          7.2449e-05-9.4595e-01j,  4.9770e-01-1.1795e+00j,
          2.9972e+00+2.6014e+00j],
        [ 2.7580e+00+1.1921e-06j, -2.3564e-01-7.2956e-01j,
         -1.0684e-05-5.9056e-01j, -2.0400e-05-5.9055e-01j,
         -2.7329e-05-5.9054e-01j, -2.8864e-05-5.9056e-01j,
         -1.6883e-05-5.9057e-01j, -2.8163e-05-5.9057e-01j,
         -5.8800e-05-5.9057e-01j,  4.7201e-01-5.7314e-01j,
         -1.7581e+00-2.1251e+00j],
        [-2.2498e+00-1.6093e-06j, -2.4141e-01-4.2715e-01j,
          9.8869e-06-3.9294e-01j,  5.2154e-06-3.9294e-01j,
          8.9034e-06-3.9294e-01j,  8.4639e-06-3.9293e-01j,
          5.0887e-06-3.9293e-01j,  4.5449e-06-3.9292e-01j,
          1.0714e-05-3.9295e-01j,  3.4008e-01-1.9410e-01j,
          1.2498e+00+1.8505e+00j],
        [ 1.7415e+00-4.1723e-07j, -2.0585e-01-2.1897e-01j,
          1.2480e-05-2.7424e-01j,  1.9066e-05-2.7426e-01j,
          7.6815e-06-2.7427e-01j,  6.5714e-06-2.7428e-01j,
          4.2900e-05-2.7426e-01j,  8.7544e-05-2.7428e-01j,
          1.0174e-04-2.7430e-01j,  1.5927e-01+8.4578e-04j,
         -7.4142e-01-1.5758e+00j],
        [-1.4860e+00+4.1127e-06j, -1.4190e-01-8.3810e-02j,
          1.4994e-05-1.9861e-01j,  1.6965e-05-1.9861e-01j,
         -1.0461e-05-1.9855e-01j, -2.1730e-05-1.9857e-01j,
          9.3095e-06-1.9855e-01j, -1.3746e-05-1.9852e-01j,
         -4.6659e-05-1.9854e-01j, -7.4884e-03+4.9216e-02j,
          4.8607e-01+1.3923e+00j],
        [ 1.2306e+00+1.9372e-06j, -6.6800e-02-1.0473e-02j,
          1.3297e-05-1.4803e-01j,  1.0869e-05-1.4803e-01j,
          2.1271e-06-1.4803e-01j, -5.2843e-05-1.4802e-01j,
         -6.6292e-06-1.4803e-01j, -1.2556e-05-1.4803e-01j,
         -2.6269e-05-1.4798e-01j, -1.1256e-01+1.9278e-03j,
         -2.3058e-01-1.2087e+00j],
        [-1.0858e+00-1.6242e-06j,  1.9604e-03+1.2476e-02j,
         -4.5490e-06-1.1296e-01j, -3.5157e-07-1.1297e-01j,
          5.3099e-06-1.1295e-01j,  1.3079e-05-1.1299e-01j,
          7.4464e-06-1.1299e-01j,  2.7884e-05-1.1299e-01j,
         -3.0464e-06-1.1303e-01j, -1.3694e-01-8.1263e-02j,
          8.5761e-02+1.0738e+00j],
        [ 9.4092e-01+5.8115e-07j,  5.0641e-02-2.0648e-04j,
          9.1051e-06-8.7735e-02j,  1.3185e-05-8.7747e-02j,
          6.1677e-06-8.7749e-02j,  2.0702e-05-8.7752e-02j,
          1.5848e-05-8.7733e-02j,  1.8269e-05-8.7750e-02j,
          2.1524e-05-8.7722e-02j, -9.2869e-02-1.4757e-01j,
          5.9072e-02-9.3906e-01j],
        [-8.5243e-01+2.8908e-06j,  7.1879e-02-3.1504e-02j,
          3.3081e-06-6.9134e-02j,  9.9167e-06-6.9107e-02j,
          1.5119e-05-6.9120e-02j,  1.1185e-05-6.9129e-02j,
          5.0755e-05-6.9070e-02j,  1.1617e-05-6.9100e-02j,
          6.6511e-05-6.9066e-02j, -1.4493e-02-1.6525e-01j,
         -1.4752e-01+8.3284e-01j],
        [ 7.6393e-01+4.7684e-07j,  6.5845e-02-6.4718e-02j,
         -1.0330e-05-5.5033e-02j,  1.7628e-05-5.5049e-02j,
         -2.9594e-05-5.5070e-02j, -1.3079e-05-5.4998e-02j,
         -3.8873e-05-5.5008e-02j, -1.0572e-05-5.4994e-02j,
         -3.3449e-05-5.4950e-02j,  5.7960e-02-1.3015e-01j,
          2.3605e-01-7.2658e-01j],
        [-7.0752e-01+2.9802e-07j,  3.9343e-02-8.6616e-02j,
         -1.6905e-05-4.4136e-02j, -2.3469e-06-4.4123e-02j,
         -9.6485e-06-4.4076e-02j,  2.0258e-05-4.4148e-02j,
          1.1474e-06-4.4102e-02j, -3.9563e-06-4.4092e-02j,
         -3.1091e-05-4.4052e-02j,  9.3750e-02-6.1317e-02j,
         -2.9239e-01+6.3817e-01j],
        [ 6.5112e-01-3.6359e-06j,  3.4683e-03-8.9727e-02j,
          1.2197e-05-3.5464e-02j,  2.6584e-05-3.5476e-02j,
          2.7277e-05-3.5498e-02j,  3.1456e-05-3.5485e-02j,
          1.9751e-05-3.5476e-02j,  2.4512e-05-3.5495e-02j,
          3.3930e-05-3.5480e-02j,  8.2363e-02+9.5680e-03j,
          3.4893e-01-5.4978e-01j],
        [-6.1475e-01+2.3842e-07j, -2.9699e-02-7.3373e-02j,
         -7.5996e-07-2.8453e-02j, -1.9222e-06-2.8454e-02j,
         -8.2105e-06-2.8441e-02j, -1.1548e-05-2.8456e-02j,
         -8.5086e-06-2.8434e-02j, -4.6492e-06-2.8434e-02j,
         -3.3230e-05-2.8437e-02j,  3.5013e-02+5.3722e-02j,
         -3.8530e-01+4.7289e-01j],
        [ 5.7838e-01+1.1921e-06j, -5.0266e-02-4.2919e-02j,
         -2.1398e-05-2.2675e-02j, -8.3148e-06-2.2661e-02j,
         -1.1921e-06-2.2653e-02j, -6.9141e-06-2.2656e-02j,
         -2.5421e-05-2.2670e-02j, -3.5852e-05-2.2684e-02j,
         -1.0818e-05-2.2620e-02j, -2.2522e-02+5.6695e-02j,
          4.2160e-01-3.9590e-01j],
        [-5.5567e-01-1.9073e-06j, -5.3064e-02-7.5495e-03j,
          7.3910e-06-1.7784e-02j, -4.7624e-05-1.7806e-02j,
         -5.1558e-05-1.7799e-02j, -4.7684e-07-1.7781e-02j,
         -6.0856e-05-1.7741e-02j,  1.4305e-06-1.7839e-02j,
          1.6212e-05-1.7867e-02j, -6.2935e-02+2.3522e-02j,
         -4.4431e-01+3.2632e-01j],
        [ 5.3296e-01+2.8610e-06j, -3.8621e-02+2.2579e-02j,
          5.1260e-06-1.3529e-02j,  2.5034e-05-1.3548e-02j,
          3.7670e-05-1.3556e-02j, -7.8678e-06-1.3556e-02j,
         -1.3232e-05-1.3592e-02j, -8.8215e-06-1.3545e-02j,
          2.4676e-05-1.3526e-02j, -6.8807e-02-2.5094e-02j,
          4.6710e-01-2.5672e-01j],
        [-5.2047e-01+4.7684e-07j, -1.2657e-02+3.9499e-02j,
         -9.0599e-06-9.7783e-03j, -2.3842e-06-9.7733e-03j,
         -4.7684e-06-9.7816e-03j,  2.6226e-06-9.7888e-03j,
          1.6689e-06-9.7620e-03j,  9.2983e-06-9.7985e-03j,
         -5.4836e-06-9.8081e-03j, -3.9800e-02-6.3587e-02j,
         -4.7955e-01+1.9158e-01j],
        [ 5.0798e-01+0.0000e+00j,  1.5934e-02+3.9506e-02j,
          0.0000e+00-6.3553e-03j, -7.6294e-06-6.3553e-03j,
         -7.6294e-06-6.3429e-03j, -2.0981e-05-6.3457e-03j,
         -1.7166e-05-6.3601e-03j,  1.5259e-05-6.3667e-03j,
         -2.2888e-05-6.3391e-03j,  8.0528e-03-7.2608e-02j,
          4.9201e-01-1.2634e-01j],
        [-5.0399e-01-9.5367e-07j,  3.7811e-02+2.4242e-02j,
          0.0000e+00-3.1281e-03j,  3.0518e-05-3.1738e-03j,
         -3.0518e-05-3.1738e-03j,  0.0000e+00-3.1738e-03j,
          6.1035e-05-3.1128e-03j,  6.1035e-05-3.0518e-03j,
         -6.1035e-05-3.1128e-03j,  5.1025e-02-4.7607e-02j,
         -4.9609e-01+6.3110e-02j],
        [ 5.0000e-01+0.0000e+00j,  4.5944e-02+0.0000e+00j,
          3.0518e-05+0.0000e+00j,  6.1035e-05+0.0000e+00j,
          1.2207e-04+0.0000e+00j,  1.2207e-04+0.0000e+00j,
          1.2207e-04+0.0000e+00j,  1.2207e-04+0.0000e+00j,
         -1.2207e-04+0.0000e+00j,  6.8115e-02+0.0000e+00j,
          5.0000e-01+0.0000e+00j]])
image

The left side is the output of torch.stft, and the right side is the output of the code used in this experiment. Except for the length of the output, everything else is exactly the same. torch.stft will only output from bin 0 to bin Nyquist, while the code used in the experiment will output the entire spectrum.

  if (self.dim() == 1) {
    out.squeeze_(0);
  }
  1. Convert the 3-dimensional Tensor back to 2-dimensional. Since our Tensor (self) is one-dimensional, we previously used unsqueeze to increase its dimension, and now we need to use squeeze_ to restore it back.

whisper.cpp/whisper.cpp

Lines 2514 to 2516 in a4bb2df

mel.n_mel = n_mel;
mel.n_len = n_samples/fft_step;
mel.n_len_org = mel.n_len;

whisper.cpp/whisper.cpp

Lines 2455 to 2457 in a4bb2df

for (int j = 1; j < fft_size / 2; j++) {
fft_out[j] += fft_out[fft_size - j];
}

The whisper.cpp does a rather poor job when calculating the STFT. For example, the formula for calculating frames is incorrect, resulting in an incorrect final count. The lack of Stage-2 Padding leads to potential edge effects. Also, after the FFT computation, it adds together the amplitudes of the symmetrical parts...

Part-5: Magnitudes of Bins

To get the power or magnitude spectrum, we compute the magnitude of each complex number from the FFT. Magnitude represents the amount of the frequency content present in that frame.

whisper/audio.py

magnitudes = stft[..., :-1].abs() ** 2

OpenAI's Whisper implementation computes mag^2 using just one line of code. The stft[..., :-1] is to remove the last frame from each frequency bin in the STFT computation results (I'm not clear why it's done this way either). Since the result of the STFT is complex, .abs() takes the absolute value of each complex number, calculating their magnitude. ** 2 then squares each magnitude. We can conduct a small experiment to verify our idea.

import torch
import torch.nn.functional as F


if __name__ == '__main__':
    hnn = torch.hann_window(50)
    test = torch.empty(200)
    for i in range(200):
        test[i] = i
    new = torch.stft(test, 50, 20, window=hnn, return_complex=True)
    print("Before:", new)
    print("After:", new[..., :-1])
Show Console
Before: tensor([[ 1.8568e+02+0.0000e+00j,  5.0039e+02+0.0000e+00j,
          1.0000e+03+0.0000e+00j,  1.5000e+03+0.0000e+00j,
          2.0000e+03+0.0000e+00j,  2.5000e+03+0.0000e+00j,
          3.0000e+03+0.0000e+00j,  3.5000e+03+0.0000e+00j,
          4.0000e+03+0.0000e+00j,  4.4992e+03+0.0000e+00j,
          4.7883e+03+0.0000e+00j],
        [-2.9432e+01-2.5749e-05j, -2.4964e+02+1.4907e+02j,
         -5.0000e+02+1.4921e+02j, -7.5000e+02+1.4921e+02j,
         -1.0000e+03+1.4921e+02j, -1.2500e+03+1.4921e+02j,
         -1.5000e+03+1.4921e+02j, -1.7500e+03+1.4921e+02j,
         -2.0000e+03+1.4921e+02j, -2.2507e+03+1.4886e+02j,
         -2.4571e+03-1.5895e+01j],
        [-7.0529e+01-7.6294e-06j,  2.8182e-01-3.3411e+01j,
          3.0518e-05-3.3157e+01j,  3.6240e-05-3.3157e+01j,
          4.3869e-05-3.3157e+01j,  4.9591e-05-3.3157e+01j,
          6.8665e-05-3.3157e+01j,  8.3923e-05-3.3157e+01j,
          9.1553e-05-3.3157e+01j, -4.9339e-01-3.3762e+01j,
          6.9529e+01+1.0568e+01j],
        [ 1.4240e+01-1.9073e-06j,  1.6559e-01-8.6154e+00j,
         -1.0729e-05-8.2887e+00j, -2.5034e-05-8.2887e+00j,
         -2.9087e-05-8.2887e+00j, -7.3910e-06-8.2887e+00j,
         -2.9087e-05-8.2887e+00j, -3.6001e-05-8.2887e+00j,
         -6.8903e-05-8.2887e+00j, -1.7455e-01-9.0084e+00j,
         -1.3240e+01-5.2422e+00j],
        [-9.7382e+00-6.6757e-06j,  3.4311e-02-3.6618e+00j,
          4.8876e-06-3.3149e+00j,  2.0504e-05-3.3149e+00j,
          4.8637e-05-3.3149e+00j, -9.0599e-06-3.3150e+00j,
          6.4969e-05-3.3150e+00j,  2.5034e-05-3.3150e+00j,
          2.1815e-05-3.3150e+00j,  1.4484e-01-3.9864e+00j,
          8.7382e+00+4.1599e+00j],
        [ 5.2361e+00-4.2915e-06j, -8.8814e-02-1.9713e+00j,
          3.3736e-05-1.6568e+00j,  2.8312e-05-1.6568e+00j,
         -1.5199e-05-1.6568e+00j,  9.6321e-05-1.6568e+00j,
          5.9903e-05-1.6569e+00j,  8.3923e-05-1.6568e+00j,
          1.2493e-04-1.6568e+00j,  3.8557e-01-2.1453e+00j,
         -4.2360e+00-3.0777e+00j],
        [-3.9971e+00+4.2915e-06j, -1.8333e-01-1.1854e+00j,
          1.9610e-05-9.4603e-01j,  2.4259e-05-9.4604e-01j,
          4.4465e-05-9.4605e-01j,  2.1726e-05-9.4602e-01j,
          6.1303e-05-9.4599e-01j,  4.7237e-05-9.4597e-01j,
          7.2449e-05-9.4595e-01j,  4.9770e-01-1.1795e+00j,
          2.9972e+00+2.6014e+00j],
        [ 2.7580e+00+1.1921e-06j, -2.3564e-01-7.2956e-01j,
         -1.0684e-05-5.9056e-01j, -2.0400e-05-5.9055e-01j,
         -2.7329e-05-5.9054e-01j, -2.8864e-05-5.9056e-01j,
         -1.6883e-05-5.9057e-01j, -2.8163e-05-5.9057e-01j,
         -5.8800e-05-5.9057e-01j,  4.7201e-01-5.7314e-01j,
         -1.7581e+00-2.1251e+00j],
        [-2.2498e+00-1.6093e-06j, -2.4141e-01-4.2715e-01j,
          9.8869e-06-3.9294e-01j,  5.2154e-06-3.9294e-01j,
          8.9034e-06-3.9294e-01j,  8.4639e-06-3.9293e-01j,
          5.0887e-06-3.9293e-01j,  4.5449e-06-3.9292e-01j,
          1.0714e-05-3.9295e-01j,  3.4008e-01-1.9410e-01j,
          1.2498e+00+1.8505e+00j],
        [ 1.7415e+00-4.1723e-07j, -2.0585e-01-2.1897e-01j,
          1.2480e-05-2.7424e-01j,  1.9066e-05-2.7426e-01j,
          7.6815e-06-2.7427e-01j,  6.5714e-06-2.7428e-01j,
          4.2900e-05-2.7426e-01j,  8.7544e-05-2.7428e-01j,
          1.0174e-04-2.7430e-01j,  1.5927e-01+8.4578e-04j,
         -7.4142e-01-1.5758e+00j],
        [-1.4860e+00+4.1127e-06j, -1.4190e-01-8.3810e-02j,
          1.4994e-05-1.9861e-01j,  1.6965e-05-1.9861e-01j,
         -1.0461e-05-1.9855e-01j, -2.1730e-05-1.9857e-01j,
          9.3095e-06-1.9855e-01j, -1.3746e-05-1.9852e-01j,
         -4.6659e-05-1.9854e-01j, -7.4884e-03+4.9216e-02j,
          4.8607e-01+1.3923e+00j],
        [ 1.2306e+00+1.9372e-06j, -6.6800e-02-1.0473e-02j,
          1.3297e-05-1.4803e-01j,  1.0869e-05-1.4803e-01j,
          2.1271e-06-1.4803e-01j, -5.2843e-05-1.4802e-01j,
         -6.6292e-06-1.4803e-01j, -1.2556e-05-1.4803e-01j,
         -2.6269e-05-1.4798e-01j, -1.1256e-01+1.9278e-03j,
         -2.3058e-01-1.2087e+00j],
        [-1.0858e+00-1.6242e-06j,  1.9604e-03+1.2476e-02j,
         -4.5490e-06-1.1296e-01j, -3.5157e-07-1.1297e-01j,
          5.3099e-06-1.1295e-01j,  1.3079e-05-1.1299e-01j,
          7.4464e-06-1.1299e-01j,  2.7884e-05-1.1299e-01j,
         -3.0464e-06-1.1303e-01j, -1.3694e-01-8.1263e-02j,
          8.5761e-02+1.0738e+00j],
        [ 9.4092e-01+5.8115e-07j,  5.0641e-02-2.0648e-04j,
          9.1051e-06-8.7735e-02j,  1.3185e-05-8.7747e-02j,
          6.1677e-06-8.7749e-02j,  2.0702e-05-8.7752e-02j,
          1.5848e-05-8.7733e-02j,  1.8269e-05-8.7750e-02j,
          2.1524e-05-8.7722e-02j, -9.2869e-02-1.4757e-01j,
          5.9072e-02-9.3906e-01j],
        [-8.5243e-01+2.8908e-06j,  7.1879e-02-3.1504e-02j,
          3.3081e-06-6.9134e-02j,  9.9167e-06-6.9107e-02j,
          1.5119e-05-6.9120e-02j,  1.1185e-05-6.9129e-02j,
          5.0755e-05-6.9070e-02j,  1.1617e-05-6.9100e-02j,
          6.6511e-05-6.9066e-02j, -1.4493e-02-1.6525e-01j,
         -1.4752e-01+8.3284e-01j],
        [ 7.6393e-01+4.7684e-07j,  6.5845e-02-6.4718e-02j,
         -1.0330e-05-5.5033e-02j,  1.7628e-05-5.5049e-02j,
         -2.9594e-05-5.5070e-02j, -1.3079e-05-5.4998e-02j,
         -3.8873e-05-5.5008e-02j, -1.0572e-05-5.4994e-02j,
         -3.3449e-05-5.4950e-02j,  5.7960e-02-1.3015e-01j,
          2.3605e-01-7.2658e-01j],
        [-7.0752e-01+2.9802e-07j,  3.9343e-02-8.6616e-02j,
         -1.6905e-05-4.4136e-02j, -2.3469e-06-4.4123e-02j,
         -9.6485e-06-4.4076e-02j,  2.0258e-05-4.4148e-02j,
          1.1474e-06-4.4102e-02j, -3.9563e-06-4.4092e-02j,
         -3.1091e-05-4.4052e-02j,  9.3750e-02-6.1317e-02j,
         -2.9239e-01+6.3817e-01j],
        [ 6.5112e-01-3.6359e-06j,  3.4683e-03-8.9727e-02j,
          1.2197e-05-3.5464e-02j,  2.6584e-05-3.5476e-02j,
          2.7277e-05-3.5498e-02j,  3.1456e-05-3.5485e-02j,
          1.9751e-05-3.5476e-02j,  2.4512e-05-3.5495e-02j,
          3.3930e-05-3.5480e-02j,  8.2363e-02+9.5680e-03j,
          3.4893e-01-5.4978e-01j],
        [-6.1475e-01+2.3842e-07j, -2.9699e-02-7.3373e-02j,
         -7.5996e-07-2.8453e-02j, -1.9222e-06-2.8454e-02j,
         -8.2105e-06-2.8441e-02j, -1.1548e-05-2.8456e-02j,
         -8.5086e-06-2.8434e-02j, -4.6492e-06-2.8434e-02j,
         -3.3230e-05-2.8437e-02j,  3.5013e-02+5.3722e-02j,
         -3.8530e-01+4.7289e-01j],
        [ 5.7838e-01+1.1921e-06j, -5.0266e-02-4.2919e-02j,
         -2.1398e-05-2.2675e-02j, -8.3148e-06-2.2661e-02j,
         -1.1921e-06-2.2653e-02j, -6.9141e-06-2.2656e-02j,
         -2.5421e-05-2.2670e-02j, -3.5852e-05-2.2684e-02j,
         -1.0818e-05-2.2620e-02j, -2.2522e-02+5.6695e-02j,
          4.2160e-01-3.9590e-01j],
        [-5.5567e-01-1.9073e-06j, -5.3064e-02-7.5495e-03j,
          7.3910e-06-1.7784e-02j, -4.7624e-05-1.7806e-02j,
         -5.1558e-05-1.7799e-02j, -4.7684e-07-1.7781e-02j,
         -6.0856e-05-1.7741e-02j,  1.4305e-06-1.7839e-02j,
          1.6212e-05-1.7867e-02j, -6.2935e-02+2.3522e-02j,
         -4.4431e-01+3.2632e-01j],
        [ 5.3296e-01+2.8610e-06j, -3.8621e-02+2.2579e-02j,
          5.1260e-06-1.3529e-02j,  2.5034e-05-1.3548e-02j,
          3.7670e-05-1.3556e-02j, -7.8678e-06-1.3556e-02j,
         -1.3232e-05-1.3592e-02j, -8.8215e-06-1.3545e-02j,
          2.4676e-05-1.3526e-02j, -6.8807e-02-2.5094e-02j,
          4.6710e-01-2.5672e-01j],
        [-5.2047e-01+4.7684e-07j, -1.2657e-02+3.9499e-02j,
         -9.0599e-06-9.7783e-03j, -2.3842e-06-9.7733e-03j,
         -4.7684e-06-9.7816e-03j,  2.6226e-06-9.7888e-03j,
          1.6689e-06-9.7620e-03j,  9.2983e-06-9.7985e-03j,
         -5.4836e-06-9.8081e-03j, -3.9800e-02-6.3587e-02j,
         -4.7955e-01+1.9158e-01j],
        [ 5.0798e-01+0.0000e+00j,  1.5934e-02+3.9506e-02j,
          0.0000e+00-6.3553e-03j, -7.6294e-06-6.3553e-03j,
         -7.6294e-06-6.3429e-03j, -2.0981e-05-6.3457e-03j,
         -1.7166e-05-6.3601e-03j,  1.5259e-05-6.3667e-03j,
         -2.2888e-05-6.3391e-03j,  8.0528e-03-7.2608e-02j,
          4.9201e-01-1.2634e-01j],
        [-5.0399e-01-9.5367e-07j,  3.7811e-02+2.4242e-02j,
          0.0000e+00-3.1281e-03j,  3.0518e-05-3.1738e-03j,
         -3.0518e-05-3.1738e-03j,  0.0000e+00-3.1738e-03j,
          6.1035e-05-3.1128e-03j,  6.1035e-05-3.0518e-03j,
         -6.1035e-05-3.1128e-03j,  5.1025e-02-4.7607e-02j,
         -4.9609e-01+6.3110e-02j],
        [ 5.0000e-01+0.0000e+00j,  4.5944e-02+0.0000e+00j,
          3.0518e-05+0.0000e+00j,  6.1035e-05+0.0000e+00j,
          1.2207e-04+0.0000e+00j,  1.2207e-04+0.0000e+00j,
          1.2207e-04+0.0000e+00j,  1.2207e-04+0.0000e+00j,
         -1.2207e-04+0.0000e+00j,  6.8115e-02+0.0000e+00j,
          5.0000e-01+0.0000e+00j]])
After: tensor([[ 1.8568e+02+0.0000e+00j,  5.0039e+02+0.0000e+00j,
          1.0000e+03+0.0000e+00j,  1.5000e+03+0.0000e+00j,
          2.0000e+03+0.0000e+00j,  2.5000e+03+0.0000e+00j,
          3.0000e+03+0.0000e+00j,  3.5000e+03+0.0000e+00j,
          4.0000e+03+0.0000e+00j,  4.4992e+03+0.0000e+00j],
        [-2.9432e+01-2.5749e-05j, -2.4964e+02+1.4907e+02j,
         -5.0000e+02+1.4921e+02j, -7.5000e+02+1.4921e+02j,
         -1.0000e+03+1.4921e+02j, -1.2500e+03+1.4921e+02j,
         -1.5000e+03+1.4921e+02j, -1.7500e+03+1.4921e+02j,
         -2.0000e+03+1.4921e+02j, -2.2507e+03+1.4886e+02j],
        [-7.0529e+01-7.6294e-06j,  2.8182e-01-3.3411e+01j,
          3.0518e-05-3.3157e+01j,  3.6240e-05-3.3157e+01j,
          4.3869e-05-3.3157e+01j,  4.9591e-05-3.3157e+01j,
          6.8665e-05-3.3157e+01j,  8.3923e-05-3.3157e+01j,
          9.1553e-05-3.3157e+01j, -4.9339e-01-3.3762e+01j],
        [ 1.4240e+01-1.9073e-06j,  1.6559e-01-8.6154e+00j,
         -1.0729e-05-8.2887e+00j, -2.5034e-05-8.2887e+00j,
         -2.9087e-05-8.2887e+00j, -7.3910e-06-8.2887e+00j,
         -2.9087e-05-8.2887e+00j, -3.6001e-05-8.2887e+00j,
         -6.8903e-05-8.2887e+00j, -1.7455e-01-9.0084e+00j],
        [-9.7382e+00-6.6757e-06j,  3.4311e-02-3.6618e+00j,
          4.8876e-06-3.3149e+00j,  2.0504e-05-3.3149e+00j,
          4.8637e-05-3.3149e+00j, -9.0599e-06-3.3150e+00j,
          6.4969e-05-3.3150e+00j,  2.5034e-05-3.3150e+00j,
          2.1815e-05-3.3150e+00j,  1.4484e-01-3.9864e+00j],
        [ 5.2361e+00-4.2915e-06j, -8.8814e-02-1.9713e+00j,
          3.3736e-05-1.6568e+00j,  2.8312e-05-1.6568e+00j,
         -1.5199e-05-1.6568e+00j,  9.6321e-05-1.6568e+00j,
          5.9903e-05-1.6569e+00j,  8.3923e-05-1.6568e+00j,
          1.2493e-04-1.6568e+00j,  3.8557e-01-2.1453e+00j],
        [-3.9971e+00+4.2915e-06j, -1.8333e-01-1.1854e+00j,
          1.9610e-05-9.4603e-01j,  2.4259e-05-9.4604e-01j,
          4.4465e-05-9.4605e-01j,  2.1726e-05-9.4602e-01j,
          6.1303e-05-9.4599e-01j,  4.7237e-05-9.4597e-01j,
          7.2449e-05-9.4595e-01j,  4.9770e-01-1.1795e+00j],
        [ 2.7580e+00+1.1921e-06j, -2.3564e-01-7.2956e-01j,
         -1.0684e-05-5.9056e-01j, -2.0400e-05-5.9055e-01j,
         -2.7329e-05-5.9054e-01j, -2.8864e-05-5.9056e-01j,
         -1.6883e-05-5.9057e-01j, -2.8163e-05-5.9057e-01j,
         -5.8800e-05-5.9057e-01j,  4.7201e-01-5.7314e-01j],
        [-2.2498e+00-1.6093e-06j, -2.4141e-01-4.2715e-01j,
          9.8869e-06-3.9294e-01j,  5.2154e-06-3.9294e-01j,
          8.9034e-06-3.9294e-01j,  8.4639e-06-3.9293e-01j,
          5.0887e-06-3.9293e-01j,  4.5449e-06-3.9292e-01j,
          1.0714e-05-3.9295e-01j,  3.4008e-01-1.9410e-01j],
        [ 1.7415e+00-4.1723e-07j, -2.0585e-01-2.1897e-01j,
          1.2480e-05-2.7424e-01j,  1.9066e-05-2.7426e-01j,
          7.6815e-06-2.7427e-01j,  6.5714e-06-2.7428e-01j,
          4.2900e-05-2.7426e-01j,  8.7544e-05-2.7428e-01j,
          1.0174e-04-2.7430e-01j,  1.5927e-01+8.4578e-04j],
        [-1.4860e+00+4.1127e-06j, -1.4190e-01-8.3810e-02j,
          1.4994e-05-1.9861e-01j,  1.6965e-05-1.9861e-01j,
         -1.0461e-05-1.9855e-01j, -2.1730e-05-1.9857e-01j,
          9.3095e-06-1.9855e-01j, -1.3746e-05-1.9852e-01j,
         -4.6659e-05-1.9854e-01j, -7.4884e-03+4.9216e-02j],
        [ 1.2306e+00+1.9372e-06j, -6.6800e-02-1.0473e-02j,
          1.3297e-05-1.4803e-01j,  1.0869e-05-1.4803e-01j,
          2.1271e-06-1.4803e-01j, -5.2843e-05-1.4802e-01j,
         -6.6292e-06-1.4803e-01j, -1.2556e-05-1.4803e-01j,
         -2.6269e-05-1.4798e-01j, -1.1256e-01+1.9278e-03j],
        [-1.0858e+00-1.6242e-06j,  1.9604e-03+1.2476e-02j,
         -4.5490e-06-1.1296e-01j, -3.5157e-07-1.1297e-01j,
          5.3099e-06-1.1295e-01j,  1.3079e-05-1.1299e-01j,
          7.4464e-06-1.1299e-01j,  2.7884e-05-1.1299e-01j,
         -3.0464e-06-1.1303e-01j, -1.3694e-01-8.1263e-02j],
        [ 9.4092e-01+5.8115e-07j,  5.0641e-02-2.0648e-04j,
          9.1051e-06-8.7735e-02j,  1.3185e-05-8.7747e-02j,
          6.1677e-06-8.7749e-02j,  2.0702e-05-8.7752e-02j,
          1.5848e-05-8.7733e-02j,  1.8269e-05-8.7750e-02j,
          2.1524e-05-8.7722e-02j, -9.2869e-02-1.4757e-01j],
        [-8.5243e-01+2.8908e-06j,  7.1879e-02-3.1504e-02j,
          3.3081e-06-6.9134e-02j,  9.9167e-06-6.9107e-02j,
          1.5119e-05-6.9120e-02j,  1.1185e-05-6.9129e-02j,
          5.0755e-05-6.9070e-02j,  1.1617e-05-6.9100e-02j,
          6.6511e-05-6.9066e-02j, -1.4493e-02-1.6525e-01j],
        [ 7.6393e-01+4.7684e-07j,  6.5845e-02-6.4718e-02j,
         -1.0330e-05-5.5033e-02j,  1.7628e-05-5.5049e-02j,
         -2.9594e-05-5.5070e-02j, -1.3079e-05-5.4998e-02j,
         -3.8873e-05-5.5008e-02j, -1.0572e-05-5.4994e-02j,
         -3.3449e-05-5.4950e-02j,  5.7960e-02-1.3015e-01j],
        [-7.0752e-01+2.9802e-07j,  3.9343e-02-8.6616e-02j,
         -1.6905e-05-4.4136e-02j, -2.3469e-06-4.4123e-02j,
         -9.6485e-06-4.4076e-02j,  2.0258e-05-4.4148e-02j,
          1.1474e-06-4.4102e-02j, -3.9563e-06-4.4092e-02j,
         -3.1091e-05-4.4052e-02j,  9.3750e-02-6.1317e-02j],
        [ 6.5112e-01-3.6359e-06j,  3.4683e-03-8.9727e-02j,
          1.2197e-05-3.5464e-02j,  2.6584e-05-3.5476e-02j,
          2.7277e-05-3.5498e-02j,  3.1456e-05-3.5485e-02j,
          1.9751e-05-3.5476e-02j,  2.4512e-05-3.5495e-02j,
          3.3930e-05-3.5480e-02j,  8.2363e-02+9.5680e-03j],
        [-6.1475e-01+2.3842e-07j, -2.9699e-02-7.3373e-02j,
         -7.5996e-07-2.8453e-02j, -1.9222e-06-2.8454e-02j,
         -8.2105e-06-2.8441e-02j, -1.1548e-05-2.8456e-02j,
         -8.5086e-06-2.8434e-02j, -4.6492e-06-2.8434e-02j,
         -3.3230e-05-2.8437e-02j,  3.5013e-02+5.3722e-02j],
        [ 5.7838e-01+1.1921e-06j, -5.0266e-02-4.2919e-02j,
         -2.1398e-05-2.2675e-02j, -8.3148e-06-2.2661e-02j,
         -1.1921e-06-2.2653e-02j, -6.9141e-06-2.2656e-02j,
         -2.5421e-05-2.2670e-02j, -3.5852e-05-2.2684e-02j,
         -1.0818e-05-2.2620e-02j, -2.2522e-02+5.6695e-02j],
        [-5.5567e-01-1.9073e-06j, -5.3064e-02-7.5495e-03j,
          7.3910e-06-1.7784e-02j, -4.7624e-05-1.7806e-02j,
         -5.1558e-05-1.7799e-02j, -4.7684e-07-1.7781e-02j,
         -6.0856e-05-1.7741e-02j,  1.4305e-06-1.7839e-02j,
          1.6212e-05-1.7867e-02j, -6.2935e-02+2.3522e-02j],
        [ 5.3296e-01+2.8610e-06j, -3.8621e-02+2.2579e-02j,
          5.1260e-06-1.3529e-02j,  2.5034e-05-1.3548e-02j,
          3.7670e-05-1.3556e-02j, -7.8678e-06-1.3556e-02j,
         -1.3232e-05-1.3592e-02j, -8.8215e-06-1.3545e-02j,
          2.4676e-05-1.3526e-02j, -6.8807e-02-2.5094e-02j],
        [-5.2047e-01+4.7684e-07j, -1.2657e-02+3.9499e-02j,
         -9.0599e-06-9.7783e-03j, -2.3842e-06-9.7733e-03j,
         -4.7684e-06-9.7816e-03j,  2.6226e-06-9.7888e-03j,
          1.6689e-06-9.7620e-03j,  9.2983e-06-9.7985e-03j,
         -5.4836e-06-9.8081e-03j, -3.9800e-02-6.3587e-02j],
        [ 5.0798e-01+0.0000e+00j,  1.5934e-02+3.9506e-02j,
          0.0000e+00-6.3553e-03j, -7.6294e-06-6.3553e-03j,
         -7.6294e-06-6.3429e-03j, -2.0981e-05-6.3457e-03j,
         -1.7166e-05-6.3601e-03j,  1.5259e-05-6.3667e-03j,
         -2.2888e-05-6.3391e-03j,  8.0528e-03-7.2608e-02j],
        [-5.0399e-01-9.5367e-07j,  3.7811e-02+2.4242e-02j,
          0.0000e+00-3.1281e-03j,  3.0518e-05-3.1738e-03j,
         -3.0518e-05-3.1738e-03j,  0.0000e+00-3.1738e-03j,
          6.1035e-05-3.1128e-03j,  6.1035e-05-3.0518e-03j,
         -6.1035e-05-3.1128e-03j,  5.1025e-02-4.7607e-02j],
        [ 5.0000e-01+0.0000e+00j,  4.5944e-02+0.0000e+00j,
          3.0518e-05+0.0000e+00j,  6.1035e-05+0.0000e+00j,
          1.2207e-04+0.0000e+00j,  1.2207e-04+0.0000e+00j,
          1.2207e-04+0.0000e+00j,  1.2207e-04+0.0000e+00j,
         -1.2207e-04+0.0000e+00j,  6.8115e-02+0.0000e+00j]])

We can see that the result is as we expected, the last frame of each frequency bin will be deleted.

Let's see how whisper.cpp implements this step?

whisper.cpp/whisper.cpp

Lines 2452 to 2454 in a4bb2df

for (int j = 0; j < fft_size; j++) {
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
}

First, whisper.cpp doesn't remove the last frame like OpenAI's whisper does, which could lead to some potential issues? Secondly, whisper.cpp doesn't calculate the magnitude first and then square it. Instead, it combines these two steps, making the implementation more efficient than OpenAI's.

Part-6: Mel Filters

Mel filters are a collection of triangular filters that are used in signal processing to mimic the non-linear frequency resolution of the human ear.

whisper/audio.py

    filters = mel_filters(audio.device, n_mels)
    mel_spec = filters @ magnitudes
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
    """
    load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
    Allows decoupling librosa dependency; saved using:

        np.savez_compressed(
            "mel_filters.npz",
            mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
        )
    """
    assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
    with np.load(
        os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
    ) as f:
        return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
image

OpenAI's Whisper uses a very simple method with mel filters. It loads the precomputed mel filters matrix from the hard drive and then performs matrix multiplication to obtain the computed results.

whisper.cpp/whisper.cpp

Lines 2466 to 2490 in a4bb2df

// mel spectrogram
for (int j = 0; j < mel.n_mel; j++) {
double sum = 0.0;
// unroll loop (suggested by GH user @lunixbochs)
int k = 0;
for (k = 0; k < n_fft - 3; k += 4) {
sum +=
fft_out[k + 0] * filters.data[j*n_fft + k + 0] +
fft_out[k + 1] * filters.data[j*n_fft + k + 1] +
fft_out[k + 2] * filters.data[j*n_fft + k + 2] +
fft_out[k + 3] * filters.data[j*n_fft + k + 3];
}
// handle n_fft remainder
for (; k < n_fft; k++) {
sum += fft_out[k] * filters.data[j * n_fft + k];
}
sum = log10(std::max(sum, 1e-10));
mel.data[j * mel.n_len + i] = sum;
}
}
}

whisper.cpp takes a similar approach by performing matrix multiplication. However, its method of processing each frame individually as it's computed means there's potential for optimization. Additionally, thanks to the filters and model weights being bundled together in the ggml, we can skip reading from the hard drive during the actual computation. Everything is loaded upfront with the model weights when the program starts.

Part-7: Dynamic Normalization

whisper/audio.py

    log_spec = torch.clamp(mel_spec, min=1e-10).log10()
    log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
    log_spec = (log_spec + 4.0) / 4.0

OpenAI's Whisper first clamps the computed results to ensure no values are less than 1e-10. Then, it takes the base-10 logarithm of these results using log10. After that, it uses maximum to make sure no values are less than the maximum value minus 8. Finally, it adds 4 to all the values and then divides by 4.

sum = log10(std::max(sum, 1e-10));

whisper.cpp/whisper.cpp

Lines 2558 to 2575 in a4bb2df

// clamping and normalization
double mmax = -1e20;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] > mmax) {
mmax = mel.data[i];
}
}
//printf("%s: max = %f\n", __func__, mmax);
mmax -= 8.0;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] < mmax) {
mel.data[i] = mmax;
}
mel.data[i] = (mel.data[i] + 4.0)/4.0;
}

whisper.cpp uses a virtually identical method, with no major issues.

The End

whisper.cpp Outdated
Comment on lines 2461 to 2465
// The frequency spectrum produced by real input data is symmetrical around the Nyquist frequency.
// This is where the actual issue lies
for (int j = 0; j < fft_size / 2; j++) {
fft_out[j] = (fft_out[fft_size - j - 1] + fft_out[j + 1]) / 2;
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Help me out a bit here - it's very unclear to me what this code is here for. We get the symmetrical spectrum in fft_out[0...fft_size] for free from the previous loop. Not only that, but inference seems to work fine with it commented out entirely. What's it supposed to be doing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your thinking is correct. Theoretically, this step is not needed at all. Because the FFT_SIZE is 400, FFT_OUT[200] is the Nyquist Component; FFT_OUT[0] is the DC Component, which is the average of all 400 samples in FFT_IN. The remaining parts, FFT_OUT[1...199] and FFT_OUT[201...399], are symmetrical. I guess the original intention of the code was just to make the result more accurate, so the symmetrical parts were added together? But this is obviously unreasonable, because the amplitude of FFT_OUT[0] does not match the rest. So I made a modification, turning it into adding them together and then dividing by 2. However, I made a guess that perhaps FFT_OUT[0], as the DC Component, has no actual meaning? So I shifted the whole thing forward by one position, and the effect was surprisingly good, so I kept it. But I don't have enough evidence to support this guess yet.

whisper.cpp/whisper.cpp

Lines 2455 to 2457 in a4bb2df

for (int j = 1; j < fft_size / 2; j++) {
fft_out[j] += fft_out[fft_size - j];
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the most important thing is to establish a complete and scientific WER (Word Error Rate) detection framework, so that we can quantitatively compare the quality. Otherwise, it would be too subjective.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are three things to play off here:

  1. The word error rate
  2. Byte-for-byte matching OpenAI's mel implementation choice
  3. Theoretical approach at this step in particular

While 1 is what we're after, and using that as a backstop is undoubtedly going to pay off, it feels like 2 would be a very useful stopping-off point to check, so we know we're at least putting the same things into the black box as they are; and 3 would help us avoid bugs on the way there.

What worries me here is this:

I guess the original intention of the code was just to make the result more accurate, so the symmetrical parts were added together?

If the things being averaged weren't already identical numbers, then wouldn't that point to a bug in the FFT implementation? Averaging them should always have been a no-op.

However, I made a guess that perhaps FFT_OUT[0], as the DC Component, has no actual meaning?

That's not quite true. fft_out[0] is referred to as the "dc component" as a shorthand, but remember that it's actually more accurately the bin which happens to centre at 0Hz. It contains information for all frequencies from 0 up to the half the bin width. With 400 samples at 16kHz the bin width is 40Hz, so fft_out[0] has everything up to 20Hz in it - and also it gets any signal power at the sample frequency aliased into it. I don't think I'd make a bet that neither of those things matters. I know human vocal chords don't typically go that low, but I wouldn't write off there being transient phoneme information in the DC component.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding how to measure the word error rate, I found that there are many well-annotated datasets online, such as Mozilla's Common Voice. An article is divided into many sentences, read by different users, ensuring a diversity of accents, and including not just English but many other languages. The dataset is under CC-0, so there won't be any potential legal issues. Therefore, we can write a script in Python to batch test the WER of whisper.cpp as well as OpenAI's whisper.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While 1 is what we're after, and using that as a backstop is undoubtedly going to pay off, it feels like 2 would be a very useful stopping-off point to check, so we know we're at least putting the same things into the black box as they are; and 3 would help us avoid bugs on the way there.

Agree. If needed, match the whisper.cpp precision so that we have identical input with PyTorch.
Before doing any kind of WER, we need identical input to the transformer and I know it is not the case currently. Not sure if it is just a matter of precision or if there is a bug in whisper.cpp - needs similar analysis to what @bobqianic did in the comment above.

If the things being averaged weren't already identical numbers, then wouldn't that point to a bug in the FFT implementation? Averaging them should always have been a no-op.

Yes, but I believe the values are identical and this step is basically a noop. Needs confirmation. I guess I've put it there due to some doubts I had back when working on ggwave and was new to audio processing, but indeed the spectrum should be symmetrical.

Don't think we need to discard fft_out[0]. Again - let's compare what PyTorch is doing to make sure.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Help me out a bit here - it's very unclear to me what this code is here for. We get the symmetrical spectrum in fft_out[0...fft_size] for free from the previous loop. Not only that, but inference seems to work fine with it commented out entirely. What's it supposed to be doing?

Confirmed. This step is unnecessary.

I carefully examined the C++ implementation of PyTorch's STFT, and it directly removes all the bins after the Nyquist bin.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know I'm not going completely mad :-) If you have a look at @lunixbochs' pocketfft branch from a little while back you'll see that the real to complex pocketfft FFT implementation gives you back a buffer of length (fft_size/2)+1. Using std::complex directly might clarify things, even without swapping out the FFT implementation.

@bobqianic bobqianic marked this pull request as draft August 7, 2023 13:17
@bobqianic
Copy link
Collaborator Author

bobqianic commented Aug 8, 2023

Oh no, I've stumbled upon a significant issue. It appears that the cos and sin algorithms used in PyTorch don't align with the algorithms we're using in <cmath>. When given the same input, the output can be quite different, particularly with small angles. For example, the input value 0.0157079632679, or (2PI * 1) / 400. In PyTorch, using torch.cos() yields 0.9998766183853149, while the C++ cos() function returns 0.9998766324816614. @regularfry

import torch


if __name__ == '__main__':
    t = torch.empty(1)
    t[0] = 0.0157079632679
    print(float(torch.cos(t)))
0.9998766183853149
#include <cmath>
#include <iostream>
#include <iomanip>


int main() {
    double t = 0.0157079632679;
    std::cout << std::setprecision(16) << cos(t) << std::endl;
}
0.9998766324816614

However, I think overall, the problem shouldn't be too significant, since we are using the float type, which only has 6 to 7 significant digits.

@AlexandrGraschenkov
Copy link
Contributor

@bobqianic we can precalculate this values in PyTorch 😏

@regularfry
Copy link

Torch's default internal representation is float32, and the C++ code is using float64. If you use torch.empty(1, dtype=torch.float64) you'll get matching results. It looks like torch uses std::cos under the hood so they'll match if the precision does.

@ggerganov
Copy link
Owner

Thank you - I will take a look in details early next week

@jbrough
Copy link
Contributor

jbrough commented Aug 12, 2023

This is an interesting discussion.

I set out to exactly replicate OpenAI's / PyTorch / Whisper.cpp Mel filter bank and mel STFT implementations in Rust a few weeks ago - and did so. After reading this, I'm going to compare with the original OpenAI model again, and see if I've missed anything.

I also got caught out by f32 vs f64.

However, there comes a point in the processing where it no longer matters.

whisper.cpp produces mel spectrograms with 1.0e-6 precision. You can literally round these to 1.0e-1 and get the same results.

I'm not really sure padding makes any difference either, but I still need to test that properly and convince myself that it does make a difference as what I've read about the model suggests it might.

@jbrough
Copy link
Contributor

jbrough commented Aug 12, 2023

ps: it's critical that mel spectrograms passed to whisper.cpp have an even number of columns in total (after padding, if you're padding) otherwise it will hallucinate massively.

@bobqianic bobqianic marked this pull request as ready for review August 13, 2023 09:29
@regularfry
Copy link

This is great stuff, well done!

@jbrough
Copy link
Contributor

jbrough commented Aug 14, 2023

would it be possible to get two mel spectrogram visual diffs, perhaps of the jfk test sample used by whisper, but anything will do - it should look like this:

quantized_mel

One pixel per mel frame. The examples above are highly compressed in the time domain and not what whisper processes. Due to the uncertainty principle, it's impossible for mel spectrogram data to "hide" useful information that isn't visually discernible, which is why precision is a red herring. Also, the differences we see above are well below variations in actual speech. Unless there's a structural problem with the spectrogram (did you mention there was mirroring?) it's difficult to see it altering results broadly in one direction or another.

I have a feeling that the transcription results you're seeing may be random and won't always favour better transcriptions. Or it might be to do with the padding changes, particularly the "contemplative" pause after the sample. It would be really too useful to isolate what, exactly, is having the effect. A visual diff will help.

@ggerganov
Copy link
Owner

Apologies for taking so long. Will try to find the time this week

@alexanderbluhm
Copy link

I compared outputs from whisper-cpp and openai-whisper for German audio files and I can confirm, that for smaller models, the differences were more significant. But even for the large(v2) model, openai-whisper produced slightly better results.
I tested this PR and found that it produces better results for my example compared to the current version.

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do some testings now on my machine and after removing the debug flag from the public API, we can merge this

whisper.h Outdated Show resolved Hide resolved
@ggerganov ggerganov merged commit 7e54df4 into ggerganov:master Aug 27, 2023
34 of 38 checks passed
@ggerganov
Copy link
Owner

ggerganov commented Aug 27, 2023

@bobqianic

Amazing work! Really appreciate the effort and the detailed analysis.

would it be possible to get two mel spectrogram visual diffs, perhaps of the jfk test sample used by whisper, but anything will do - it should look like this:

quantized_mel

One pixel per mel frame. The examples above are highly compressed in the time domain and not what whisper processes. Due to the uncertainty principle, it's impossible for mel spectrogram data to "hide" useful information that isn't visually discernible, which is why precision is a red herring. Also, the differences we see above are well below variations in actual speech. Unless there's a structural problem with the spectrogram (did you mention there was mirroring?) it's difficult to see it altering results broadly in one direction or another.

I have a feeling that the transcription results you're seeing may be random and won't always favour better transcriptions. Or it might be to do with the padding changes, particularly the "contemplative" pause after the sample. It would be really too useful to isolate what, exactly, is having the effect. A visual diff will help.

I agree with @jbrough analysis. A visual diff would be interesting to see.
But in any case, having a matching input to whisper.cpp as the one in OpenAI is always a good thing, regardless of how big the impact on the transcription quality is

@bobqianic
Copy link
Collaborator Author

Excellent work!

To get a better perspective, how does it affect the overall transcription speed from start to end?

With the recent optimizations to beam_search sampling (#1243), log_mel_spectrogram calculation (#1148), and FFT calculation (#1142), the latest version runs 10% faster end-to-end compared to the version from July.

jacobwu-b pushed a commit to jacobwu-b/Transcriptify-by-whisper.cpp that referenced this pull request Oct 24, 2023
* Fix MSVC compile error C3688

Instead of simply using 'add_compile_options(/utf-8)' to address the MSVC compile error C3688, a better approach would be to handle it in a way that prevents passing '/utf-8' to NVCC.

* Significantly improve inference quality

In the function `log_mel_spectrogram_worker_thread`, there's an array out-of-bounds issue occurring during the calculation of complex number moduli. This issue is causing disruptions in the FFT spectrum, which, in turn, is reducing the quality of inference.

* Significantly improve inference quality

At last, I've pinpointed the actual source of the problem. Given that the frequency spectrum generated from real input data is symmetrical around the Nyquist frequency, there's a for-loop within the `log_mel_spectrogram_worker_thread` function that attempts to fold the frequency spectrum. Regrettably, a bug within this for-loop is causing a frame shift in the frequency spectrum. The previous attempt to remedy this, which involved using `fft_size + 1` when calculating the modulus, was merely a band-aid solution and did not address the underlying issue.

* Addressed a few minor issues

Fixed the issue of `fft_out` continuously expanding. Resolved the fallback caused by using 'break' instead of `fft_in[j] = 0`.

* Significantly improve inference quality 

Thanks for your patience everyone. It's finally sorted out. Now, the right side of the FFT spectrum is being flipped over to the left, and the amplitudes at corresponding positions on the left and right are added together (the spectrum on the left needs to be shifted by one position), then the average is calculated. FFT_OUT[0] is no longer discarded, making full use of the limited space to pack in more information.

* Add annotation and performance improvement

* Calculate FFT only when fft_in are not all zero

* Some minor performance improvement

* Fixed a bug impacting inference quality

* The first version after all the analysis is completed.

* Fix some bugs and add debug mode

* Fixed several bugs

* Temporarily disable speed-up mode and add debug mode.

* Add debug mode

* Disable speed-up mode and add debug mode

* Fix CI error (ggerganov#1)

* Fix error

* Fix error

* Fixed several bugs including [BLANK_AUDIO] problem

* Remove Hard-coded hann window

* Some Final Fix (ggerganov#2)

* Fix error

* Fix error

* Probably the last commit

* Probably the last commit

* whisper : minor coding style changes

* whisper : remove debug from public API

---------

Co-authored-by: Georgi Gerganov <[email protected]>
jacobwu-b pushed a commit to jacobwu-b/Transcriptify-by-whisper.cpp that referenced this pull request Oct 24, 2023
* Fix MSVC compile error C3688

Instead of simply using 'add_compile_options(/utf-8)' to address the MSVC compile error C3688, a better approach would be to handle it in a way that prevents passing '/utf-8' to NVCC.

* Significantly improve inference quality

In the function `log_mel_spectrogram_worker_thread`, there's an array out-of-bounds issue occurring during the calculation of complex number moduli. This issue is causing disruptions in the FFT spectrum, which, in turn, is reducing the quality of inference.

* Significantly improve inference quality

At last, I've pinpointed the actual source of the problem. Given that the frequency spectrum generated from real input data is symmetrical around the Nyquist frequency, there's a for-loop within the `log_mel_spectrogram_worker_thread` function that attempts to fold the frequency spectrum. Regrettably, a bug within this for-loop is causing a frame shift in the frequency spectrum. The previous attempt to remedy this, which involved using `fft_size + 1` when calculating the modulus, was merely a band-aid solution and did not address the underlying issue.

* Addressed a few minor issues

Fixed the issue of `fft_out` continuously expanding. Resolved the fallback caused by using 'break' instead of `fft_in[j] = 0`.

* Significantly improve inference quality 

Thanks for your patience everyone. It's finally sorted out. Now, the right side of the FFT spectrum is being flipped over to the left, and the amplitudes at corresponding positions on the left and right are added together (the spectrum on the left needs to be shifted by one position), then the average is calculated. FFT_OUT[0] is no longer discarded, making full use of the limited space to pack in more information.

* Add annotation and performance improvement

* Calculate FFT only when fft_in are not all zero

* Some minor performance improvement

* Fixed a bug impacting inference quality

* The first version after all the analysis is completed.

* Fix some bugs and add debug mode

* Fixed several bugs

* Temporarily disable speed-up mode and add debug mode.

* Add debug mode

* Disable speed-up mode and add debug mode

* Fix CI error (ggerganov#1)

* Fix error

* Fix error

* Fixed several bugs including [BLANK_AUDIO] problem

* Remove Hard-coded hann window

* Some Final Fix (ggerganov#2)

* Fix error

* Fix error

* Probably the last commit

* Probably the last commit

* whisper : minor coding style changes

* whisper : remove debug from public API

---------

Co-authored-by: Georgi Gerganov <[email protected]>
vonstring pushed a commit to vonstring/whisper.cpp that referenced this pull request Nov 7, 2023
* Fix MSVC compile error C3688

Instead of simply using 'add_compile_options(/utf-8)' to address the MSVC compile error C3688, a better approach would be to handle it in a way that prevents passing '/utf-8' to NVCC.

* Significantly improve inference quality

In the function `log_mel_spectrogram_worker_thread`, there's an array out-of-bounds issue occurring during the calculation of complex number moduli. This issue is causing disruptions in the FFT spectrum, which, in turn, is reducing the quality of inference.

* Significantly improve inference quality

At last, I've pinpointed the actual source of the problem. Given that the frequency spectrum generated from real input data is symmetrical around the Nyquist frequency, there's a for-loop within the `log_mel_spectrogram_worker_thread` function that attempts to fold the frequency spectrum. Regrettably, a bug within this for-loop is causing a frame shift in the frequency spectrum. The previous attempt to remedy this, which involved using `fft_size + 1` when calculating the modulus, was merely a band-aid solution and did not address the underlying issue.

* Addressed a few minor issues

Fixed the issue of `fft_out` continuously expanding. Resolved the fallback caused by using 'break' instead of `fft_in[j] = 0`.

* Significantly improve inference quality 

Thanks for your patience everyone. It's finally sorted out. Now, the right side of the FFT spectrum is being flipped over to the left, and the amplitudes at corresponding positions on the left and right are added together (the spectrum on the left needs to be shifted by one position), then the average is calculated. FFT_OUT[0] is no longer discarded, making full use of the limited space to pack in more information.

* Add annotation and performance improvement

* Calculate FFT only when fft_in are not all zero

* Some minor performance improvement

* Fixed a bug impacting inference quality

* The first version after all the analysis is completed.

* Fix some bugs and add debug mode

* Fixed several bugs

* Temporarily disable speed-up mode and add debug mode.

* Add debug mode

* Disable speed-up mode and add debug mode

* Fix CI error (ggerganov#1)

* Fix error

* Fix error

* Fixed several bugs including [BLANK_AUDIO] problem

* Remove Hard-coded hann window

* Some Final Fix (ggerganov#2)

* Fix error

* Fix error

* Probably the last commit

* Probably the last commit

* whisper : minor coding style changes

* whisper : remove debug from public API

---------

Co-authored-by: Georgi Gerganov <[email protected]>
landtanin pushed a commit to landtanin/whisper.cpp that referenced this pull request Dec 16, 2023
* Fix MSVC compile error C3688

Instead of simply using 'add_compile_options(/utf-8)' to address the MSVC compile error C3688, a better approach would be to handle it in a way that prevents passing '/utf-8' to NVCC.

* Significantly improve inference quality

In the function `log_mel_spectrogram_worker_thread`, there's an array out-of-bounds issue occurring during the calculation of complex number moduli. This issue is causing disruptions in the FFT spectrum, which, in turn, is reducing the quality of inference.

* Significantly improve inference quality

At last, I've pinpointed the actual source of the problem. Given that the frequency spectrum generated from real input data is symmetrical around the Nyquist frequency, there's a for-loop within the `log_mel_spectrogram_worker_thread` function that attempts to fold the frequency spectrum. Regrettably, a bug within this for-loop is causing a frame shift in the frequency spectrum. The previous attempt to remedy this, which involved using `fft_size + 1` when calculating the modulus, was merely a band-aid solution and did not address the underlying issue.

* Addressed a few minor issues

Fixed the issue of `fft_out` continuously expanding. Resolved the fallback caused by using 'break' instead of `fft_in[j] = 0`.

* Significantly improve inference quality 

Thanks for your patience everyone. It's finally sorted out. Now, the right side of the FFT spectrum is being flipped over to the left, and the amplitudes at corresponding positions on the left and right are added together (the spectrum on the left needs to be shifted by one position), then the average is calculated. FFT_OUT[0] is no longer discarded, making full use of the limited space to pack in more information.

* Add annotation and performance improvement

* Calculate FFT only when fft_in are not all zero

* Some minor performance improvement

* Fixed a bug impacting inference quality

* The first version after all the analysis is completed.

* Fix some bugs and add debug mode

* Fixed several bugs

* Temporarily disable speed-up mode and add debug mode.

* Add debug mode

* Disable speed-up mode and add debug mode

* Fix CI error (ggerganov#1)

* Fix error

* Fix error

* Fixed several bugs including [BLANK_AUDIO] problem

* Remove Hard-coded hann window

* Some Final Fix (ggerganov#2)

* Fix error

* Fix error

* Probably the last commit

* Probably the last commit

* whisper : minor coding style changes

* whisper : remove debug from public API

---------

Co-authored-by: Georgi Gerganov <[email protected]>
iThalay pushed a commit to iThalay/whisper.cpp that referenced this pull request Sep 23, 2024
* Fix MSVC compile error C3688

Instead of simply using 'add_compile_options(/utf-8)' to address the MSVC compile error C3688, a better approach would be to handle it in a way that prevents passing '/utf-8' to NVCC.

* Significantly improve inference quality

In the function `log_mel_spectrogram_worker_thread`, there's an array out-of-bounds issue occurring during the calculation of complex number moduli. This issue is causing disruptions in the FFT spectrum, which, in turn, is reducing the quality of inference.

* Significantly improve inference quality

At last, I've pinpointed the actual source of the problem. Given that the frequency spectrum generated from real input data is symmetrical around the Nyquist frequency, there's a for-loop within the `log_mel_spectrogram_worker_thread` function that attempts to fold the frequency spectrum. Regrettably, a bug within this for-loop is causing a frame shift in the frequency spectrum. The previous attempt to remedy this, which involved using `fft_size + 1` when calculating the modulus, was merely a band-aid solution and did not address the underlying issue.

* Addressed a few minor issues

Fixed the issue of `fft_out` continuously expanding. Resolved the fallback caused by using 'break' instead of `fft_in[j] = 0`.

* Significantly improve inference quality 

Thanks for your patience everyone. It's finally sorted out. Now, the right side of the FFT spectrum is being flipped over to the left, and the amplitudes at corresponding positions on the left and right are added together (the spectrum on the left needs to be shifted by one position), then the average is calculated. FFT_OUT[0] is no longer discarded, making full use of the limited space to pack in more information.

* Add annotation and performance improvement

* Calculate FFT only when fft_in are not all zero

* Some minor performance improvement

* Fixed a bug impacting inference quality

* The first version after all the analysis is completed.

* Fix some bugs and add debug mode

* Fixed several bugs

* Temporarily disable speed-up mode and add debug mode.

* Add debug mode

* Disable speed-up mode and add debug mode

* Fix CI error (chidiwilliams#1)

* Fix error

* Fix error

* Fixed several bugs including [BLANK_AUDIO] problem

* Remove Hard-coded hann window

* Some Final Fix (chidiwilliams#2)

* Fix error

* Fix error

* Probably the last commit

* Probably the last commit

* whisper : minor coding style changes

* whisper : remove debug from public API

---------

Co-authored-by: Georgi Gerganov <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants