Skip to content

Commit

Permalink
support fieldwise locale tokenization for query inference
Browse files Browse the repository at this point in the history
  • Loading branch information
krunal1313 committed Dec 17, 2024
1 parent dfa744e commit 00ce97c
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 17 deletions.
6 changes: 4 additions & 2 deletions include/collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -481,13 +481,15 @@ class Collection {
void parse_search_query(const std::string &query, std::vector<std::string>& q_include_tokens, std::vector<std::string>& q_include_tokens_non_stemmed,
std::vector<std::vector<std::string>>& q_exclude_tokens,
std::vector<std::vector<std::string>>& q_phrases,
const std::string& locale, const bool already_segmented, const std::string& stopword_set="", std::shared_ptr<Stemmer> stemmer = nullptr) const;
const std::string& locale, const bool already_segmented, const std::string& stopword_set="", std::shared_ptr<Stemmer> stemmer = nullptr,
const std::vector<char>& field_token_separators = {}, const std::vector<char>& field_symbols_to_index = {}) const;

void process_tokens(std::vector<std::string>& tokens, std::vector<std::string>& q_include_tokens,
std::vector<std::vector<std::string>>& q_exclude_tokens,
std::vector<std::vector<std::string>>& q_phrases, bool& exclude_operator_prior,
bool& phrase_search_op_prior, std::vector<std::string>& phrase, const std::string& stopwords_set,
const bool& already_segmented, const std::string& locale, std::shared_ptr<Stemmer> stemmer) const;
const bool& already_segmented, const std::string& locale, std::shared_ptr<Stemmer> stemmer,
const std::vector<char>& local_token_separators, const std::vector<char>& local_symbols_to_index) const;

// PUBLIC OPERATIONS

