Skip to content

Commit

Permalink
Support short time furier transform
Browse files Browse the repository at this point in the history
  • Loading branch information
TakanoTaiga committed Nov 7, 2024
1 parent f470699 commit f10a868
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 24 deletions.
15 changes: 10 additions & 5 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gradio as gr
from moviepy.editor import VideoFileClip
from lab_tools import wavelet
from lab_tools import spectrogram
from lab_tools import labutils
from lab_tools import analyze1f
from lab_tools import ytutil
Expand Down Expand Up @@ -45,10 +45,15 @@ def update_slidar_range_video_file(file_path):


with gr.Blocks() as main_ui:
with gr.Tab("Wavelet"):
with gr.Tab("Spectrogram analyze"):
with gr.Row():
with gr.Column():
file_input = gr.File(label="CSVファイルをアップロードしてください。", file_count="single", file_types=["csv"])
analysis_method = gr.Radio(
["Short-Time Fourier Transform", "Wavelet"],
label="Analysis method",
value="Short-Time Fourier Transform",
)
fs_slider = gr.Slider(minimum=0, maximum=10000, value=1000, label="サンプリング周波数", step=10, info="単位はHz。")
fmax_slider = gr.Slider(minimum=0, maximum=200, value=60, label="wavelet 最大周波数", step=10, info="単位はHz。")
column_dropdown = gr.Dropdown(["Fp1", "Fp2", "T7", "T8", "O1", "O2"], value="Fp2", label="使用する信号データ", allow_custom_value=True, info="使用する信号データを選んでください。デフォルトはFp2です。")
Expand Down Expand Up @@ -76,13 +81,13 @@ def update_slidar_range_video_file(file_path):
wavelet_image = gr.Image(type="filepath", label="Wavelet")
signal_image = gr.Image(type="filepath", label="Signal")

submit_button.click(wavelet.wavelet_ui, inputs=[
file_input,
submit_button.click(spectrogram.spectrogram_ui, inputs=[
file_input, analysis_method,
fs_slider, fmax_slider, column_dropdown, start_time, end_time,
filter_setting, fp_hp, fs_hp, gpass, gstop],
outputs=[wavelet_image, signal_image])

with gr.Tab("1f Noise Search"):
with gr.Tab("1f noise analyze"):
with gr.Row():
with gr.Column():
mode_setting = gr.Radio(
Expand Down
56 changes: 37 additions & 19 deletions lab_tools/wavelet.py → lab_tools/spectrogram.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import numpy as np
import matplotlib.pyplot as plt
import math
import tempfile

import scipy.signal as signal
from lab_tools import labutils
from lab_tools import filter
import math


# モルレーウェーブレット関数
def morlet(x, f, width):
def morlet_wavelet(x, f, width):
sf = f / width
st = 1 / (2 * math.pi * sf)
A = 1 / (st * math.sqrt(2 * math.pi))
Expand All @@ -17,21 +15,19 @@ def morlet(x, f, width):
return A * np.exp(co1) * np.exp(h)


# 連続ウェーブレット変換
def continuous_wavelet_transform(Fs, data, fmax, width=48, wavelet_R=0.5):
Ts = 1 / Fs
wavelet_length = np.arange(-wavelet_R, wavelet_R, Ts)
data_length = len(data)
cwt_result = np.zeros([fmax, data_length])

for i in range(fmax):
conv_result = np.convolve(data, morlet(wavelet_length, i + 1, width), mode='same')
conv_result = np.convolve(data, morlet_wavelet(wavelet_length, i + 1, width), mode='same')
cwt_result[i, :] = (2 * np.abs(conv_result) / Fs) ** 2

return cwt_result


# 連続ウェーブレット変換結果をカラーマップとしてプロット
def plot_cwt(cwt_result, time_data, fmax):
plt.imshow(cwt_result, cmap='jet', aspect='auto',
extent=[time_data[0], time_data[-1], fmax, 0],
Expand All @@ -43,9 +39,23 @@ def plot_cwt(cwt_result, time_data, fmax):
plt.gca().invert_yaxis()


# グラフ描画とCWTの処理を行う関数
def wavelet_ui(
uploaded_file,
def stft_plot_spectrogram(data, Fs, N, freq_limit=None):
freqs, times, Zxx = signal.stft(data, fs=Fs, window='hann', nperseg=N, noverlap=None)
amp = np.abs(Zxx)
amp[amp == 0] = np.finfo(float).eps
fig, ax = plt.subplots(figsize=(12, 6))
spectrogram = ax.pcolormesh(times, freqs, np.log10(amp), shading="auto", vmin=0, vmax=5)
fig.colorbar(spectrogram, ax=ax, orientation="vertical").set_label("Amplitude (dB)")
ax.set_xlabel("Time [s]")
ax.set_ylabel("Frequency [Hz]")
if freq_limit:
ax.set_ylim([0, freq_limit])
plt.show()


# グラフ描画とスペクトログラムの処理を行う関数
def spectrogram_ui(
uploaded_file, analysis_method,
Fs, fmax, column_name, start_time, end_time,
filter_setting, fp_hp, fs_hp, gpass, gstop):
filepath = uploaded_file.name
Expand All @@ -71,19 +81,27 @@ def wavelet_ui(
signal = signal[start_idx:end_idx]
t_data = t_data[start_idx:end_idx]

signal_filename = tempfile.NamedTemporaryFile(delete=False, suffix='.png').name
# 信号をプロットして保存
plt.figure(dpi=200)
plt.title("Signal")
plt.plot(t_data, signal)
plt.xlim(start_time, end_time)
plt.xlabel("Time [sec]")
plt.ylabel("Voltage [uV]")
signal_filename = "signal_plot.png"
plt.savefig(signal_filename)

cwt_signal_filename = tempfile.NamedTemporaryFile(delete=False, suffix='.png').name
cwt_signal = continuous_wavelet_transform(Fs=Fs, data=signal, fmax=fmax)
plt.figure(dpi=200)
plot_cwt(cwt_signal, t_data, fmax)
plt.savefig(cwt_signal_filename)

return cwt_signal_filename, signal_filename
# スペクトログラムをプロットして保存
if analysis_method == "Short-Time Fourier Transform":
plt.figure(dpi=200)
stft_plot_spectrogram(data=signal, Fs=Fs, N=256, freq_limit=fmax)
spectrogram_filename = "stft_spectrogram_plot.png"
plt.savefig(spectrogram_filename)
else:
spectrogram_filename = "wavelet_spectrogram_plot.png"
cwt_signal = continuous_wavelet_transform(Fs=Fs, data=signal, fmax=fmax)
plt.figure(dpi=200)
plot_cwt(cwt_signal, t_data, fmax)
plt.savefig(spectrogram_filename)

return spectrogram_filename, signal_filename

0 comments on commit f10a868

Please sign in to comment.