Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed Nov 2, 2023
1 parent fe449f9 commit b4a401b
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions src/models/mono_depth_estimation/metric3d.inl
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ private:
std::string model_file_path;
// trt context
TrtLogger logger;
std::unique_ptr<nvinfer1::IRuntime> runtime;
std::unique_ptr<nvinfer1::ICudaEngine> engine;
std::unique_ptr<nvinfer1::IExecutionContext> execution_context;
nvinfer1::IRuntime* runtime = nullptr;
nvinfer1::ICudaEngine* engine = nullptr;
nvinfer1::IExecutionContext* execution_context = nullptr;
// trt bindings
EngineBinding input_image_binding;
EngineBinding output_depth_binding;
Expand Down Expand Up @@ -669,12 +669,11 @@ template <typename INPUT, typename OUTPUT>
StatusCode Metric3D<INPUT, OUTPUT>::Impl::init_trt(const toml::value& cfg) {
// init trt runtime
_m_trt_params.logger = TrtLogger();
auto* trt_runtime = nvinfer1::createInferRuntime(_m_trt_params.logger);
if(nullptr == trt_runtime) {
_m_trt_params.runtime = nvinfer1::createInferRuntime(_m_trt_params.logger);
if(nullptr == _m_trt_params.runtime) {
LOG(ERROR) << "init tensorrt runtime failed";
return StatusCode::MODEL_INIT_FAILED;
}
_m_trt_params.runtime = std::unique_ptr<nvinfer1::IRuntime>(trt_runtime);

// init trt engine
if (!cfg.contains("model_file_path")) {
Expand All @@ -693,15 +692,14 @@ StatusCode Metric3D<INPUT, OUTPUT>::Impl::init_trt(const toml::value& cfg) {
return StatusCode::MODEL_INIT_FAILED;
}
auto model_content_length = sizeof(model_file_content[0]) * model_file_content.size();
_m_trt_params.engine = std::unique_ptr<nvinfer1::ICudaEngine>(
_m_trt_params.runtime->deserializeCudaEngine(model_file_content.data(), model_content_length));
_m_trt_params.engine = _m_trt_params.runtime->deserializeCudaEngine(model_file_content.data(), model_content_length);
if (nullptr == _m_trt_params.engine) {
LOG(ERROR) << "deserialize trt engine failed";
return StatusCode::MODEL_INIT_FAILED;
}

// init trt execution context
_m_trt_params.execution_context = std::unique_ptr<nvinfer1::IExecutionContext>(_m_trt_params.engine->createExecutionContext());
_m_trt_params.execution_context = _m_trt_params.engine->createExecutionContext();
if (nullptr == _m_trt_params.execution_context) {
LOG(ERROR) << "create trt engine failed";
return StatusCode::MODEL_INIT_FAILED;
Expand Down

0 comments on commit b4a401b

Please sign in to comment.