forked from chankeh/CardiacSeg-master
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
109 lines (67 loc) · 3.73 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import tensorflow as tf
import tensorflow.contrib.slim as slim
import SegNetCMR
WORKING_DIR = os.getcwd()
TRAINING_DIR = os.path.join(WORKING_DIR, 'Data', 'Training')
TEST_DIR = os.path.join(WORKING_DIR, 'Data', 'Test')
ROOT_LOG_DIR = os.path.join(WORKING_DIR, 'Output')
RUN_NAME = "Run3x3"
LOG_DIR = os.path.join(ROOT_LOG_DIR, RUN_NAME)
TRAIN_WRITER_DIR = os.path.join(LOG_DIR, 'Train')
TEST_WRITER_DIR = os.path.join(LOG_DIR, 'Test')
CHECKPOINT_FN = 'model.ckpt'
CHECKPOINT_FL = os.path.join(LOG_DIR, CHECKPOINT_FN)
BATCH_NORM_DECAY = 0.95 #Start off at 0.9, then increase.
MAX_STEPS = 500 #原先是20000步,这里我为了节省时间,先改成500
BATCH_SIZE = 6
SAVE_INTERVAL = 50
def main():
training_data = SegNetCMR.GetData(TRAINING_DIR)
test_data = SegNetCMR.GetData(TEST_DIR)
g = tf.Graph()
with g.as_default():
images, labels, is_training = SegNetCMR.placeholder_inputs(batch_size=BATCH_SIZE)
arg_scope = SegNetCMR.inference_scope(is_training=True, batch_norm_decay=BATCH_NORM_DECAY)
with slim.arg_scope(arg_scope):
logits = SegNetCMR.inference(images, class_inc_bg=2)
SegNetCMR.add_output_images(images=images, logits=logits, labels=labels)
loss = SegNetCMR.loss_calc(logits=logits, labels=labels)
train_op, global_step = SegNetCMR.training(loss=loss, learning_rate=1e-04)
accuracy = SegNetCMR.evaluation(logits=logits, labels=labels)
summary = tf.summary.merge_all()
init = tf.global_variables_initializer()
saver = tf.train.Saver([x for x in tf.global_variables() if 'Adam' not in x.name])
sm = tf.train.SessionManager()
with sm.prepare_session("", init_op=init, saver=saver, checkpoint_dir=LOG_DIR) as sess:
sess.run(tf.variables_initializer([x for x in tf.global_variables() if 'Adam' in x.name]))
train_writer = tf.summary.FileWriter(TRAIN_WRITER_DIR, sess.graph)
test_writer = tf.summary.FileWriter(TEST_WRITER_DIR)
global_step_value, = sess.run([global_step])
print("Last trained iteration was: ", global_step_value)
for step in range(global_step_value+1, global_step_value+MAX_STEPS+1):
print("Iteration: ", step)
images_batch, labels_batch = training_data.next_batch(BATCH_SIZE)
train_feed_dict = {images: images_batch,
labels: labels_batch,
is_training: True}
_, train_loss_value, train_accuracy_value, train_summary_str = sess.run([train_op, loss, accuracy, summary], feed_dict=train_feed_dict)
if step % SAVE_INTERVAL == 0:
print("Train Loss: ", train_loss_value)
print("Train accuracy: ", train_accuracy_value)
train_writer.add_summary(train_summary_str, step)
train_writer.flush()
images_batch, labels_batch = test_data.next_batch(BATCH_SIZE)
test_feed_dict = {images: images_batch,
labels: labels_batch,
is_training: False}
test_loss_value, test_accuracy_value, test_summary_str = sess.run([loss, accuracy, summary], feed_dict=test_feed_dict)
print("Test Loss: ", test_loss_value)
print("Test accuracy: ", test_accuracy_value)
test_writer.add_summary(test_summary_str, step)
test_writer.flush()
saver.save(sess, CHECKPOINT_FL, global_step=step)
print("Session Saved")
print("================")
if __name__ == '__main__':
main()