-
Notifications
You must be signed in to change notification settings - Fork 0
/
imagenet_utils.py
131 lines (122 loc) · 5.28 KB
/
imagenet_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import numpy as np
import json
from keras.utils.data_utils import get_file
from keras import backend as K
CLASS_INDEX = None
CLASS_INDEX_PATH = 'https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json'
def preprocess_input(x, dim_ordering='default'):
"""Preprocesses a tensor encoding a batch of images.
# Arguments
x: input Numpy tensor, 4D.
dim_ordering: data format of the image tensor.
# Returns
Preprocessed tensor.
"""
if dim_ordering == 'default':
dim_ordering = K.image_dim_ordering()
assert dim_ordering in {'tf', 'th'}
if dim_ordering == 'th':
# 'RGB'->'BGR'
x = x[:, ::-1, :, :]
# Zero-center by mean pixel
x[:, 0, :, :] -= 103.939
x[:, 1, :, :] -= 116.779
x[:, 2, :, :] -= 123.68
else:
# 'RGB'->'BGR'
x = x[:, :, :, ::-1]
# Zero-center by mean pixel
x[:, :, :, 0] -= 103.939
x[:, :, :, 1] -= 116.779
x[:, :, :, 2] -= 123.68
return x
def decode_predictions(preds, top=5):
"""Decodes the prediction of an ImageNet model.
# Arguments
preds: Numpy tensor encoding a batch of predictions.
top: integer, how many top-guesses to return.
# Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
# Raises
ValueError: in case of invalid shape of the `pred` array
(must be 2D).
"""
global CLASS_INDEX
if len(preds.shape) != 2 or preds.shape[1] != 1000:
raise ValueError('`decode_predictions` expects '
'a batch of predictions '
'(i.e. a 2D array of shape (samples, 1000)). '
'Found array with shape: ' + str(preds.shape))
if CLASS_INDEX is None:
fpath = get_file('imagenet_class_index.json',
CLASS_INDEX_PATH,
cache_subdir='models')
CLASS_INDEX = json.load(open(fpath))
results = []
for pred in preds:
top_indices = pred.argsort()[-top:][::-1]
result = [tuple(CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices]
result.sort(key=lambda x: x[2], reverse=True)
results.append(result)
return results
def _obtain_input_shape(input_shape,
default_size,
min_size,
dim_ordering,
include_top):
"""Internal utility to compute/validate an ImageNet model's input shape.
# Arguments
input_shape: either None (will return the default network input shape),
or a user-provided shape to be validated.
default_size: default input width/height for the model.
min_size: minimum input width/height accepted by the model.
dim_ordering: image data format to use.
include_top: whether the model is expected to
be linked to a classifier via a Flatten layer.
# Returns
An integer shape tuple (may include None entries).
# Raises
ValueError: in case of invalid argument values.
"""
if dim_ordering == 'th':
default_shape = (3, default_size, default_size)
else:
default_shape = (default_size, default_size, 3)
if include_top:
if input_shape is not None:
if input_shape != default_shape:
raise ValueError('When setting`include_top=True`, '
'`input_shape` should be ' + str(default_shape) + '.')
input_shape = default_shape
else:
if dim_ordering == 'th':
if input_shape is not None:
if len(input_shape) != 3:
raise ValueError('`input_shape` must be a tuple of three integers.')
if input_shape[0] != 3:
raise ValueError('The input must have 3 channels; got '
'`input_shape=' + str(input_shape) + '`')
if ((input_shape[1] is not None and input_shape[1] < min_size) or
(input_shape[2] is not None and input_shape[2] < min_size)):
raise ValueError('Input size must be at least ' +
str(min_size) + 'x' + str(min_size) + ', got '
'`input_shape=' + str(input_shape) + '`')
else:
input_shape = (3, None, None)
else:
if input_shape is not None:
if len(input_shape) != 3:
raise ValueError('`input_shape` must be a tuple of three integers.')
if input_shape[-1] != 3:
raise ValueError('The input must have 3 channels; got '
'`input_shape=' + str(input_shape) + '`')
if ((input_shape[0] is not None and input_shape[0] < min_size) or
(input_shape[1] is not None and input_shape[1] < min_size)):
raise ValueError('Input size must be at least ' +
str(min_size) + 'x' + str(min_size) + ', got '
'`input_shape=' + str(input_shape) + '`')
else:
input_shape = (None, None, 3)
return input_shape