Skip to content

Commit

Permalink
hybrid search flat_search_cutoff
Browse files Browse the repository at this point in the history
  • Loading branch information
krunal1313 committed Dec 9, 2024
1 parent f6bdba1 commit fd2b56e
Showing 1 changed file with 92 additions and 36 deletions.
128 changes: 92 additions & 36 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3542,46 +3542,102 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
VectorFilterFunctor filterFunctor(filter_result_iterator, excluded_result_ids, excluded_result_ids_size);
auto& field_vector_index = vector_index.at(vector_query.field_name);

std::vector<std::pair<float, size_t>> dist_labels;
// use k as 100 by default for ensuring results stability in pagination
size_t default_k = 100;
auto k = vector_query.k == 0 ? std::max<size_t>(fetch_size, default_k) : vector_query.k;
if(field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(normalized_q.data(), k, vector_query.ef, &filterFunctor);
} else {
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, vector_query.ef, &filterFunctor);
uint32_t filter_id_count = filter_result_iterator->approx_filter_ids_length;
std::vector<std::pair<float, single_filter_result_t>> dist_results;
if (filter_by_provided && filter_id_count < vector_query.flat_search_cutoff) {
while (filter_result_iterator->validity == filter_result_iterator_t::valid) {
auto seq_id = filter_result_iterator->seq_id;
auto filter_result = single_filter_result_t(seq_id, std::move(
filter_result_iterator->reference));
filter_result_iterator->next();
std::vector<float> values;

try {
values = field_vector_index->vecdex->getDataByLabel<float>(seq_id);
} catch (...) {
// likely not found
continue;
}

float dist;
if (field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
dist = field_vector_index->space->get_dist_func()(normalized_q.data(),
values.data(),
&field_vector_index->num_dim);
} else {
dist = field_vector_index->space->get_dist_func()(
vector_query.values.data(), values.data(),
&field_vector_index->num_dim);
}

dist_results.emplace_back(dist, filter_result);
}
}

filter_result_iterator->reset();
search_cutoff = search_cutoff || filter_result_iterator->validity == filter_result_iterator_t::timed_out;

std::vector<std::pair<uint32_t,float>> vec_results;
for (const auto& dist_label : dist_labels) {
uint32_t seq_id = dist_label.second;
if (!filter_by_provided || (filter_id_count >= vector_query.flat_search_cutoff &&
filter_result_iterator->validity ==
filter_result_iterator_t::valid)) {

auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) :
dist_label.first;
if(vec_dist_score > vector_query.distance_threshold) {
continue;
dist_results.clear();
std::vector<std::pair<float, size_t>> dist_labels;
// use k as 100 by default for ensuring results stability in pagination
size_t default_k = 100;
auto k = vector_query.k == 0 ? std::max<size_t>(fetch_size, default_k)
: vector_query.k;
if (field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(
normalized_q.data(), k, vector_query.ef, &filterFunctor);
} else {
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(
vector_query.values.data(), k, vector_query.ef, &filterFunctor);
}
vec_results.emplace_back(seq_id, vec_dist_score);
}
filter_result_iterator->reset();
search_cutoff = search_cutoff || filter_result_iterator->validity ==
filter_result_iterator_t::timed_out;

// iteration needs to happen on sorted sequence ID but score wise sort needed for compute rank fusion
std::sort(vec_results.begin(), vec_results.end(), [](const auto& a, const auto& b) {
return a.second < b.second;
});
std::vector<std::pair<uint32_t, float>> vec_results;
for (const auto &dist_label: dist_labels) {
uint32_t seq_id = dist_label.second;

auto vec_dist_score = (field_vector_index->distance_type == cosine)
? std::abs(dist_label.first) :
dist_label.first;
if (vec_dist_score > vector_query.distance_threshold) {
continue;
}
vec_results.emplace_back(seq_id, vec_dist_score);
}


// iteration needs to happen on sorted sequence ID but score wise sort needed for compute rank fusion
std::sort(vec_results.begin(), vec_results.end(),
[](const auto &a, const auto &b) {
return a.second < b.second;
});

for (const auto &pair: vec_results) {
auto filter_result = single_filter_result_t(pair.first, {});
dist_results.emplace_back(pair.second, filter_result);
}
}

std::unordered_map<uint32_t, uint32_t> seq_id_to_rank;

for(size_t vec_index = 0; vec_index < vec_results.size(); vec_index++) {
seq_id_to_rank.emplace(vec_results[vec_index].first, vec_index);
for (size_t vec_index = 0; vec_index < dist_results.size(); vec_index++) {
seq_id_to_rank.emplace(dist_results[vec_index].second.seq_id, vec_index);
}

std::sort(vec_results.begin(), vec_results.end(), [](const auto& a, const auto& b) {
return a.first < b.first;
});

std::sort(dist_results.begin(), dist_results.end(),
[](const auto &a, const auto &b) {
return a.second.seq_id < b.second.seq_id;
});

std::vector<KV*> kvs;
if(group_limit != 0) {
Expand Down Expand Up @@ -3618,10 +3674,10 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
group_by_field_it_vec = get_group_by_field_iterators(group_by_fields);
}

for(size_t res_index = 0; res_index < vec_results.size() &&
for(size_t res_index = 0; res_index < dist_results.size() &&
filter_result_iterator->validity != filter_result_iterator_t::timed_out; res_index++) {
auto& vec_result = vec_results[res_index];
auto seq_id = vec_result.first;
auto& dist_result = dist_results[res_index];
auto seq_id = dist_result.second.seq_id;

if (filter_by_provided && filter_result_iterator->is_valid(seq_id) != 1) {
continue;
Expand Down Expand Up @@ -3651,7 +3707,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
// result overlaps with keyword search: we have to combine the scores

// old_score + (1 / rank_of_document) * WEIGHT)
found_kv->vector_distance = vec_result.second;
found_kv->vector_distance = dist_result.first;
int64_t match_score = float_to_int64_t(
(int64_t_to_float(found_kv->scores[found_kv->match_score_index])) +
((1.0 / (seq_id_to_rank[seq_id] + 1)) * VECTOR_SEARCH_WEIGHT));
Expand All @@ -3662,7 +3718,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
auto compute_sort_scores_op = compute_sort_scores(sort_fields_std, sort_order, field_values,
geopoint_indices, seq_id, references, eval_filter_indexes,
match_score, scores, match_score_index, should_skip,
vec_result.second, collection_name);
dist_result.first, collection_name);
if (!compute_sort_scores_op.ok()) {
return compute_sort_scores_op;
}
Expand All @@ -3688,7 +3744,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
auto compute_sort_scores_op = compute_sort_scores(sort_fields_std, sort_order, field_values,
geopoint_indices, seq_id, references, eval_filter_indexes,
match_score, scores, match_score_index, should_skip,
vec_result.second, collection_name);
dist_result.first, collection_name);
if (!compute_sort_scores_op.ok()) {
return compute_sort_scores_op;
}
Expand All @@ -3711,7 +3767,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
}
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores, std::move(references));
kv.text_match_score = 0;
kv.vector_distance = vec_result.second;
kv.vector_distance = dist_result.first;

auto ret = topster->add(&kv);
vec_search_ids.push_back(seq_id);
Expand Down

0 comments on commit fd2b56e

Please sign in to comment.