diff --git a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h index 1c4b847d1a..1d2a1076ab 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h @@ -31,7 +31,7 @@ class RaftCagraHnswlib : public ANN, public AnnGPU { RaftCagraHnswlib(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1) : ANN(metric, dim), - cagra_build_{metric, dim, param, concurrent_searches}, + cagra_build_{metric, dim, param, concurrent_searches, true}, // HnswLib param values don't matter since we don't build with HnswLib hnswlib_search_{metric, dim, typename HnswLib::BuildParam{50, 100}} { diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index 0b892dec35..b03f875a8e 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -72,11 +72,16 @@ class RaftCagra : public ANN, public AnnGPU { std::optional ivf_pq_search_params = std::nullopt; }; - RaftCagra(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1) + RaftCagra(Metric metric, + int dim, + const BuildParam& param, + int concurrent_searches = 1, + bool shall_include_dataset = false) : ANN(metric, dim), index_params_(param), dimension_(dim), need_dataset_update_(true), + shall_include_dataset_(shall_include_dataset), dataset_(std::make_shared>( std::move(make_device_matrix(handle_, 0, 0)))), graph_(std::make_shared>( @@ -135,6 +140,7 @@ class RaftCagra : public ANN, public AnnGPU { float refine_ratio_; BuildParam index_params_; bool need_dataset_update_; + bool shall_include_dataset_; raft::neighbors::cagra::search_params search_params_; std::shared_ptr> index_; int dimension_; @@ -161,7 +167,7 @@ void RaftCagra::build(const T* dataset, size_t nrow) auto& params = index_params_.cagra_params; // Do include the compressed dataset for the CAGRA-Q - bool shall_include_dataset = params.compression.has_value(); + bool include_dataset = params.compression.has_value() || shall_include_dataset_; index_ = std::make_shared>( std::move(raft::neighbors::cagra::detail::build(handle_, @@ -171,7 +177,7 @@ void RaftCagra::build(const T* dataset, size_t nrow) index_params_.ivf_pq_refine_rate, index_params_.ivf_pq_build_params, index_params_.ivf_pq_search_params, - shall_include_dataset))); + include_dataset))); } inline std::string allocator_to_string(AllocatorType mem_type)