forked from riebling/dscore
-
Notifications
You must be signed in to change notification settings - Fork 0
/
score_batch.py
executable file
·313 lines (258 loc) · 10.5 KB
/
score_batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
#!/usr/bin/env python
"""Score diarization system output for a batch of files and write to a
dataframe.
To evaluate system output stored in RTTM files in the directory ``sys_dir``
against reference RTTM files stored in the directory ``ref_dir`` and write
the output to a file ``scores.df``:
python score_batch.py scores.df ref_dir sys_dir
This will scan both ``ref_dir`` and ``sys_dir`` for files with the ``.rttm``
extension, score each file found in both directories, and write the scores to
a tab-delimited file suitable for reading into R as a dataframe. Alternately,
the file ids could have been specified explicitly via a script file of ids
(one per line) using the ``-S`` flag:
python score_batch.py -S all.scp scores.df ref_dir sys_dir
Minimally, the output dataframe has the following columns:
- FID -- the file id
- DER -- diarization error rate
- B3Precision -- B-cubed precision
- B3Recall -- B-cubed recall
- B3F1 -- B-cubed F1
- GKTRefSys -- Goodman-Kruskal tau in the direction of the reference
diarization to the system diarization
- GKTSysRef -- Goodman-Kruskal tau in the direction of the system diarization
to the reference diarization
- HRefSys -- conditional entropy of the reference diarization given the
system diarization (bits)
- MI -- mutual information (bits)
- NMI -- normalized mutual information (bits)
Optionally, it may contain additional columns specified via the
``--additional_columns`` flag, which takes a string containing semicolon
delimited column name/value pairs, each pair having the form:
CNAME=VAL
For instance, the string
Corpus=AMI;NClusters=4
would result in two additional columns, "Corpus" and "NClusters", being output
with the values "AMI" and 4 respectively in each row.
Diarization error rate (DER) is scored using the NIST ``md-eval.pl`` tool
using a default collar size of 0.0 ms. If the value is not zero, it's ignoring
regions that contain overlapping speech in the reference RTTM. If desired,
this behavior can be altered using the ``--collar`` and ``--score_overlaps`` flags.
For instance :
python --collar 0.100 --score_overlaps score.py ref.rttm sys.rttm
would compute DER using a 100 ms collar and with overlapped speech included.
All other metrics are computed off of frame-level labelings created from the
turns in the RTTM files **WITHOUT** any use of collars. The default frame
step is 10 ms, which may be altered via the ``--step`` flag.
"""
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import glob
import os
import sys
from multiprocessing import Pool
from scorelib import __version__ as VERSION
from scorelib.logging import getLogger
from scorelib.score import score
logger = getLogger()
def _score_recordings(args):
fid, ref_rttm_dir, sys_rttm_dir, collar, ignore_overlaps, step = args
ref_rttm_fn = os.path.join(ref_rttm_dir, fid +'.rttm')
sys_rttm_fn = os.path.join(sys_rttm_dir, fid + '.rttm')
fail = False
if not (os.path.exists(ref_rttm_fn)):
logger.warn('Missing reference RTTM: %s. Skipping.' % sys_rttm_fn)
fail = True
if not (os.path.exists(ref_rttm_fn)):
logger.warn('Missing system RTTM: %s. Skipping.' % sys_rttm_fn)
fail = True
if fail:
return
row = [fid]
row.extend(score(ref_rttm_fn, sys_rttm_fn))
return row
def score_recordings(fids, ref_rttm_dir, sys_rttm_dir, collar, ignore_overlaps,
step, n_jobs=1):
"""Score batch of recordings.
Parameters
----------
fid : list of str
File ids.
ref_rttm_dir : str
Path to directory containing reference RTTM files.
sys_rttm_dur : str
Path to directory containing system RTTM files.
collar : float, optional
Size of forgiveness collar in seconds. Diarization output will not be
evaluated within +/- ``collar`` seconds of reference speaker
boundaries. Only relevant for computing DER.
(Default: 0.0)
ignore_overlaps : bool, optional
If True, ignore regions in the reference diarization in which more
than one speaker is speaking. Only relevant for computing DER.
(Default: True)
step : float, optional
Frame step size in seconds. Not relevant for computation of DER.
(Default: 0.01)
n_jobs : int, optional
Number of threads to use.
(Default: 1)
"""
def args_gen():
for fid in fids:
yield (fid, ref_rttm_dir, sys_rttm_dir, collar, ignore_overlaps,
step)
if n_jobs == 1:
rows = [_score_recordings(args) for args in args_gen()]
else:
pool = Pool(n_jobs)
rows = pool.map(_score_recordings, args_gen())
rows = [row for row in rows if row]
return rows
def write_dataframe(fn, rows, additional_columns=None, enc='utf-8'):
"""Write scores to dataframe.
Parameters
----------
fn : str
Output dataframe.
rows : list of tuple
Rows of dataframe.
additonal_columns : list of tuple, optional
List of column name/value pairs specifying additional columns to be
written.
(Default: None)
enc : str, optional
Character encoding.
(Default: 'utf-8')
"""
with open(fn, 'wb') as f:
def write_line(vals):
vals = map(str, vals)
line = '\t'.join(vals)
f.write(line.encode(enc))
f.write('\n')
# Write header.
col_names = ['DER', # Diarization error rate.
'B3Precision', # B-cubed precision.
'B3Recall', # B-cubed recall.
'B3F1', # B-cubed F1.
'TauRefSys', # Goodman-Kruskal tau ref --> sys.
'TauSysRef', # Goodman-Kruskal tau sys --> ref.
'CE', # H(ref | sys).
'MI', # Mutual information between ref and sys.
'NMI', # Normalized mutual information between ref/sys.
]
if additional_columns:
col_names.extend(col_name for col_name, val in additional_columns)
write_line(col_names)
# Write rows.
for row in rows:
if additional_columns:
row.extend(val for col_name, val in additional_columns)
write_line(row)
def parse_additional_columns(spec_str):
"""Parse additional columns specification.
The column specification should be a semicolon delimited list of column
name/value pairs, each pair having the form
CNAME=VAL
For instance, the string
Corpus=AMI;NClusters=4
would be parsed as specifying two columns, "Corpus" and "NClusters",
taking on the values "AMI" and 4 respectively.
Parameters
----------
spec_str : str
Additional columns specificiation.
Returns
-------
additional_columns : list of tuple
List of column name/value pairs.
"""
if spec_str == '':
return []
else:
return [pair.split('=') for pair in spec_str.split(';')]
def _check_collar(ref_rttm, collar):
""" Check if the cumulated duration of the reference transcripted speech
is greater than two times the collar.
Parameters
----------
ref_rttm : str
path to the reference transcription (.rttm) that needs to be analyzed.
Returns
-------
No result. This function raises an error if the total duration
is lower than two times the collar.
"""
pass_collar_test = False
with open(ref_rttm) as fn:
for line in fn:
row = line.split()
t_dur = row[4]
t_dur = float(t_dur)
if t_dur > 2.0 * collar:
pass_collar_test = True
if not pass_collar_test:
raise ValueError(
"The transcription {} has no line whose duration is greater than two times the collar.\n".format(ref_rttm) +
"You should remove this file or set the collar to a lower value.\n")
def _get_fids(ref_rttm_dir, sys_rttm_dir, collar):
ref_bns = []
for fn in glob.glob(os.path.join(ref_rttm_dir, '*.rttm')):
if not os.stat(fn).st_size == 0:
_check_collar(fn, collar)
ref_bns.append(os.path.basename(fn))
sys_bns = []
for fn in glob.glob(os.path.join(sys_rttm_dir, '*.rttm')):
if not os.stat(fn).st_size == 0:
sys_bns.append(os.path.basename(fn))
bns = set(ref_bns) & set(sys_bns)
return sorted([bn.replace('.rttm', '') for bn in bns])
if __name__ == '__main__':
# Parse command line arguments.
parser = argparse.ArgumentParser(
description='Score RTTMs.', add_help=True,
usage='%(prog)s [options] scoresf ref_rttm_dir sys_rttm_dir')
parser.add_argument(
'scoresf', nargs=None, help='output dataframe')
parser.add_argument(
'ref_rttm_dir', nargs=None, help='reference RTTM directory')
parser.add_argument(
'sys_rttm_dir', nargs=None, help='system RTTM directory')
parser.add_argument(
'-S', nargs=None, default=None, metavar='FILE', dest='scpf',
help='set script file (Default: None)')
parser.add_argument(
'--collar', nargs=None, default=0.0, type=float, metavar='FLOAT',
help='collar size in seconds for DER computaton '
'(Default: %(default)s)')
parser.add_argument(
'--score_overlaps', action='store_false', default=True,
dest='ignore_overlaps',
help='score overlaps when computing DER')
parser.add_argument(
'--step', nargs=None, default=0.010, type=float, metavar='FLOAT',
help='step size in seconds (Default: %(default)s)')
parser.add_argument(
'--additional_columns', nargs=None, default='',
help='additional columns')
parser.add_argument(
'-j', nargs=None, default=1, type=int, metavar='N', dest='n_jobs',
help='set number of threads to use (Default: 1)')
parser.add_argument(
'--version', action='version',
version='%(prog)s ' + VERSION)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
args = parser.parse_args()
if args.scpf is not None:
with open(args.scpf, 'rb') as f:
fids = [line.strip() for line in f]
else:
fids = _get_fids(args.ref_rttm_dir, args.sys_rttm_dir, args.collar)
rows = score_recordings(
fids, args.ref_rttm_dir, args.sys_rttm_dir, args.collar,
args.ignore_overlaps, args.step, args.n_jobs)
additional_columns = parse_additional_columns(args.additional_columns)
write_dataframe(args.scoresf, rows, additional_columns)