forked from nianlonggu/WhisperSeg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
106 lines (94 loc) · 4.78 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import argparse
import json
import logging
import os
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Union
import librosa
import numpy as np
from datautils import get_audio_and_label_paths
from model import WhisperSegmenterFast
from train import evaluate
from utils import create_if_not_exists
def convert_numpy_to_regular(data: Union[np.generic, List, Dict]):
"""Convert numpy types to regular types.
Args:
data (Union[np.generic, List, Dict]): Data to be converted
"""
if isinstance(data, np.generic): # numpy.int8/16/32/64, numpy.float8/16/32/64
return data.item()
elif isinstance(data, dict):
return {key: convert_numpy_to_regular(value) for key, value in data.items()}
elif isinstance(data, list):
return [convert_numpy_to_regular(item) for item in data]
else:
return data
def evaluate_dataset(dataset_path: str, model_path: str, num_trials: int, consolidation_method: str = "clustering",
max_length: int = 448, num_beams: int = 4, batch_size: int = 8, **kwargs) -> Dict:
"""Evaluate a trained WhisperSeg checkpoint on a dataset.
Args:
dataset_path (str): Path to the dataset to be evaluated
model_path (str): Path to the trained model checkpoint
num_trials (int): Number of trials
consolidation_method (str, optional): Method used for consolidating the results of majority voting. Defaults to "clustering".
max_length (int, optional): Maximum allowed number of tokens generated by the Whisper decoder, based on the original implementation. Defaults to 448.
num_beams (int, optional): Beam size during decoding. Defaults to 4.
batch_size (int, optional): Batch size. Defaults to 8.
Returns:
dict: Evaluation results
"""
audio_list, label_list = [], []
audio_paths, label_paths = get_audio_and_label_paths(dataset_path)
for audio_path, label_path in zip(audio_paths, label_paths):
with open(label_path, 'r') as f:
label = json.load(f)
audio, _ = librosa.load(audio_path, sr = label["sr"])
audio_list.append(audio)
label_list.append(label)
segmenter = WhisperSegmenterFast(model_path = model_path, device = "cuda")
if kwargs['identifier']:
cm_name = raw_data_name = kwargs['identifier']
else:
cm_name = raw_data_name = None
res = evaluate(audio_list, label_list, segmenter, batch_size, max_length, num_trials, consolidation_method, num_beams, target_cluster = None, confusion_matrix=cm_name, save_cm_data=raw_data_name)
all_res = {
"segment_wise_scores": {
"N-true-positive": res["segment_wise"][0],
"N-positive-in-prediction": res["segment_wise"][1],
"N-positive-in-ground-truth": res["segment_wise"][2],
"precision": res["segment_wise"][3],
"recall": res["segment_wise"][4],
"F1": res["segment_wise"][5]
},
"frame_wise_scores": {
"N-true-positive": res["frame_wise"][0],
"N-positive-in-prediction": res["frame_wise"][1],
"N-positive-in-ground-truth": res["frame_wise"][2],
"precision": res["frame_wise"][3],
"recall": res["frame_wise"][4],
"F1": res["frame_wise"][5],
},
}
return convert_numpy_to_regular(all_res)
if __name__ == "__main__":
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
parser = argparse.ArgumentParser(description="Evaluate a trained WhisperSeg checkpoint on a dataset.")
parser.add_argument("-d", "--dataset_path", type=str, help="Path to the dataset to be evaluated", required=True)
parser.add_argument("-m", "--model_path", type=str, help="Path to the trained model checkpoint", required=True)
parser.add_argument("-o", "--output_dir", type=str, help="Path to a directory where the output files will be saved", default=None)
parser.add_argument("-i", "--identifier", type=str, help="Unique identifier used for the output file name in case model path and dataset name are not meaningful.", default=None)
parser.add_argument("-n", "--num_trials", type=int, help="Number of trials", default=3)
args = parser.parse_args()
eval_res = evaluate_dataset(**vars(args))
if args.output_dir == None:
out_path = create_if_not_exists(os.path.join(os.getcwd(), "results"))
else:
out_path = args.output_dir
if args.identifier == None:
out_name=os.path.join(out_path, datetime.now().strftime("%Y%m%d-%H%M%S") + f'_eval_{Path(args.model_path).stem}_{Path(args.dataset_path).stem}.txt')
else:
out_name=os.path.join(out_path, datetime.now().strftime("%Y%m%d-%H%M%S") + f'_eval_{args.identifier}.txt')
with open(out_name, "w") as f:
f.write(json.dumps(eval_res, indent=2))