-
Notifications
You must be signed in to change notification settings - Fork 7
/
get_pretrained_model.py
32 lines (27 loc) · 1.02 KB
/
get_pretrained_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import requests
import sys
import os
def download_file(url, path):
print(f"Downloading {os.path.basename(path)}...")
response = requests.get(url, allow_redirects=True)
with open(path, 'wb') as file:
file.write(response.content)
print(f"Saved {os.path.basename(path)}.")
def get_model(model_type):
gen = f"./models/{model_type}/G_0.pth"
if not os.path.isfile(gen):
model_urls = {
'D_0.pth': 'https://github.com/ORI-Muchim/PolyLangVITS/releases/download/v1.0/D_0.pth',
'G_0.pth': 'https://github.com/ORI-Muchim/PolyLangVITS/releases/download/v1.0/G_0.pth'
}
directory = f'./models/{model_type}'
if not os.path.exists(directory):
os.makedirs(directory)
for filename, url in model_urls.items():
file_path = os.path.join(directory, filename)
download_file(url, file_path)
else:
print('Skipping Download... Model exists.')
if __name__ == "__main__":
model_type = sys.argv[2]
get_model(model_type)