Skip to content

Commit

Permalink
refactor repeatative code
Browse files Browse the repository at this point in the history
  • Loading branch information
krunal1313 committed Dec 10, 2024
1 parent fd2b56e commit 61bea97
Showing 1 changed file with 109 additions and 136 deletions.
245 changes: 109 additions & 136 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2870,6 +2870,108 @@ Option<bool> Index::search_infix(const std::string& query, const std::string& fi
return Option<bool>(true);
}

void process_results_bruteforce(filter_result_iterator_t* filter_result_iterator, const vector_query_t& vector_query,
hnsw_index_t* field_vector_index, std::vector<std::pair<float, single_filter_result_t>>& dist_results) {

while (filter_result_iterator->validity == filter_result_iterator_t::valid) {
auto seq_id = filter_result_iterator->seq_id;
auto filter_result = single_filter_result_t(seq_id, std::move(filter_result_iterator->reference));
filter_result_iterator->next();
std::vector<float> values;

try {
values = field_vector_index->vecdex->getDataByLabel<float>(seq_id);
} catch (...) {
// likely not found
continue;
}

float dist;
if (field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
dist = field_vector_index->space->get_dist_func()(normalized_q.data(), values.data(),
&field_vector_index->num_dim);
} else {
dist = field_vector_index->space->get_dist_func()(vector_query.values.data(), values.data(),
&field_vector_index->num_dim);
}

dist_results.emplace_back(dist, filter_result);
}
}

void process_results_hnsw_index(filter_result_iterator_t* filter_result_iterator, const vector_query_t& vector_query,
hnsw_index_t* field_vector_index, VectorFilterFunctor& filterFunctor, size_t k,
std::vector<std::pair<float, single_filter_result_t>>& dist_results, bool is_wildcard_non_phrase_query = false) {

std::vector<std::pair<float, size_t>> pairs;
if(field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
pairs = field_vector_index->vecdex->searchKnnCloserFirst(normalized_q.data(), k, vector_query.ef, &filterFunctor);
} else {
pairs = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, vector_query.ef, &filterFunctor);
}

std::sort(pairs.begin(), pairs.end(), [](auto& x, auto& y) {
return x.second < y.second;
});

filter_result_iterator->reset();

if (!filter_result_iterator->reference.empty() && is_wildcard_non_phrase_query) {
// We'll have to get the references of each document.
for (auto pair: pairs) {
if (filter_result_iterator->validity == filter_result_iterator_t::timed_out) {
// Overriding timeout since we need to get the references of matched docs.
filter_result_iterator->reset(true);
search_cutoff = true;
}

auto const& seq_id = pair.second;
if (filter_result_iterator->is_valid(seq_id, search_cutoff) != 1) {
continue;
}
// The seq_id must be valid otherwise it would've been filtered out upstream.
auto filter_result = single_filter_result_t(seq_id,
std::move(filter_result_iterator->reference));
dist_results.emplace_back(pair.first, filter_result);
}
} else {

search_cutoff = search_cutoff || filter_result_iterator->validity ==
filter_result_iterator_t::timed_out;

if(!is_wildcard_non_phrase_query) {
std::vector<std::pair<float, size_t>> vec_results;

for(const auto& pair: pairs) {
auto vec_dist_score = (field_vector_index->distance_type == cosine)
? std::abs(pair.first) :
pair.first;
if (vec_dist_score > vector_query.distance_threshold) {
continue;
}
vec_results.push_back(pair);
}

// iteration needs to happen on sorted sequence ID but score wise sort needed for compute rank fusion
std::sort(vec_results.begin(), vec_results.end(),
[](const auto &a, const auto &b) {
return a.first < b.first;
});

pairs = std::move(vec_results);
}

for (const auto &pair: pairs) {
auto filter_result = single_filter_result_t(pair.second, {});
dist_results.emplace_back(pair.first, filter_result);
}
}
}

Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::vector<search_field_t>& the_fields,
const text_match_type_t match_type,
filter_node_t*& filter_tree_root, std::vector<facet>& facets, facet_query_t& facet_query,
Expand Down Expand Up @@ -3102,79 +3204,16 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons

uint32_t filter_id_count = filter_result_iterator->approx_filter_ids_length;
if (filter_by_provided && filter_id_count < vector_query.flat_search_cutoff) {
while (filter_result_iterator->validity == filter_result_iterator_t::valid) {
auto seq_id = filter_result_iterator->seq_id;
auto filter_result = single_filter_result_t(seq_id, std::move(filter_result_iterator->reference));
filter_result_iterator->next();
std::vector<float> values;

try {
values = field_vector_index->vecdex->getDataByLabel<float>(seq_id);
} catch (...) {
// likely not found
continue;
}

float dist;
if (field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
dist = field_vector_index->space->get_dist_func()(normalized_q.data(), values.data(),
&field_vector_index->num_dim);
} else {
dist = field_vector_index->space->get_dist_func()(vector_query.values.data(), values.data(),
&field_vector_index->num_dim);
}

dist_results.emplace_back(dist, filter_result);
}
process_results_bruteforce(filter_result_iterator, vector_query, field_vector_index, dist_results);
}

