-
Notifications
You must be signed in to change notification settings - Fork 13
/
ops.py
82 lines (66 loc) · 4.27 KB
/
ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import tensorflow as tf
def batchnorm(x, train_phase, scope_bn):
#Batch Normalization
#Ioffe S, Szegedy C. Batch normalization: accelerating deep network training by reducing internal covariate shift[J]. 2015:448-456.
with tf.variable_scope(scope_bn, reuse=tf.AUTO_REUSE):
beta = tf.get_variable(name='beta', shape=[x.shape[-1]], initializer=tf.constant_initializer([0.]), trainable=True)
gamma = tf.get_variable(name='gamma', shape=[x.shape[-1]], initializer=tf.constant_initializer([1.]), trainable=True)
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments')
ema = tf.train.ExponentialMovingAverage(decay=0.5)
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = tf.cond(train_phase, mean_var_with_update,
lambda: (ema.average(batch_mean), ema.average(batch_var)))
normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
return normed
def InstanceNorm(inputs, name):
with tf.variable_scope(name):
mean, var = tf.nn.moments(inputs, axes=[1, 2], keep_dims=True)
scale = tf.get_variable("scale", shape=mean.shape[-1], initializer=tf.constant_initializer([1.]))
shift = tf.get_variable("shift", shape=mean.shape[-1], initializer=tf.constant_initializer([0.]))
return (inputs - mean) * scale / tf.sqrt(var + 1e-10) + shift
def conv(name, inputs, nums_out, ksize, strides, padding="SAME", is_SN=False):
with tf.variable_scope(name):
W = tf.get_variable("W", shape=[ksize, ksize, int(inputs.shape[-1]), nums_out], initializer=tf.truncated_normal_initializer(stddev=0.02))
b = tf.get_variable("b", shape=[nums_out], initializer=tf.constant_initializer(0.))
if is_SN:
return tf.nn.conv2d(inputs, spectral_norm(name, W), [1, strides, strides, 1], padding) + b
else:
return tf.nn.conv2d(inputs, W, [1, strides, strides, 1], padding) + b
def uconv(name, inputs, nums_out, ksize, strides, padding="SAME"):
with tf.variable_scope(name):
w = tf.get_variable("W", shape=[ksize, ksize, nums_out, int(inputs.shape[-1])], initializer=tf.truncated_normal_initializer(stddev=0.02))
b = tf.get_variable("b", [nums_out], initializer=tf.constant_initializer(0.))
# inputs = tf.image.resize_nearest_neighbor(inputs, [H*strides, W*strides])
# return tf.nn.conv2d(inputs, w, [1, 1, 1, 1], padding) + b
return tf.nn.conv2d_transpose(inputs, w, [tf.shape(inputs)[0], int(inputs.shape[1])*strides, int(inputs.shape[2])*strides, nums_out], [1, strides, strides, 1], padding=padding) + b
def fully_connected(name, inputs, nums_out):
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
W = tf.get_variable("W", [int(inputs.shape[-1]), nums_out], initializer=tf.truncated_normal_initializer(stddev=0.02))
b = tf.get_variable("b", [nums_out], initializer=tf.constant_initializer(0.))
return tf.matmul(inputs, W) + b
def spectral_norm(name, w, iteration=1):
#Spectral normalization which was published on ICLR2018,please refer to "https://www.researchgate.net/publication/318572189_Spectral_Normalization_for_Generative_Adversarial_Networks"
#This function spectral_norm is forked from "https://github.com/taki0112/Spectral_Normalization-Tensorflow"
w_shape = w.shape.as_list()
w = tf.reshape(w, [-1, w_shape[-1]])
with tf.variable_scope(name, reuse=False):
u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)
u_hat = u
v_hat = None
def l2_norm(v, eps=1e-12):
return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)
for i in range(iteration):
v_ = tf.matmul(u_hat, tf.transpose(w))
v_hat = l2_norm(v_)
u_ = tf.matmul(v_hat, w)
u_hat = l2_norm(u_)
sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
w_norm = w / sigma
with tf.control_dependencies([u.assign(u_hat)]):
w_norm = tf.reshape(w_norm, w_shape)
return w_norm
def leaky_relu(x, slope=0.2):
return tf.maximum(x, slope*x)