-
Notifications
You must be signed in to change notification settings - Fork 9
/
generateAnchorDecode.cu
556 lines (451 loc) · 18.4 KB
/
generateAnchorDecode.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
#include "generateAnchorDecode.h"
using namespace nvinfer1;
using nvinfer1::GenerateAnchorDecodePlugin;
using nvinfer1::GenerateAnchorDecodePluginCreator;
using namespace std;
#define checkCudaErrors(status) \
{ \
if (status != 0) \
{ \
std::cout << "Cuda failure: " << cudaGetErrorString(status) \
<< " at line " << __LINE__ \
<< " in file " << __FILE__ \
<< " error status: " << status \
<< std::endl; \
abort(); \
} \
}
#define CHECK(status) \
do\
{\
auto ret = (status);\
if (ret != 0)\
{\
std::cerr << "Cuda failure: " << ret << std::endl;\
abort();\
}\
} while (0)
#define CUDA_MEM_ALIGN 256
static const char* PLUGIN_VERSION{"1"};
static const char* PLUGIN_NAME{"GenerateAnchorDecodePlugin"};
// Static class fields initialization
PluginFieldCollection GenerateAnchorDecodePluginCreator::mFC{};
std::vector<PluginField> GenerateAnchorDecodePluginCreator::mPluginAttributes;
// Helper function for serializing plugin
template <typename T>
void writeToBuffer(char*& buffer, const T& val)
{
*reinterpret_cast<T*>(buffer) = val;
buffer += sizeof(T);
}
// Helper function for deserializing plugin
template <typename T>
T readFromBuffer(const char*& buffer)
{
T val = *reinterpret_cast<const T*>(buffer);
buffer += sizeof(T);
return val;
}
// Mimic np.round as in voxel generator in spconv implementation
int np_round(float x) {
// half way round to nearest-even
int x2 = int(x * 2.0f);
if(x != int(x) && x2 == x * 2.0f) {
return int(x / 2.0f + 0.5f) * 2;
}
return int(x + 0.5f);
}
// ALIGNPTR
int8_t* alignPtr(int8_t* ptr, uintptr_t to)
{
uintptr_t addr = (uintptr_t) ptr;
if (addr % to)
{
addr += to - addr % to;
}
return (int8_t*) addr;
}
// NEXTWORKSPACEPTR
int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize)
{
uintptr_t addr = (uintptr_t) ptr;
addr += previousWorkspaceSize;
return alignPtr((int8_t*) addr, CUDA_MEM_ALIGN);
}
// CALCULATE TOTAL WORKSPACE SIZE
size_t calculateTotalWorkspaceSize(size_t* workspaces, int count)
{
size_t total = 0;
for (int i = 0; i < count; i++)
{
total += workspaces[i];
if (workspaces[i] % CUDA_MEM_ALIGN)
{
total += CUDA_MEM_ALIGN - (workspaces[i] % CUDA_MEM_ALIGN);
}
}
return total;
}
// create the plugin at runtime from a byte stream
GenerateAnchorDecodePlugin::GenerateAnchorDecodePlugin(float min_x_range,float max_x_range,float min_y_range,float max_y_range,
float min_z_range,float max_z_range, int feature_map_height, int feature_map_width, float car_length,
float car_width, float car_height, float direction_angle_0, float direction_angle_1,int direction_angle_num)
: min_x_range_(min_x_range), max_x_range_(max_x_range), min_y_range_(min_y_range),
max_y_range_(max_y_range),min_z_range_(min_z_range),max_z_range_(max_z_range),feature_map_height_(feature_map_height),
feature_map_width_(feature_map_width),car_length_(car_length),car_width_(car_width),car_height_(car_height),
direction_angle_0_(direction_angle_0),direction_angle_1_(direction_angle_1),direction_angle_num_(direction_angle_num)
{
}
GenerateAnchorDecodePlugin::GenerateAnchorDecodePlugin(const void* data, size_t length)
{
const char* d = reinterpret_cast<const char*>(data);
min_x_range_ = readFromBuffer<float>(d);
max_x_range_ = readFromBuffer<float>(d);
min_y_range_ = readFromBuffer<float>(d);
max_y_range_ = readFromBuffer<float>(d);
min_z_range_ = readFromBuffer<float>(d);
max_z_range_ = readFromBuffer<float>(d);
feature_map_height_ = readFromBuffer<int>(d);
feature_map_width_ = readFromBuffer<int>(d);
car_length_ = readFromBuffer<float>(d);
car_width_ = readFromBuffer<float>(d);
car_height_ = readFromBuffer<float>(d);
direction_angle_0_ = readFromBuffer<float>(d);
direction_angle_1_ = readFromBuffer<float>(d);
direction_angle_num_ = readFromBuffer<int>(d);
}
IPluginV2DynamicExt* GenerateAnchorDecodePlugin::clone() const noexcept
{
auto* plugin = new GenerateAnchorDecodePlugin(min_x_range_,max_x_range_,min_y_range_,max_y_range_,
min_z_range_,max_z_range_,feature_map_height_,feature_map_width_,car_length_,car_width_,car_height_,
direction_angle_0_,direction_angle_1_,direction_angle_num_);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
nvinfer1::DimsExprs GenerateAnchorDecodePlugin::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{
// assert(outputIndex >= 0 && outputIndex < this->getNbOutputs());
auto batch_size = inputs[0].d[0];
auto line_num = inputs[0].d[1];
auto dim_num = inputs[0].d[2];
// std::cout << "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" << std::endl;
// std::cout << batch_size->getConstantValue() << " " << line_num->getConstantValue() << " "
// << dim_num->getConstantValue() << std::endl;
if (outputIndex == 0)
{
nvinfer1::DimsExprs dim0{};
dim0.nbDims = 3;
dim0.d[0] = batch_size;
dim0.d[1] = line_num;
dim0.d[2] = dim_num;
return dim0;
}
}
bool GenerateAnchorDecodePlugin::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept
{
const PluginTensorDesc& in = inOut[pos];
if (pos == 0)
{
return (in.type == nvinfer1::DataType::kFLOAT) && (in.format == TensorFormat::kLINEAR);
}
if (pos == 1)
{
return (in.type == nvinfer1::DataType::kFLOAT) && (in.format == TensorFormat::kLINEAR);
}
return false;
}
void GenerateAnchorDecodePlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept
{
}
size_t GenerateAnchorDecodePlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept
{
return 0;
}
__global__ void generate_anchor_kernel(float* output_features,float min_x_range, float max_x_range, float min_y_range,
float max_y_range, float min_z_range, float max_z_range, int feature_map_height, int feature_map_width,
float car_length, float car_width, float car_height, int direction_angle_num, float direction_angle_0,
float direction_angle_1)
{
// printf("point_size:%d\n",*points_size);
int line_idx = blockIdx.x * blockDim.x + threadIdx.x;
float stride_x = (max_x_range-min_x_range) / feature_map_width;
float stride_y = (max_y_range-min_y_range) / feature_map_height;
float x_start = min_x_range + stride_x/2;
float y_start = min_y_range + stride_y/2;
int pos_0_value_index = (line_idx%(feature_map_width*direction_angle_num))/direction_angle_num; //index%352/2
float pos_0_value = x_start + pos_0_value_index * stride_x;
int pos_1_value_index = line_idx / (feature_map_width*direction_angle_num); // index/352
float pos_1_value = y_start + pos_1_value_index * stride_y;
float pos_2_value = -1.0;
float pos_3_value = car_width; // 1.60; // w
float pos_4_value = car_length; // 3.90; // l
float pos_5_value = car_height; // 1.56; // h
float pos_6_value = 0.0;
int pos_6_value_index = line_idx % direction_angle_num;
if(pos_6_value_index == 0)
{
pos_6_value = direction_angle_0;
}
if(pos_6_value_index == 1)
{
pos_6_value = direction_angle_1;
// printf("direction_range_1: %f\n",direction_angle_1);
}
*(output_features+line_idx*7+0) = pos_0_value;
*(output_features+line_idx*7+1) = pos_1_value;
*(output_features+line_idx*7+2) = pos_2_value;
*(output_features+line_idx*7+3) = pos_3_value;
*(output_features+line_idx*7+4) = pos_4_value;
*(output_features+line_idx*7+5) = pos_5_value;
*(output_features+line_idx*7+6) = pos_6_value;
}
cudaError_t generate_anchor_launch(float* output_features,float min_x_range, float max_x_range, float min_y_range,
float max_y_range, float min_z_range, float max_z_range, int feature_map_height, int feature_map_width,
float car_length, float car_width, float car_height, int direction_angle_num, float direction_angle_0,
float direction_angle_1, cudaStream_t stream)
{
int threadNum = THREADS_FOR_VOXEL;
dim3 blocks((feature_map_height*feature_map_width*direction_angle_num+threadNum-1)/threadNum);
dim3 threads(threadNum);
generate_anchor_kernel<<<blocks, threads, 0, stream>>>
(output_features,min_x_range,max_x_range,min_y_range,max_y_range,min_z_range,max_z_range,
feature_map_height,feature_map_width,car_length,car_width,car_height,direction_angle_num,
direction_angle_0,direction_angle_1);
cudaError_t err = cudaGetLastError();
return err;
}
__global__ void decode_kernel(float* features, float* output_features)
{
// printf("point_size:%d\n",*points_size);
int line_idx = blockIdx.x * blockDim.x + threadIdx.x;
// output_features == anchors
// features == box_preds
// xa, ya, za, wa, la, ha, ra
float anchor_x = *(output_features+line_idx*7+0);
float anchor_y = *(output_features+line_idx*7+1);
float anchor_z = *(output_features+line_idx*7+2);
float anchor_w = *(output_features+line_idx*7+3);
float anchor_l = *(output_features+line_idx*7+4);
float anchor_h = *(output_features+line_idx*7+5);
float anchor_r = *(output_features+line_idx*7+6);
float box_x = *(features+line_idx*7+0);
float box_y = *(features+line_idx*7+1);
float box_z = *(features+line_idx*7+2);
float box_w = *(features+line_idx*7+3);
float box_l = *(features+line_idx*7+4);
float box_h = *(features+line_idx*7+5);
float box_r = *(features+line_idx*7+6);
// diagonal = torch.sqrt(la ** 2 + wa ** 2)
// xg = xt * diagonal + xa
// yg = yt * diagonal + ya
// zg = zt * ha + za
float diagonal = sqrt(anchor_l * anchor_l + anchor_w * anchor_w);
*(output_features+line_idx*7+0) = box_x * diagonal + anchor_x;
*(output_features+line_idx*7+1) = box_y * diagonal + anchor_y;
*(output_features+line_idx*7+2) = box_z * anchor_h + anchor_z;
// lg = torch.exp(lt) * la
// wg = torch.exp(wt) * wa
// hg = torch.exp(ht) * ha
// ret.extend([wg, lg, hg])
*(output_features+line_idx*7+3) = exp(box_w) * anchor_w;
*(output_features+line_idx*7+4) = exp(box_l) * anchor_l;
*(output_features+line_idx*7+5) = exp(box_h) * anchor_h;
// rg = rt + ra
*(output_features+line_idx*7+6) = box_r + anchor_r;
}
cudaError_t decode_launch(float * features, float* output_features, int feature_map_height, int feature_map_width,
int direction_angle_num, cudaStream_t stream)
{
int threadNum = THREADS_FOR_VOXEL;
dim3 blocks((feature_map_height*feature_map_width*direction_angle_num+threadNum-1)/threadNum);
dim3 threads(threadNum);
decode_kernel<<<blocks, threads, 0, stream>>>
(features,output_features);
cudaError_t err = cudaGetLastError();
return err;
}
int GenerateAnchorDecodePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept
{
int batchSize = inputDesc[0].dims.d[0];
// int maxNumPoints = inputDesc[0].dims.d[1];
//TRT-input
float * features = const_cast<float *>((const float *)inputs[0]);
//TRT-output
float *output_features = (float *)(outputs[0]);
// init output
unsigned int output_features_size = batchSize * feature_map_height_*feature_map_width_*
direction_angle_num_ * 7 * sizeof(float);
checkCudaErrors(cudaMemsetAsync(output_features, 0, output_features_size, stream));
checkCudaErrors(generate_anchor_launch(
output_features,min_x_range_,max_x_range_,min_y_range_,max_y_range_,min_z_range_,max_z_range_,
feature_map_height_,feature_map_width_,car_length_,car_width_,car_height_,direction_angle_num_,
direction_angle_0_,direction_angle_1_, stream));
checkCudaErrors(decode_launch(
features,output_features,feature_map_height_,feature_map_width_,direction_angle_num_, stream));
// cout << "generate anchor finished" << std::endl;
return 0;
}
nvinfer1::DataType GenerateAnchorDecodePlugin::getOutputDataType(
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept
{
return inputTypes[0];
}
const char* GenerateAnchorDecodePlugin::getPluginType() const noexcept
{
return PLUGIN_NAME;
}
const char* GenerateAnchorDecodePlugin::getPluginVersion() const noexcept
{
return PLUGIN_VERSION;
}
int GenerateAnchorDecodePlugin::getNbOutputs() const noexcept
{
return 1;
}
int GenerateAnchorDecodePlugin::initialize() noexcept
{
return 0;
}
void GenerateAnchorDecodePlugin::terminate() noexcept
{
}
size_t GenerateAnchorDecodePlugin::getSerializationSize() const noexcept
{
return 3 * sizeof(int)+11*sizeof(float);
}
void GenerateAnchorDecodePlugin::serialize(void* buffer) const noexcept
{
char* d = reinterpret_cast<char*>(buffer);
writeToBuffer<float>(d, min_x_range_);
writeToBuffer<float>(d, max_x_range_);
writeToBuffer<float>(d, min_y_range_);
writeToBuffer<float>(d, max_y_range_);
writeToBuffer<float>(d, min_z_range_);
writeToBuffer<float>(d, max_y_range_);
writeToBuffer<int>(d, feature_map_height_);
writeToBuffer<int>(d, feature_map_width_);
writeToBuffer<float>(d, car_length_);
writeToBuffer<float>(d, car_width_);
writeToBuffer<float>(d, car_height_);
writeToBuffer<float>(d, direction_angle_0_);
writeToBuffer<float>(d, direction_angle_1_);
writeToBuffer<int>(d, direction_angle_num_);
}
void GenerateAnchorDecodePlugin::destroy() noexcept
{
delete this;
}
void GenerateAnchorDecodePlugin::setPluginNamespace(const char* libNamespace) noexcept
{
mNamespace = libNamespace;
}
const char* GenerateAnchorDecodePlugin::getPluginNamespace() const noexcept
{
return mNamespace.c_str();
}
GenerateAnchorDecodePluginCreator::GenerateAnchorDecodePluginCreator()
{
mPluginAttributes.clear();
mPluginAttributes.emplace_back(PluginField("point_cloud_range", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(PluginField("feature_map_size", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("car_size", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(PluginField("direction_angle", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(PluginField("direction_angle_num", nullptr, PluginFieldType::kINT32, 1));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char* GenerateAnchorDecodePluginCreator::getPluginName() const noexcept
{
return PLUGIN_NAME;
}
const char* GenerateAnchorDecodePluginCreator::getPluginVersion() const noexcept
{
return PLUGIN_VERSION;
}
const PluginFieldCollection* GenerateAnchorDecodePluginCreator::getFieldNames() noexcept
{
return &mFC;
}
IPluginV2* GenerateAnchorDecodePluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept
{
const PluginField* fields = fc->fields;
int nbFields = fc->nbFields;
float min_x_range = 0.0;
float max_x_range = 1000.0;
float min_y_range = 0.0;
float max_y_range = 1000.0;
float min_z_range = 0.0;
float max_z_range = 1000.0;
int feature_map_height = 0;
int feature_map_width = 0;
float car_length = 0.0;
float car_width = 0.0;
float car_height = 0.0;
float direction_angle_0 = 0.0;
float direction_angle_1 = 0.0;
int direction_angle_num = 0;
for (int i = 0; i < nbFields; ++i)
{
const char* attr_name = fields[i].name;
if (!strcmp(attr_name, "point_cloud_range"))
{
const float* d = static_cast<const float*>(fields[i].data);
min_x_range = d[0];
min_y_range = d[1];
min_z_range = d[2];
max_x_range = d[3];
max_y_range = d[4];
max_z_range = d[5];
}
else if(!strcmp(attr_name, "feature_map_size"))
{
const int* d = static_cast<const int*>(fields[i].data);
feature_map_height = d[0];
feature_map_width = d[1];
}
else if (!strcmp(attr_name, "car_size"))
{
const float* d = static_cast<const float*>(fields[i].data);
car_length = d[0];
car_width = d[1];
car_height = d[2];
}
else if (!strcmp(attr_name, "direction_angle"))
{
const float* d = static_cast<const float*>(fields[i].data);
direction_angle_0 = d[0];
direction_angle_1 = d[1];
}
else if(!strcmp(attr_name, "direction_angle_num"))
{
const int* d = static_cast<const int*>(fields[i].data);
direction_angle_num = d[0];
}
}
IPluginV2DynamicExt* plugin = new GenerateAnchorDecodePlugin(min_x_range,max_x_range,min_y_range,max_y_range,
min_z_range,max_z_range,feature_map_height,feature_map_width,car_length,car_width,car_height,
direction_angle_0,direction_angle_1,direction_angle_num);
return plugin;
}
IPluginV2* GenerateAnchorDecodePluginCreator::deserializePlugin(
const char* name, const void* serialData, size_t serialLength) noexcept
{
return new GenerateAnchorDecodePlugin(serialData, serialLength);
}
void GenerateAnchorDecodePluginCreator::setPluginNamespace(const char* libNamespace) noexcept
{
mNamespace = libNamespace;
}
const char* GenerateAnchorDecodePluginCreator::getPluginNamespace() const noexcept
{
return mNamespace.c_str();
}
GenerateAnchorDecodePluginCreator::~GenerateAnchorDecodePluginCreator()
{
}