-
Notifications
You must be signed in to change notification settings - Fork 0
/
processing_functions.py
124 lines (99 loc) · 4.92 KB
/
processing_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# import matplotlib.pyplot as plt
import numpy as np
from scipy.fftpack import fft, fftfreq, ifft, fftshift, ifftshift
import pywt
import copy
# from wfdb import processing
# from scipy.signal import butter
# from scipy import signal
def baseline_removal_moving_median(signal, fs):
"""
Perform baseline removal using a moving median.
Parameters:
-----------
signal : numpy array
The signal to be filtered.
window_size : int
The size of the window for the moving median filter.
Returns:
--------
filtered_signal : numpy array
The baseline-corrected signal.
"""
window_size = 2 * fs
filtered_signal = signal - np.convolve(signal, np.ones(window_size) / window_size, mode='same')
filtered_signal = filtered_signal - np.convolve(filtered_signal, np.ones(window_size) / window_size, mode='same')
return filtered_signal
def low_pass_filter(cutoff_freq, signal, sampling_rate):
spectrum = fftshift(fft(signal))
freq = fftshift(fftfreq(signal.shape[-1], 1 / sampling_rate))
spectrum[0:(round((len(spectrum) * (sampling_rate / 2 - cutoff_freq)) / sampling_rate) + 1)] = 0
spectrum[(signal.shape[-1] - round((len(spectrum) * (sampling_rate / 2 - cutoff_freq)) / sampling_rate)): (
signal.shape[-1])] = 0
filter_signal = ifft(ifftshift(spectrum))
return filter_signal, freq, spectrum
def high_pass_filter(cutoff_freq, signal, sampling_rate):
freq = fftshift(fftfreq(signal.shape[-1], 1 / sampling_rate))
spectrum_without_shift = (fft(signal))
spectrum_without_shift[0:(round((len(spectrum_without_shift) * cutoff_freq) / sampling_rate) + 1)] = 0
spectrum_without_shift[
(signal.shape[-1] - round((len(spectrum_without_shift) * cutoff_freq) / sampling_rate)): (signal.shape[-1])] = 0
filter_signal = ifft(spectrum_without_shift)
return filter_signal, freq, fftshift(spectrum_without_shift)
def band_pass_filter(cutoff_freq_down, cutoff_freq_up, signal, sampling_rate):
# freq = fftshift(fftfreq(signal.shape[-1], 1 / sampling_rate))
spectrum = fftshift(fft(signal))
spectrum[0:(round((len(spectrum) * (sampling_rate / 2 - cutoff_freq_up)) / sampling_rate) + 1)] = 0
spectrum[(signal.shape[-1] - round((len(spectrum) * (sampling_rate / 2 - cutoff_freq_up)) / sampling_rate)): (
signal.shape[-1])] = 0
spectrum_without_shift = (ifftshift(spectrum))
spectrum_without_shift[0:(round((len(spectrum_without_shift) * cutoff_freq_down) / sampling_rate) + 1)] = 0
spectrum_without_shift[
(signal.shape[-1] - round((len(spectrum_without_shift) * cutoff_freq_down) / sampling_rate)): (
signal.shape[-1])] = 0
filter_signal = np.real(ifft(spectrum_without_shift))
return filter_signal
def compute_fft(signal, sample_rate):
N = len(signal)
fft_signal = np.abs(fft(signal-np.mean(signal))[0:N // 2])
frequency_bins = fftfreq(N, 1/sample_rate)[:N // 2]
return fft_signal, frequency_bins
def wavelet_filter(signal):
wavelet = pywt.Wavelet('db2')
# levdec = min(pywt.dwt_max_level(signal.shape[-1], wavelet.dec_len), 6)
Ca4, Cd4, Cd3, Cd2, Cd1 = pywt.wavedec(signal, wavelet=wavelet, level=4)
Ca4, Cd2, Cd1 = np.zeros(Ca4.shape[-1]), np.zeros(Cd2.shape[-1]), np.zeros(Cd1.shape[-1])
filtered_signal = pywt.waverec([Ca4, Cd4, Cd3, Cd2, Cd1], wavelet)
return filtered_signal
def ecg_pre_processing(ecg_dict):
fs = ecg_dict['fs']
ecg_processed = copy.deepcopy(ecg_dict)
for i in range(ecg_dict['num_of_segments']):
processed_signal = ecg_processed['signal'][i]
# Baseline removal
processed_signal = baseline_removal_moving_median(processed_signal, fs)
"""
if input("Perform powerline filter [y/N]? ") == "y":
# Remove powerline interference
powerline = [50, 60]
bandwidth = 1
ecg_filtered['signal'] = notch_filter(ecg_filtered['signal'], powerline, bandwidth, fs)
if input("Perform BP filter [y/N]? ") == "y":
# Remove high frequency noise
ecg_filtered['signal'] = band_pass_filter(0.5, 50, ecg_filtered['signal'], fs)
if input("Perform Wavelet filter [y/N]? ") == "y":
# Remove high frequency noise
ecg_filtered['signal'] = wavelet_filter(ecg_filtered['signal'])
"""
ecg_processed['signal'][i] = processed_signal
ecg_processed['fft'][i], ecg_processed['frequency_bins'][i] = compute_fft(ecg_processed["signal"][i], fs)
return ecg_processed
def dict_compare(d1, d2):
d1_keys = set(d1.keys())
d2_keys = set(d2.keys())
shared_keys = d1_keys.intersection(d2_keys)
added = d1_keys - d2_keys
removed = d2_keys - d1_keys
modified = {o : (d1[o], d2[o]) for o in shared_keys if d1[o] != d2[o]}
same = set(o for o in shared_keys if d1[o] == d2[o])
print('dict compare result: \nadded: {},\nremoved: {},\nmodified: {},\nsame: {}'.format(added, removed, modified, same))