forked from wz-bff/AutoComper
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sound_reader.py
188 lines (145 loc) · 5.61 KB
/
sound_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
#!/usr/bin/env python
import subprocess
import sys
import numpy as np
# import onnx
import onnxruntime as ort
from typing import Generator, Any
from utils import FFMPEG_PATH
from proglog import default_bar_logger
SAMPLE_RATE = 32000
is_windows = sys.platform.startswith('win')
def subsample(frame: np.ndarray, scale_factor: int) -> np.ndarray:
subframe = frame[:len(frame) - (len(frame) % scale_factor)].reshape(
-1, scale_factor)
subframe_mean = subframe.max(axis=1)
subsample = subframe_mean
if len(frame) % scale_factor != 0:
residual_frame = frame[len(frame) - (len(frame) % scale_factor):]
residual_mean = residual_frame.max()
subsample = np.append(subsample, residual_mean)
return subsample
def get_segments(
scores: np.ndarray,
precision: int,
threshold: float,
offset: int
) -> Generator[Any, Any, Any]:
seq_iter = iter(np.where(scores > threshold)[0])
try:
seq = next(seq_iter)
pred = scores[seq]
segment = {'start': seq, 'end': seq, 'pred': pred}
except StopIteration:
return
for seq in seq_iter:
pred = scores[seq]
if seq - 1 == segment['end']:
segment['end'] = seq
segment['pred'] = max(segment['pred'], pred)
else:
segment['start'] = segment['start']
segment['end'] = segment['end']
yield segment
segment = {'start': seq, 'end': seq, 'pred': pred}
yield segment
def compute_timestamps(
framewise_output: np.ndarray,
precision: int,
threshold: float,
focus_idx: int,
offset: int,
):
focus = framewise_output[:, focus_idx]
# precision in the amount of milliseconds per timestamp sample (higher values will result in less precise timestamps)
subsampled_scores = subsample(focus, precision)
segments = map(
lambda segment: {
'start': segment['start'] * precision / 100 + offset,
'end': segment['end'] * precision / 100 + offset + 1,
'pred': round(float(segment['pred']), 6)
}, get_segments(subsampled_scores, precision, threshold, offset))
return segments
def pad_array_if_needed(arr, desired_size, pad_value=0):
current_size = arr.shape[0]
if current_size < desired_size:
padding_needed = desired_size - current_size
padded_array = np.pad(
arr, (0, padding_needed), "constant", constant_values=(pad_value,)
)
return padded_array
else:
return arr
def load_audio(file: str, sr: int, frame_count: int):
cmd = [
FFMPEG_PATH, '-hide_banner', '-loglevel', 'warning', '-i', file,
'-filter_complex', '[0:a]aresample=32000:async=1,asetpts=PTS-STARTPTS,atempo=1,pan=mono|c0=c0[audio]', '-map',
'[audio]', '-f', 's16le', '-acodec', 'pcm_s16le', '-ar', str(
sr), '-ac', '1', '-bufsize', '128k', '-'
]
# Specify subprocess options to suppress the command prompt on Windows
subprocess_options = {
'stdout': subprocess.PIPE,
'stderr': subprocess.PIPE,
}
if is_windows:
subprocess_options['creationflags'] = subprocess.CREATE_NO_WINDOW
chunk_size = frame_count * 2
process = subprocess.Popen(
cmd, bufsize=1, **subprocess_options)
try:
while True:
chunk = process.stdout.read(chunk_size)
if not chunk:
break
yield chunk
except GeneratorExit:
# Thrown if the user cancels the process (i.e. kills the thread)
process.terminate()
process.wait()
return
process.stdout.close()
return_code = process.wait()
if return_code:
if process.returncode != 0:
raise Exception(
"Failed to process the file. Either the file does not exist or is corrupted.")
raise subprocess.CalledProcessError(return_code, cmd)
def get_timestamps(file, precision=100, block_size=600, threshold=0.90, focus_idx=58, model="bdetectionmodel_05_01_23", logger=None):
# Input checking
if precision < 0:
raise Exception("Precision must be a positive number!")
if not (threshold >= 0 and threshold <= 1):
raise Exception("Threshold must be between 0 and 1!")
if block_size < 0:
raise Exception("Block size must be a positive number!")
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
ort_session = ort.InferenceSession(
model,
sess_options,
providers=ort.get_available_providers()
)
offset = 0
blocks = list(load_audio(file, SAMPLE_RATE, SAMPLE_RATE * block_size))
info = {'filename': file, 'timestamps': []}
frame_count = SAMPLE_RATE * block_size
if logger:
bar_logger = default_bar_logger(logger)
blocks = bar_logger.iter_bar(block=blocks)
for block in blocks:
samples = np.frombuffer(block, dtype=np.int16)
samples = pad_array_if_needed(samples, frame_count)
samples = samples.reshape(1, -1)
samples = samples / (2**15)
samples = samples.astype(np.float32)
ort_inputs = {"input": samples}
framewise_output = ort_session.run(["output"], ort_inputs)[0]
preds = framewise_output[0]
info["timestamps"].extend(
compute_timestamps(
preds, precision, threshold, focus_idx, offset
)
)
offset += block_size
return info