-
Notifications
You must be signed in to change notification settings - Fork 0
/
bisenet_test_rect.py
64 lines (50 loc) · 2.15 KB
/
bisenet_test_rect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from model import bisenet
from utility.loss import seg_loss, dice_coef, iou_coef
import os
import glob
import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.callbacks import ModelCheckpoint
from sklearn.preprocessing import binarize
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
def npy_loader(maindir, seed: int = 2):
# Get list of files in directory
directory = maindir + '*.npy'
pathlist = glob.glob(directory)
# Iterate over list of files
for path in pathlist:
array = np.load(path)
img = array[:, :, 0]
img = (img * 2) - 1
mask = binarize(array[:, :, 1])
yield img[..., np.newaxis], mask[..., np.newaxis]
def npy_dataset(maindir, shape_i, shape_m, seed: int = 2, batch: int = 1):
ds = tf.data.Dataset.from_generator(lambda: npy_loader(maindir=maindir, seed=seed),
output_types=(tf.float16, tf.float16),
output_shapes=(shape_i, shape_m))
return ds.batch(batch)
batch_size = 16
train_df = npy_dataset('/home/careinfolab/unet_mammo/images/pos_norm/train/',
(1024, 832, 1),
(1024, 832, 1),
seed=2,
batch=batch_size)
val_df = npy_dataset('/home/careinfolab/unet_mammo/images/pos_norm/val/',
(1024, 832, 1),
(1024, 832, 1),
seed=2,
batch=batch_size)
test_df = npy_dataset('/home/careinfolab/unet_mammo/images/pos_norm/test/',
(1024, 832, 1),
(1024, 832, 1),
seed=2,
batch=batch_size)
bn = bisenet.BiSeNetV2((1024, 832, 1), 1)
model = bn.build_graph()
model.compile(optimizer=Adam(), loss=[seg_loss], metrics=[dice_coef, iou_coef])
name = './saved_models/bisenetv2_test1_rect'
model_checkpoint = ModelCheckpoint(name, monitor='val_loss', save_best_only=True)
history = model.fit(train_df, epochs=150, verbose=1, shuffle=True,
validation_data=val_df,
callbacks=[model_checkpoint])