Skip to content

Commit

Permalink
allow multiple plurals set
Browse files Browse the repository at this point in the history
  • Loading branch information
krunal1313 committed Nov 28, 2024
1 parent f76db44 commit c624195
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 52 deletions.
6 changes: 3 additions & 3 deletions include/field.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ struct field {
bool is_reference_helper = false;

bool stem = false;
bool stem_dictionary = false;
std::string stem_dictionary = "";
std::shared_ptr<Stemmer> stemmer;

nlohmann::json hnsw_params;
Expand All @@ -148,7 +148,7 @@ struct field {
bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false,
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine,
std::string reference = "", const nlohmann::json& embed = nlohmann::json(), const bool range_index = false,
const bool store = true, const bool stem = false, const bool stem_dictionary = false, const nlohmann::json hnsw_params = nlohmann::json(),
const bool store = true, const bool stem = false, const std::string& stem_dictionary = "", const nlohmann::json hnsw_params = nlohmann::json(),
const bool async_reference = false) :
name(name), type(type), facet(facet), optional(optional), index(index), locale(locale),
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference),
Expand All @@ -159,7 +159,7 @@ struct field {

auto const suffix = std::string(fields::REFERENCE_HELPER_FIELD_SUFFIX);
is_reference_helper = name.size() > suffix.size() && name.substr(name.size() - suffix.size()) == suffix;
if (stem || stem_dictionary) {
if (stem || !stem_dictionary.empty()) {
this->stem = true;
stemmer = StemmerManager::get_instance().get_stemmer(locale, stem_dictionary);
}
Expand Down
12 changes: 6 additions & 6 deletions include/stemmer_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ class Stemmer {
sb_stemmer * stemmer = nullptr;
LRU::Cache<std::string, std::string> cache;
std::mutex mutex;
bool use_dictionary = false;
std::string dictionary_name;
public:
Stemmer(const char * language, bool use_dictionary = false);
Stemmer(const char * language, const std::string& dictionary_name="");
~Stemmer();
std::string stem(const std::string & word);
};
Expand All @@ -29,7 +29,7 @@ class StemmerManager {
std::unordered_map<std::string, std::shared_ptr<Stemmer>> stemmers;
StemmerManager() {}
std::mutex mutex;
spp::sparse_hash_map<std::string, std::string> stem_dictionary;
spp::sparse_hash_map<std::string, spp::sparse_hash_map<std::string, std::string>> stem_dictionary;
public:
static StemmerManager& get_instance() {
static StemmerManager instance;
Expand All @@ -40,10 +40,10 @@ class StemmerManager {
StemmerManager(StemmerManager&&) = delete;
void operator=(StemmerManager&&) = delete;
~StemmerManager();
std::shared_ptr<Stemmer> get_stemmer(const std::string& language, bool stem_dictionary = false);
std::shared_ptr<Stemmer> get_stemmer(const std::string& language, const std::string& dictionary_name="");
void delete_stemmer(const std::string& language);
void delete_all_stemmers();
const bool validate_language(const std::string& language);
bool save_words(const std::vector<std::string> &json_lines);
spp::sparse_hash_map<std::string, std::string> get_dictionary();
bool save_words(const std::string& dictionary_name, const std::vector<std::string> &json_lines);
spp::sparse_hash_map<std::string, std::string> get_dictionary(const std::string& dictionary_name);
};
2 changes: 1 addition & 1 deletion src/collection_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection
}

if(field_obj.count(fields::stem_dictionary) == 0) {
field_obj[fields::stem_dictionary] = false;
field_obj[fields::stem_dictionary] = "";
}

if(field_obj.count(fields::range_index) == 0) {
Expand Down
10 changes: 9 additions & 1 deletion src/core_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2920,11 +2920,19 @@ bool post_write_analytics_to_db(const std::shared_ptr<http_req>& req, const std:

bool post_import_plurals(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res) {
const char *BATCH_SIZE = "batch_size";
const char *PLURALS_SET = "plurals_set";

if(req->params.count(BATCH_SIZE) == 0) {
req->params[BATCH_SIZE] = "40";
}

if(req->params.count(PLURALS_SET) == 0) {
res->final = true;
res->set_400("Parameter `" + std::string(PLURALS_SET) + "` must be provided while importing plurals.");
stream_response(req, res);
return false;
}

if(!StringUtils::is_uint32_t(req->params[BATCH_SIZE])) {
res->final = true;
res->set_400("Parameter `" + std::string(BATCH_SIZE) + "` must be a positive integer.");
Expand All @@ -2945,7 +2953,7 @@ bool post_import_plurals(const std::shared_ptr<http_req>& req, const std::shared
std::stringstream response_stream;

if(!single_partial_record_body) {
if(!StemmerManager::get_instance().save_words(json_lines)) {
if(!StemmerManager::get_instance().save_words(req->params.at(PLURALS_SET), json_lines)) {
res->set_400("Bad/malformed dictionary import.");
stream_response(req, res);
}
Expand Down
2 changes: 1 addition & 1 deletion src/field.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
}

if(field_json.count(fields::stem_dictionary) == 0) {
field_json[fields::stem_dictionary] = false;
field_json[fields::stem_dictionary] = "";
}

if (field_json.count(fields::range_index) != 0) {
Expand Down
26 changes: 15 additions & 11 deletions src/stemmer_manager.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include "stemmer_manager.h"


Stemmer::Stemmer(const char * language, bool use_dictionary) {
if(!use_dictionary) {
Stemmer::Stemmer(const char * language, const std::string& dictionary_name) {
if(dictionary_name.empty()) {
this->stemmer = sb_stemmer_new(language, nullptr);
} else {
this->use_dictionary = true;
this->dictionary_name = dictionary_name;
}

this->cache = LRU::Cache<std::string, std::string>(20);
Expand All @@ -24,11 +24,11 @@ std::string Stemmer::stem(const std::string & word) {
return cache.lookup(word);
}

if(!use_dictionary) {
if(dictionary_name.empty()) {
auto stemmed = sb_stemmer_stem(stemmer, reinterpret_cast<const sb_symbol *>(word.c_str()), word.length());
stemmed_word = std::string(reinterpret_cast<const char *>(stemmed));
} else {
const auto& stem_dictionary = StemmerManager::get_instance().get_dictionary();
const auto& stem_dictionary = StemmerManager::get_instance().get_dictionary(dictionary_name);
auto it = stem_dictionary.find(word);
if(it == stem_dictionary.end()) {
stemmed_word = word;
Expand All @@ -46,12 +46,12 @@ StemmerManager::~StemmerManager() {
stem_dictionary.clear();
}

std::shared_ptr<Stemmer> StemmerManager::get_stemmer(const std::string& language, bool use_dictionary) {
std::shared_ptr<Stemmer> StemmerManager::get_stemmer(const std::string& language, const std::string& dictionary_name) {
std::unique_lock<std::mutex> lock(mutex);
// use english as default language
const std::string language_ = language.empty() ? "english" : language;
if (stemmers.find(language_) == stemmers.end()) {
stemmers[language] = std::make_shared<Stemmer>(language_.c_str(), use_dictionary);
stemmers[language] = std::make_shared<Stemmer>(language_.c_str(), dictionary_name);
}
return stemmers[language];
}
Expand All @@ -78,7 +78,7 @@ const bool StemmerManager::validate_language(const std::string& language) {
return true;
}

bool StemmerManager::save_words(const std::vector<std::string> &json_lines) {
bool StemmerManager::save_words(const std::string& dictionary_name, const std::vector<std::string> &json_lines) {
if(json_lines.empty()) {
return false;
}
Expand All @@ -94,11 +94,15 @@ bool StemmerManager::save_words(const std::vector<std::string> &json_lines) {
if(!json_line.contains("word") || !json_line.contains("root")) {
return false;
}
stem_dictionary.emplace(json_line["word"], json_line["root"]);
stem_dictionary[dictionary_name].emplace(json_line["word"], json_line["root"]);
}
return true;
}

spp::sparse_hash_map<std::string, std::string> StemmerManager::get_dictionary() {
return stem_dictionary;
spp::sparse_hash_map<std::string, std::string> StemmerManager::get_dictionary(const std::string& dictionary_name) {
if(stem_dictionary.count(dictionary_name) != 0) {
return stem_dictionary.at(dictionary_name);
}

return spp::sparse_hash_map<std::string, std::string>();
}
30 changes: 15 additions & 15 deletions test/collection_manager_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ TEST_F(CollectionManagerTest, CollectionCreation) {
"type":"string",
"range_index":false,
"stem":false,
"stem_dictionary": false
"stem_dictionary": ""
},
{
"facet":false,
Expand All @@ -152,7 +152,7 @@ TEST_F(CollectionManagerTest, CollectionCreation) {
"type":"string",
"range_index":false,
"stem":false,
"stem_dictionary": false
"stem_dictionary": ""
},
{
"facet":true,
Expand All @@ -167,7 +167,7 @@ TEST_F(CollectionManagerTest, CollectionCreation) {
"type":"string[]",
"range_index":false,
"stem":false,
"stem_dictionary": false
"stem_dictionary": ""
},
{
"facet":true,
Expand All @@ -182,7 +182,7 @@ TEST_F(CollectionManagerTest, CollectionCreation) {
"type":"int32",
"range_index":false,
"stem":false,
"stem_dictionary": false
"stem_dictionary": ""
},
{
"facet":false,
Expand All @@ -197,7 +197,7 @@ TEST_F(CollectionManagerTest, CollectionCreation) {
"type":"geopoint",
"range_index":false,
"stem":false,
"stem_dictionary": false
"stem_dictionary": ""
},
{
"facet":false,
Expand All @@ -212,7 +212,7 @@ TEST_F(CollectionManagerTest, CollectionCreation) {
"type":"string",
"range_index":false,
"stem":false,
"stem_dictionary": false
"stem_dictionary": ""
},
{
"facet":false,
Expand All @@ -227,7 +227,7 @@ TEST_F(CollectionManagerTest, CollectionCreation) {
"type":"int32",
"range_index":false,
"stem":false,
"stem_dictionary": false
"stem_dictionary": ""
},
{
"facet":false,
Expand All @@ -243,7 +243,7 @@ TEST_F(CollectionManagerTest, CollectionCreation) {
"type":"object",
"range_index":false,
"stem":false,
"stem_dictionary": false
"stem_dictionary": ""
},
{
"facet":false,
Expand All @@ -260,7 +260,7 @@ TEST_F(CollectionManagerTest, CollectionCreation) {
"vec_dist":"cosine",
"range_index":false,
"stem":false,
"stem_dictionary": false
"stem_dictionary": ""
},
{
"async_reference":true,
Expand All @@ -277,7 +277,7 @@ TEST_F(CollectionManagerTest, CollectionCreation) {
"reference":"Products.product_id",
"range_index":false,
"stem":false,
"stem_dictionary": false
"stem_dictionary": ""
},
{
"facet":false,
Expand All @@ -292,7 +292,7 @@ TEST_F(CollectionManagerTest, CollectionCreation) {
"type":"int64",
"range_index":false,
"stem":false,
"stem_dictionary": false
"stem_dictionary": ""
}
],
"id":0,
Expand Down Expand Up @@ -1681,7 +1681,7 @@ TEST_F(CollectionManagerTest, CollectionCreationWithMetadata) {
"type":"string",
"range_index":false,
"stem":false,
"stem_dictionary": false
"stem_dictionary": ""
},
{
"facet":true,
Expand All @@ -1697,7 +1697,7 @@ TEST_F(CollectionManagerTest, CollectionCreationWithMetadata) {
"type":"int32",
"range_index":false,
"stem":false,
"stem_dictionary": false
"stem_dictionary": ""
},{
"facet":true,
"index":true,
Expand All @@ -1712,7 +1712,7 @@ TEST_F(CollectionManagerTest, CollectionCreationWithMetadata) {
"type":"int32",
"range_index":false,
"stem":false,
"stem_dictionary": false
"stem_dictionary": ""
},{
"facet":true,
"index":true,
Expand All @@ -1727,7 +1727,7 @@ TEST_F(CollectionManagerTest, CollectionCreationWithMetadata) {
"type":"int32",
"range_index":false,
"stem":false,
"stem_dictionary": false
"stem_dictionary": ""
}
],
"id":1,
Expand Down
6 changes: 3 additions & 3 deletions test/collection_specific_more_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3396,7 +3396,7 @@ TEST_F(CollectionSpecificMoreTest, StemmingPlurals) {
nlohmann::json schema = R"({
"name": "titles",
"fields": [
{"name": "title", "type": "string", "stem_dictionary": true }
{"name": "title", "type": "string", "stem_dictionary": "set1" }
]
})"_json;

Expand All @@ -3409,12 +3409,12 @@ TEST_F(CollectionSpecificMoreTest, StemmingPlurals) {
std::vector<std::string> json_lines;
json_lines.push_back(json_line);

ASSERT_TRUE(StemmerManager::get_instance().save_words(json_lines));
ASSERT_TRUE(StemmerManager::get_instance().save_words("set1", json_lines));

schema = R"({
"name": "titles_no_stem",
"fields": [
{"name": "title", "type": "string", "stem_dictionary": false }
{"name": "title", "type": "string" }
]
})"_json;

Expand Down
Loading

0 comments on commit c624195

Please sign in to comment.