From 36f77a18a5d0ddaf429f45b41addd4913361fd39 Mon Sep 17 00:00:00 2001 From: Divye Gala Date: Thu, 27 Jun 2024 12:03:46 -0400 Subject: [PATCH] Fix 0 recall issue in `raft_cagra_hnswlib` ANN benchmark (#2369) `raft_cagra` wrapper stopped including the dataset in the index to save memory, but this adversely affected `raft_cagra_hnswlib` wrapper because the dataset needed to be included in the index. The need for inclusion of the dataset is because we need the dataset to be serialized when writing to the `hnswlib` format. Authors: - Divye Gala (https://github.com/divyegala) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2369 --- cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h | 2 +- cpp/bench/ann/src/raft/raft_cagra_wrapper.h | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h index 1c4b847d1a..1d2a1076ab 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h @@ -31,7 +31,7 @@ class RaftCagraHnswlib : public ANN, public AnnGPU { RaftCagraHnswlib(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1) : ANN(metric, dim), - cagra_build_{metric, dim, param, concurrent_searches}, + cagra_build_{metric, dim, param, concurrent_searches, true}, // HnswLib param values don't matter since we don't build with HnswLib hnswlib_search_{metric, dim, typename HnswLib::BuildParam{50, 100}} { diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index 0b892dec35..b03f875a8e 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -72,11 +72,16 @@ class RaftCagra : public ANN, public AnnGPU { std::optional ivf_pq_search_params = std::nullopt; }; - RaftCagra(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1) + RaftCagra(Metric metric, + int dim, + const BuildParam& param, + int concurrent_searches = 1, + bool shall_include_dataset = false) : ANN(metric, dim), index_params_(param), dimension_(dim), need_dataset_update_(true), + shall_include_dataset_(shall_include_dataset), dataset_(std::make_shared>( std::move(make_device_matrix(handle_, 0, 0)))), graph_(std::make_shared>( @@ -135,6 +140,7 @@ class RaftCagra : public ANN, public AnnGPU { float refine_ratio_; BuildParam index_params_; bool need_dataset_update_; + bool shall_include_dataset_; raft::neighbors::cagra::search_params search_params_; std::shared_ptr> index_; int dimension_; @@ -161,7 +167,7 @@ void RaftCagra::build(const T* dataset, size_t nrow) auto& params = index_params_.cagra_params; // Do include the compressed dataset for the CAGRA-Q - bool shall_include_dataset = params.compression.has_value(); + bool include_dataset = params.compression.has_value() || shall_include_dataset_; index_ = std::make_shared>( std::move(raft::neighbors::cagra::detail::build(handle_, @@ -171,7 +177,7 @@ void RaftCagra::build(const T* dataset, size_t nrow) index_params_.ivf_pq_refine_rate, index_params_.ivf_pq_build_params, index_params_.ivf_pq_search_params, - shall_include_dataset))); + include_dataset))); } inline std::string allocator_to_string(AllocatorType mem_type)