-
Notifications
You must be signed in to change notification settings - Fork 3
/
features.py
132 lines (92 loc) · 3.96 KB
/
features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import glob
import os
import torch
import warnings
import torch.multiprocessing
import torch.nn.functional as F
# this is really important - without this the program fails with a
# "too many files open" error, at least on UNIX systems
torch.multiprocessing.set_sharing_strategy('file_system')
warnings.filterwarnings("ignore")
import numpy as np
from multiprocessing import cpu_count
import utils
audio_dir = os.path.join(os.path.curdir, 'fma_medium')
target_dir = os.path.join(audio_dir, 'targets')
os.environ['AUDIO_DIR'] = audio_dir
tracks = utils.load('tracks.csv')
genres = utils.load('genres.csv')
target_paths = [*glob.iglob(os.path.join(target_dir, '*_targets.npz'), recursive=True)]
tids = list(map(lambda x: int(os.path.splitext(os.path.basename(x).replace('_targets', ''))[0]), target_paths))
tracks_subset = tracks['track'].loc[tids]
genres_subset = tracks_subset['genre_top']
artists_subset = tracks['artist'].loc[tids]
from torch.utils.data import Dataset, DataLoader, random_split
genre_counts = genres_subset.value_counts()
genre_counts = genre_counts[genre_counts > 0]
print(genre_counts)
coded_genres = {genre: k for k, genre in enumerate(genre_counts.index)}
coded_genres_reverse = {k: genre for genre, k in coded_genres.items()}
print(coded_genres)
# X frames with 50% overlap = 2X-1 frames
num_frames = 4
total_frames = 2 * num_frames - 1
frame_size = 1290 // num_frames
class FeatureDataset(Dataset):
def __len__(self):
return len(target_paths)
def __getitem__(self, idx):
path = target_paths[idx]
tid = tids[idx]
# argmax these
names = ['subgenres', 'mfcc', 'chroma', 'spectral_contrast']
features = {}
with np.load(path) as data:
for k in names:
# features[k] = data[k].argmax()
features[k] = F.softmax(torch.from_numpy(data[k]), dim=0)
mel = data['mel']
features['genre'] = coded_genres[tracks_subset['genre_top'][tid]]
return mel, features
class FramedFeatureDataset(FeatureDataset):
def __len__(self):
return total_frames * super().__len__()
def __getitem__(self, idx):
if isinstance(idx, torch.Tensor):
idx = idx.item()
song_idx, frame = divmod(idx, total_frames)
mel, features = super().__getitem__(song_idx)
shift, half_shift = divmod(frame, 2)
i = shift * frame_size + half_shift * frame_size // 2
# add channel dimension so its 1x128x(frame_size)
mel_frame = np.expand_dims(mel[:, i:i + frame_size], axis=0)
return mel_frame, features
def get_data_loaders(dataset, batch_size, valid_split):
dataset_len = len(dataset)
# split dataset
valid_len = int(dataset_len * valid_split)
train_len = dataset_len - valid_len
train_dataset, valid_dataset = random_split(dataset, [train_len, valid_len])
# disable if it fucks things up but if it doesnt its apparently rly good
pin_memory = True
train_loader = DataLoader(train_dataset,
batch_size=batch_size,
num_workers=cpu_count(),
# sampler=train_sampler,
shuffle=True,
pin_memory=pin_memory)
valid_loader = DataLoader(valid_dataset,
batch_size=batch_size,
num_workers=cpu_count(),
shuffle=False,
pin_memory=pin_memory)
return train_loader, valid_loader
if __name__ == '__main__':
dataset = FramedFeatureDataset()
print(len(dataset))
train_loader, valid_loader = get_data_loaders(dataset, 64, 0.15)
print(len(train_loader))
for _ in range(2):
for i, batch in enumerate(train_loader):
if i % 30 == 0:
print(i, '/', len(train_loader))