-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
96 lines (84 loc) · 3.29 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from pathlib import Path
import shutil
import librosa
import torch
import soundfile
from modules.dc_crn import DCCRN as Model
from tools.lzf_utils.audio_utils import AudioUtils
def enchance_one(net: Model, in_wav_path: Path, out_wav_path: Path, out_input=False):
try:
in_data, _ = librosa.load(in_wav_path, sr=g_sr)
inputs = torch.FloatTensor(in_data[None]).to(device)
with torch.no_grad():
_, output = net(inputs)
output = output.cpu().numpy().squeeze()
if out_input:
output = AudioUtils.merge_channels(in_data, output)
soundfile.write(out_wav_path, output, g_sr)
print(out_wav_path)
except Exception as e:
print("error:", e)
...
def enhance(
net: Model,
idx: int,
in_wav_list: list,
out_wav_dir: Path,
out_input: bool,
):
def get_experiment_name():
res = checkpoint_dir.parent.name
res = res[res.rfind("]") + 1 :]
return res
out_wav_dir.mkdir(parents=True, exist_ok=True)
for in_f in map(Path, in_wav_list):
out_f = out_wav_dir.joinpath(
in_f.stem + ";" + get_experiment_name() + f"_ep{idx:03}" + in_f.suffix
)
enchance_one(net, in_f, out_f, out_input=out_input)
...
if __name__ == "__main__":
############ configuration start ############
(
rnn_units,
frame_len,
frame_hop,
masking_mode,
kernel_num,
g_sr,
model_index_ranges,
device,
out_input,
) = (128, 512, 256, "R", [8, 16], 16000, 28, "cuda", bool(0))
out_wav_base_dir = Path(r"/home/featurize/train_output/enhanced/out/")
checkpoint_dir = Path(
r"/home/featurize/train_output/models/[server][full]DCCRN_0103_sisdr_dnsdrb_half_hamming_2kernel_128u/checkpoints"
)
in_wav_list = [
# r"/home/featurize/data/from_lzf/evaluation_data/1.in_data/input.wav",
# r"/home/featurize/data/from_lzf/evaluation_data/1.in_data/中会议室_女声_降噪去混响测试.wav",
# r"/home/featurize/data/from_lzf/evaluation_data/1.in_data/小会议室_女声_降噪去混响测试.wav",
# r"/home/featurize/data/from_lzf/evaluation_data/1.in_data/大会议室_男声_降噪去混响测试_RK降噪开启.wav",
# r"/home/featurize/data/from_lzf/evaluation_data/1.in_data/大会议室_男声_降噪去混响测试_RK降噪开启_mic1.wav",
r"/home/featurize/data/from_lzf/evaluation_data/1.in_data/large_meeting_room_input.wav",
]
############ configuration end ############
shutil.rmtree(out_wav_base_dir, ignore_errors=True)
if isinstance(model_index_ranges, int):
model_index_ranges = (model_index_ranges, model_index_ranges + 1)
for idx in range(*model_index_ranges):
ckpt_path = checkpoint_dir.joinpath(f"model_{idx:04}.pth")
net = Model(
rnn_units=rnn_units,
win_len=frame_len,
win_inc=frame_hop,
fft_len=frame_len,
masking_mode=masking_mode,
kernel_num=kernel_num,
).to(device)
net.load_state_dict(torch.load(ckpt_path, device))
net.eval()
# out_wav_dir = out_wav_base_dir.joinpath(f"ep{idx:03}")
out_wav_dir = Path(out_wav_base_dir)
enhance(net, idx, in_wav_list, out_wav_dir, out_input)
...