-
Notifications
You must be signed in to change notification settings - Fork 0
/
simple_server.py
147 lines (131 loc) · 5.46 KB
/
simple_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# simple_server.py
import http.server
import socketserver
from PIL import Image
from io import BytesIO
import base64
import cv2
import time
import numpy as np
from utils.anchor_generator import generate_anchors
from utils.anchor_decode import decode_bbox
from utils.nms import single_class_non_max_suppression
from load_model.tensorflow_loader import load_tf_model, tf_inference
sess, graph = load_tf_model('models/face_mask_detection.pb')
# anchor configuration
feature_map_sizes = [[33, 33], [17, 17], [9, 9], [5, 5], [3, 3]]
anchor_sizes = [[0.04, 0.056], [0.08, 0.11], [0.16, 0.22], [0.32, 0.45], [0.64, 0.72]]
anchor_ratios = [[1, 0.62, 0.42]] * 5
# generate anchors
anchors = generate_anchors(feature_map_sizes, anchor_sizes, anchor_ratios)
# for inference , the batch size is 1, the model output shape is [1, N, 4],
# so we expand dim for anchors to [1, anchor_num, 4]
anchors_exp = np.expand_dims(anchors, axis=0)
id2class = {0: 'Mask', 1: 'NoMask'}
def inference(image,
conf_thresh=0.5,
iou_thresh=0.4,
target_shape=(160, 160),
draw_result=True,
show_result=True
):
'''
Main function of detection inference
:param image: 3D numpy array of image
:param conf_thresh: the min threshold of classification probabity.
:param iou_thresh: the IOU threshold of NMS
:param target_shape: the model input size.
:param draw_result: whether to daw bounding box to the image.
:param show_result: whether to display the image.
:return:
'''
# image = np.copy(image)
output_info = []
height, width, _ = image.shape
image_resized = cv2.resize(image, target_shape)
image_np = image_resized / 255.0 # 归一化到0~1
image_exp = np.expand_dims(image_np, axis=0)
y_bboxes_output, y_cls_output = tf_inference(sess, graph, image_exp)
# remove the batch dimension, for batch is always 1 for inference.
y_bboxes = decode_bbox(anchors_exp, y_bboxes_output)[0]
y_cls = y_cls_output[0]
# To speed up, do single class NMS, not multiple classes NMS.
bbox_max_scores = np.max(y_cls, axis=1)
bbox_max_score_classes = np.argmax(y_cls, axis=1)
# keep_idx is the alive bounding box after nms.
keep_idxs = single_class_non_max_suppression(y_bboxes,
bbox_max_scores,
conf_thresh=conf_thresh,
iou_thresh=iou_thresh,
)
for idx in keep_idxs:
conf = float(bbox_max_scores[idx])
class_id = bbox_max_score_classes[idx]
bbox = y_bboxes[idx]
print(class_id)
if class_id == 0:
print("True this person has mask")
# clip the coordinate, avoid the value exceed the image boundary.
xmin = max(0, int(bbox[0] * width))
ymin = max(0, int(bbox[1] * height))
xmax = min(int(bbox[2] * width), width)
ymax = min(int(bbox[3] * height), height)
if draw_result:
if class_id == 0:
color = (0, 255, 0)
else:
color = (255, 0, 0)
cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, 2)
cv2.putText(image, "%s: %.2f" % (id2class[class_id], conf), (xmin + 2, ymin - 2),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, color)
output_info.append([class_id, conf, xmin, ymin, xmax, ymax])
if show_result:
Image.fromarray(image).show()
return output_info
PORT = 8000
class MyHandler(http.server.BaseHTTPRequestHandler):
def do_HEAD(self):
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
def do_GET(self):
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
print(self.wfile)
self.wfile.write(b"<html><head><title>Title goes here.</title></head>")
self.wfile.write(b"<body><p>This is a test.</p>")
# If someone went to "http://something.somewhere.net/foo/bar/",
# then s.path equals "/foo/bar/".
self.wfile.write(b"<p>You accessed path: %b</p>" % self.path.encode('utf8'))
self.wfile.write(b"</body></html>")
def do_POST(self):
content_length = int(self.headers['Content-Length'])
body = self.rfile.read(content_length)
#print(body)
self.send_response(200)
self.end_headers()
response = BytesIO()
#image = Image.open(BytesIO(body))
#opencvImage = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
#img = cv2.imdecode(jpg_as_np, flags=1)
#img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#inference(img, show_result=True, target_shape=(260, 260))
#response.write(b'12345')
#response.write(b'Received: ')
#response.write(body)
image = Image.open(BytesIO(body))
#image.show()
img = np.array(image)
#img = cv2.imdecode(img, flags=1)
inf_out = inference(img, show_result=True, target_shape=(260, 260))
response.write(bytes(str(inf_out), encoding='utf8'))
self.wfile.write(response.getvalue())
try:
#server = http.server.HTTPServer(('localhost', PORT), MyHandler)
server = http.server.HTTPServer(('0.0.0.0', PORT), MyHandler)
print('Started http server')
server.serve_forever()
except KeyboardInterrupt:
print('^C received, shutting down server')
server.socket.close()