From e9a4748ecd10fb73d1aedd3623a871be6e2cd145 Mon Sep 17 00:00:00 2001 From: yzq <58433399+yangzq50@users.noreply.github.com> Date: Tue, 12 Nov 2024 14:02:55 +0800 Subject: [PATCH 1/3] Update rag_analyzer.cpp --- src/common/analyzer/rag_analyzer.cpp | 103 ++++++++++++--------------- 1 file changed, 44 insertions(+), 59 deletions(-) diff --git a/src/common/analyzer/rag_analyzer.cpp b/src/common/analyzer/rag_analyzer.cpp index b936af6f7d..1f1d516c35 100644 --- a/src/common/analyzer/rag_analyzer.cpp +++ b/src/common/analyzer/rag_analyzer.cpp @@ -22,6 +22,7 @@ module; #include #include #include +#include #include "string_utils.h" @@ -132,7 +133,7 @@ String Join(const Vector &tokens, int start, int end, const String &deli oss << delim; oss << tokens[i]; } - return oss.str(); + return std::move(oss).str(); } String Join(const Vector &tokens, int start, const String &delim = " ") { return Join(tokens, start, tokens.size(), delim); } @@ -144,7 +145,7 @@ String Join(const TermList &tokens, int start, int end, const String &delim = " oss << delim; oss << tokens[i].text_; } - return oss.str(); + return std::move(oss).str(); } bool IsChinese(const String &str) { @@ -657,36 +658,23 @@ String RAGAnalyzer::RKey(const String &line) { } Pair, double> RAGAnalyzer::Score(const Vector> &token_freqs) { - const int B = 30; - double F = 0.0, L = 0.0; + constexpr i64 B = 30; + i64 F = 0, L = 0; Vector tokens; tokens.reserve(token_freqs.size()); - - for (const auto &tk_freq_tag : token_freqs) { - const String &token = tk_freq_tag.first; - const auto &freq_tag = tk_freq_tag.second; - i32 freq = DecodeFreq(freq_tag); - - F += freq; + for (const auto &[token, freq_tag] : token_freqs) { + F += DecodeFreq(freq_tag); L += (UTF8Length(token) < 2) ? 0 : 1; tokens.push_back(token); } - - if (!tokens.empty()) { - F /= tokens.size(); - L /= tokens.size(); - } - - double score = B / static_cast(tokens.size()) + L + F; - return {tokens, score}; + const auto score = (B + L + F) / static_cast(tokens.size()); + return {std::move(tokens), score}; } void RAGAnalyzer::SortTokens(const Vector>> &token_list, Vector, double>> &res) { for (const auto &tfts : token_list) { - auto [tks, score] = Score(tfts); - res.emplace_back(tks, score); + res.push_back(Score(tfts)); } - std::sort(res.begin(), res.end(), [](const auto &a, const auto &b) { return a.second > b.second; }); } @@ -711,9 +699,9 @@ Pair, double> RAGAnalyzer::MaxForward(const String &line) { int v = trie_->Get(Key(t)); if (v != -1) { - res.emplace_back(t, v); + res.emplace_back(std::move(t), v); } else { - res.emplace_back(t, 0); + res.emplace_back(std::move(t), 0); } s = e; @@ -727,7 +715,7 @@ Pair, double> RAGAnalyzer::MaxBackward(const String &line) { int s = UTF8Length(line) - 1; while (s >= 0) { - int e = s + 1; + const int e = s + 1; String t = UTF8Substr(line, s, e - s); while (s > 0 && trie_->HasKeysWithPrefix(RKey(t))) { s -= 1; @@ -739,10 +727,11 @@ Pair, double> RAGAnalyzer::MaxBackward(const String &line) { } int v = trie_->Get(Key(t)); - if (v != -1) - res.emplace_back(t, v); - else - res.emplace_back(t, 0); + if (v != -1) { + res.emplace_back(std::move(t), v); + } else { + res.emplace_back(std::move(t), 0); + } s -= 1; } @@ -752,27 +741,20 @@ Pair, double> RAGAnalyzer::MaxBackward(const String &line) { } int RAGAnalyzer::DFS(const String &chars, - int s, + const int s, Vector> &pre_tokens, Vector>> &token_list, Vector &best_tokens, double &max_score, - bool memo_all) { + const bool memo_all) { int res = s; - int len = UTF8Length(chars); + const int len = UTF8Length(chars); if (s >= len) { if (memo_all) { token_list.push_back(pre_tokens); - } else { - double current_score = Score(pre_tokens).second; - if (current_score > max_score) { - best_tokens.clear(); - best_tokens.reserve(pre_tokens.size()); - for (auto &t : pre_tokens) { - best_tokens.push_back(t.first); - } - max_score = current_score; - } + } else if (auto [vec_str ,current_score] = Score(pre_tokens); current_score > max_score) { + best_tokens = std::move(vec_str); + max_score = current_score; } return res; } @@ -802,13 +784,9 @@ int RAGAnalyzer::DFS(const String &chars, break; } - if (trie_->HasKeysWithPrefix(k)) { + if (const int v = trie_->Get(k); v != -1) { auto pretks = pre_tokens; - int v = trie_->Get(k); - if (v != -1) - pretks.emplace_back(t, v); - else - pretks.emplace_back(t, Encode(-12, 0)); + pretks.emplace_back(std::move(t), v); res = std::max(res, DFS(chars, e, pretks, token_list, best_tokens, max_score, memo_all)); } } @@ -818,12 +796,11 @@ int RAGAnalyzer::DFS(const String &chars, } String t = UTF8Substr(chars, s, 1); - String k = Key(t); - int v = trie_->Get(k); - if (v != -1) - pre_tokens.emplace_back(t, v); - else - pre_tokens.emplace_back(t, Encode(-12, 0)); + if (const int v = trie_->Get(Key(t)); v != -1) { + pre_tokens.emplace_back(std::move(t), v); + } else { + pre_tokens.emplace_back(std::move(t), Encode(-12, 0)); + } return DFS(chars, s + 1, pre_tokens, token_list, best_tokens, max_score, memo_all); } @@ -1106,12 +1083,13 @@ void RAGAnalyzer::FineGrainedTokenize(const String &tokens, Vector &resu } for (auto &token : tks) { - if (UTF8Length(token) < 3 || RE2::PartialMatch(token, pattern4_)) { //[0-9,\\.-]+$ + const auto token_len = UTF8Length(token); + if (token_len < 3 || RE2::PartialMatch(token, pattern4_)) { //[0-9,\\.-]+$ res.push_back(token); continue; } Vector>> token_list; - if (UTF8Length(token) > 10) { + if (token_len > 10) { Vector> tk; tk.emplace_back(token, Encode(-1, 0)); token_list.push_back(tk); @@ -1127,16 +1105,23 @@ void RAGAnalyzer::FineGrainedTokenize(const String &tokens, Vector &resu } Vector, double>> sorted_tokens; SortTokens(token_list, sorted_tokens); - auto stk = sorted_tokens[1].first; - if (stk.size() == token.length()) { + const auto &stk = sorted_tokens[1].first; + if (stk.size() == token_len) { res.push_back(token); } else if (RE2::PartialMatch(token, pattern5_)) { // [a-z\\.-]+ + bool need_append_stk = true; for (auto &t : stk) { if (UTF8Length(t) < 3) { res.push_back(token); + need_append_stk = false; break; } } + if (need_append_stk) { + for (auto &t : stk) { + res.push_back(t); + } + } } else { for (auto &t : stk) { res.push_back(t); @@ -1151,7 +1136,7 @@ void RAGAnalyzer::FineGrainedTokenize(const String &tokens, Vector &resu int RAGAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func) { unsigned level = 0; Vector tokens; - String output = Tokenize(input.text_.c_str()); + String output = Tokenize(input.text_); if (fine_grained_) { FineGrainedTokenize(output, tokens); } else From 0861717015a4d15496bd3a1e8ed84ba128c50c00 Mon Sep 17 00:00:00 2001 From: yzq <58433399+yangzq50@users.noreply.github.com> Date: Thu, 14 Nov 2024 19:19:02 +0800 Subject: [PATCH 2/3] Add GetBestTokens --- src/common/analyzer/darts_trie.cpp | 30 +++--- src/common/analyzer/darts_trie.cppm | 6 +- src/common/analyzer/rag_analyzer.cpp | 134 +++++++++++++++++++++++--- src/common/analyzer/rag_analyzer.cppm | 26 ++--- src/common/analyzer/string_utils.h | 31 +++++- 5 files changed, 184 insertions(+), 43 deletions(-) diff --git a/src/common/analyzer/darts_trie.cpp b/src/common/analyzer/darts_trie.cpp index 19eceb0d04..30c8e7fcd4 100644 --- a/src/common/analyzer/darts_trie.cpp +++ b/src/common/analyzer/darts_trie.cpp @@ -89,22 +89,28 @@ void DartsTrie::Load(const String &file_name) { darts_->open(file_name.c_str()); void DartsTrie::Save(const String &file_name) { darts_->save(file_name.c_str()); } -bool DartsTrie::HasKeysWithPrefix(const String &key) { - std::size_t key_pos = 0; - DartsCore::value_type result = 0; - std::size_t id = 0; - for (std::size_t i = 0; i < key.length(); ++i) { - result = darts_->traverse(key.c_str(), id, key_pos, i + 1); - if (result == -2) - return false; +// string literal "" is null-terminated +constexpr std::string_view empty_null_terminated_sv = ""; + +bool DartsTrie::HasKeysWithPrefix(std::string_view key) const { + if (key.empty()) [[unlikely]] { + key = empty_null_terminated_sv; } + std::size_t id = 0; + std::size_t key_pos = 0; + const auto result = darts_->traverse(key.data(), id, key_pos, key.size()); return result != -2; } -int DartsTrie::Get(const String &key) { - DartsCore::value_type value; - darts_->exactMatchSearch(key.c_str(), value); - return value; +int DartsTrie::Traverse(const char *key, std::size_t &node_pos, std::size_t &key_pos, const std::size_t length) const { + return darts_->traverse(key, node_pos, key_pos, length); +} + +int DartsTrie::Get(std::string_view key) const { + if (key.empty()) [[unlikely]] { + key = empty_null_terminated_sv; + } + return darts_->exactMatchSearch(key.data(), key.size()); } } // namespace infinity \ No newline at end of file diff --git a/src/common/analyzer/darts_trie.cppm b/src/common/analyzer/darts_trie.cppm index 809ae7d4c8..4c6d014a50 100644 --- a/src/common/analyzer/darts_trie.cppm +++ b/src/common/analyzer/darts_trie.cppm @@ -64,9 +64,11 @@ public: void Save(const String &file_name); - bool HasKeysWithPrefix(const String &key); + bool HasKeysWithPrefix(std::string_view key) const; - int Get(const String &key); + int Traverse(const char *key, SizeT &node_pos, SizeT &key_pos, SizeT length) const; + + int Get(std::string_view key) const; }; } // namespace infinity \ No newline at end of file diff --git a/src/common/analyzer/rag_analyzer.cpp b/src/common/analyzer/rag_analyzer.cpp index 1f1d516c35..2fbcce4c10 100644 --- a/src/common/analyzer/rag_analyzer.cpp +++ b/src/common/analyzer/rag_analyzer.cpp @@ -126,7 +126,8 @@ String Replace(const RE2 &re, const String &replacement, const String &input) { return output; } -String Join(const Vector &tokens, int start, int end, const String &delim = " ") { +template +String Join(const Vector &tokens, int start, int end, const String &delim = " ") { std::ostringstream oss; for (int i = start; i < end; ++i) { if (i > start) @@ -136,7 +137,10 @@ String Join(const Vector &tokens, int start, int end, const String &deli return std::move(oss).str(); } -String Join(const Vector &tokens, int start, const String &delim = " ") { return Join(tokens, start, tokens.size(), delim); } +template +String Join(const Vector &tokens, int start, const String &delim = " ") { + return Join(tokens, start, tokens.size(), delim); +} String Join(const TermList &tokens, int start, int end, const String &delim = " ") { std::ostringstream oss; @@ -636,16 +640,18 @@ String RAGAnalyzer::StrQ2B(const String &input) { return output; } -i32 RAGAnalyzer::Freq(const String &key) { +i32 RAGAnalyzer::Freq(const std::string_view key) const { i32 v = trie_->Get(key); v = DecodeFreq(v); - return i32(std::exp(v) * DENOMINATOR + 0.5); + return static_cast(std::exp(v) * DENOMINATOR + 0.5); } -String RAGAnalyzer::Key(const String &line) { return ToLowerString(line); } +String RAGAnalyzer::Key(const std::string_view line) { return ToLowerString(line); } -String RAGAnalyzer::RKey(const String &line) { +String RAGAnalyzer::RKey(const std::string_view line) { String reversed; + reversed.reserve(line.size() + 2); + reversed += "DD"; for (size_t i = line.size(); i > 0;) { size_t start = i - 1; while (start > 0 && (line[start] & 0xC0) == 0x80) { @@ -654,7 +660,8 @@ String RAGAnalyzer::RKey(const String &line) { reversed += line.substr(start, i - start); i = start; } - return "DD" + ToLowerString(reversed); + ToLower(reversed.data() + 2, reversed.size() - 2); + return reversed; } Pair, double> RAGAnalyzer::Score(const Vector> &token_freqs) { @@ -678,7 +685,7 @@ void RAGAnalyzer::SortTokens(const Vector>> &token_list std::sort(res.begin(), res.end(), [](const auto &a, const auto &b) { return a.second > b.second; }); } -Pair, double> RAGAnalyzer::MaxForward(const String &line) { +Pair, double> RAGAnalyzer::MaxForward(const String &line) const { Vector> res; std::size_t s = 0; std::size_t len = UTF8Length(line); @@ -710,7 +717,7 @@ Pair, double> RAGAnalyzer::MaxForward(const String &line) { return Score(res); } -Pair, double> RAGAnalyzer::MaxBackward(const String &line) { +Pair, double> RAGAnalyzer::MaxBackward(const String &line) const { Vector> res; int s = UTF8Length(line) - 1; @@ -746,7 +753,7 @@ int RAGAnalyzer::DFS(const String &chars, Vector>> &token_list, Vector &best_tokens, double &max_score, - const bool memo_all) { + const bool memo_all) const { int res = s; const int len = UTF8Length(chars); if (s >= len) { @@ -805,7 +812,108 @@ int RAGAnalyzer::DFS(const String &chars, return DFS(chars, s + 1, pre_tokens, token_list, best_tokens, max_score, memo_all); } -String RAGAnalyzer::Merge(const String &tks_str) { +struct BestTokenCandidate { + u32 token_num{}; + i64 score_sum{}; + Vector tokens{}; +}; + +struct GrowingBestTokenCandidates { + Vector candidates{}; + + void AddBestTokenCandidate(const u32 tn, const i64 ss, const Vector &tks_old_first, const std::string_view wait_append) { + const auto it = + std::lower_bound(candidates.begin(), + candidates.end(), + tn, + [](const BestTokenCandidate &a, const u32 x) { + return a.token_num < x; + }); + const bool it_tn_same = (it != candidates.end() && it->token_num == tn); + if (it_tn_same && it->score_sum >= ss) { + return; + } + BestTokenCandidate candidate = {tn, ss}; + candidate.tokens.reserve(tks_old_first.size() + 1); + candidate.tokens.insert(candidate.tokens.end(), tks_old_first.begin(), tks_old_first.end()); + candidate.tokens.push_back(wait_append); + if (it_tn_same) { + *it = std::move(candidate); + } else { + candidates.insert(it, std::move(candidate)); + } + } +}; + +constexpr i64 BASE_SCORE_SUM = 30; + +Pair, double> RAGAnalyzer::GetBestTokens(const std::string_view chars) const { + const auto utf8_len = UTF8Length(chars); + Vector dp_vec(utf8_len + 1); + dp_vec[0].candidates.resize(1); + const char *current_utf8_ptr = chars.data(); + u32 current_left_chars = chars.size(); + String growing_key; // in lower case + for (u32 i = 0; i < utf8_len; ++i) { + const std::string_view current_chars{current_utf8_ptr, current_left_chars}; + const u32 left_utf8_cnt = utf8_len - i; + growing_key.clear(); + const char *lookup_until = current_utf8_ptr; + u32 lookup_left_chars = current_left_chars; + std::size_t reuse_node_pos = 0; + std::size_t reuse_key_pos = 0; + for (u32 j = 1; j <= left_utf8_cnt; ++j) { + { + // handle growing_key + const auto next_one_utf8 = UTF8Substrview({lookup_until, lookup_left_chars}, 0, 1); + if (next_one_utf8.size() == 1 && next_one_utf8[0] >= 'A' && next_one_utf8[0] <= 'Z') { + growing_key.push_back(next_one_utf8[0] - 'A' + 'a'); + } else { + growing_key.append(next_one_utf8); + } + lookup_until += next_one_utf8.size(); + lookup_left_chars -= next_one_utf8.size(); + } + auto update_dp_vec = [&dp_vec, i, j, original_sv=std::string_view{current_utf8_ptr, growing_key.size()}](const i32 key_score) { + auto &target_dp = dp_vec[i + j]; + for (const auto &[tn, ss, v] : dp_vec[i].candidates) { + target_dp.AddBestTokenCandidate(tn + 1, ss + key_score, v, original_sv); + } + }; + if (const auto traverse_result = trie_->Traverse(growing_key.data(), reuse_node_pos, reuse_key_pos, growing_key.size()); + traverse_result >= 0) { + // in dictionary + const auto key_score = DecodeFreq(traverse_result) + static_cast(j >= 2); + update_dp_vec(key_score); + } else { + // not in dictionary + if (j == 1) { + // also give a score: -12 + update_dp_vec(-12); + } + if (traverse_result == -2) { + // no more results + break; + } + } + } + // update current_utf8_ptr and current_left_chars + const auto forward_cnt = UTF8Substrview(current_chars, 0, 1).size(); + current_utf8_ptr += forward_cnt; + current_left_chars -= forward_cnt; + } + Pair, double> result; + result.second = std::numeric_limits::lowest(); + for (auto &[token_num, score_sum, tokens] : dp_vec.back().candidates) { + if (const auto score = static_cast(BASE_SCORE_SUM + score_sum) / token_num; score > result.second) { + result.first = std::move(tokens); + result.second = score; + } + } + return result; +} + +String RAGAnalyzer::Merge(const String &tks_str) const { String tks = tks_str; tks = Replace(replace_space_pattern_, " ", tks); @@ -849,7 +957,7 @@ void RAGAnalyzer::EnglishNormalize(const Vector &tokens, Vector } } -void RAGAnalyzer::TokenizeInner(Vector &res, const String &L) { +void RAGAnalyzer::TokenizeInner(Vector &res, const String &L) const { auto [tks, s] = MaxForward(L); auto [tks1, s1] = MaxBackward(L); #if 0 @@ -950,7 +1058,7 @@ void RAGAnalyzer::TokenizeInner(Vector &res, const String &L) { #endif } -void RAGAnalyzer::SplitLongText(const String &L, u32 length, Vector &sublines) { +void RAGAnalyzer::SplitLongText(const String &L, u32 length, Vector &sublines) const { u32 slice_count = length / MAX_SENTENCE_LEN + 1; sublines.reserve(slice_count); std::size_t last_sentence_start = 0; diff --git a/src/common/analyzer/rag_analyzer.cppm b/src/common/analyzer/rag_analyzer.cppm index 07ff9de072..640bba5632 100644 --- a/src/common/analyzer/rag_analyzer.cppm +++ b/src/common/analyzer/rag_analyzer.cppm @@ -57,21 +57,21 @@ protected: private: static constexpr float DENOMINATOR = 1000000; - String StrQ2B(const String &input); + static String StrQ2B(const String &input); - i32 Freq(const String &key); + i32 Freq(std::string_view key) const; - String Key(const String &line); + static String Key(std::string_view line); - String RKey(const String &line); + static String RKey(std::string_view line); - Pair, double> Score(const Vector> &token_freqs); + static Pair, double> Score(const Vector> &token_freqs); - void SortTokens(const Vector>> &token_list, Vector, double>> &res); + static void SortTokens(const Vector>> &token_list, Vector, double>> &res); - Pair, double> MaxForward(const String &line); + Pair, double> MaxForward(const String &line) const; - Pair, double> MaxBackward(const String &line); + Pair, double> MaxBackward(const String &line) const; int DFS(const String &chars, int s, @@ -79,17 +79,19 @@ private: Vector>> &token_list, Vector &best_tokens, double &max_score, - bool memo_all); + bool memo_all) const; - void TokenizeInner(Vector &res, const String &L); + void TokenizeInner(Vector &res, const String &L) const; - void SplitLongText(const String &L, u32 length, Vector &sublines); + void SplitLongText(const String &L, u32 length, Vector &sublines) const; - String Merge(const String &tokens); + String Merge(const String &tokens) const; void EnglishNormalize(const Vector &tokens, Vector &res); public: + Pair, double> GetBestTokens(std::string_view chars) const; + static const SizeT term_string_buffer_limit_ = 4096 * 3; String dict_path_; diff --git a/src/common/analyzer/string_utils.h b/src/common/analyzer/string_utils.h index 086d80852f..97a12bdfc6 100644 --- a/src/common/analyzer/string_utils.h +++ b/src/common/analyzer/string_utils.h @@ -71,8 +71,8 @@ inline void ToLower(const char *data, size_t len, char *out, size_t out_limit) { (*end) = '\0'; } -inline std::string ToLowerString(std::string const &s) { - std::string result = s; +inline std::string ToLowerString(std::string_view s) { + std::string result{s.data(), s.size()}; char *begin = result.data(); char *end = result.data() + s.size(); @@ -143,10 +143,10 @@ static const uint8_t UTF8_BYTE_LENGTH_TABLE[256] = { // invalid utf8 byte: 0b1111'1000~ 0b1111'1111 4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1}; -inline uint32_t UTF8Length(std::string const &str) { +inline uint32_t UTF8Length(const std::string_view str) { uint32_t len = 0; for (uint32_t i = 0, char_size = 0; i < str.size(); i += char_size) { - char_size = UTF8_BYTE_LENGTH_TABLE[static_cast(str.data()[i])]; + char_size = UTF8_BYTE_LENGTH_TABLE[static_cast(str[i])]; ++len; } return len; @@ -175,4 +175,27 @@ static inline std::string UTF8Substr(const std::string &str, std::size_t start, return str.substr(start_byte, end_byte - start_byte); } +static inline std::string_view UTF8Substrview(const std::string_view str, const std::size_t start, const std::size_t len) { + const std::size_t str_len = str.length(); + std::size_t i = 0; + std::size_t byte_index = 0; + std::size_t start_byte = 0; + std::size_t end_byte = 0; + + while (byte_index < str_len && i < (start + len)) { + const std::size_t char_len = UTF8_BYTE_LENGTH_TABLE[static_cast(str[byte_index])]; + if (i >= start) { + if (i == start) { + start_byte = byte_index; + } + end_byte = byte_index + char_len; + } + + byte_index += char_len; + i += 1; + } + + return str.substr(start_byte, end_byte - start_byte); +} + } // namespace infinity From 4b3d70c5e77f260eb6b5eec9d89d9248bd698c81 Mon Sep 17 00:00:00 2001 From: yzq <58433399+yangzq50@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:27:52 +0800 Subject: [PATCH 3/3] Update test --- src/common/analyzer/rag_analyzer.cpp | 92 ++++++++++++++++++++++++++-- 1 file changed, 87 insertions(+), 5 deletions(-) diff --git a/src/common/analyzer/rag_analyzer.cpp b/src/common/analyzer/rag_analyzer.cpp index 2fbcce4c10..c058c72b53 100644 --- a/src/common/analyzer/rag_analyzer.cpp +++ b/src/common/analyzer/rag_analyzer.cpp @@ -23,6 +23,8 @@ module; #include #include #include +#include +#include #include "string_utils.h" @@ -957,6 +959,62 @@ void RAGAnalyzer::EnglishNormalize(const Vector &tokens, Vector } } +// TODO: for test +// #ifndef INFINITY_DEBUG +// #define INFINITY_DEBUG 1 +// #endif + +#ifdef INFINITY_DEBUG +template +String TestPrintTokens(const Vector &tokens) { + std::ostringstream oss; + for (std::size_t i = 0; i < tokens.size(); ++i) { + oss << (i ? " #" : "#"); + oss << tokens[i]; + oss << "#"; + } + return std::move(oss).str(); +} + +inline void CheckDP(const RAGAnalyzer *this_ptr, + const std::string_view input_str, + const Vector &dfs_tokens, + const double dfs_score, + const auto t0, + const auto t1) { + const auto [dp_vec, dp_score] = this_ptr->GetBestTokens(input_str); + const auto t2 = std::chrono::high_resolution_clock::now(); + const auto dfs_duration = std::chrono::duration_cast>(t1 - t0); + const auto dp_duration = std::chrono::duration_cast>(t2 - t1); + auto print_1 = [](const bool b) { + return b ? "✅✅✅" : "❌❌❌"; + }; + auto print_2 = [](const bool b) { + return b ? "" : " not"; + }; + const auto dp_faster = dp_duration < dfs_duration; + std::cerr << "\n!!! " << print_1(dp_faster) << "\nDFS duration: " << dfs_duration << " \nDP duration: " << dp_duration; + const auto b_score_eq = dp_score == dfs_score; + std::cerr << std::format("\n{} DFS and DP score{} equal:\nDFS: {}\nDP : {}\n", print_1(b_score_eq), print_2(b_score_eq), dfs_score, dp_score); + bool vec_equal = true; + if (dp_vec.size() != dfs_tokens.size()) { + vec_equal = false; + } else { + for (std::size_t k = 0; k < dp_vec.size(); ++k) { + if (dp_vec[k] != dfs_tokens[k]) { + vec_equal = false; + break; + } + } + } + std::cerr << std::format("{} DFS and DP result{} equal:\nDFS: {}\nDP : {}\n", + print_1(vec_equal), + print_2(vec_equal), + TestPrintTokens(dfs_tokens), + TestPrintTokens(dp_vec)); +} +#endif + void RAGAnalyzer::TokenizeInner(Vector &res, const String &L) const { auto [tks, s] = MaxForward(L); auto [tks1, s1] = MaxBackward(L); @@ -992,7 +1050,15 @@ void RAGAnalyzer::TokenizeInner(Vector &res, const String &L) const { Vector>> token_list; Vector best_tokens; double max_score = -100.0F; - DFS(Join(tks, _j, j, ""), 0, pre_tokens, token_list, best_tokens, max_score, false); + const auto str_for_dfs = Join(tks, _j, j, ""); +#ifdef INFINITY_DEBUG + const auto t0 = std::chrono::high_resolution_clock::now(); +#endif + DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false); +#ifdef INFINITY_DEBUG + const auto t1 = std::chrono::high_resolution_clock::now(); + CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1); +#endif res.push_back(Join(best_tokens, 0)); same = 1; @@ -1009,7 +1075,15 @@ void RAGAnalyzer::TokenizeInner(Vector &res, const String &L) const { Vector>> token_list; Vector best_tokens; double max_score = -100.0F; - DFS(Join(tks, _j, tks.size(), ""), 0, pre_tokens, token_list, best_tokens, max_score, false); + const auto str_for_dfs = Join(tks, _j, tks.size(), ""); +#ifdef INFINITY_DEBUG + const auto t0 = std::chrono::high_resolution_clock::now(); +#endif + DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false); +#ifdef INFINITY_DEBUG + const auto t1 = std::chrono::high_resolution_clock::now(); + CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1); +#endif res.push_back(Join(best_tokens, 0)); } @@ -1048,7 +1122,15 @@ void RAGAnalyzer::TokenizeInner(Vector &res, const String &L) const { Vector>> token_list; Vector best_tokens; double max_score = -100.0F; - DFS(Join(tks, s, e < tks.size() ? e + 1 : e, ""), 0, pre_tokens, token_list, best_tokens, max_score, false); + const auto str_for_dfs = Join(tks, s, e < tks.size() ? e + 1 : e, ""); +#ifdef INFINITY_DEBUG + const auto t0 = std::chrono::high_resolution_clock::now(); +#endif + DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false); +#ifdef INFINITY_DEBUG + const auto t1 = std::chrono::high_resolution_clock::now(); + CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1); +#endif // Vector, double>> sorted_tokens; // SortTokens(token_list, sorted_tokens); // res.push_back(Join(sorted_tokens[0].first, 0)); @@ -1146,7 +1228,7 @@ String RAGAnalyzer::Tokenize(const String &line) { res.push_back(L); continue; } - +#if 1 if (length > MAX_SENTENCE_LEN) { Vector sublines; SplitLongText(L, length, sublines); @@ -1154,7 +1236,7 @@ String RAGAnalyzer::Tokenize(const String &line) { TokenizeInner(res, l); } } else - +#endif TokenizeInner(res, L); } Vector normalize_res;