-
Notifications
You must be signed in to change notification settings - Fork 4
/
live_prediction.py
217 lines (159 loc) · 8.64 KB
/
live_prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import argparse
from sklearn.externals import joblib
import copy
from time import time
import numpy as np
import logging
import os
from shutil import copyfile
from action_recognition.tracker import Tracker, TrackVisualiser
from action_recognition.detector import CaffeOpenpose
from action_recognition.analysis import PostProcessor, ChunkVisualiser
from action_recognition import transforms
def main(args):
os.makedirs(args.out_directory, exist_ok=True)
_, video_ending = os.path.splitext(args.video)
# Copy video file so we can create multiple different videos
# with it as base simultaneously.
tmp_video_file = "output/tmp" + video_ending
copyfile(args.video, tmp_video_file)
classifier = joblib.load(args.classifier)
detector = CaffeOpenpose(args.model_path)
tracker = Tracker(detector, out_dir=args.out_directory)
logging.info("Classes: {}".format(classifier.classes_))
valid_predictions = []
track_people_start = time()
for tracks, img, current_frame in tracker.video_generator(args.video, args.draw_frames):
# Don't predict every frame, not enough has changed for it to be valuable.
if current_frame % 20 != 0 or len(tracks) <= 0:
write_predictions(valid_predictions, img)
continue
# We only care about recently updated tracks.
tracks = [track for track in tracks
if track.recently_updated(current_frame)]
track_people_time = time() - track_people_start
logging.debug("Number of tracks: {}".format(len(tracks)))
predict_people_start = time()
valid_predictions = predict(tracks, classifier, current_frame, args.confidence_threshold)
predict_people_time = time() - predict_people_start
write_predictions(valid_predictions, img)
save_predictions(valid_predictions, args.video, tmp_video_file, args.out_directory)
logging.info("Predict time: {:.3f}, Track time: {:.3f}".format(
predict_people_time, track_people_time))
track_people_start = time()
def predict(tracks, classifier, current_frame, confidence_threshold):
# Extract the latest frames, as we don't want to copy
# too much data here, and we've already predicted for the rest
processor = PostProcessor()
processor.tracks = [copy.deepcopy(t.copy(-50)) for t in tracks]
processor.post_process_tracks()
predictions = [predict_per_track(t, classifier) for t in processor.tracks]
valid_predictions = filter_bad_predictions(
predictions, confidence_threshold, classifier.classes_)
save_predictions_to_track(predictions, classifier.classes_, tracks, current_frame)
no_stop_predictions = [predict_no_stop(track, confidence_threshold)
for track in tracks]
for t in [t for p, t in no_stop_predictions if p]:
valid_predictions.append(t)
log_predictions(predictions, no_stop_predictions, classifier.classes_)
return valid_predictions
def predict_per_track(track, classifier):
all_chunks = []
all_frames = []
divisions = [(50, 0), (30, 10), (25, 0), (20, 5)]
for frames_per_chunk, overlap in divisions:
chunks, chunk_frames = track.divide_into_chunks(frames_per_chunk, overlap)
if len(chunks) > 0:
all_chunks.append(chunks[-1])
all_frames.append(chunk_frames[-1])
if len(all_chunks) > 0:
predictions = classifier.predict_proba(all_chunks)
average_prediction = np.amax(predictions, axis=0)
return all_chunks[0], all_frames[0], average_prediction
else:
return None, None, [0] * len(classifier.classes_)
def write_predictions(valid_predictions, img):
for label, confidence, position, _, _ in valid_predictions:
TrackVisualiser().draw_text(img, "{}: {:.3f}".format(label, confidence), position)
def save_predictions(valid_predictions, video_name, video, out_directory):
for i, (label, _, _, chunk, frames) in enumerate(valid_predictions):
write_chunk_to_file(video_name, video, frames, chunk, label, out_directory, i)
def filter_bad_predictions(predictions, threshold, classes):
valid_predictions = []
for chunk, frames, prediction in predictions:
label, confidence = get_best_pred(prediction, classes)
if confidence > threshold:
position = tuple(chunk[-1, 0, :2].astype(np.int))
prediction_tuple = (label, confidence, position, chunk, frames)
valid_predictions.append(prediction_tuple)
return valid_predictions
def save_predictions_to_track(predictions, classes, tracks, current_frame):
for t, (_, _, prediction) in zip(tracks, predictions):
label, confidence = get_best_pred(prediction, classes)
t.add_prediction(label, confidence, current_frame)
def get_best_pred(prediction, classes):
best_pred_i = np.argmax(prediction)
confidence = prediction[best_pred_i]
label = classes[best_pred_i]
return label, confidence
def write_chunk_to_file(video_name, video, frames, chunk, label, out_dir, i):
_, video_name = os.path.split(video_name)
video_name, _ = os.path.splitext(video_name)
file_name = "{}-{}-{}-{}.avi".format(video_name, frames[-1], i, label)
out_file = os.path.join(out_dir, file_name)
ChunkVisualiser().chunk_to_video_scene(video, chunk, out_file, frames, label)
def predict_no_stop(track, confidence_threshold):
if len(track) < 50:
return False, ()
classifier_prediction = classifier_predict_no_stop(track, confidence_threshold)
# Copy last 200 frames to chunk for visusalisation.
track = track.copy(-200)
chunks, chunk_frames = track.divide_into_chunks(len(track) - 1, 0)
position = tuple(chunks[0, -1, 1, :2].astype(np.int))
prediction_tuple = ("Has not stopped", classifier_prediction,
position, chunks[0], chunk_frames[0])
return classifier_prediction > confidence_threshold, prediction_tuple
def classifier_predict_no_stop(track, confidence_threshold):
# If there haven't been that many predictions, we can't say anything.
if len(track.predictions) < 5:
return 0
number_moving = sum(prediction['label'] == 'moving' and
prediction['confidence'] > confidence_threshold
for prediction in list(track.predictions.values())[-20:])
return number_moving / len(track.predictions)
def log_predictions(predictions, no_stop_predictions, classes):
prints = []
for _, _, prediction in predictions:
prints.append(get_best_pred(prediction, classes))
if no_stop_predictions:
for label, confidence, _, _, _ in [t for p, t in no_stop_predictions if p]:
prints.append((label, confidence))
logging.info("Predictions: " + ", ".join(
["{}: {:.3f}".format(*t)
for t in prints]))
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description=('Generates action predictions live given a video and a pre-trained classifier. '
'Uses Tracker.tracker.video_generator which yields every track every frame, '
'from which it predicts the class of action using the pre-trained classifier. '
'To get a better prediction, it takes the latest 50, 30, 25, and 20 frames '
'as chunks and selects the likliest prediction among the five * n_classes. '
'It also predicts if a person has not stopped moving (e.g. if they are moving '
'through a self-checkout area without scanning anything) by checking if '
'a proportion of the latest identified actions for a track/person is moving.'))
parser.add_argument('--classifier', type=str,
help='Path to a .pkl file with a pre-trained action recognition classifier.')
parser.add_argument('--video', type=str,
help='Path to video file to predict actions for.')
parser.add_argument('--model-path', type=str, default='../openpose/models/',
help='The model path for OpenPose.')
parser.add_argument('--confidence-threshold', type=float, default=0.6,
help='Threshold for how confident the model should be in each prediction.')
parser.add_argument('--draw-frames', action='store_true',
help='Flag for if the frames with identified frames should be drawn or not.')
parser.add_argument('--out-directory', type=str, default='output/prediction',
help=('Output directory to where the processed video and identified '
'chunks are saved.'))
logging.basicConfig(level=logging.INFO)
args = parser.parse_args()
main(args)