-
Notifications
You must be signed in to change notification settings - Fork 2
/
mixture_mlp.py
102 lines (90 loc) · 3.21 KB
/
mixture_mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import sys
from os.path import join
import numpy as np
import yaml
from sklearn.metrics import balanced_accuracy_score
from utils import (build_model, diagnose_output, prepare_mixture_dataset,
print_dataset_info, repeat_and_collate,
set_classification_targets)
def classify(**args):
"""
Main method that prepares dataset, builds model, executes training and displays results.
:param args: keyword arguments passed from cli parser
"""
with open('config/datasets.yaml') as cnf:
dataset_configs = yaml.safe_load(cnf)
try:
repo_path = dataset_configs['repo_path']
except KeyError as e:
print(f'Missing dataset config key: {e}')
sys.exit(1)
batch_size = 64
repetitions = args['repetitions']
# determine classification targets and parameters to construct datasets properly
cls_target, cls_str = set_classification_targets(args['cls_choice'])
# list of 5% increments ranging from 0% to 100%
mixture_range = np.arange(0, 1.01, .05)
results = np.zeros((len(mixture_range), repetitions))
for i,cut in enumerate(mixture_range):
print(f'cut: {cut}')
d = prepare_mixture_dataset(
cls_target,
args['batch_size'],
mixture_pct=cut,
normalisation=args['norm_choice'])
# perform #repetitions per 5% dataset mixture
for j in range(repetitions):
model = build_model(0, d['num_classes'], name='baseline_mlp', new_input=True)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# train and evaluate
model.fit(
d['train_data'],
steps_per_epoch=d['train_steps'],
epochs=args['epochs'],
verbose=0,
class_weight=d['class_weights'])
results[i,j] = balanced_accuracy_score(d['test_labels'], model.predict(d['test_data'](), steps=d['test_steps']).argmax(axis=1))
print(results)
np.save(join(repo_path, 'data/synthetic_influence_target_{cls_target}', results))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'-r', '--repetitions',
type=int,
default=1,
help='Number of times to repeat experiment',
dest='repetitions'
)
parser.add_argument(
'-b', '--batchsize',
type=int,
default=64,
help='Target batch size of dataset preprocessing',
dest='batch_size'
)
parser.add_argument(
'-c', '--classification',
type=int,
choices=[0, 1, 2],
default=2,
help='Which classification target to pursue. 0=classes, 1=subgroups, 2=minerals',
dest='cls_choice'
)
parser.add_argument(
'-e', '--epochs',
type=int,
default=10,
help='How many epochs to train for',
dest='epochs'
)
parser.add_argument(
'-n', '--normalisation',
type=int,
choices=[0, 1, 2],
default=2,
help='Which normalisation to use. 0=None, 1=snv, 2=minmax',
dest='norm_choice'
)
args = parser.parse_args()
classify(**vars(args))