Skip to content

Commit

Permalink
acess dictionary word from stemmerManager directly
Browse files Browse the repository at this point in the history
  • Loading branch information
krunal1313 committed Nov 29, 2024
1 parent c624195 commit 096a7ae
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 19 deletions.
2 changes: 1 addition & 1 deletion include/core_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ bool del_analytics_rules(const std::shared_ptr<http_req>& req, const std::shared
bool post_write_analytics_to_db(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res);

//plurals, nouns
bool post_import_plurals(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res);
bool post_import_dictionary(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res);

// Misc helpers

Expand Down
4 changes: 2 additions & 2 deletions include/stemmer_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class StemmerManager {
std::unordered_map<std::string, std::shared_ptr<Stemmer>> stemmers;
StemmerManager() {}
std::mutex mutex;
spp::sparse_hash_map<std::string, spp::sparse_hash_map<std::string, std::string>> stem_dictionary;
spp::sparse_hash_map<std::string, spp::sparse_hash_map<std::string, std::string>> stem_dictionaries;
public:
static StemmerManager& get_instance() {
static StemmerManager instance;
Expand All @@ -45,5 +45,5 @@ class StemmerManager {
void delete_all_stemmers();
const bool validate_language(const std::string& language);
bool save_words(const std::string& dictionary_name, const std::vector<std::string> &json_lines);
spp::sparse_hash_map<std::string, std::string> get_dictionary(const std::string& dictionary_name);
std::string get_normalized_word(const std::string& dictionary_name, const std::string& word);
};
10 changes: 5 additions & 5 deletions src/core_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2918,17 +2918,17 @@ bool post_write_analytics_to_db(const std::shared_ptr<http_req>& req, const std:
return true;
}

bool post_import_plurals(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res) {
bool post_import_dictionary(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res) {
const char *BATCH_SIZE = "batch_size";
const char *PLURALS_SET = "plurals_set";
const char *DICTIONARY_SET = "dictionary_set";

if(req->params.count(BATCH_SIZE) == 0) {
req->params[BATCH_SIZE] = "40";
}

if(req->params.count(PLURALS_SET) == 0) {
if(req->params.count(DICTIONARY_SET) == 0) {
res->final = true;
res->set_400("Parameter `" + std::string(PLURALS_SET) + "` must be provided while importing plurals.");
res->set_400("Parameter `" + std::string(DICTIONARY_SET) + "` must be provided while importing dictionary words.");
stream_response(req, res);
return false;
}
Expand All @@ -2953,7 +2953,7 @@ bool post_import_plurals(const std::shared_ptr<http_req>& req, const std::shared
std::stringstream response_stream;

if(!single_partial_record_body) {
if(!StemmerManager::get_instance().save_words(req->params.at(PLURALS_SET), json_lines)) {
if(!StemmerManager::get_instance().save_words(req->params.at(DICTIONARY_SET), json_lines)) {
res->set_400("Bad/malformed dictionary import.");
stream_response(req, res);
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/typesense_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void master_server_routes() {
server->post("/analytics/aggregate_events", post_write_analytics_to_db);

// for plurals, nouns
server->post("/stemming/plurals/import", post_import_plurals);
server->post("/stemming/dictionary/import", post_import_dictionary);

// meta
server->get("/metrics.json", get_metrics_json);
Expand Down
28 changes: 18 additions & 10 deletions src/stemmer_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@ std::string Stemmer::stem(const std::string & word) {
auto stemmed = sb_stemmer_stem(stemmer, reinterpret_cast<const sb_symbol *>(word.c_str()), word.length());
stemmed_word = std::string(reinterpret_cast<const char *>(stemmed));
} else {
const auto& stem_dictionary = StemmerManager::get_instance().get_dictionary(dictionary_name);
auto it = stem_dictionary.find(word);
if(it == stem_dictionary.end()) {
const auto& normalized_word = StemmerManager::get_instance().get_normalized_word(dictionary_name, word);
if(normalized_word.empty()) {
stemmed_word = word;
} else {
stemmed_word = it->second;
stemmed_word = normalized_word;
}
}

Expand All @@ -43,7 +42,7 @@ std::string Stemmer::stem(const std::string & word) {

StemmerManager::~StemmerManager() {
delete_all_stemmers();
stem_dictionary.clear();
stem_dictionaries.clear();
}

std::shared_ptr<Stemmer> StemmerManager::get_stemmer(const std::string& language, const std::string& dictionary_name) {
Expand Down Expand Up @@ -83,6 +82,8 @@ bool StemmerManager::save_words(const std::string& dictionary_name, const std::v
return false;
}

std::lock_guard<std::mutex> lock(mutex);

nlohmann::json json_line;
for(const auto& line_str : json_lines) {
try {
Expand All @@ -94,15 +95,22 @@ bool StemmerManager::save_words(const std::string& dictionary_name, const std::v
if(!json_line.contains("word") || !json_line.contains("root")) {
return false;
}
stem_dictionary[dictionary_name].emplace(json_line["word"], json_line["root"]);
stem_dictionaries[dictionary_name].emplace(json_line["word"], json_line["root"]);
}
return true;
}

spp::sparse_hash_map<std::string, std::string> StemmerManager::get_dictionary(const std::string& dictionary_name) {
if(stem_dictionary.count(dictionary_name) != 0) {
return stem_dictionary.at(dictionary_name);
std::string StemmerManager::get_normalized_word(const std::string &dictionary_name, const std::string &word) {
std::lock_guard<std::mutex> lock(mutex);

std::string normalized_word;
if(stem_dictionaries.count(dictionary_name) != 0) {
const auto& dictionary = stem_dictionaries.at(dictionary_name);
auto found = dictionary.find(word);
if(found != dictionary.end()) {
normalized_word = found->second;
}
}

return spp::sparse_hash_map<std::string, std::string>();
return normalized_word;
}

0 comments on commit 096a7ae

Please sign in to comment.