-
Notifications
You must be signed in to change notification settings - Fork 0
/
riffusion.py
238 lines (181 loc) · 5.98 KB
/
riffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
from PIL import Image
import numpy as np
#from share_btn import community_icon_html, loading_icon_html, share_js
import torch
#from spectro import wav_bytes_from_spectrogram_image
from diffusers import StableDiffusionPipeline
from diffusers import EulerAncestralDiscreteScheduler
import io
#from os import path
from pydub import AudioSegment
import io
import typing as T
import numpy as np
from PIL import Image
import pydub
from scipy.io import wavfile
import torch
import torchaudio
import gc
def wav_bytes_from_spectrogram_image(image: Image.Image) -> T.Tuple[io.BytesIO, float]:
"""
Reconstruct a WAV audio clip from a spectrogram image. Also returns the duration in seconds.
"""
max_volume = 50
power_for_image = 0.25
Sxx = spectrogram_from_image(image, max_volume=max_volume, power_for_image=power_for_image)
sample_rate = 44100 # [Hz]
clip_duration_ms = 5000 # [ms]
bins_per_image = 512
n_mels = 512
# FFT parameters
window_duration_ms = 100 # [ms]
padded_duration_ms = 400 # [ms]
step_size_ms = 10 # [ms]
# Derived parameters
num_samples = int(image.width / float(bins_per_image) * clip_duration_ms) * sample_rate
n_fft = int(padded_duration_ms / 1000.0 * sample_rate)
hop_length = int(step_size_ms / 1000.0 * sample_rate)
win_length = int(window_duration_ms / 1000.0 * sample_rate)
samples = waveform_from_spectrogram(
Sxx=Sxx,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
num_samples=num_samples,
sample_rate=sample_rate,
mel_scale=True,
n_mels=n_mels,
max_mel_iters=200,
num_griffin_lim_iters=32,
)
wav_bytes = io.BytesIO()
wavfile.write(wav_bytes, sample_rate, samples.astype(np.int16))
wav_bytes.seek(0)
duration_s = float(len(samples)) / sample_rate
return wav_bytes, duration_s
def spectrogram_from_image(
image: Image.Image, max_volume: float = 50, power_for_image: float = 0.25
) -> np.ndarray:
"""
Compute a spectrogram magnitude array from a spectrogram image.
TODO(hayk): Add image_from_spectrogram and call this out as the reverse.
"""
# Convert to a numpy array of floats
data = np.array(image).astype(np.float32)
# Flip Y take a single channel
data = data[::-1, :, 0]
# Invert
data = 255 - data
# Rescale to max volume
data = data * max_volume / 255
# Reverse the power curve
data = np.power(data, 1 / power_for_image)
return data
def spectrogram_from_waveform(
waveform: np.ndarray,
sample_rate: int,
n_fft: int,
hop_length: int,
win_length: int,
mel_scale: bool = True,
n_mels: int = 512,
) -> np.ndarray:
"""
Compute a spectrogram from a waveform.
"""
spectrogram_func = torchaudio.transforms.Spectrogram(
n_fft=n_fft,
power=None,
hop_length=hop_length,
win_length=win_length,
)
waveform_tensor = torch.from_numpy(waveform.astype(np.float32)).reshape(1, -1)
Sxx_complex = spectrogram_func(waveform_tensor).numpy()[0]
Sxx_mag = np.abs(Sxx_complex)
if mel_scale:
mel_scaler = torchaudio.transforms.MelScale(
n_mels=n_mels,
sample_rate=sample_rate,
f_min=0,
f_max=10000,
n_stft=n_fft // 2 + 1,
norm=None,
mel_scale="htk",
)
Sxx_mag = mel_scaler(torch.from_numpy(Sxx_mag)).numpy()
return Sxx_mag
def waveform_from_spectrogram(
Sxx: np.ndarray,
n_fft: int,
hop_length: int,
win_length: int,
num_samples: int,
sample_rate: int,
mel_scale: bool = True,
n_mels: int = 512,
max_mel_iters: int = 200,
num_griffin_lim_iters: int = 32,
device: str = "cuda:0",
) -> np.ndarray:
"""
Reconstruct a waveform from a spectrogram.
This is an approximate inverse of spectrogram_from_waveform, using the Griffin-Lim algorithm
to approximate the phase.
"""
Sxx_torch = torch.from_numpy(Sxx).to(device)
# TODO(hayk): Make this a class that caches the two things
if mel_scale:
mel_inv_scaler = torchaudio.transforms.InverseMelScale(
n_mels=n_mels,
sample_rate=sample_rate,
f_min=0,
f_max=10000,
n_stft=n_fft // 2 + 1,
norm=None,
mel_scale="htk",
max_iter=max_mel_iters,
).to(device)
Sxx_torch = mel_inv_scaler(Sxx_torch)
griffin_lim = torchaudio.transforms.GriffinLim(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
power=1.0,
n_iter=num_griffin_lim_iters,
).to(device)
waveform = griffin_lim(Sxx_torch).cpu().numpy()
return waveform
def mp3_bytes_from_wav_bytes(wav_bytes: io.BytesIO) -> io.BytesIO:
mp3_bytes = io.BytesIO()
sound = pydub.AudioSegment.from_wav(wav_bytes)
sound.export(mp3_bytes, format="mp3")
mp3_bytes.seek(0)
return mp3_bytes
model_id2 = "riffusion/riffusion-model-v1"
pipe2 = StableDiffusionPipeline.from_pretrained(model_id2, torch_dtype=torch.float16)
def safety_checker(images, clip_input):
return images, False
pipe2.safety_checker=safety_checker
#pipe2 = pipe2.to("cuda")
pipe2.enable_attention_slicing()
#pipe2.enable_xformers_memory_efficient_attention()
def get_music(prompt, duration,wavfile_name = "audio.wav",mp3file_name = "audio.mp3"):
global pipe2
pipe2 = pipe2.to("cuda")
if duration == 5:
width_duration=512
else :
width_duration = 512 + ((int(duration)-5) * 128)
spec = pipe2(prompt, height=512, width=width_duration).images[0]
print(spec)
wav = wav_bytes_from_spectrogram_image(spec)
with open(wavfile_name, "wb") as f:
f.write(wav[0].getbuffer())
#Convert to mp3, for video merging function
wavfile = AudioSegment.from_wav(wavfile_name)
wavfile.export(mp3file_name, format="mp3")
pipe2 = pipe2.to("cpu")
gc.collect()
torch.cuda.empty_cache()
return spec, mp3file_name