Skip to content

Commit

Permalink
reload dictionaries on restart
Browse files Browse the repository at this point in the history
  • Loading branch information
krunal1313 committed Dec 3, 2024
1 parent da4a9ae commit 87abea6
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 25 deletions.
33 changes: 31 additions & 2 deletions include/stemmer_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "lru/lru.hpp"
#include "sparsepp.h"
#include "json.hpp"
#include "store.h"


class Stemmer {
Expand All @@ -30,24 +31,52 @@ class StemmerManager {
StemmerManager() {}
std::mutex mutex;
spp::sparse_hash_map<std::string, spp::sparse_hash_map<std::string, std::string>> stem_dictionaries;
Store* store;

std::string get_stemming_dictionary_key(const std::string& dictionary_name);

public:
static StemmerManager& get_instance() {
static StemmerManager instance;
return instance;
}

static constexpr const char* STEMMING_DICTIONARY_PREFIX = "$SD";

StemmerManager(StemmerManager const&) = delete;

void operator=(StemmerManager const&) = delete;

StemmerManager(StemmerManager&&) = delete;

void operator=(StemmerManager&&) = delete;

~StemmerManager();

void init(Store* _store);

void dispose();

std::shared_ptr<Stemmer> get_stemmer(const std::string& language, const std::string& dictionary_name="");

void delete_stemmer(const std::string& language);

void delete_all_stemmers();

const bool validate_language(const std::string& language);
bool save_words(const std::string& dictionary_name, const std::vector<std::string> &json_lines);

Option<bool> upsert_stemming_dictionary(const std::string& dictionary_name, const std::vector<std::string> &json_lines,
bool write_to_store = true);

bool load_stemming_dictioary(const nlohmann::json& dictionary);

std::string get_normalized_word(const std::string& dictionary_name, const std::string& word);

void get_stemming_dictionaries(nlohmann::json& dictionaries);

bool get_stemming_dictionary(const std::string& id, nlohmann::json& dictionary);
void del_stemming_dictionary(const std::string& id);

Option<bool> del_stemming_dictionary(const std::string& id);

void delete_all_stemming_dictionaries();
};
22 changes: 20 additions & 2 deletions src/collection_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,6 @@ Option<bool> CollectionManager::load(const size_t collection_batch_size, const s

iter = store->scan(preset_prefix_key, &preset_upper_bound);
while(iter->Valid() && iter->key().starts_with(preset_prefix_key)) {
std::vector<std::string> parts;
std::string preset_name = iter->key().ToString().substr(preset_prefix_key.size());
nlohmann::json preset_obj = nlohmann::json::parse(iter->value().ToString(), nullptr, false);

Expand All @@ -451,7 +450,6 @@ Option<bool> CollectionManager::load(const size_t collection_batch_size, const s

iter = store->scan(stopword_prefix_key, &stopword_upper_bound);
while(iter->Valid() && iter->key().starts_with(stopword_prefix_key)) {
std::vector<std::string> parts;
std::string stopword_name = iter->key().ToString().substr(stopword_prefix_key.size());
nlohmann::json stopword_obj = nlohmann::json::parse(iter->value().ToString(), nullptr, false);

Expand All @@ -465,6 +463,26 @@ Option<bool> CollectionManager::load(const size_t collection_batch_size, const s
}
delete iter;

// load stemming dictionaries
std::string stemming_dictionary_prefix_key = std::string(StemmerManager::STEMMING_DICTIONARY_PREFIX) + "_";
std::string stemming_dictionary_upper_bound_key = std::string(StemmerManager::STEMMING_DICTIONARY_PREFIX) + "`";
rocksdb::Slice stemming_dictionary_upper_bound(stemming_dictionary_upper_bound_key);

iter = store->scan(stemming_dictionary_prefix_key, &stemming_dictionary_upper_bound);
while(iter->Valid() && iter->key().starts_with(stemming_dictionary_prefix_key)) {
std::string stemming_dictionary_name = iter->key().ToString().substr(stemming_dictionary_prefix_key.size());
nlohmann::json stemming_dictionary_obj = nlohmann::json::parse(iter->value().ToString(), nullptr, false);

if(!stemming_dictionary_obj.is_discarded() && stemming_dictionary_obj.is_object()) {
StemmerManager::get_instance().load_stemming_dictioary(stemming_dictionary_obj);
} else {
LOG(INFO) << "Invalid object for stemming dictionary " << stemming_dictionary_name;
}

iter->Next();
}
delete iter;

// restore query suggestions configs
std::vector<std::string> analytics_config_jsons;
store->scan_fill(AnalyticsManager::ANALYTICS_RULE_PREFIX,
Expand Down
11 changes: 8 additions & 3 deletions src/core_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2953,8 +2953,9 @@ bool post_import_stemming_dictionary(const std::shared_ptr<http_req>& req, const
std::stringstream response_stream;

if(!single_partial_record_body) {
if(!StemmerManager::get_instance().save_words(req->params.at(ID), json_lines)) {
res->set_400("Bad/malformed dictionary import.");
auto op = StemmerManager::get_instance().upsert_stemming_dictionary(req->params.at(ID), json_lines);
if(!op.ok()) {
res->set(op.code(), op.error());
stream_response(req, res);
}

Expand Down Expand Up @@ -3009,7 +3010,11 @@ bool del_stemming_dictionary(const std::shared_ptr<http_req>& req, const std::sh
const std::string& id = req->params["id"];
nlohmann::json dictionary;

StemmerManager::get_instance().del_stemming_dictionary(id);
auto delete_op = StemmerManager::get_instance().del_stemming_dictionary(id);

if(!delete_op.ok()) {
res->set(delete_op.code(), delete_op.error());
}

nlohmann::json res_json;
res_json["id"] = id;
Expand Down
62 changes: 56 additions & 6 deletions src/stemmer_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,20 @@ std::string Stemmer::stem(const std::string & word) {
return stemmed_word;
}

void StemmerManager::init(Store* _store) {
store = _store;
}

StemmerManager::~StemmerManager() {
delete_all_stemmers();
stem_dictionaries.clear();
}

void StemmerManager::dispose() {
delete_all_stemmers();
stem_dictionaries.clear();
}

std::shared_ptr<Stemmer> StemmerManager::get_stemmer(const std::string& language, const std::string& dictionary_name) {
std::unique_lock<std::mutex> lock(mutex);
// use english as default language
Expand Down Expand Up @@ -77,26 +86,53 @@ const bool StemmerManager::validate_language(const std::string& language) {
return true;
}

bool StemmerManager::save_words(const std::string& dictionary_name, const std::vector<std::string> &json_lines) {
Option<bool> StemmerManager::upsert_stemming_dictionary(const std::string& dictionary_name, const std::vector<std::string> &json_lines,
bool write_to_store) {
if(json_lines.empty()) {
return false;
return Option<bool>(400, "Invalid dictionary format.");
}

std::lock_guard<std::mutex> lock(mutex);

nlohmann::json json_line;
nlohmann::json dictionary_json;
dictionary_json["id"] = dictionary_name;
dictionary_json["words"] = nlohmann::json::array();

for(const auto& line_str : json_lines) {
try {
json_line = nlohmann::json::parse(line_str);
} catch(...) {
return false;
return Option<bool>(400, "Invalid dictionary format.");
}

if(!json_line.contains("word") || !json_line.contains("root")) {
return false;
return Option<bool>(400, "dictionary lines should contain `word` and `root` values.");
}
stem_dictionaries[dictionary_name].emplace(json_line["word"], json_line["root"]);
dictionary_json["words"].push_back(json_line);
}

if(write_to_store) {
bool inserted = store->insert(get_stemming_dictionary_key(dictionary_name), dictionary_json.dump());
if (!inserted) {
return Option<bool>(500, "Unable to insert into store.");
}
}

return Option<bool>(true);
}

bool StemmerManager::load_stemming_dictioary(const nlohmann::json &dictionary_json) {
const auto& dictionary_name = dictionary_json["id"];
std::vector<std::string> json_lines;

for(const auto& line : dictionary_json["words"]) {
json_lines.push_back(line.dump());
}

upsert_stemming_dictionary(dictionary_name, json_lines, false);

return true;
}

Expand Down Expand Up @@ -151,16 +187,30 @@ bool StemmerManager::get_stemming_dictionary(const std::string &id, nlohmann::js
return false;
}

void StemmerManager::del_stemming_dictionary(const std::string &id) {
Option<bool> StemmerManager::del_stemming_dictionary(const std::string &id) {
std::lock_guard<std::mutex> lock(mutex);

auto found = stem_dictionaries.find(id);
if(found != stem_dictionaries.end()) {
stem_dictionaries.erase(found);

bool removed = store->remove(get_stemming_dictionary_key(id));
if(!removed) {
return Option<bool>(500, "Unable to delete from store.");
}
}

return Option<bool>(true);
}

void StemmerManager::delete_all_stemming_dictionaries() {
std::lock_guard<std::mutex> lock(mutex);
for(const auto& kv : stem_dictionaries) {
store->remove(get_stemming_dictionary_key(kv.first));
}
stem_dictionaries.clear();
}
}

std::string StemmerManager::get_stemming_dictionary_key(const std::string &dictionary_name) {
return std::string(STEMMING_DICTIONARY_PREFIX) + "_" + dictionary_name;
}
4 changes: 4 additions & 0 deletions src/typesense_server_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "stopwords_manager.h"
#include "conversation_manager.h"
#include "vq_model_manager.h"
#include "stemmer_manager.h"

#ifndef ASAN_BUILD
#include "jemalloc.h"
Expand Down Expand Up @@ -476,6 +477,9 @@ int run_server(const Config & config, const std::string & version, void (*master
StopwordsManager& stopwordsManager = StopwordsManager::get_instance();
stopwordsManager.init(&store);

StemmerManager& stemmerManager = StemmerManager::get_instance();
stemmerManager.init(&store);

RateLimitManager *rateLimitManager = RateLimitManager::getInstance();
auto rate_limit_manager_init = rateLimitManager->init(&meta_store);

Expand Down
65 changes: 53 additions & 12 deletions test/collection_specific_more_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class CollectionSpecificMoreTest : public ::testing::Test {
protected:
Store *store;
CollectionManager & collectionManager = CollectionManager::get_instance();
StemmerManager& stemmerManager = StemmerManager::get_instance();
std::atomic<bool> quit = false;

std::vector<std::string> query_fields;
Expand All @@ -21,6 +22,7 @@ class CollectionSpecificMoreTest : public ::testing::Test {
system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str());

store = new Store(state_dir_path);
stemmerManager.init(store);
collectionManager.init(store, 1.0, "auth_key", quit);
collectionManager.load(8, 1000);
}
Expand Down Expand Up @@ -3409,7 +3411,7 @@ TEST_F(CollectionSpecificMoreTest, StemmingDictionary) {
std::vector<std::string> json_lines;
json_lines.push_back(json_line);

ASSERT_TRUE(StemmerManager::get_instance().save_words("set1", json_lines));
ASSERT_TRUE(stemmerManager.upsert_stemming_dictionary("set1", json_lines).ok());

schema = R"({
"name": "titles_no_stem",
Expand Down Expand Up @@ -3437,18 +3439,18 @@ TEST_F(CollectionSpecificMoreTest, StemmingDictionary) {
}

TEST_F(CollectionSpecificMoreTest, StemmingDictionaryBasics) {
StemmerManager::get_instance().delete_all_stemming_dictionaries();
stemmerManager.delete_all_stemming_dictionaries();

std::string json_line = "{\"word\": \"people\", \"root\":\"person\"}";
std::vector<std::string> json_lines;
json_lines.push_back(json_line);

//add dictionary set
ASSERT_TRUE(StemmerManager::get_instance().save_words("set1", json_lines));
ASSERT_TRUE(stemmerManager.upsert_stemming_dictionary("set1", json_lines).ok());

//get dictionary set
nlohmann::json dictionary;
ASSERT_TRUE(StemmerManager::get_instance().get_stemming_dictionary("set1", dictionary));
ASSERT_TRUE(stemmerManager.get_stemming_dictionary("set1", dictionary));
ASSERT_EQ("set1", dictionary["id"]);
ASSERT_EQ(1, dictionary["words"].size());
ASSERT_EQ("people", dictionary["words"][0]["word"]);
Expand All @@ -3458,9 +3460,9 @@ TEST_F(CollectionSpecificMoreTest, StemmingDictionaryBasics) {
json_lines.clear();
json_line = "{\"word\": \"qualities\", \"root\":\"quality\"}";
json_lines.push_back(json_line);
ASSERT_TRUE(StemmerManager::get_instance().save_words("set2", json_lines));
ASSERT_TRUE(stemmerManager.upsert_stemming_dictionary("set2", json_lines).ok());

ASSERT_TRUE(StemmerManager::get_instance().get_stemming_dictionary("set2", dictionary));
ASSERT_TRUE(stemmerManager.get_stemming_dictionary("set2", dictionary));
ASSERT_EQ("set2", dictionary["id"]);
ASSERT_EQ(1, dictionary["words"].size());
ASSERT_EQ("qualities", dictionary["words"][0]["word"]);
Expand All @@ -3470,9 +3472,9 @@ TEST_F(CollectionSpecificMoreTest, StemmingDictionaryBasics) {
json_lines.clear();
json_line = "{\"word\": \"mangoes\", \"root\":\"mango\"}";
json_lines.push_back(json_line);
ASSERT_TRUE(StemmerManager::get_instance().save_words("set2", json_lines));
ASSERT_TRUE(stemmerManager.upsert_stemming_dictionary("set2", json_lines).ok());

ASSERT_TRUE(StemmerManager::get_instance().get_stemming_dictionary("set2", dictionary));
ASSERT_TRUE(stemmerManager.get_stemming_dictionary("set2", dictionary));
ASSERT_EQ("set2", dictionary["id"]);
ASSERT_EQ(2, dictionary["words"].size());
ASSERT_EQ("qualities", dictionary["words"][0]["word"]);
Expand All @@ -3482,15 +3484,54 @@ TEST_F(CollectionSpecificMoreTest, StemmingDictionaryBasics) {

//get all dictionary sets
nlohmann::json dictionary_sets;
StemmerManager::get_instance().get_stemming_dictionaries(dictionary_sets);
stemmerManager.get_stemming_dictionaries(dictionary_sets);
ASSERT_EQ(2, dictionary_sets["dictionaries"].size());
ASSERT_EQ("set1", dictionary_sets["dictionaries"][0].get<std::string>());
ASSERT_EQ("set2", dictionary_sets["dictionaries"][1].get<std::string>());

//del dictionary set and get
dictionary_sets.clear();
StemmerManager::get_instance().del_stemming_dictionary("set2");
StemmerManager::get_instance().get_stemming_dictionaries(dictionary_sets);
stemmerManager.del_stemming_dictionary("set2");
stemmerManager.get_stemming_dictionaries(dictionary_sets);
ASSERT_EQ(1, dictionary_sets["dictionaries"].size());
ASSERT_EQ("set1", dictionary_sets["dictionaries"][0].get<std::string>());
}
}

TEST_F(CollectionSpecificMoreTest, ReloadStemmingDictionaryOnRestart) {
stemmerManager.delete_all_stemming_dictionaries();

std::string json_line = "{\"word\": \"people\", \"root\":\"person\"}";
std::vector<std::string> json_lines;
json_lines.push_back(json_line);

//add dictionary set
ASSERT_TRUE(stemmerManager.upsert_stemming_dictionary("set1", json_lines).ok());

//get dictionary set
nlohmann::json dictionary;
ASSERT_TRUE(stemmerManager.get_stemming_dictionary("set1", dictionary));
ASSERT_EQ("set1", dictionary["id"]);
ASSERT_EQ(1, dictionary["words"].size());
ASSERT_EQ("people", dictionary["words"][0]["word"]);
ASSERT_EQ("person", dictionary["words"][0]["root"]);

//dispose collection manager and reload all stemming dictionaries
collectionManager.dispose();
stemmerManager.dispose();
delete store;

std::string state_dir_path = "/tmp/typesense_test/collection_specific_more";
store = new Store(state_dir_path);

stemmerManager.init(store);
collectionManager.init(store, 1.0, "auth_key", quit);
collectionManager.load(8, 1000);

ASSERT_TRUE(stemmerManager.get_stemming_dictionary("set1", dictionary));
ASSERT_EQ("set1", dictionary["id"]);
ASSERT_EQ(1, dictionary["words"].size());
ASSERT_EQ("people", dictionary["words"][0]["word"]);
ASSERT_EQ("person", dictionary["words"][0]["root"]);

collectionManager.drop_collection("coll1");
}

0 comments on commit 87abea6

Please sign in to comment.