filter_result_iterator->reset();
search_cutoff = search_cutoff || filter_result_iterator->validity == filter_result_iterator_t::timed_out;

if(!filter_by_provided ||
(filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator->validity == filter_result_iterator_t::valid)) {
dist_results.clear();

std::vector<std::pair<float, size_t>> pairs;
if(field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
pairs = field_vector_index->vecdex->searchKnnCloserFirst(normalized_q.data(), k, vector_query.ef, &filterFunctor);
} else {
pairs = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, vector_query.ef, &filterFunctor);
}

std::sort(pairs.begin(), pairs.end(), [](auto& x, auto& y) {
return x.second < y.second;
});

filter_result_iterator->reset();

if (!filter_result_iterator->reference.empty()) {
// We'll have to get the references of each document.
for (auto pair: pairs) {
if (filter_result_iterator->validity == filter_result_iterator_t::timed_out) {
// Overriding timeout since we need to get the references of matched docs.
filter_result_iterator->reset(true);
search_cutoff = true;
}

auto const& seq_id = pair.second;
if (filter_result_iterator->is_valid(seq_id, search_cutoff) != 1) {
continue;
}
// The seq_id must be valid otherwise it would've been filtered out upstream.
auto filter_result = single_filter_result_t(seq_id,
std::move(filter_result_iterator->reference));
dist_results.emplace_back(pair.first, filter_result);
}
} else {
for (const auto &pair: pairs) {
auto filter_result = single_filter_result_t(pair.second, {});
dist_results.emplace_back(pair.first, filter_result);
}
}
process_results_hnsw_index(filter_result_iterator, vector_query, field_vector_index, filterFunctor, k, dist_results, true);
}

std::vector<uint32_t> nearest_ids;
Expand Down Expand Up @@ -3544,87 +3583,21 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons

uint32_t filter_id_count = filter_result_iterator->approx_filter_ids_length;
std::vector<std::pair<float, single_filter_result_t>> dist_results;
if (filter_by_provided && filter_id_count < vector_query.flat_search_cutoff) {
while (filter_result_iterator->validity == filter_result_iterator_t::valid) {
auto seq_id = filter_result_iterator->seq_id;
auto filter_result = single_filter_result_t(seq_id, std::move(
filter_result_iterator->reference));
filter_result_iterator->next();
std::vector<float> values;

try {
values = field_vector_index->vecdex->getDataByLabel<float>(seq_id);
} catch (...) {
// likely not found
continue;
}

float dist;
if (field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
dist = field_vector_index->space->get_dist_func()(normalized_q.data(),
values.data(),
&field_vector_index->num_dim);
} else {
dist = field_vector_index->space->get_dist_func()(
vector_query.values.data(), values.data(),
&field_vector_index->num_dim);
}

dist_results.emplace_back(dist, filter_result);
}
if (filter_by_provided && filter_id_count < vector_query.flat_search_cutoff) {
process_results_bruteforce(filter_result_iterator, vector_query, field_vector_index, dist_results);
}

filter_result_iterator->reset();

if (!filter_by_provided || (filter_id_count >= vector_query.flat_search_cutoff &&
filter_result_iterator->validity ==
filter_result_iterator_t::valid)) {

if (!filter_by_provided || (filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator->validity == filter_result_iterator_t::valid)) {
dist_results.clear();
std::vector<std::pair<float, size_t>> dist_labels;
// use k as 100 by default for ensuring results stability in pagination
size_t default_k = 100;
auto k = vector_query.k == 0 ? std::max<size_t>(fetch_size, default_k)
: vector_query.k;
if (field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(
normalized_q.data(), k, vector_query.ef, &filterFunctor);
} else {
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(
vector_query.values.data(), k, vector_query.ef, &filterFunctor);
}
filter_result_iterator->reset();
search_cutoff = search_cutoff || filter_result_iterator->validity ==
filter_result_iterator_t::timed_out;

std::vector<std::pair<uint32_t, float>> vec_results;
for (const auto &dist_label: dist_labels) {
uint32_t seq_id = dist_label.second;

auto vec_dist_score = (field_vector_index->distance_type == cosine)
? std::abs(dist_label.first) :
dist_label.first;
if (vec_dist_score > vector_query.distance_threshold) {
continue;
}
vec_results.emplace_back(seq_id, vec_dist_score);
}


// iteration needs to happen on sorted sequence ID but score wise sort needed for compute rank fusion
std::sort(vec_results.begin(), vec_results.end(),
[](const auto &a, const auto &b) {
return a.second < b.second;
});

for (const auto &pair: vec_results) {
auto filter_result = single_filter_result_t(pair.first, {});
dist_results.emplace_back(pair.second, filter_result);
}
process_results_hnsw_index(filter_result_iterator, vector_query, field_vector_index, filterFunctor, k, dist_results);
}

std::unordered_map<uint32_t, uint32_t> seq_id_to_rank;
Expand Down

0 comments on commit 61bea97

Please sign in to comment.