forked from marcbelmont/cnn-watermark-removal
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tests.py
87 lines (69 loc) · 2.75 KB
/
tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import tensorflow as tf
import numpy as np
import os
import watermarks as w
from glob import glob
from datetime import datetime
w.DEBUG = True
class WatermarkTest(tf.test.TestCase):
def setUp(self):
w.FLAGS.logdir = '/tmp/tensorflow_log/%s' % datetime.now()
w.FLAGS.logdir = '/tmp/tensorflow_log'
w.FLAGS.batch_size = 4
w.FLAGS.learning_rate = 1e-1
[os.remove(x) for x in glob(w.FLAGS.logdir + '/events*')]
########
# Data #
########
def test_dataset_paths(self):
w.FLAGS.batch_size = 2
with self.test_session() as sess:
np.set_printoptions(threshold=np.nan)
x, _ = w.dataset_paths(['assets/cat.png'])
sess.run(tf.tables_initializer())
self.assertTupleEqual(x.eval().shape, (1, 300, 300, 4))
def test_dataset_mask(self):
w.FLAGS.batch_size = 2
with self.test_session():
np.set_printoptions(threshold=np.nan)
mask = w.batch_masks(None, 32, 32, .1, .4)
self.assertTupleEqual(mask.eval().shape, (2, 32, 32, 1))
def test_dataset_cifar(self):
w.FLAGS.batch_size = 2
with self.test_session() as sess:
np.set_printoptions(threshold=np.nan)
x, init = w.dataset_cifar()
init(sess)
self.assertTupleEqual(x.eval().shape, (2, 32, 32, 3))
def test_dataset_split(self):
w.FLAGS.batch_size = 2
with self.test_session() as sess:
x, iterator_inits = w.dataset_split(w.dataset_voc2012_rec, .8)
sess.run(iterator_inits[0])
self.assertTupleEqual(x.eval().shape, (2, 120, 120, 3))
def test_dataset_voc2012(self):
w.FLAGS.batch_size = 2
with self.test_session() as sess:
x, init = w.dataset_voc2012()
sess.run(tf.tables_initializer())
self.assertTupleEqual(x.eval().shape, (2, 120, 120, 3))
############
# Pipeline #
############
def test_training(self):
with self.test_session() as sess:
w.train(sess, w.dataset_voc2012_rec)
def test_inference_voc(self):
with self.test_session() as sess:
dv = w.dataset_voc2012
results = w.inference(sess, dv, 1)
self.assertTupleEqual(results[0].shape, (4, 120, 120, 3))
def test_inference_other(self):
with self.test_session() as sess:
def d_cherry(): return w.dataset_paths(['assets/cat.png', ])
def dm(): return w.dataset_paths(['assets/empty.png', ])
def ds(): return w.dataset_paths(['assets/cat-selection.png', ])
results = w.inference(sess, d_cherry, 1, dm, ds)
self.assertTupleEqual(results[0].shape, (1, 300, 300, 3))
if __name__ == '__main__':
tf.test.main()