Expand Down
76 changes: 61 additions & 15 deletions src/collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2410,14 +2410,49 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
}
} else {
field_query_tokens.emplace_back(query_tokens_t{});
auto most_weighted_field = search_schema.at(weighted_search_fields[0].name);
const std::string & field_locale = most_weighted_field.locale;

parse_search_query(query, q_include_tokens, q_unstemmed_tokens,
field_query_tokens[0].q_exclude_tokens,
field_query_tokens[0].q_phrases,
field_locale, pre_segmented_query, stopwords_set, most_weighted_field.get_stemmer());
for(int i = 0; i < weighted_search_fields.size(); ++i) {
auto weighted_field = search_schema.at(weighted_search_fields[i].name);
const std::string& field_locale = weighted_field.locale;
const auto& field_token_separators = weighted_field.token_separators;
const auto& field_symbols_to_index = weighted_field.symbols_to_index;

if(i > 0 && field_locale.empty() && field_token_separators.empty() && field_symbols_to_index.empty()) {
//if no different locale, token_separators, symbols_to_index then no need to process query again
continue;
}

parse_search_query(query, q_include_tokens, q_unstemmed_tokens,
field_query_tokens[0].q_exclude_tokens,
field_query_tokens[0].q_phrases,
field_locale, pre_segmented_query, stopwords_set, weighted_field.get_stemmer(),
field_token_separators, field_symbols_to_index);
}

if(weighted_search_fields.size() > 1) {
//remove duplicate tokens found during multiple fields
std::set<std::string> tokens_mapping;
for (auto it = q_include_tokens.begin(); it != q_include_tokens.end();) {
if (tokens_mapping.find(*it) == tokens_mapping.end()) {
tokens_mapping.insert(*it);
it++;
} else {
//duplicate
it = q_include_tokens.erase(it);
}
}

tokens_mapping.clear();
for (auto it = q_unstemmed_tokens.begin(); it != q_unstemmed_tokens.end();) {
if (tokens_mapping.find(*it) == tokens_mapping.end()) {
tokens_mapping.insert(*it);
it++;
} else {
//duplicate
it = q_unstemmed_tokens.erase(it);
}
}
}
// process filter overrides first, before synonyms (order is important)

// included_ids, excluded_ids
Expand Down Expand Up @@ -3534,12 +3569,13 @@ void Collection::process_tokens(std::vector<std::string>& tokens, std::vector<st
std::vector<std::vector<std::string>>& q_exclude_tokens,
std::vector<std::vector<std::string>>& q_phrases, bool& exclude_operator_prior,
bool& phrase_search_op_prior, std::vector<std::string>& phrase, const std::string& stopwords_set,
const bool& already_segmented, const std::string& locale, std::shared_ptr<Stemmer> stemmer) const{
const bool& already_segmented, const std::string& locale, std::shared_ptr<Stemmer> stemmer,
const std::vector<char>& local_token_separators, const std::vector<char>& local_symbols_to_index) const{



auto symbols_to_index_has_minus =
std::find(symbols_to_index.begin(), symbols_to_index.end(), '-') != symbols_to_index.end();
std::find(local_symbols_to_index.begin(), local_symbols_to_index.end(), '-') != local_symbols_to_index.end();

for(auto& token: tokens) {
bool end_of_phrase = false;
Expand Down Expand Up @@ -3576,7 +3612,7 @@ void Collection::process_tokens(std::vector<std::string>& tokens, std::vector<st
if(already_segmented) {
StringUtils::split(token, sub_tokens, " ");
} else {
Tokenizer(token, true, false, locale, symbols_to_index, token_separators, stemmer).tokenize(sub_tokens);
Tokenizer(token, true, false, locale, local_symbols_to_index, local_token_separators, stemmer).tokenize(sub_tokens);
}

for(auto& sub_token: sub_tokens) {
Expand Down Expand Up @@ -3633,11 +3669,18 @@ void Collection::process_tokens(std::vector<std::string>& tokens, std::vector<st
void Collection::parse_search_query(const std::string &query, std::vector<std::string>& q_include_tokens, std::vector<std::string>& q_unstemmed_tokens,
std::vector<std::vector<std::string>>& q_exclude_tokens,
std::vector<std::vector<std::string>>& q_phrases,
const std::string& locale, const bool already_segmented, const std::string& stopwords_set, std::shared_ptr<Stemmer> stemmer) const {
const std::string& locale, const bool already_segmented,
const std::string& stopwords_set, std::shared_ptr<Stemmer> stemmer,
const std::vector<char>& field_token_separators,
const std::vector<char>& field_symbols_to_index) const {
if(query == "*") {
q_exclude_tokens = {};
q_include_tokens = {query};
} else {

const auto& local_token_separators = !field_token_separators.empty() ? field_token_separators : token_separators;
const auto& local_symbols_to_index = !field_symbols_to_index.empty() ? field_symbols_to_index : symbols_to_index;

std::vector<std::string> tokens;
std::vector<std::string> tokens_non_stemmed;
stopword_struct_t stopwordStruct;
Expand All @@ -3652,13 +3695,13 @@ void Collection::parse_search_query(const std::string &query, std::vector<std::s
if(already_segmented) {
StringUtils::split(query, tokens, " ");
} else {
std::vector<char> custom_symbols = symbols_to_index;
std::vector<char> custom_symbols = local_symbols_to_index;
custom_symbols.push_back('-');
custom_symbols.push_back('"');

Tokenizer(query, true, false, locale, custom_symbols, token_separators, stemmer).tokenize(tokens);
Tokenizer(query, true, false, locale, custom_symbols, local_token_separators, stemmer).tokenize(tokens);
if(stemmer) {
Tokenizer(query, true, false, locale, custom_symbols, token_separators, nullptr).tokenize(tokens_non_stemmed);
Tokenizer(query, true, false, locale, custom_symbols, local_token_separators, nullptr).tokenize(tokens_non_stemmed);
}
}

Expand All @@ -3671,7 +3714,8 @@ void Collection::parse_search_query(const std::string &query, std::vector<std::s
bool phrase_search_op_prior = false;
std::vector<std::string> phrase;

process_tokens(tokens, q_include_tokens, q_exclude_tokens, q_phrases, exclude_operator_prior, phrase_search_op_prior, phrase, stopwords_set, already_segmented, locale, stemmer);
process_tokens(tokens, q_include_tokens, q_exclude_tokens, q_phrases, exclude_operator_prior, phrase_search_op_prior,
phrase, stopwords_set, already_segmented, locale, stemmer, local_token_separators, local_symbols_to_index);

if(stemmer) {
exclude_operator_prior = false;
Expand All @@ -3681,7 +3725,9 @@ void Collection::parse_search_query(const std::string &query, std::vector<std::s
std::vector<std::vector<std::string>> q_exclude_tokens_dummy;
std::vector<std::vector<std::string>> q_phrases_dummy;

process_tokens(tokens_non_stemmed, q_unstemmed_tokens, q_exclude_tokens_dummy, q_phrases_dummy, exclude_operator_prior, phrase_search_op_prior, phrase, stopwords_set, already_segmented, locale, nullptr);
process_tokens(tokens_non_stemmed, q_unstemmed_tokens, q_exclude_tokens_dummy, q_phrases_dummy, exclude_operator_prior,
phrase_search_op_prior, phrase, stopwords_set, already_segmented, locale, nullptr, local_token_separators,
local_symbols_to_index);
}
}
}
Expand Down

0 comments on commit 00ce97c

Please sign in to comment.