forked from probml/pyprobml
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cnn_filters.py
78 lines (59 loc) · 2.05 KB
/
cnn_filters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# https://github.com/ageron/handson-ml2/blob/master/14_deep_computer_vision_with_cnns.ipynb
import sys
assert sys.version_info >= (3, 5)
# Scikit-Learn ≥0.20 is required
import sklearn
assert sklearn.__version__ >= "0.20"
# TensorFlow ≥2.0-preview is required
import tensorflow as tf
from tensorflow import keras
assert tf.__version__ >= "2.0"
# Common imports
import numpy as np
import os
import matplotlib.pyplot as plt
figdir = "../figures"
def save_fig(fname):
if figdir: plt.savefig(os.path.join(figdir, fname))
# to make this notebook's output stable across runs
np.random.seed(42)
tf.random.set_seed(42)
def plot_image(image):
plt.imshow(image, cmap="gray", interpolation="nearest")
plt.axis("off")
def plot_color_image(image):
plt.imshow(image, interpolation="nearest")
plt.axis("off")
import numpy as np
from sklearn.datasets import load_sample_image
# Load sample images
china = load_sample_image("china.jpg") / 255
flower = load_sample_image("flower.jpg") / 255
images = np.array([china, flower])
batch_size, height, width, channels = images.shape
# Create 2 filters
filters = np.zeros(shape=(7, 7, channels, 2), dtype=np.float32)
filters[:, 3, :, 0] = 1 # vertical line
filters[3, :, :, 1] = 1 # horizontal line
outputs = tf.nn.conv2d(images, filters, strides=1, padding="SAME")
plt.imshow(outputs[0, :, :, 1], cmap="gray") # plot 1st image's 2nd feature map
plt.axis("off") # Not shown in the book
plt.show()
for image_index in (0, 1):
for feature_map_index in (0, 1):
plt.subplot(2, 2, image_index * 2 + feature_map_index + 1)
plot_image(outputs[image_index, :, :, feature_map_index])
plt.show()
def crop(images):
return images[150:220, 130:250]
plot_image(crop(images[0, :, :, 0]))
save_fig("china_original", tight_layout=False)
plt.show()
for feature_map_index, filename in enumerate(["china_vertical", "china_horizontal"]):
plot_image(crop(outputs[0, :, :, feature_map_index]))
save_fig(filename)
plt.show()
plot_image(filters[:, :, 0, 0])
plt.show()
plot_image(filters[:, :, 0, 1])
plt.show()