-
Notifications
You must be signed in to change notification settings - Fork 110
/
main_video_person_reid.py
executable file
·297 lines (250 loc) · 11.8 KB
/
main_video_person_reid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
from __future__ import print_function, absolute_import
import os
import sys
import time
import datetime
import argparse
import os.path as osp
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.optim import lr_scheduler
import data_manager
from video_loader import VideoDataset
import transforms as T
import models
from models import resnet3d
from losses import CrossEntropyLabelSmooth, TripletLoss
from utils import AverageMeter, Logger, save_checkpoint
from eval_metrics import evaluate
from samplers import RandomIdentitySampler
parser = argparse.ArgumentParser(description='Train video model with cross entropy loss')
# Datasets
parser.add_argument('-d', '--dataset', type=str, default='mars',
choices=data_manager.get_names())
parser.add_argument('-j', '--workers', default=4, type=int,
help="number of data loading workers (default: 4)")
parser.add_argument('--height', type=int, default=224,
help="height of an image (default: 224)")
parser.add_argument('--width', type=int, default=112,
help="width of an image (default: 112)")
parser.add_argument('--seq-len', type=int, default=4, help="number of images to sample in a tracklet")
# Optimization options
parser.add_argument('--max-epoch', default=800, type=int,
help="maximum epochs to run")
parser.add_argument('--start-epoch', default=0, type=int,
help="manual epoch number (useful on restarts)")
parser.add_argument('--train-batch', default=32, type=int,
help="train batch size")
parser.add_argument('--test-batch', default=1, type=int, help="has to be 1")
parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float,
help="initial learning rate, use 0.0001 for rnn, use 0.0003 for pooling and attention")
parser.add_argument('--stepsize', default=200, type=int,
help="stepsize to decay learning rate (>0 means this is enabled)")
parser.add_argument('--gamma', default=0.1, type=float,
help="learning rate decay")
parser.add_argument('--weight-decay', default=5e-04, type=float,
help="weight decay (default: 5e-04)")
parser.add_argument('--margin', type=float, default=0.3, help="margin for triplet loss")
parser.add_argument('--num-instances', type=int, default=4,
help="number of instances per identity")
parser.add_argument('--htri-only', action='store_true', default=False,
help="if this is True, only htri loss is used in training")
# Architecture
parser.add_argument('-a', '--arch', type=str, default='resnet50tp', help="resnet503d, resnet50tp, resnet50ta, resnetrnn")
parser.add_argument('--pool', type=str, default='avg', choices=['avg', 'max'])
# Miscs
parser.add_argument('--print-freq', type=int, default=80, help="print frequency")
parser.add_argument('--seed', type=int, default=1, help="manual seed")
parser.add_argument('--pretrained-model', type=str, default='/home/jiyang/Workspace/Works/video-person-reid/3dconv-person-reid/pretrained_models/resnet-50-kinetics.pth', help='need to be set for resnet3d models')
parser.add_argument('--evaluate', action='store_true', help="evaluation only")
parser.add_argument('--eval-step', type=int, default=50,
help="run evaluation for every N epochs (set to -1 to test after training)")
parser.add_argument('--save-dir', type=str, default='log')
parser.add_argument('--use-cpu', action='store_true', help="use cpu")
parser.add_argument('--gpu-devices', default='0', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES')
args = parser.parse_args()
def main():
torch.manual_seed(args.seed)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
use_gpu = torch.cuda.is_available()
if args.use_cpu: use_gpu = False
if not args.evaluate:
sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
else:
sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
print("==========\nArgs:{}\n==========".format(args))
if use_gpu:
print("Currently using GPU {}".format(args.gpu_devices))
cudnn.benchmark = True
torch.cuda.manual_seed_all(args.seed)
else:
print("Currently using CPU (GPU is highly recommended)")
print("Initializing dataset {}".format(args.dataset))
dataset = data_manager.init_dataset(name=args.dataset)
transform_train = T.Compose([
T.Random2DTranslation(args.height, args.width),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transform_test = T.Compose([
T.Resize((args.height, args.width)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
pin_memory = True if use_gpu else False
trainloader = DataLoader(
VideoDataset(dataset.train, seq_len=args.seq_len, sample='random',transform=transform_train),
sampler=RandomIdentitySampler(dataset.train, num_instances=args.num_instances),
batch_size=args.train_batch, num_workers=args.workers,
pin_memory=pin_memory, drop_last=True,
)
queryloader = DataLoader(
VideoDataset(dataset.query, seq_len=args.seq_len, sample='dense', transform=transform_test),
batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
pin_memory=pin_memory, drop_last=False,
)
galleryloader = DataLoader(
VideoDataset(dataset.gallery, seq_len=args.seq_len, sample='dense', transform=transform_test),
batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
pin_memory=pin_memory, drop_last=False,
)
print("Initializing model: {}".format(args.arch))
if args.arch=='resnet503d':
model = resnet3d.resnet50(num_classes=dataset.num_train_pids, sample_width=args.width, sample_height=args.height, sample_duration=args.seq_len)
if not os.path.exists(args.pretrained_model):
raise IOError("Can't find pretrained model: {}".format(args.pretrained_model))
print("Loading checkpoint from '{}'".format(args.pretrained_model))
checkpoint = torch.load(args.pretrained_model)
state_dict = {}
for key in checkpoint['state_dict']:
if 'fc' in key: continue
state_dict[key.partition("module.")[2]] = checkpoint['state_dict'][key]
model.load_state_dict(state_dict, strict=False)
else:
model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={'xent', 'htri'})
print("Model size: {:.5f}M".format(sum(p.numel() for p in model.parameters())/1000000.0))
criterion_xent = CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids, use_gpu=use_gpu)
criterion_htri = TripletLoss(margin=args.margin)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.stepsize > 0:
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)
start_epoch = args.start_epoch
if use_gpu:
model = nn.DataParallel(model).cuda()
if args.evaluate:
print("Evaluate only")
test(model, queryloader, galleryloader, args.pool, use_gpu)
return
start_time = time.time()
best_rank1 = -np.inf
if args.arch=='resnet503d':
torch.backends.cudnn.benchmark = False
for epoch in range(start_epoch, args.max_epoch):
print("==> Epoch {}/{}".format(epoch+1, args.max_epoch))
train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu)
if args.stepsize > 0: scheduler.step()
if args.eval_step > 0 and (epoch+1) % args.eval_step == 0 or (epoch+1) == args.max_epoch:
print("==> Test")
rank1 = test(model, queryloader, galleryloader, args.pool, use_gpu)
is_best = rank1 > best_rank1
if is_best: best_rank1 = rank1
if use_gpu:
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
save_checkpoint({
'state_dict': state_dict,
'rank1': rank1,
'epoch': epoch,
}, is_best, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch+1) + '.pth.tar'))
elapsed = round(time.time() - start_time)
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))
def train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu):
model.train()
losses = AverageMeter()
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
if use_gpu:
imgs, pids = imgs.cuda(), pids.cuda()
imgs, pids = Variable(imgs), Variable(pids)
outputs, features = model(imgs)
if args.htri_only:
# only use hard triplet loss to train the network
loss = criterion_htri(features, pids)
else:
# combine hard triplet loss with cross entropy loss
xent_loss = criterion_xent(outputs, pids)
htri_loss = criterion_htri(features, pids)
loss = xent_loss + htri_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.update(loss.data[0], pids.size(0))
if (batch_idx+1) % args.print_freq == 0:
print("Batch {}/{}\t Loss {:.6f} ({:.6f})".format(batch_idx+1, len(trainloader), losses.val, losses.avg))
def test(model, queryloader, galleryloader, pool, use_gpu, ranks=[1, 5, 10, 20]):
model.eval()
qf, q_pids, q_camids = [], [], []
for batch_idx, (imgs, pids, camids) in enumerate(queryloader):
if use_gpu:
imgs = imgs.cuda()
imgs = Variable(imgs, volatile=True)
# b=1, n=number of clips, s=16
b, n, s, c, h, w = imgs.size()
assert(b==1)
imgs = imgs.view(b*n, s, c, h, w)
features = model(imgs)
features = features.view(n, -1)
features = torch.mean(features, 0)
features = features.data.cpu()
qf.append(features)
q_pids.extend(pids)
q_camids.extend(camids)
qf = torch.stack(qf)
q_pids = np.asarray(q_pids)
q_camids = np.asarray(q_camids)
print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))
gf, g_pids, g_camids = [], [], []
for batch_idx, (imgs, pids, camids) in enumerate(galleryloader):
if use_gpu:
imgs = imgs.cuda()
imgs = Variable(imgs, volatile=True)
b, n, s, c, h, w = imgs.size()
imgs = imgs.view(b*n, s , c, h, w)
assert(b==1)
features = model(imgs)
features = features.view(n, -1)
if pool == 'avg':
features = torch.mean(features, 0)
else:
features, _ = torch.max(features, 0)
features = features.data.cpu()
gf.append(features)
g_pids.extend(pids)
g_camids.extend(camids)
gf = torch.stack(gf)
g_pids = np.asarray(g_pids)
g_camids = np.asarray(g_camids)
print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1)))
print("Computing distance matrix")
m, n = qf.size(0), gf.size(0)
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
distmat.addmm_(1, -2, qf, gf.t())
distmat = distmat.numpy()
print("Computing CMC and mAP")
cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
print("Results ----------")
print("mAP: {:.1%}".format(mAP))
print("CMC curve")
for r in ranks:
print("Rank-{:<3}: {:.1%}".format(r, cmc[r-1]))
print("------------------")
return cmc[0]
if __name__ == '__main__':
main()