forked from CSTR-Edinburgh/ophelia
-
Notifications
You must be signed in to change notification settings - Fork 8
/
synthesise_validation_waveforms.py
executable file
·101 lines (76 loc) · 3.27 KB
/
synthesise_validation_waveforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python
# -*- coding: utf-8 -*-
## Project: SCRIPT - February 2018
## Contact: Oliver Watts - [email protected]
import sys
import os
import glob
from argparse import ArgumentParser
import imp
import numpy as np
from utils import spectrogram2wav
# from scipy.io.wavfile import write
import soundfile
import tqdm
from concurrent.futures import ProcessPoolExecutor
import tensorflow as tf
from architectures import SSRNGraph
from synthesize import make_mel_batch, split_batch, synth_mel2mag
from configuration import load_config
def synth_wave(hp, magfile):
mag = np.load(magfile)
#print ('mag shape %s'%(str(mag.shape)))
wav = spectrogram2wav(hp, mag)
outfile = magfile.replace('.mag.npy', '.wav')
outfile = outfile.replace('.npy', '.wav')
#print magfile
#print outfile
#print
# write(outfile, hp.sr, wav)
soundfile.write(outfile, wav, hp.sr)
def main_work():
#################################################
# ======== Get stuff from command line ==========
a = ArgumentParser()
a.add_argument('-c', dest='config', required=True, type=str)
a.add_argument('-ncores', type=int, default=1)
opts = a.parse_args()
# ===============================================
hp = load_config(opts.config)
### 1) convert saved coarse mels to mags with latest-trained SSRN
print('mel2mag: restore last saved SSRN')
g = SSRNGraph(hp, mode="synthesize")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
## TODO: use restore_latest_model_parameters from synthesize?
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'SSRN')
saver2 = tf.train.Saver(var_list=var_list)
savepath = hp.logdir + "-ssrn"
latest_checkpoint = tf.train.latest_checkpoint(savepath)
if latest_checkpoint is None: sys.exit('No SSRN at %s?'%(savepath))
ssrn_epoch = latest_checkpoint.strip('/ ').split('/')[-1].replace('model_epoch_', '')
saver2.restore(sess, latest_checkpoint)
print("SSRN Restored from latest epoch %s"%(ssrn_epoch))
filelist = glob.glob(hp.logdir + '-t2m/validation_epoch_*/*.npy')
filelist = [fname for fname in filelist if not fname.endswith('.mag.npy')]
batch, lengths = make_mel_batch(hp, filelist, oracle=False)
Z = synth_mel2mag(hp, batch, g, sess, batchsize=32)
print ('synthesised mags, now splitting batch:')
maglist = split_batch(Z, lengths)
for (infname, outdata) in tqdm.tqdm(zip(filelist, maglist)):
np.save(infname.replace('.npy','.mag.npy'), outdata)
### 2) GL in parallel for both t2m and ssrn validation set
print('GL for SSRN validation')
filelist = glob.glob(hp.logdir + '-t2m/validation_epoch_*/*.mag.npy') + \
glob.glob(hp.logdir + '-ssrn/validation_epoch_*/*.npy')
if opts.ncores==1:
for fname in tqdm.tqdm(filelist):
synth_wave(hp, fname)
else:
executor = ProcessPoolExecutor(max_workers=opts.ncores)
futures = []
for fpath in filelist:
futures.append(executor.submit(synth_wave, hp, fpath))
proc_list = [future.result() for future in tqdm.tqdm(futures)]
if __name__=="__main__":
main_work()