Skip to content

Commit

Permalink
fix: success to build dynamic shape
Browse files Browse the repository at this point in the history
Signed-off-by: ktro2828 <[email protected]>
  • Loading branch information
ktro2828 committed Apr 26, 2024
1 parent f21bbd9 commit 3c1844a
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 47 deletions.
8 changes: 3 additions & 5 deletions include/mtr/agent.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ struct AgentData
// Return the number of classes.
size_t num_class() const { return num_class_; }

size_t num_attr() const { return num_timestamp_ + num_state_dim_ + num_class_ + 3; }

// Return the data shape which is meant to `(N, T, D)`.
std::tuple<size_t, size_t, size_t> shape() const
{
Expand All @@ -267,11 +269,7 @@ struct AgentData
size_t size() const { return num_agent_ * num_timestamp_ * num_state_dim_; }

// Return the number of elements of MTR input (B * N * T * A).
size_t input_size() const
{
return num_target_ * num_agent_ * num_timestamp_ *
(num_timestamp_ + num_state_dim_ + num_class_ + 3);
}
size_t input_size() const { return num_target_ * num_agent_ * num_timestamp_ * num_attr(); }

// Return the index number of ego.
int ego_index() const { return ego_index_; }
Expand Down
2 changes: 1 addition & 1 deletion include/mtr/builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class MTRBuilder

bool isDynamic() const;

bool setBindingDimensions(int32_t index, nvinfer1::Dims dimensions);
bool setBindingDimensions(int index, nvinfer1::Dims dimensions);

/**
* @brief A wrapper of `nvinfer1::IExecuteContext::enqueueV2`.
Expand Down
4 changes: 2 additions & 2 deletions include/mtr/mtr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ class TrtMTR

IntentionPoint intention_point_;

size_t num_target_, num_agent_, num_timestamp_, num_agent_dim_, num_agent_class_;
size_t num_polyline_, num_point_, num_point_dim_;
size_t num_target_, num_agent_, num_timestamp_, num_agent_dim_, num_agent_class_, num_agent_attr_;
size_t num_polyline_, num_point_, num_point_dim_, num_point_attr_;

// source data
cuda::unique_ptr<int[]> d_target_index_{nullptr};
Expand Down
5 changes: 2 additions & 3 deletions include/mtr/polyline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ struct PolylineData
// Return the number of point dimensions.
size_t num_state_dim() const { return num_state_dim_; }

size_t num_attr() const { return num_state_dim_ + 2; }

// Return the data shape which is meant to `(L, P, D)`.
std::tuple<size_t, size_t, size_t> shape() const
{
Expand All @@ -179,9 +181,6 @@ struct PolylineData
// Return the total number of elements which is meant to `L*P*D`.
size_t size() const { return num_polyline_ * num_point_ * num_state_dim_; }

// Return the input attribute size which is meant to `D+2`.
size_t input_attribute_size() const { return num_state_dim_ + 2; }

// Return the address pointer of data array.
const float * data_ptr() const noexcept { return data_.data(); }

Expand Down
40 changes: 20 additions & 20 deletions src/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ fs::path MTRBuilder::createEngineCachePath() const
std::string calibration_name = build_config_->precision == PrecisionType::INT8
? getCalibrationName(build_config_->calibration)
: "";
cache_engine_path.replace_extension(calibration_name + precision_name);
cache_engine_path.replace_extension(calibration_name + precision_name + ".engine");
return cache_engine_path;
}

Expand Down Expand Up @@ -299,24 +299,24 @@ bool MTRBuilder::buildEngineFromOnnx(
profile->setDimensions(
name, nvinfer1::OptProfileSelector::kMAX, nvinfer1::Dims3{batch_target.k_max, 64, 2});
}
// { // pred scores
// auto name = network->getOutput(0)->getName();
// profile->setDimensions(
// name, nvinfer1::OptProfileSelector::kMIN, nvinfer1::Dims2{batch_target.k_min, 6});
// profile->setDimensions(
// name, nvinfer1::OptProfileSelector::kOPT, nvinfer1::Dims2{batch_target.k_opt, 6});
// profile->setDimensions(
// name, nvinfer1::OptProfileSelector::kMAX, nvinfer1::Dims2{batch_target.k_max, 6});
// }
// { // pred trajs
// auto name = network->getOutput(1)->getName();
// profile->setDimensions(
// name, nvinfer1::OptProfileSelector::kMIN, nvinfer1::Dims4{batch_target.k_min, 6, 80, 7});
// profile->setDimensions(
// name, nvinfer1::OptProfileSelector::kOPT, nvinfer1::Dims4{batch_target.k_opt, 6, 80, 7});
// profile->setDimensions(
// name, nvinfer1::OptProfileSelector::kMAX, nvinfer1::Dims4{batch_target.k_max, 6, 80, 7});
// }
{ // pred scores
auto name = network->getOutput(0)->getName();
profile->setDimensions(
name, nvinfer1::OptProfileSelector::kMIN, nvinfer1::Dims2{batch_target.k_min, 6});
profile->setDimensions(
name, nvinfer1::OptProfileSelector::kOPT, nvinfer1::Dims2{batch_target.k_opt, 6});
profile->setDimensions(
name, nvinfer1::OptProfileSelector::kMAX, nvinfer1::Dims2{batch_target.k_max, 6});
}
{ // pred trajs
auto name = network->getOutput(1)->getName();
profile->setDimensions(
name, nvinfer1::OptProfileSelector::kMIN, nvinfer1::Dims4{batch_target.k_min, 6, 80, 7});
profile->setDimensions(
name, nvinfer1::OptProfileSelector::kOPT, nvinfer1::Dims4{batch_target.k_opt, 6, 80, 7});
profile->setDimensions(
name, nvinfer1::OptProfileSelector::kMAX, nvinfer1::Dims4{batch_target.k_max, 6, 80, 7});
}
config->addOptimizationProfile(profile);
}

Expand Down Expand Up @@ -376,7 +376,7 @@ bool MTRBuilder::isDynamic() const
return build_config_->is_dynamic();
}

bool MTRBuilder::setBindingDimensions(int32_t index, nvinfer1::Dims dimensions)
bool MTRBuilder::setBindingDimensions(int index, nvinfer1::Dims dimensions)
{
if (isDynamic()) {
return context_->setBindingDimensions(index, dimensions);
Expand Down
29 changes: 13 additions & 16 deletions src/mtr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ void TrtMTR::initCudaPtr(const AgentData & agent_data, const PolylineData & poly
num_timestamp_ = agent_data.num_timestamp();
num_agent_dim_ = agent_data.num_state_dim();
num_agent_class_ = agent_data.num_class();
num_agent_attr_ = agent_data.num_attr();
num_polyline_ = polyline_data.num_polyline();
num_point_ = polyline_data.num_point();
num_point_dim_ = polyline_data.num_state_dim();
num_point_attr_ = polyline_data.num_attr();

// source data
d_target_index_ = cuda::make_unique<int[]>(num_target_);
Expand All @@ -92,14 +94,14 @@ void TrtMTR::initCudaPtr(const AgentData & agent_data, const PolylineData & poly
d_in_trajectory_mask_ = cuda::make_unique<bool[]>(num_target_ * num_agent_ * num_timestamp_);
d_in_last_pos_ = cuda::make_unique<float[]>(num_target_ * num_agent_ * 3);
d_in_polyline_ = cuda::make_unique<float[]>(
num_target_ * config_.max_num_polyline * num_point_ * polyline_data.input_attribute_size());
num_target_ * config_.max_num_polyline * num_point_ * num_point_attr_);
d_in_polyline_mask_ =
cuda::make_unique<bool[]>(num_target_ * config_.max_num_polyline * num_point_);
d_in_polyline_center_ = cuda::make_unique<float[]>(num_target_ * config_.max_num_polyline * 3);

if (config_.max_num_polyline < num_polyline_) {
d_tmp_polyline_ = cuda::make_unique<float[]>(
num_target_ * num_polyline_ * num_point_ * polyline_data.input_attribute_size());
d_tmp_polyline_ =
cuda::make_unique<float[]>(num_target_ * num_polyline_ * num_point_ * num_point_attr_);
d_tmp_polyline_mask_ = cuda::make_unique<bool[]>(num_target_ * num_polyline_ * num_point_);
d_tmp_distance_ = cuda::make_unique<float[]>(num_target_ * num_polyline_);
}
Expand All @@ -108,32 +110,27 @@ void TrtMTR::initCudaPtr(const AgentData & agent_data, const PolylineData & poly
// TODO(ktro2828): refactor
// obj_trajs
builder_->setBindingDimensions(
0, nvinfer1::Dims4{
agent_data.TargetNum, agent_data.AgentNum, agent_data.TimeLength, inAgentDim});
0, nvinfer1::Dims4{num_target_, num_agent_, num_timestamp_, num_agent_attr_});
// obj_trajs_mask
builder_->setBindingDimensions(
1, nvinfer1::Dims3{agent_data.TargetNum, agent_data.AgentNum, agent_data.TimeLength});
builder_->setBindingDimensions(1, nvinfer1::Dims3{num_target_, num_agent_, num_timestamp_});
// polylines
builder_->setBindingDimensions(
2, nvinfer1::Dims4{
agent_data.TargetNum, config_.max_num_polyline, polyline_data.PointNum, inPointDim});
2, nvinfer1::Dims4{num_target_, config_.max_num_polyline, num_point_, num_point_attr_});
// polyline mask
builder_->setBindingDimensions(
3, nvinfer1::Dims3{agent_data.TargetNum, config_.max_num_polyline, polyline_data.PointNum});
3, nvinfer1::Dims3{num_target_, config_.max_num_polyline, num_point_});
// polyline center
builder_->setBindingDimensions(
4, nvinfer1::Dims3{agent_data.TargetNum, config_.max_num_polyline, 3});
builder_->setBindingDimensions(4, nvinfer1::Dims3{num_target_, config_.max_num_polyline, 3});
// obj last pos
builder_->setBindingDimensions(
5, nvinfer1::Dims3{agent_data.TargetNum, agent_data.AgentNum, 3});
builder_->setBindingDimensions(5, nvinfer1::Dims3{num_target_, num_agent_, 3});
// track index to predict
nvinfer1::Dims targetIdxDim;
targetIdxDim.nbDims = 1;
targetIdxDim.d[0] = agent_data.TargetNum;
targetIdxDim.d[0] = num_target_;
builder_->setBindingDimensions(6, targetIdxDim);
// intention points
builder_->setBindingDimensions(
7, nvinfer1::Dims3{agent_data.TargetNum, config_.num_intention_point_cluster, 2});
7, nvinfer1::Dims3{num_target_, config_.num_intention_point_cluster, 2});
}

// outputs
Expand Down

0 comments on commit 3c1844a

Please sign in to comment.