Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed May 9, 2024
1 parent cfed5cd commit c9dc21f
Showing 1 changed file with 26 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
/************************************************
* Copyright MaybeShewill-CV. All Rights Reserved.
* Author: MaybeShewill-CV
* File: ddim_sampler_benchmark.cpp
* File: cls_cond_ddim_sampler_benchmark.cpp
* Date: 24-4-28
************************************************/

// ddim-sampler benchmark tool
// cls-cond-ddim-sampler benchmark tool

#include <random>

Expand All @@ -14,20 +14,20 @@

#include "common/time_stamp.h"
#include "models/model_io_define.h"
#include "models/diffussion/ddim_sampler.h"
#include "models/diffussion/cls_cond_ddim_sampler.h"

using jinq::common::CvUtils;
using jinq::common::Timestamp;
using jinq::common::FilePathUtil;
using jinq::models::io_define::diffusion::std_ddim_input;
using jinq::models::io_define::diffusion::std_ddim_output;
using jinq::models::diffusion::DDIMSampler;
using jinq::models::io_define::diffusion::std_cls_cond_ddim_input;
using jinq::models::io_define::diffusion::std_cls_cond_ddim_output;
using jinq::models::diffusion::ClsCondDDIMSampler;

int main(int argc, char** argv) {

if (argc != 2 && argc != 3 && argc != 4 && argc != 5) {
if (argc != 2 && argc != 3 && argc != 4 && argc != 5 && argc != 6) {
LOG(ERROR) << "wrong usage";
LOG(INFO) << "exe config_file_path [sample_size(default: 128)] [sample_steps(default: 10)] "
LOG(INFO) << "exe config_file_path cls_id [sample_size(default: 128)] [sample_steps(default: 10)] "
"[save_all_mid_results(default: true)]";
return -1;
}
Expand All @@ -41,43 +41,49 @@ int main(int argc, char** argv) {
}
auto cfg = toml::parse(cfg_file_path);

int sample_size = 128;
int cls_id = 0;
if (argc >= 3) {
sample_size = std::stoi(argv[2]);
cls_id = std::stoi(argv[2]);
}

int sample_steps = 10;
int sample_size = 128;
if (argc >= 4) {
sample_steps = std::stoi(argv[3]);
sample_size = std::stoi(argv[3]);
}

bool save_all_mid_results = true;
int sample_steps = 10;
if (argc >= 5) {
save_all_mid_results = std::stoi(argv[4]) == 1;
sample_steps = std::stoi(argv[4]);
}

bool save_all_mid_results = true;
if (argc >= 6) {
save_all_mid_results = std::stoi(argv[5]) == 1;
}

// construct model input
std_ddim_output model_output;
std_ddim_input model_input;
std_cls_cond_ddim_output model_output;
std_cls_cond_ddim_input model_input;
model_input.sample_size = cv::Size(sample_size, sample_size);
model_input.total_steps = 1000;
model_input.sample_steps = sample_steps;
model_input.channels = 3;
model_input.save_all_mid_results = save_all_mid_results;
model_input.eta = 1.0f;
model_input.cls_id = cls_id;

// construct ddpm unet
auto sampler = std::make_unique<DDIMSampler<std_ddim_input, std_ddim_output > >();
auto sampler = std::make_unique<ClsCondDDIMSampler<std_cls_cond_ddim_input, std_cls_cond_ddim_output > >();
sampler->init(cfg);
if (!sampler->is_successfully_initialized()) {
LOG(INFO) << "ddim sampler model init failed";
LOG(INFO) << "class cond ddim sampler model init failed";
return -1;
}

// run benchmark
int loop_times = 1;
LOG(INFO) << "ddim sampler run loop times: " << loop_times;
LOG(INFO) << "start ddim sampler benchmark at: " << Timestamp::now().to_format_str();
LOG(INFO) << "class cond ddim sampler run loop times: " << loop_times;
LOG(INFO) << "start class cond ddim sampler benchmark at: " << Timestamp::now().to_format_str();
auto ts = Timestamp::now();
for (int i = 0; i < loop_times; ++i) {
sampler->run(model_input, model_output);
Expand Down

0 comments on commit c9dc21f

Please sign in to comment.